diff options
| author | Botahamec <botahamec@outlook.com> | 2025-02-28 16:09:11 -0500 |
|---|---|---|
| committer | Botahamec <botahamec@outlook.com> | 2025-02-28 16:09:11 -0500 |
| commit | 4ba03be97e6cc7e790bbc9bfc18caaa228c8a262 (patch) | |
| tree | a257184577a93ddf240aba698755c2886188788b /src/mutex | |
| parent | 4a5ec04a29cba07c5960792528bd66b0f99ee3ee (diff) | |
Scoped lock API
Diffstat (limited to 'src/mutex')
| -rw-r--r-- | src/mutex/guard.rs | 71 | ||||
| -rw-r--r-- | src/mutex/mutex.rs | 61 |
2 files changed, 64 insertions, 68 deletions
diff --git a/src/mutex/guard.rs b/src/mutex/guard.rs index 4e4d5f1..22e59c1 100644 --- a/src/mutex/guard.rs +++ b/src/mutex/guard.rs @@ -5,34 +5,14 @@ use std::ops::{Deref, DerefMut}; use lock_api::RawMutex; -use crate::key::Keyable; use crate::lockable::RawLock; +use crate::ThreadKey; use super::{Mutex, MutexGuard, MutexRef}; // These impls make things slightly easier because now you can use // `println!("{guard}")` instead of `println!("{}", *guard)` -impl<T: PartialEq + ?Sized, R: RawMutex> PartialEq for MutexRef<'_, T, R> { - fn eq(&self, other: &Self) -> bool { - self.deref().eq(&**other) - } -} - -impl<T: Eq + ?Sized, R: RawMutex> Eq for MutexRef<'_, T, R> {} - -impl<T: PartialOrd + ?Sized, R: RawMutex> PartialOrd for MutexRef<'_, T, R> { - fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> { - self.deref().partial_cmp(&**other) - } -} - -impl<T: Ord + ?Sized, R: RawMutex> Ord for MutexRef<'_, T, R> { - fn cmp(&self, other: &Self) -> std::cmp::Ordering { - self.deref().cmp(&**other) - } -} - #[mutants::skip] // hashing involves RNG and is hard to test #[cfg(not(tarpaulin_include))] impl<T: Hash + ?Sized, R: RawMutex> Hash for MutexRef<'_, T, R> { @@ -107,39 +87,9 @@ impl<'a, T: ?Sized, R: RawMutex> MutexRef<'a, T, R> { // it's kinda annoying to re-implement some of this stuff on guards // there's nothing i can do about that -#[mutants::skip] // it's hard to get two guards safely -#[cfg(not(tarpaulin_include))] -impl<T: PartialEq + ?Sized, R: RawMutex, Key: Keyable> PartialEq for MutexGuard<'_, '_, T, Key, R> { - fn eq(&self, other: &Self) -> bool { - self.deref().eq(&**other) - } -} - -#[mutants::skip] // it's hard to get two guards safely -#[cfg(not(tarpaulin_include))] -impl<T: Eq + ?Sized, R: RawMutex, Key: Keyable> Eq for MutexGuard<'_, '_, T, Key, R> {} - -#[mutants::skip] // it's hard to get two guards safely -#[cfg(not(tarpaulin_include))] -impl<T: PartialOrd + ?Sized, R: RawMutex, Key: Keyable> PartialOrd - for MutexGuard<'_, '_, T, Key, R> -{ - fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> { - self.deref().partial_cmp(&**other) - } -} - -#[mutants::skip] // it's hard to get two guards safely -#[cfg(not(tarpaulin_include))] -impl<T: Ord + ?Sized, R: RawMutex, Key: Keyable> Ord for MutexGuard<'_, '_, T, Key, R> { - fn cmp(&self, other: &Self) -> std::cmp::Ordering { - self.deref().cmp(&**other) - } -} - #[mutants::skip] // hashing involves RNG and is hard to test #[cfg(not(tarpaulin_include))] -impl<T: Hash + ?Sized, R: RawMutex, Key: Keyable> Hash for MutexGuard<'_, '_, T, Key, R> { +impl<T: Hash + ?Sized, R: RawMutex> Hash for MutexGuard<'_, T, R> { fn hash<H: std::hash::Hasher>(&self, state: &mut H) { self.deref().hash(state) } @@ -147,19 +97,19 @@ impl<T: Hash + ?Sized, R: RawMutex, Key: Keyable> Hash for MutexGuard<'_, '_, T, #[mutants::skip] #[cfg(not(tarpaulin_include))] -impl<T: Debug + ?Sized, Key: Keyable, R: RawMutex> Debug for MutexGuard<'_, '_, T, Key, R> { +impl<T: Debug + ?Sized, R: RawMutex> Debug for MutexGuard<'_, T, R> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { Debug::fmt(&**self, f) } } -impl<T: Display + ?Sized, Key: Keyable, R: RawMutex> Display for MutexGuard<'_, '_, T, Key, R> { +impl<T: Display + ?Sized, R: RawMutex> Display for MutexGuard<'_, T, R> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { Display::fmt(&**self, f) } } -impl<T: ?Sized, Key: Keyable, R: RawMutex> Deref for MutexGuard<'_, '_, T, Key, R> { +impl<T: ?Sized, R: RawMutex> Deref for MutexGuard<'_, T, R> { type Target = T; fn deref(&self) -> &Self::Target { @@ -167,33 +117,32 @@ impl<T: ?Sized, Key: Keyable, R: RawMutex> Deref for MutexGuard<'_, '_, T, Key, } } -impl<T: ?Sized, Key: Keyable, R: RawMutex> DerefMut for MutexGuard<'_, '_, T, Key, R> { +impl<T: ?Sized, R: RawMutex> DerefMut for MutexGuard<'_, T, R> { fn deref_mut(&mut self) -> &mut Self::Target { &mut self.mutex } } -impl<T: ?Sized, Key: Keyable, R: RawMutex> AsRef<T> for MutexGuard<'_, '_, T, Key, R> { +impl<T: ?Sized, R: RawMutex> AsRef<T> for MutexGuard<'_, T, R> { fn as_ref(&self) -> &T { self } } -impl<T: ?Sized, Key: Keyable, R: RawMutex> AsMut<T> for MutexGuard<'_, '_, T, Key, R> { +impl<T: ?Sized, R: RawMutex> AsMut<T> for MutexGuard<'_, T, R> { fn as_mut(&mut self) -> &mut T { self } } -impl<'a, T: ?Sized, Key: Keyable, R: RawMutex> MutexGuard<'a, '_, T, Key, R> { +impl<'a, T: ?Sized, R: RawMutex> MutexGuard<'a, T, R> { /// Create a guard to the given mutex. Undefined if multiple guards to the /// same mutex exist at once. #[must_use] - pub(super) unsafe fn new(mutex: &'a Mutex<T, R>, thread_key: Key) -> Self { + pub(super) unsafe fn new(mutex: &'a Mutex<T, R>, thread_key: ThreadKey) -> Self { Self { mutex: MutexRef(mutex, PhantomData), thread_key, - _phantom: PhantomData, } } } diff --git a/src/mutex/mutex.rs b/src/mutex/mutex.rs index 0bd5286..1d8ce8b 100644 --- a/src/mutex/mutex.rs +++ b/src/mutex/mutex.rs @@ -6,9 +6,9 @@ use std::panic::AssertUnwindSafe; use lock_api::RawMutex; use crate::handle_unwind::handle_unwind; -use crate::key::Keyable; use crate::lockable::{Lockable, LockableGetMut, LockableIntoInner, OwnedLockable, RawLock}; use crate::poisonable::PoisonFlag; +use crate::{Keyable, ThreadKey}; use super::{Mutex, MutexGuard, MutexRef}; @@ -62,6 +62,11 @@ unsafe impl<T: Send, R: RawMutex + Send + Sync> Lockable for Mutex<T, R> { where Self: 'g; + type DataMut<'a> + = &'a mut T + where + Self: 'a; + fn get_ptrs<'a>(&'a self, ptrs: &mut Vec<&'a dyn RawLock>) { ptrs.push(self); } @@ -69,6 +74,10 @@ unsafe impl<T: Send, R: RawMutex + Send + Sync> Lockable for Mutex<T, R> { unsafe fn guard(&self) -> Self::Guard<'_> { MutexRef::new(self) } + + unsafe fn data_mut(&self) -> Self::DataMut<'_> { + self.data.get().as_mut().unwrap_unchecked() + } } impl<T: Send, R: RawMutex + Send + Sync> LockableIntoInner for Mutex<T, R> { @@ -214,6 +223,46 @@ impl<T: ?Sized, R> Mutex<T, R> { } impl<T: ?Sized, R: RawMutex> Mutex<T, R> { + pub fn scoped_lock<Ret>(&self, key: impl Keyable, f: impl FnOnce(&mut T) -> Ret) -> Ret { + unsafe { + // safety: we have the thread key + self.raw_lock(); + + // safety: the mutex was just locked + let r = f(self.data.get().as_mut().unwrap_unchecked()); + + // safety: we locked the mutex already + self.raw_unlock(); + + drop(key); // ensures we drop the key in the correct place + + r + } + } + + pub fn scoped_try_lock<Key: Keyable, Ret>( + &self, + key: Key, + f: impl FnOnce(&mut T) -> Ret, + ) -> Result<Ret, Key> { + unsafe { + // safety: we have the thread key + if !self.raw_try_lock() { + return Err(key); + } + + // safety: the mutex was just locked + let r = f(self.data.get().as_mut().unwrap_unchecked()); + + // safety: we locked the mutex already + self.raw_unlock(); + + drop(key); // ensures we drop the key in the correct place + + Ok(r) + } + } + /// Block the thread until this mutex can be locked, and lock it. /// /// Upon returning, the thread is the only thread with a lock on the @@ -237,7 +286,7 @@ impl<T: ?Sized, R: RawMutex> Mutex<T, R> { /// let key = ThreadKey::get().unwrap(); /// assert_eq!(*mutex.lock(key), 10); /// ``` - pub fn lock<'s, 'k: 's, Key: Keyable>(&'s self, key: Key) -> MutexGuard<'s, 'k, T, Key, R> { + pub fn lock(&self, key: ThreadKey) -> MutexGuard<'_, T, R> { unsafe { // safety: we have the thread key self.raw_lock(); @@ -280,10 +329,7 @@ impl<T: ?Sized, R: RawMutex> Mutex<T, R> { /// let key = ThreadKey::get().unwrap(); /// assert_eq!(*mutex.lock(key), 10); /// ``` - pub fn try_lock<'s, 'k: 's, Key: Keyable>( - &'s self, - key: Key, - ) -> Result<MutexGuard<'s, 'k, T, Key, R>, Key> { + pub fn try_lock(&self, key: ThreadKey) -> Result<MutexGuard<'_, T, R>, ThreadKey> { unsafe { // safety: we have the key to the mutex if self.raw_try_lock() { @@ -322,7 +368,8 @@ impl<T: ?Sized, R: RawMutex> Mutex<T, R> { /// /// let key = Mutex::unlock(guard); /// ``` - pub fn unlock<'a, 'k: 'a, Key: Keyable + 'k>(guard: MutexGuard<'a, 'k, T, Key, R>) -> Key { + #[must_use] + pub fn unlock(guard: MutexGuard<'_, T, R>) -> ThreadKey { unsafe { guard.mutex.0.raw_unlock(); } |
