diff options
| -rw-r--r-- | happylock.md | 107 | ||||
| -rw-r--r-- | src/collection/boxed.rs | 9 | ||||
| -rw-r--r-- | src/collection/guard.rs | 5 | ||||
| -rw-r--r-- | src/collection/owned.rs | 4 | ||||
| -rw-r--r-- | src/collection/ref.rs | 6 | ||||
| -rw-r--r-- | src/collection/retry.rs | 4 | ||||
| -rw-r--r-- | src/collection/utils.rs | 29 | ||||
| -rw-r--r-- | src/lockable.rs | 55 | ||||
| -rw-r--r-- | src/mutex/mutex.rs | 53 | ||||
| -rw-r--r-- | src/rwlock.rs | 7 | ||||
| -rw-r--r-- | src/rwlock/read_lock.rs | 2 | ||||
| -rw-r--r-- | src/rwlock/rwlock.rs | 89 | ||||
| -rw-r--r-- | src/rwlock/write_lock.rs | 2 |
13 files changed, 299 insertions, 73 deletions
diff --git a/happylock.md b/happylock.md index c157e19..37b0640 100644 --- a/happylock.md +++ b/happylock.md @@ -465,6 +465,98 @@ This is what we were trying to avoid earlier --- +## Keyable + +```rust +unsafe trait Keyable: Sealed {} +unsafe impl Keyable for ThreadKey {} +unsafe impl Keyable for &mut ThreadKey {} +``` + +This is helpful because you can get the thread key back immediately. + +```rust +impl<T, R> Mutex<T, R> { + pub fn lock<'a, 'key, Key: Keyable + 'key>( + &'a self, + key: Key + ) -> MutexGuard<'a, 'key, T, R, Key>; +} +``` + +--- + +## Keyable + +So conveniently, this compiles. + +```rust +let mut key = ThreadKey::get().unwrap(); +let guard = MUTEX1.lock(&mut key); + +// the first guard can no longer be used here +let guard = MUTEX1.lock(&mut key); +``` + +The problem is that this also compiles + +```rust +let guard = MUTEX1.lock(&mut key); +std::mem::forget(guard); + +// wait, the mutex is still locked! +let guard = MUTEX1.lock(&mut key); +// deadlocked now +``` + +--- + +## Scoped Threads + +Let's take inspiration from scoped threads: + +```rust +fn scope<'env, F, T>(f: F) -> T +where + F: for<'scope> FnOnce(&'scope Scope<'scope, env>) -> T; + +let mut a = vec![1, 2, 3]; +let mut x = 0; + +scope(|scope| { + scope.spawn(|| { + println!("we can borrow `a` here"); + dbg!(a) + }); + scope.spawn(|| { + println!("we can even borrow mutably"); + println!("no other threads will use it"); + x += a[0] + a[2]; + }); + println!("hello from the main thread"); +}); +``` + +The `Drop` implementation of the `Scope` type will join all of the spawned +threads. And because we only have a reference to the `Scope`, we'll never be +able to `mem::forget` it. + +--- + +## Scoped Locks + +Let's try the same thing for locks + +```rust +let mut key = ThreadKey::get().unwrap(); +let mutex_plus_one = MUTEX.scoped_lock(|guard: &mut i32| *guard + 1); +``` + +If you use scoped locks, then you can guarantee that locks will always be +unlocked (assuming you never immediately abort the thread). + +--- + ## RwLocks in collections This is what I used in HappyLock 0.1: @@ -537,7 +629,7 @@ Allows: `Poisonable<LockCollection>` and `LockCollection<Poisonable>` --- -# `LockableGetMut` +## `LockableGetMut` ```rust fn Mutex::<T>::get_mut(&mut self) -> &mut T // already exists in std @@ -557,19 +649,6 @@ impl<A: LockableGetMut, B: LockableGetMut> LockableGetMut for (A, B) { } } ``` - ---- - -## Missing Features - -- `Condvar`/`Barrier` -- `OnceLock` or `LazyLock` -- Standard Library Backend -- Support for `no_std` -- Convenience methods: `lock_swap`, `lock_set`? -- `try_lock_swap` doesn't need a `ThreadKey` -- Going further: `LockCell` API (preemptive allocation) - --- <!--_class: invert lead --> diff --git a/src/collection/boxed.rs b/src/collection/boxed.rs index 1891119..0a30eac 100644 --- a/src/collection/boxed.rs +++ b/src/collection/boxed.rs @@ -23,7 +23,6 @@ unsafe impl<L: Lockable> RawLock for BoxedLockCollection<L> { } unsafe fn raw_try_write(&self) -> bool { - println!("{}", self.locks().len()); utils::ordered_try_write(self.locks()) } @@ -60,7 +59,10 @@ unsafe impl<L: Lockable> Lockable for BoxedLockCollection<L> { Self: 'a; fn get_ptrs<'a>(&'a self, ptrs: &mut Vec<&'a dyn RawLock>) { - ptrs.push(self); + // Doing it this way means that if a boxed collection is put inside a + // different collection, it will use the other method of locking. However, + // this prevents duplicate locks in a collection. + ptrs.extend_from_slice(&self.locks); } unsafe fn guard(&self) -> Self::Guard<'_> { @@ -170,6 +172,7 @@ impl<L: Debug> Debug for BoxedLockCollection<L> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct(stringify!(BoxedLockCollection)) .field("data", &self.data) + // there's not much reason to show the sorted locks .finish_non_exhaustive() } } @@ -331,7 +334,7 @@ impl<L: Lockable> BoxedLockCollection<L> { // cast to *const () because fat pointers can't be converted to usize locks.sort_by_key(|lock| (&raw const **lock).cast::<()>() as usize); - // safety we're just changing the lifetimes + // safety: we're just changing the lifetimes let locks: Vec<&'static dyn RawLock> = std::mem::transmute(locks); let data = &raw const *data; Self { data, locks } diff --git a/src/collection/guard.rs b/src/collection/guard.rs index 78d9895..ab66ffe 100644 --- a/src/collection/guard.rs +++ b/src/collection/guard.rs @@ -12,6 +12,11 @@ impl<Guard: Hash> Hash for LockGuard<Guard> { } } +// No implementations of Eq, PartialEq, PartialOrd, or Ord +// You can't implement both PartialEq<Self> and PartialEq<T> +// It's easier to just implement neither and ask users to dereference +// This is less of a problem when using the scoped lock API + #[mutants::skip] #[cfg(not(tarpaulin_include))] impl<Guard: Debug> Debug for LockGuard<Guard> { diff --git a/src/collection/owned.rs b/src/collection/owned.rs index 68170d1..866d778 100644 --- a/src/collection/owned.rs +++ b/src/collection/owned.rs @@ -63,6 +63,9 @@ unsafe impl<L: Lockable> Lockable for OwnedLockCollection<L> { #[mutants::skip] // It's hard to test lkocks in an OwnedLockCollection, because they're owned #[cfg(not(tarpaulin_include))] fn get_ptrs<'a>(&'a self, ptrs: &mut Vec<&'a dyn RawLock>) { + // It's ok to use self here, because the values in the collection already + // cannot be referenced anywhere else. It's necessary to use self as the lock + // because otherwise we will be handing out shared references to the child ptrs.push(self) } @@ -263,6 +266,7 @@ impl<L: OwnedLockable> OwnedLockCollection<L> { /// ``` pub fn try_lock(&self, key: ThreadKey) -> Result<LockGuard<L::Guard<'_>>, ThreadKey> { let guard = unsafe { + // safety: we've acquired the key if !self.raw_try_write() { return Err(key); } diff --git a/src/collection/ref.rs b/src/collection/ref.rs index 5f96533..e71624d 100644 --- a/src/collection/ref.rs +++ b/src/collection/ref.rs @@ -71,7 +71,9 @@ unsafe impl<L: Lockable> Lockable for RefLockCollection<'_, L> { Self: 'a; fn get_ptrs<'a>(&'a self, ptrs: &mut Vec<&'a dyn RawLock>) { - ptrs.push(self) + // Just like with BoxedLockCollection, we need to return all the individual + // locks to avoid duplicates + ptrs.extend_from_slice(&self.locks); } unsafe fn guard(&self) -> Self::Guard<'_> { @@ -115,6 +117,7 @@ impl<L: Debug> Debug for RefLockCollection<'_, L> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct(stringify!(RefLockCollection)) .field("data", self.data) + // there's not much reason to show the sorting order .finish_non_exhaustive() } } @@ -314,6 +317,7 @@ impl<'a, L: Lockable> RefLockCollection<'a, L> { /// ``` pub fn try_lock(&self, key: ThreadKey) -> Result<LockGuard<L::Guard<'_>>, ThreadKey> { let guard = unsafe { + // safety: we have the thread key if !self.raw_try_write() { return Err(key); } diff --git a/src/collection/retry.rs b/src/collection/retry.rs index 70e5183..15f626d 100644 --- a/src/collection/retry.rs +++ b/src/collection/retry.rs @@ -244,7 +244,9 @@ unsafe impl<L: Lockable> Lockable for RetryingLockCollection<L> { Self: 'a; fn get_ptrs<'a>(&'a self, ptrs: &mut Vec<&'a dyn RawLock>) { - ptrs.push(self) + // this collection, just like the sorting collection, must return all of its + // locks in order to check for duplication + self.data.get_ptrs(ptrs) } unsafe fn guard(&self) -> Self::Guard<'_> { diff --git a/src/collection/utils.rs b/src/collection/utils.rs index 59a68da..71a023e 100644 --- a/src/collection/utils.rs +++ b/src/collection/utils.rs @@ -4,14 +4,17 @@ use crate::handle_unwind::handle_unwind; use crate::lockable::{Lockable, RawLock, Sharable}; use crate::Keyable; +/// Returns a list of locks in the given collection and sorts them by their +/// memory address #[must_use] pub fn get_locks<L: Lockable>(data: &L) -> Vec<&dyn RawLock> { - let mut locks = Vec::new(); - data.get_ptrs(&mut locks); + let mut locks = get_locks_unsorted(data); locks.sort_by_key(|lock| &raw const **lock); locks } +/// Returns a list of locks from the data. Unlike the above function, this does +/// not do any sorting of the locks. #[must_use] pub fn get_locks_unsorted<L: Lockable>(data: &L) -> Vec<&dyn RawLock> { let mut locks = Vec::new(); @@ -121,7 +124,7 @@ pub unsafe fn ordered_try_read(locks: &[&dyn RawLock]) -> bool { ) } -pub fn scoped_write<'a, L: RawLock + Lockable, R>( +pub fn scoped_write<'a, L: RawLock + Lockable + ?Sized, R>( collection: &'a L, key: impl Keyable, f: impl FnOnce(L::DataMut<'a>) -> R, @@ -131,7 +134,10 @@ pub fn scoped_write<'a, L: RawLock + Lockable, R>( collection.raw_write(); // safety: we just locked this - let r = f(collection.data_mut()); + let r = handle_unwind( + || f(collection.data_mut()), + || collection.raw_unlock_write(), + ); // this ensures the key is held long enough drop(key); @@ -143,7 +149,7 @@ pub fn scoped_write<'a, L: RawLock + Lockable, R>( } } -pub fn scoped_try_write<'a, L: RawLock + Lockable, Key: Keyable, R>( +pub fn scoped_try_write<'a, L: RawLock + Lockable + ?Sized, Key: Keyable, R>( collection: &'a L, key: Key, f: impl FnOnce(L::DataMut<'a>) -> R, @@ -155,7 +161,10 @@ pub fn scoped_try_write<'a, L: RawLock + Lockable, Key: Keyable, R>( } // safety: we just locked this - let r = f(collection.data_mut()); + let r = handle_unwind( + || f(collection.data_mut()), + || collection.raw_unlock_write(), + ); // this ensures the key is held long enough drop(key); @@ -167,7 +176,7 @@ pub fn scoped_try_write<'a, L: RawLock + Lockable, Key: Keyable, R>( } } -pub fn scoped_read<'a, L: RawLock + Sharable, R>( +pub fn scoped_read<'a, L: RawLock + Sharable + ?Sized, R>( collection: &'a L, key: impl Keyable, f: impl FnOnce(L::DataRef<'a>) -> R, @@ -177,7 +186,7 @@ pub fn scoped_read<'a, L: RawLock + Sharable, R>( collection.raw_read(); // safety: we just locked this - let r = f(collection.data_ref()); + let r = handle_unwind(|| f(collection.data_ref()), || collection.raw_unlock_read()); // this ensures the key is held long enough drop(key); @@ -189,7 +198,7 @@ pub fn scoped_read<'a, L: RawLock + Sharable, R>( } } -pub fn scoped_try_read<'a, L: RawLock + Sharable, Key: Keyable, R>( +pub fn scoped_try_read<'a, L: RawLock + Sharable + ?Sized, Key: Keyable, R>( collection: &'a L, key: Key, f: impl FnOnce(L::DataRef<'a>) -> R, @@ -201,7 +210,7 @@ pub fn scoped_try_read<'a, L: RawLock + Sharable, Key: Keyable, R>( } // safety: we just locked this - let r = f(collection.data_ref()); + let r = handle_unwind(|| f(collection.data_ref()), || collection.raw_unlock_read()); // this ensures the key is held long enough drop(key); diff --git a/src/lockable.rs b/src/lockable.rs index 94042ea..16e3968 100644 --- a/src/lockable.rs +++ b/src/lockable.rs @@ -105,6 +105,9 @@ pub unsafe trait RawLock { /// The order of the resulting list from `get_ptrs` must be deterministic. As /// long as the value is not mutated, the references must always be in the same /// order. +/// +/// The list returned by `get_ptrs` must contain any lock which could possibly +/// be referenced in another collection. pub unsafe trait Lockable { /// The exclusive guard that does not hold a key type Guard<'g> @@ -333,8 +336,6 @@ macro_rules! tuple_impls { } unsafe fn guard(&self) -> Self::Guard<'_> { - // It's weird that this works - // I don't think any other way of doing it compiles ($(self.$value.guard(),)*) } @@ -525,27 +526,17 @@ impl<T: LockableGetMut + 'static> LockableGetMut for Box<[T]> { } } -unsafe impl<T: Sharable> Sharable for Box<[T]> { - type ReadGuard<'g> - = Box<[T::ReadGuard<'g>]> - where - Self: 'g; - - type DataRef<'a> - = Box<[T::DataRef<'a>]> - where - Self: 'a; - - unsafe fn read_guard(&self) -> Self::ReadGuard<'_> { - self.iter().map(|lock| lock.read_guard()).collect() - } +impl<T: LockableIntoInner + 'static> LockableIntoInner for Box<[T]> { + type Inner = Box<[T::Inner]>; - unsafe fn data_ref(&self) -> Self::DataRef<'_> { - self.iter().map(|lock| lock.data_ref()).collect() + fn into_inner(self) -> Self::Inner { + Self::into_iter(self) + .map(LockableIntoInner::into_inner) + .collect() } } -unsafe impl<T: Sharable> Sharable for Vec<T> { +unsafe impl<T: Sharable> Sharable for Box<[T]> { type ReadGuard<'g> = Box<[T::ReadGuard<'g>]> where @@ -565,8 +556,6 @@ unsafe impl<T: Sharable> Sharable for Vec<T> { } } -unsafe impl<T: OwnedLockable> OwnedLockable for Box<[T]> {} - unsafe impl<T: Lockable> Lockable for Vec<T> { // There's no reason why I'd ever want to extend a list of lock guards type Guard<'g> @@ -594,11 +583,31 @@ unsafe impl<T: Lockable> Lockable for Vec<T> { } } +unsafe impl<T: Sharable> Sharable for Vec<T> { + type ReadGuard<'g> + = Box<[T::ReadGuard<'g>]> + where + Self: 'g; + + type DataRef<'a> + = Box<[T::DataRef<'a>]> + where + Self: 'a; + + unsafe fn read_guard(&self) -> Self::ReadGuard<'_> { + self.iter().map(|lock| lock.read_guard()).collect() + } + + unsafe fn data_ref(&self) -> Self::DataRef<'_> { + self.iter().map(|lock| lock.data_ref()).collect() + } +} + +unsafe impl<T: OwnedLockable> OwnedLockable for Box<[T]> {} + // I'd make a generic impl<T: Lockable, I: IntoIterator<Item=T>> Lockable for I // but I think that'd require sealing up this trait -// TODO: using edition 2024, impl LockableIntoInner for Box<[T]> - impl<T: LockableGetMut + 'static> LockableGetMut for Vec<T> { type Inner<'a> = Box<[T::Inner<'a>]> diff --git a/src/mutex/mutex.rs b/src/mutex/mutex.rs index f0fb680..a2813a1 100644 --- a/src/mutex/mutex.rs +++ b/src/mutex/mutex.rs @@ -5,7 +5,6 @@ use std::panic::AssertUnwindSafe; use lock_api::RawMutex; -use crate::collection::utils; use crate::handle_unwind::handle_unwind; use crate::lockable::{Lockable, LockableGetMut, LockableIntoInner, OwnedLockable, RawLock}; use crate::poisonable::PoisonFlag; @@ -87,7 +86,7 @@ unsafe impl<T, R: RawMutex> Lockable for Mutex<T, R> { } } -impl<T: Send, R: RawMutex> LockableIntoInner for Mutex<T, R> { +impl<T, R: RawMutex> LockableIntoInner for Mutex<T, R> { type Inner = T; fn into_inner(self) -> Self::Inner { @@ -95,7 +94,7 @@ impl<T: Send, R: RawMutex> LockableIntoInner for Mutex<T, R> { } } -impl<T: Send, R: RawMutex> LockableGetMut for Mutex<T, R> { +impl<T, R: RawMutex> LockableGetMut for Mutex<T, R> { type Inner<'a> = &'a mut T where @@ -106,7 +105,7 @@ impl<T: Send, R: RawMutex> LockableGetMut for Mutex<T, R> { } } -unsafe impl<T: Send, R: RawMutex> OwnedLockable for Mutex<T, R> {} +unsafe impl<T, R: RawMutex> OwnedLockable for Mutex<T, R> {} impl<T, R: RawMutex> Mutex<T, R> { /// Create a new unlocked `Mutex`. @@ -147,7 +146,7 @@ impl<T, R: RawMutex> Mutex<T, R> { #[mutants::skip] #[cfg(not(tarpaulin_include))] -impl<T: Debug, R: RawMutex> Debug for Mutex<T, R> { +impl<T: ?Sized + Debug, R: RawMutex> Debug for Mutex<T, R> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { // safety: this is just a try lock, and the value is dropped // immediately after, so there's no risk of blocking ourselves @@ -229,13 +228,30 @@ impl<T: ?Sized, R> Mutex<T, R> { } } -impl<T, R: RawMutex> Mutex<T, R> { +impl<T: ?Sized, R: RawMutex> Mutex<T, R> { pub fn scoped_lock<'a, Ret>( &'a self, key: impl Keyable, f: impl FnOnce(&'a mut T) -> Ret, ) -> Ret { - utils::scoped_write(self, key, f) + unsafe { + // safety: we have the key + self.raw_write(); + + // safety: the data has been locked + let r = handle_unwind( + || f(self.data.get().as_mut().unwrap_unchecked()), + || self.raw_unlock_write(), + ); + + // ensures the key is held long enough + drop(key); + + // safety: the mutex is still locked + self.raw_unlock_write(); + + r + } } pub fn scoped_try_lock<'a, Key: Keyable, Ret>( @@ -243,9 +259,30 @@ impl<T, R: RawMutex> Mutex<T, R> { key: Key, f: impl FnOnce(&'a mut T) -> Ret, ) -> Result<Ret, Key> { - utils::scoped_try_write(self, key, f) + unsafe { + // safety: we have the key + if !self.raw_try_write() { + return Err(key); + } + + // safety: the data has been locked + let r = handle_unwind( + || f(self.data.get().as_mut().unwrap_unchecked()), + || self.raw_unlock_write(), + ); + + // ensures the key is held long enough + drop(key); + + // safety: the mutex is still locked + self.raw_unlock_write(); + + Ok(r) + } } +} +impl<T: ?Sized, R: RawMutex> Mutex<T, R> { /// Block the thread until this mutex can be locked, and lock it. /// /// Upon returning, the thread is the only thread with a lock on the diff --git a/src/rwlock.rs b/src/rwlock.rs index 2d3dd85..f5c0ec5 100644 --- a/src/rwlock.rs +++ b/src/rwlock.rs @@ -58,7 +58,7 @@ pub struct RwLock<T: ?Sized, R> { /// /// [`LockCollection`]: `crate::LockCollection` #[repr(transparent)] -pub struct ReadLock<'l, T: ?Sized, R>(&'l RwLock<T, R>); +struct ReadLock<'l, T: ?Sized, R>(&'l RwLock<T, R>); /// Grants write access to an [`RwLock`] /// @@ -67,7 +67,7 @@ pub struct ReadLock<'l, T: ?Sized, R>(&'l RwLock<T, R>); /// /// [`LockCollection`]: `crate::LockCollection` #[repr(transparent)] -pub struct WriteLock<'l, T: ?Sized, R>(&'l RwLock<T, R>); +struct WriteLock<'l, T: ?Sized, R>(&'l RwLock<T, R>); /// RAII structure that unlocks the shared read access to a [`RwLock`] /// @@ -187,6 +187,7 @@ mod tests { } #[test] + #[ignore = "We've removed ReadLock"] fn read_lock_get_ptrs() { let rwlock = RwLock::new(5); let readlock = ReadLock::new(&rwlock); @@ -198,6 +199,7 @@ mod tests { } #[test] + #[ignore = "We've removed WriteLock"] fn write_lock_get_ptrs() { let rwlock = RwLock::new(5); let writelock = WriteLock::new(&rwlock); @@ -446,6 +448,7 @@ mod tests { } #[test] + #[ignore = "We've removed ReadLock"] fn read_lock_in_collection() { let mut key = ThreadKey::get().unwrap(); let lock = crate::RwLock::new("hi"); diff --git a/src/rwlock/read_lock.rs b/src/rwlock/read_lock.rs index dd9e42f..f13f2b9 100644 --- a/src/rwlock/read_lock.rs +++ b/src/rwlock/read_lock.rs @@ -49,7 +49,7 @@ unsafe impl<T, R: RawRwLock> Lockable for ReadLock<'_, T, R> { Self: 'a; fn get_ptrs<'a>(&'a self, ptrs: &mut Vec<&'a dyn RawLock>) { - ptrs.push(self); + ptrs.push(self.0); } unsafe fn guard(&self) -> Self::Guard<'_> { diff --git a/src/rwlock/rwlock.rs b/src/rwlock/rwlock.rs index 5f407d1..f1cdca5 100644 --- a/src/rwlock/rwlock.rs +++ b/src/rwlock/rwlock.rs @@ -5,7 +5,6 @@ use std::panic::AssertUnwindSafe; use lock_api::RawRwLock; -use crate::collection::utils; use crate::handle_unwind::handle_unwind; use crate::lockable::{ Lockable, LockableGetMut, LockableIntoInner, OwnedLockable, RawLock, Sharable, @@ -118,9 +117,9 @@ unsafe impl<T, R: RawRwLock> Sharable for RwLock<T, R> { } } -unsafe impl<T: Send, R: RawRwLock> OwnedLockable for RwLock<T, R> {} +unsafe impl<T, R: RawRwLock> OwnedLockable for RwLock<T, R> {} -impl<T: Send, R: RawRwLock> LockableIntoInner for RwLock<T, R> { +impl<T, R: RawRwLock> LockableIntoInner for RwLock<T, R> { type Inner = T; fn into_inner(self) -> Self::Inner { @@ -128,7 +127,7 @@ impl<T: Send, R: RawRwLock> LockableIntoInner for RwLock<T, R> { } } -impl<T: Send, R: RawRwLock> LockableGetMut for RwLock<T, R> { +impl<T, R: RawRwLock> LockableGetMut for RwLock<T, R> { type Inner<'a> = &'a mut T where @@ -248,9 +247,26 @@ impl<T: ?Sized, R> RwLock<T, R> { } } -impl<T, R: RawRwLock> RwLock<T, R> { +impl<T: ?Sized, R: RawRwLock> RwLock<T, R> { pub fn scoped_read<'a, Ret>(&'a self, key: impl Keyable, f: impl Fn(&'a T) -> Ret) -> Ret { - utils::scoped_read(self, key, f) + unsafe { + // safety: we have the key + self.raw_read(); + + // safety: the data has been locked + let r = handle_unwind( + || f(self.data.get().as_ref().unwrap_unchecked()), + || self.raw_unlock_read(), + ); + + // ensures the key is held long enough + drop(key); + + // safety: the mutex is still locked + self.raw_unlock_read(); + + r + } } pub fn scoped_try_read<'a, Key: Keyable, Ret>( @@ -258,11 +274,47 @@ impl<T, R: RawRwLock> RwLock<T, R> { key: Key, f: impl Fn(&'a T) -> Ret, ) -> Result<Ret, Key> { - utils::scoped_try_read(self, key, f) + unsafe { + // safety: we have the key + if !self.raw_try_read() { + return Err(key); + } + + // safety: the data has been locked + let r = handle_unwind( + || f(self.data.get().as_ref().unwrap_unchecked()), + || self.raw_unlock_read(), + ); + + // ensures the key is held long enough + drop(key); + + // safety: the mutex is still locked + self.raw_unlock_read(); + + Ok(r) + } } pub fn scoped_write<'a, Ret>(&'a self, key: impl Keyable, f: impl Fn(&'a mut T) -> Ret) -> Ret { - utils::scoped_write(self, key, f) + unsafe { + // safety: we have the key + self.raw_write(); + + // safety: the data has been locked + let r = handle_unwind( + || f(self.data.get().as_mut().unwrap_unchecked()), + || self.raw_unlock_write(), + ); + + // ensures the key is held long enough + drop(key); + + // safety: the mutex is still locked + self.raw_unlock_write(); + + r + } } pub fn scoped_try_write<'a, Key: Keyable, Ret>( @@ -270,7 +322,26 @@ impl<T, R: RawRwLock> RwLock<T, R> { key: Key, f: impl Fn(&'a mut T) -> Ret, ) -> Result<Ret, Key> { - utils::scoped_try_write(self, key, f) + unsafe { + // safety: we have the key + if !self.raw_try_write() { + return Err(key); + } + + // safety: the data has been locked + let r = handle_unwind( + || f(self.data.get().as_mut().unwrap_unchecked()), + || self.raw_unlock_write(), + ); + + // ensures the key is held long enough + drop(key); + + // safety: the mutex is still locked + self.raw_unlock_write(); + + Ok(r) + } } /// Locks this `RwLock` with shared read access, blocking the current diff --git a/src/rwlock/write_lock.rs b/src/rwlock/write_lock.rs index 5ae4dda..6469a67 100644 --- a/src/rwlock/write_lock.rs +++ b/src/rwlock/write_lock.rs @@ -49,7 +49,7 @@ unsafe impl<T, R: RawRwLock> Lockable for WriteLock<'_, T, R> { Self: 'a; fn get_ptrs<'a>(&'a self, ptrs: &mut Vec<&'a dyn RawLock>) { - ptrs.push(self) + ptrs.push(self.0); } unsafe fn guard(&self) -> Self::Guard<'_> { |
