diff options
Diffstat (limited to 'src/collection/ref.rs')
| -rw-r--r-- | src/collection/ref.rs | 261 |
1 files changed, 179 insertions, 82 deletions
diff --git a/src/collection/ref.rs b/src/collection/ref.rs index b68b72f..5f96533 100644 --- a/src/collection/ref.rs +++ b/src/collection/ref.rs @@ -3,7 +3,10 @@ use std::fmt::Debug; use crate::lockable::{Lockable, OwnedLockable, RawLock, Sharable}; use crate::{Keyable, ThreadKey}; -use super::utils::{get_locks, ordered_contains_duplicates}; +use super::utils::{ + get_locks, ordered_contains_duplicates, scoped_read, scoped_try_read, scoped_try_write, + scoped_write, +}; use super::{utils, LockGuard, RefLockCollection}; impl<'a, L> IntoIterator for &'a RefLockCollection<'a, L> @@ -27,17 +30,17 @@ unsafe impl<L: Lockable> RawLock for RefLockCollection<'_, L> { } } - unsafe fn raw_lock(&self) { - utils::ordered_lock(&self.locks) + unsafe fn raw_write(&self) { + utils::ordered_write(&self.locks) } - unsafe fn raw_try_lock(&self) -> bool { - utils::ordered_try_lock(&self.locks) + unsafe fn raw_try_write(&self) -> bool { + utils::ordered_try_write(&self.locks) } - unsafe fn raw_unlock(&self) { + unsafe fn raw_unlock_write(&self) { for lock in &self.locks { - lock.raw_unlock(); + lock.raw_unlock_write(); } } @@ -68,7 +71,7 @@ unsafe impl<L: Lockable> Lockable for RefLockCollection<'_, L> { Self: 'a; fn get_ptrs<'a>(&'a self, ptrs: &mut Vec<&'a dyn RawLock>) { - ptrs.extend_from_slice(&self.locks); + ptrs.push(self) } unsafe fn guard(&self) -> Self::Guard<'_> { @@ -100,7 +103,7 @@ unsafe impl<L: Sharable> Sharable for RefLockCollection<'_, L> { } } -impl<T, L: AsRef<T>> AsRef<T> for RefLockCollection<'_, L> { +impl<T: ?Sized, L: AsRef<T>> AsRef<T> for RefLockCollection<'_, L> { fn as_ref(&self) -> &T { self.data.as_ref() } @@ -234,44 +237,16 @@ impl<'a, L: Lockable> RefLockCollection<'a, L> { Some(Self { data, locks }) } - pub fn scoped_lock<R>(&self, key: impl Keyable, f: impl Fn(L::DataMut<'_>) -> R) -> R { - unsafe { - // safety: we have the thread key - self.raw_lock(); - - // safety: the data was just locked - let r = f(self.data_mut()); - - // safety: the collection is still locked - self.raw_unlock(); - - drop(key); // ensure the key stays alive long enough - - r - } + pub fn scoped_lock<'s, R>(&'s self, key: impl Keyable, f: impl Fn(L::DataMut<'s>) -> R) -> R { + scoped_write(self, key, f) } - pub fn scoped_try_lock<Key: Keyable, R>( - &self, + pub fn scoped_try_lock<'s, Key: Keyable, R>( + &'s self, key: Key, - f: impl Fn(L::DataMut<'_>) -> R, + f: impl Fn(L::DataMut<'s>) -> R, ) -> Result<R, Key> { - unsafe { - // safety: we have the thread key - if !self.raw_try_lock() { - return Err(key); - } - - // safety: we just locked the collection - let r = f(self.data_mut()); - - // safety: the collection is still locked - self.raw_unlock(); - - drop(key); // ensures the key stays valid long enough - - Ok(r) - } + scoped_try_write(self, key, f) } /// Locks the collection @@ -298,7 +273,7 @@ impl<'a, L: Lockable> RefLockCollection<'a, L> { pub fn lock(&self, key: ThreadKey) -> LockGuard<L::Guard<'_>> { let guard = unsafe { // safety: we have the thread key - self.raw_lock(); + self.raw_write(); // safety: we've locked all of this already self.data.guard() @@ -339,7 +314,7 @@ impl<'a, L: Lockable> RefLockCollection<'a, L> { /// ``` pub fn try_lock(&self, key: ThreadKey) -> Result<LockGuard<L::Guard<'_>>, ThreadKey> { let guard = unsafe { - if !self.raw_try_lock() { + if !self.raw_try_write() { return Err(key); } @@ -376,44 +351,16 @@ impl<'a, L: Lockable> RefLockCollection<'a, L> { } impl<L: Sharable> RefLockCollection<'_, L> { - pub fn scoped_read<R>(&self, key: impl Keyable, f: impl Fn(L::DataRef<'_>) -> R) -> R { - unsafe { - // safety: we have the thread key - self.raw_read(); - - // safety: the data was just locked - let r = f(self.data_ref()); - - // safety: the collection is still locked - self.raw_unlock_read(); - - drop(key); // ensure the key stays alive long enough - - r - } + pub fn scoped_read<'a, R>(&'a self, key: impl Keyable, f: impl Fn(L::DataRef<'a>) -> R) -> R { + scoped_read(self, key, f) } - pub fn scoped_try_read<Key: Keyable, R>( - &self, + pub fn scoped_try_read<'a, Key: Keyable, R>( + &'a self, key: Key, - f: impl Fn(L::DataRef<'_>) -> R, + f: impl Fn(L::DataRef<'a>) -> R, ) -> Result<R, Key> { - unsafe { - // safety: we have the thread key - if !self.raw_try_read() { - return Err(key); - } - - // safety: we just locked the collection - let r = f(self.data_ref()); - - // safety: the collection is still locked - self.raw_unlock_read(); - - drop(key); // ensures the key stays valid long enough - - Ok(r) - } + scoped_try_read(self, key, f) } /// Locks the collection, so that other threads can still read from it @@ -565,6 +512,88 @@ mod tests { } #[test] + fn from() { + let key = ThreadKey::get().unwrap(); + let mutexes = [Mutex::new("foo"), Mutex::new("bar"), Mutex::new("baz")]; + let collection = RefLockCollection::from(&mutexes); + let guard = collection.lock(key); + assert_eq!(*guard[0], "foo"); + assert_eq!(*guard[1], "bar"); + assert_eq!(*guard[2], "baz"); + } + + #[test] + fn scoped_lock_changes_collection() { + let mut key = ThreadKey::get().unwrap(); + let mutexes = [Mutex::new(24), Mutex::new(42)]; + let collection = RefLockCollection::new(&mutexes); + let sum = collection.scoped_lock(&mut key, |guard| { + *guard[0] = 128; + *guard[0] + *guard[1] + }); + + assert_eq!(sum, 128 + 42); + + let guard = collection.lock(key); + assert_eq!(*guard[0], 128); + assert_eq!(*guard[1], 42); + } + + #[test] + fn scoped_read_sees_changes() { + let mut key = ThreadKey::get().unwrap(); + let mutexes = [RwLock::new(24), RwLock::new(42)]; + let collection = RefLockCollection::new(&mutexes); + collection.scoped_lock(&mut key, |guard| { + *guard[0] = 128; + }); + + let sum = collection.scoped_read(&mut key, |guard| { + assert_eq!(*guard[0], 128); + assert_eq!(*guard[1], 42); + *guard[0] + *guard[1] + }); + + assert_eq!(sum, 128 + 42); + } + + #[test] + fn scoped_try_lock_can_fail() { + let key = ThreadKey::get().unwrap(); + let locks = [Mutex::new(1), Mutex::new(2)]; + let collection = RefLockCollection::new(&locks); + let guard = collection.lock(key); + + std::thread::scope(|s| { + s.spawn(|| { + let key = ThreadKey::get().unwrap(); + let r = collection.scoped_try_lock(key, |_| {}); + assert!(r.is_err()); + }); + }); + + drop(guard); + } + + #[test] + fn scoped_try_read_can_fail() { + let key = ThreadKey::get().unwrap(); + let locks = [RwLock::new(1), RwLock::new(2)]; + let collection = RefLockCollection::new(&locks); + let guard = collection.lock(key); + + std::thread::scope(|s| { + s.spawn(|| { + let key = ThreadKey::get().unwrap(); + let r = collection.scoped_try_read(key, |_| {}); + assert!(r.is_err()); + }); + }); + + drop(guard); + } + + #[test] fn try_lock_succeeds_for_unlocked_collection() { let key = ThreadKey::get().unwrap(); let mutexes = [Mutex::new(24), Mutex::new(42)]; @@ -644,17 +673,85 @@ mod tests { } #[test] + fn into_ref_iterator() { + let mut key = ThreadKey::get().unwrap(); + let mutexes = [Mutex::new(0), Mutex::new(1), Mutex::new(2)]; + let collection = RefLockCollection::new(&mutexes); + for (i, mutex) in (&collection).into_iter().enumerate() { + mutex.scoped_lock(&mut key, |val| assert_eq!(*val, i)) + } + } + + #[test] + fn ref_iterator() { + let mut key = ThreadKey::get().unwrap(); + let mutexes = [Mutex::new(0), Mutex::new(1), Mutex::new(2)]; + let collection = RefLockCollection::new(&mutexes); + for (i, mutex) in collection.iter().enumerate() { + mutex.scoped_lock(&mut key, |val| assert_eq!(*val, i)) + } + } + + #[test] fn works_in_collection() { let key = ThreadKey::get().unwrap(); - let mutex1 = Mutex::new(0); - let mutex2 = Mutex::new(1); + let mutex1 = RwLock::new(0); + let mutex2 = RwLock::new(1); let collection0 = [&mutex1, &mutex2]; let collection1 = RefLockCollection::try_new(&collection0).unwrap(); let collection = RefLockCollection::try_new(&collection1).unwrap(); - let guard = collection.lock(key); + let mut guard = collection.lock(key); assert!(mutex1.is_locked()); assert!(mutex2.is_locked()); + assert_eq!(*guard[0], 0); + assert_eq!(*guard[1], 1); + *guard[1] = 2; drop(guard); + + let key = ThreadKey::get().unwrap(); + let guard = collection.read(key); + assert!(mutex1.is_locked()); + assert!(mutex2.is_locked()); + assert_eq!(*guard[0], 0); + assert_eq!(*guard[1], 2); + } + + #[test] + fn unlock_collection_works() { + let key = ThreadKey::get().unwrap(); + let mutexes = (Mutex::new("foo"), Mutex::new("bar")); + let collection = RefLockCollection::new(&mutexes); + let guard = collection.lock(key); + + let key = RefLockCollection::<(Mutex<_>, Mutex<_>)>::unlock(guard); + assert!(collection.try_lock(key).is_ok()) + } + + #[test] + fn read_unlock_collection_works() { + let key = ThreadKey::get().unwrap(); + let locks = (RwLock::new("foo"), RwLock::new("bar")); + let collection = RefLockCollection::new(&locks); + let guard = collection.read(key); + + let key = RefLockCollection::<(&RwLock<_>, &RwLock<_>)>::unlock_read(guard); + assert!(collection.try_lock(key).is_ok()) + } + + #[test] + fn as_ref_works() { + let mutexes = [Mutex::new(0), Mutex::new(1)]; + let collection = RefLockCollection::new(&mutexes); + + assert!(std::ptr::addr_eq(&mutexes, collection.as_ref())) + } + + #[test] + fn child() { + let mutexes = [Mutex::new(0), Mutex::new(1)]; + let collection = RefLockCollection::new(&mutexes); + + assert!(std::ptr::addr_eq(&mutexes, collection.child())) } } |
