summaryrefslogtreecommitdiff
path: root/src/rwlock/rwlock.rs
diff options
context:
space:
mode:
authorBotahamec <botahamec@outlook.com>2025-02-28 16:09:11 -0500
committerBotahamec <botahamec@outlook.com>2025-02-28 16:09:11 -0500
commit4ba03be97e6cc7e790bbc9bfc18caaa228c8a262 (patch)
treea257184577a93ddf240aba698755c2886188788b /src/rwlock/rwlock.rs
parent4a5ec04a29cba07c5960792528bd66b0f99ee3ee (diff)
Scoped lock API
Diffstat (limited to 'src/rwlock/rwlock.rs')
-rw-r--r--src/rwlock/rwlock.rs130
1 files changed, 107 insertions, 23 deletions
diff --git a/src/rwlock/rwlock.rs b/src/rwlock/rwlock.rs
index 038e6c7..905ecf8 100644
--- a/src/rwlock/rwlock.rs
+++ b/src/rwlock/rwlock.rs
@@ -6,10 +6,10 @@ use std::panic::AssertUnwindSafe;
use lock_api::RawRwLock;
use crate::handle_unwind::handle_unwind;
-use crate::key::Keyable;
use crate::lockable::{
Lockable, LockableGetMut, LockableIntoInner, OwnedLockable, RawLock, Sharable,
};
+use crate::{Keyable, ThreadKey};
use super::{PoisonFlag, RwLock, RwLockReadGuard, RwLockReadRef, RwLockWriteGuard, RwLockWriteRef};
@@ -79,6 +79,11 @@ unsafe impl<T: Send, R: RawRwLock + Send + Sync> Lockable for RwLock<T, R> {
where
Self: 'g;
+ type DataMut<'a>
+ = &'a mut T
+ where
+ Self: 'a;
+
fn get_ptrs<'a>(&'a self, ptrs: &mut Vec<&'a dyn RawLock>) {
ptrs.push(self);
}
@@ -86,6 +91,10 @@ unsafe impl<T: Send, R: RawRwLock + Send + Sync> Lockable for RwLock<T, R> {
unsafe fn guard(&self) -> Self::Guard<'_> {
RwLockWriteRef::new(self)
}
+
+ unsafe fn data_mut(&self) -> Self::DataMut<'_> {
+ self.data.get().as_mut().unwrap_unchecked()
+ }
}
unsafe impl<T: Send, R: RawRwLock + Send + Sync> Sharable for RwLock<T, R> {
@@ -94,9 +103,18 @@ unsafe impl<T: Send, R: RawRwLock + Send + Sync> Sharable for RwLock<T, R> {
where
Self: 'g;
+ type DataRef<'a>
+ = &'a T
+ where
+ Self: 'a;
+
unsafe fn read_guard(&self) -> Self::ReadGuard<'_> {
RwLockReadRef::new(self)
}
+
+ unsafe fn data_ref(&self) -> Self::DataRef<'_> {
+ self.data.get().as_ref().unwrap_unchecked()
+ }
}
unsafe impl<T: Send, R: RawRwLock + Send + Sync> OwnedLockable for RwLock<T, R> {}
@@ -230,6 +248,86 @@ impl<T: ?Sized, R> RwLock<T, R> {
}
impl<T: ?Sized, R: RawRwLock> RwLock<T, R> {
+ pub fn scoped_read<Ret>(&self, key: impl Keyable, f: impl Fn(&T) -> Ret) -> Ret {
+ unsafe {
+ // safety: we have the thread key
+ self.raw_read();
+
+ // safety: the rwlock was just locked
+ let r = f(self.data.get().as_ref().unwrap_unchecked());
+
+ // safety: the rwlock is already locked
+ self.raw_unlock_read();
+
+ drop(key); // ensure the key stays valid for long enough
+
+ r
+ }
+ }
+
+ pub fn scoped_try_read<Key: Keyable, Ret>(
+ &self,
+ key: Key,
+ f: impl Fn(&T) -> Ret,
+ ) -> Result<Ret, Key> {
+ unsafe {
+ // safety: we have the thread key
+ if !self.raw_try_read() {
+ return Err(key);
+ }
+
+ // safety: the rwlock was just locked
+ let r = f(self.data.get().as_ref().unwrap_unchecked());
+
+ // safety: the rwlock is already locked
+ self.raw_unlock_read();
+
+ drop(key); // ensure the key stays valid for long enough
+
+ Ok(r)
+ }
+ }
+
+ pub fn scoped_write<Ret>(&self, key: impl Keyable, f: impl Fn(&mut T) -> Ret) -> Ret {
+ unsafe {
+ // safety: we have the thread key
+ self.raw_lock();
+
+ // safety: we just locked the rwlock
+ let r = f(self.data.get().as_mut().unwrap_unchecked());
+
+ // safety: the rwlock is already locked
+ self.raw_unlock();
+
+ drop(key); // ensure the key stays valid for long enough
+
+ r
+ }
+ }
+
+ pub fn scoped_try_write<Key: Keyable, Ret>(
+ &self,
+ key: Key,
+ f: impl Fn(&mut T) -> Ret,
+ ) -> Result<Ret, Key> {
+ unsafe {
+ // safety: we have the thread key
+ if !self.raw_try_lock() {
+ return Err(key);
+ }
+
+ // safety: the rwlock was just locked
+ let r = f(self.data.get().as_mut().unwrap_unchecked());
+
+ // safety: the rwlock is already locked
+ self.raw_unlock();
+
+ drop(key); // ensure the key stays valid for long enough
+
+ Ok(r)
+ }
+ }
+
/// Locks this `RwLock` with shared read access, blocking the current
/// thread until it can be acquired.
///
@@ -264,10 +362,7 @@ impl<T: ?Sized, R: RawRwLock> RwLock<T, R> {
/// ```
///
/// [`ThreadKey`]: `crate::ThreadKey`
- pub fn read<'s, 'key: 's, Key: Keyable>(
- &'s self,
- key: Key,
- ) -> RwLockReadGuard<'s, 'key, T, Key, R> {
+ pub fn read(&self, key: ThreadKey) -> RwLockReadGuard<'_, T, R> {
unsafe {
self.raw_read();
@@ -305,10 +400,7 @@ impl<T: ?Sized, R: RawRwLock> RwLock<T, R> {
/// Err(_) => unreachable!(),
/// };
/// ```
- pub fn try_read<'s, 'key: 's, Key: Keyable>(
- &'s self,
- key: Key,
- ) -> Result<RwLockReadGuard<'s, 'key, T, Key, R>, Key> {
+ pub fn try_read(&self, key: ThreadKey) -> Result<RwLockReadGuard<'_, T, R>, ThreadKey> {
unsafe {
if self.raw_try_read() {
// safety: the lock is locked first
@@ -369,10 +461,7 @@ impl<T: ?Sized, R: RawRwLock> RwLock<T, R> {
/// ```
///
/// [`ThreadKey`]: `crate::ThreadKey`
- pub fn write<'s, 'key: 's, Key: Keyable>(
- &'s self,
- key: Key,
- ) -> RwLockWriteGuard<'s, 'key, T, Key, R> {
+ pub fn write(&self, key: ThreadKey) -> RwLockWriteGuard<'_, T, R> {
unsafe {
self.raw_lock();
@@ -407,10 +496,7 @@ impl<T: ?Sized, R: RawRwLock> RwLock<T, R> {
/// let n = lock.read(key);
/// assert_eq!(*n, 1);
/// ```
- pub fn try_write<'s, 'key: 's, Key: Keyable>(
- &'s self,
- key: Key,
- ) -> Result<RwLockWriteGuard<'s, 'key, T, Key, R>, Key> {
+ pub fn try_write(&self, key: ThreadKey) -> Result<RwLockWriteGuard<'_, T, R>, ThreadKey> {
unsafe {
if self.raw_try_lock() {
// safety: the lock is locked first
@@ -445,9 +531,8 @@ impl<T: ?Sized, R: RawRwLock> RwLock<T, R> {
/// assert_eq!(*guard, 0);
/// let key = RwLock::unlock_read(guard);
/// ```
- pub fn unlock_read<'key, Key: Keyable + 'key>(
- guard: RwLockReadGuard<'_, 'key, T, Key, R>,
- ) -> Key {
+ #[must_use]
+ pub fn unlock_read(guard: RwLockReadGuard<'_, T, R>) -> ThreadKey {
unsafe {
guard.rwlock.0.raw_unlock_read();
}
@@ -473,9 +558,8 @@ impl<T: ?Sized, R: RawRwLock> RwLock<T, R> {
/// *guard += 20;
/// let key = RwLock::unlock_write(guard);
/// ```
- pub fn unlock_write<'key, Key: Keyable + 'key>(
- guard: RwLockWriteGuard<'_, 'key, T, Key, R>,
- ) -> Key {
+ #[must_use]
+ pub fn unlock_write(guard: RwLockWriteGuard<'_, T, R>) -> ThreadKey {
unsafe {
guard.rwlock.0.raw_unlock();
}