From 9c9e17decbed3c03a423b3a8a9a1b05a59ca4d51 Mon Sep 17 00:00:00 2001 From: ljy9810 Date: Mon, 9 Sep 2024 10:07:01 +0800 Subject: [PATCH] =?UTF-8?q?cancel=5Fsafe=E9=87=8D=E6=9E=84?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: ljy9810 --- ylong_runtime/src/sync/mpsc/bounded/array.rs | 10 ++- ylong_runtime/src/sync/semaphore_inner.rs | 42 ++++++++++--- ylong_runtime/src/sync/wake_list.rs | 66 +++++++++++++++++--- ylong_runtime/src/sync/watch.rs | 9 ++- ylong_runtime/src/util/slots.rs | 15 +++++ ylong_runtime/tests/cancel_safe.rs | 51 +++++++++++++++ ylong_runtime/tests/entry.rs | 1 + 7 files changed, 172 insertions(+), 22 deletions(-) create mode 100644 ylong_runtime/tests/cancel_safe.rs diff --git a/ylong_runtime/src/sync/mpsc/bounded/array.rs b/ylong_runtime/src/sync/mpsc/bounded/array.rs index 872969a..2cc7a08 100644 --- a/ylong_runtime/src/sync/mpsc/bounded/array.rs +++ b/ylong_runtime/src/sync/mpsc/bounded/array.rs @@ -17,13 +17,14 @@ use std::mem::MaybeUninit; use std::pin::Pin; use std::sync::atomic::AtomicUsize; use std::sync::atomic::Ordering::{AcqRel, Acquire, Release}; +use std::sync::Arc; use std::task::Poll::{Pending, Ready}; use std::task::{Context, Poll}; use crate::sync::atomic_waker::AtomicWaker; use crate::sync::error::{RecvError, SendError, TryRecvError, TrySendError}; use crate::sync::mpsc::Container; -use crate::sync::wake_list::WakerList; +use crate::sync::wake_list::{ListItem, WakerList}; /// The offset of the index. const INDEX_SHIFT: usize = 1; @@ -236,8 +237,11 @@ impl Future for Position<'_, T> { SendPosition::Closed => return Ready(SendPosition::Closed), SendPosition::Full => {} } - - self.array.waiters.insert(cx.waker().clone()); + let wake = cx.waker().clone(); + self.array.waiters.insert(ListItem { + wake, + wait_permit: Arc::new(AtomicUsize::new(1)), + }); let tail = self.array.tail.load(Acquire); let index = (tail >> INDEX_SHIFT) % self.array.capacity; diff --git a/ylong_runtime/src/sync/semaphore_inner.rs b/ylong_runtime/src/sync/semaphore_inner.rs index 50c2b62..c4bd572 100644 --- a/ylong_runtime/src/sync/semaphore_inner.rs +++ b/ylong_runtime/src/sync/semaphore_inner.rs @@ -17,10 +17,11 @@ use std::future::Future; use std::pin::Pin; use std::sync::atomic::AtomicUsize; use std::sync::atomic::Ordering::{AcqRel, Acquire, Release}; +use std::sync::Arc; use std::task::Poll::{Pending, Ready}; use std::task::{Context, Poll}; -use crate::sync::wake_list::WakerList; +use crate::sync::wake_list::{ListItem, WakerList}; /// Maximum capacity of `Semaphore`. const MAX_PERMITS: usize = usize::MAX >> 1; @@ -38,6 +39,8 @@ pub(crate) struct Permit<'a> { semaphore: &'a SemaphoreInner, waker_index: Option, enqueue: bool, + // Sharing state in a multi-threaded environment + wait_permit: Arc, } /// Error returned by `Semaphore`. @@ -181,6 +184,7 @@ impl SemaphoreInner { cx: &mut Context<'_>, waker_index: &mut Option, enqueue: &mut bool, + wait_permit: Arc, ) -> Poll> { let mut curr = self.permits.load(Acquire); if curr & CLOSED == CLOSED { @@ -199,7 +203,11 @@ impl SemaphoreInner { return res; } } else if !(*enqueue) { - *waker_index = Some(self.waker_list.insert(cx.waker().clone())); + let wake = cx.waker().clone(); + *waker_index = Some(self.waker_list.insert(ListItem { + wake, + wait_permit: wait_permit.clone(), + })); *enqueue = true; curr = self.permits.load(Acquire); } else { @@ -218,11 +226,12 @@ impl Debug for SemaphoreInner { } impl<'a> Permit<'a> { - fn new(semaphore: &'a SemaphoreInner) -> Permit { + fn new(semaphore: &'a SemaphoreInner) -> Permit<'a> { Permit { semaphore, waker_index: None, enqueue: false, + wait_permit: Arc::new(AtomicUsize::new(1)), } } } @@ -231,20 +240,35 @@ impl Future for Permit<'_> { type Output = Result<(), SemaphoreError>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let (semaphore, waker_index, enqueue) = unsafe { + let (semaphore, waker_index, enqueue, wait_permit) = unsafe { let me = self.get_unchecked_mut(); - (me.semaphore, &mut me.waker_index, &mut me.enqueue) + ( + me.semaphore, + &mut me.waker_index, + &mut me.enqueue, + me.wait_permit.clone(), + ) }; - semaphore.poll_acquire(cx, waker_index, enqueue) + semaphore.poll_acquire(cx, waker_index, enqueue, wait_permit) } } impl Drop for Permit<'_> { fn drop(&mut self) { if self.enqueue { - // if `enqueue` is true, `waker_index` must be `Some(_)`. - let _ = self.semaphore.waker_list.remove(self.waker_index.unwrap()); + let mut list = self.semaphore.waker_list.lock(); + let wait_permit = self.wait_permit.load(Acquire); + // if 'enqueue' is true, 'waker_index' must be 'Some(_)' + let index = self.waker_index.unwrap(); + let res = list.remove_permit(index, wait_permit); + if res { + let prev = self.semaphore.permits.fetch_add(1 << PERMIT_SHIFT, Release); + assert!( + (prev >> PERMIT_SHIFT) < MAX_PERMITS, + "the number of permits will overflow the capacity after addition" + ); + } } } -} +} \ No newline at end of file diff --git a/ylong_runtime/src/sync/wake_list.rs b/ylong_runtime/src/sync/wake_list.rs index ac5749a..b25c260 100644 --- a/ylong_runtime/src/sync/wake_list.rs +++ b/ylong_runtime/src/sync/wake_list.rs @@ -12,9 +12,12 @@ // limitations under the License. use std::cell::UnsafeCell; +use std::cmp; use std::hint::spin_loop; use std::ops::{Deref, DerefMut}; +use std::sync::atomic::Ordering::{AcqRel, Acquire}; use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::Arc; use std::task::Waker; use crate::util::slots::{Slots, SlotsError}; @@ -25,7 +28,14 @@ const LOCKED: usize = 1 << 0; const NOTIFIABLE: usize = 1 << 1; pub(crate) struct Inner { - wake_list: Slots, + wake_list: Slots, +} + +pub(crate) struct ListItem { + // Task waker + pub(crate) wake: Waker, + // The status of task to get semaphores + pub(crate) wait_permit: Arc, } /// Lists of Wakers @@ -52,14 +62,15 @@ impl WakerList { } /// Pushes a waker into the list and return its index in the list. - pub fn insert(&self, waker: Waker) -> usize { + pub fn insert(&self, waker: ListItem) -> usize { let mut list = self.lock(); list.wake_list.push_back(waker) } /// Removes the waker corresponding to the key. - pub fn remove(&self, key: usize) -> Result { + #[allow(dead_code)] + pub fn remove(&self, key: usize) -> Result { let mut inner = self.lock(); inner.wake_list.remove(key) } @@ -101,15 +112,41 @@ impl WakerList { } } +impl ListItem { + fn get_wait_permit(&self) -> usize { + self.wait_permit.load(Acquire) + } + fn change_permit(&self, curr: usize, next: usize) -> Result { + self.wait_permit. + compare_exchange(curr, next, AcqRel, Acquire) + } + + fn change_status(&self, acquired_permit: usize) -> bool { + let mut curr = self.get_wait_permit(); + loop { + let assign = cmp::min(curr, acquired_permit); + let next = curr - assign; + match self.change_permit(curr, next) { + Ok(_) => return next == 0, + Err(actual) => curr = actual, + } + } + } +} + impl Inner { /// Wakes up one or more members in the WakerList, and return the result. #[inline] fn notify(&mut self, notify_type: Notify) -> bool { let mut is_wake = false; - while let Some(waker) = self.wake_list.pop_front() { - waker.wake(); - is_wake = true; - + while let Some(list_item) = self.wake_list.get_first() { + let res= list_item.change_status(1); + if res { + // If entering this branch, 'wake_list.pop_front()' must be 'Some(_)' + let pop = self.wake_list.pop_front().expect("The list first is NULL"); + pop.wake.wake(); + is_wake = true; + } if notify_type == Notify::One { return is_wake; } @@ -136,6 +173,21 @@ pub(crate) struct Lock<'a> { waker_set: &'a WakerList, } +impl Lock<'_> { + pub(crate) fn remove_permit(&mut self, key: usize, wait_permit: usize) -> bool { + if let Some(list_item) = self.wake_list.get_by_index(key) { + let inner_wait_permit = list_item.get_wait_permit(); + if inner_wait_permit == wait_permit { + let _ = self.wake_list.remove(key); + } + } + if wait_permit == 0 { + return !self.notify_one(); + } + false + } +} + impl Drop for Lock<'_> { #[inline] fn drop(&mut self) { diff --git a/ylong_runtime/src/sync/watch.rs b/ylong_runtime/src/sync/watch.rs index cd34c33..5a229d3 100644 --- a/ylong_runtime/src/sync/watch.rs +++ b/ylong_runtime/src/sync/watch.rs @@ -23,7 +23,7 @@ use std::task::{Context, Poll}; use crate::futures::poll_fn; use crate::sync::error::{RecvError, SendError}; -use crate::sync::wake_list::WakerList; +use crate::sync::wake_list::{ListItem, WakerList}; /// The least significant bit that marks the version of channel. const VERSION_SHIFT: usize = 1; @@ -359,8 +359,11 @@ impl Receiver { Some(Err(e)) => return Ready(Err(e)), None => {} } - - self.channel.waker_list.insert(cx.waker().clone()); + let wake = cx.waker().clone(); + self.channel.waker_list.insert(ListItem { + wake, + wait_permit: Arc::new(AtomicUsize::new(1)), + }); match self.try_notified() { Some(Ok(())) => Ready(Ok(())), diff --git a/ylong_runtime/src/util/slots.rs b/ylong_runtime/src/util/slots.rs index fa73761..9253a53 100644 --- a/ylong_runtime/src/util/slots.rs +++ b/ylong_runtime/src/util/slots.rs @@ -232,6 +232,21 @@ impl Slots { len: 0, } } + pub(crate) fn get_by_index(&mut self, key: usize) -> Option<& T> { + if let Some(entry) = self.entries.get_mut(key) { + let val = entry.data.as_ref(); + return val; + } + None + } + pub(crate) fn get_first(&mut self) -> Option<& T> { + let curr = self.head; + if let Some(entry) = self.entries.get_mut(curr) { + let val = entry.data.as_ref(); + return val; + } + None + } } impl Default for Slots { diff --git a/ylong_runtime/tests/cancel_safe.rs b/ylong_runtime/tests/cancel_safe.rs new file mode 100644 index 0000000..c240e78 --- /dev/null +++ b/ylong_runtime/tests/cancel_safe.rs @@ -0,0 +1,51 @@ +// Copyright (c) 2023 Huawei Device Co., Ltd. +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#![cfg(feature = "sync")] +use std::sync::Arc; + +/// SDV test cases for Semaphore Mutex cancel-safe +/// +/// # Brief +/// 1. Create a counting auto-release-semaphore with an initial capacity. +/// 2. Asynchronously acquires a permit multiple times. +/// 3. Cancel half of the asynchronous tasks +/// 4. Execute remaining tasks +#[test] +fn sdv_semaphore_cancel_test() { + let sema = Arc::new(ylong_runtime::sync::AutoRelSemaphore::new(1).unwrap()); + let mut handles = vec![]; + let mut canceled_handles = vec![]; + for i in 0..100 { + let sema_cpy = sema.clone(); + let handle = ylong_runtime::spawn(async move { + for _ in 0..1000 { + let ret = sema_cpy.acquire().await.unwrap(); + drop(ret); + } + 1 + }); + if i % 2 == 0 { + handles.push(handle); + } else { + canceled_handles.push(handle); + } + } + for handle in canceled_handles { + handle.cancel(); + } + for handle in handles { + let ret = ylong_runtime::block_on(handle).unwrap(); + assert_eq!(ret, 1); + } +} \ No newline at end of file diff --git a/ylong_runtime/tests/entry.rs b/ylong_runtime/tests/entry.rs index 860d91e..5eea605 100644 --- a/ylong_runtime/tests/entry.rs +++ b/ylong_runtime/tests/entry.rs @@ -42,3 +42,4 @@ mod tcp_test; mod timer_test; mod udp_test; mod uds_test; +mod cancel_safe; -- Gitee