diff options
Diffstat (limited to 'src/collection/boxed.rs')
| -rw-r--r-- | src/collection/boxed.rs | 170 |
1 files changed, 113 insertions, 57 deletions
diff --git a/src/collection/boxed.rs b/src/collection/boxed.rs index 0597e90..364ec97 100644 --- a/src/collection/boxed.rs +++ b/src/collection/boxed.rs @@ -1,25 +1,12 @@ use std::cell::UnsafeCell; use std::fmt::Debug; -use std::marker::PhantomData; use crate::lockable::{Lockable, LockableIntoInner, OwnedLockable, RawLock, Sharable}; -use crate::Keyable; +use crate::{Keyable, ThreadKey}; +use super::utils::ordered_contains_duplicates; use super::{utils, BoxedLockCollection, LockGuard}; -/// returns `true` if the sorted list contains a duplicate -#[must_use] -fn contains_duplicates(l: &[&dyn RawLock]) -> bool { - if l.is_empty() { - // Return early to prevent panic in the below call to `windows` - return false; - } - - l.windows(2) - // NOTE: addr_eq is necessary because eq would also compare the v-table pointers - .any(|window| std::ptr::addr_eq(window[0], window[1])) -} - unsafe impl<L: Lockable> RawLock for BoxedLockCollection<L> { #[mutants::skip] // this should never be called #[cfg(not(tarpaulin_include))] @@ -65,6 +52,11 @@ unsafe impl<L: Lockable> Lockable for BoxedLockCollection<L> { where Self: 'g; + type DataMut<'a> + = L::DataMut<'a> + where + Self: 'a; + fn get_ptrs<'a>(&'a self, ptrs: &mut Vec<&'a dyn RawLock>) { ptrs.extend(self.locks()) } @@ -72,6 +64,10 @@ unsafe impl<L: Lockable> Lockable for BoxedLockCollection<L> { unsafe fn guard(&self) -> Self::Guard<'_> { self.child().guard() } + + unsafe fn data_mut(&self) -> Self::DataMut<'_> { + self.child().data_mut() + } } unsafe impl<L: Sharable> Sharable for BoxedLockCollection<L> { @@ -80,9 +76,18 @@ unsafe impl<L: Sharable> Sharable for BoxedLockCollection<L> { where Self: 'g; + type DataRef<'a> + = L::DataRef<'a> + where + Self: 'a; + unsafe fn read_guard(&self) -> Self::ReadGuard<'_> { self.child().read_guard() } + + unsafe fn data_ref(&self) -> Self::DataRef<'_> { + self.child().data_ref() + } } unsafe impl<L: OwnedLockable> OwnedLockable for BoxedLockCollection<L> {} @@ -352,13 +357,53 @@ impl<L: Lockable> BoxedLockCollection<L> { // safety: we are checking for duplicates before returning unsafe { let this = Self::new_unchecked(data); - if contains_duplicates(this.locks()) { + if ordered_contains_duplicates(this.locks()) { return None; } Some(this) } } + pub fn scoped_lock<R>(&self, key: impl Keyable, f: impl Fn(L::DataMut<'_>) -> R) -> R { + unsafe { + // safety: we have the thread key + self.raw_lock(); + + // safety: the data was just locked + let r = f(self.data_mut()); + + // safety: the collection is still locked + self.raw_unlock(); + + drop(key); // ensure the key stays alive long enough + + r + } + } + + pub fn scoped_try_lock<Key: Keyable, R>( + &self, + key: Key, + f: impl Fn(L::DataMut<'_>) -> R, + ) -> Result<R, Key> { + unsafe { + // safety: we have the thread key + if !self.raw_try_lock() { + return Err(key); + } + + // safety: we just locked the collection + let r = f(self.data_mut()); + + // safety: the collection is still locked + self.raw_unlock(); + + drop(key); // ensures the key stays valid long enough + + Ok(r) + } + } + /// Locks the collection /// /// This function returns a guard that can be used to access the underlying @@ -378,10 +423,8 @@ impl<L: Lockable> BoxedLockCollection<L> { /// *guard.0 += 1; /// *guard.1 = "1"; /// ``` - pub fn lock<'g, 'key: 'g, Key: Keyable + 'key>( - &'g self, - key: Key, - ) -> LockGuard<'key, L::Guard<'g>, Key> { + #[must_use] + pub fn lock(&self, key: ThreadKey) -> LockGuard<L::Guard<'_>> { unsafe { // safety: we have the thread key self.raw_lock(); @@ -390,7 +433,6 @@ impl<L: Lockable> BoxedLockCollection<L> { // safety: we've already acquired the lock guard: self.child().guard(), key, - _phantom: PhantomData, } } } @@ -424,10 +466,7 @@ impl<L: Lockable> BoxedLockCollection<L> { /// }; /// /// ``` - pub fn try_lock<'g, 'key: 'g, Key: Keyable + 'key>( - &'g self, - key: Key, - ) -> Result<LockGuard<'key, L::Guard<'g>, Key>, Key> { + pub fn try_lock(&self, key: ThreadKey) -> Result<LockGuard<L::Guard<'_>>, ThreadKey> { let guard = unsafe { if !self.raw_try_lock() { return Err(key); @@ -437,11 +476,7 @@ impl<L: Lockable> BoxedLockCollection<L> { self.child().guard() }; - Ok(LockGuard { - guard, - key, - _phantom: PhantomData, - }) + Ok(LockGuard { guard, key }) } /// Unlocks the underlying lockable data type, returning the key that's @@ -461,13 +496,53 @@ impl<L: Lockable> BoxedLockCollection<L> { /// *guard.1 = "1"; /// let key = LockCollection::<(Mutex<i32>, Mutex<&str>)>::unlock(guard); /// ``` - pub fn unlock<'key, Key: Keyable + 'key>(guard: LockGuard<'key, L::Guard<'_>, Key>) -> Key { + pub fn unlock(guard: LockGuard<L::Guard<'_>>) -> ThreadKey { drop(guard.guard); guard.key } } impl<L: Sharable> BoxedLockCollection<L> { + pub fn scoped_read<R>(&self, key: impl Keyable, f: impl Fn(L::DataRef<'_>) -> R) -> R { + unsafe { + // safety: we have the thread key + self.raw_read(); + + // safety: the data was just locked + let r = f(self.data_ref()); + + // safety: the collection is still locked + self.raw_unlock_read(); + + drop(key); // ensure the key stays alive long enough + + r + } + } + + pub fn scoped_try_read<Key: Keyable, R>( + &self, + key: Key, + f: impl Fn(L::DataRef<'_>) -> R, + ) -> Result<R, Key> { + unsafe { + // safety: we have the thread key + if !self.raw_try_read() { + return Err(key); + } + + // safety: we just locked the collection + let r = f(self.data_ref()); + + // safety: the collection is still locked + self.raw_unlock_read(); + + drop(key); // ensures the key stays valid long enough + + Ok(r) + } + } + /// Locks the collection, so that other threads can still read from it /// /// This function returns a guard that can be used to access the underlying @@ -487,10 +562,8 @@ impl<L: Sharable> BoxedLockCollection<L> { /// assert_eq!(*guard.0, 0); /// assert_eq!(*guard.1, ""); /// ``` - pub fn read<'g, 'key: 'g, Key: Keyable + 'key>( - &'g self, - key: Key, - ) -> LockGuard<'key, L::ReadGuard<'g>, Key> { + #[must_use] + pub fn read(&self, key: ThreadKey) -> LockGuard<L::ReadGuard<'_>> { unsafe { // safety: we have the thread key self.raw_read(); @@ -499,7 +572,6 @@ impl<L: Sharable> BoxedLockCollection<L> { // safety: we've already acquired the lock guard: self.child().read_guard(), key, - _phantom: PhantomData, } } } @@ -534,10 +606,7 @@ impl<L: Sharable> BoxedLockCollection<L> { /// }; /// /// ``` - pub fn try_read<'g, 'key: 'g, Key: Keyable + 'key>( - &'g self, - key: Key, - ) -> Result<LockGuard<'key, L::ReadGuard<'g>, Key>, Key> { + pub fn try_read(&self, key: ThreadKey) -> Result<LockGuard<L::ReadGuard<'_>>, ThreadKey> { let guard = unsafe { // safety: we have the thread key if !self.raw_try_read() { @@ -548,11 +617,7 @@ impl<L: Sharable> BoxedLockCollection<L> { self.child().read_guard() }; - Ok(LockGuard { - guard, - key, - _phantom: PhantomData, - }) + Ok(LockGuard { guard, key }) } /// Unlocks the underlying lockable data type, returning the key that's @@ -570,9 +635,7 @@ impl<L: Sharable> BoxedLockCollection<L> { /// let mut guard = lock.read(key); /// let key = LockCollection::<(RwLock<i32>, RwLock<&str>)>::unlock_read(guard); /// ``` - pub fn unlock_read<'key, Key: Keyable + 'key>( - guard: LockGuard<'key, L::ReadGuard<'_>, Key>, - ) -> Key { + pub fn unlock_read(guard: LockGuard<L::ReadGuard<'_>>) -> ThreadKey { drop(guard.guard); guard.key } @@ -635,7 +698,6 @@ mod tests { .into_iter() .collect(); let guard = collection.lock(key); - // TODO impl PartialEq<T> for MutexRef<T> assert_eq!(*guard[0], "foo"); assert_eq!(*guard[1], "bar"); assert_eq!(*guard[2], "baz"); @@ -647,7 +709,6 @@ mod tests { let collection = BoxedLockCollection::from([Mutex::new("foo"), Mutex::new("bar"), Mutex::new("baz")]); let guard = collection.lock(key); - // TODO impl PartialEq<T> for MutexRef<T> assert_eq!(*guard[0], "foo"); assert_eq!(*guard[1], "bar"); assert_eq!(*guard[2], "baz"); @@ -666,7 +727,7 @@ mod tests { let mut key = ThreadKey::get().unwrap(); let collection = BoxedLockCollection::new([Mutex::new(0), Mutex::new(1), Mutex::new(2)]); for (i, mutex) in (&collection).into_iter().enumerate() { - assert_eq!(*mutex.lock(&mut key), i); + mutex.scoped_lock(&mut key, |val| assert_eq!(*val, i)) } } @@ -675,7 +736,7 @@ mod tests { let mut key = ThreadKey::get().unwrap(); let collection = BoxedLockCollection::new([Mutex::new(0), Mutex::new(1), Mutex::new(2)]); for (i, mutex) in collection.iter().enumerate() { - assert_eq!(*mutex.lock(&mut key), i); + mutex.scoped_lock(&mut key, |val| assert_eq!(*val, i)) } } @@ -704,11 +765,6 @@ mod tests { } #[test] - fn contains_duplicates_empty() { - assert!(!contains_duplicates(&[])) - } - - #[test] fn try_lock_works() { let key = ThreadKey::get().unwrap(); let collection = BoxedLockCollection::new([Mutex::new(1), Mutex::new(2)]); |
