diff options
| author | Botahamec <botahamec@outlook.com> | 2025-02-28 16:09:11 -0500 |
|---|---|---|
| committer | Botahamec <botahamec@outlook.com> | 2025-02-28 16:09:11 -0500 |
| commit | 4ba03be97e6cc7e790bbc9bfc18caaa228c8a262 (patch) | |
| tree | a257184577a93ddf240aba698755c2886188788b /src/collection/ref.rs | |
| parent | 4a5ec04a29cba07c5960792528bd66b0f99ee3ee (diff) | |
Scoped lock API
Diffstat (limited to 'src/collection/ref.rs')
| -rw-r--r-- | src/collection/ref.rs | 261 |
1 files changed, 196 insertions, 65 deletions
diff --git a/src/collection/ref.rs b/src/collection/ref.rs index c86f298..b68b72f 100644 --- a/src/collection/ref.rs +++ b/src/collection/ref.rs @@ -1,32 +1,11 @@ use std::fmt::Debug; -use std::marker::PhantomData; use crate::lockable::{Lockable, OwnedLockable, RawLock, Sharable}; -use crate::Keyable; +use crate::{Keyable, ThreadKey}; +use super::utils::{get_locks, ordered_contains_duplicates}; use super::{utils, LockGuard, RefLockCollection}; -#[must_use] -pub fn get_locks<L: Lockable>(data: &L) -> Vec<&dyn RawLock> { - let mut locks = Vec::new(); - data.get_ptrs(&mut locks); - locks.sort_by_key(|lock| &raw const **lock); - locks -} - -/// returns `true` if the sorted list contains a duplicate -#[must_use] -fn contains_duplicates(l: &[&dyn RawLock]) -> bool { - if l.is_empty() { - // Return early to prevent panic in the below call to `windows` - return false; - } - - l.windows(2) - // NOTE: addr_eq is necessary because eq would also compare the v-table pointers - .any(|window| std::ptr::addr_eq(window[0], window[1])) -} - impl<'a, L> IntoIterator for &'a RefLockCollection<'a, L> where &'a L: IntoIterator, @@ -83,6 +62,11 @@ unsafe impl<L: Lockable> Lockable for RefLockCollection<'_, L> { where Self: 'g; + type DataMut<'a> + = L::DataMut<'a> + where + Self: 'a; + fn get_ptrs<'a>(&'a self, ptrs: &mut Vec<&'a dyn RawLock>) { ptrs.extend_from_slice(&self.locks); } @@ -90,6 +74,10 @@ unsafe impl<L: Lockable> Lockable for RefLockCollection<'_, L> { unsafe fn guard(&self) -> Self::Guard<'_> { self.data.guard() } + + unsafe fn data_mut(&self) -> Self::DataMut<'_> { + self.data.data_mut() + } } unsafe impl<L: Sharable> Sharable for RefLockCollection<'_, L> { @@ -98,9 +86,18 @@ unsafe impl<L: Sharable> Sharable for RefLockCollection<'_, L> { where Self: 'g; + type DataRef<'a> + = L::DataRef<'a> + where + Self: 'a; + unsafe fn read_guard(&self) -> Self::ReadGuard<'_> { self.data.read_guard() } + + unsafe fn data_ref(&self) -> Self::DataRef<'_> { + self.data.data_ref() + } } impl<T, L: AsRef<T>> AsRef<T> for RefLockCollection<'_, L> { @@ -230,13 +227,53 @@ impl<'a, L: Lockable> RefLockCollection<'a, L> { #[must_use] pub fn try_new(data: &'a L) -> Option<Self> { let locks = get_locks(data); - if contains_duplicates(&locks) { + if ordered_contains_duplicates(&locks) { return None; } Some(Self { data, locks }) } + pub fn scoped_lock<R>(&self, key: impl Keyable, f: impl Fn(L::DataMut<'_>) -> R) -> R { + unsafe { + // safety: we have the thread key + self.raw_lock(); + + // safety: the data was just locked + let r = f(self.data_mut()); + + // safety: the collection is still locked + self.raw_unlock(); + + drop(key); // ensure the key stays alive long enough + + r + } + } + + pub fn scoped_try_lock<Key: Keyable, R>( + &self, + key: Key, + f: impl Fn(L::DataMut<'_>) -> R, + ) -> Result<R, Key> { + unsafe { + // safety: we have the thread key + if !self.raw_try_lock() { + return Err(key); + } + + // safety: we just locked the collection + let r = f(self.data_mut()); + + // safety: the collection is still locked + self.raw_unlock(); + + drop(key); // ensures the key stays valid long enough + + Ok(r) + } + } + /// Locks the collection /// /// This function returns a guard that can be used to access the underlying @@ -257,10 +294,8 @@ impl<'a, L: Lockable> RefLockCollection<'a, L> { /// *guard.0 += 1; /// *guard.1 = "1"; /// ``` - pub fn lock<'g, 'key: 'g, Key: Keyable + 'key>( - &'g self, - key: Key, - ) -> LockGuard<'key, L::Guard<'g>, Key> { + #[must_use] + pub fn lock(&self, key: ThreadKey) -> LockGuard<L::Guard<'_>> { let guard = unsafe { // safety: we have the thread key self.raw_lock(); @@ -269,11 +304,7 @@ impl<'a, L: Lockable> RefLockCollection<'a, L> { self.data.guard() }; - LockGuard { - guard, - key, - _phantom: PhantomData, - } + LockGuard { guard, key } } /// Attempts to lock the without blocking. @@ -306,10 +337,7 @@ impl<'a, L: Lockable> RefLockCollection<'a, L> { /// }; /// /// ``` - pub fn try_lock<'g, 'key: 'a, Key: Keyable + 'key>( - &'g self, - key: Key, - ) -> Result<LockGuard<'key, L::Guard<'g>, Key>, Key> { + pub fn try_lock(&self, key: ThreadKey) -> Result<LockGuard<L::Guard<'_>>, ThreadKey> { let guard = unsafe { if !self.raw_try_lock() { return Err(key); @@ -319,11 +347,7 @@ impl<'a, L: Lockable> RefLockCollection<'a, L> { self.data.guard() }; - Ok(LockGuard { - guard, - key, - _phantom: PhantomData, - }) + Ok(LockGuard { guard, key }) } /// Unlocks the underlying lockable data type, returning the key that's @@ -345,13 +369,53 @@ impl<'a, L: Lockable> RefLockCollection<'a, L> { /// let key = RefLockCollection::<(Mutex<i32>, Mutex<&str>)>::unlock(guard); /// ``` #[allow(clippy::missing_const_for_fn)] - pub fn unlock<'g, 'key, Key: Keyable + 'key>(guard: LockGuard<'key, L::Guard<'g>, Key>) -> Key { + pub fn unlock(guard: LockGuard<L::Guard<'_>>) -> ThreadKey { drop(guard.guard); guard.key } } -impl<'a, L: Sharable> RefLockCollection<'a, L> { +impl<L: Sharable> RefLockCollection<'_, L> { + pub fn scoped_read<R>(&self, key: impl Keyable, f: impl Fn(L::DataRef<'_>) -> R) -> R { + unsafe { + // safety: we have the thread key + self.raw_read(); + + // safety: the data was just locked + let r = f(self.data_ref()); + + // safety: the collection is still locked + self.raw_unlock_read(); + + drop(key); // ensure the key stays alive long enough + + r + } + } + + pub fn scoped_try_read<Key: Keyable, R>( + &self, + key: Key, + f: impl Fn(L::DataRef<'_>) -> R, + ) -> Result<R, Key> { + unsafe { + // safety: we have the thread key + if !self.raw_try_read() { + return Err(key); + } + + // safety: we just locked the collection + let r = f(self.data_ref()); + + // safety: the collection is still locked + self.raw_unlock_read(); + + drop(key); // ensures the key stays valid long enough + + Ok(r) + } + } + /// Locks the collection, so that other threads can still read from it /// /// This function returns a guard that can be used to access the underlying @@ -372,10 +436,8 @@ impl<'a, L: Sharable> RefLockCollection<'a, L> { /// assert_eq!(*guard.0, 0); /// assert_eq!(*guard.1, ""); /// ``` - pub fn read<'g, 'key: 'g, Key: Keyable + 'key>( - &'g self, - key: Key, - ) -> LockGuard<'key, L::ReadGuard<'g>, Key> { + #[must_use] + pub fn read(&self, key: ThreadKey) -> LockGuard<L::ReadGuard<'_>> { unsafe { // safety: we have the thread key self.raw_read(); @@ -384,7 +446,6 @@ impl<'a, L: Sharable> RefLockCollection<'a, L> { // safety: we've already acquired the lock guard: self.data.read_guard(), key, - _phantom: PhantomData, } } } @@ -412,33 +473,26 @@ impl<'a, L: Sharable> RefLockCollection<'a, L> { /// let lock = RefLockCollection::new(&data); /// /// match lock.try_read(key) { - /// Some(mut guard) => { + /// Ok(mut guard) => { /// assert_eq!(*guard.0, 5); /// assert_eq!(*guard.1, "6"); /// }, - /// None => unreachable!(), + /// Err(_) => unreachable!(), /// }; /// /// ``` - pub fn try_read<'g, 'key: 'g, Key: Keyable + 'key>( - &'g self, - key: Key, - ) -> Option<LockGuard<'key, L::ReadGuard<'g>, Key>> { + pub fn try_read(&self, key: ThreadKey) -> Result<LockGuard<L::ReadGuard<'_>>, ThreadKey> { let guard = unsafe { // safety: we have the thread key if !self.raw_try_read() { - return None; + return Err(key); } // safety: we've acquired the locks self.data.read_guard() }; - Some(LockGuard { - guard, - key, - _phantom: PhantomData, - }) + Ok(LockGuard { guard, key }) } /// Unlocks the underlying lockable data type, returning the key that's @@ -458,9 +512,7 @@ impl<'a, L: Sharable> RefLockCollection<'a, L> { /// let key = RefLockCollection::<(RwLock<i32>, RwLock<&str>)>::unlock_read(guard); /// ``` #[allow(clippy::missing_const_for_fn)] - pub fn unlock_read<'key: 'a, Key: Keyable + 'key>( - guard: LockGuard<'key, L::ReadGuard<'a>, Key>, - ) -> Key { + pub fn unlock_read(guard: LockGuard<L::ReadGuard<'_>>) -> ThreadKey { drop(guard.guard); guard.key } @@ -497,7 +549,7 @@ where #[cfg(test)] mod tests { use super::*; - use crate::{Mutex, ThreadKey}; + use crate::{Mutex, RwLock, ThreadKey}; #[test] fn non_duplicates_allowed() { @@ -513,6 +565,85 @@ mod tests { } #[test] + fn try_lock_succeeds_for_unlocked_collection() { + let key = ThreadKey::get().unwrap(); + let mutexes = [Mutex::new(24), Mutex::new(42)]; + let collection = RefLockCollection::new(&mutexes); + let guard = collection.try_lock(key).unwrap(); + assert_eq!(*guard[0], 24); + assert_eq!(*guard[1], 42); + } + + #[test] + fn try_lock_fails_for_locked_collection() { + let key = ThreadKey::get().unwrap(); + let mutexes = [Mutex::new(24), Mutex::new(42)]; + let collection = RefLockCollection::new(&mutexes); + + std::thread::scope(|s| { + s.spawn(|| { + let key = ThreadKey::get().unwrap(); + let guard = mutexes[1].lock(key); + assert_eq!(*guard, 42); + std::mem::forget(guard); + }); + }); + + let guard = collection.try_lock(key); + assert!(guard.is_err()); + } + + #[test] + fn try_read_succeeds_for_unlocked_collection() { + let key = ThreadKey::get().unwrap(); + let mutexes = [RwLock::new(24), RwLock::new(42)]; + let collection = RefLockCollection::new(&mutexes); + let guard = collection.try_read(key).unwrap(); + assert_eq!(*guard[0], 24); + assert_eq!(*guard[1], 42); + } + + #[test] + fn try_read_fails_for_locked_collection() { + let key = ThreadKey::get().unwrap(); + let mutexes = [RwLock::new(24), RwLock::new(42)]; + let collection = RefLockCollection::new(&mutexes); + + std::thread::scope(|s| { + s.spawn(|| { + let key = ThreadKey::get().unwrap(); + let guard = mutexes[1].write(key); + assert_eq!(*guard, 42); + std::mem::forget(guard); + }); + }); + + let guard = collection.try_read(key); + assert!(guard.is_err()); + } + + #[test] + fn can_read_twice_on_different_threads() { + let key = ThreadKey::get().unwrap(); + let mutexes = [RwLock::new(24), RwLock::new(42)]; + let collection = RefLockCollection::new(&mutexes); + + std::thread::scope(|s| { + s.spawn(|| { + let key = ThreadKey::get().unwrap(); + let guard = collection.read(key); + assert_eq!(*guard[0], 24); + assert_eq!(*guard[1], 42); + std::mem::forget(guard); + }); + }); + + let guard = collection.try_read(key).unwrap(); + assert_eq!(*guard[0], 24); + assert_eq!(*guard[1], 42); + } + + #[test] fn works_in_collection() { let key = ThreadKey::get().unwrap(); let mutex1 = Mutex::new(0); |
