summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/collection.rs2
-rw-r--r--src/collection/boxed.rs186
-rw-r--r--src/collection/owned.rs283
-rw-r--r--src/collection/ref.rs261
-rw-r--r--src/collection/retry.rs403
-rw-r--r--src/collection/utils.rs113
-rw-r--r--src/lockable.rs127
-rw-r--r--src/mutex.rs25
-rw-r--r--src/mutex/guard.rs6
-rw-r--r--src/mutex/mutex.rs85
-rw-r--r--src/poisonable.rs164
-rw-r--r--src/poisonable/poisonable.rs57
-rw-r--r--src/rwlock.rs323
-rw-r--r--src/rwlock/read_guard.rs4
-rw-r--r--src/rwlock/read_lock.rs54
-rw-r--r--src/rwlock/rwlock.rs115
-rw-r--r--src/rwlock/write_guard.rs6
-rw-r--r--src/rwlock/write_lock.rs52
-rw-r--r--src/thread.rs19
-rw-r--r--src/thread/scope.rst47
20 files changed, 1798 insertions, 534 deletions
diff --git a/src/collection.rs b/src/collection.rs
index e50cc30..f8c31d7 100644
--- a/src/collection.rs
+++ b/src/collection.rs
@@ -7,7 +7,7 @@ mod guard;
mod owned;
mod r#ref;
mod retry;
-mod utils;
+pub(crate) mod utils;
/// Locks a collection of locks, which cannot be shared immutably.
///
diff --git a/src/collection/boxed.rs b/src/collection/boxed.rs
index 364ec97..1891119 100644
--- a/src/collection/boxed.rs
+++ b/src/collection/boxed.rs
@@ -4,7 +4,9 @@ use std::fmt::Debug;
use crate::lockable::{Lockable, LockableIntoInner, OwnedLockable, RawLock, Sharable};
use crate::{Keyable, ThreadKey};
-use super::utils::ordered_contains_duplicates;
+use super::utils::{
+ ordered_contains_duplicates, scoped_read, scoped_try_read, scoped_try_write, scoped_write,
+};
use super::{utils, BoxedLockCollection, LockGuard};
unsafe impl<L: Lockable> RawLock for BoxedLockCollection<L> {
@@ -16,18 +18,18 @@ unsafe impl<L: Lockable> RawLock for BoxedLockCollection<L> {
}
}
- unsafe fn raw_lock(&self) {
- utils::ordered_lock(self.locks())
+ unsafe fn raw_write(&self) {
+ utils::ordered_write(self.locks())
}
- unsafe fn raw_try_lock(&self) -> bool {
+ unsafe fn raw_try_write(&self) -> bool {
println!("{}", self.locks().len());
- utils::ordered_try_lock(self.locks())
+ utils::ordered_try_write(self.locks())
}
- unsafe fn raw_unlock(&self) {
+ unsafe fn raw_unlock_write(&self) {
for lock in self.locks() {
- lock.raw_unlock();
+ lock.raw_unlock_write();
}
}
@@ -58,7 +60,7 @@ unsafe impl<L: Lockable> Lockable for BoxedLockCollection<L> {
Self: 'a;
fn get_ptrs<'a>(&'a self, ptrs: &mut Vec<&'a dyn RawLock>) {
- ptrs.extend(self.locks())
+ ptrs.push(self);
}
unsafe fn guard(&self) -> Self::Guard<'_> {
@@ -156,7 +158,7 @@ impl<L> Drop for BoxedLockCollection<L> {
}
}
-impl<T, L: AsRef<T>> AsRef<T> for BoxedLockCollection<L> {
+impl<T: ?Sized, L: AsRef<T>> AsRef<T> for BoxedLockCollection<L> {
fn as_ref(&self) -> &T {
self.child().as_ref()
}
@@ -364,44 +366,16 @@ impl<L: Lockable> BoxedLockCollection<L> {
}
}
- pub fn scoped_lock<R>(&self, key: impl Keyable, f: impl Fn(L::DataMut<'_>) -> R) -> R {
- unsafe {
- // safety: we have the thread key
- self.raw_lock();
-
- // safety: the data was just locked
- let r = f(self.data_mut());
-
- // safety: the collection is still locked
- self.raw_unlock();
-
- drop(key); // ensure the key stays alive long enough
-
- r
- }
+ pub fn scoped_lock<'a, R>(&'a self, key: impl Keyable, f: impl Fn(L::DataMut<'a>) -> R) -> R {
+ scoped_write(self, key, f)
}
- pub fn scoped_try_lock<Key: Keyable, R>(
- &self,
+ pub fn scoped_try_lock<'a, Key: Keyable, R>(
+ &'a self,
key: Key,
- f: impl Fn(L::DataMut<'_>) -> R,
+ f: impl Fn(L::DataMut<'a>) -> R,
) -> Result<R, Key> {
- unsafe {
- // safety: we have the thread key
- if !self.raw_try_lock() {
- return Err(key);
- }
-
- // safety: we just locked the collection
- let r = f(self.data_mut());
-
- // safety: the collection is still locked
- self.raw_unlock();
-
- drop(key); // ensures the key stays valid long enough
-
- Ok(r)
- }
+ scoped_try_write(self, key, f)
}
/// Locks the collection
@@ -427,7 +401,7 @@ impl<L: Lockable> BoxedLockCollection<L> {
pub fn lock(&self, key: ThreadKey) -> LockGuard<L::Guard<'_>> {
unsafe {
// safety: we have the thread key
- self.raw_lock();
+ self.raw_write();
LockGuard {
// safety: we've already acquired the lock
@@ -468,7 +442,7 @@ impl<L: Lockable> BoxedLockCollection<L> {
/// ```
pub fn try_lock(&self, key: ThreadKey) -> Result<LockGuard<L::Guard<'_>>, ThreadKey> {
let guard = unsafe {
- if !self.raw_try_lock() {
+ if !self.raw_try_write() {
return Err(key);
}
@@ -503,44 +477,16 @@ impl<L: Lockable> BoxedLockCollection<L> {
}
impl<L: Sharable> BoxedLockCollection<L> {
- pub fn scoped_read<R>(&self, key: impl Keyable, f: impl Fn(L::DataRef<'_>) -> R) -> R {
- unsafe {
- // safety: we have the thread key
- self.raw_read();
-
- // safety: the data was just locked
- let r = f(self.data_ref());
-
- // safety: the collection is still locked
- self.raw_unlock_read();
-
- drop(key); // ensure the key stays alive long enough
-
- r
- }
+ pub fn scoped_read<'a, R>(&'a self, key: impl Keyable, f: impl Fn(L::DataRef<'a>) -> R) -> R {
+ scoped_read(self, key, f)
}
- pub fn scoped_try_read<Key: Keyable, R>(
- &self,
+ pub fn scoped_try_read<'a, Key: Keyable, R>(
+ &'a self,
key: Key,
- f: impl Fn(L::DataRef<'_>) -> R,
+ f: impl Fn(L::DataRef<'a>) -> R,
) -> Result<R, Key> {
- unsafe {
- // safety: we have the thread key
- if !self.raw_try_read() {
- return Err(key);
- }
-
- // safety: we just locked the collection
- let r = f(self.data_ref());
-
- // safety: the collection is still locked
- self.raw_unlock_read();
-
- drop(key); // ensures the key stays valid long enough
-
- Ok(r)
- }
+ scoped_try_read(self, key, f)
}
/// Locks the collection, so that other threads can still read from it
@@ -765,6 +711,56 @@ mod tests {
}
#[test]
+ fn scoped_read_sees_changes() {
+ let mut key = ThreadKey::get().unwrap();
+ let mutexes = [RwLock::new(24), RwLock::new(42)];
+ let collection = BoxedLockCollection::new(mutexes);
+ collection.scoped_lock(&mut key, |guard| *guard[0] = 128);
+
+ let sum = collection.scoped_read(&mut key, |guard| {
+ assert_eq!(*guard[0], 128);
+ assert_eq!(*guard[1], 42);
+ *guard[0] + *guard[1]
+ });
+
+ assert_eq!(sum, 128 + 42);
+ }
+
+ #[test]
+ fn scoped_try_lock_can_fail() {
+ let key = ThreadKey::get().unwrap();
+ let collection = BoxedLockCollection::new([Mutex::new(1), Mutex::new(2)]);
+ let guard = collection.lock(key);
+
+ std::thread::scope(|s| {
+ s.spawn(|| {
+ let key = ThreadKey::get().unwrap();
+ let r = collection.scoped_try_lock(key, |_| {});
+ assert!(r.is_err());
+ });
+ });
+
+ drop(guard);
+ }
+
+ #[test]
+ fn scoped_try_read_can_fail() {
+ let key = ThreadKey::get().unwrap();
+ let collection = BoxedLockCollection::new([RwLock::new(1), RwLock::new(2)]);
+ let guard = collection.lock(key);
+
+ std::thread::scope(|s| {
+ s.spawn(|| {
+ let key = ThreadKey::get().unwrap();
+ let r = collection.scoped_try_read(key, |_| {});
+ assert!(r.is_err());
+ });
+ });
+
+ drop(guard);
+ }
+
+ #[test]
fn try_lock_works() {
let key = ThreadKey::get().unwrap();
let collection = BoxedLockCollection::new([Mutex::new(1), Mutex::new(2)]);
@@ -884,15 +880,41 @@ mod tests {
#[test]
fn works_in_collection() {
let key = ThreadKey::get().unwrap();
- let mutex1 = Mutex::new(0);
- let mutex2 = Mutex::new(1);
+ let mutex1 = RwLock::new(0);
+ let mutex2 = RwLock::new(1);
let collection =
BoxedLockCollection::try_new(BoxedLockCollection::try_new([&mutex1, &mutex2]).unwrap())
.unwrap();
- let guard = collection.lock(key);
+ let mut guard = collection.lock(key);
+ assert!(mutex1.is_locked());
+ assert!(mutex2.is_locked());
+ assert_eq!(*guard[0], 0);
+ assert_eq!(*guard[1], 1);
+ *guard[0] = 2;
+ let key = BoxedLockCollection::<BoxedLockCollection<[&RwLock<_>; 2]>>::unlock(guard);
+
+ let guard = collection.read(key);
assert!(mutex1.is_locked());
assert!(mutex2.is_locked());
+ assert_eq!(*guard[0], 2);
+ assert_eq!(*guard[1], 1);
drop(guard);
}
+
+ #[test]
+ fn as_ref_works() {
+ let mutexes = [Mutex::new(0), Mutex::new(1)];
+ let collection = BoxedLockCollection::new_ref(&mutexes);
+
+ assert!(std::ptr::addr_eq(&mutexes, collection.as_ref()))
+ }
+
+ #[test]
+ fn child() {
+ let mutexes = [Mutex::new(0), Mutex::new(1)];
+ let collection = BoxedLockCollection::new_ref(&mutexes);
+
+ assert!(std::ptr::addr_eq(&mutexes, *collection.child()))
+ }
}
diff --git a/src/collection/owned.rs b/src/collection/owned.rs
index b9cf313..68170d1 100644
--- a/src/collection/owned.rs
+++ b/src/collection/owned.rs
@@ -3,6 +3,7 @@ use crate::lockable::{
};
use crate::{Keyable, ThreadKey};
+use super::utils::{scoped_read, scoped_try_read, scoped_try_write, scoped_write};
use super::{utils, LockGuard, OwnedLockCollection};
unsafe impl<L: Lockable> RawLock for OwnedLockCollection<L> {
@@ -15,19 +16,19 @@ unsafe impl<L: Lockable> RawLock for OwnedLockCollection<L> {
}
}
- unsafe fn raw_lock(&self) {
- utils::ordered_lock(&utils::get_locks_unsorted(&self.data))
+ unsafe fn raw_write(&self) {
+ utils::ordered_write(&utils::get_locks_unsorted(&self.data))
}
- unsafe fn raw_try_lock(&self) -> bool {
+ unsafe fn raw_try_write(&self) -> bool {
let locks = utils::get_locks_unsorted(&self.data);
- utils::ordered_try_lock(&locks)
+ utils::ordered_try_write(&locks)
}
- unsafe fn raw_unlock(&self) {
+ unsafe fn raw_unlock_write(&self) {
let locks = utils::get_locks_unsorted(&self.data);
for lock in locks {
- lock.raw_unlock();
+ lock.raw_unlock_write();
}
}
@@ -62,7 +63,7 @@ unsafe impl<L: Lockable> Lockable for OwnedLockCollection<L> {
#[mutants::skip] // It's hard to test lkocks in an OwnedLockCollection, because they're owned
#[cfg(not(tarpaulin_include))]
fn get_ptrs<'a>(&'a self, ptrs: &mut Vec<&'a dyn RawLock>) {
- self.data.get_ptrs(ptrs)
+ ptrs.push(self)
}
unsafe fn guard(&self) -> Self::Guard<'_> {
@@ -146,7 +147,7 @@ impl<E: OwnedLockable + Extend<L>, L: OwnedLockable> Extend<L> for OwnedLockColl
// invariant that there is only one way to lock the collection. AsMut is fine,
// because the collection can't be locked as long as the reference is valid.
-impl<T, L: AsMut<T>> AsMut<T> for OwnedLockCollection<L> {
+impl<T: ?Sized, L: AsMut<T>> AsMut<T> for OwnedLockCollection<L> {
fn as_mut(&mut self) -> &mut T {
self.data.as_mut()
}
@@ -185,44 +186,16 @@ impl<L: OwnedLockable> OwnedLockCollection<L> {
Self { data }
}
- pub fn scoped_lock<R>(&self, key: impl Keyable, f: impl Fn(L::DataMut<'_>) -> R) -> R {
- unsafe {
- // safety: we have the thread key
- self.raw_lock();
-
- // safety: the data was just locked
- let r = f(self.data_mut());
-
- // safety: the collection is still locked
- self.raw_unlock();
-
- drop(key); // ensure the key stays alive long enough
-
- r
- }
+ pub fn scoped_lock<'a, R>(&'a self, key: impl Keyable, f: impl Fn(L::DataMut<'a>) -> R) -> R {
+ scoped_write(self, key, f)
}
- pub fn scoped_try_lock<Key: Keyable, R>(
- &self,
+ pub fn scoped_try_lock<'a, Key: Keyable, R>(
+ &'a self,
key: Key,
- f: impl Fn(L::DataMut<'_>) -> R,
+ f: impl Fn(L::DataMut<'a>) -> R,
) -> Result<R, Key> {
- unsafe {
- // safety: we have the thread key
- if !self.raw_try_lock() {
- return Err(key);
- }
-
- // safety: we just locked the collection
- let r = f(self.data_mut());
-
- // safety: the collection is still locked
- self.raw_unlock();
-
- drop(key); // ensures the key stays valid long enough
-
- Ok(r)
- }
+ scoped_try_write(self, key, f)
}
/// Locks the collection
@@ -249,7 +222,7 @@ impl<L: OwnedLockable> OwnedLockCollection<L> {
let guard = unsafe {
// safety: we have the thread key, and these locks happen in a
// predetermined order
- self.raw_lock();
+ self.raw_write();
// safety: we've locked all of this already
self.data.guard()
@@ -290,7 +263,7 @@ impl<L: OwnedLockable> OwnedLockCollection<L> {
/// ```
pub fn try_lock(&self, key: ThreadKey) -> Result<LockGuard<L::Guard<'_>>, ThreadKey> {
let guard = unsafe {
- if !self.raw_try_lock() {
+ if !self.raw_try_write() {
return Err(key);
}
@@ -327,44 +300,16 @@ impl<L: OwnedLockable> OwnedLockCollection<L> {
}
impl<L: Sharable> OwnedLockCollection<L> {
- pub fn scoped_read<R>(&self, key: impl Keyable, f: impl Fn(L::DataRef<'_>) -> R) -> R {
- unsafe {
- // safety: we have the thread key
- self.raw_read();
-
- // safety: the data was just locked
- let r = f(self.data_ref());
-
- // safety: the collection is still locked
- self.raw_unlock_read();
-
- drop(key); // ensure the key stays alive long enough
-
- r
- }
+ pub fn scoped_read<'a, R>(&'a self, key: impl Keyable, f: impl Fn(L::DataRef<'a>) -> R) -> R {
+ scoped_read(self, key, f)
}
- pub fn scoped_try_read<Key: Keyable, R>(
- &self,
+ pub fn scoped_try_read<'a, Key: Keyable, R>(
+ &'a self,
key: Key,
- f: impl Fn(L::DataRef<'_>) -> R,
+ f: impl Fn(L::DataRef<'a>) -> R,
) -> Result<R, Key> {
- unsafe {
- // safety: we have the thread key
- if !self.raw_try_read() {
- return Err(key);
- }
-
- // safety: we just locked the collection
- let r = f(self.data_ref());
-
- // safety: the collection is still locked
- self.raw_unlock_read();
-
- drop(key); // ensures the key stays valid long enough
-
- Ok(r)
- }
+ scoped_try_read(self, key, f)
}
/// Locks the collection, so that other threads can still read from it
@@ -554,7 +499,7 @@ impl<L: LockableIntoInner> OwnedLockCollection<L> {
#[cfg(test)]
mod tests {
use super::*;
- use crate::{Mutex, ThreadKey};
+ use crate::{Mutex, RwLock, ThreadKey};
#[test]
fn get_mut_applies_changes() {
@@ -604,6 +549,63 @@ mod tests {
}
#[test]
+ fn scoped_read_works() {
+ let mut key = ThreadKey::get().unwrap();
+ let collection = OwnedLockCollection::new([RwLock::new(24), RwLock::new(42)]);
+ let sum = collection.scoped_read(&mut key, |guard| guard[0] + guard[1]);
+ assert_eq!(sum, 24 + 42);
+ }
+
+ #[test]
+ fn scoped_lock_works() {
+ let mut key = ThreadKey::get().unwrap();
+ let collection = OwnedLockCollection::new([RwLock::new(24), RwLock::new(42)]);
+ collection.scoped_lock(&mut key, |guard| *guard[0] += *guard[1]);
+
+ let sum = collection.scoped_lock(&mut key, |guard| {
+ assert_eq!(*guard[0], 24 + 42);
+ assert_eq!(*guard[1], 42);
+ *guard[0] + *guard[1]
+ });
+
+ assert_eq!(sum, 24 + 42 + 42);
+ }
+
+ #[test]
+ fn scoped_try_lock_can_fail() {
+ let key = ThreadKey::get().unwrap();
+ let collection = OwnedLockCollection::new([Mutex::new(1), Mutex::new(2)]);
+ let guard = collection.lock(key);
+
+ std::thread::scope(|s| {
+ s.spawn(|| {
+ let key = ThreadKey::get().unwrap();
+ let r = collection.scoped_try_lock(key, |_| {});
+ assert!(r.is_err());
+ });
+ });
+
+ drop(guard);
+ }
+
+ #[test]
+ fn scoped_try_read_can_fail() {
+ let key = ThreadKey::get().unwrap();
+ let collection = OwnedLockCollection::new([RwLock::new(1), RwLock::new(2)]);
+ let guard = collection.lock(key);
+
+ std::thread::scope(|s| {
+ s.spawn(|| {
+ let key = ThreadKey::get().unwrap();
+ let r = collection.scoped_try_read(key, |_| {});
+ assert!(r.is_err());
+ });
+ });
+
+ drop(guard);
+ }
+
+ #[test]
fn try_lock_works_on_unlocked() {
let key = ThreadKey::get().unwrap();
let collection = OwnedLockCollection::new((Mutex::new(0), Mutex::new(1)));
@@ -630,6 +632,74 @@ mod tests {
}
#[test]
+ fn try_read_succeeds_for_unlocked_collection() {
+ let key = ThreadKey::get().unwrap();
+ let mutexes = [RwLock::new(24), RwLock::new(42)];
+ let collection = OwnedLockCollection::new(mutexes);
+ let guard = collection.try_read(key).unwrap();
+ assert_eq!(*guard[0], 24);
+ assert_eq!(*guard[1], 42);
+ }
+
+ #[test]
+ fn try_read_fails_on_locked() {
+ let key = ThreadKey::get().unwrap();
+ let collection = OwnedLockCollection::new((RwLock::new(0), RwLock::new(1)));
+
+ std::thread::scope(|s| {
+ s.spawn(|| {
+ let key = ThreadKey::get().unwrap();
+ #[allow(unused)]
+ let guard = collection.lock(key);
+ std::mem::forget(guard);
+ });
+ });
+
+ assert!(collection.try_read(key).is_err());
+ }
+
+ #[test]
+ fn can_read_twice_on_different_threads() {
+ let key = ThreadKey::get().unwrap();
+ let mutexes = [RwLock::new(24), RwLock::new(42)];
+ let collection = OwnedLockCollection::new(mutexes);
+
+ std::thread::scope(|s| {
+ s.spawn(|| {
+ let key = ThreadKey::get().unwrap();
+ let guard = collection.read(key);
+ assert_eq!(*guard[0], 24);
+ assert_eq!(*guard[1], 42);
+ std::mem::forget(guard);
+ });
+ });
+
+ let guard = collection.try_read(key).unwrap();
+ assert_eq!(*guard[0], 24);
+ assert_eq!(*guard[1], 42);
+ }
+
+ #[test]
+ fn unlock_collection_works() {
+ let key = ThreadKey::get().unwrap();
+ let collection = OwnedLockCollection::new((Mutex::new("foo"), Mutex::new("bar")));
+ let guard = collection.lock(key);
+
+ let key = OwnedLockCollection::<(Mutex<_>, Mutex<_>)>::unlock(guard);
+ assert!(collection.try_lock(key).is_ok())
+ }
+
+ #[test]
+ fn read_unlock_collection_works() {
+ let key = ThreadKey::get().unwrap();
+ let collection = OwnedLockCollection::new((RwLock::new("foo"), RwLock::new("bar")));
+ let guard = collection.read(key);
+
+ let key = OwnedLockCollection::<(&RwLock<_>, &RwLock<_>)>::unlock_read(guard);
+ assert!(collection.try_lock(key).is_ok())
+ }
+
+ #[test]
fn default_works() {
type MyCollection = OwnedLockCollection<(Mutex<i32>, Mutex<Option<i32>>, Mutex<String>)>;
let collection = MyCollection::default();
@@ -649,4 +719,59 @@ mod tests {
assert_eq!(collection.data.len(), 3);
}
+
+ #[test]
+ fn works_in_collection() {
+ let key = ThreadKey::get().unwrap();
+ let collection =
+ OwnedLockCollection::new(OwnedLockCollection::new([RwLock::new(0), RwLock::new(1)]));
+
+ let mut guard = collection.lock(key);
+ assert_eq!(*guard[0], 0);
+ assert_eq!(*guard[1], 1);
+ *guard[1] = 2;
+
+ let key = OwnedLockCollection::<OwnedLockCollection<[RwLock<_>; 2]>>::unlock(guard);
+ let guard = collection.read(key);
+ assert_eq!(*guard[0], 0);
+ assert_eq!(*guard[1], 2);
+ }
+
+ #[test]
+ fn as_mut_works() {
+ let mut mutexes = [Mutex::new(0), Mutex::new(1)];
+ let mut collection = OwnedLockCollection::new(&mut mutexes);
+
+ collection.as_mut()[0] = Mutex::new(42);
+
+ assert_eq!(*collection.as_mut()[0].get_mut(), 42);
+ }
+
+ #[test]
+ fn child_mut_works() {
+ let mut mutexes = [Mutex::new(0), Mutex::new(1)];
+ let mut collection = OwnedLockCollection::new(&mut mutexes);
+
+ collection.child_mut()[0] = Mutex::new(42);
+
+ assert_eq!(*collection.child_mut()[0].get_mut(), 42);
+ }
+
+ #[test]
+ fn into_child_works() {
+ let mutexes = [Mutex::new(0), Mutex::new(1)];
+ let mut collection = OwnedLockCollection::new(mutexes);
+
+ collection.child_mut()[0] = Mutex::new(42);
+
+ assert_eq!(
+ *collection
+ .into_child()
+ .as_mut()
+ .get_mut(0)
+ .unwrap()
+ .get_mut(),
+ 42
+ );
+ }
}
diff --git a/src/collection/ref.rs b/src/collection/ref.rs
index b68b72f..5f96533 100644
--- a/src/collection/ref.rs
+++ b/src/collection/ref.rs
@@ -3,7 +3,10 @@ use std::fmt::Debug;
use crate::lockable::{Lockable, OwnedLockable, RawLock, Sharable};
use crate::{Keyable, ThreadKey};
-use super::utils::{get_locks, ordered_contains_duplicates};
+use super::utils::{
+ get_locks, ordered_contains_duplicates, scoped_read, scoped_try_read, scoped_try_write,
+ scoped_write,
+};
use super::{utils, LockGuard, RefLockCollection};
impl<'a, L> IntoIterator for &'a RefLockCollection<'a, L>
@@ -27,17 +30,17 @@ unsafe impl<L: Lockable> RawLock for RefLockCollection<'_, L> {
}
}
- unsafe fn raw_lock(&self) {
- utils::ordered_lock(&self.locks)
+ unsafe fn raw_write(&self) {
+ utils::ordered_write(&self.locks)
}
- unsafe fn raw_try_lock(&self) -> bool {
- utils::ordered_try_lock(&self.locks)
+ unsafe fn raw_try_write(&self) -> bool {
+ utils::ordered_try_write(&self.locks)
}
- unsafe fn raw_unlock(&self) {
+ unsafe fn raw_unlock_write(&self) {
for lock in &self.locks {
- lock.raw_unlock();
+ lock.raw_unlock_write();
}
}
@@ -68,7 +71,7 @@ unsafe impl<L: Lockable> Lockable for RefLockCollection<'_, L> {
Self: 'a;
fn get_ptrs<'a>(&'a self, ptrs: &mut Vec<&'a dyn RawLock>) {
- ptrs.extend_from_slice(&self.locks);
+ ptrs.push(self)
}
unsafe fn guard(&self) -> Self::Guard<'_> {
@@ -100,7 +103,7 @@ unsafe impl<L: Sharable> Sharable for RefLockCollection<'_, L> {
}
}
-impl<T, L: AsRef<T>> AsRef<T> for RefLockCollection<'_, L> {
+impl<T: ?Sized, L: AsRef<T>> AsRef<T> for RefLockCollection<'_, L> {
fn as_ref(&self) -> &T {
self.data.as_ref()
}
@@ -234,44 +237,16 @@ impl<'a, L: Lockable> RefLockCollection<'a, L> {
Some(Self { data, locks })
}
- pub fn scoped_lock<R>(&self, key: impl Keyable, f: impl Fn(L::DataMut<'_>) -> R) -> R {
- unsafe {
- // safety: we have the thread key
- self.raw_lock();
-
- // safety: the data was just locked
- let r = f(self.data_mut());
-
- // safety: the collection is still locked
- self.raw_unlock();
-
- drop(key); // ensure the key stays alive long enough
-
- r
- }
+ pub fn scoped_lock<'s, R>(&'s self, key: impl Keyable, f: impl Fn(L::DataMut<'s>) -> R) -> R {
+ scoped_write(self, key, f)
}
- pub fn scoped_try_lock<Key: Keyable, R>(
- &self,
+ pub fn scoped_try_lock<'s, Key: Keyable, R>(
+ &'s self,
key: Key,
- f: impl Fn(L::DataMut<'_>) -> R,
+ f: impl Fn(L::DataMut<'s>) -> R,
) -> Result<R, Key> {
- unsafe {
- // safety: we have the thread key
- if !self.raw_try_lock() {
- return Err(key);
- }
-
- // safety: we just locked the collection
- let r = f(self.data_mut());
-
- // safety: the collection is still locked
- self.raw_unlock();
-
- drop(key); // ensures the key stays valid long enough
-
- Ok(r)
- }
+ scoped_try_write(self, key, f)
}
/// Locks the collection
@@ -298,7 +273,7 @@ impl<'a, L: Lockable> RefLockCollection<'a, L> {
pub fn lock(&self, key: ThreadKey) -> LockGuard<L::Guard<'_>> {
let guard = unsafe {
// safety: we have the thread key
- self.raw_lock();
+ self.raw_write();
// safety: we've locked all of this already
self.data.guard()
@@ -339,7 +314,7 @@ impl<'a, L: Lockable> RefLockCollection<'a, L> {
/// ```
pub fn try_lock(&self, key: ThreadKey) -> Result<LockGuard<L::Guard<'_>>, ThreadKey> {
let guard = unsafe {
- if !self.raw_try_lock() {
+ if !self.raw_try_write() {
return Err(key);
}
@@ -376,44 +351,16 @@ impl<'a, L: Lockable> RefLockCollection<'a, L> {
}
impl<L: Sharable> RefLockCollection<'_, L> {
- pub fn scoped_read<R>(&self, key: impl Keyable, f: impl Fn(L::DataRef<'_>) -> R) -> R {
- unsafe {
- // safety: we have the thread key
- self.raw_read();
-
- // safety: the data was just locked
- let r = f(self.data_ref());
-
- // safety: the collection is still locked
- self.raw_unlock_read();
-
- drop(key); // ensure the key stays alive long enough
-
- r
- }
+ pub fn scoped_read<'a, R>(&'a self, key: impl Keyable, f: impl Fn(L::DataRef<'a>) -> R) -> R {
+ scoped_read(self, key, f)
}
- pub fn scoped_try_read<Key: Keyable, R>(
- &self,
+ pub fn scoped_try_read<'a, Key: Keyable, R>(
+ &'a self,
key: Key,
- f: impl Fn(L::DataRef<'_>) -> R,
+ f: impl Fn(L::DataRef<'a>) -> R,
) -> Result<R, Key> {
- unsafe {
- // safety: we have the thread key
- if !self.raw_try_read() {
- return Err(key);
- }
-
- // safety: we just locked the collection
- let r = f(self.data_ref());
-
- // safety: the collection is still locked
- self.raw_unlock_read();
-
- drop(key); // ensures the key stays valid long enough
-
- Ok(r)
- }
+ scoped_try_read(self, key, f)
}
/// Locks the collection, so that other threads can still read from it
@@ -565,6 +512,88 @@ mod tests {
}
#[test]
+ fn from() {
+ let key = ThreadKey::get().unwrap();
+ let mutexes = [Mutex::new("foo"), Mutex::new("bar"), Mutex::new("baz")];
+ let collection = RefLockCollection::from(&mutexes);
+ let guard = collection.lock(key);
+ assert_eq!(*guard[0], "foo");
+ assert_eq!(*guard[1], "bar");
+ assert_eq!(*guard[2], "baz");
+ }
+
+ #[test]
+ fn scoped_lock_changes_collection() {
+ let mut key = ThreadKey::get().unwrap();
+ let mutexes = [Mutex::new(24), Mutex::new(42)];
+ let collection = RefLockCollection::new(&mutexes);
+ let sum = collection.scoped_lock(&mut key, |guard| {
+ *guard[0] = 128;
+ *guard[0] + *guard[1]
+ });
+
+ assert_eq!(sum, 128 + 42);
+
+ let guard = collection.lock(key);
+ assert_eq!(*guard[0], 128);
+ assert_eq!(*guard[1], 42);
+ }
+
+ #[test]
+ fn scoped_read_sees_changes() {
+ let mut key = ThreadKey::get().unwrap();
+ let mutexes = [RwLock::new(24), RwLock::new(42)];
+ let collection = RefLockCollection::new(&mutexes);
+ collection.scoped_lock(&mut key, |guard| {
+ *guard[0] = 128;
+ });
+
+ let sum = collection.scoped_read(&mut key, |guard| {
+ assert_eq!(*guard[0], 128);
+ assert_eq!(*guard[1], 42);
+ *guard[0] + *guard[1]
+ });
+
+ assert_eq!(sum, 128 + 42);
+ }
+
+ #[test]
+ fn scoped_try_lock_can_fail() {
+ let key = ThreadKey::get().unwrap();
+ let locks = [Mutex::new(1), Mutex::new(2)];
+ let collection = RefLockCollection::new(&locks);
+ let guard = collection.lock(key);
+
+ std::thread::scope(|s| {
+ s.spawn(|| {
+ let key = ThreadKey::get().unwrap();
+ let r = collection.scoped_try_lock(key, |_| {});
+ assert!(r.is_err());
+ });
+ });
+
+ drop(guard);
+ }
+
+ #[test]
+ fn scoped_try_read_can_fail() {
+ let key = ThreadKey::get().unwrap();
+ let locks = [RwLock::new(1), RwLock::new(2)];
+ let collection = RefLockCollection::new(&locks);
+ let guard = collection.lock(key);
+
+ std::thread::scope(|s| {
+ s.spawn(|| {
+ let key = ThreadKey::get().unwrap();
+ let r = collection.scoped_try_read(key, |_| {});
+ assert!(r.is_err());
+ });
+ });
+
+ drop(guard);
+ }
+
+ #[test]
fn try_lock_succeeds_for_unlocked_collection() {
let key = ThreadKey::get().unwrap();
let mutexes = [Mutex::new(24), Mutex::new(42)];
@@ -644,17 +673,85 @@ mod tests {
}
#[test]
+ fn into_ref_iterator() {
+ let mut key = ThreadKey::get().unwrap();
+ let mutexes = [Mutex::new(0), Mutex::new(1), Mutex::new(2)];
+ let collection = RefLockCollection::new(&mutexes);
+ for (i, mutex) in (&collection).into_iter().enumerate() {
+ mutex.scoped_lock(&mut key, |val| assert_eq!(*val, i))
+ }
+ }
+
+ #[test]
+ fn ref_iterator() {
+ let mut key = ThreadKey::get().unwrap();
+ let mutexes = [Mutex::new(0), Mutex::new(1), Mutex::new(2)];
+ let collection = RefLockCollection::new(&mutexes);
+ for (i, mutex) in collection.iter().enumerate() {
+ mutex.scoped_lock(&mut key, |val| assert_eq!(*val, i))
+ }
+ }
+
+ #[test]
fn works_in_collection() {
let key = ThreadKey::get().unwrap();
- let mutex1 = Mutex::new(0);
- let mutex2 = Mutex::new(1);
+ let mutex1 = RwLock::new(0);
+ let mutex2 = RwLock::new(1);
let collection0 = [&mutex1, &mutex2];
let collection1 = RefLockCollection::try_new(&collection0).unwrap();
let collection = RefLockCollection::try_new(&collection1).unwrap();
- let guard = collection.lock(key);
+ let mut guard = collection.lock(key);
assert!(mutex1.is_locked());
assert!(mutex2.is_locked());
+ assert_eq!(*guard[0], 0);
+ assert_eq!(*guard[1], 1);
+ *guard[1] = 2;
drop(guard);
+
+ let key = ThreadKey::get().unwrap();
+ let guard = collection.read(key);
+ assert!(mutex1.is_locked());
+ assert!(mutex2.is_locked());
+ assert_eq!(*guard[0], 0);
+ assert_eq!(*guard[1], 2);
+ }
+
+ #[test]
+ fn unlock_collection_works() {
+ let key = ThreadKey::get().unwrap();
+ let mutexes = (Mutex::new("foo"), Mutex::new("bar"));
+ let collection = RefLockCollection::new(&mutexes);
+ let guard = collection.lock(key);
+
+ let key = RefLockCollection::<(Mutex<_>, Mutex<_>)>::unlock(guard);
+ assert!(collection.try_lock(key).is_ok())
+ }
+
+ #[test]
+ fn read_unlock_collection_works() {
+ let key = ThreadKey::get().unwrap();
+ let locks = (RwLock::new("foo"), RwLock::new("bar"));
+ let collection = RefLockCollection::new(&locks);
+ let guard = collection.read(key);
+
+ let key = RefLockCollection::<(&RwLock<_>, &RwLock<_>)>::unlock_read(guard);
+ assert!(collection.try_lock(key).is_ok())
+ }
+
+ #[test]
+ fn as_ref_works() {
+ let mutexes = [Mutex::new(0), Mutex::new(1)];
+ let collection = RefLockCollection::new(&mutexes);
+
+ assert!(std::ptr::addr_eq(&mutexes, collection.as_ref()))
+ }
+
+ #[test]
+ fn child() {
+ let mutexes = [Mutex::new(0), Mutex::new(1)];
+ let collection = RefLockCollection::new(&mutexes);
+
+ assert!(std::ptr::addr_eq(&mutexes, collection.child()))
}
}
diff --git a/src/collection/retry.rs b/src/collection/retry.rs
index 775ea29..70e5183 100644
--- a/src/collection/retry.rs
+++ b/src/collection/retry.rs
@@ -9,7 +9,8 @@ use crate::lockable::{
use crate::{Keyable, ThreadKey};
use super::utils::{
- attempt_to_recover_locks_from_panic, attempt_to_recover_reads_from_panic, get_locks_unsorted,
+ attempt_to_recover_reads_from_panic, attempt_to_recover_writes_from_panic, get_locks_unsorted,
+ scoped_read, scoped_try_read, scoped_try_write, scoped_write,
};
use super::{LockGuard, RetryingLockCollection};
@@ -40,7 +41,7 @@ unsafe impl<L: Lockable> RawLock for RetryingLockCollection<L> {
}
}
- unsafe fn raw_lock(&self) {
+ unsafe fn raw_write(&self) {
let locks = get_locks_unsorted(&self.data);
if locks.is_empty() {
@@ -57,7 +58,7 @@ unsafe impl<L: Lockable> RawLock for RetryingLockCollection<L> {
// This prevents us from entering a spin loop waiting for
// the same lock to be unlocked
// safety: we have the thread key
- locks[first_index.get()].raw_lock();
+ locks[first_index.get()].raw_write();
for (i, lock) in locks.iter().enumerate() {
if i == first_index.get() {
// we've already locked this one
@@ -69,15 +70,15 @@ unsafe impl<L: Lockable> RawLock for RetryingLockCollection<L> {
// it does return false, then the lock function is called
// immediately after, causing a panic
// safety: we have the thread key
- if lock.raw_try_lock() {
+ if lock.raw_try_write() {
locked.set(locked.get() + 1);
} else {
// safety: we already locked all of these
- attempt_to_recover_locks_from_panic(&locks[0..i]);
+ attempt_to_recover_writes_from_panic(&locks[0..i]);
if first_index.get() >= i {
// safety: this is already locked and can't be
// unlocked by the previous loop
- locks[first_index.get()].raw_unlock();
+ locks[first_index.get()].raw_unlock_write();
}
// nothing is locked anymore
@@ -94,15 +95,15 @@ unsafe impl<L: Lockable> RawLock for RetryingLockCollection<L> {
}
},
|| {
- utils::attempt_to_recover_locks_from_panic(&locks[0..locked.get()]);
+ utils::attempt_to_recover_writes_from_panic(&locks[0..locked.get()]);
if first_index.get() >= locked.get() {
- locks[first_index.get()].raw_unlock();
+ locks[first_index.get()].raw_unlock_write();
}
},
)
}
- unsafe fn raw_try_lock(&self) -> bool {
+ unsafe fn raw_try_write(&self) -> bool {
let locks = get_locks_unsorted(&self.data);
if locks.is_empty() {
@@ -117,26 +118,26 @@ unsafe impl<L: Lockable> RawLock for RetryingLockCollection<L> {
|| unsafe {
for (i, lock) in locks.iter().enumerate() {
// safety: we have the thread key
- if lock.raw_try_lock() {
+ if lock.raw_try_write() {
locked.set(locked.get() + 1);
} else {
// safety: we already locked all of these
- attempt_to_recover_locks_from_panic(&locks[0..i]);
+ attempt_to_recover_writes_from_panic(&locks[0..i]);
return false;
}
}
true
},
- || utils::attempt_to_recover_locks_from_panic(&locks[0..locked.get()]),
+ || utils::attempt_to_recover_writes_from_panic(&locks[0..locked.get()]),
)
}
- unsafe fn raw_unlock(&self) {
+ unsafe fn raw_unlock_write(&self) {
let locks = get_locks_unsorted(&self.data);
for lock in locks {
- lock.raw_unlock();
+ lock.raw_unlock_write();
}
}
@@ -243,7 +244,7 @@ unsafe impl<L: Lockable> Lockable for RetryingLockCollection<L> {
Self: 'a;
fn get_ptrs<'a>(&'a self, ptrs: &mut Vec<&'a dyn RawLock>) {
- self.data.get_ptrs(ptrs)
+ ptrs.push(self)
}
unsafe fn guard(&self) -> Self::Guard<'_> {
@@ -347,13 +348,13 @@ impl<E: OwnedLockable + Extend<L>, L: OwnedLockable> Extend<L> for RetryingLockC
}
}
-impl<T, L: AsRef<T>> AsRef<T> for RetryingLockCollection<L> {
+impl<T: ?Sized, L: AsRef<T>> AsRef<T> for RetryingLockCollection<L> {
fn as_ref(&self) -> &T {
self.data.as_ref()
}
}
-impl<T, L: AsMut<T>> AsMut<T> for RetryingLockCollection<L> {
+impl<T: ?Sized, L: AsMut<T>> AsMut<T> for RetryingLockCollection<L> {
fn as_mut(&mut self) -> &mut T {
self.data.as_mut()
}
@@ -389,7 +390,8 @@ impl<L: OwnedLockable> RetryingLockCollection<L> {
/// ```
#[must_use]
pub const fn new(data: L) -> Self {
- Self { data }
+ // safety: the data cannot cannot contain references
+ unsafe { Self::new_unchecked(data) }
}
}
@@ -410,7 +412,8 @@ impl<'a, L: OwnedLockable> RetryingLockCollection<&'a L> {
/// ```
#[must_use]
pub const fn new_ref(data: &'a L) -> Self {
- Self { data }
+ // safety: the data cannot cannot contain references
+ unsafe { Self::new_unchecked(data) }
}
}
@@ -525,47 +528,20 @@ impl<L: Lockable> RetryingLockCollection<L> {
/// ```
#[must_use]
pub fn try_new(data: L) -> Option<Self> {
- (!contains_duplicates(&data)).then_some(Self { data })
+ // safety: the data is checked for duplicates before returning the collection
+ (!contains_duplicates(&data)).then_some(unsafe { Self::new_unchecked(data) })
}
- pub fn scoped_lock<R>(&self, key: impl Keyable, f: impl Fn(L::DataMut<'_>) -> R) -> R {
- unsafe {
- // safety: we have the thread key
- self.raw_lock();
-
- // safety: the data was just locked
- let r = f(self.data_mut());
-
- // safety: the collection is still locked
- self.raw_unlock();
-
- drop(key); // ensure the key stays alive long enough
-
- r
- }
+ pub fn scoped_lock<'a, R>(&'a self, key: impl Keyable, f: impl Fn(L::DataMut<'a>) -> R) -> R {
+ scoped_write(self, key, f)
}
- pub fn scoped_try_lock<Key: Keyable, R>(
- &self,
+ pub fn scoped_try_lock<'a, Key: Keyable, R>(
+ &'a self,
key: Key,
- f: impl Fn(L::DataMut<'_>) -> R,
+ f: impl Fn(L::DataMut<'a>) -> R,
) -> Result<R, Key> {
- unsafe {
- // safety: we have the thread key
- if !self.raw_try_lock() {
- return Err(key);
- }
-
- // safety: we just locked the collection
- let r = f(self.data_mut());
-
- // safety: the collection is still locked
- self.raw_unlock();
-
- drop(key); // ensures the key stays valid long enough
-
- Ok(r)
- }
+ scoped_try_write(self, key, f)
}
/// Locks the collection
@@ -591,7 +567,7 @@ impl<L: Lockable> RetryingLockCollection<L> {
pub fn lock(&self, key: ThreadKey) -> LockGuard<L::Guard<'_>> {
unsafe {
// safety: we're taking the thread key
- self.raw_lock();
+ self.raw_write();
LockGuard {
// safety: we just locked the collection
@@ -634,7 +610,7 @@ impl<L: Lockable> RetryingLockCollection<L> {
pub fn try_lock(&self, key: ThreadKey) -> Result<LockGuard<L::Guard<'_>>, ThreadKey> {
unsafe {
// safety: we're taking the thread key
- if self.raw_try_lock() {
+ if self.raw_try_write() {
Ok(LockGuard {
// safety: we just succeeded in locking everything
guard: self.guard(),
@@ -671,44 +647,16 @@ impl<L: Lockable> RetryingLockCollection<L> {
}
impl<L: Sharable> RetryingLockCollection<L> {
- pub fn scoped_read<R>(&self, key: impl Keyable, f: impl Fn(L::DataRef<'_>) -> R) -> R {
- unsafe {
- // safety: we have the thread key
- self.raw_read();
-
- // safety: the data was just locked
- let r = f(self.data_ref());
-
- // safety: the collection is still locked
- self.raw_unlock_read();
-
- drop(key); // ensure the key stays alive long enough
-
- r
- }
+ pub fn scoped_read<'a, R>(&'a self, key: impl Keyable, f: impl Fn(L::DataRef<'a>) -> R) -> R {
+ scoped_read(self, key, f)
}
- pub fn scoped_try_read<Key: Keyable, R>(
- &self,
+ pub fn scoped_try_read<'a, Key: Keyable, R>(
+ &'a self,
key: Key,
- f: impl Fn(L::DataRef<'_>) -> R,
+ f: impl Fn(L::DataRef<'a>) -> R,
) -> Result<R, Key> {
- unsafe {
- // safety: we have the thread key
- if !self.raw_try_read() {
- return Err(key);
- }
-
- // safety: we just locked the collection
- let r = f(self.data_ref());
-
- // safety: the collection is still locked
- self.raw_unlock_read();
-
- drop(key); // ensures the key stays valid long enough
-
- Ok(r)
- }
+ scoped_try_read(self, key, f)
}
/// Locks the collection, so that other threads can still read from it
@@ -778,7 +726,7 @@ impl<L: Sharable> RetryingLockCollection<L> {
pub fn try_read(&self, key: ThreadKey) -> Result<LockGuard<L::ReadGuard<'_>>, ThreadKey> {
unsafe {
// safety: we're taking the thread key
- if !self.raw_try_lock() {
+ if !self.raw_try_read() {
return Err(key);
}
@@ -911,7 +859,7 @@ where
mod tests {
use super::*;
use crate::collection::BoxedLockCollection;
- use crate::{LockCollection, Mutex, RwLock, ThreadKey};
+ use crate::{Mutex, RwLock, ThreadKey};
#[test]
fn nonduplicate_lock_references_are_allowed() {
@@ -927,6 +875,159 @@ mod tests {
}
#[test]
+ #[allow(clippy::float_cmp)]
+ fn uses_correct_default() {
+ let collection =
+ RetryingLockCollection::<(RwLock<f64>, Mutex<Option<i32>>, Mutex<usize>)>::default();
+ let tuple = collection.into_inner();
+ assert_eq!(tuple.0, 0.0);
+ assert!(tuple.1.is_none());
+ assert_eq!(tuple.2, 0)
+ }
+
+ #[test]
+ fn from() {
+ let key = ThreadKey::get().unwrap();
+ let collection =
+ RetryingLockCollection::from([Mutex::new("foo"), Mutex::new("bar"), Mutex::new("baz")]);
+ let guard = collection.lock(key);
+ assert_eq!(*guard[0], "foo");
+ assert_eq!(*guard[1], "bar");
+ assert_eq!(*guard[2], "baz");
+ }
+
+ #[test]
+ fn new_ref_works() {
+ let key = ThreadKey::get().unwrap();
+ let mutexes = [Mutex::new(0), Mutex::new(1)];
+ let collection = RetryingLockCollection::new_ref(&mutexes);
+ collection.scoped_lock(key, |guard| {
+ assert_eq!(*guard[0], 0);
+ assert_eq!(*guard[1], 1);
+ })
+ }
+
+ #[test]
+ fn scoped_read_sees_changes() {
+ let mut key = ThreadKey::get().unwrap();
+ let mutexes = [RwLock::new(24), RwLock::new(42)];
+ let collection = RetryingLockCollection::new(mutexes);
+ collection.scoped_lock(&mut key, |guard| *guard[0] = 128);
+
+ let sum = collection.scoped_read(&mut key, |guard| {
+ assert_eq!(*guard[0], 128);
+ assert_eq!(*guard[1], 42);
+ *guard[0] + *guard[1]
+ });
+
+ assert_eq!(sum, 128 + 42);
+ }
+
+ #[test]
+ fn get_mut_affects_scoped_read() {
+ let mut key = ThreadKey::get().unwrap();
+ let mutexes = [RwLock::new(24), RwLock::new(42)];
+ let mut collection = RetryingLockCollection::new(mutexes);
+ let guard = collection.get_mut();
+ *guard[0] = 128;
+
+ let sum = collection.scoped_read(&mut key, |guard| {
+ assert_eq!(*guard[0], 128);
+ assert_eq!(*guard[1], 42);
+ *guard[0] + *guard[1]
+ });
+
+ assert_eq!(sum, 128 + 42);
+ }
+
+ #[test]
+ fn scoped_try_lock_can_fail() {
+ let key = ThreadKey::get().unwrap();
+ let collection = RetryingLockCollection::new([Mutex::new(1), Mutex::new(2)]);
+ let guard = collection.lock(key);
+
+ std::thread::scope(|s| {
+ s.spawn(|| {
+ let key = ThreadKey::get().unwrap();
+ let r = collection.scoped_try_lock(key, |_| {});
+ assert!(r.is_err());
+ });
+ });
+
+ drop(guard);
+ }
+
+ #[test]
+ fn scoped_try_read_can_fail() {
+ let key = ThreadKey::get().unwrap();
+ let collection = RetryingLockCollection::new([RwLock::new(1), RwLock::new(2)]);
+ let guard = collection.lock(key);
+
+ std::thread::scope(|s| {
+ s.spawn(|| {
+ let key = ThreadKey::get().unwrap();
+ let r = collection.scoped_try_read(key, |_| {});
+ assert!(r.is_err());
+ });
+ });
+
+ drop(guard);
+ }
+
+ #[test]
+ fn try_lock_works() {
+ let key = ThreadKey::get().unwrap();
+ let collection = RetryingLockCollection::new([Mutex::new(1), Mutex::new(2)]);
+ let guard = collection.try_lock(key);
+
+ std::thread::scope(|s| {
+ s.spawn(|| {
+ let key = ThreadKey::get().unwrap();
+ let guard = collection.try_lock(key);
+ assert!(guard.is_err());
+ });
+ });
+
+ assert!(guard.is_ok());
+ }
+
+ #[test]
+ fn try_read_works() {
+ let key = ThreadKey::get().unwrap();
+ let collection = RetryingLockCollection::new([RwLock::new(1), RwLock::new(2)]);
+ let guard = collection.try_read(key);
+
+ std::thread::scope(|s| {
+ s.spawn(|| {
+ let key = ThreadKey::get().unwrap();
+ let guard = collection.try_read(key);
+ assert!(guard.is_ok());
+ });
+ });
+
+ assert!(guard.is_ok());
+ }
+
+ #[test]
+ fn try_read_fails_for_locked_collection() {
+ let key = ThreadKey::get().unwrap();
+ let mutexes = [RwLock::new(24), RwLock::new(42)];
+ let collection = RetryingLockCollection::new_ref(&mutexes);
+
+ std::thread::scope(|s| {
+ s.spawn(|| {
+ let key = ThreadKey::get().unwrap();
+ let guard = mutexes[1].write(key);
+ assert_eq!(*guard, 42);
+ std::mem::forget(guard);
+ });
+ });
+
+ let guard = collection.try_read(key);
+ assert!(guard.is_err());
+ }
+
+ #[test]
fn locks_all_inner_mutexes() {
let key = ThreadKey::get().unwrap();
let mutex1 = Mutex::new(0);
@@ -974,6 +1075,55 @@ mod tests {
}
#[test]
+ fn from_iterator() {
+ let key = ThreadKey::get().unwrap();
+ let collection: RetryingLockCollection<Vec<Mutex<&str>>> =
+ [Mutex::new("foo"), Mutex::new("bar"), Mutex::new("baz")]
+ .into_iter()
+ .collect();
+ let guard = collection.lock(key);
+ assert_eq!(*guard[0], "foo");
+ assert_eq!(*guard[1], "bar");
+ assert_eq!(*guard[2], "baz");
+ }
+
+ #[test]
+ fn into_owned_iterator() {
+ let collection = RetryingLockCollection::new([Mutex::new(0), Mutex::new(1), Mutex::new(2)]);
+ for (i, mutex) in collection.into_iter().enumerate() {
+ assert_eq!(mutex.into_inner(), i);
+ }
+ }
+
+ #[test]
+ fn into_ref_iterator() {
+ let mut key = ThreadKey::get().unwrap();
+ let collection = RetryingLockCollection::new([Mutex::new(0), Mutex::new(1), Mutex::new(2)]);
+ for (i, mutex) in (&collection).into_iter().enumerate() {
+ mutex.scoped_lock(&mut key, |val| assert_eq!(*val, i))
+ }
+ }
+
+ #[test]
+ fn ref_iterator() {
+ let mut key = ThreadKey::get().unwrap();
+ let collection = RetryingLockCollection::new([Mutex::new(0), Mutex::new(1), Mutex::new(2)]);
+ for (i, mutex) in collection.iter().enumerate() {
+ mutex.scoped_lock(&mut key, |val| assert_eq!(*val, i))
+ }
+ }
+
+ #[test]
+ fn mut_iterator() {
+ let mut key = ThreadKey::get().unwrap();
+ let mut collection =
+ RetryingLockCollection::new([Mutex::new(0), Mutex::new(1), Mutex::new(2)]);
+ for (i, mutex) in collection.iter_mut().enumerate() {
+ mutex.scoped_lock(&mut key, |val| assert_eq!(*val, i))
+ }
+ }
+
+ #[test]
fn extend_collection() {
let mutex1 = Mutex::new(0);
let mutex2 = Mutex::new(0);
@@ -991,9 +1141,76 @@ mod tests {
let guard = collection.lock(key);
assert!(guard.len() == 0);
- let key = LockCollection::<[RwLock<_>; 0]>::unlock(guard);
+ let key = RetryingLockCollection::<[RwLock<_>; 0]>::unlock(guard);
+
+ let guard = collection.read(key);
+ assert!(guard.len() == 0);
+ }
+
+ #[test]
+ fn read_empty_lock_collection() {
+ let key = ThreadKey::get().unwrap();
+ let collection: RetryingLockCollection<[RwLock<i32>; 0]> = RetryingLockCollection::new([]);
let guard = collection.read(key);
assert!(guard.len() == 0);
+ let key = RetryingLockCollection::<[RwLock<_>; 0]>::unlock_read(guard);
+
+ let guard = collection.lock(key);
+ assert!(guard.len() == 0);
+ }
+
+ #[test]
+ fn as_ref_works() {
+ let mutexes = [Mutex::new(0), Mutex::new(1)];
+ let collection = RetryingLockCollection::new_ref(&mutexes);
+
+ assert!(std::ptr::addr_eq(&mutexes, collection.as_ref()))
+ }
+
+ #[test]
+ fn as_mut_works() {
+ let mut mutexes = [Mutex::new(0), Mutex::new(1)];
+ let mut collection = RetryingLockCollection::new(&mut mutexes);
+
+ collection.as_mut()[0] = Mutex::new(42);
+
+ assert_eq!(*collection.as_mut()[0].get_mut(), 42);
+ }
+
+ #[test]
+ fn child() {
+ let mutexes = [Mutex::new(0), Mutex::new(1)];
+ let collection = RetryingLockCollection::new_ref(&mutexes);
+
+ assert!(std::ptr::addr_eq(&mutexes, *collection.child()))
+ }
+
+ #[test]
+ fn child_mut_works() {
+ let mut mutexes = [Mutex::new(0), Mutex::new(1)];
+ let mut collection = RetryingLockCollection::new(&mut mutexes);
+
+ collection.child_mut()[0] = Mutex::new(42);
+
+ assert_eq!(*collection.child_mut()[0].get_mut(), 42);
+ }
+
+ #[test]
+ fn into_child_works() {
+ let mutexes = [Mutex::new(0), Mutex::new(1)];
+ let mut collection = RetryingLockCollection::new(mutexes);
+
+ collection.child_mut()[0] = Mutex::new(42);
+
+ assert_eq!(
+ *collection
+ .into_child()
+ .as_mut()
+ .get_mut(0)
+ .unwrap()
+ .get_mut(),
+ 42
+ );
}
}
diff --git a/src/collection/utils.rs b/src/collection/utils.rs
index 1d96e5c..59a68da 100644
--- a/src/collection/utils.rs
+++ b/src/collection/utils.rs
@@ -1,7 +1,8 @@
use std::cell::Cell;
use crate::handle_unwind::handle_unwind;
-use crate::lockable::{Lockable, RawLock};
+use crate::lockable::{Lockable, RawLock, Sharable};
+use crate::Keyable;
#[must_use]
pub fn get_locks<L: Lockable>(data: &L) -> Vec<&dyn RawLock> {
@@ -32,18 +33,18 @@ pub fn ordered_contains_duplicates(l: &[&dyn RawLock]) -> bool {
}
/// Lock a set of locks in the given order. It's UB to call this without a `ThreadKey`
-pub unsafe fn ordered_lock(locks: &[&dyn RawLock]) {
+pub unsafe fn ordered_write(locks: &[&dyn RawLock]) {
// these will be unlocked in case of a panic
let locked = Cell::new(0);
handle_unwind(
|| {
for lock in locks {
- lock.raw_lock();
+ lock.raw_write();
locked.set(locked.get() + 1);
}
},
- || attempt_to_recover_locks_from_panic(&locks[0..locked.get()]),
+ || attempt_to_recover_writes_from_panic(&locks[0..locked.get()]),
)
}
@@ -65,19 +66,19 @@ pub unsafe fn ordered_read(locks: &[&dyn RawLock]) {
/// Locks the locks in the order they are given. This causes deadlock if the
/// locks contain duplicates, or if this is called by multiple threads with the
/// locks in different orders.
-pub unsafe fn ordered_try_lock(locks: &[&dyn RawLock]) -> bool {
+pub unsafe fn ordered_try_write(locks: &[&dyn RawLock]) -> bool {
let locked = Cell::new(0);
handle_unwind(
|| unsafe {
for (i, lock) in locks.iter().enumerate() {
// safety: we have the thread key
- if lock.raw_try_lock() {
+ if lock.raw_try_write() {
locked.set(locked.get() + 1);
} else {
for lock in &locks[0..i] {
// safety: this lock was already acquired
- lock.raw_unlock();
+ lock.raw_unlock_write();
}
return false;
}
@@ -87,7 +88,7 @@ pub unsafe fn ordered_try_lock(locks: &[&dyn RawLock]) -> bool {
},
||
// safety: everything in locked is locked
- attempt_to_recover_locks_from_panic(&locks[0..locked.get()]),
+ attempt_to_recover_writes_from_panic(&locks[0..locked.get()]),
)
}
@@ -120,12 +121,104 @@ pub unsafe fn ordered_try_read(locks: &[&dyn RawLock]) -> bool {
)
}
+pub fn scoped_write<'a, L: RawLock + Lockable, R>(
+ collection: &'a L,
+ key: impl Keyable,
+ f: impl FnOnce(L::DataMut<'a>) -> R,
+) -> R {
+ unsafe {
+ // safety: we have the key
+ collection.raw_write();
+
+ // safety: we just locked this
+ let r = f(collection.data_mut());
+
+ // this ensures the key is held long enough
+ drop(key);
+
+ // safety: we've locked already, and aren't using the data again
+ collection.raw_unlock_write();
+
+ r
+ }
+}
+
+pub fn scoped_try_write<'a, L: RawLock + Lockable, Key: Keyable, R>(
+ collection: &'a L,
+ key: Key,
+ f: impl FnOnce(L::DataMut<'a>) -> R,
+) -> Result<R, Key> {
+ unsafe {
+ // safety: we have the key
+ if !collection.raw_try_write() {
+ return Err(key);
+ }
+
+ // safety: we just locked this
+ let r = f(collection.data_mut());
+
+ // this ensures the key is held long enough
+ drop(key);
+
+ // safety: we've locked already, and aren't using the data again
+ collection.raw_unlock_write();
+
+ Ok(r)
+ }
+}
+
+pub fn scoped_read<'a, L: RawLock + Sharable, R>(
+ collection: &'a L,
+ key: impl Keyable,
+ f: impl FnOnce(L::DataRef<'a>) -> R,
+) -> R {
+ unsafe {
+ // safety: we have the key
+ collection.raw_read();
+
+ // safety: we just locked this
+ let r = f(collection.data_ref());
+
+ // this ensures the key is held long enough
+ drop(key);
+
+ // safety: we've locked already, and aren't using the data again
+ collection.raw_unlock_read();
+
+ r
+ }
+}
+
+pub fn scoped_try_read<'a, L: RawLock + Sharable, Key: Keyable, R>(
+ collection: &'a L,
+ key: Key,
+ f: impl FnOnce(L::DataRef<'a>) -> R,
+) -> Result<R, Key> {
+ unsafe {
+ // safety: we have the key
+ if !collection.raw_try_read() {
+ return Err(key);
+ }
+
+ // safety: we just locked this
+ let r = f(collection.data_ref());
+
+ // this ensures the key is held long enough
+ drop(key);
+
+ // safety: we've locked already, and aren't using the data again
+ collection.raw_unlock_read();
+
+ Ok(r)
+ }
+}
+
/// Unlocks the already locked locks in order to recover from a panic
-pub unsafe fn attempt_to_recover_locks_from_panic(locks: &[&dyn RawLock]) {
+pub unsafe fn attempt_to_recover_writes_from_panic(locks: &[&dyn RawLock]) {
handle_unwind(
|| {
// safety: the caller assumes that these are already locked
- locks.iter().for_each(|lock| lock.raw_unlock());
+ locks.iter().for_each(|lock| lock.raw_unlock_write());
},
// if we get another panic in here, we'll just have to poison what remains
|| locks.iter().for_each(|l| l.poison()),
diff --git a/src/lockable.rs b/src/lockable.rs
index f125c02..94042ea 100644
--- a/src/lockable.rs
+++ b/src/lockable.rs
@@ -28,7 +28,7 @@ pub unsafe trait RawLock {
/// value is alive.
///
/// [`ThreadKey`]: `crate::ThreadKey`
- unsafe fn raw_lock(&self);
+ unsafe fn raw_write(&self);
/// Attempt to lock without blocking.
///
@@ -41,14 +41,14 @@ pub unsafe trait RawLock {
/// value is alive.
///
/// [`ThreadKey`]: `crate::ThreadKey`
- unsafe fn raw_try_lock(&self) -> bool;
+ unsafe fn raw_try_write(&self) -> bool;
/// Releases the lock
///
/// # Safety
///
/// It is undefined behavior to use this if the lock is not acquired
- unsafe fn raw_unlock(&self);
+ unsafe fn raw_unlock_write(&self);
/// Blocks until the data the lock protects can be safely read.
///
@@ -625,7 +625,7 @@ unsafe impl<T: OwnedLockable> OwnedLockable for Vec<T> {}
#[cfg(test)]
mod tests {
use super::*;
- use crate::{Mutex, RwLock};
+ use crate::{LockCollection, Mutex, RwLock, ThreadKey};
#[test]
fn mut_ref_get_ptrs() {
@@ -719,6 +719,57 @@ mod tests {
}
#[test]
+ fn vec_guard_ref() {
+ let key = ThreadKey::get().unwrap();
+ let locks = vec![RwLock::new(1), RwLock::new(2)];
+ let collection = LockCollection::new(locks);
+
+ let mut guard = collection.lock(key);
+ assert_eq!(*guard[0], 1);
+ assert_eq!(*guard[1], 2);
+ *guard[0] = 3;
+
+ let key = LockCollection::<Vec<RwLock<_>>>::unlock(guard);
+ let guard = collection.read(key);
+ assert_eq!(*guard[0], 3);
+ assert_eq!(*guard[1], 2);
+ }
+
+ #[test]
+ fn vec_data_mut() {
+ let mut key = ThreadKey::get().unwrap();
+ let mutexes = vec![Mutex::new(1), Mutex::new(2)];
+ let collection = LockCollection::new(mutexes);
+ collection.scoped_lock(&mut key, |guard| {
+ assert_eq!(*guard[0], 1);
+ assert_eq!(*guard[1], 2);
+ *guard[0] = 3;
+ });
+
+ collection.scoped_lock(&mut key, |guard| {
+ assert_eq!(*guard[0], 3);
+ assert_eq!(*guard[1], 2);
+ })
+ }
+
+ #[test]
+ fn vec_data_ref() {
+ let mut key = ThreadKey::get().unwrap();
+ let mutexes = vec![RwLock::new(1), RwLock::new(2)];
+ let collection = LockCollection::new(mutexes);
+ collection.scoped_lock(&mut key, |guard| {
+ assert_eq!(*guard[0], 1);
+ assert_eq!(*guard[1], 2);
+ *guard[0] = 3;
+ });
+
+ collection.scoped_read(&mut key, |guard| {
+ assert_eq!(*guard[0], 3);
+ assert_eq!(*guard[1], 2);
+ })
+ }
+
+ #[test]
fn box_get_ptrs_empty() {
let locks: Box<[Mutex<()>]> = Box::from([]);
let mut lock_ptrs = Vec::new();
@@ -757,4 +808,72 @@ mod tests {
assert_eq!(*lock_ptrs[0], 1);
assert_eq!(*lock_ptrs[1], 2);
}
+
+ #[test]
+ fn box_guard_mut() {
+ let key = ThreadKey::get().unwrap();
+ let x = [Mutex::new(1), Mutex::new(2)];
+ let collection: LockCollection<Box<[Mutex<_>]>> = LockCollection::new(Box::new(x));
+
+ let mut guard = collection.lock(key);
+ assert_eq!(*guard[0], 1);
+ assert_eq!(*guard[1], 2);
+ *guard[0] = 3;
+
+ let key = LockCollection::<Box<[Mutex<_>]>>::unlock(guard);
+ let guard = collection.lock(key);
+ assert_eq!(*guard[0], 3);
+ assert_eq!(*guard[1], 2);
+ }
+
+ #[test]
+ fn box_data_mut() {
+ let mut key = ThreadKey::get().unwrap();
+ let mutexes = vec![Mutex::new(1), Mutex::new(2)].into_boxed_slice();
+ let collection = LockCollection::new(mutexes);
+ collection.scoped_lock(&mut key, |guard| {
+ assert_eq!(*guard[0], 1);
+ assert_eq!(*guard[1], 2);
+ *guard[0] = 3;
+ });
+
+ collection.scoped_lock(&mut key, |guard| {
+ assert_eq!(*guard[0], 3);
+ assert_eq!(*guard[1], 2);
+ });
+ }
+
+ #[test]
+ fn box_guard_ref() {
+ let key = ThreadKey::get().unwrap();
+ let locks = [RwLock::new(1), RwLock::new(2)];
+ let collection: LockCollection<Box<[RwLock<_>]>> = LockCollection::new(Box::new(locks));
+
+ let mut guard = collection.lock(key);
+ assert_eq!(*guard[0], 1);
+ assert_eq!(*guard[1], 2);
+ *guard[0] = 3;
+
+ let key = LockCollection::<Box<[RwLock<_>]>>::unlock(guard);
+ let guard = collection.read(key);
+ assert_eq!(*guard[0], 3);
+ assert_eq!(*guard[1], 2);
+ }
+
+ #[test]
+ fn box_data_ref() {
+ let mut key = ThreadKey::get().unwrap();
+ let mutexes = vec![RwLock::new(1), RwLock::new(2)].into_boxed_slice();
+ let collection = LockCollection::new(mutexes);
+ collection.scoped_lock(&mut key, |guard| {
+ assert_eq!(*guard[0], 1);
+ assert_eq!(*guard[1], 2);
+ *guard[0] = 3;
+ });
+
+ collection.scoped_read(&mut key, |guard| {
+ assert_eq!(*guard[0], 3);
+ assert_eq!(*guard[1], 2);
+ });
+ }
}
diff --git a/src/mutex.rs b/src/mutex.rs
index 2022501..413bd8a 100644
--- a/src/mutex.rs
+++ b/src/mutex.rs
@@ -136,10 +136,7 @@ pub struct Mutex<T: ?Sized, R> {
/// A reference to a mutex that unlocks it when dropped.
///
/// This is similar to [`MutexGuard`], except it does not hold a [`Keyable`].
-pub struct MutexRef<'a, T: ?Sized + 'a, R: RawMutex>(
- &'a Mutex<T, R>,
- PhantomData<(&'a mut T, R::GuardMarker)>,
-);
+pub struct MutexRef<'a, T: ?Sized + 'a, R: RawMutex>(&'a Mutex<T, R>, PhantomData<R::GuardMarker>);
/// An RAII implementation of a “scoped lock” of a mutex.
///
@@ -183,6 +180,26 @@ mod tests {
}
#[test]
+ fn from_works() {
+ let key = ThreadKey::get().unwrap();
+ let mutex: crate::Mutex<_> = Mutex::from("Hello, world!");
+
+ let guard = mutex.lock(key);
+ assert_eq!(*guard, "Hello, world!");
+ }
+
+ #[test]
+ fn as_mut_works() {
+ let key = ThreadKey::get().unwrap();
+ let mut mutex = crate::Mutex::from(42);
+
+ let mut_ref = mutex.as_mut();
+ *mut_ref = 24;
+
+ mutex.scoped_lock(key, |guard| assert_eq!(*guard, 24))
+ }
+
+ #[test]
fn display_works_for_guard() {
let key = ThreadKey::get().unwrap();
let mutex: crate::Mutex<_> = Mutex::new("Hello, world!");
diff --git a/src/mutex/guard.rs b/src/mutex/guard.rs
index 22e59c1..d88fded 100644
--- a/src/mutex/guard.rs
+++ b/src/mutex/guard.rs
@@ -39,7 +39,7 @@ impl<T: ?Sized, R: RawMutex> Drop for MutexRef<'_, T, R> {
fn drop(&mut self) {
// safety: this guard is being destroyed, so the data cannot be
// accessed without locking again
- unsafe { self.0.raw_unlock() }
+ unsafe { self.0.raw_unlock_write() }
}
}
@@ -79,7 +79,7 @@ impl<'a, T: ?Sized, R: RawMutex> MutexRef<'a, T, R> {
/// Creates a reference to the underlying data of a mutex without
/// attempting to lock it or take ownership of the key. But it's also quite
/// dangerous to drop.
- pub(crate) unsafe fn new(mutex: &'a Mutex<T, R>) -> Self {
+ pub(crate) const unsafe fn new(mutex: &'a Mutex<T, R>) -> Self {
Self(mutex, PhantomData)
}
}
@@ -139,7 +139,7 @@ impl<'a, T: ?Sized, R: RawMutex> MutexGuard<'a, T, R> {
/// Create a guard to the given mutex. Undefined if multiple guards to the
/// same mutex exist at once.
#[must_use]
- pub(super) unsafe fn new(mutex: &'a Mutex<T, R>, thread_key: ThreadKey) -> Self {
+ pub(super) const unsafe fn new(mutex: &'a Mutex<T, R>, thread_key: ThreadKey) -> Self {
Self {
mutex: MutexRef(mutex, PhantomData),
thread_key,
diff --git a/src/mutex/mutex.rs b/src/mutex/mutex.rs
index 1d8ce8b..f0fb680 100644
--- a/src/mutex/mutex.rs
+++ b/src/mutex/mutex.rs
@@ -5,6 +5,7 @@ use std::panic::AssertUnwindSafe;
use lock_api::RawMutex;
+use crate::collection::utils;
use crate::handle_unwind::handle_unwind;
use crate::lockable::{Lockable, LockableGetMut, LockableIntoInner, OwnedLockable, RawLock};
use crate::poisonable::PoisonFlag;
@@ -17,7 +18,7 @@ unsafe impl<T: ?Sized, R: RawMutex> RawLock for Mutex<T, R> {
self.poison.poison();
}
- unsafe fn raw_lock(&self) {
+ unsafe fn raw_write(&self) {
assert!(!self.poison.is_poisoned(), "The mutex has been killed");
// if the closure unwraps, then the mutex will be killed
@@ -25,7 +26,7 @@ unsafe impl<T: ?Sized, R: RawMutex> RawLock for Mutex<T, R> {
handle_unwind(|| this.raw.lock(), || self.poison())
}
- unsafe fn raw_try_lock(&self) -> bool {
+ unsafe fn raw_try_write(&self) -> bool {
if self.poison.is_poisoned() {
return false;
}
@@ -35,7 +36,7 @@ unsafe impl<T: ?Sized, R: RawMutex> RawLock for Mutex<T, R> {
handle_unwind(|| this.raw.try_lock(), || self.poison())
}
- unsafe fn raw_unlock(&self) {
+ unsafe fn raw_unlock_write(&self) {
// if the closure unwraps, then the mutex will be killed
let this = AssertUnwindSafe(self);
handle_unwind(|| this.raw.unlock(), || self.poison())
@@ -43,20 +44,26 @@ unsafe impl<T: ?Sized, R: RawMutex> RawLock for Mutex<T, R> {
// this is the closest thing to a read we can get, but Sharable isn't
// implemented for this
+ #[mutants::skip]
+ #[cfg(not(tarpaulin_include))]
unsafe fn raw_read(&self) {
- self.raw_lock()
+ self.raw_write()
}
+ #[mutants::skip]
+ #[cfg(not(tarpaulin_include))]
unsafe fn raw_try_read(&self) -> bool {
- self.raw_try_lock()
+ self.raw_try_write()
}
+ #[mutants::skip]
+ #[cfg(not(tarpaulin_include))]
unsafe fn raw_unlock_read(&self) {
- self.raw_unlock()
+ self.raw_unlock_write()
}
}
-unsafe impl<T: Send, R: RawMutex + Send + Sync> Lockable for Mutex<T, R> {
+unsafe impl<T, R: RawMutex> Lockable for Mutex<T, R> {
type Guard<'g>
= MutexRef<'g, T, R>
where
@@ -80,7 +87,7 @@ unsafe impl<T: Send, R: RawMutex + Send + Sync> Lockable for Mutex<T, R> {
}
}
-impl<T: Send, R: RawMutex + Send + Sync> LockableIntoInner for Mutex<T, R> {
+impl<T: Send, R: RawMutex> LockableIntoInner for Mutex<T, R> {
type Inner = T;
fn into_inner(self) -> Self::Inner {
@@ -88,7 +95,7 @@ impl<T: Send, R: RawMutex + Send + Sync> LockableIntoInner for Mutex<T, R> {
}
}
-impl<T: Send, R: RawMutex + Send + Sync> LockableGetMut for Mutex<T, R> {
+impl<T: Send, R: RawMutex> LockableGetMut for Mutex<T, R> {
type Inner<'a>
= &'a mut T
where
@@ -99,7 +106,7 @@ impl<T: Send, R: RawMutex + Send + Sync> LockableGetMut for Mutex<T, R> {
}
}
-unsafe impl<T: Send, R: RawMutex + Send + Sync> OwnedLockable for Mutex<T, R> {}
+unsafe impl<T: Send, R: RawMutex> OwnedLockable for Mutex<T, R> {}
impl<T, R: RawMutex> Mutex<T, R> {
/// Create a new unlocked `Mutex`.
@@ -140,7 +147,7 @@ impl<T, R: RawMutex> Mutex<T, R> {
#[mutants::skip]
#[cfg(not(tarpaulin_include))]
-impl<T: ?Sized + Debug, R: RawMutex> Debug for Mutex<T, R> {
+impl<T: Debug, R: RawMutex> Debug for Mutex<T, R> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
// safety: this is just a try lock, and the value is dropped
// immediately after, so there's no risk of blocking ourselves
@@ -222,45 +229,21 @@ impl<T: ?Sized, R> Mutex<T, R> {
}
}
-impl<T: ?Sized, R: RawMutex> Mutex<T, R> {
- pub fn scoped_lock<Ret>(&self, key: impl Keyable, f: impl FnOnce(&mut T) -> Ret) -> Ret {
- unsafe {
- // safety: we have the thread key
- self.raw_lock();
-
- // safety: the mutex was just locked
- let r = f(self.data.get().as_mut().unwrap_unchecked());
-
- // safety: we locked the mutex already
- self.raw_unlock();
-
- drop(key); // ensures we drop the key in the correct place
-
- r
- }
+impl<T, R: RawMutex> Mutex<T, R> {
+ pub fn scoped_lock<'a, Ret>(
+ &'a self,
+ key: impl Keyable,
+ f: impl FnOnce(&'a mut T) -> Ret,
+ ) -> Ret {
+ utils::scoped_write(self, key, f)
}
- pub fn scoped_try_lock<Key: Keyable, Ret>(
- &self,
+ pub fn scoped_try_lock<'a, Key: Keyable, Ret>(
+ &'a self,
key: Key,
- f: impl FnOnce(&mut T) -> Ret,
+ f: impl FnOnce(&'a mut T) -> Ret,
) -> Result<Ret, Key> {
- unsafe {
- // safety: we have the thread key
- if !self.raw_try_lock() {
- return Err(key);
- }
-
- // safety: the mutex was just locked
- let r = f(self.data.get().as_mut().unwrap_unchecked());
-
- // safety: we locked the mutex already
- self.raw_unlock();
-
- drop(key); // ensures we drop the key in the correct place
-
- Ok(r)
- }
+ utils::scoped_try_write(self, key, f)
}
/// Block the thread until this mutex can be locked, and lock it.
@@ -289,7 +272,7 @@ impl<T: ?Sized, R: RawMutex> Mutex<T, R> {
pub fn lock(&self, key: ThreadKey) -> MutexGuard<'_, T, R> {
unsafe {
// safety: we have the thread key
- self.raw_lock();
+ self.raw_write();
// safety: we just locked the mutex
MutexGuard::new(self, key)
@@ -332,7 +315,7 @@ impl<T: ?Sized, R: RawMutex> Mutex<T, R> {
pub fn try_lock(&self, key: ThreadKey) -> Result<MutexGuard<'_, T, R>, ThreadKey> {
unsafe {
// safety: we have the key to the mutex
- if self.raw_try_lock() {
+ if self.raw_try_write() {
// safety: we just locked the mutex
Ok(MutexGuard::new(self, key))
} else {
@@ -350,7 +333,7 @@ impl<T: ?Sized, R: RawMutex> Mutex<T, R> {
/// Lock without a [`ThreadKey`]. It is undefined behavior to do this without
/// owning the [`ThreadKey`].
pub(crate) unsafe fn try_lock_no_key(&self) -> Option<MutexRef<'_, T, R>> {
- self.raw_try_lock().then_some(MutexRef(self, PhantomData))
+ self.raw_try_write().then_some(MutexRef(self, PhantomData))
}
/// Consumes the [`MutexGuard`], and consequently unlocks its `Mutex`.
@@ -370,9 +353,7 @@ impl<T: ?Sized, R: RawMutex> Mutex<T, R> {
/// ```
#[must_use]
pub fn unlock(guard: MutexGuard<'_, T, R>) -> ThreadKey {
- unsafe {
- guard.mutex.0.raw_unlock();
- }
+ drop(guard.mutex);
guard.thread_key
}
}
diff --git a/src/poisonable.rs b/src/poisonable.rs
index 8d5a810..f9b0622 100644
--- a/src/poisonable.rs
+++ b/src/poisonable.rs
@@ -99,7 +99,7 @@ mod tests {
use super::*;
use crate::lockable::Lockable;
- use crate::{LockCollection, Mutex, ThreadKey};
+ use crate::{LockCollection, Mutex, RwLock, ThreadKey};
#[test]
fn locking_poisoned_mutex_returns_error_in_collection() {
@@ -127,6 +127,31 @@ mod tests {
}
#[test]
+ fn locking_poisoned_rwlock_returns_error_in_collection() {
+ let key = ThreadKey::get().unwrap();
+ let mutex = LockCollection::new(Poisonable::new(RwLock::new(42)));
+
+ std::thread::scope(|s| {
+ s.spawn(|| {
+ let key = ThreadKey::get().unwrap();
+ let mut guard1 = mutex.read(key);
+ let guard = guard1.as_deref_mut().unwrap();
+ assert_eq!(**guard, 42);
+ panic!();
+
+ #[allow(unreachable_code)]
+ drop(guard1);
+ })
+ .join()
+ .unwrap_err();
+ });
+
+ let error = mutex.read(key);
+ let error = error.as_deref().unwrap_err();
+ assert_eq!(***error.get_ref(), 42);
+ }
+
+ #[test]
fn non_poisoned_get_mut_is_ok() {
let mut mutex = Poisonable::new(Mutex::new(42));
let guard = mutex.get_mut();
@@ -201,6 +226,118 @@ mod tests {
}
#[test]
+ fn scoped_lock_can_poison() {
+ let key = ThreadKey::get().unwrap();
+ let mutex = Poisonable::new(Mutex::new(42));
+
+ let r = std::panic::catch_unwind(|| {
+ mutex.scoped_lock(key, |num| {
+ *num.unwrap() = 56;
+ panic!();
+ })
+ });
+ assert!(r.is_err());
+
+ let key = ThreadKey::get().unwrap();
+ assert!(mutex.is_poisoned());
+ mutex.scoped_lock(key, |num| {
+ let Err(error) = num else { panic!() };
+ mutex.clear_poison();
+ let guard = error.into_inner();
+ assert_eq!(*guard, 56);
+ });
+ assert!(!mutex.is_poisoned());
+ }
+
+ #[test]
+ fn scoped_try_lock_can_fail() {
+ let key = ThreadKey::get().unwrap();
+ let mutex = Poisonable::new(Mutex::new(42));
+ let guard = mutex.lock(key);
+
+ std::thread::scope(|s| {
+ s.spawn(|| {
+ let key = ThreadKey::get().unwrap();
+ let r = mutex.scoped_try_lock(key, |_| {});
+ assert!(r.is_err());
+ });
+ });
+
+ drop(guard);
+ }
+
+ #[test]
+ fn scoped_try_lock_can_succeed() {
+ let rwlock = Poisonable::new(RwLock::new(42));
+
+ std::thread::scope(|s| {
+ s.spawn(|| {
+ let key = ThreadKey::get().unwrap();
+ let r = rwlock.scoped_try_lock(key, |guard| {
+ assert_eq!(*guard.unwrap(), 42);
+ });
+ assert!(r.is_ok());
+ });
+ });
+ }
+
+ #[test]
+ fn scoped_read_can_poison() {
+ let key = ThreadKey::get().unwrap();
+ let mutex = Poisonable::new(RwLock::new(42));
+
+ let r = std::panic::catch_unwind(|| {
+ mutex.scoped_read(key, |num| {
+ assert_eq!(*num.unwrap(), 42);
+ panic!();
+ })
+ });
+ assert!(r.is_err());
+
+ let key = ThreadKey::get().unwrap();
+ assert!(mutex.is_poisoned());
+ mutex.scoped_read(key, |num| {
+ let Err(error) = num else { panic!() };
+ mutex.clear_poison();
+ let guard = error.into_inner();
+ assert_eq!(*guard, 42);
+ });
+ assert!(!mutex.is_poisoned());
+ }
+
+ #[test]
+ fn scoped_try_read_can_fail() {
+ let key = ThreadKey::get().unwrap();
+ let rwlock = Poisonable::new(RwLock::new(42));
+ let guard = rwlock.lock(key);
+
+ std::thread::scope(|s| {
+ s.spawn(|| {
+ let key = ThreadKey::get().unwrap();
+ let r = rwlock.scoped_try_read(key, |_| {});
+ assert!(r.is_err());
+ });
+ });
+
+ drop(guard);
+ }
+
+ #[test]
+ fn scoped_try_read_can_succeed() {
+ let rwlock = Poisonable::new(RwLock::new(42));
+
+ std::thread::scope(|s| {
+ s.spawn(|| {
+ let key = ThreadKey::get().unwrap();
+ let r = rwlock.scoped_try_read(key, |guard| {
+ assert_eq!(*guard.unwrap(), 42);
+ });
+ assert!(r.is_ok());
+ });
+ });
+ }
+
+ #[test]
fn display_works() {
let key = ThreadKey::get().unwrap();
let mutex = Poisonable::new(Mutex::new("Hello, world!"));
@@ -318,6 +455,31 @@ mod tests {
}
#[test]
+ fn clear_poison_for_poisoned_rwlock() {
+ let lock = Arc::new(Poisonable::new(RwLock::new(0)));
+ let c_mutex = Arc::clone(&lock);
+
+ let _ = std::thread::spawn(move || {
+ let key = ThreadKey::get().unwrap();
+ let lock = c_mutex.read(key).unwrap();
+ assert_eq!(*lock, 42);
+ panic!(); // the mutex gets poisoned
+ })
+ .join();
+
+ assert!(lock.is_poisoned());
+
+ let key = ThreadKey::get().unwrap();
+ let _ = lock.lock(key).unwrap_or_else(|mut e| {
+ **e.get_mut() = 1;
+ lock.clear_poison();
+ e.into_inner()
+ });
+
+ assert!(!lock.is_poisoned());
+ }
+
+ #[test]
fn error_as_ref() {
let mutex = Poisonable::new(Mutex::new("foo"));
diff --git a/src/poisonable/poisonable.rs b/src/poisonable/poisonable.rs
index efe4ed0..ff78330 100644
--- a/src/poisonable/poisonable.rs
+++ b/src/poisonable/poisonable.rs
@@ -1,5 +1,6 @@
use std::panic::{RefUnwindSafe, UnwindSafe};
+use crate::handle_unwind::handle_unwind;
use crate::lockable::{
Lockable, LockableGetMut, LockableIntoInner, OwnedLockable, RawLock, Sharable,
};
@@ -17,16 +18,16 @@ unsafe impl<L: Lockable + RawLock> RawLock for Poisonable<L> {
self.inner.poison()
}
- unsafe fn raw_lock(&self) {
- self.inner.raw_lock()
+ unsafe fn raw_write(&self) {
+ self.inner.raw_write()
}
- unsafe fn raw_try_lock(&self) -> bool {
- self.inner.raw_try_lock()
+ unsafe fn raw_try_write(&self) -> bool {
+ self.inner.raw_try_write()
}
- unsafe fn raw_unlock(&self) {
- self.inner.raw_unlock()
+ unsafe fn raw_unlock_write(&self) {
+ self.inner.raw_unlock_write()
}
unsafe fn raw_read(&self) {
@@ -313,13 +314,19 @@ impl<L: Lockable + RawLock> Poisonable<L> {
) -> R {
unsafe {
// safety: we have the thread key
- self.raw_lock();
+ self.raw_write();
// safety: the data was just locked
- let r = f(self.data_mut());
+ let r = handle_unwind(
+ || f(self.data_mut()),
+ || {
+ self.poisoned.poison();
+ self.raw_unlock_write();
+ },
+ );
// safety: the collection is still locked
- self.raw_unlock();
+ self.raw_unlock_write();
drop(key); // ensure the key stays alive long enough
@@ -334,15 +341,21 @@ impl<L: Lockable + RawLock> Poisonable<L> {
) -> Result<R, Key> {
unsafe {
// safety: we have the thread key
- if !self.raw_try_lock() {
+ if !self.raw_try_write() {
return Err(key);
}
// safety: we just locked the collection
- let r = f(self.data_mut());
+ let r = handle_unwind(
+ || f(self.data_mut()),
+ || {
+ self.poisoned.poison();
+ self.raw_unlock_write();
+ },
+ );
// safety: the collection is still locked
- self.raw_unlock();
+ self.raw_unlock_write();
drop(key); // ensures the key stays valid long enough
@@ -383,7 +396,7 @@ impl<L: Lockable + RawLock> Poisonable<L> {
/// ```
pub fn lock(&self, key: ThreadKey) -> PoisonResult<PoisonGuard<'_, L::Guard<'_>>> {
unsafe {
- self.inner.raw_lock();
+ self.inner.raw_write();
self.guard(key)
}
}
@@ -434,7 +447,7 @@ impl<L: Lockable + RawLock> Poisonable<L> {
/// [`WouldBlock`]: `TryLockPoisonableError::WouldBlock`
pub fn try_lock(&self, key: ThreadKey) -> TryLockPoisonableResult<'_, L::Guard<'_>> {
unsafe {
- if self.inner.raw_try_lock() {
+ if self.inner.raw_try_write() {
Ok(self.guard(key)?)
} else {
Err(TryLockPoisonableError::WouldBlock(key))
@@ -487,7 +500,13 @@ impl<L: Sharable + RawLock> Poisonable<L> {
self.raw_read();
// safety: the data was just locked
- let r = f(self.data_ref());
+ let r = handle_unwind(
+ || f(self.data_ref()),
+ || {
+ self.poisoned.poison();
+ self.raw_unlock_read();
+ },
+ );
// safety: the collection is still locked
self.raw_unlock_read();
@@ -510,7 +529,13 @@ impl<L: Sharable + RawLock> Poisonable<L> {
}
// safety: we just locked the collection
- let r = f(self.data_ref());
+ let r = handle_unwind(
+ || f(self.data_ref()),
+ || {
+ self.poisoned.poison();
+ self.raw_unlock_read();
+ },
+ );
// safety: the collection is still locked
self.raw_unlock_read();
diff --git a/src/rwlock.rs b/src/rwlock.rs
index b604370..2d3dd85 100644
--- a/src/rwlock.rs
+++ b/src/rwlock.rs
@@ -75,7 +75,7 @@ pub struct WriteLock<'l, T: ?Sized, R>(&'l RwLock<T, R>);
/// [`Keyable`].
pub struct RwLockReadRef<'a, T: ?Sized, R: RawRwLock>(
&'a RwLock<T, R>,
- PhantomData<(&'a mut T, R::GuardMarker)>,
+ PhantomData<R::GuardMarker>,
);
/// RAII structure that unlocks the exclusive write access to a [`RwLock`]
@@ -84,7 +84,7 @@ pub struct RwLockReadRef<'a, T: ?Sized, R: RawRwLock>(
/// [`Keyable`].
pub struct RwLockWriteRef<'a, T: ?Sized, R: RawRwLock>(
&'a RwLock<T, R>,
- PhantomData<(&'a mut T, R::GuardMarker)>,
+ PhantomData<R::GuardMarker>,
);
/// RAII structure used to release the shared read access of a lock when
@@ -115,6 +115,8 @@ pub struct RwLockWriteGuard<'a, T: ?Sized, R: RawRwLock> {
#[cfg(test)]
mod tests {
use crate::lockable::Lockable;
+ use crate::lockable::RawLock;
+ use crate::LockCollection;
use crate::RwLock;
use crate::ThreadKey;
@@ -149,6 +151,33 @@ mod tests {
}
#[test]
+ fn read_lock_scoped_works() {
+ let mut key = ThreadKey::get().unwrap();
+ let lock: crate::RwLock<_> = RwLock::new(42);
+ let reader = ReadLock::new(&lock);
+
+ reader.scoped_lock(&mut key, |num| assert_eq!(*num, 42));
+ }
+
+ #[test]
+ fn read_lock_scoped_try_fails_during_write() {
+ let key = ThreadKey::get().unwrap();
+ let lock: crate::RwLock<_> = RwLock::new(42);
+ let reader = ReadLock::new(&lock);
+ let guard = lock.write(key);
+
+ std::thread::scope(|s| {
+ s.spawn(|| {
+ let key = ThreadKey::get().unwrap();
+ let r = reader.scoped_try_lock(key, |_| {});
+ assert!(r.is_err());
+ });
+ });
+
+ drop(guard);
+ }
+
+ #[test]
fn write_lock_unlocked_when_initialized() {
let key = ThreadKey::get().unwrap();
let lock: crate::RwLock<_> = RwLock::new("Hello, world!");
@@ -165,7 +194,7 @@ mod tests {
readlock.get_ptrs(&mut lock_ptrs);
assert_eq!(lock_ptrs.len(), 1);
- assert!(std::ptr::addr_eq(lock_ptrs[0], &rwlock));
+ assert!(std::ptr::addr_eq(lock_ptrs[0], &readlock));
}
#[test]
@@ -176,7 +205,34 @@ mod tests {
writelock.get_ptrs(&mut lock_ptrs);
assert_eq!(lock_ptrs.len(), 1);
- assert!(std::ptr::addr_eq(lock_ptrs[0], &rwlock));
+ assert!(std::ptr::addr_eq(lock_ptrs[0], &writelock));
+ }
+
+ #[test]
+ fn write_lock_scoped_works() {
+ let mut key = ThreadKey::get().unwrap();
+ let lock: crate::RwLock<_> = RwLock::new(42);
+ let writer = WriteLock::new(&lock);
+
+ writer.scoped_lock(&mut key, |num| assert_eq!(*num, 42));
+ }
+
+ #[test]
+ fn write_lock_scoped_try_fails_during_write() {
+ let key = ThreadKey::get().unwrap();
+ let lock: crate::RwLock<_> = RwLock::new(42);
+ let writer = WriteLock::new(&lock);
+ let guard = lock.write(key);
+
+ std::thread::scope(|s| {
+ s.spawn(|| {
+ let key = ThreadKey::get().unwrap();
+ let r = writer.scoped_try_lock(key, |_| {});
+ assert!(r.is_err());
+ });
+ });
+
+ drop(guard);
}
#[test]
@@ -226,6 +282,69 @@ mod tests {
}
#[test]
+ fn locked_after_scoped_write() {
+ let mut key = ThreadKey::get().unwrap();
+ let lock = crate::RwLock::new("Hello, world!");
+
+ lock.scoped_write(&mut key, |guard| {
+ assert!(lock.is_locked());
+ assert_eq!(*guard, "Hello, world!");
+
+ std::thread::scope(|s| {
+ s.spawn(|| {
+ let key = ThreadKey::get().unwrap();
+ assert!(lock.try_read(key).is_err());
+ });
+ })
+ })
+ }
+
+ #[test]
+ fn get_mut_works() {
+ let key = ThreadKey::get().unwrap();
+ let mut lock = crate::RwLock::from(42);
+
+ let mut_ref = lock.get_mut();
+ *mut_ref = 24;
+
+ lock.scoped_read(key, |guard| assert_eq!(*guard, 24))
+ }
+
+ #[test]
+ fn try_write_can_fail() {
+ let key = ThreadKey::get().unwrap();
+ let lock = crate::RwLock::new("Hello");
+ let guard = lock.write(key);
+
+ std::thread::scope(|s| {
+ s.spawn(|| {
+ let key = ThreadKey::get().unwrap();
+ let r = lock.try_write(key);
+ assert!(r.is_err());
+ });
+ });
+
+ drop(guard);
+ }
+
+ #[test]
+ fn try_read_can_fail() {
+ let key = ThreadKey::get().unwrap();
+ let lock = crate::RwLock::new("Hello");
+ let guard = lock.write(key);
+
+ std::thread::scope(|s| {
+ s.spawn(|| {
+ let key = ThreadKey::get().unwrap();
+ let r = lock.try_read(key);
+ assert!(r.is_err());
+ });
+ });
+
+ drop(guard);
+ }
+
+ #[test]
fn read_display_works() {
let key = ThreadKey::get().unwrap();
let lock: crate::RwLock<_> = RwLock::new("Hello, world!");
@@ -275,4 +394,200 @@ mod tests {
assert!(!lock.is_locked());
}
+
+ #[test]
+ fn unlock_write() {
+ let key = ThreadKey::get().unwrap();
+ let lock = crate::RwLock::new("Hello, world");
+
+ let mut guard = lock.write(key);
+ *guard = "Goodbye, world!";
+ let key = RwLock::unlock_write(guard);
+
+ let guard = lock.read(key);
+ assert_eq!(*guard, "Goodbye, world!");
+ }
+
+ #[test]
+ fn unlock_read() {
+ let key = ThreadKey::get().unwrap();
+ let lock = crate::RwLock::new("Hello, world");
+
+ let guard = lock.read(key);
+ assert_eq!(*guard, "Hello, world");
+ let key = RwLock::unlock_read(guard);
+
+ let guard = lock.write(key);
+ assert_eq!(*guard, "Hello, world");
+ }
+
+ #[test]
+ fn unlock_read_lock() {
+ let key = ThreadKey::get().unwrap();
+ let lock = crate::RwLock::new("Hello, world");
+ let reader = ReadLock::new(&lock);
+
+ let guard = reader.lock(key);
+ let key = ReadLock::unlock(guard);
+
+ lock.write(key);
+ }
+
+ #[test]
+ fn unlock_write_lock() {
+ let key = ThreadKey::get().unwrap();
+ let lock = crate::RwLock::new("Hello, world");
+ let writer = WriteLock::from(&lock);
+
+ let guard = writer.lock(key);
+ let key = WriteLock::unlock(guard);
+
+ lock.write(key);
+ }
+
+ #[test]
+ fn read_lock_in_collection() {
+ let mut key = ThreadKey::get().unwrap();
+ let lock = crate::RwLock::new("hi");
+ let collection = LockCollection::try_new(ReadLock::new(&lock)).unwrap();
+
+ collection.scoped_lock(&mut key, |guard| {
+ assert_eq!(*guard, "hi");
+ });
+ collection.scoped_read(&mut key, |guard| {
+ assert_eq!(*guard, "hi");
+ });
+ assert!(collection
+ .scoped_try_lock(&mut key, |guard| {
+ assert_eq!(*guard, "hi");
+ })
+ .is_ok());
+ assert!(collection
+ .scoped_try_read(&mut key, |guard| {
+ assert_eq!(*guard, "hi");
+ })
+ .is_ok());
+
+ let guard = collection.lock(key);
+ assert_eq!(**guard, "hi");
+
+ let key = LockCollection::<ReadLock<_, _>>::unlock(guard);
+ let guard = collection.read(key);
+ assert_eq!(**guard, "hi");
+
+ let key = LockCollection::<ReadLock<_, _>>::unlock(guard);
+ let guard = lock.write(key);
+
+ std::thread::scope(|s| {
+ s.spawn(|| {
+ let key = ThreadKey::get().unwrap();
+ let guard = collection.try_lock(key);
+ assert!(guard.is_err());
+ });
+ s.spawn(|| {
+ let key = ThreadKey::get().unwrap();
+ let guard = collection.try_read(key);
+ assert!(guard.is_err());
+ });
+ });
+
+ drop(guard);
+ }
+
+ #[test]
+ fn write_lock_in_collection() {
+ let mut key = ThreadKey::get().unwrap();
+ let lock = crate::RwLock::new("hi");
+ let collection = LockCollection::try_new(WriteLock::new(&lock)).unwrap();
+
+ collection.scoped_lock(&mut key, |guard| {
+ assert_eq!(*guard, "hi");
+ });
+ assert!(collection
+ .scoped_try_lock(&mut key, |guard| {
+ assert_eq!(*guard, "hi");
+ })
+ .is_ok());
+
+ let guard = collection.lock(key);
+ assert_eq!(**guard, "hi");
+
+ let key = LockCollection::<WriteLock<_, _>>::unlock(guard);
+ let guard = lock.write(key);
+
+ std::thread::scope(|s| {
+ s.spawn(|| {
+ let key = ThreadKey::get().unwrap();
+ let guard = collection.try_lock(key);
+ assert!(guard.is_err());
+ });
+ });
+
+ drop(guard);
+ }
+
+ #[test]
+ fn read_ref_as_ref() {
+ let key = ThreadKey::get().unwrap();
+ let lock = LockCollection::new(crate::RwLock::new("hi"));
+ let guard = lock.read(key);
+
+ assert_eq!(*(*guard).as_ref(), "hi");
+ }
+
+ #[test]
+ fn read_guard_as_ref() {
+ let key = ThreadKey::get().unwrap();
+ let lock = crate::RwLock::new("hi");
+ let guard = lock.read(key);
+
+ assert_eq!(*guard.as_ref(), "hi");
+ }
+
+ #[test]
+ fn write_ref_as_ref() {
+ let key = ThreadKey::get().unwrap();
+ let lock = LockCollection::new(crate::RwLock::new("hi"));
+ let guard = lock.lock(key);
+
+ assert_eq!(*(*guard).as_ref(), "hi");
+ }
+
+ #[test]
+ fn write_guard_as_ref() {
+ let key = ThreadKey::get().unwrap();
+ let lock = crate::RwLock::new("hi");
+ let guard = lock.write(key);
+
+ assert_eq!(*guard.as_ref(), "hi");
+ }
+
+ #[test]
+ fn write_guard_as_mut() {
+ let key = ThreadKey::get().unwrap();
+ let lock = crate::RwLock::new("hi");
+ let mut guard = lock.write(key);
+
+ assert_eq!(*guard.as_mut(), "hi");
+ *guard.as_mut() = "foo";
+ assert_eq!(*guard.as_mut(), "foo");
+ }
+
+ #[test]
+ fn poison_read_lock() {
+ let lock = crate::RwLock::new("hi");
+ let reader = ReadLock::new(&lock);
+
+ reader.poison();
+ assert!(lock.poison.is_poisoned());
+ }
+
+ #[test]
+ fn poison_write_lock() {
+ let lock = crate::RwLock::new("hi");
+ let reader = WriteLock::new(&lock);
+
+ reader.poison();
+ assert!(lock.poison.is_poisoned());
+ }
}
diff --git a/src/rwlock/read_guard.rs b/src/rwlock/read_guard.rs
index 0d68c75..5b26c06 100644
--- a/src/rwlock/read_guard.rs
+++ b/src/rwlock/read_guard.rs
@@ -64,7 +64,7 @@ impl<'a, T: ?Sized, R: RawRwLock> RwLockReadRef<'a, T, R> {
/// Creates an immutable reference for the underlying data of an [`RwLock`]
/// without locking it or taking ownership of the key.
#[must_use]
- pub(crate) unsafe fn new(mutex: &'a RwLock<T, R>) -> Self {
+ pub(crate) const unsafe fn new(mutex: &'a RwLock<T, R>) -> Self {
Self(mutex, PhantomData)
}
}
@@ -109,7 +109,7 @@ impl<'a, T: ?Sized, R: RawRwLock> RwLockReadGuard<'a, T, R> {
/// Create a guard to the given mutex. Undefined if multiple guards to the
/// same mutex exist at once.
#[must_use]
- pub(super) unsafe fn new(rwlock: &'a RwLock<T, R>, thread_key: ThreadKey) -> Self {
+ pub(super) const unsafe fn new(rwlock: &'a RwLock<T, R>, thread_key: ThreadKey) -> Self {
Self {
rwlock: RwLockReadRef(rwlock, PhantomData),
thread_key,
diff --git a/src/rwlock/read_lock.rs b/src/rwlock/read_lock.rs
index 05b184a..dd9e42f 100644
--- a/src/rwlock/read_lock.rs
+++ b/src/rwlock/read_lock.rs
@@ -3,11 +3,41 @@ use std::fmt::Debug;
use lock_api::RawRwLock;
use crate::lockable::{Lockable, RawLock, Sharable};
-use crate::ThreadKey;
+use crate::{Keyable, ThreadKey};
use super::{ReadLock, RwLock, RwLockReadGuard, RwLockReadRef};
-unsafe impl<T: Send, R: RawRwLock + Send + Sync> Lockable for ReadLock<'_, T, R> {
+unsafe impl<T, R: RawRwLock> RawLock for ReadLock<'_, T, R> {
+ fn poison(&self) {
+ self.0.poison()
+ }
+
+ unsafe fn raw_write(&self) {
+ self.0.raw_read()
+ }
+
+ unsafe fn raw_try_write(&self) -> bool {
+ self.0.raw_try_read()
+ }
+
+ unsafe fn raw_unlock_write(&self) {
+ self.0.raw_unlock_read()
+ }
+
+ unsafe fn raw_read(&self) {
+ self.0.raw_read()
+ }
+
+ unsafe fn raw_try_read(&self) -> bool {
+ self.0.raw_try_read()
+ }
+
+ unsafe fn raw_unlock_read(&self) {
+ self.0.raw_unlock_read()
+ }
+}
+
+unsafe impl<T, R: RawRwLock> Lockable for ReadLock<'_, T, R> {
type Guard<'g>
= RwLockReadRef<'g, T, R>
where
@@ -19,7 +49,7 @@ unsafe impl<T: Send, R: RawRwLock + Send + Sync> Lockable for ReadLock<'_, T, R>
Self: 'a;
fn get_ptrs<'a>(&'a self, ptrs: &mut Vec<&'a dyn RawLock>) {
- ptrs.push(self.as_ref());
+ ptrs.push(self);
}
unsafe fn guard(&self) -> Self::Guard<'_> {
@@ -31,7 +61,7 @@ unsafe impl<T: Send, R: RawRwLock + Send + Sync> Lockable for ReadLock<'_, T, R>
}
}
-unsafe impl<T: Send, R: RawRwLock + Send + Sync> Sharable for ReadLock<'_, T, R> {
+unsafe impl<T, R: RawRwLock> Sharable for ReadLock<'_, T, R> {
type ReadGuard<'g>
= RwLockReadRef<'g, T, R>
where
@@ -53,7 +83,7 @@ unsafe impl<T: Send, R: RawRwLock + Send + Sync> Sharable for ReadLock<'_, T, R>
#[mutants::skip]
#[cfg(not(tarpaulin_include))]
-impl<T: ?Sized + Debug, R: RawRwLock> Debug for ReadLock<'_, T, R> {
+impl<T: Debug, R: RawRwLock> Debug for ReadLock<'_, T, R> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
// safety: this is just a try lock, and the value is dropped
// immediately after, so there's no risk of blocking ourselves
@@ -104,7 +134,19 @@ impl<'l, T, R> ReadLock<'l, T, R> {
}
}
-impl<T: ?Sized, R: RawRwLock> ReadLock<'_, T, R> {
+impl<T, R: RawRwLock> ReadLock<'_, T, R> {
+ pub fn scoped_lock<'a, Ret>(&'a self, key: impl Keyable, f: impl Fn(&'a T) -> Ret) -> Ret {
+ self.0.scoped_read(key, f)
+ }
+
+ pub fn scoped_try_lock<'a, Key: Keyable, Ret>(
+ &'a self,
+ key: Key,
+ f: impl Fn(&'a T) -> Ret,
+ ) -> Result<Ret, Key> {
+ self.0.scoped_try_read(key, f)
+ }
+
/// Locks the underlying [`RwLock`] with shared read access, blocking the
/// current thread until it can be acquired.
///
diff --git a/src/rwlock/rwlock.rs b/src/rwlock/rwlock.rs
index 905ecf8..5f407d1 100644
--- a/src/rwlock/rwlock.rs
+++ b/src/rwlock/rwlock.rs
@@ -5,6 +5,7 @@ use std::panic::AssertUnwindSafe;
use lock_api::RawRwLock;
+use crate::collection::utils;
use crate::handle_unwind::handle_unwind;
use crate::lockable::{
Lockable, LockableGetMut, LockableIntoInner, OwnedLockable, RawLock, Sharable,
@@ -18,7 +19,7 @@ unsafe impl<T: ?Sized, R: RawRwLock> RawLock for RwLock<T, R> {
self.poison.poison();
}
- unsafe fn raw_lock(&self) {
+ unsafe fn raw_write(&self) {
assert!(
!self.poison.is_poisoned(),
"The read-write lock has been killed"
@@ -29,7 +30,7 @@ unsafe impl<T: ?Sized, R: RawRwLock> RawLock for RwLock<T, R> {
handle_unwind(|| this.raw.lock_exclusive(), || self.poison())
}
- unsafe fn raw_try_lock(&self) -> bool {
+ unsafe fn raw_try_write(&self) -> bool {
if self.poison.is_poisoned() {
return false;
}
@@ -39,7 +40,7 @@ unsafe impl<T: ?Sized, R: RawRwLock> RawLock for RwLock<T, R> {
handle_unwind(|| this.raw.try_lock_exclusive(), || self.poison())
}
- unsafe fn raw_unlock(&self) {
+ unsafe fn raw_unlock_write(&self) {
// if the closure unwraps, then the mutex will be killed
let this = AssertUnwindSafe(self);
handle_unwind(|| this.raw.unlock_exclusive(), || self.poison())
@@ -73,7 +74,7 @@ unsafe impl<T: ?Sized, R: RawRwLock> RawLock for RwLock<T, R> {
}
}
-unsafe impl<T: Send, R: RawRwLock + Send + Sync> Lockable for RwLock<T, R> {
+unsafe impl<T, R: RawRwLock> Lockable for RwLock<T, R> {
type Guard<'g>
= RwLockWriteRef<'g, T, R>
where
@@ -97,7 +98,7 @@ unsafe impl<T: Send, R: RawRwLock + Send + Sync> Lockable for RwLock<T, R> {
}
}
-unsafe impl<T: Send, R: RawRwLock + Send + Sync> Sharable for RwLock<T, R> {
+unsafe impl<T, R: RawRwLock> Sharable for RwLock<T, R> {
type ReadGuard<'g>
= RwLockReadRef<'g, T, R>
where
@@ -117,9 +118,9 @@ unsafe impl<T: Send, R: RawRwLock + Send + Sync> Sharable for RwLock<T, R> {
}
}
-unsafe impl<T: Send, R: RawRwLock + Send + Sync> OwnedLockable for RwLock<T, R> {}
+unsafe impl<T: Send, R: RawRwLock> OwnedLockable for RwLock<T, R> {}
-impl<T: Send, R: RawRwLock + Send + Sync> LockableIntoInner for RwLock<T, R> {
+impl<T: Send, R: RawRwLock> LockableIntoInner for RwLock<T, R> {
type Inner = T;
fn into_inner(self) -> Self::Inner {
@@ -127,7 +128,7 @@ impl<T: Send, R: RawRwLock + Send + Sync> LockableIntoInner for RwLock<T, R> {
}
}
-impl<T: Send, R: RawRwLock + Send + Sync> LockableGetMut for RwLock<T, R> {
+impl<T: Send, R: RawRwLock> LockableGetMut for RwLock<T, R> {
type Inner<'a>
= &'a mut T
where
@@ -160,7 +161,7 @@ impl<T, R: RawRwLock> RwLock<T, R> {
#[mutants::skip]
#[cfg(not(tarpaulin_include))]
-impl<T: ?Sized + Debug, R: RawRwLock> Debug for RwLock<T, R> {
+impl<T: Debug, R: RawRwLock> Debug for RwLock<T, R> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
// safety: this is just a try lock, and the value is dropped
// immediately after, so there's no risk of blocking ourselves
@@ -247,85 +248,29 @@ impl<T: ?Sized, R> RwLock<T, R> {
}
}
-impl<T: ?Sized, R: RawRwLock> RwLock<T, R> {
- pub fn scoped_read<Ret>(&self, key: impl Keyable, f: impl Fn(&T) -> Ret) -> Ret {
- unsafe {
- // safety: we have the thread key
- self.raw_read();
-
- // safety: the rwlock was just locked
- let r = f(self.data.get().as_ref().unwrap_unchecked());
-
- // safety: the rwlock is already locked
- self.raw_unlock_read();
-
- drop(key); // ensure the key stays valid for long enough
-
- r
- }
+impl<T, R: RawRwLock> RwLock<T, R> {
+ pub fn scoped_read<'a, Ret>(&'a self, key: impl Keyable, f: impl Fn(&'a T) -> Ret) -> Ret {
+ utils::scoped_read(self, key, f)
}
- pub fn scoped_try_read<Key: Keyable, Ret>(
- &self,
+ pub fn scoped_try_read<'a, Key: Keyable, Ret>(
+ &'a self,
key: Key,
- f: impl Fn(&T) -> Ret,
+ f: impl Fn(&'a T) -> Ret,
) -> Result<Ret, Key> {
- unsafe {
- // safety: we have the thread key
- if !self.raw_try_read() {
- return Err(key);
- }
-
- // safety: the rwlock was just locked
- let r = f(self.data.get().as_ref().unwrap_unchecked());
-
- // safety: the rwlock is already locked
- self.raw_unlock_read();
-
- drop(key); // ensure the key stays valid for long enough
-
- Ok(r)
- }
+ utils::scoped_try_read(self, key, f)
}
- pub fn scoped_write<Ret>(&self, key: impl Keyable, f: impl Fn(&mut T) -> Ret) -> Ret {
- unsafe {
- // safety: we have the thread key
- self.raw_lock();
-
- // safety: we just locked the rwlock
- let r = f(self.data.get().as_mut().unwrap_unchecked());
-
- // safety: the rwlock is already locked
- self.raw_unlock();
-
- drop(key); // ensure the key stays valid for long enough
-
- r
- }
+ pub fn scoped_write<'a, Ret>(&'a self, key: impl Keyable, f: impl Fn(&'a mut T) -> Ret) -> Ret {
+ utils::scoped_write(self, key, f)
}
- pub fn scoped_try_write<Key: Keyable, Ret>(
- &self,
+ pub fn scoped_try_write<'a, Key: Keyable, Ret>(
+ &'a self,
key: Key,
- f: impl Fn(&mut T) -> Ret,
+ f: impl Fn(&'a mut T) -> Ret,
) -> Result<Ret, Key> {
- unsafe {
- // safety: we have the thread key
- if !self.raw_try_lock() {
- return Err(key);
- }
-
- // safety: the rwlock was just locked
- let r = f(self.data.get().as_mut().unwrap_unchecked());
-
- // safety: the rwlock is already locked
- self.raw_unlock();
-
- drop(key); // ensure the key stays valid for long enough
-
- Ok(r)
- }
+ utils::scoped_try_write(self, key, f)
}
/// Locks this `RwLock` with shared read access, blocking the current
@@ -426,7 +371,7 @@ impl<T: ?Sized, R: RawRwLock> RwLock<T, R> {
/// without exclusive access to the key is undefined behavior.
#[cfg(test)]
pub(crate) unsafe fn try_write_no_key(&self) -> Option<RwLockWriteRef<'_, T, R>> {
- if self.raw_try_lock() {
+ if self.raw_try_write() {
// safety: the lock is locked first
Some(RwLockWriteRef(self, PhantomData))
} else {
@@ -463,7 +408,7 @@ impl<T: ?Sized, R: RawRwLock> RwLock<T, R> {
/// [`ThreadKey`]: `crate::ThreadKey`
pub fn write(&self, key: ThreadKey) -> RwLockWriteGuard<'_, T, R> {
unsafe {
- self.raw_lock();
+ self.raw_write();
// safety: the lock is locked first
RwLockWriteGuard::new(self, key)
@@ -498,7 +443,7 @@ impl<T: ?Sized, R: RawRwLock> RwLock<T, R> {
/// ```
pub fn try_write(&self, key: ThreadKey) -> Result<RwLockWriteGuard<'_, T, R>, ThreadKey> {
unsafe {
- if self.raw_try_lock() {
+ if self.raw_try_write() {
// safety: the lock is locked first
Ok(RwLockWriteGuard::new(self, key))
} else {
@@ -533,9 +478,7 @@ impl<T: ?Sized, R: RawRwLock> RwLock<T, R> {
/// ```
#[must_use]
pub fn unlock_read(guard: RwLockReadGuard<'_, T, R>) -> ThreadKey {
- unsafe {
- guard.rwlock.0.raw_unlock_read();
- }
+ drop(guard.rwlock);
guard.thread_key
}
@@ -560,9 +503,7 @@ impl<T: ?Sized, R: RawRwLock> RwLock<T, R> {
/// ```
#[must_use]
pub fn unlock_write(guard: RwLockWriteGuard<'_, T, R>) -> ThreadKey {
- unsafe {
- guard.rwlock.0.raw_unlock();
- }
+ drop(guard.rwlock);
guard.thread_key
}
}
diff --git a/src/rwlock/write_guard.rs b/src/rwlock/write_guard.rs
index 3fabf8e..c7676b5 100644
--- a/src/rwlock/write_guard.rs
+++ b/src/rwlock/write_guard.rs
@@ -71,7 +71,7 @@ impl<T: ?Sized, R: RawRwLock> Drop for RwLockWriteRef<'_, T, R> {
fn drop(&mut self) {
// safety: this guard is being destroyed, so the data cannot be
// accessed without locking again
- unsafe { self.0.raw_unlock() }
+ unsafe { self.0.raw_unlock_write() }
}
}
@@ -79,7 +79,7 @@ impl<'a, T: ?Sized + 'a, R: RawRwLock> RwLockWriteRef<'a, T, R> {
/// Creates a reference to the underlying data of an [`RwLock`] without
/// locking or taking ownership of the key.
#[must_use]
- pub(crate) unsafe fn new(mutex: &'a RwLock<T, R>) -> Self {
+ pub(crate) const unsafe fn new(mutex: &'a RwLock<T, R>) -> Self {
Self(mutex, PhantomData)
}
}
@@ -136,7 +136,7 @@ impl<'a, T: ?Sized + 'a, R: RawRwLock> RwLockWriteGuard<'a, T, R> {
/// Create a guard to the given mutex. Undefined if multiple guards to the
/// same mutex exist at once.
#[must_use]
- pub(super) unsafe fn new(rwlock: &'a RwLock<T, R>, thread_key: ThreadKey) -> Self {
+ pub(super) const unsafe fn new(rwlock: &'a RwLock<T, R>, thread_key: ThreadKey) -> Self {
Self {
rwlock: RwLockWriteRef(rwlock, PhantomData),
thread_key,
diff --git a/src/rwlock/write_lock.rs b/src/rwlock/write_lock.rs
index 8a44a2d..5ae4dda 100644
--- a/src/rwlock/write_lock.rs
+++ b/src/rwlock/write_lock.rs
@@ -3,11 +3,41 @@ use std::fmt::Debug;
use lock_api::RawRwLock;
use crate::lockable::{Lockable, RawLock};
-use crate::ThreadKey;
+use crate::{Keyable, ThreadKey};
use super::{RwLock, RwLockWriteGuard, RwLockWriteRef, WriteLock};
-unsafe impl<T: Send, R: RawRwLock + Send + Sync> Lockable for WriteLock<'_, T, R> {
+unsafe impl<T, R: RawRwLock> RawLock for WriteLock<'_, T, R> {
+ fn poison(&self) {
+ self.0.poison()
+ }
+
+ unsafe fn raw_write(&self) {
+ self.0.raw_write()
+ }
+
+ unsafe fn raw_try_write(&self) -> bool {
+ self.0.raw_try_write()
+ }
+
+ unsafe fn raw_unlock_write(&self) {
+ self.0.raw_unlock_write()
+ }
+
+ unsafe fn raw_read(&self) {
+ self.0.raw_write()
+ }
+
+ unsafe fn raw_try_read(&self) -> bool {
+ self.0.raw_try_write()
+ }
+
+ unsafe fn raw_unlock_read(&self) {
+ self.0.raw_unlock_write()
+ }
+}
+
+unsafe impl<T, R: RawRwLock> Lockable for WriteLock<'_, T, R> {
type Guard<'g>
= RwLockWriteRef<'g, T, R>
where
@@ -19,7 +49,7 @@ unsafe impl<T: Send, R: RawRwLock + Send + Sync> Lockable for WriteLock<'_, T, R
Self: 'a;
fn get_ptrs<'a>(&'a self, ptrs: &mut Vec<&'a dyn RawLock>) {
- ptrs.push(self.as_ref());
+ ptrs.push(self)
}
unsafe fn guard(&self) -> Self::Guard<'_> {
@@ -36,7 +66,7 @@ unsafe impl<T: Send, R: RawRwLock + Send + Sync> Lockable for WriteLock<'_, T, R
#[mutants::skip]
#[cfg(not(tarpaulin_include))]
-impl<T: ?Sized + Debug, R: RawRwLock> Debug for WriteLock<'_, T, R> {
+impl<T: Debug, R: RawRwLock> Debug for WriteLock<'_, T, R> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
// safety: this is just a try lock, and the value is dropped
// immediately after, so there's no risk of blocking ourselves
@@ -89,7 +119,19 @@ impl<'l, T, R> WriteLock<'l, T, R> {
}
}
-impl<T: ?Sized, R: RawRwLock> WriteLock<'_, T, R> {
+impl<T, R: RawRwLock> WriteLock<'_, T, R> {
+ pub fn scoped_lock<'a, Ret>(&'a self, key: impl Keyable, f: impl Fn(&'a mut T) -> Ret) -> Ret {
+ self.0.scoped_write(key, f)
+ }
+
+ pub fn scoped_try_lock<'a, Key: Keyable, Ret>(
+ &'a self,
+ key: Key,
+ f: impl Fn(&'a mut T) -> Ret,
+ ) -> Result<Ret, Key> {
+ self.0.scoped_try_write(key, f)
+ }
+
/// Locks the underlying [`RwLock`] with exclusive write access, blocking
/// the current until it can be acquired.
///
diff --git a/src/thread.rs b/src/thread.rs
new file mode 100644
index 0000000..6e9c270
--- /dev/null
+++ b/src/thread.rs
@@ -0,0 +1,19 @@
+use std::marker::PhantomData;
+
+mod scope;
+
+#[derive(Debug)]
+pub struct Scope<'scope, 'env: 'scope>(PhantomData<(&'env (), &'scope ())>);
+
+#[derive(Debug)]
+pub struct ScopedJoinHandle<'scope, T> {
+ handle: std::thread::JoinHandle<T>,
+ _phantom: PhantomData<&'scope ()>,
+}
+
+pub struct JoinHandle<T> {
+ handle: std::thread::JoinHandle<T>,
+ key: crate::ThreadKey,
+}
+
+pub struct ThreadBuilder(std::thread::Builder);
diff --git a/src/thread/scope.rst b/src/thread/scope.rst
new file mode 100644
index 0000000..09319cb
--- /dev/null
+++ b/src/thread/scope.rst
@@ -0,0 +1,47 @@
+use std::marker::PhantomData;
+
+use crate::{Keyable, ThreadKey};
+
+use super::{Scope, ScopedJoinHandle};
+
+pub fn scope<'env, F, T>(key: impl Keyable, f: F) -> T
+where
+ F: for<'scope> FnOnce(&'scope Scope<'scope, 'env>) -> T,
+{
+ let scope = Scope(PhantomData);
+ let t = f(&scope);
+ drop(key);
+ t
+}
+
+impl<'scope> Scope<'scope, '_> {
+ #[allow(clippy::unused_self)]
+ pub fn spawn<T: Send + 'scope>(
+ &self,
+ f: impl FnOnce(ThreadKey) -> T + Send + 'scope,
+ ) -> std::io::Result<ScopedJoinHandle<'scope, T>> {
+ unsafe {
+ // safety: the lifetimes ensure that the data lives long enough
+ let handle = std::thread::Builder::new().spawn_unchecked(|| {
+ // safety: the thread just started, so the key cannot be acquired yet
+ let key = ThreadKey::get().unwrap_unchecked();
+ f(key)
+ })?;
+
+ Ok(ScopedJoinHandle {
+ handle,
+ _phantom: PhantomData,
+ })
+ }
+ }
+}
+
+impl<T> ScopedJoinHandle<'_, T> {
+ pub fn is_finished(&self) -> bool {
+ self.handle.is_finished()
+ }
+
+ pub fn join(self) -> std::thread::Result<T> {
+ self.handle.join()
+ }
+}