From 4ba03be97e6cc7e790bbc9bfc18caaa228c8a262 Mon Sep 17 00:00:00 2001 From: Botahamec Date: Fri, 28 Feb 2025 16:09:11 -0500 Subject: Scoped lock API --- src/mutex/guard.rs | 71 ++++++++---------------------------------------------- src/mutex/mutex.rs | 61 ++++++++++++++++++++++++++++++++++++++++------ 2 files changed, 64 insertions(+), 68 deletions(-) (limited to 'src/mutex') diff --git a/src/mutex/guard.rs b/src/mutex/guard.rs index 4e4d5f1..22e59c1 100644 --- a/src/mutex/guard.rs +++ b/src/mutex/guard.rs @@ -5,34 +5,14 @@ use std::ops::{Deref, DerefMut}; use lock_api::RawMutex; -use crate::key::Keyable; use crate::lockable::RawLock; +use crate::ThreadKey; use super::{Mutex, MutexGuard, MutexRef}; // These impls make things slightly easier because now you can use // `println!("{guard}")` instead of `println!("{}", *guard)` -impl PartialEq for MutexRef<'_, T, R> { - fn eq(&self, other: &Self) -> bool { - self.deref().eq(&**other) - } -} - -impl Eq for MutexRef<'_, T, R> {} - -impl PartialOrd for MutexRef<'_, T, R> { - fn partial_cmp(&self, other: &Self) -> Option { - self.deref().partial_cmp(&**other) - } -} - -impl Ord for MutexRef<'_, T, R> { - fn cmp(&self, other: &Self) -> std::cmp::Ordering { - self.deref().cmp(&**other) - } -} - #[mutants::skip] // hashing involves RNG and is hard to test #[cfg(not(tarpaulin_include))] impl Hash for MutexRef<'_, T, R> { @@ -107,39 +87,9 @@ impl<'a, T: ?Sized, R: RawMutex> MutexRef<'a, T, R> { // it's kinda annoying to re-implement some of this stuff on guards // there's nothing i can do about that -#[mutants::skip] // it's hard to get two guards safely -#[cfg(not(tarpaulin_include))] -impl PartialEq for MutexGuard<'_, '_, T, Key, R> { - fn eq(&self, other: &Self) -> bool { - self.deref().eq(&**other) - } -} - -#[mutants::skip] // it's hard to get two guards safely -#[cfg(not(tarpaulin_include))] -impl Eq for MutexGuard<'_, '_, T, Key, R> {} - -#[mutants::skip] // it's hard to get two guards safely -#[cfg(not(tarpaulin_include))] -impl PartialOrd - for MutexGuard<'_, '_, T, Key, R> -{ - fn partial_cmp(&self, other: &Self) -> Option { - self.deref().partial_cmp(&**other) - } -} - -#[mutants::skip] // it's hard to get two guards safely -#[cfg(not(tarpaulin_include))] -impl Ord for MutexGuard<'_, '_, T, Key, R> { - fn cmp(&self, other: &Self) -> std::cmp::Ordering { - self.deref().cmp(&**other) - } -} - #[mutants::skip] // hashing involves RNG and is hard to test #[cfg(not(tarpaulin_include))] -impl Hash for MutexGuard<'_, '_, T, Key, R> { +impl Hash for MutexGuard<'_, T, R> { fn hash(&self, state: &mut H) { self.deref().hash(state) } @@ -147,19 +97,19 @@ impl Hash for MutexGuard<'_, '_, T, #[mutants::skip] #[cfg(not(tarpaulin_include))] -impl Debug for MutexGuard<'_, '_, T, Key, R> { +impl Debug for MutexGuard<'_, T, R> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { Debug::fmt(&**self, f) } } -impl Display for MutexGuard<'_, '_, T, Key, R> { +impl Display for MutexGuard<'_, T, R> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { Display::fmt(&**self, f) } } -impl Deref for MutexGuard<'_, '_, T, Key, R> { +impl Deref for MutexGuard<'_, T, R> { type Target = T; fn deref(&self) -> &Self::Target { @@ -167,33 +117,32 @@ impl Deref for MutexGuard<'_, '_, T, Key, } } -impl DerefMut for MutexGuard<'_, '_, T, Key, R> { +impl DerefMut for MutexGuard<'_, T, R> { fn deref_mut(&mut self) -> &mut Self::Target { &mut self.mutex } } -impl AsRef for MutexGuard<'_, '_, T, Key, R> { +impl AsRef for MutexGuard<'_, T, R> { fn as_ref(&self) -> &T { self } } -impl AsMut for MutexGuard<'_, '_, T, Key, R> { +impl AsMut for MutexGuard<'_, T, R> { fn as_mut(&mut self) -> &mut T { self } } -impl<'a, T: ?Sized, Key: Keyable, R: RawMutex> MutexGuard<'a, '_, T, Key, R> { +impl<'a, T: ?Sized, R: RawMutex> MutexGuard<'a, T, R> { /// Create a guard to the given mutex. Undefined if multiple guards to the /// same mutex exist at once. #[must_use] - pub(super) unsafe fn new(mutex: &'a Mutex, thread_key: Key) -> Self { + pub(super) unsafe fn new(mutex: &'a Mutex, thread_key: ThreadKey) -> Self { Self { mutex: MutexRef(mutex, PhantomData), thread_key, - _phantom: PhantomData, } } } diff --git a/src/mutex/mutex.rs b/src/mutex/mutex.rs index 0bd5286..1d8ce8b 100644 --- a/src/mutex/mutex.rs +++ b/src/mutex/mutex.rs @@ -6,9 +6,9 @@ use std::panic::AssertUnwindSafe; use lock_api::RawMutex; use crate::handle_unwind::handle_unwind; -use crate::key::Keyable; use crate::lockable::{Lockable, LockableGetMut, LockableIntoInner, OwnedLockable, RawLock}; use crate::poisonable::PoisonFlag; +use crate::{Keyable, ThreadKey}; use super::{Mutex, MutexGuard, MutexRef}; @@ -62,6 +62,11 @@ unsafe impl Lockable for Mutex { where Self: 'g; + type DataMut<'a> + = &'a mut T + where + Self: 'a; + fn get_ptrs<'a>(&'a self, ptrs: &mut Vec<&'a dyn RawLock>) { ptrs.push(self); } @@ -69,6 +74,10 @@ unsafe impl Lockable for Mutex { unsafe fn guard(&self) -> Self::Guard<'_> { MutexRef::new(self) } + + unsafe fn data_mut(&self) -> Self::DataMut<'_> { + self.data.get().as_mut().unwrap_unchecked() + } } impl LockableIntoInner for Mutex { @@ -214,6 +223,46 @@ impl Mutex { } impl Mutex { + pub fn scoped_lock(&self, key: impl Keyable, f: impl FnOnce(&mut T) -> Ret) -> Ret { + unsafe { + // safety: we have the thread key + self.raw_lock(); + + // safety: the mutex was just locked + let r = f(self.data.get().as_mut().unwrap_unchecked()); + + // safety: we locked the mutex already + self.raw_unlock(); + + drop(key); // ensures we drop the key in the correct place + + r + } + } + + pub fn scoped_try_lock( + &self, + key: Key, + f: impl FnOnce(&mut T) -> Ret, + ) -> Result { + unsafe { + // safety: we have the thread key + if !self.raw_try_lock() { + return Err(key); + } + + // safety: the mutex was just locked + let r = f(self.data.get().as_mut().unwrap_unchecked()); + + // safety: we locked the mutex already + self.raw_unlock(); + + drop(key); // ensures we drop the key in the correct place + + Ok(r) + } + } + /// Block the thread until this mutex can be locked, and lock it. /// /// Upon returning, the thread is the only thread with a lock on the @@ -237,7 +286,7 @@ impl Mutex { /// let key = ThreadKey::get().unwrap(); /// assert_eq!(*mutex.lock(key), 10); /// ``` - pub fn lock<'s, 'k: 's, Key: Keyable>(&'s self, key: Key) -> MutexGuard<'s, 'k, T, Key, R> { + pub fn lock(&self, key: ThreadKey) -> MutexGuard<'_, T, R> { unsafe { // safety: we have the thread key self.raw_lock(); @@ -280,10 +329,7 @@ impl Mutex { /// let key = ThreadKey::get().unwrap(); /// assert_eq!(*mutex.lock(key), 10); /// ``` - pub fn try_lock<'s, 'k: 's, Key: Keyable>( - &'s self, - key: Key, - ) -> Result, Key> { + pub fn try_lock(&self, key: ThreadKey) -> Result, ThreadKey> { unsafe { // safety: we have the key to the mutex if self.raw_try_lock() { @@ -322,7 +368,8 @@ impl Mutex { /// /// let key = Mutex::unlock(guard); /// ``` - pub fn unlock<'a, 'k: 'a, Key: Keyable + 'k>(guard: MutexGuard<'a, 'k, T, Key, R>) -> Key { + #[must_use] + pub fn unlock(guard: MutexGuard<'_, T, R>) -> ThreadKey { unsafe { guard.mutex.0.raw_unlock(); } -- cgit v1.2.3