summaryrefslogtreecommitdiff
path: root/src/collection/retry.rs
diff options
context:
space:
mode:
authorBotahamec <botahamec@outlook.com>2025-02-28 16:09:11 -0500
committerBotahamec <botahamec@outlook.com>2025-02-28 16:09:11 -0500
commit4ba03be97e6cc7e790bbc9bfc18caaa228c8a262 (patch)
treea257184577a93ddf240aba698755c2886188788b /src/collection/retry.rs
parent4a5ec04a29cba07c5960792528bd66b0f99ee3ee (diff)
Scoped lock API
Diffstat (limited to 'src/collection/retry.rs')
-rw-r--r--src/collection/retry.rs176
1 files changed, 127 insertions, 49 deletions
diff --git a/src/collection/retry.rs b/src/collection/retry.rs
index 331b669..775ea29 100644
--- a/src/collection/retry.rs
+++ b/src/collection/retry.rs
@@ -1,24 +1,18 @@
use std::cell::Cell;
use std::collections::HashSet;
-use std::marker::PhantomData;
use crate::collection::utils;
use crate::handle_unwind::handle_unwind;
use crate::lockable::{
Lockable, LockableGetMut, LockableIntoInner, OwnedLockable, RawLock, Sharable,
};
-use crate::Keyable;
+use crate::{Keyable, ThreadKey};
-use super::utils::{attempt_to_recover_locks_from_panic, attempt_to_recover_reads_from_panic};
+use super::utils::{
+ attempt_to_recover_locks_from_panic, attempt_to_recover_reads_from_panic, get_locks_unsorted,
+};
use super::{LockGuard, RetryingLockCollection};
-/// Get all raw locks in the collection
-fn get_locks<L: Lockable>(data: &L) -> Vec<&dyn RawLock> {
- let mut locks = Vec::new();
- data.get_ptrs(&mut locks);
- locks
-}
-
/// Checks that a collection contains no duplicate references to a lock.
fn contains_duplicates<L: Lockable>(data: L) -> bool {
let mut locks = Vec::new();
@@ -40,14 +34,14 @@ unsafe impl<L: Lockable> RawLock for RetryingLockCollection<L> {
#[mutants::skip] // this should never run
#[cfg(not(tarpaulin_include))]
fn poison(&self) {
- let locks = get_locks(&self.data);
+ let locks = get_locks_unsorted(&self.data);
for lock in locks {
lock.poison();
}
}
unsafe fn raw_lock(&self) {
- let locks = get_locks(&self.data);
+ let locks = get_locks_unsorted(&self.data);
if locks.is_empty() {
// this probably prevents a panic later
@@ -109,7 +103,7 @@ unsafe impl<L: Lockable> RawLock for RetryingLockCollection<L> {
}
unsafe fn raw_try_lock(&self) -> bool {
- let locks = get_locks(&self.data);
+ let locks = get_locks_unsorted(&self.data);
if locks.is_empty() {
// this is an interesting case, but it doesn't give us access to
@@ -139,7 +133,7 @@ unsafe impl<L: Lockable> RawLock for RetryingLockCollection<L> {
}
unsafe fn raw_unlock(&self) {
- let locks = get_locks(&self.data);
+ let locks = get_locks_unsorted(&self.data);
for lock in locks {
lock.raw_unlock();
@@ -147,7 +141,7 @@ unsafe impl<L: Lockable> RawLock for RetryingLockCollection<L> {
}
unsafe fn raw_read(&self) {
- let locks = get_locks(&self.data);
+ let locks = get_locks_unsorted(&self.data);
if locks.is_empty() {
// this probably prevents a panic later
@@ -200,7 +194,7 @@ unsafe impl<L: Lockable> RawLock for RetryingLockCollection<L> {
}
unsafe fn raw_try_read(&self) -> bool {
- let locks = get_locks(&self.data);
+ let locks = get_locks_unsorted(&self.data);
if locks.is_empty() {
// this is an interesting case, but it doesn't give us access to
@@ -229,7 +223,7 @@ unsafe impl<L: Lockable> RawLock for RetryingLockCollection<L> {
}
unsafe fn raw_unlock_read(&self) {
- let locks = get_locks(&self.data);
+ let locks = get_locks_unsorted(&self.data);
for lock in locks {
lock.raw_unlock_read();
@@ -243,6 +237,11 @@ unsafe impl<L: Lockable> Lockable for RetryingLockCollection<L> {
where
Self: 'g;
+ type DataMut<'a>
+ = L::DataMut<'a>
+ where
+ Self: 'a;
+
fn get_ptrs<'a>(&'a self, ptrs: &mut Vec<&'a dyn RawLock>) {
self.data.get_ptrs(ptrs)
}
@@ -250,6 +249,10 @@ unsafe impl<L: Lockable> Lockable for RetryingLockCollection<L> {
unsafe fn guard(&self) -> Self::Guard<'_> {
self.data.guard()
}
+
+ unsafe fn data_mut(&self) -> Self::DataMut<'_> {
+ self.data.data_mut()
+ }
}
unsafe impl<L: Sharable> Sharable for RetryingLockCollection<L> {
@@ -258,9 +261,18 @@ unsafe impl<L: Sharable> Sharable for RetryingLockCollection<L> {
where
Self: 'g;
+ type DataRef<'a>
+ = L::DataRef<'a>
+ where
+ Self: 'a;
+
unsafe fn read_guard(&self) -> Self::ReadGuard<'_> {
self.data.read_guard()
}
+
+ unsafe fn data_ref(&self) -> Self::DataRef<'_> {
+ self.data.data_ref()
+ }
}
unsafe impl<L: OwnedLockable> OwnedLockable for RetryingLockCollection<L> {}
@@ -516,6 +528,46 @@ impl<L: Lockable> RetryingLockCollection<L> {
(!contains_duplicates(&data)).then_some(Self { data })
}
+ pub fn scoped_lock<R>(&self, key: impl Keyable, f: impl Fn(L::DataMut<'_>) -> R) -> R {
+ unsafe {
+ // safety: we have the thread key
+ self.raw_lock();
+
+ // safety: the data was just locked
+ let r = f(self.data_mut());
+
+ // safety: the collection is still locked
+ self.raw_unlock();
+
+ drop(key); // ensure the key stays alive long enough
+
+ r
+ }
+ }
+
+ pub fn scoped_try_lock<Key: Keyable, R>(
+ &self,
+ key: Key,
+ f: impl Fn(L::DataMut<'_>) -> R,
+ ) -> Result<R, Key> {
+ unsafe {
+ // safety: we have the thread key
+ if !self.raw_try_lock() {
+ return Err(key);
+ }
+
+ // safety: we just locked the collection
+ let r = f(self.data_mut());
+
+ // safety: the collection is still locked
+ self.raw_unlock();
+
+ drop(key); // ensures the key stays valid long enough
+
+ Ok(r)
+ }
+ }
+
/// Locks the collection
///
/// This function returns a guard that can be used to access the underlying
@@ -536,10 +588,7 @@ impl<L: Lockable> RetryingLockCollection<L> {
/// *guard.0 += 1;
/// *guard.1 = "1";
/// ```
- pub fn lock<'g, 'key: 'g, Key: Keyable + 'key>(
- &'g self,
- key: Key,
- ) -> LockGuard<'key, L::Guard<'g>, Key> {
+ pub fn lock(&self, key: ThreadKey) -> LockGuard<L::Guard<'_>> {
unsafe {
// safety: we're taking the thread key
self.raw_lock();
@@ -548,7 +597,6 @@ impl<L: Lockable> RetryingLockCollection<L> {
// safety: we just locked the collection
guard: self.guard(),
key,
- _phantom: PhantomData,
}
}
}
@@ -583,10 +631,7 @@ impl<L: Lockable> RetryingLockCollection<L> {
/// };
///
/// ```
- pub fn try_lock<'g, 'key: 'g, Key: Keyable + 'key>(
- &'g self,
- key: Key,
- ) -> Result<LockGuard<'key, L::Guard<'g>, Key>, Key> {
+ pub fn try_lock(&self, key: ThreadKey) -> Result<LockGuard<L::Guard<'_>>, ThreadKey> {
unsafe {
// safety: we're taking the thread key
if self.raw_try_lock() {
@@ -594,7 +639,6 @@ impl<L: Lockable> RetryingLockCollection<L> {
// safety: we just succeeded in locking everything
guard: self.guard(),
key,
- _phantom: PhantomData,
})
} else {
Err(key)
@@ -620,13 +664,53 @@ impl<L: Lockable> RetryingLockCollection<L> {
/// *guard.1 = "1";
/// let key = RetryingLockCollection::<(Mutex<i32>, Mutex<&str>)>::unlock(guard);
/// ```
- pub fn unlock<'key, Key: Keyable + 'key>(guard: LockGuard<'key, L::Guard<'_>, Key>) -> Key {
+ pub fn unlock(guard: LockGuard<L::Guard<'_>>) -> ThreadKey {
drop(guard.guard);
guard.key
}
}
impl<L: Sharable> RetryingLockCollection<L> {
+ pub fn scoped_read<R>(&self, key: impl Keyable, f: impl Fn(L::DataRef<'_>) -> R) -> R {
+ unsafe {
+ // safety: we have the thread key
+ self.raw_read();
+
+ // safety: the data was just locked
+ let r = f(self.data_ref());
+
+ // safety: the collection is still locked
+ self.raw_unlock_read();
+
+ drop(key); // ensure the key stays alive long enough
+
+ r
+ }
+ }
+
+ pub fn scoped_try_read<Key: Keyable, R>(
+ &self,
+ key: Key,
+ f: impl Fn(L::DataRef<'_>) -> R,
+ ) -> Result<R, Key> {
+ unsafe {
+ // safety: we have the thread key
+ if !self.raw_try_read() {
+ return Err(key);
+ }
+
+ // safety: we just locked the collection
+ let r = f(self.data_ref());
+
+ // safety: the collection is still locked
+ self.raw_unlock_read();
+
+ drop(key); // ensures the key stays valid long enough
+
+ Ok(r)
+ }
+ }
+
/// Locks the collection, so that other threads can still read from it
///
/// This function returns a guard that can be used to access the underlying
@@ -647,10 +731,7 @@ impl<L: Sharable> RetryingLockCollection<L> {
/// assert_eq!(*guard.0, 0);
/// assert_eq!(*guard.1, "");
/// ```
- pub fn read<'g, 'key: 'g, Key: Keyable + 'key>(
- &'g self,
- key: Key,
- ) -> LockGuard<'key, L::ReadGuard<'g>, Key> {
+ pub fn read(&self, key: ThreadKey) -> LockGuard<L::ReadGuard<'_>> {
unsafe {
// safety: we're taking the thread key
self.raw_read();
@@ -659,7 +740,6 @@ impl<L: Sharable> RetryingLockCollection<L> {
// safety: we just locked the collection
guard: self.read_guard(),
key,
- _phantom: PhantomData,
}
}
}
@@ -687,25 +767,25 @@ impl<L: Sharable> RetryingLockCollection<L> {
/// let lock = RetryingLockCollection::new(data);
///
/// match lock.try_read(key) {
- /// Some(mut guard) => {
+ /// Ok(mut guard) => {
/// assert_eq!(*guard.0, 5);
/// assert_eq!(*guard.1, "6");
/// },
- /// None => unreachable!(),
+ /// Err(_) => unreachable!(),
/// };
///
/// ```
- pub fn try_read<'g, 'key: 'g, Key: Keyable + 'key>(
- &'g self,
- key: Key,
- ) -> Option<LockGuard<'key, L::ReadGuard<'g>, Key>> {
+ pub fn try_read(&self, key: ThreadKey) -> Result<LockGuard<L::ReadGuard<'_>>, ThreadKey> {
unsafe {
// safety: we're taking the thread key
- self.raw_try_lock().then(|| LockGuard {
+ if !self.raw_try_lock() {
+ return Err(key);
+ }
+
+ Ok(LockGuard {
// safety: we just succeeded in locking everything
guard: self.read_guard(),
key,
- _phantom: PhantomData,
})
}
}
@@ -726,9 +806,7 @@ impl<L: Sharable> RetryingLockCollection<L> {
/// let mut guard = lock.read(key);
/// let key = RetryingLockCollection::<(RwLock<i32>, RwLock<&str>)>::unlock_read(guard);
/// ```
- pub fn unlock_read<'key, Key: Keyable + 'key>(
- guard: LockGuard<'key, L::ReadGuard<'_>, Key>,
- ) -> Key {
+ pub fn unlock_read(guard: LockGuard<L::ReadGuard<'_>>) -> ThreadKey {
drop(guard.guard);
guard.key
}
@@ -833,7 +911,7 @@ where
mod tests {
use super::*;
use crate::collection::BoxedLockCollection;
- use crate::{Mutex, RwLock, ThreadKey};
+ use crate::{LockCollection, Mutex, RwLock, ThreadKey};
#[test]
fn nonduplicate_lock_references_are_allowed() {
@@ -869,7 +947,6 @@ mod tests {
let rwlock1 = RwLock::new(0);
let rwlock2 = RwLock::new(0);
let collection = RetryingLockCollection::try_new([&rwlock1, &rwlock2]).unwrap();
- // TODO Poisonable::read
let guard = collection.read(key);
@@ -909,13 +986,14 @@ mod tests {
#[test]
fn lock_empty_lock_collection() {
- let mut key = ThreadKey::get().unwrap();
+ let key = ThreadKey::get().unwrap();
let collection: RetryingLockCollection<[RwLock<i32>; 0]> = RetryingLockCollection::new([]);
- let guard = collection.lock(&mut key);
+ let guard = collection.lock(key);
assert!(guard.len() == 0);
+ let key = LockCollection::<[RwLock<_>; 0]>::unlock(guard);
- let guard = collection.read(&mut key);
+ let guard = collection.read(key);
assert!(guard.len() == 0);
}
}