diff options
| author | Mica White <botahamec@outlook.com> | 2025-03-09 20:49:56 -0400 |
|---|---|---|
| committer | Mica White <botahamec@outlook.com> | 2025-03-09 20:49:56 -0400 |
| commit | 58abf5872023aca7ee6459fa3b2e067d57923ba5 (patch) | |
| tree | 196cadda0dd4386668477ef286f9c9b09480e713 /src/collection/retry.rs | |
| parent | 4ba03be97e6cc7e790bbc9bfc18caaa228c8a262 (diff) | |
Finish testing and fixing
Diffstat (limited to 'src/collection/retry.rs')
| -rw-r--r-- | src/collection/retry.rs | 403 |
1 files changed, 310 insertions, 93 deletions
diff --git a/src/collection/retry.rs b/src/collection/retry.rs index 775ea29..70e5183 100644 --- a/src/collection/retry.rs +++ b/src/collection/retry.rs @@ -9,7 +9,8 @@ use crate::lockable::{ use crate::{Keyable, ThreadKey}; use super::utils::{ - attempt_to_recover_locks_from_panic, attempt_to_recover_reads_from_panic, get_locks_unsorted, + attempt_to_recover_reads_from_panic, attempt_to_recover_writes_from_panic, get_locks_unsorted, + scoped_read, scoped_try_read, scoped_try_write, scoped_write, }; use super::{LockGuard, RetryingLockCollection}; @@ -40,7 +41,7 @@ unsafe impl<L: Lockable> RawLock for RetryingLockCollection<L> { } } - unsafe fn raw_lock(&self) { + unsafe fn raw_write(&self) { let locks = get_locks_unsorted(&self.data); if locks.is_empty() { @@ -57,7 +58,7 @@ unsafe impl<L: Lockable> RawLock for RetryingLockCollection<L> { // This prevents us from entering a spin loop waiting for // the same lock to be unlocked // safety: we have the thread key - locks[first_index.get()].raw_lock(); + locks[first_index.get()].raw_write(); for (i, lock) in locks.iter().enumerate() { if i == first_index.get() { // we've already locked this one @@ -69,15 +70,15 @@ unsafe impl<L: Lockable> RawLock for RetryingLockCollection<L> { // it does return false, then the lock function is called // immediately after, causing a panic // safety: we have the thread key - if lock.raw_try_lock() { + if lock.raw_try_write() { locked.set(locked.get() + 1); } else { // safety: we already locked all of these - attempt_to_recover_locks_from_panic(&locks[0..i]); + attempt_to_recover_writes_from_panic(&locks[0..i]); if first_index.get() >= i { // safety: this is already locked and can't be // unlocked by the previous loop - locks[first_index.get()].raw_unlock(); + locks[first_index.get()].raw_unlock_write(); } // nothing is locked anymore @@ -94,15 +95,15 @@ unsafe impl<L: Lockable> RawLock for RetryingLockCollection<L> { } }, || { - utils::attempt_to_recover_locks_from_panic(&locks[0..locked.get()]); + utils::attempt_to_recover_writes_from_panic(&locks[0..locked.get()]); if first_index.get() >= locked.get() { - locks[first_index.get()].raw_unlock(); + locks[first_index.get()].raw_unlock_write(); } }, ) } - unsafe fn raw_try_lock(&self) -> bool { + unsafe fn raw_try_write(&self) -> bool { let locks = get_locks_unsorted(&self.data); if locks.is_empty() { @@ -117,26 +118,26 @@ unsafe impl<L: Lockable> RawLock for RetryingLockCollection<L> { || unsafe { for (i, lock) in locks.iter().enumerate() { // safety: we have the thread key - if lock.raw_try_lock() { + if lock.raw_try_write() { locked.set(locked.get() + 1); } else { // safety: we already locked all of these - attempt_to_recover_locks_from_panic(&locks[0..i]); + attempt_to_recover_writes_from_panic(&locks[0..i]); return false; } } true }, - || utils::attempt_to_recover_locks_from_panic(&locks[0..locked.get()]), + || utils::attempt_to_recover_writes_from_panic(&locks[0..locked.get()]), ) } - unsafe fn raw_unlock(&self) { + unsafe fn raw_unlock_write(&self) { let locks = get_locks_unsorted(&self.data); for lock in locks { - lock.raw_unlock(); + lock.raw_unlock_write(); } } @@ -243,7 +244,7 @@ unsafe impl<L: Lockable> Lockable for RetryingLockCollection<L> { Self: 'a; fn get_ptrs<'a>(&'a self, ptrs: &mut Vec<&'a dyn RawLock>) { - self.data.get_ptrs(ptrs) + ptrs.push(self) } unsafe fn guard(&self) -> Self::Guard<'_> { @@ -347,13 +348,13 @@ impl<E: OwnedLockable + Extend<L>, L: OwnedLockable> Extend<L> for RetryingLockC } } -impl<T, L: AsRef<T>> AsRef<T> for RetryingLockCollection<L> { +impl<T: ?Sized, L: AsRef<T>> AsRef<T> for RetryingLockCollection<L> { fn as_ref(&self) -> &T { self.data.as_ref() } } -impl<T, L: AsMut<T>> AsMut<T> for RetryingLockCollection<L> { +impl<T: ?Sized, L: AsMut<T>> AsMut<T> for RetryingLockCollection<L> { fn as_mut(&mut self) -> &mut T { self.data.as_mut() } @@ -389,7 +390,8 @@ impl<L: OwnedLockable> RetryingLockCollection<L> { /// ``` #[must_use] pub const fn new(data: L) -> Self { - Self { data } + // safety: the data cannot cannot contain references + unsafe { Self::new_unchecked(data) } } } @@ -410,7 +412,8 @@ impl<'a, L: OwnedLockable> RetryingLockCollection<&'a L> { /// ``` #[must_use] pub const fn new_ref(data: &'a L) -> Self { - Self { data } + // safety: the data cannot cannot contain references + unsafe { Self::new_unchecked(data) } } } @@ -525,47 +528,20 @@ impl<L: Lockable> RetryingLockCollection<L> { /// ``` #[must_use] pub fn try_new(data: L) -> Option<Self> { - (!contains_duplicates(&data)).then_some(Self { data }) + // safety: the data is checked for duplicates before returning the collection + (!contains_duplicates(&data)).then_some(unsafe { Self::new_unchecked(data) }) } - pub fn scoped_lock<R>(&self, key: impl Keyable, f: impl Fn(L::DataMut<'_>) -> R) -> R { - unsafe { - // safety: we have the thread key - self.raw_lock(); - - // safety: the data was just locked - let r = f(self.data_mut()); - - // safety: the collection is still locked - self.raw_unlock(); - - drop(key); // ensure the key stays alive long enough - - r - } + pub fn scoped_lock<'a, R>(&'a self, key: impl Keyable, f: impl Fn(L::DataMut<'a>) -> R) -> R { + scoped_write(self, key, f) } - pub fn scoped_try_lock<Key: Keyable, R>( - &self, + pub fn scoped_try_lock<'a, Key: Keyable, R>( + &'a self, key: Key, - f: impl Fn(L::DataMut<'_>) -> R, + f: impl Fn(L::DataMut<'a>) -> R, ) -> Result<R, Key> { - unsafe { - // safety: we have the thread key - if !self.raw_try_lock() { - return Err(key); - } - - // safety: we just locked the collection - let r = f(self.data_mut()); - - // safety: the collection is still locked - self.raw_unlock(); - - drop(key); // ensures the key stays valid long enough - - Ok(r) - } + scoped_try_write(self, key, f) } /// Locks the collection @@ -591,7 +567,7 @@ impl<L: Lockable> RetryingLockCollection<L> { pub fn lock(&self, key: ThreadKey) -> LockGuard<L::Guard<'_>> { unsafe { // safety: we're taking the thread key - self.raw_lock(); + self.raw_write(); LockGuard { // safety: we just locked the collection @@ -634,7 +610,7 @@ impl<L: Lockable> RetryingLockCollection<L> { pub fn try_lock(&self, key: ThreadKey) -> Result<LockGuard<L::Guard<'_>>, ThreadKey> { unsafe { // safety: we're taking the thread key - if self.raw_try_lock() { + if self.raw_try_write() { Ok(LockGuard { // safety: we just succeeded in locking everything guard: self.guard(), @@ -671,44 +647,16 @@ impl<L: Lockable> RetryingLockCollection<L> { } impl<L: Sharable> RetryingLockCollection<L> { - pub fn scoped_read<R>(&self, key: impl Keyable, f: impl Fn(L::DataRef<'_>) -> R) -> R { - unsafe { - // safety: we have the thread key - self.raw_read(); - - // safety: the data was just locked - let r = f(self.data_ref()); - - // safety: the collection is still locked - self.raw_unlock_read(); - - drop(key); // ensure the key stays alive long enough - - r - } + pub fn scoped_read<'a, R>(&'a self, key: impl Keyable, f: impl Fn(L::DataRef<'a>) -> R) -> R { + scoped_read(self, key, f) } - pub fn scoped_try_read<Key: Keyable, R>( - &self, + pub fn scoped_try_read<'a, Key: Keyable, R>( + &'a self, key: Key, - f: impl Fn(L::DataRef<'_>) -> R, + f: impl Fn(L::DataRef<'a>) -> R, ) -> Result<R, Key> { - unsafe { - // safety: we have the thread key - if !self.raw_try_read() { - return Err(key); - } - - // safety: we just locked the collection - let r = f(self.data_ref()); - - // safety: the collection is still locked - self.raw_unlock_read(); - - drop(key); // ensures the key stays valid long enough - - Ok(r) - } + scoped_try_read(self, key, f) } /// Locks the collection, so that other threads can still read from it @@ -778,7 +726,7 @@ impl<L: Sharable> RetryingLockCollection<L> { pub fn try_read(&self, key: ThreadKey) -> Result<LockGuard<L::ReadGuard<'_>>, ThreadKey> { unsafe { // safety: we're taking the thread key - if !self.raw_try_lock() { + if !self.raw_try_read() { return Err(key); } @@ -911,7 +859,7 @@ where mod tests { use super::*; use crate::collection::BoxedLockCollection; - use crate::{LockCollection, Mutex, RwLock, ThreadKey}; + use crate::{Mutex, RwLock, ThreadKey}; #[test] fn nonduplicate_lock_references_are_allowed() { @@ -927,6 +875,159 @@ mod tests { } #[test] + #[allow(clippy::float_cmp)] + fn uses_correct_default() { + let collection = + RetryingLockCollection::<(RwLock<f64>, Mutex<Option<i32>>, Mutex<usize>)>::default(); + let tuple = collection.into_inner(); + assert_eq!(tuple.0, 0.0); + assert!(tuple.1.is_none()); + assert_eq!(tuple.2, 0) + } + + #[test] + fn from() { + let key = ThreadKey::get().unwrap(); + let collection = + RetryingLockCollection::from([Mutex::new("foo"), Mutex::new("bar"), Mutex::new("baz")]); + let guard = collection.lock(key); + assert_eq!(*guard[0], "foo"); + assert_eq!(*guard[1], "bar"); + assert_eq!(*guard[2], "baz"); + } + + #[test] + fn new_ref_works() { + let key = ThreadKey::get().unwrap(); + let mutexes = [Mutex::new(0), Mutex::new(1)]; + let collection = RetryingLockCollection::new_ref(&mutexes); + collection.scoped_lock(key, |guard| { + assert_eq!(*guard[0], 0); + assert_eq!(*guard[1], 1); + }) + } + + #[test] + fn scoped_read_sees_changes() { + let mut key = ThreadKey::get().unwrap(); + let mutexes = [RwLock::new(24), RwLock::new(42)]; + let collection = RetryingLockCollection::new(mutexes); + collection.scoped_lock(&mut key, |guard| *guard[0] = 128); + + let sum = collection.scoped_read(&mut key, |guard| { + assert_eq!(*guard[0], 128); + assert_eq!(*guard[1], 42); + *guard[0] + *guard[1] + }); + + assert_eq!(sum, 128 + 42); + } + + #[test] + fn get_mut_affects_scoped_read() { + let mut key = ThreadKey::get().unwrap(); + let mutexes = [RwLock::new(24), RwLock::new(42)]; + let mut collection = RetryingLockCollection::new(mutexes); + let guard = collection.get_mut(); + *guard[0] = 128; + + let sum = collection.scoped_read(&mut key, |guard| { + assert_eq!(*guard[0], 128); + assert_eq!(*guard[1], 42); + *guard[0] + *guard[1] + }); + + assert_eq!(sum, 128 + 42); + } + + #[test] + fn scoped_try_lock_can_fail() { + let key = ThreadKey::get().unwrap(); + let collection = RetryingLockCollection::new([Mutex::new(1), Mutex::new(2)]); + let guard = collection.lock(key); + + std::thread::scope(|s| { + s.spawn(|| { + let key = ThreadKey::get().unwrap(); + let r = collection.scoped_try_lock(key, |_| {}); + assert!(r.is_err()); + }); + }); + + drop(guard); + } + + #[test] + fn scoped_try_read_can_fail() { + let key = ThreadKey::get().unwrap(); + let collection = RetryingLockCollection::new([RwLock::new(1), RwLock::new(2)]); + let guard = collection.lock(key); + + std::thread::scope(|s| { + s.spawn(|| { + let key = ThreadKey::get().unwrap(); + let r = collection.scoped_try_read(key, |_| {}); + assert!(r.is_err()); + }); + }); + + drop(guard); + } + + #[test] + fn try_lock_works() { + let key = ThreadKey::get().unwrap(); + let collection = RetryingLockCollection::new([Mutex::new(1), Mutex::new(2)]); + let guard = collection.try_lock(key); + + std::thread::scope(|s| { + s.spawn(|| { + let key = ThreadKey::get().unwrap(); + let guard = collection.try_lock(key); + assert!(guard.is_err()); + }); + }); + + assert!(guard.is_ok()); + } + + #[test] + fn try_read_works() { + let key = ThreadKey::get().unwrap(); + let collection = RetryingLockCollection::new([RwLock::new(1), RwLock::new(2)]); + let guard = collection.try_read(key); + + std::thread::scope(|s| { + s.spawn(|| { + let key = ThreadKey::get().unwrap(); + let guard = collection.try_read(key); + assert!(guard.is_ok()); + }); + }); + + assert!(guard.is_ok()); + } + + #[test] + fn try_read_fails_for_locked_collection() { + let key = ThreadKey::get().unwrap(); + let mutexes = [RwLock::new(24), RwLock::new(42)]; + let collection = RetryingLockCollection::new_ref(&mutexes); + + std::thread::scope(|s| { + s.spawn(|| { + let key = ThreadKey::get().unwrap(); + let guard = mutexes[1].write(key); + assert_eq!(*guard, 42); + std::mem::forget(guard); + }); + }); + + let guard = collection.try_read(key); + assert!(guard.is_err()); + } + + #[test] fn locks_all_inner_mutexes() { let key = ThreadKey::get().unwrap(); let mutex1 = Mutex::new(0); @@ -974,6 +1075,55 @@ mod tests { } #[test] + fn from_iterator() { + let key = ThreadKey::get().unwrap(); + let collection: RetryingLockCollection<Vec<Mutex<&str>>> = + [Mutex::new("foo"), Mutex::new("bar"), Mutex::new("baz")] + .into_iter() + .collect(); + let guard = collection.lock(key); + assert_eq!(*guard[0], "foo"); + assert_eq!(*guard[1], "bar"); + assert_eq!(*guard[2], "baz"); + } + + #[test] + fn into_owned_iterator() { + let collection = RetryingLockCollection::new([Mutex::new(0), Mutex::new(1), Mutex::new(2)]); + for (i, mutex) in collection.into_iter().enumerate() { + assert_eq!(mutex.into_inner(), i); + } + } + + #[test] + fn into_ref_iterator() { + let mut key = ThreadKey::get().unwrap(); + let collection = RetryingLockCollection::new([Mutex::new(0), Mutex::new(1), Mutex::new(2)]); + for (i, mutex) in (&collection).into_iter().enumerate() { + mutex.scoped_lock(&mut key, |val| assert_eq!(*val, i)) + } + } + + #[test] + fn ref_iterator() { + let mut key = ThreadKey::get().unwrap(); + let collection = RetryingLockCollection::new([Mutex::new(0), Mutex::new(1), Mutex::new(2)]); + for (i, mutex) in collection.iter().enumerate() { + mutex.scoped_lock(&mut key, |val| assert_eq!(*val, i)) + } + } + + #[test] + fn mut_iterator() { + let mut key = ThreadKey::get().unwrap(); + let mut collection = + RetryingLockCollection::new([Mutex::new(0), Mutex::new(1), Mutex::new(2)]); + for (i, mutex) in collection.iter_mut().enumerate() { + mutex.scoped_lock(&mut key, |val| assert_eq!(*val, i)) + } + } + + #[test] fn extend_collection() { let mutex1 = Mutex::new(0); let mutex2 = Mutex::new(0); @@ -991,9 +1141,76 @@ mod tests { let guard = collection.lock(key); assert!(guard.len() == 0); - let key = LockCollection::<[RwLock<_>; 0]>::unlock(guard); + let key = RetryingLockCollection::<[RwLock<_>; 0]>::unlock(guard); + + let guard = collection.read(key); + assert!(guard.len() == 0); + } + + #[test] + fn read_empty_lock_collection() { + let key = ThreadKey::get().unwrap(); + let collection: RetryingLockCollection<[RwLock<i32>; 0]> = RetryingLockCollection::new([]); let guard = collection.read(key); assert!(guard.len() == 0); + let key = RetryingLockCollection::<[RwLock<_>; 0]>::unlock_read(guard); + + let guard = collection.lock(key); + assert!(guard.len() == 0); + } + + #[test] + fn as_ref_works() { + let mutexes = [Mutex::new(0), Mutex::new(1)]; + let collection = RetryingLockCollection::new_ref(&mutexes); + + assert!(std::ptr::addr_eq(&mutexes, collection.as_ref())) + } + + #[test] + fn as_mut_works() { + let mut mutexes = [Mutex::new(0), Mutex::new(1)]; + let mut collection = RetryingLockCollection::new(&mut mutexes); + + collection.as_mut()[0] = Mutex::new(42); + + assert_eq!(*collection.as_mut()[0].get_mut(), 42); + } + + #[test] + fn child() { + let mutexes = [Mutex::new(0), Mutex::new(1)]; + let collection = RetryingLockCollection::new_ref(&mutexes); + + assert!(std::ptr::addr_eq(&mutexes, *collection.child())) + } + + #[test] + fn child_mut_works() { + let mut mutexes = [Mutex::new(0), Mutex::new(1)]; + let mut collection = RetryingLockCollection::new(&mut mutexes); + + collection.child_mut()[0] = Mutex::new(42); + + assert_eq!(*collection.child_mut()[0].get_mut(), 42); + } + + #[test] + fn into_child_works() { + let mutexes = [Mutex::new(0), Mutex::new(1)]; + let mut collection = RetryingLockCollection::new(mutexes); + + collection.child_mut()[0] = Mutex::new(42); + + assert_eq!( + *collection + .into_child() + .as_mut() + .get_mut(0) + .unwrap() + .get_mut(), + 42 + ); } } |
