From 58abf5872023aca7ee6459fa3b2e067d57923ba5 Mon Sep 17 00:00:00 2001 From: Mica White Date: Sun, 9 Mar 2025 20:49:56 -0400 Subject: Finish testing and fixing --- src/collection.rs | 2 +- src/collection/boxed.rs | 186 +++++++++++--------- src/collection/owned.rs | 283 +++++++++++++++++++++--------- src/collection/ref.rs | 261 +++++++++++++++++++--------- src/collection/retry.rs | 403 +++++++++++++++++++++++++++++++++---------- src/collection/utils.rs | 113 ++++++++++-- src/lockable.rs | 127 +++++++++++++- src/mutex.rs | 25 ++- src/mutex/guard.rs | 6 +- src/mutex/mutex.rs | 85 ++++----- src/poisonable.rs | 164 +++++++++++++++++- src/poisonable/poisonable.rs | 57 ++++-- src/rwlock.rs | 323 +++++++++++++++++++++++++++++++++- src/rwlock/read_guard.rs | 4 +- src/rwlock/read_lock.rs | 54 +++++- src/rwlock/rwlock.rs | 115 +++--------- src/rwlock/write_guard.rs | 6 +- src/rwlock/write_lock.rs | 52 +++++- src/thread.rs | 19 ++ src/thread/scope.rst | 47 +++++ 20 files changed, 1798 insertions(+), 534 deletions(-) create mode 100644 src/thread.rs create mode 100644 src/thread/scope.rst (limited to 'src') diff --git a/src/collection.rs b/src/collection.rs index e50cc30..f8c31d7 100644 --- a/src/collection.rs +++ b/src/collection.rs @@ -7,7 +7,7 @@ mod guard; mod owned; mod r#ref; mod retry; -mod utils; +pub(crate) mod utils; /// Locks a collection of locks, which cannot be shared immutably. /// diff --git a/src/collection/boxed.rs b/src/collection/boxed.rs index 364ec97..1891119 100644 --- a/src/collection/boxed.rs +++ b/src/collection/boxed.rs @@ -4,7 +4,9 @@ use std::fmt::Debug; use crate::lockable::{Lockable, LockableIntoInner, OwnedLockable, RawLock, Sharable}; use crate::{Keyable, ThreadKey}; -use super::utils::ordered_contains_duplicates; +use super::utils::{ + ordered_contains_duplicates, scoped_read, scoped_try_read, scoped_try_write, scoped_write, +}; use super::{utils, BoxedLockCollection, LockGuard}; unsafe impl RawLock for BoxedLockCollection { @@ -16,18 +18,18 @@ unsafe impl RawLock for BoxedLockCollection { } } - unsafe fn raw_lock(&self) { - utils::ordered_lock(self.locks()) + unsafe fn raw_write(&self) { + utils::ordered_write(self.locks()) } - unsafe fn raw_try_lock(&self) -> bool { + unsafe fn raw_try_write(&self) -> bool { println!("{}", self.locks().len()); - utils::ordered_try_lock(self.locks()) + utils::ordered_try_write(self.locks()) } - unsafe fn raw_unlock(&self) { + unsafe fn raw_unlock_write(&self) { for lock in self.locks() { - lock.raw_unlock(); + lock.raw_unlock_write(); } } @@ -58,7 +60,7 @@ unsafe impl Lockable for BoxedLockCollection { Self: 'a; fn get_ptrs<'a>(&'a self, ptrs: &mut Vec<&'a dyn RawLock>) { - ptrs.extend(self.locks()) + ptrs.push(self); } unsafe fn guard(&self) -> Self::Guard<'_> { @@ -156,7 +158,7 @@ impl Drop for BoxedLockCollection { } } -impl> AsRef for BoxedLockCollection { +impl> AsRef for BoxedLockCollection { fn as_ref(&self) -> &T { self.child().as_ref() } @@ -364,44 +366,16 @@ impl BoxedLockCollection { } } - pub fn scoped_lock(&self, key: impl Keyable, f: impl Fn(L::DataMut<'_>) -> R) -> R { - unsafe { - // safety: we have the thread key - self.raw_lock(); - - // safety: the data was just locked - let r = f(self.data_mut()); - - // safety: the collection is still locked - self.raw_unlock(); - - drop(key); // ensure the key stays alive long enough - - r - } + pub fn scoped_lock<'a, R>(&'a self, key: impl Keyable, f: impl Fn(L::DataMut<'a>) -> R) -> R { + scoped_write(self, key, f) } - pub fn scoped_try_lock( - &self, + pub fn scoped_try_lock<'a, Key: Keyable, R>( + &'a self, key: Key, - f: impl Fn(L::DataMut<'_>) -> R, + f: impl Fn(L::DataMut<'a>) -> R, ) -> Result { - unsafe { - // safety: we have the thread key - if !self.raw_try_lock() { - return Err(key); - } - - // safety: we just locked the collection - let r = f(self.data_mut()); - - // safety: the collection is still locked - self.raw_unlock(); - - drop(key); // ensures the key stays valid long enough - - Ok(r) - } + scoped_try_write(self, key, f) } /// Locks the collection @@ -427,7 +401,7 @@ impl BoxedLockCollection { pub fn lock(&self, key: ThreadKey) -> LockGuard> { unsafe { // safety: we have the thread key - self.raw_lock(); + self.raw_write(); LockGuard { // safety: we've already acquired the lock @@ -468,7 +442,7 @@ impl BoxedLockCollection { /// ``` pub fn try_lock(&self, key: ThreadKey) -> Result>, ThreadKey> { let guard = unsafe { - if !self.raw_try_lock() { + if !self.raw_try_write() { return Err(key); } @@ -503,44 +477,16 @@ impl BoxedLockCollection { } impl BoxedLockCollection { - pub fn scoped_read(&self, key: impl Keyable, f: impl Fn(L::DataRef<'_>) -> R) -> R { - unsafe { - // safety: we have the thread key - self.raw_read(); - - // safety: the data was just locked - let r = f(self.data_ref()); - - // safety: the collection is still locked - self.raw_unlock_read(); - - drop(key); // ensure the key stays alive long enough - - r - } + pub fn scoped_read<'a, R>(&'a self, key: impl Keyable, f: impl Fn(L::DataRef<'a>) -> R) -> R { + scoped_read(self, key, f) } - pub fn scoped_try_read( - &self, + pub fn scoped_try_read<'a, Key: Keyable, R>( + &'a self, key: Key, - f: impl Fn(L::DataRef<'_>) -> R, + f: impl Fn(L::DataRef<'a>) -> R, ) -> Result { - unsafe { - // safety: we have the thread key - if !self.raw_try_read() { - return Err(key); - } - - // safety: we just locked the collection - let r = f(self.data_ref()); - - // safety: the collection is still locked - self.raw_unlock_read(); - - drop(key); // ensures the key stays valid long enough - - Ok(r) - } + scoped_try_read(self, key, f) } /// Locks the collection, so that other threads can still read from it @@ -764,6 +710,56 @@ mod tests { assert!(BoxedLockCollection::try_new([&mutex1, &mutex1]).is_none()) } + #[test] + fn scoped_read_sees_changes() { + let mut key = ThreadKey::get().unwrap(); + let mutexes = [RwLock::new(24), RwLock::new(42)]; + let collection = BoxedLockCollection::new(mutexes); + collection.scoped_lock(&mut key, |guard| *guard[0] = 128); + + let sum = collection.scoped_read(&mut key, |guard| { + assert_eq!(*guard[0], 128); + assert_eq!(*guard[1], 42); + *guard[0] + *guard[1] + }); + + assert_eq!(sum, 128 + 42); + } + + #[test] + fn scoped_try_lock_can_fail() { + let key = ThreadKey::get().unwrap(); + let collection = BoxedLockCollection::new([Mutex::new(1), Mutex::new(2)]); + let guard = collection.lock(key); + + std::thread::scope(|s| { + s.spawn(|| { + let key = ThreadKey::get().unwrap(); + let r = collection.scoped_try_lock(key, |_| {}); + assert!(r.is_err()); + }); + }); + + drop(guard); + } + + #[test] + fn scoped_try_read_can_fail() { + let key = ThreadKey::get().unwrap(); + let collection = BoxedLockCollection::new([RwLock::new(1), RwLock::new(2)]); + let guard = collection.lock(key); + + std::thread::scope(|s| { + s.spawn(|| { + let key = ThreadKey::get().unwrap(); + let r = collection.scoped_try_read(key, |_| {}); + assert!(r.is_err()); + }); + }); + + drop(guard); + } + #[test] fn try_lock_works() { let key = ThreadKey::get().unwrap(); @@ -884,15 +880,41 @@ mod tests { #[test] fn works_in_collection() { let key = ThreadKey::get().unwrap(); - let mutex1 = Mutex::new(0); - let mutex2 = Mutex::new(1); + let mutex1 = RwLock::new(0); + let mutex2 = RwLock::new(1); let collection = BoxedLockCollection::try_new(BoxedLockCollection::try_new([&mutex1, &mutex2]).unwrap()) .unwrap(); - let guard = collection.lock(key); + let mut guard = collection.lock(key); + assert!(mutex1.is_locked()); + assert!(mutex2.is_locked()); + assert_eq!(*guard[0], 0); + assert_eq!(*guard[1], 1); + *guard[0] = 2; + let key = BoxedLockCollection::; 2]>>::unlock(guard); + + let guard = collection.read(key); assert!(mutex1.is_locked()); assert!(mutex2.is_locked()); + assert_eq!(*guard[0], 2); + assert_eq!(*guard[1], 1); drop(guard); } + + #[test] + fn as_ref_works() { + let mutexes = [Mutex::new(0), Mutex::new(1)]; + let collection = BoxedLockCollection::new_ref(&mutexes); + + assert!(std::ptr::addr_eq(&mutexes, collection.as_ref())) + } + + #[test] + fn child() { + let mutexes = [Mutex::new(0), Mutex::new(1)]; + let collection = BoxedLockCollection::new_ref(&mutexes); + + assert!(std::ptr::addr_eq(&mutexes, *collection.child())) + } } diff --git a/src/collection/owned.rs b/src/collection/owned.rs index b9cf313..68170d1 100644 --- a/src/collection/owned.rs +++ b/src/collection/owned.rs @@ -3,6 +3,7 @@ use crate::lockable::{ }; use crate::{Keyable, ThreadKey}; +use super::utils::{scoped_read, scoped_try_read, scoped_try_write, scoped_write}; use super::{utils, LockGuard, OwnedLockCollection}; unsafe impl RawLock for OwnedLockCollection { @@ -15,19 +16,19 @@ unsafe impl RawLock for OwnedLockCollection { } } - unsafe fn raw_lock(&self) { - utils::ordered_lock(&utils::get_locks_unsorted(&self.data)) + unsafe fn raw_write(&self) { + utils::ordered_write(&utils::get_locks_unsorted(&self.data)) } - unsafe fn raw_try_lock(&self) -> bool { + unsafe fn raw_try_write(&self) -> bool { let locks = utils::get_locks_unsorted(&self.data); - utils::ordered_try_lock(&locks) + utils::ordered_try_write(&locks) } - unsafe fn raw_unlock(&self) { + unsafe fn raw_unlock_write(&self) { let locks = utils::get_locks_unsorted(&self.data); for lock in locks { - lock.raw_unlock(); + lock.raw_unlock_write(); } } @@ -62,7 +63,7 @@ unsafe impl Lockable for OwnedLockCollection { #[mutants::skip] // It's hard to test lkocks in an OwnedLockCollection, because they're owned #[cfg(not(tarpaulin_include))] fn get_ptrs<'a>(&'a self, ptrs: &mut Vec<&'a dyn RawLock>) { - self.data.get_ptrs(ptrs) + ptrs.push(self) } unsafe fn guard(&self) -> Self::Guard<'_> { @@ -146,7 +147,7 @@ impl, L: OwnedLockable> Extend for OwnedLockColl // invariant that there is only one way to lock the collection. AsMut is fine, // because the collection can't be locked as long as the reference is valid. -impl> AsMut for OwnedLockCollection { +impl> AsMut for OwnedLockCollection { fn as_mut(&mut self) -> &mut T { self.data.as_mut() } @@ -185,44 +186,16 @@ impl OwnedLockCollection { Self { data } } - pub fn scoped_lock(&self, key: impl Keyable, f: impl Fn(L::DataMut<'_>) -> R) -> R { - unsafe { - // safety: we have the thread key - self.raw_lock(); - - // safety: the data was just locked - let r = f(self.data_mut()); - - // safety: the collection is still locked - self.raw_unlock(); - - drop(key); // ensure the key stays alive long enough - - r - } + pub fn scoped_lock<'a, R>(&'a self, key: impl Keyable, f: impl Fn(L::DataMut<'a>) -> R) -> R { + scoped_write(self, key, f) } - pub fn scoped_try_lock( - &self, + pub fn scoped_try_lock<'a, Key: Keyable, R>( + &'a self, key: Key, - f: impl Fn(L::DataMut<'_>) -> R, + f: impl Fn(L::DataMut<'a>) -> R, ) -> Result { - unsafe { - // safety: we have the thread key - if !self.raw_try_lock() { - return Err(key); - } - - // safety: we just locked the collection - let r = f(self.data_mut()); - - // safety: the collection is still locked - self.raw_unlock(); - - drop(key); // ensures the key stays valid long enough - - Ok(r) - } + scoped_try_write(self, key, f) } /// Locks the collection @@ -249,7 +222,7 @@ impl OwnedLockCollection { let guard = unsafe { // safety: we have the thread key, and these locks happen in a // predetermined order - self.raw_lock(); + self.raw_write(); // safety: we've locked all of this already self.data.guard() @@ -290,7 +263,7 @@ impl OwnedLockCollection { /// ``` pub fn try_lock(&self, key: ThreadKey) -> Result>, ThreadKey> { let guard = unsafe { - if !self.raw_try_lock() { + if !self.raw_try_write() { return Err(key); } @@ -327,44 +300,16 @@ impl OwnedLockCollection { } impl OwnedLockCollection { - pub fn scoped_read(&self, key: impl Keyable, f: impl Fn(L::DataRef<'_>) -> R) -> R { - unsafe { - // safety: we have the thread key - self.raw_read(); - - // safety: the data was just locked - let r = f(self.data_ref()); - - // safety: the collection is still locked - self.raw_unlock_read(); - - drop(key); // ensure the key stays alive long enough - - r - } + pub fn scoped_read<'a, R>(&'a self, key: impl Keyable, f: impl Fn(L::DataRef<'a>) -> R) -> R { + scoped_read(self, key, f) } - pub fn scoped_try_read( - &self, + pub fn scoped_try_read<'a, Key: Keyable, R>( + &'a self, key: Key, - f: impl Fn(L::DataRef<'_>) -> R, + f: impl Fn(L::DataRef<'a>) -> R, ) -> Result { - unsafe { - // safety: we have the thread key - if !self.raw_try_read() { - return Err(key); - } - - // safety: we just locked the collection - let r = f(self.data_ref()); - - // safety: the collection is still locked - self.raw_unlock_read(); - - drop(key); // ensures the key stays valid long enough - - Ok(r) - } + scoped_try_read(self, key, f) } /// Locks the collection, so that other threads can still read from it @@ -554,7 +499,7 @@ impl OwnedLockCollection { #[cfg(test)] mod tests { use super::*; - use crate::{Mutex, ThreadKey}; + use crate::{Mutex, RwLock, ThreadKey}; #[test] fn get_mut_applies_changes() { @@ -603,6 +548,63 @@ mod tests { } } + #[test] + fn scoped_read_works() { + let mut key = ThreadKey::get().unwrap(); + let collection = OwnedLockCollection::new([RwLock::new(24), RwLock::new(42)]); + let sum = collection.scoped_read(&mut key, |guard| guard[0] + guard[1]); + assert_eq!(sum, 24 + 42); + } + + #[test] + fn scoped_lock_works() { + let mut key = ThreadKey::get().unwrap(); + let collection = OwnedLockCollection::new([RwLock::new(24), RwLock::new(42)]); + collection.scoped_lock(&mut key, |guard| *guard[0] += *guard[1]); + + let sum = collection.scoped_lock(&mut key, |guard| { + assert_eq!(*guard[0], 24 + 42); + assert_eq!(*guard[1], 42); + *guard[0] + *guard[1] + }); + + assert_eq!(sum, 24 + 42 + 42); + } + + #[test] + fn scoped_try_lock_can_fail() { + let key = ThreadKey::get().unwrap(); + let collection = OwnedLockCollection::new([Mutex::new(1), Mutex::new(2)]); + let guard = collection.lock(key); + + std::thread::scope(|s| { + s.spawn(|| { + let key = ThreadKey::get().unwrap(); + let r = collection.scoped_try_lock(key, |_| {}); + assert!(r.is_err()); + }); + }); + + drop(guard); + } + + #[test] + fn scoped_try_read_can_fail() { + let key = ThreadKey::get().unwrap(); + let collection = OwnedLockCollection::new([RwLock::new(1), RwLock::new(2)]); + let guard = collection.lock(key); + + std::thread::scope(|s| { + s.spawn(|| { + let key = ThreadKey::get().unwrap(); + let r = collection.scoped_try_read(key, |_| {}); + assert!(r.is_err()); + }); + }); + + drop(guard); + } + #[test] fn try_lock_works_on_unlocked() { let key = ThreadKey::get().unwrap(); @@ -629,6 +631,74 @@ mod tests { assert!(collection.try_lock(key).is_err()); } + #[test] + fn try_read_succeeds_for_unlocked_collection() { + let key = ThreadKey::get().unwrap(); + let mutexes = [RwLock::new(24), RwLock::new(42)]; + let collection = OwnedLockCollection::new(mutexes); + let guard = collection.try_read(key).unwrap(); + assert_eq!(*guard[0], 24); + assert_eq!(*guard[1], 42); + } + + #[test] + fn try_read_fails_on_locked() { + let key = ThreadKey::get().unwrap(); + let collection = OwnedLockCollection::new((RwLock::new(0), RwLock::new(1))); + + std::thread::scope(|s| { + s.spawn(|| { + let key = ThreadKey::get().unwrap(); + #[allow(unused)] + let guard = collection.lock(key); + std::mem::forget(guard); + }); + }); + + assert!(collection.try_read(key).is_err()); + } + + #[test] + fn can_read_twice_on_different_threads() { + let key = ThreadKey::get().unwrap(); + let mutexes = [RwLock::new(24), RwLock::new(42)]; + let collection = OwnedLockCollection::new(mutexes); + + std::thread::scope(|s| { + s.spawn(|| { + let key = ThreadKey::get().unwrap(); + let guard = collection.read(key); + assert_eq!(*guard[0], 24); + assert_eq!(*guard[1], 42); + std::mem::forget(guard); + }); + }); + + let guard = collection.try_read(key).unwrap(); + assert_eq!(*guard[0], 24); + assert_eq!(*guard[1], 42); + } + + #[test] + fn unlock_collection_works() { + let key = ThreadKey::get().unwrap(); + let collection = OwnedLockCollection::new((Mutex::new("foo"), Mutex::new("bar"))); + let guard = collection.lock(key); + + let key = OwnedLockCollection::<(Mutex<_>, Mutex<_>)>::unlock(guard); + assert!(collection.try_lock(key).is_ok()) + } + + #[test] + fn read_unlock_collection_works() { + let key = ThreadKey::get().unwrap(); + let collection = OwnedLockCollection::new((RwLock::new("foo"), RwLock::new("bar"))); + let guard = collection.read(key); + + let key = OwnedLockCollection::<(&RwLock<_>, &RwLock<_>)>::unlock_read(guard); + assert!(collection.try_lock(key).is_ok()) + } + #[test] fn default_works() { type MyCollection = OwnedLockCollection<(Mutex, Mutex>, Mutex)>; @@ -649,4 +719,59 @@ mod tests { assert_eq!(collection.data.len(), 3); } + + #[test] + fn works_in_collection() { + let key = ThreadKey::get().unwrap(); + let collection = + OwnedLockCollection::new(OwnedLockCollection::new([RwLock::new(0), RwLock::new(1)])); + + let mut guard = collection.lock(key); + assert_eq!(*guard[0], 0); + assert_eq!(*guard[1], 1); + *guard[1] = 2; + + let key = OwnedLockCollection::; 2]>>::unlock(guard); + let guard = collection.read(key); + assert_eq!(*guard[0], 0); + assert_eq!(*guard[1], 2); + } + + #[test] + fn as_mut_works() { + let mut mutexes = [Mutex::new(0), Mutex::new(1)]; + let mut collection = OwnedLockCollection::new(&mut mutexes); + + collection.as_mut()[0] = Mutex::new(42); + + assert_eq!(*collection.as_mut()[0].get_mut(), 42); + } + + #[test] + fn child_mut_works() { + let mut mutexes = [Mutex::new(0), Mutex::new(1)]; + let mut collection = OwnedLockCollection::new(&mut mutexes); + + collection.child_mut()[0] = Mutex::new(42); + + assert_eq!(*collection.child_mut()[0].get_mut(), 42); + } + + #[test] + fn into_child_works() { + let mutexes = [Mutex::new(0), Mutex::new(1)]; + let mut collection = OwnedLockCollection::new(mutexes); + + collection.child_mut()[0] = Mutex::new(42); + + assert_eq!( + *collection + .into_child() + .as_mut() + .get_mut(0) + .unwrap() + .get_mut(), + 42 + ); + } } diff --git a/src/collection/ref.rs b/src/collection/ref.rs index b68b72f..5f96533 100644 --- a/src/collection/ref.rs +++ b/src/collection/ref.rs @@ -3,7 +3,10 @@ use std::fmt::Debug; use crate::lockable::{Lockable, OwnedLockable, RawLock, Sharable}; use crate::{Keyable, ThreadKey}; -use super::utils::{get_locks, ordered_contains_duplicates}; +use super::utils::{ + get_locks, ordered_contains_duplicates, scoped_read, scoped_try_read, scoped_try_write, + scoped_write, +}; use super::{utils, LockGuard, RefLockCollection}; impl<'a, L> IntoIterator for &'a RefLockCollection<'a, L> @@ -27,17 +30,17 @@ unsafe impl RawLock for RefLockCollection<'_, L> { } } - unsafe fn raw_lock(&self) { - utils::ordered_lock(&self.locks) + unsafe fn raw_write(&self) { + utils::ordered_write(&self.locks) } - unsafe fn raw_try_lock(&self) -> bool { - utils::ordered_try_lock(&self.locks) + unsafe fn raw_try_write(&self) -> bool { + utils::ordered_try_write(&self.locks) } - unsafe fn raw_unlock(&self) { + unsafe fn raw_unlock_write(&self) { for lock in &self.locks { - lock.raw_unlock(); + lock.raw_unlock_write(); } } @@ -68,7 +71,7 @@ unsafe impl Lockable for RefLockCollection<'_, L> { Self: 'a; fn get_ptrs<'a>(&'a self, ptrs: &mut Vec<&'a dyn RawLock>) { - ptrs.extend_from_slice(&self.locks); + ptrs.push(self) } unsafe fn guard(&self) -> Self::Guard<'_> { @@ -100,7 +103,7 @@ unsafe impl Sharable for RefLockCollection<'_, L> { } } -impl> AsRef for RefLockCollection<'_, L> { +impl> AsRef for RefLockCollection<'_, L> { fn as_ref(&self) -> &T { self.data.as_ref() } @@ -234,44 +237,16 @@ impl<'a, L: Lockable> RefLockCollection<'a, L> { Some(Self { data, locks }) } - pub fn scoped_lock(&self, key: impl Keyable, f: impl Fn(L::DataMut<'_>) -> R) -> R { - unsafe { - // safety: we have the thread key - self.raw_lock(); - - // safety: the data was just locked - let r = f(self.data_mut()); - - // safety: the collection is still locked - self.raw_unlock(); - - drop(key); // ensure the key stays alive long enough - - r - } + pub fn scoped_lock<'s, R>(&'s self, key: impl Keyable, f: impl Fn(L::DataMut<'s>) -> R) -> R { + scoped_write(self, key, f) } - pub fn scoped_try_lock( - &self, + pub fn scoped_try_lock<'s, Key: Keyable, R>( + &'s self, key: Key, - f: impl Fn(L::DataMut<'_>) -> R, + f: impl Fn(L::DataMut<'s>) -> R, ) -> Result { - unsafe { - // safety: we have the thread key - if !self.raw_try_lock() { - return Err(key); - } - - // safety: we just locked the collection - let r = f(self.data_mut()); - - // safety: the collection is still locked - self.raw_unlock(); - - drop(key); // ensures the key stays valid long enough - - Ok(r) - } + scoped_try_write(self, key, f) } /// Locks the collection @@ -298,7 +273,7 @@ impl<'a, L: Lockable> RefLockCollection<'a, L> { pub fn lock(&self, key: ThreadKey) -> LockGuard> { let guard = unsafe { // safety: we have the thread key - self.raw_lock(); + self.raw_write(); // safety: we've locked all of this already self.data.guard() @@ -339,7 +314,7 @@ impl<'a, L: Lockable> RefLockCollection<'a, L> { /// ``` pub fn try_lock(&self, key: ThreadKey) -> Result>, ThreadKey> { let guard = unsafe { - if !self.raw_try_lock() { + if !self.raw_try_write() { return Err(key); } @@ -376,44 +351,16 @@ impl<'a, L: Lockable> RefLockCollection<'a, L> { } impl RefLockCollection<'_, L> { - pub fn scoped_read(&self, key: impl Keyable, f: impl Fn(L::DataRef<'_>) -> R) -> R { - unsafe { - // safety: we have the thread key - self.raw_read(); - - // safety: the data was just locked - let r = f(self.data_ref()); - - // safety: the collection is still locked - self.raw_unlock_read(); - - drop(key); // ensure the key stays alive long enough - - r - } + pub fn scoped_read<'a, R>(&'a self, key: impl Keyable, f: impl Fn(L::DataRef<'a>) -> R) -> R { + scoped_read(self, key, f) } - pub fn scoped_try_read( - &self, + pub fn scoped_try_read<'a, Key: Keyable, R>( + &'a self, key: Key, - f: impl Fn(L::DataRef<'_>) -> R, + f: impl Fn(L::DataRef<'a>) -> R, ) -> Result { - unsafe { - // safety: we have the thread key - if !self.raw_try_read() { - return Err(key); - } - - // safety: we just locked the collection - let r = f(self.data_ref()); - - // safety: the collection is still locked - self.raw_unlock_read(); - - drop(key); // ensures the key stays valid long enough - - Ok(r) - } + scoped_try_read(self, key, f) } /// Locks the collection, so that other threads can still read from it @@ -564,6 +511,88 @@ mod tests { assert!(RefLockCollection::try_new(&[&mutex1, &mutex1]).is_none()) } + #[test] + fn from() { + let key = ThreadKey::get().unwrap(); + let mutexes = [Mutex::new("foo"), Mutex::new("bar"), Mutex::new("baz")]; + let collection = RefLockCollection::from(&mutexes); + let guard = collection.lock(key); + assert_eq!(*guard[0], "foo"); + assert_eq!(*guard[1], "bar"); + assert_eq!(*guard[2], "baz"); + } + + #[test] + fn scoped_lock_changes_collection() { + let mut key = ThreadKey::get().unwrap(); + let mutexes = [Mutex::new(24), Mutex::new(42)]; + let collection = RefLockCollection::new(&mutexes); + let sum = collection.scoped_lock(&mut key, |guard| { + *guard[0] = 128; + *guard[0] + *guard[1] + }); + + assert_eq!(sum, 128 + 42); + + let guard = collection.lock(key); + assert_eq!(*guard[0], 128); + assert_eq!(*guard[1], 42); + } + + #[test] + fn scoped_read_sees_changes() { + let mut key = ThreadKey::get().unwrap(); + let mutexes = [RwLock::new(24), RwLock::new(42)]; + let collection = RefLockCollection::new(&mutexes); + collection.scoped_lock(&mut key, |guard| { + *guard[0] = 128; + }); + + let sum = collection.scoped_read(&mut key, |guard| { + assert_eq!(*guard[0], 128); + assert_eq!(*guard[1], 42); + *guard[0] + *guard[1] + }); + + assert_eq!(sum, 128 + 42); + } + + #[test] + fn scoped_try_lock_can_fail() { + let key = ThreadKey::get().unwrap(); + let locks = [Mutex::new(1), Mutex::new(2)]; + let collection = RefLockCollection::new(&locks); + let guard = collection.lock(key); + + std::thread::scope(|s| { + s.spawn(|| { + let key = ThreadKey::get().unwrap(); + let r = collection.scoped_try_lock(key, |_| {}); + assert!(r.is_err()); + }); + }); + + drop(guard); + } + + #[test] + fn scoped_try_read_can_fail() { + let key = ThreadKey::get().unwrap(); + let locks = [RwLock::new(1), RwLock::new(2)]; + let collection = RefLockCollection::new(&locks); + let guard = collection.lock(key); + + std::thread::scope(|s| { + s.spawn(|| { + let key = ThreadKey::get().unwrap(); + let r = collection.scoped_try_read(key, |_| {}); + assert!(r.is_err()); + }); + }); + + drop(guard); + } + #[test] fn try_lock_succeeds_for_unlocked_collection() { let key = ThreadKey::get().unwrap(); @@ -643,18 +672,86 @@ mod tests { assert_eq!(*guard[1], 42); } + #[test] + fn into_ref_iterator() { + let mut key = ThreadKey::get().unwrap(); + let mutexes = [Mutex::new(0), Mutex::new(1), Mutex::new(2)]; + let collection = RefLockCollection::new(&mutexes); + for (i, mutex) in (&collection).into_iter().enumerate() { + mutex.scoped_lock(&mut key, |val| assert_eq!(*val, i)) + } + } + + #[test] + fn ref_iterator() { + let mut key = ThreadKey::get().unwrap(); + let mutexes = [Mutex::new(0), Mutex::new(1), Mutex::new(2)]; + let collection = RefLockCollection::new(&mutexes); + for (i, mutex) in collection.iter().enumerate() { + mutex.scoped_lock(&mut key, |val| assert_eq!(*val, i)) + } + } + #[test] fn works_in_collection() { let key = ThreadKey::get().unwrap(); - let mutex1 = Mutex::new(0); - let mutex2 = Mutex::new(1); + let mutex1 = RwLock::new(0); + let mutex2 = RwLock::new(1); let collection0 = [&mutex1, &mutex2]; let collection1 = RefLockCollection::try_new(&collection0).unwrap(); let collection = RefLockCollection::try_new(&collection1).unwrap(); - let guard = collection.lock(key); + let mut guard = collection.lock(key); assert!(mutex1.is_locked()); assert!(mutex2.is_locked()); + assert_eq!(*guard[0], 0); + assert_eq!(*guard[1], 1); + *guard[1] = 2; drop(guard); + + let key = ThreadKey::get().unwrap(); + let guard = collection.read(key); + assert!(mutex1.is_locked()); + assert!(mutex2.is_locked()); + assert_eq!(*guard[0], 0); + assert_eq!(*guard[1], 2); + } + + #[test] + fn unlock_collection_works() { + let key = ThreadKey::get().unwrap(); + let mutexes = (Mutex::new("foo"), Mutex::new("bar")); + let collection = RefLockCollection::new(&mutexes); + let guard = collection.lock(key); + + let key = RefLockCollection::<(Mutex<_>, Mutex<_>)>::unlock(guard); + assert!(collection.try_lock(key).is_ok()) + } + + #[test] + fn read_unlock_collection_works() { + let key = ThreadKey::get().unwrap(); + let locks = (RwLock::new("foo"), RwLock::new("bar")); + let collection = RefLockCollection::new(&locks); + let guard = collection.read(key); + + let key = RefLockCollection::<(&RwLock<_>, &RwLock<_>)>::unlock_read(guard); + assert!(collection.try_lock(key).is_ok()) + } + + #[test] + fn as_ref_works() { + let mutexes = [Mutex::new(0), Mutex::new(1)]; + let collection = RefLockCollection::new(&mutexes); + + assert!(std::ptr::addr_eq(&mutexes, collection.as_ref())) + } + + #[test] + fn child() { + let mutexes = [Mutex::new(0), Mutex::new(1)]; + let collection = RefLockCollection::new(&mutexes); + + assert!(std::ptr::addr_eq(&mutexes, collection.child())) } } diff --git a/src/collection/retry.rs b/src/collection/retry.rs index 775ea29..70e5183 100644 --- a/src/collection/retry.rs +++ b/src/collection/retry.rs @@ -9,7 +9,8 @@ use crate::lockable::{ use crate::{Keyable, ThreadKey}; use super::utils::{ - attempt_to_recover_locks_from_panic, attempt_to_recover_reads_from_panic, get_locks_unsorted, + attempt_to_recover_reads_from_panic, attempt_to_recover_writes_from_panic, get_locks_unsorted, + scoped_read, scoped_try_read, scoped_try_write, scoped_write, }; use super::{LockGuard, RetryingLockCollection}; @@ -40,7 +41,7 @@ unsafe impl RawLock for RetryingLockCollection { } } - unsafe fn raw_lock(&self) { + unsafe fn raw_write(&self) { let locks = get_locks_unsorted(&self.data); if locks.is_empty() { @@ -57,7 +58,7 @@ unsafe impl RawLock for RetryingLockCollection { // This prevents us from entering a spin loop waiting for // the same lock to be unlocked // safety: we have the thread key - locks[first_index.get()].raw_lock(); + locks[first_index.get()].raw_write(); for (i, lock) in locks.iter().enumerate() { if i == first_index.get() { // we've already locked this one @@ -69,15 +70,15 @@ unsafe impl RawLock for RetryingLockCollection { // it does return false, then the lock function is called // immediately after, causing a panic // safety: we have the thread key - if lock.raw_try_lock() { + if lock.raw_try_write() { locked.set(locked.get() + 1); } else { // safety: we already locked all of these - attempt_to_recover_locks_from_panic(&locks[0..i]); + attempt_to_recover_writes_from_panic(&locks[0..i]); if first_index.get() >= i { // safety: this is already locked and can't be // unlocked by the previous loop - locks[first_index.get()].raw_unlock(); + locks[first_index.get()].raw_unlock_write(); } // nothing is locked anymore @@ -94,15 +95,15 @@ unsafe impl RawLock for RetryingLockCollection { } }, || { - utils::attempt_to_recover_locks_from_panic(&locks[0..locked.get()]); + utils::attempt_to_recover_writes_from_panic(&locks[0..locked.get()]); if first_index.get() >= locked.get() { - locks[first_index.get()].raw_unlock(); + locks[first_index.get()].raw_unlock_write(); } }, ) } - unsafe fn raw_try_lock(&self) -> bool { + unsafe fn raw_try_write(&self) -> bool { let locks = get_locks_unsorted(&self.data); if locks.is_empty() { @@ -117,26 +118,26 @@ unsafe impl RawLock for RetryingLockCollection { || unsafe { for (i, lock) in locks.iter().enumerate() { // safety: we have the thread key - if lock.raw_try_lock() { + if lock.raw_try_write() { locked.set(locked.get() + 1); } else { // safety: we already locked all of these - attempt_to_recover_locks_from_panic(&locks[0..i]); + attempt_to_recover_writes_from_panic(&locks[0..i]); return false; } } true }, - || utils::attempt_to_recover_locks_from_panic(&locks[0..locked.get()]), + || utils::attempt_to_recover_writes_from_panic(&locks[0..locked.get()]), ) } - unsafe fn raw_unlock(&self) { + unsafe fn raw_unlock_write(&self) { let locks = get_locks_unsorted(&self.data); for lock in locks { - lock.raw_unlock(); + lock.raw_unlock_write(); } } @@ -243,7 +244,7 @@ unsafe impl Lockable for RetryingLockCollection { Self: 'a; fn get_ptrs<'a>(&'a self, ptrs: &mut Vec<&'a dyn RawLock>) { - self.data.get_ptrs(ptrs) + ptrs.push(self) } unsafe fn guard(&self) -> Self::Guard<'_> { @@ -347,13 +348,13 @@ impl, L: OwnedLockable> Extend for RetryingLockC } } -impl> AsRef for RetryingLockCollection { +impl> AsRef for RetryingLockCollection { fn as_ref(&self) -> &T { self.data.as_ref() } } -impl> AsMut for RetryingLockCollection { +impl> AsMut for RetryingLockCollection { fn as_mut(&mut self) -> &mut T { self.data.as_mut() } @@ -389,7 +390,8 @@ impl RetryingLockCollection { /// ``` #[must_use] pub const fn new(data: L) -> Self { - Self { data } + // safety: the data cannot cannot contain references + unsafe { Self::new_unchecked(data) } } } @@ -410,7 +412,8 @@ impl<'a, L: OwnedLockable> RetryingLockCollection<&'a L> { /// ``` #[must_use] pub const fn new_ref(data: &'a L) -> Self { - Self { data } + // safety: the data cannot cannot contain references + unsafe { Self::new_unchecked(data) } } } @@ -525,47 +528,20 @@ impl RetryingLockCollection { /// ``` #[must_use] pub fn try_new(data: L) -> Option { - (!contains_duplicates(&data)).then_some(Self { data }) + // safety: the data is checked for duplicates before returning the collection + (!contains_duplicates(&data)).then_some(unsafe { Self::new_unchecked(data) }) } - pub fn scoped_lock(&self, key: impl Keyable, f: impl Fn(L::DataMut<'_>) -> R) -> R { - unsafe { - // safety: we have the thread key - self.raw_lock(); - - // safety: the data was just locked - let r = f(self.data_mut()); - - // safety: the collection is still locked - self.raw_unlock(); - - drop(key); // ensure the key stays alive long enough - - r - } + pub fn scoped_lock<'a, R>(&'a self, key: impl Keyable, f: impl Fn(L::DataMut<'a>) -> R) -> R { + scoped_write(self, key, f) } - pub fn scoped_try_lock( - &self, + pub fn scoped_try_lock<'a, Key: Keyable, R>( + &'a self, key: Key, - f: impl Fn(L::DataMut<'_>) -> R, + f: impl Fn(L::DataMut<'a>) -> R, ) -> Result { - unsafe { - // safety: we have the thread key - if !self.raw_try_lock() { - return Err(key); - } - - // safety: we just locked the collection - let r = f(self.data_mut()); - - // safety: the collection is still locked - self.raw_unlock(); - - drop(key); // ensures the key stays valid long enough - - Ok(r) - } + scoped_try_write(self, key, f) } /// Locks the collection @@ -591,7 +567,7 @@ impl RetryingLockCollection { pub fn lock(&self, key: ThreadKey) -> LockGuard> { unsafe { // safety: we're taking the thread key - self.raw_lock(); + self.raw_write(); LockGuard { // safety: we just locked the collection @@ -634,7 +610,7 @@ impl RetryingLockCollection { pub fn try_lock(&self, key: ThreadKey) -> Result>, ThreadKey> { unsafe { // safety: we're taking the thread key - if self.raw_try_lock() { + if self.raw_try_write() { Ok(LockGuard { // safety: we just succeeded in locking everything guard: self.guard(), @@ -671,44 +647,16 @@ impl RetryingLockCollection { } impl RetryingLockCollection { - pub fn scoped_read(&self, key: impl Keyable, f: impl Fn(L::DataRef<'_>) -> R) -> R { - unsafe { - // safety: we have the thread key - self.raw_read(); - - // safety: the data was just locked - let r = f(self.data_ref()); - - // safety: the collection is still locked - self.raw_unlock_read(); - - drop(key); // ensure the key stays alive long enough - - r - } + pub fn scoped_read<'a, R>(&'a self, key: impl Keyable, f: impl Fn(L::DataRef<'a>) -> R) -> R { + scoped_read(self, key, f) } - pub fn scoped_try_read( - &self, + pub fn scoped_try_read<'a, Key: Keyable, R>( + &'a self, key: Key, - f: impl Fn(L::DataRef<'_>) -> R, + f: impl Fn(L::DataRef<'a>) -> R, ) -> Result { - unsafe { - // safety: we have the thread key - if !self.raw_try_read() { - return Err(key); - } - - // safety: we just locked the collection - let r = f(self.data_ref()); - - // safety: the collection is still locked - self.raw_unlock_read(); - - drop(key); // ensures the key stays valid long enough - - Ok(r) - } + scoped_try_read(self, key, f) } /// Locks the collection, so that other threads can still read from it @@ -778,7 +726,7 @@ impl RetryingLockCollection { pub fn try_read(&self, key: ThreadKey) -> Result>, ThreadKey> { unsafe { // safety: we're taking the thread key - if !self.raw_try_lock() { + if !self.raw_try_read() { return Err(key); } @@ -911,7 +859,7 @@ where mod tests { use super::*; use crate::collection::BoxedLockCollection; - use crate::{LockCollection, Mutex, RwLock, ThreadKey}; + use crate::{Mutex, RwLock, ThreadKey}; #[test] fn nonduplicate_lock_references_are_allowed() { @@ -926,6 +874,159 @@ mod tests { assert!(RetryingLockCollection::try_new([&mutex, &mutex]).is_none()); } + #[test] + #[allow(clippy::float_cmp)] + fn uses_correct_default() { + let collection = + RetryingLockCollection::<(RwLock, Mutex>, Mutex)>::default(); + let tuple = collection.into_inner(); + assert_eq!(tuple.0, 0.0); + assert!(tuple.1.is_none()); + assert_eq!(tuple.2, 0) + } + + #[test] + fn from() { + let key = ThreadKey::get().unwrap(); + let collection = + RetryingLockCollection::from([Mutex::new("foo"), Mutex::new("bar"), Mutex::new("baz")]); + let guard = collection.lock(key); + assert_eq!(*guard[0], "foo"); + assert_eq!(*guard[1], "bar"); + assert_eq!(*guard[2], "baz"); + } + + #[test] + fn new_ref_works() { + let key = ThreadKey::get().unwrap(); + let mutexes = [Mutex::new(0), Mutex::new(1)]; + let collection = RetryingLockCollection::new_ref(&mutexes); + collection.scoped_lock(key, |guard| { + assert_eq!(*guard[0], 0); + assert_eq!(*guard[1], 1); + }) + } + + #[test] + fn scoped_read_sees_changes() { + let mut key = ThreadKey::get().unwrap(); + let mutexes = [RwLock::new(24), RwLock::new(42)]; + let collection = RetryingLockCollection::new(mutexes); + collection.scoped_lock(&mut key, |guard| *guard[0] = 128); + + let sum = collection.scoped_read(&mut key, |guard| { + assert_eq!(*guard[0], 128); + assert_eq!(*guard[1], 42); + *guard[0] + *guard[1] + }); + + assert_eq!(sum, 128 + 42); + } + + #[test] + fn get_mut_affects_scoped_read() { + let mut key = ThreadKey::get().unwrap(); + let mutexes = [RwLock::new(24), RwLock::new(42)]; + let mut collection = RetryingLockCollection::new(mutexes); + let guard = collection.get_mut(); + *guard[0] = 128; + + let sum = collection.scoped_read(&mut key, |guard| { + assert_eq!(*guard[0], 128); + assert_eq!(*guard[1], 42); + *guard[0] + *guard[1] + }); + + assert_eq!(sum, 128 + 42); + } + + #[test] + fn scoped_try_lock_can_fail() { + let key = ThreadKey::get().unwrap(); + let collection = RetryingLockCollection::new([Mutex::new(1), Mutex::new(2)]); + let guard = collection.lock(key); + + std::thread::scope(|s| { + s.spawn(|| { + let key = ThreadKey::get().unwrap(); + let r = collection.scoped_try_lock(key, |_| {}); + assert!(r.is_err()); + }); + }); + + drop(guard); + } + + #[test] + fn scoped_try_read_can_fail() { + let key = ThreadKey::get().unwrap(); + let collection = RetryingLockCollection::new([RwLock::new(1), RwLock::new(2)]); + let guard = collection.lock(key); + + std::thread::scope(|s| { + s.spawn(|| { + let key = ThreadKey::get().unwrap(); + let r = collection.scoped_try_read(key, |_| {}); + assert!(r.is_err()); + }); + }); + + drop(guard); + } + + #[test] + fn try_lock_works() { + let key = ThreadKey::get().unwrap(); + let collection = RetryingLockCollection::new([Mutex::new(1), Mutex::new(2)]); + let guard = collection.try_lock(key); + + std::thread::scope(|s| { + s.spawn(|| { + let key = ThreadKey::get().unwrap(); + let guard = collection.try_lock(key); + assert!(guard.is_err()); + }); + }); + + assert!(guard.is_ok()); + } + + #[test] + fn try_read_works() { + let key = ThreadKey::get().unwrap(); + let collection = RetryingLockCollection::new([RwLock::new(1), RwLock::new(2)]); + let guard = collection.try_read(key); + + std::thread::scope(|s| { + s.spawn(|| { + let key = ThreadKey::get().unwrap(); + let guard = collection.try_read(key); + assert!(guard.is_ok()); + }); + }); + + assert!(guard.is_ok()); + } + + #[test] + fn try_read_fails_for_locked_collection() { + let key = ThreadKey::get().unwrap(); + let mutexes = [RwLock::new(24), RwLock::new(42)]; + let collection = RetryingLockCollection::new_ref(&mutexes); + + std::thread::scope(|s| { + s.spawn(|| { + let key = ThreadKey::get().unwrap(); + let guard = mutexes[1].write(key); + assert_eq!(*guard, 42); + std::mem::forget(guard); + }); + }); + + let guard = collection.try_read(key); + assert!(guard.is_err()); + } + #[test] fn locks_all_inner_mutexes() { let key = ThreadKey::get().unwrap(); @@ -973,6 +1074,55 @@ mod tests { drop(guard); } + #[test] + fn from_iterator() { + let key = ThreadKey::get().unwrap(); + let collection: RetryingLockCollection>> = + [Mutex::new("foo"), Mutex::new("bar"), Mutex::new("baz")] + .into_iter() + .collect(); + let guard = collection.lock(key); + assert_eq!(*guard[0], "foo"); + assert_eq!(*guard[1], "bar"); + assert_eq!(*guard[2], "baz"); + } + + #[test] + fn into_owned_iterator() { + let collection = RetryingLockCollection::new([Mutex::new(0), Mutex::new(1), Mutex::new(2)]); + for (i, mutex) in collection.into_iter().enumerate() { + assert_eq!(mutex.into_inner(), i); + } + } + + #[test] + fn into_ref_iterator() { + let mut key = ThreadKey::get().unwrap(); + let collection = RetryingLockCollection::new([Mutex::new(0), Mutex::new(1), Mutex::new(2)]); + for (i, mutex) in (&collection).into_iter().enumerate() { + mutex.scoped_lock(&mut key, |val| assert_eq!(*val, i)) + } + } + + #[test] + fn ref_iterator() { + let mut key = ThreadKey::get().unwrap(); + let collection = RetryingLockCollection::new([Mutex::new(0), Mutex::new(1), Mutex::new(2)]); + for (i, mutex) in collection.iter().enumerate() { + mutex.scoped_lock(&mut key, |val| assert_eq!(*val, i)) + } + } + + #[test] + fn mut_iterator() { + let mut key = ThreadKey::get().unwrap(); + let mut collection = + RetryingLockCollection::new([Mutex::new(0), Mutex::new(1), Mutex::new(2)]); + for (i, mutex) in collection.iter_mut().enumerate() { + mutex.scoped_lock(&mut key, |val| assert_eq!(*val, i)) + } + } + #[test] fn extend_collection() { let mutex1 = Mutex::new(0); @@ -991,9 +1141,76 @@ mod tests { let guard = collection.lock(key); assert!(guard.len() == 0); - let key = LockCollection::<[RwLock<_>; 0]>::unlock(guard); + let key = RetryingLockCollection::<[RwLock<_>; 0]>::unlock(guard); + + let guard = collection.read(key); + assert!(guard.len() == 0); + } + + #[test] + fn read_empty_lock_collection() { + let key = ThreadKey::get().unwrap(); + let collection: RetryingLockCollection<[RwLock; 0]> = RetryingLockCollection::new([]); let guard = collection.read(key); assert!(guard.len() == 0); + let key = RetryingLockCollection::<[RwLock<_>; 0]>::unlock_read(guard); + + let guard = collection.lock(key); + assert!(guard.len() == 0); + } + + #[test] + fn as_ref_works() { + let mutexes = [Mutex::new(0), Mutex::new(1)]; + let collection = RetryingLockCollection::new_ref(&mutexes); + + assert!(std::ptr::addr_eq(&mutexes, collection.as_ref())) + } + + #[test] + fn as_mut_works() { + let mut mutexes = [Mutex::new(0), Mutex::new(1)]; + let mut collection = RetryingLockCollection::new(&mut mutexes); + + collection.as_mut()[0] = Mutex::new(42); + + assert_eq!(*collection.as_mut()[0].get_mut(), 42); + } + + #[test] + fn child() { + let mutexes = [Mutex::new(0), Mutex::new(1)]; + let collection = RetryingLockCollection::new_ref(&mutexes); + + assert!(std::ptr::addr_eq(&mutexes, *collection.child())) + } + + #[test] + fn child_mut_works() { + let mut mutexes = [Mutex::new(0), Mutex::new(1)]; + let mut collection = RetryingLockCollection::new(&mut mutexes); + + collection.child_mut()[0] = Mutex::new(42); + + assert_eq!(*collection.child_mut()[0].get_mut(), 42); + } + + #[test] + fn into_child_works() { + let mutexes = [Mutex::new(0), Mutex::new(1)]; + let mut collection = RetryingLockCollection::new(mutexes); + + collection.child_mut()[0] = Mutex::new(42); + + assert_eq!( + *collection + .into_child() + .as_mut() + .get_mut(0) + .unwrap() + .get_mut(), + 42 + ); } } diff --git a/src/collection/utils.rs b/src/collection/utils.rs index 1d96e5c..59a68da 100644 --- a/src/collection/utils.rs +++ b/src/collection/utils.rs @@ -1,7 +1,8 @@ use std::cell::Cell; use crate::handle_unwind::handle_unwind; -use crate::lockable::{Lockable, RawLock}; +use crate::lockable::{Lockable, RawLock, Sharable}; +use crate::Keyable; #[must_use] pub fn get_locks(data: &L) -> Vec<&dyn RawLock> { @@ -32,18 +33,18 @@ pub fn ordered_contains_duplicates(l: &[&dyn RawLock]) -> bool { } /// Lock a set of locks in the given order. It's UB to call this without a `ThreadKey` -pub unsafe fn ordered_lock(locks: &[&dyn RawLock]) { +pub unsafe fn ordered_write(locks: &[&dyn RawLock]) { // these will be unlocked in case of a panic let locked = Cell::new(0); handle_unwind( || { for lock in locks { - lock.raw_lock(); + lock.raw_write(); locked.set(locked.get() + 1); } }, - || attempt_to_recover_locks_from_panic(&locks[0..locked.get()]), + || attempt_to_recover_writes_from_panic(&locks[0..locked.get()]), ) } @@ -65,19 +66,19 @@ pub unsafe fn ordered_read(locks: &[&dyn RawLock]) { /// Locks the locks in the order they are given. This causes deadlock if the /// locks contain duplicates, or if this is called by multiple threads with the /// locks in different orders. -pub unsafe fn ordered_try_lock(locks: &[&dyn RawLock]) -> bool { +pub unsafe fn ordered_try_write(locks: &[&dyn RawLock]) -> bool { let locked = Cell::new(0); handle_unwind( || unsafe { for (i, lock) in locks.iter().enumerate() { // safety: we have the thread key - if lock.raw_try_lock() { + if lock.raw_try_write() { locked.set(locked.get() + 1); } else { for lock in &locks[0..i] { // safety: this lock was already acquired - lock.raw_unlock(); + lock.raw_unlock_write(); } return false; } @@ -87,7 +88,7 @@ pub unsafe fn ordered_try_lock(locks: &[&dyn RawLock]) -> bool { }, || // safety: everything in locked is locked - attempt_to_recover_locks_from_panic(&locks[0..locked.get()]), + attempt_to_recover_writes_from_panic(&locks[0..locked.get()]), ) } @@ -120,12 +121,104 @@ pub unsafe fn ordered_try_read(locks: &[&dyn RawLock]) -> bool { ) } +pub fn scoped_write<'a, L: RawLock + Lockable, R>( + collection: &'a L, + key: impl Keyable, + f: impl FnOnce(L::DataMut<'a>) -> R, +) -> R { + unsafe { + // safety: we have the key + collection.raw_write(); + + // safety: we just locked this + let r = f(collection.data_mut()); + + // this ensures the key is held long enough + drop(key); + + // safety: we've locked already, and aren't using the data again + collection.raw_unlock_write(); + + r + } +} + +pub fn scoped_try_write<'a, L: RawLock + Lockable, Key: Keyable, R>( + collection: &'a L, + key: Key, + f: impl FnOnce(L::DataMut<'a>) -> R, +) -> Result { + unsafe { + // safety: we have the key + if !collection.raw_try_write() { + return Err(key); + } + + // safety: we just locked this + let r = f(collection.data_mut()); + + // this ensures the key is held long enough + drop(key); + + // safety: we've locked already, and aren't using the data again + collection.raw_unlock_write(); + + Ok(r) + } +} + +pub fn scoped_read<'a, L: RawLock + Sharable, R>( + collection: &'a L, + key: impl Keyable, + f: impl FnOnce(L::DataRef<'a>) -> R, +) -> R { + unsafe { + // safety: we have the key + collection.raw_read(); + + // safety: we just locked this + let r = f(collection.data_ref()); + + // this ensures the key is held long enough + drop(key); + + // safety: we've locked already, and aren't using the data again + collection.raw_unlock_read(); + + r + } +} + +pub fn scoped_try_read<'a, L: RawLock + Sharable, Key: Keyable, R>( + collection: &'a L, + key: Key, + f: impl FnOnce(L::DataRef<'a>) -> R, +) -> Result { + unsafe { + // safety: we have the key + if !collection.raw_try_read() { + return Err(key); + } + + // safety: we just locked this + let r = f(collection.data_ref()); + + // this ensures the key is held long enough + drop(key); + + // safety: we've locked already, and aren't using the data again + collection.raw_unlock_read(); + + Ok(r) + } +} + /// Unlocks the already locked locks in order to recover from a panic -pub unsafe fn attempt_to_recover_locks_from_panic(locks: &[&dyn RawLock]) { +pub unsafe fn attempt_to_recover_writes_from_panic(locks: &[&dyn RawLock]) { handle_unwind( || { // safety: the caller assumes that these are already locked - locks.iter().for_each(|lock| lock.raw_unlock()); + locks.iter().for_each(|lock| lock.raw_unlock_write()); }, // if we get another panic in here, we'll just have to poison what remains || locks.iter().for_each(|l| l.poison()), diff --git a/src/lockable.rs b/src/lockable.rs index f125c02..94042ea 100644 --- a/src/lockable.rs +++ b/src/lockable.rs @@ -28,7 +28,7 @@ pub unsafe trait RawLock { /// value is alive. /// /// [`ThreadKey`]: `crate::ThreadKey` - unsafe fn raw_lock(&self); + unsafe fn raw_write(&self); /// Attempt to lock without blocking. /// @@ -41,14 +41,14 @@ pub unsafe trait RawLock { /// value is alive. /// /// [`ThreadKey`]: `crate::ThreadKey` - unsafe fn raw_try_lock(&self) -> bool; + unsafe fn raw_try_write(&self) -> bool; /// Releases the lock /// /// # Safety /// /// It is undefined behavior to use this if the lock is not acquired - unsafe fn raw_unlock(&self); + unsafe fn raw_unlock_write(&self); /// Blocks until the data the lock protects can be safely read. /// @@ -625,7 +625,7 @@ unsafe impl OwnedLockable for Vec {} #[cfg(test)] mod tests { use super::*; - use crate::{Mutex, RwLock}; + use crate::{LockCollection, Mutex, RwLock, ThreadKey}; #[test] fn mut_ref_get_ptrs() { @@ -718,6 +718,57 @@ mod tests { assert_eq!(lock_ptrs[1], 2); } + #[test] + fn vec_guard_ref() { + let key = ThreadKey::get().unwrap(); + let locks = vec![RwLock::new(1), RwLock::new(2)]; + let collection = LockCollection::new(locks); + + let mut guard = collection.lock(key); + assert_eq!(*guard[0], 1); + assert_eq!(*guard[1], 2); + *guard[0] = 3; + + let key = LockCollection::>>::unlock(guard); + let guard = collection.read(key); + assert_eq!(*guard[0], 3); + assert_eq!(*guard[1], 2); + } + + #[test] + fn vec_data_mut() { + let mut key = ThreadKey::get().unwrap(); + let mutexes = vec![Mutex::new(1), Mutex::new(2)]; + let collection = LockCollection::new(mutexes); + collection.scoped_lock(&mut key, |guard| { + assert_eq!(*guard[0], 1); + assert_eq!(*guard[1], 2); + *guard[0] = 3; + }); + + collection.scoped_lock(&mut key, |guard| { + assert_eq!(*guard[0], 3); + assert_eq!(*guard[1], 2); + }) + } + + #[test] + fn vec_data_ref() { + let mut key = ThreadKey::get().unwrap(); + let mutexes = vec![RwLock::new(1), RwLock::new(2)]; + let collection = LockCollection::new(mutexes); + collection.scoped_lock(&mut key, |guard| { + assert_eq!(*guard[0], 1); + assert_eq!(*guard[1], 2); + *guard[0] = 3; + }); + + collection.scoped_read(&mut key, |guard| { + assert_eq!(*guard[0], 3); + assert_eq!(*guard[1], 2); + }) + } + #[test] fn box_get_ptrs_empty() { let locks: Box<[Mutex<()>]> = Box::from([]); @@ -757,4 +808,72 @@ mod tests { assert_eq!(*lock_ptrs[0], 1); assert_eq!(*lock_ptrs[1], 2); } + + #[test] + fn box_guard_mut() { + let key = ThreadKey::get().unwrap(); + let x = [Mutex::new(1), Mutex::new(2)]; + let collection: LockCollection]>> = LockCollection::new(Box::new(x)); + + let mut guard = collection.lock(key); + assert_eq!(*guard[0], 1); + assert_eq!(*guard[1], 2); + *guard[0] = 3; + + let key = LockCollection::]>>::unlock(guard); + let guard = collection.lock(key); + assert_eq!(*guard[0], 3); + assert_eq!(*guard[1], 2); + } + + #[test] + fn box_data_mut() { + let mut key = ThreadKey::get().unwrap(); + let mutexes = vec![Mutex::new(1), Mutex::new(2)].into_boxed_slice(); + let collection = LockCollection::new(mutexes); + collection.scoped_lock(&mut key, |guard| { + assert_eq!(*guard[0], 1); + assert_eq!(*guard[1], 2); + *guard[0] = 3; + }); + + collection.scoped_lock(&mut key, |guard| { + assert_eq!(*guard[0], 3); + assert_eq!(*guard[1], 2); + }); + } + + #[test] + fn box_guard_ref() { + let key = ThreadKey::get().unwrap(); + let locks = [RwLock::new(1), RwLock::new(2)]; + let collection: LockCollection]>> = LockCollection::new(Box::new(locks)); + + let mut guard = collection.lock(key); + assert_eq!(*guard[0], 1); + assert_eq!(*guard[1], 2); + *guard[0] = 3; + + let key = LockCollection::]>>::unlock(guard); + let guard = collection.read(key); + assert_eq!(*guard[0], 3); + assert_eq!(*guard[1], 2); + } + + #[test] + fn box_data_ref() { + let mut key = ThreadKey::get().unwrap(); + let mutexes = vec![RwLock::new(1), RwLock::new(2)].into_boxed_slice(); + let collection = LockCollection::new(mutexes); + collection.scoped_lock(&mut key, |guard| { + assert_eq!(*guard[0], 1); + assert_eq!(*guard[1], 2); + *guard[0] = 3; + }); + + collection.scoped_read(&mut key, |guard| { + assert_eq!(*guard[0], 3); + assert_eq!(*guard[1], 2); + }); + } } diff --git a/src/mutex.rs b/src/mutex.rs index 2022501..413bd8a 100644 --- a/src/mutex.rs +++ b/src/mutex.rs @@ -136,10 +136,7 @@ pub struct Mutex { /// A reference to a mutex that unlocks it when dropped. /// /// This is similar to [`MutexGuard`], except it does not hold a [`Keyable`]. -pub struct MutexRef<'a, T: ?Sized + 'a, R: RawMutex>( - &'a Mutex, - PhantomData<(&'a mut T, R::GuardMarker)>, -); +pub struct MutexRef<'a, T: ?Sized + 'a, R: RawMutex>(&'a Mutex, PhantomData); /// An RAII implementation of a “scoped lock” of a mutex. /// @@ -182,6 +179,26 @@ mod tests { drop(guard) } + #[test] + fn from_works() { + let key = ThreadKey::get().unwrap(); + let mutex: crate::Mutex<_> = Mutex::from("Hello, world!"); + + let guard = mutex.lock(key); + assert_eq!(*guard, "Hello, world!"); + } + + #[test] + fn as_mut_works() { + let key = ThreadKey::get().unwrap(); + let mut mutex = crate::Mutex::from(42); + + let mut_ref = mutex.as_mut(); + *mut_ref = 24; + + mutex.scoped_lock(key, |guard| assert_eq!(*guard, 24)) + } + #[test] fn display_works_for_guard() { let key = ThreadKey::get().unwrap(); diff --git a/src/mutex/guard.rs b/src/mutex/guard.rs index 22e59c1..d88fded 100644 --- a/src/mutex/guard.rs +++ b/src/mutex/guard.rs @@ -39,7 +39,7 @@ impl Drop for MutexRef<'_, T, R> { fn drop(&mut self) { // safety: this guard is being destroyed, so the data cannot be // accessed without locking again - unsafe { self.0.raw_unlock() } + unsafe { self.0.raw_unlock_write() } } } @@ -79,7 +79,7 @@ impl<'a, T: ?Sized, R: RawMutex> MutexRef<'a, T, R> { /// Creates a reference to the underlying data of a mutex without /// attempting to lock it or take ownership of the key. But it's also quite /// dangerous to drop. - pub(crate) unsafe fn new(mutex: &'a Mutex) -> Self { + pub(crate) const unsafe fn new(mutex: &'a Mutex) -> Self { Self(mutex, PhantomData) } } @@ -139,7 +139,7 @@ impl<'a, T: ?Sized, R: RawMutex> MutexGuard<'a, T, R> { /// Create a guard to the given mutex. Undefined if multiple guards to the /// same mutex exist at once. #[must_use] - pub(super) unsafe fn new(mutex: &'a Mutex, thread_key: ThreadKey) -> Self { + pub(super) const unsafe fn new(mutex: &'a Mutex, thread_key: ThreadKey) -> Self { Self { mutex: MutexRef(mutex, PhantomData), thread_key, diff --git a/src/mutex/mutex.rs b/src/mutex/mutex.rs index 1d8ce8b..f0fb680 100644 --- a/src/mutex/mutex.rs +++ b/src/mutex/mutex.rs @@ -5,6 +5,7 @@ use std::panic::AssertUnwindSafe; use lock_api::RawMutex; +use crate::collection::utils; use crate::handle_unwind::handle_unwind; use crate::lockable::{Lockable, LockableGetMut, LockableIntoInner, OwnedLockable, RawLock}; use crate::poisonable::PoisonFlag; @@ -17,7 +18,7 @@ unsafe impl RawLock for Mutex { self.poison.poison(); } - unsafe fn raw_lock(&self) { + unsafe fn raw_write(&self) { assert!(!self.poison.is_poisoned(), "The mutex has been killed"); // if the closure unwraps, then the mutex will be killed @@ -25,7 +26,7 @@ unsafe impl RawLock for Mutex { handle_unwind(|| this.raw.lock(), || self.poison()) } - unsafe fn raw_try_lock(&self) -> bool { + unsafe fn raw_try_write(&self) -> bool { if self.poison.is_poisoned() { return false; } @@ -35,7 +36,7 @@ unsafe impl RawLock for Mutex { handle_unwind(|| this.raw.try_lock(), || self.poison()) } - unsafe fn raw_unlock(&self) { + unsafe fn raw_unlock_write(&self) { // if the closure unwraps, then the mutex will be killed let this = AssertUnwindSafe(self); handle_unwind(|| this.raw.unlock(), || self.poison()) @@ -43,20 +44,26 @@ unsafe impl RawLock for Mutex { // this is the closest thing to a read we can get, but Sharable isn't // implemented for this + #[mutants::skip] + #[cfg(not(tarpaulin_include))] unsafe fn raw_read(&self) { - self.raw_lock() + self.raw_write() } + #[mutants::skip] + #[cfg(not(tarpaulin_include))] unsafe fn raw_try_read(&self) -> bool { - self.raw_try_lock() + self.raw_try_write() } + #[mutants::skip] + #[cfg(not(tarpaulin_include))] unsafe fn raw_unlock_read(&self) { - self.raw_unlock() + self.raw_unlock_write() } } -unsafe impl Lockable for Mutex { +unsafe impl Lockable for Mutex { type Guard<'g> = MutexRef<'g, T, R> where @@ -80,7 +87,7 @@ unsafe impl Lockable for Mutex { } } -impl LockableIntoInner for Mutex { +impl LockableIntoInner for Mutex { type Inner = T; fn into_inner(self) -> Self::Inner { @@ -88,7 +95,7 @@ impl LockableIntoInner for Mutex { } } -impl LockableGetMut for Mutex { +impl LockableGetMut for Mutex { type Inner<'a> = &'a mut T where @@ -99,7 +106,7 @@ impl LockableGetMut for Mutex { } } -unsafe impl OwnedLockable for Mutex {} +unsafe impl OwnedLockable for Mutex {} impl Mutex { /// Create a new unlocked `Mutex`. @@ -140,7 +147,7 @@ impl Mutex { #[mutants::skip] #[cfg(not(tarpaulin_include))] -impl Debug for Mutex { +impl Debug for Mutex { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { // safety: this is just a try lock, and the value is dropped // immediately after, so there's no risk of blocking ourselves @@ -222,45 +229,21 @@ impl Mutex { } } -impl Mutex { - pub fn scoped_lock(&self, key: impl Keyable, f: impl FnOnce(&mut T) -> Ret) -> Ret { - unsafe { - // safety: we have the thread key - self.raw_lock(); - - // safety: the mutex was just locked - let r = f(self.data.get().as_mut().unwrap_unchecked()); - - // safety: we locked the mutex already - self.raw_unlock(); - - drop(key); // ensures we drop the key in the correct place - - r - } +impl Mutex { + pub fn scoped_lock<'a, Ret>( + &'a self, + key: impl Keyable, + f: impl FnOnce(&'a mut T) -> Ret, + ) -> Ret { + utils::scoped_write(self, key, f) } - pub fn scoped_try_lock( - &self, + pub fn scoped_try_lock<'a, Key: Keyable, Ret>( + &'a self, key: Key, - f: impl FnOnce(&mut T) -> Ret, + f: impl FnOnce(&'a mut T) -> Ret, ) -> Result { - unsafe { - // safety: we have the thread key - if !self.raw_try_lock() { - return Err(key); - } - - // safety: the mutex was just locked - let r = f(self.data.get().as_mut().unwrap_unchecked()); - - // safety: we locked the mutex already - self.raw_unlock(); - - drop(key); // ensures we drop the key in the correct place - - Ok(r) - } + utils::scoped_try_write(self, key, f) } /// Block the thread until this mutex can be locked, and lock it. @@ -289,7 +272,7 @@ impl Mutex { pub fn lock(&self, key: ThreadKey) -> MutexGuard<'_, T, R> { unsafe { // safety: we have the thread key - self.raw_lock(); + self.raw_write(); // safety: we just locked the mutex MutexGuard::new(self, key) @@ -332,7 +315,7 @@ impl Mutex { pub fn try_lock(&self, key: ThreadKey) -> Result, ThreadKey> { unsafe { // safety: we have the key to the mutex - if self.raw_try_lock() { + if self.raw_try_write() { // safety: we just locked the mutex Ok(MutexGuard::new(self, key)) } else { @@ -350,7 +333,7 @@ impl Mutex { /// Lock without a [`ThreadKey`]. It is undefined behavior to do this without /// owning the [`ThreadKey`]. pub(crate) unsafe fn try_lock_no_key(&self) -> Option> { - self.raw_try_lock().then_some(MutexRef(self, PhantomData)) + self.raw_try_write().then_some(MutexRef(self, PhantomData)) } /// Consumes the [`MutexGuard`], and consequently unlocks its `Mutex`. @@ -370,9 +353,7 @@ impl Mutex { /// ``` #[must_use] pub fn unlock(guard: MutexGuard<'_, T, R>) -> ThreadKey { - unsafe { - guard.mutex.0.raw_unlock(); - } + drop(guard.mutex); guard.thread_key } } diff --git a/src/poisonable.rs b/src/poisonable.rs index 8d5a810..f9b0622 100644 --- a/src/poisonable.rs +++ b/src/poisonable.rs @@ -99,7 +99,7 @@ mod tests { use super::*; use crate::lockable::Lockable; - use crate::{LockCollection, Mutex, ThreadKey}; + use crate::{LockCollection, Mutex, RwLock, ThreadKey}; #[test] fn locking_poisoned_mutex_returns_error_in_collection() { @@ -126,6 +126,31 @@ mod tests { assert_eq!(***error.get_ref(), 42); } + #[test] + fn locking_poisoned_rwlock_returns_error_in_collection() { + let key = ThreadKey::get().unwrap(); + let mutex = LockCollection::new(Poisonable::new(RwLock::new(42))); + + std::thread::scope(|s| { + s.spawn(|| { + let key = ThreadKey::get().unwrap(); + let mut guard1 = mutex.read(key); + let guard = guard1.as_deref_mut().unwrap(); + assert_eq!(**guard, 42); + panic!(); + + #[allow(unreachable_code)] + drop(guard1); + }) + .join() + .unwrap_err(); + }); + + let error = mutex.read(key); + let error = error.as_deref().unwrap_err(); + assert_eq!(***error.get_ref(), 42); + } + #[test] fn non_poisoned_get_mut_is_ok() { let mut mutex = Poisonable::new(Mutex::new(42)); @@ -200,6 +225,118 @@ mod tests { assert_eq!(error.into_inner().into_inner(), "foo"); } + #[test] + fn scoped_lock_can_poison() { + let key = ThreadKey::get().unwrap(); + let mutex = Poisonable::new(Mutex::new(42)); + + let r = std::panic::catch_unwind(|| { + mutex.scoped_lock(key, |num| { + *num.unwrap() = 56; + panic!(); + }) + }); + assert!(r.is_err()); + + let key = ThreadKey::get().unwrap(); + assert!(mutex.is_poisoned()); + mutex.scoped_lock(key, |num| { + let Err(error) = num else { panic!() }; + mutex.clear_poison(); + let guard = error.into_inner(); + assert_eq!(*guard, 56); + }); + assert!(!mutex.is_poisoned()); + } + + #[test] + fn scoped_try_lock_can_fail() { + let key = ThreadKey::get().unwrap(); + let mutex = Poisonable::new(Mutex::new(42)); + let guard = mutex.lock(key); + + std::thread::scope(|s| { + s.spawn(|| { + let key = ThreadKey::get().unwrap(); + let r = mutex.scoped_try_lock(key, |_| {}); + assert!(r.is_err()); + }); + }); + + drop(guard); + } + + #[test] + fn scoped_try_lock_can_succeed() { + let rwlock = Poisonable::new(RwLock::new(42)); + + std::thread::scope(|s| { + s.spawn(|| { + let key = ThreadKey::get().unwrap(); + let r = rwlock.scoped_try_lock(key, |guard| { + assert_eq!(*guard.unwrap(), 42); + }); + assert!(r.is_ok()); + }); + }); + } + + #[test] + fn scoped_read_can_poison() { + let key = ThreadKey::get().unwrap(); + let mutex = Poisonable::new(RwLock::new(42)); + + let r = std::panic::catch_unwind(|| { + mutex.scoped_read(key, |num| { + assert_eq!(*num.unwrap(), 42); + panic!(); + }) + }); + assert!(r.is_err()); + + let key = ThreadKey::get().unwrap(); + assert!(mutex.is_poisoned()); + mutex.scoped_read(key, |num| { + let Err(error) = num else { panic!() }; + mutex.clear_poison(); + let guard = error.into_inner(); + assert_eq!(*guard, 42); + }); + assert!(!mutex.is_poisoned()); + } + + #[test] + fn scoped_try_read_can_fail() { + let key = ThreadKey::get().unwrap(); + let rwlock = Poisonable::new(RwLock::new(42)); + let guard = rwlock.lock(key); + + std::thread::scope(|s| { + s.spawn(|| { + let key = ThreadKey::get().unwrap(); + let r = rwlock.scoped_try_read(key, |_| {}); + assert!(r.is_err()); + }); + }); + + drop(guard); + } + + #[test] + fn scoped_try_read_can_succeed() { + let rwlock = Poisonable::new(RwLock::new(42)); + + std::thread::scope(|s| { + s.spawn(|| { + let key = ThreadKey::get().unwrap(); + let r = rwlock.scoped_try_read(key, |guard| { + assert_eq!(*guard.unwrap(), 42); + }); + assert!(r.is_ok()); + }); + }); + } + #[test] fn display_works() { let key = ThreadKey::get().unwrap(); @@ -317,6 +454,31 @@ mod tests { assert!(!mutex.is_poisoned()); } + #[test] + fn clear_poison_for_poisoned_rwlock() { + let lock = Arc::new(Poisonable::new(RwLock::new(0))); + let c_mutex = Arc::clone(&lock); + + let _ = std::thread::spawn(move || { + let key = ThreadKey::get().unwrap(); + let lock = c_mutex.read(key).unwrap(); + assert_eq!(*lock, 42); + panic!(); // the mutex gets poisoned + }) + .join(); + + assert!(lock.is_poisoned()); + + let key = ThreadKey::get().unwrap(); + let _ = lock.lock(key).unwrap_or_else(|mut e| { + **e.get_mut() = 1; + lock.clear_poison(); + e.into_inner() + }); + + assert!(!lock.is_poisoned()); + } + #[test] fn error_as_ref() { let mutex = Poisonable::new(Mutex::new("foo")); diff --git a/src/poisonable/poisonable.rs b/src/poisonable/poisonable.rs index efe4ed0..ff78330 100644 --- a/src/poisonable/poisonable.rs +++ b/src/poisonable/poisonable.rs @@ -1,5 +1,6 @@ use std::panic::{RefUnwindSafe, UnwindSafe}; +use crate::handle_unwind::handle_unwind; use crate::lockable::{ Lockable, LockableGetMut, LockableIntoInner, OwnedLockable, RawLock, Sharable, }; @@ -17,16 +18,16 @@ unsafe impl RawLock for Poisonable { self.inner.poison() } - unsafe fn raw_lock(&self) { - self.inner.raw_lock() + unsafe fn raw_write(&self) { + self.inner.raw_write() } - unsafe fn raw_try_lock(&self) -> bool { - self.inner.raw_try_lock() + unsafe fn raw_try_write(&self) -> bool { + self.inner.raw_try_write() } - unsafe fn raw_unlock(&self) { - self.inner.raw_unlock() + unsafe fn raw_unlock_write(&self) { + self.inner.raw_unlock_write() } unsafe fn raw_read(&self) { @@ -313,13 +314,19 @@ impl Poisonable { ) -> R { unsafe { // safety: we have the thread key - self.raw_lock(); + self.raw_write(); // safety: the data was just locked - let r = f(self.data_mut()); + let r = handle_unwind( + || f(self.data_mut()), + || { + self.poisoned.poison(); + self.raw_unlock_write(); + }, + ); // safety: the collection is still locked - self.raw_unlock(); + self.raw_unlock_write(); drop(key); // ensure the key stays alive long enough @@ -334,15 +341,21 @@ impl Poisonable { ) -> Result { unsafe { // safety: we have the thread key - if !self.raw_try_lock() { + if !self.raw_try_write() { return Err(key); } // safety: we just locked the collection - let r = f(self.data_mut()); + let r = handle_unwind( + || f(self.data_mut()), + || { + self.poisoned.poison(); + self.raw_unlock_write(); + }, + ); // safety: the collection is still locked - self.raw_unlock(); + self.raw_unlock_write(); drop(key); // ensures the key stays valid long enough @@ -383,7 +396,7 @@ impl Poisonable { /// ``` pub fn lock(&self, key: ThreadKey) -> PoisonResult>> { unsafe { - self.inner.raw_lock(); + self.inner.raw_write(); self.guard(key) } } @@ -434,7 +447,7 @@ impl Poisonable { /// [`WouldBlock`]: `TryLockPoisonableError::WouldBlock` pub fn try_lock(&self, key: ThreadKey) -> TryLockPoisonableResult<'_, L::Guard<'_>> { unsafe { - if self.inner.raw_try_lock() { + if self.inner.raw_try_write() { Ok(self.guard(key)?) } else { Err(TryLockPoisonableError::WouldBlock(key)) @@ -487,7 +500,13 @@ impl Poisonable { self.raw_read(); // safety: the data was just locked - let r = f(self.data_ref()); + let r = handle_unwind( + || f(self.data_ref()), + || { + self.poisoned.poison(); + self.raw_unlock_read(); + }, + ); // safety: the collection is still locked self.raw_unlock_read(); @@ -510,7 +529,13 @@ impl Poisonable { } // safety: we just locked the collection - let r = f(self.data_ref()); + let r = handle_unwind( + || f(self.data_ref()), + || { + self.poisoned.poison(); + self.raw_unlock_read(); + }, + ); // safety: the collection is still locked self.raw_unlock_read(); diff --git a/src/rwlock.rs b/src/rwlock.rs index b604370..2d3dd85 100644 --- a/src/rwlock.rs +++ b/src/rwlock.rs @@ -75,7 +75,7 @@ pub struct WriteLock<'l, T: ?Sized, R>(&'l RwLock); /// [`Keyable`]. pub struct RwLockReadRef<'a, T: ?Sized, R: RawRwLock>( &'a RwLock, - PhantomData<(&'a mut T, R::GuardMarker)>, + PhantomData, ); /// RAII structure that unlocks the exclusive write access to a [`RwLock`] @@ -84,7 +84,7 @@ pub struct RwLockReadRef<'a, T: ?Sized, R: RawRwLock>( /// [`Keyable`]. pub struct RwLockWriteRef<'a, T: ?Sized, R: RawRwLock>( &'a RwLock, - PhantomData<(&'a mut T, R::GuardMarker)>, + PhantomData, ); /// RAII structure used to release the shared read access of a lock when @@ -115,6 +115,8 @@ pub struct RwLockWriteGuard<'a, T: ?Sized, R: RawRwLock> { #[cfg(test)] mod tests { use crate::lockable::Lockable; + use crate::lockable::RawLock; + use crate::LockCollection; use crate::RwLock; use crate::ThreadKey; @@ -148,6 +150,33 @@ mod tests { assert_eq!(*guard, "Hello, world!"); } + #[test] + fn read_lock_scoped_works() { + let mut key = ThreadKey::get().unwrap(); + let lock: crate::RwLock<_> = RwLock::new(42); + let reader = ReadLock::new(&lock); + + reader.scoped_lock(&mut key, |num| assert_eq!(*num, 42)); + } + + #[test] + fn read_lock_scoped_try_fails_during_write() { + let key = ThreadKey::get().unwrap(); + let lock: crate::RwLock<_> = RwLock::new(42); + let reader = ReadLock::new(&lock); + let guard = lock.write(key); + + std::thread::scope(|s| { + s.spawn(|| { + let key = ThreadKey::get().unwrap(); + let r = reader.scoped_try_lock(key, |_| {}); + assert!(r.is_err()); + }); + }); + + drop(guard); + } + #[test] fn write_lock_unlocked_when_initialized() { let key = ThreadKey::get().unwrap(); @@ -165,7 +194,7 @@ mod tests { readlock.get_ptrs(&mut lock_ptrs); assert_eq!(lock_ptrs.len(), 1); - assert!(std::ptr::addr_eq(lock_ptrs[0], &rwlock)); + assert!(std::ptr::addr_eq(lock_ptrs[0], &readlock)); } #[test] @@ -176,7 +205,34 @@ mod tests { writelock.get_ptrs(&mut lock_ptrs); assert_eq!(lock_ptrs.len(), 1); - assert!(std::ptr::addr_eq(lock_ptrs[0], &rwlock)); + assert!(std::ptr::addr_eq(lock_ptrs[0], &writelock)); + } + + #[test] + fn write_lock_scoped_works() { + let mut key = ThreadKey::get().unwrap(); + let lock: crate::RwLock<_> = RwLock::new(42); + let writer = WriteLock::new(&lock); + + writer.scoped_lock(&mut key, |num| assert_eq!(*num, 42)); + } + + #[test] + fn write_lock_scoped_try_fails_during_write() { + let key = ThreadKey::get().unwrap(); + let lock: crate::RwLock<_> = RwLock::new(42); + let writer = WriteLock::new(&lock); + let guard = lock.write(key); + + std::thread::scope(|s| { + s.spawn(|| { + let key = ThreadKey::get().unwrap(); + let r = writer.scoped_try_lock(key, |_| {}); + assert!(r.is_err()); + }); + }); + + drop(guard); } #[test] @@ -225,6 +281,69 @@ mod tests { drop(guard) } + #[test] + fn locked_after_scoped_write() { + let mut key = ThreadKey::get().unwrap(); + let lock = crate::RwLock::new("Hello, world!"); + + lock.scoped_write(&mut key, |guard| { + assert!(lock.is_locked()); + assert_eq!(*guard, "Hello, world!"); + + std::thread::scope(|s| { + s.spawn(|| { + let key = ThreadKey::get().unwrap(); + assert!(lock.try_read(key).is_err()); + }); + }) + }) + } + + #[test] + fn get_mut_works() { + let key = ThreadKey::get().unwrap(); + let mut lock = crate::RwLock::from(42); + + let mut_ref = lock.get_mut(); + *mut_ref = 24; + + lock.scoped_read(key, |guard| assert_eq!(*guard, 24)) + } + + #[test] + fn try_write_can_fail() { + let key = ThreadKey::get().unwrap(); + let lock = crate::RwLock::new("Hello"); + let guard = lock.write(key); + + std::thread::scope(|s| { + s.spawn(|| { + let key = ThreadKey::get().unwrap(); + let r = lock.try_write(key); + assert!(r.is_err()); + }); + }); + + drop(guard); + } + + #[test] + fn try_read_can_fail() { + let key = ThreadKey::get().unwrap(); + let lock = crate::RwLock::new("Hello"); + let guard = lock.write(key); + + std::thread::scope(|s| { + s.spawn(|| { + let key = ThreadKey::get().unwrap(); + let r = lock.try_read(key); + assert!(r.is_err()); + }); + }); + + drop(guard); + } + #[test] fn read_display_works() { let key = ThreadKey::get().unwrap(); @@ -275,4 +394,200 @@ mod tests { assert!(!lock.is_locked()); } + + #[test] + fn unlock_write() { + let key = ThreadKey::get().unwrap(); + let lock = crate::RwLock::new("Hello, world"); + + let mut guard = lock.write(key); + *guard = "Goodbye, world!"; + let key = RwLock::unlock_write(guard); + + let guard = lock.read(key); + assert_eq!(*guard, "Goodbye, world!"); + } + + #[test] + fn unlock_read() { + let key = ThreadKey::get().unwrap(); + let lock = crate::RwLock::new("Hello, world"); + + let guard = lock.read(key); + assert_eq!(*guard, "Hello, world"); + let key = RwLock::unlock_read(guard); + + let guard = lock.write(key); + assert_eq!(*guard, "Hello, world"); + } + + #[test] + fn unlock_read_lock() { + let key = ThreadKey::get().unwrap(); + let lock = crate::RwLock::new("Hello, world"); + let reader = ReadLock::new(&lock); + + let guard = reader.lock(key); + let key = ReadLock::unlock(guard); + + lock.write(key); + } + + #[test] + fn unlock_write_lock() { + let key = ThreadKey::get().unwrap(); + let lock = crate::RwLock::new("Hello, world"); + let writer = WriteLock::from(&lock); + + let guard = writer.lock(key); + let key = WriteLock::unlock(guard); + + lock.write(key); + } + + #[test] + fn read_lock_in_collection() { + let mut key = ThreadKey::get().unwrap(); + let lock = crate::RwLock::new("hi"); + let collection = LockCollection::try_new(ReadLock::new(&lock)).unwrap(); + + collection.scoped_lock(&mut key, |guard| { + assert_eq!(*guard, "hi"); + }); + collection.scoped_read(&mut key, |guard| { + assert_eq!(*guard, "hi"); + }); + assert!(collection + .scoped_try_lock(&mut key, |guard| { + assert_eq!(*guard, "hi"); + }) + .is_ok()); + assert!(collection + .scoped_try_read(&mut key, |guard| { + assert_eq!(*guard, "hi"); + }) + .is_ok()); + + let guard = collection.lock(key); + assert_eq!(**guard, "hi"); + + let key = LockCollection::>::unlock(guard); + let guard = collection.read(key); + assert_eq!(**guard, "hi"); + + let key = LockCollection::>::unlock(guard); + let guard = lock.write(key); + + std::thread::scope(|s| { + s.spawn(|| { + let key = ThreadKey::get().unwrap(); + let guard = collection.try_lock(key); + assert!(guard.is_err()); + }); + s.spawn(|| { + let key = ThreadKey::get().unwrap(); + let guard = collection.try_read(key); + assert!(guard.is_err()); + }); + }); + + drop(guard); + } + + #[test] + fn write_lock_in_collection() { + let mut key = ThreadKey::get().unwrap(); + let lock = crate::RwLock::new("hi"); + let collection = LockCollection::try_new(WriteLock::new(&lock)).unwrap(); + + collection.scoped_lock(&mut key, |guard| { + assert_eq!(*guard, "hi"); + }); + assert!(collection + .scoped_try_lock(&mut key, |guard| { + assert_eq!(*guard, "hi"); + }) + .is_ok()); + + let guard = collection.lock(key); + assert_eq!(**guard, "hi"); + + let key = LockCollection::>::unlock(guard); + let guard = lock.write(key); + + std::thread::scope(|s| { + s.spawn(|| { + let key = ThreadKey::get().unwrap(); + let guard = collection.try_lock(key); + assert!(guard.is_err()); + }); + }); + + drop(guard); + } + + #[test] + fn read_ref_as_ref() { + let key = ThreadKey::get().unwrap(); + let lock = LockCollection::new(crate::RwLock::new("hi")); + let guard = lock.read(key); + + assert_eq!(*(*guard).as_ref(), "hi"); + } + + #[test] + fn read_guard_as_ref() { + let key = ThreadKey::get().unwrap(); + let lock = crate::RwLock::new("hi"); + let guard = lock.read(key); + + assert_eq!(*guard.as_ref(), "hi"); + } + + #[test] + fn write_ref_as_ref() { + let key = ThreadKey::get().unwrap(); + let lock = LockCollection::new(crate::RwLock::new("hi")); + let guard = lock.lock(key); + + assert_eq!(*(*guard).as_ref(), "hi"); + } + + #[test] + fn write_guard_as_ref() { + let key = ThreadKey::get().unwrap(); + let lock = crate::RwLock::new("hi"); + let guard = lock.write(key); + + assert_eq!(*guard.as_ref(), "hi"); + } + + #[test] + fn write_guard_as_mut() { + let key = ThreadKey::get().unwrap(); + let lock = crate::RwLock::new("hi"); + let mut guard = lock.write(key); + + assert_eq!(*guard.as_mut(), "hi"); + *guard.as_mut() = "foo"; + assert_eq!(*guard.as_mut(), "foo"); + } + + #[test] + fn poison_read_lock() { + let lock = crate::RwLock::new("hi"); + let reader = ReadLock::new(&lock); + + reader.poison(); + assert!(lock.poison.is_poisoned()); + } + + #[test] + fn poison_write_lock() { + let lock = crate::RwLock::new("hi"); + let reader = WriteLock::new(&lock); + + reader.poison(); + assert!(lock.poison.is_poisoned()); + } } diff --git a/src/rwlock/read_guard.rs b/src/rwlock/read_guard.rs index 0d68c75..5b26c06 100644 --- a/src/rwlock/read_guard.rs +++ b/src/rwlock/read_guard.rs @@ -64,7 +64,7 @@ impl<'a, T: ?Sized, R: RawRwLock> RwLockReadRef<'a, T, R> { /// Creates an immutable reference for the underlying data of an [`RwLock`] /// without locking it or taking ownership of the key. #[must_use] - pub(crate) unsafe fn new(mutex: &'a RwLock) -> Self { + pub(crate) const unsafe fn new(mutex: &'a RwLock) -> Self { Self(mutex, PhantomData) } } @@ -109,7 +109,7 @@ impl<'a, T: ?Sized, R: RawRwLock> RwLockReadGuard<'a, T, R> { /// Create a guard to the given mutex. Undefined if multiple guards to the /// same mutex exist at once. #[must_use] - pub(super) unsafe fn new(rwlock: &'a RwLock, thread_key: ThreadKey) -> Self { + pub(super) const unsafe fn new(rwlock: &'a RwLock, thread_key: ThreadKey) -> Self { Self { rwlock: RwLockReadRef(rwlock, PhantomData), thread_key, diff --git a/src/rwlock/read_lock.rs b/src/rwlock/read_lock.rs index 05b184a..dd9e42f 100644 --- a/src/rwlock/read_lock.rs +++ b/src/rwlock/read_lock.rs @@ -3,11 +3,41 @@ use std::fmt::Debug; use lock_api::RawRwLock; use crate::lockable::{Lockable, RawLock, Sharable}; -use crate::ThreadKey; +use crate::{Keyable, ThreadKey}; use super::{ReadLock, RwLock, RwLockReadGuard, RwLockReadRef}; -unsafe impl Lockable for ReadLock<'_, T, R> { +unsafe impl RawLock for ReadLock<'_, T, R> { + fn poison(&self) { + self.0.poison() + } + + unsafe fn raw_write(&self) { + self.0.raw_read() + } + + unsafe fn raw_try_write(&self) -> bool { + self.0.raw_try_read() + } + + unsafe fn raw_unlock_write(&self) { + self.0.raw_unlock_read() + } + + unsafe fn raw_read(&self) { + self.0.raw_read() + } + + unsafe fn raw_try_read(&self) -> bool { + self.0.raw_try_read() + } + + unsafe fn raw_unlock_read(&self) { + self.0.raw_unlock_read() + } +} + +unsafe impl Lockable for ReadLock<'_, T, R> { type Guard<'g> = RwLockReadRef<'g, T, R> where @@ -19,7 +49,7 @@ unsafe impl Lockable for ReadLock<'_, T, R> Self: 'a; fn get_ptrs<'a>(&'a self, ptrs: &mut Vec<&'a dyn RawLock>) { - ptrs.push(self.as_ref()); + ptrs.push(self); } unsafe fn guard(&self) -> Self::Guard<'_> { @@ -31,7 +61,7 @@ unsafe impl Lockable for ReadLock<'_, T, R> } } -unsafe impl Sharable for ReadLock<'_, T, R> { +unsafe impl Sharable for ReadLock<'_, T, R> { type ReadGuard<'g> = RwLockReadRef<'g, T, R> where @@ -53,7 +83,7 @@ unsafe impl Sharable for ReadLock<'_, T, R> #[mutants::skip] #[cfg(not(tarpaulin_include))] -impl Debug for ReadLock<'_, T, R> { +impl Debug for ReadLock<'_, T, R> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { // safety: this is just a try lock, and the value is dropped // immediately after, so there's no risk of blocking ourselves @@ -104,7 +134,19 @@ impl<'l, T, R> ReadLock<'l, T, R> { } } -impl ReadLock<'_, T, R> { +impl ReadLock<'_, T, R> { + pub fn scoped_lock<'a, Ret>(&'a self, key: impl Keyable, f: impl Fn(&'a T) -> Ret) -> Ret { + self.0.scoped_read(key, f) + } + + pub fn scoped_try_lock<'a, Key: Keyable, Ret>( + &'a self, + key: Key, + f: impl Fn(&'a T) -> Ret, + ) -> Result { + self.0.scoped_try_read(key, f) + } + /// Locks the underlying [`RwLock`] with shared read access, blocking the /// current thread until it can be acquired. /// diff --git a/src/rwlock/rwlock.rs b/src/rwlock/rwlock.rs index 905ecf8..5f407d1 100644 --- a/src/rwlock/rwlock.rs +++ b/src/rwlock/rwlock.rs @@ -5,6 +5,7 @@ use std::panic::AssertUnwindSafe; use lock_api::RawRwLock; +use crate::collection::utils; use crate::handle_unwind::handle_unwind; use crate::lockable::{ Lockable, LockableGetMut, LockableIntoInner, OwnedLockable, RawLock, Sharable, @@ -18,7 +19,7 @@ unsafe impl RawLock for RwLock { self.poison.poison(); } - unsafe fn raw_lock(&self) { + unsafe fn raw_write(&self) { assert!( !self.poison.is_poisoned(), "The read-write lock has been killed" @@ -29,7 +30,7 @@ unsafe impl RawLock for RwLock { handle_unwind(|| this.raw.lock_exclusive(), || self.poison()) } - unsafe fn raw_try_lock(&self) -> bool { + unsafe fn raw_try_write(&self) -> bool { if self.poison.is_poisoned() { return false; } @@ -39,7 +40,7 @@ unsafe impl RawLock for RwLock { handle_unwind(|| this.raw.try_lock_exclusive(), || self.poison()) } - unsafe fn raw_unlock(&self) { + unsafe fn raw_unlock_write(&self) { // if the closure unwraps, then the mutex will be killed let this = AssertUnwindSafe(self); handle_unwind(|| this.raw.unlock_exclusive(), || self.poison()) @@ -73,7 +74,7 @@ unsafe impl RawLock for RwLock { } } -unsafe impl Lockable for RwLock { +unsafe impl Lockable for RwLock { type Guard<'g> = RwLockWriteRef<'g, T, R> where @@ -97,7 +98,7 @@ unsafe impl Lockable for RwLock { } } -unsafe impl Sharable for RwLock { +unsafe impl Sharable for RwLock { type ReadGuard<'g> = RwLockReadRef<'g, T, R> where @@ -117,9 +118,9 @@ unsafe impl Sharable for RwLock { } } -unsafe impl OwnedLockable for RwLock {} +unsafe impl OwnedLockable for RwLock {} -impl LockableIntoInner for RwLock { +impl LockableIntoInner for RwLock { type Inner = T; fn into_inner(self) -> Self::Inner { @@ -127,7 +128,7 @@ impl LockableIntoInner for RwLock { } } -impl LockableGetMut for RwLock { +impl LockableGetMut for RwLock { type Inner<'a> = &'a mut T where @@ -160,7 +161,7 @@ impl RwLock { #[mutants::skip] #[cfg(not(tarpaulin_include))] -impl Debug for RwLock { +impl Debug for RwLock { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { // safety: this is just a try lock, and the value is dropped // immediately after, so there's no risk of blocking ourselves @@ -247,85 +248,29 @@ impl RwLock { } } -impl RwLock { - pub fn scoped_read(&self, key: impl Keyable, f: impl Fn(&T) -> Ret) -> Ret { - unsafe { - // safety: we have the thread key - self.raw_read(); - - // safety: the rwlock was just locked - let r = f(self.data.get().as_ref().unwrap_unchecked()); - - // safety: the rwlock is already locked - self.raw_unlock_read(); - - drop(key); // ensure the key stays valid for long enough - - r - } +impl RwLock { + pub fn scoped_read<'a, Ret>(&'a self, key: impl Keyable, f: impl Fn(&'a T) -> Ret) -> Ret { + utils::scoped_read(self, key, f) } - pub fn scoped_try_read( - &self, + pub fn scoped_try_read<'a, Key: Keyable, Ret>( + &'a self, key: Key, - f: impl Fn(&T) -> Ret, + f: impl Fn(&'a T) -> Ret, ) -> Result { - unsafe { - // safety: we have the thread key - if !self.raw_try_read() { - return Err(key); - } - - // safety: the rwlock was just locked - let r = f(self.data.get().as_ref().unwrap_unchecked()); - - // safety: the rwlock is already locked - self.raw_unlock_read(); - - drop(key); // ensure the key stays valid for long enough - - Ok(r) - } + utils::scoped_try_read(self, key, f) } - pub fn scoped_write(&self, key: impl Keyable, f: impl Fn(&mut T) -> Ret) -> Ret { - unsafe { - // safety: we have the thread key - self.raw_lock(); - - // safety: we just locked the rwlock - let r = f(self.data.get().as_mut().unwrap_unchecked()); - - // safety: the rwlock is already locked - self.raw_unlock(); - - drop(key); // ensure the key stays valid for long enough - - r - } + pub fn scoped_write<'a, Ret>(&'a self, key: impl Keyable, f: impl Fn(&'a mut T) -> Ret) -> Ret { + utils::scoped_write(self, key, f) } - pub fn scoped_try_write( - &self, + pub fn scoped_try_write<'a, Key: Keyable, Ret>( + &'a self, key: Key, - f: impl Fn(&mut T) -> Ret, + f: impl Fn(&'a mut T) -> Ret, ) -> Result { - unsafe { - // safety: we have the thread key - if !self.raw_try_lock() { - return Err(key); - } - - // safety: the rwlock was just locked - let r = f(self.data.get().as_mut().unwrap_unchecked()); - - // safety: the rwlock is already locked - self.raw_unlock(); - - drop(key); // ensure the key stays valid for long enough - - Ok(r) - } + utils::scoped_try_write(self, key, f) } /// Locks this `RwLock` with shared read access, blocking the current @@ -426,7 +371,7 @@ impl RwLock { /// without exclusive access to the key is undefined behavior. #[cfg(test)] pub(crate) unsafe fn try_write_no_key(&self) -> Option> { - if self.raw_try_lock() { + if self.raw_try_write() { // safety: the lock is locked first Some(RwLockWriteRef(self, PhantomData)) } else { @@ -463,7 +408,7 @@ impl RwLock { /// [`ThreadKey`]: `crate::ThreadKey` pub fn write(&self, key: ThreadKey) -> RwLockWriteGuard<'_, T, R> { unsafe { - self.raw_lock(); + self.raw_write(); // safety: the lock is locked first RwLockWriteGuard::new(self, key) @@ -498,7 +443,7 @@ impl RwLock { /// ``` pub fn try_write(&self, key: ThreadKey) -> Result, ThreadKey> { unsafe { - if self.raw_try_lock() { + if self.raw_try_write() { // safety: the lock is locked first Ok(RwLockWriteGuard::new(self, key)) } else { @@ -533,9 +478,7 @@ impl RwLock { /// ``` #[must_use] pub fn unlock_read(guard: RwLockReadGuard<'_, T, R>) -> ThreadKey { - unsafe { - guard.rwlock.0.raw_unlock_read(); - } + drop(guard.rwlock); guard.thread_key } @@ -560,9 +503,7 @@ impl RwLock { /// ``` #[must_use] pub fn unlock_write(guard: RwLockWriteGuard<'_, T, R>) -> ThreadKey { - unsafe { - guard.rwlock.0.raw_unlock(); - } + drop(guard.rwlock); guard.thread_key } } diff --git a/src/rwlock/write_guard.rs b/src/rwlock/write_guard.rs index 3fabf8e..c7676b5 100644 --- a/src/rwlock/write_guard.rs +++ b/src/rwlock/write_guard.rs @@ -71,7 +71,7 @@ impl Drop for RwLockWriteRef<'_, T, R> { fn drop(&mut self) { // safety: this guard is being destroyed, so the data cannot be // accessed without locking again - unsafe { self.0.raw_unlock() } + unsafe { self.0.raw_unlock_write() } } } @@ -79,7 +79,7 @@ impl<'a, T: ?Sized + 'a, R: RawRwLock> RwLockWriteRef<'a, T, R> { /// Creates a reference to the underlying data of an [`RwLock`] without /// locking or taking ownership of the key. #[must_use] - pub(crate) unsafe fn new(mutex: &'a RwLock) -> Self { + pub(crate) const unsafe fn new(mutex: &'a RwLock) -> Self { Self(mutex, PhantomData) } } @@ -136,7 +136,7 @@ impl<'a, T: ?Sized + 'a, R: RawRwLock> RwLockWriteGuard<'a, T, R> { /// Create a guard to the given mutex. Undefined if multiple guards to the /// same mutex exist at once. #[must_use] - pub(super) unsafe fn new(rwlock: &'a RwLock, thread_key: ThreadKey) -> Self { + pub(super) const unsafe fn new(rwlock: &'a RwLock, thread_key: ThreadKey) -> Self { Self { rwlock: RwLockWriteRef(rwlock, PhantomData), thread_key, diff --git a/src/rwlock/write_lock.rs b/src/rwlock/write_lock.rs index 8a44a2d..5ae4dda 100644 --- a/src/rwlock/write_lock.rs +++ b/src/rwlock/write_lock.rs @@ -3,11 +3,41 @@ use std::fmt::Debug; use lock_api::RawRwLock; use crate::lockable::{Lockable, RawLock}; -use crate::ThreadKey; +use crate::{Keyable, ThreadKey}; use super::{RwLock, RwLockWriteGuard, RwLockWriteRef, WriteLock}; -unsafe impl Lockable for WriteLock<'_, T, R> { +unsafe impl RawLock for WriteLock<'_, T, R> { + fn poison(&self) { + self.0.poison() + } + + unsafe fn raw_write(&self) { + self.0.raw_write() + } + + unsafe fn raw_try_write(&self) -> bool { + self.0.raw_try_write() + } + + unsafe fn raw_unlock_write(&self) { + self.0.raw_unlock_write() + } + + unsafe fn raw_read(&self) { + self.0.raw_write() + } + + unsafe fn raw_try_read(&self) -> bool { + self.0.raw_try_write() + } + + unsafe fn raw_unlock_read(&self) { + self.0.raw_unlock_write() + } +} + +unsafe impl Lockable for WriteLock<'_, T, R> { type Guard<'g> = RwLockWriteRef<'g, T, R> where @@ -19,7 +49,7 @@ unsafe impl Lockable for WriteLock<'_, T, R Self: 'a; fn get_ptrs<'a>(&'a self, ptrs: &mut Vec<&'a dyn RawLock>) { - ptrs.push(self.as_ref()); + ptrs.push(self) } unsafe fn guard(&self) -> Self::Guard<'_> { @@ -36,7 +66,7 @@ unsafe impl Lockable for WriteLock<'_, T, R #[mutants::skip] #[cfg(not(tarpaulin_include))] -impl Debug for WriteLock<'_, T, R> { +impl Debug for WriteLock<'_, T, R> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { // safety: this is just a try lock, and the value is dropped // immediately after, so there's no risk of blocking ourselves @@ -89,7 +119,19 @@ impl<'l, T, R> WriteLock<'l, T, R> { } } -impl WriteLock<'_, T, R> { +impl WriteLock<'_, T, R> { + pub fn scoped_lock<'a, Ret>(&'a self, key: impl Keyable, f: impl Fn(&'a mut T) -> Ret) -> Ret { + self.0.scoped_write(key, f) + } + + pub fn scoped_try_lock<'a, Key: Keyable, Ret>( + &'a self, + key: Key, + f: impl Fn(&'a mut T) -> Ret, + ) -> Result { + self.0.scoped_try_write(key, f) + } + /// Locks the underlying [`RwLock`] with exclusive write access, blocking /// the current until it can be acquired. /// diff --git a/src/thread.rs b/src/thread.rs new file mode 100644 index 0000000..6e9c270 --- /dev/null +++ b/src/thread.rs @@ -0,0 +1,19 @@ +use std::marker::PhantomData; + +mod scope; + +#[derive(Debug)] +pub struct Scope<'scope, 'env: 'scope>(PhantomData<(&'env (), &'scope ())>); + +#[derive(Debug)] +pub struct ScopedJoinHandle<'scope, T> { + handle: std::thread::JoinHandle, + _phantom: PhantomData<&'scope ()>, +} + +pub struct JoinHandle { + handle: std::thread::JoinHandle, + key: crate::ThreadKey, +} + +pub struct ThreadBuilder(std::thread::Builder); diff --git a/src/thread/scope.rst b/src/thread/scope.rst new file mode 100644 index 0000000..09319cb --- /dev/null +++ b/src/thread/scope.rst @@ -0,0 +1,47 @@ +use std::marker::PhantomData; + +use crate::{Keyable, ThreadKey}; + +use super::{Scope, ScopedJoinHandle}; + +pub fn scope<'env, F, T>(key: impl Keyable, f: F) -> T +where + F: for<'scope> FnOnce(&'scope Scope<'scope, 'env>) -> T, +{ + let scope = Scope(PhantomData); + let t = f(&scope); + drop(key); + t +} + +impl<'scope> Scope<'scope, '_> { + #[allow(clippy::unused_self)] + pub fn spawn( + &self, + f: impl FnOnce(ThreadKey) -> T + Send + 'scope, + ) -> std::io::Result> { + unsafe { + // safety: the lifetimes ensure that the data lives long enough + let handle = std::thread::Builder::new().spawn_unchecked(|| { + // safety: the thread just started, so the key cannot be acquired yet + let key = ThreadKey::get().unwrap_unchecked(); + f(key) + })?; + + Ok(ScopedJoinHandle { + handle, + _phantom: PhantomData, + }) + } + } +} + +impl ScopedJoinHandle<'_, T> { + pub fn is_finished(&self) -> bool { + self.handle.is_finished() + } + + pub fn join(self) -> std::thread::Result { + self.handle.join() + } +} -- cgit v1.2.3