diff options
Diffstat (limited to 'src/collection/owned.rs')
| -rw-r--r-- | src/collection/owned.rs | 283 |
1 files changed, 204 insertions, 79 deletions
diff --git a/src/collection/owned.rs b/src/collection/owned.rs index b9cf313..68170d1 100644 --- a/src/collection/owned.rs +++ b/src/collection/owned.rs @@ -3,6 +3,7 @@ use crate::lockable::{ }; use crate::{Keyable, ThreadKey}; +use super::utils::{scoped_read, scoped_try_read, scoped_try_write, scoped_write}; use super::{utils, LockGuard, OwnedLockCollection}; unsafe impl<L: Lockable> RawLock for OwnedLockCollection<L> { @@ -15,19 +16,19 @@ unsafe impl<L: Lockable> RawLock for OwnedLockCollection<L> { } } - unsafe fn raw_lock(&self) { - utils::ordered_lock(&utils::get_locks_unsorted(&self.data)) + unsafe fn raw_write(&self) { + utils::ordered_write(&utils::get_locks_unsorted(&self.data)) } - unsafe fn raw_try_lock(&self) -> bool { + unsafe fn raw_try_write(&self) -> bool { let locks = utils::get_locks_unsorted(&self.data); - utils::ordered_try_lock(&locks) + utils::ordered_try_write(&locks) } - unsafe fn raw_unlock(&self) { + unsafe fn raw_unlock_write(&self) { let locks = utils::get_locks_unsorted(&self.data); for lock in locks { - lock.raw_unlock(); + lock.raw_unlock_write(); } } @@ -62,7 +63,7 @@ unsafe impl<L: Lockable> Lockable for OwnedLockCollection<L> { #[mutants::skip] // It's hard to test lkocks in an OwnedLockCollection, because they're owned #[cfg(not(tarpaulin_include))] fn get_ptrs<'a>(&'a self, ptrs: &mut Vec<&'a dyn RawLock>) { - self.data.get_ptrs(ptrs) + ptrs.push(self) } unsafe fn guard(&self) -> Self::Guard<'_> { @@ -146,7 +147,7 @@ impl<E: OwnedLockable + Extend<L>, L: OwnedLockable> Extend<L> for OwnedLockColl // invariant that there is only one way to lock the collection. AsMut is fine, // because the collection can't be locked as long as the reference is valid. -impl<T, L: AsMut<T>> AsMut<T> for OwnedLockCollection<L> { +impl<T: ?Sized, L: AsMut<T>> AsMut<T> for OwnedLockCollection<L> { fn as_mut(&mut self) -> &mut T { self.data.as_mut() } @@ -185,44 +186,16 @@ impl<L: OwnedLockable> OwnedLockCollection<L> { Self { data } } - pub fn scoped_lock<R>(&self, key: impl Keyable, f: impl Fn(L::DataMut<'_>) -> R) -> R { - unsafe { - // safety: we have the thread key - self.raw_lock(); - - // safety: the data was just locked - let r = f(self.data_mut()); - - // safety: the collection is still locked - self.raw_unlock(); - - drop(key); // ensure the key stays alive long enough - - r - } + pub fn scoped_lock<'a, R>(&'a self, key: impl Keyable, f: impl Fn(L::DataMut<'a>) -> R) -> R { + scoped_write(self, key, f) } - pub fn scoped_try_lock<Key: Keyable, R>( - &self, + pub fn scoped_try_lock<'a, Key: Keyable, R>( + &'a self, key: Key, - f: impl Fn(L::DataMut<'_>) -> R, + f: impl Fn(L::DataMut<'a>) -> R, ) -> Result<R, Key> { - unsafe { - // safety: we have the thread key - if !self.raw_try_lock() { - return Err(key); - } - - // safety: we just locked the collection - let r = f(self.data_mut()); - - // safety: the collection is still locked - self.raw_unlock(); - - drop(key); // ensures the key stays valid long enough - - Ok(r) - } + scoped_try_write(self, key, f) } /// Locks the collection @@ -249,7 +222,7 @@ impl<L: OwnedLockable> OwnedLockCollection<L> { let guard = unsafe { // safety: we have the thread key, and these locks happen in a // predetermined order - self.raw_lock(); + self.raw_write(); // safety: we've locked all of this already self.data.guard() @@ -290,7 +263,7 @@ impl<L: OwnedLockable> OwnedLockCollection<L> { /// ``` pub fn try_lock(&self, key: ThreadKey) -> Result<LockGuard<L::Guard<'_>>, ThreadKey> { let guard = unsafe { - if !self.raw_try_lock() { + if !self.raw_try_write() { return Err(key); } @@ -327,44 +300,16 @@ impl<L: OwnedLockable> OwnedLockCollection<L> { } impl<L: Sharable> OwnedLockCollection<L> { - pub fn scoped_read<R>(&self, key: impl Keyable, f: impl Fn(L::DataRef<'_>) -> R) -> R { - unsafe { - // safety: we have the thread key - self.raw_read(); - - // safety: the data was just locked - let r = f(self.data_ref()); - - // safety: the collection is still locked - self.raw_unlock_read(); - - drop(key); // ensure the key stays alive long enough - - r - } + pub fn scoped_read<'a, R>(&'a self, key: impl Keyable, f: impl Fn(L::DataRef<'a>) -> R) -> R { + scoped_read(self, key, f) } - pub fn scoped_try_read<Key: Keyable, R>( - &self, + pub fn scoped_try_read<'a, Key: Keyable, R>( + &'a self, key: Key, - f: impl Fn(L::DataRef<'_>) -> R, + f: impl Fn(L::DataRef<'a>) -> R, ) -> Result<R, Key> { - unsafe { - // safety: we have the thread key - if !self.raw_try_read() { - return Err(key); - } - - // safety: we just locked the collection - let r = f(self.data_ref()); - - // safety: the collection is still locked - self.raw_unlock_read(); - - drop(key); // ensures the key stays valid long enough - - Ok(r) - } + scoped_try_read(self, key, f) } /// Locks the collection, so that other threads can still read from it @@ -554,7 +499,7 @@ impl<L: LockableIntoInner> OwnedLockCollection<L> { #[cfg(test)] mod tests { use super::*; - use crate::{Mutex, ThreadKey}; + use crate::{Mutex, RwLock, ThreadKey}; #[test] fn get_mut_applies_changes() { @@ -604,6 +549,63 @@ mod tests { } #[test] + fn scoped_read_works() { + let mut key = ThreadKey::get().unwrap(); + let collection = OwnedLockCollection::new([RwLock::new(24), RwLock::new(42)]); + let sum = collection.scoped_read(&mut key, |guard| guard[0] + guard[1]); + assert_eq!(sum, 24 + 42); + } + + #[test] + fn scoped_lock_works() { + let mut key = ThreadKey::get().unwrap(); + let collection = OwnedLockCollection::new([RwLock::new(24), RwLock::new(42)]); + collection.scoped_lock(&mut key, |guard| *guard[0] += *guard[1]); + + let sum = collection.scoped_lock(&mut key, |guard| { + assert_eq!(*guard[0], 24 + 42); + assert_eq!(*guard[1], 42); + *guard[0] + *guard[1] + }); + + assert_eq!(sum, 24 + 42 + 42); + } + + #[test] + fn scoped_try_lock_can_fail() { + let key = ThreadKey::get().unwrap(); + let collection = OwnedLockCollection::new([Mutex::new(1), Mutex::new(2)]); + let guard = collection.lock(key); + + std::thread::scope(|s| { + s.spawn(|| { + let key = ThreadKey::get().unwrap(); + let r = collection.scoped_try_lock(key, |_| {}); + assert!(r.is_err()); + }); + }); + + drop(guard); + } + + #[test] + fn scoped_try_read_can_fail() { + let key = ThreadKey::get().unwrap(); + let collection = OwnedLockCollection::new([RwLock::new(1), RwLock::new(2)]); + let guard = collection.lock(key); + + std::thread::scope(|s| { + s.spawn(|| { + let key = ThreadKey::get().unwrap(); + let r = collection.scoped_try_read(key, |_| {}); + assert!(r.is_err()); + }); + }); + + drop(guard); + } + + #[test] fn try_lock_works_on_unlocked() { let key = ThreadKey::get().unwrap(); let collection = OwnedLockCollection::new((Mutex::new(0), Mutex::new(1))); @@ -630,6 +632,74 @@ mod tests { } #[test] + fn try_read_succeeds_for_unlocked_collection() { + let key = ThreadKey::get().unwrap(); + let mutexes = [RwLock::new(24), RwLock::new(42)]; + let collection = OwnedLockCollection::new(mutexes); + let guard = collection.try_read(key).unwrap(); + assert_eq!(*guard[0], 24); + assert_eq!(*guard[1], 42); + } + + #[test] + fn try_read_fails_on_locked() { + let key = ThreadKey::get().unwrap(); + let collection = OwnedLockCollection::new((RwLock::new(0), RwLock::new(1))); + + std::thread::scope(|s| { + s.spawn(|| { + let key = ThreadKey::get().unwrap(); + #[allow(unused)] + let guard = collection.lock(key); + std::mem::forget(guard); + }); + }); + + assert!(collection.try_read(key).is_err()); + } + + #[test] + fn can_read_twice_on_different_threads() { + let key = ThreadKey::get().unwrap(); + let mutexes = [RwLock::new(24), RwLock::new(42)]; + let collection = OwnedLockCollection::new(mutexes); + + std::thread::scope(|s| { + s.spawn(|| { + let key = ThreadKey::get().unwrap(); + let guard = collection.read(key); + assert_eq!(*guard[0], 24); + assert_eq!(*guard[1], 42); + std::mem::forget(guard); + }); + }); + + let guard = collection.try_read(key).unwrap(); + assert_eq!(*guard[0], 24); + assert_eq!(*guard[1], 42); + } + + #[test] + fn unlock_collection_works() { + let key = ThreadKey::get().unwrap(); + let collection = OwnedLockCollection::new((Mutex::new("foo"), Mutex::new("bar"))); + let guard = collection.lock(key); + + let key = OwnedLockCollection::<(Mutex<_>, Mutex<_>)>::unlock(guard); + assert!(collection.try_lock(key).is_ok()) + } + + #[test] + fn read_unlock_collection_works() { + let key = ThreadKey::get().unwrap(); + let collection = OwnedLockCollection::new((RwLock::new("foo"), RwLock::new("bar"))); + let guard = collection.read(key); + + let key = OwnedLockCollection::<(&RwLock<_>, &RwLock<_>)>::unlock_read(guard); + assert!(collection.try_lock(key).is_ok()) + } + + #[test] fn default_works() { type MyCollection = OwnedLockCollection<(Mutex<i32>, Mutex<Option<i32>>, Mutex<String>)>; let collection = MyCollection::default(); @@ -649,4 +719,59 @@ mod tests { assert_eq!(collection.data.len(), 3); } + + #[test] + fn works_in_collection() { + let key = ThreadKey::get().unwrap(); + let collection = + OwnedLockCollection::new(OwnedLockCollection::new([RwLock::new(0), RwLock::new(1)])); + + let mut guard = collection.lock(key); + assert_eq!(*guard[0], 0); + assert_eq!(*guard[1], 1); + *guard[1] = 2; + + let key = OwnedLockCollection::<OwnedLockCollection<[RwLock<_>; 2]>>::unlock(guard); + let guard = collection.read(key); + assert_eq!(*guard[0], 0); + assert_eq!(*guard[1], 2); + } + + #[test] + fn as_mut_works() { + let mut mutexes = [Mutex::new(0), Mutex::new(1)]; + let mut collection = OwnedLockCollection::new(&mut mutexes); + + collection.as_mut()[0] = Mutex::new(42); + + assert_eq!(*collection.as_mut()[0].get_mut(), 42); + } + + #[test] + fn child_mut_works() { + let mut mutexes = [Mutex::new(0), Mutex::new(1)]; + let mut collection = OwnedLockCollection::new(&mut mutexes); + + collection.child_mut()[0] = Mutex::new(42); + + assert_eq!(*collection.child_mut()[0].get_mut(), 42); + } + + #[test] + fn into_child_works() { + let mutexes = [Mutex::new(0), Mutex::new(1)]; + let mut collection = OwnedLockCollection::new(mutexes); + + collection.child_mut()[0] = Mutex::new(42); + + assert_eq!( + *collection + .into_child() + .as_mut() + .get_mut(0) + .unwrap() + .get_mut(), + 42 + ); + } } |
