summaryrefslogtreecommitdiff
path: root/src/collection/owned.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/collection/owned.rs')
-rw-r--r--src/collection/owned.rs283
1 files changed, 204 insertions, 79 deletions
diff --git a/src/collection/owned.rs b/src/collection/owned.rs
index b9cf313..68170d1 100644
--- a/src/collection/owned.rs
+++ b/src/collection/owned.rs
@@ -3,6 +3,7 @@ use crate::lockable::{
};
use crate::{Keyable, ThreadKey};
+use super::utils::{scoped_read, scoped_try_read, scoped_try_write, scoped_write};
use super::{utils, LockGuard, OwnedLockCollection};
unsafe impl<L: Lockable> RawLock for OwnedLockCollection<L> {
@@ -15,19 +16,19 @@ unsafe impl<L: Lockable> RawLock for OwnedLockCollection<L> {
}
}
- unsafe fn raw_lock(&self) {
- utils::ordered_lock(&utils::get_locks_unsorted(&self.data))
+ unsafe fn raw_write(&self) {
+ utils::ordered_write(&utils::get_locks_unsorted(&self.data))
}
- unsafe fn raw_try_lock(&self) -> bool {
+ unsafe fn raw_try_write(&self) -> bool {
let locks = utils::get_locks_unsorted(&self.data);
- utils::ordered_try_lock(&locks)
+ utils::ordered_try_write(&locks)
}
- unsafe fn raw_unlock(&self) {
+ unsafe fn raw_unlock_write(&self) {
let locks = utils::get_locks_unsorted(&self.data);
for lock in locks {
- lock.raw_unlock();
+ lock.raw_unlock_write();
}
}
@@ -62,7 +63,7 @@ unsafe impl<L: Lockable> Lockable for OwnedLockCollection<L> {
#[mutants::skip] // It's hard to test lkocks in an OwnedLockCollection, because they're owned
#[cfg(not(tarpaulin_include))]
fn get_ptrs<'a>(&'a self, ptrs: &mut Vec<&'a dyn RawLock>) {
- self.data.get_ptrs(ptrs)
+ ptrs.push(self)
}
unsafe fn guard(&self) -> Self::Guard<'_> {
@@ -146,7 +147,7 @@ impl<E: OwnedLockable + Extend<L>, L: OwnedLockable> Extend<L> for OwnedLockColl
// invariant that there is only one way to lock the collection. AsMut is fine,
// because the collection can't be locked as long as the reference is valid.
-impl<T, L: AsMut<T>> AsMut<T> for OwnedLockCollection<L> {
+impl<T: ?Sized, L: AsMut<T>> AsMut<T> for OwnedLockCollection<L> {
fn as_mut(&mut self) -> &mut T {
self.data.as_mut()
}
@@ -185,44 +186,16 @@ impl<L: OwnedLockable> OwnedLockCollection<L> {
Self { data }
}
- pub fn scoped_lock<R>(&self, key: impl Keyable, f: impl Fn(L::DataMut<'_>) -> R) -> R {
- unsafe {
- // safety: we have the thread key
- self.raw_lock();
-
- // safety: the data was just locked
- let r = f(self.data_mut());
-
- // safety: the collection is still locked
- self.raw_unlock();
-
- drop(key); // ensure the key stays alive long enough
-
- r
- }
+ pub fn scoped_lock<'a, R>(&'a self, key: impl Keyable, f: impl Fn(L::DataMut<'a>) -> R) -> R {
+ scoped_write(self, key, f)
}
- pub fn scoped_try_lock<Key: Keyable, R>(
- &self,
+ pub fn scoped_try_lock<'a, Key: Keyable, R>(
+ &'a self,
key: Key,
- f: impl Fn(L::DataMut<'_>) -> R,
+ f: impl Fn(L::DataMut<'a>) -> R,
) -> Result<R, Key> {
- unsafe {
- // safety: we have the thread key
- if !self.raw_try_lock() {
- return Err(key);
- }
-
- // safety: we just locked the collection
- let r = f(self.data_mut());
-
- // safety: the collection is still locked
- self.raw_unlock();
-
- drop(key); // ensures the key stays valid long enough
-
- Ok(r)
- }
+ scoped_try_write(self, key, f)
}
/// Locks the collection
@@ -249,7 +222,7 @@ impl<L: OwnedLockable> OwnedLockCollection<L> {
let guard = unsafe {
// safety: we have the thread key, and these locks happen in a
// predetermined order
- self.raw_lock();
+ self.raw_write();
// safety: we've locked all of this already
self.data.guard()
@@ -290,7 +263,7 @@ impl<L: OwnedLockable> OwnedLockCollection<L> {
/// ```
pub fn try_lock(&self, key: ThreadKey) -> Result<LockGuard<L::Guard<'_>>, ThreadKey> {
let guard = unsafe {
- if !self.raw_try_lock() {
+ if !self.raw_try_write() {
return Err(key);
}
@@ -327,44 +300,16 @@ impl<L: OwnedLockable> OwnedLockCollection<L> {
}
impl<L: Sharable> OwnedLockCollection<L> {
- pub fn scoped_read<R>(&self, key: impl Keyable, f: impl Fn(L::DataRef<'_>) -> R) -> R {
- unsafe {
- // safety: we have the thread key
- self.raw_read();
-
- // safety: the data was just locked
- let r = f(self.data_ref());
-
- // safety: the collection is still locked
- self.raw_unlock_read();
-
- drop(key); // ensure the key stays alive long enough
-
- r
- }
+ pub fn scoped_read<'a, R>(&'a self, key: impl Keyable, f: impl Fn(L::DataRef<'a>) -> R) -> R {
+ scoped_read(self, key, f)
}
- pub fn scoped_try_read<Key: Keyable, R>(
- &self,
+ pub fn scoped_try_read<'a, Key: Keyable, R>(
+ &'a self,
key: Key,
- f: impl Fn(L::DataRef<'_>) -> R,
+ f: impl Fn(L::DataRef<'a>) -> R,
) -> Result<R, Key> {
- unsafe {
- // safety: we have the thread key
- if !self.raw_try_read() {
- return Err(key);
- }
-
- // safety: we just locked the collection
- let r = f(self.data_ref());
-
- // safety: the collection is still locked
- self.raw_unlock_read();
-
- drop(key); // ensures the key stays valid long enough
-
- Ok(r)
- }
+ scoped_try_read(self, key, f)
}
/// Locks the collection, so that other threads can still read from it
@@ -554,7 +499,7 @@ impl<L: LockableIntoInner> OwnedLockCollection<L> {
#[cfg(test)]
mod tests {
use super::*;
- use crate::{Mutex, ThreadKey};
+ use crate::{Mutex, RwLock, ThreadKey};
#[test]
fn get_mut_applies_changes() {
@@ -604,6 +549,63 @@ mod tests {
}
#[test]
+ fn scoped_read_works() {
+ let mut key = ThreadKey::get().unwrap();
+ let collection = OwnedLockCollection::new([RwLock::new(24), RwLock::new(42)]);
+ let sum = collection.scoped_read(&mut key, |guard| guard[0] + guard[1]);
+ assert_eq!(sum, 24 + 42);
+ }
+
+ #[test]
+ fn scoped_lock_works() {
+ let mut key = ThreadKey::get().unwrap();
+ let collection = OwnedLockCollection::new([RwLock::new(24), RwLock::new(42)]);
+ collection.scoped_lock(&mut key, |guard| *guard[0] += *guard[1]);
+
+ let sum = collection.scoped_lock(&mut key, |guard| {
+ assert_eq!(*guard[0], 24 + 42);
+ assert_eq!(*guard[1], 42);
+ *guard[0] + *guard[1]
+ });
+
+ assert_eq!(sum, 24 + 42 + 42);
+ }
+
+ #[test]
+ fn scoped_try_lock_can_fail() {
+ let key = ThreadKey::get().unwrap();
+ let collection = OwnedLockCollection::new([Mutex::new(1), Mutex::new(2)]);
+ let guard = collection.lock(key);
+
+ std::thread::scope(|s| {
+ s.spawn(|| {
+ let key = ThreadKey::get().unwrap();
+ let r = collection.scoped_try_lock(key, |_| {});
+ assert!(r.is_err());
+ });
+ });
+
+ drop(guard);
+ }
+
+ #[test]
+ fn scoped_try_read_can_fail() {
+ let key = ThreadKey::get().unwrap();
+ let collection = OwnedLockCollection::new([RwLock::new(1), RwLock::new(2)]);
+ let guard = collection.lock(key);
+
+ std::thread::scope(|s| {
+ s.spawn(|| {
+ let key = ThreadKey::get().unwrap();
+ let r = collection.scoped_try_read(key, |_| {});
+ assert!(r.is_err());
+ });
+ });
+
+ drop(guard);
+ }
+
+ #[test]
fn try_lock_works_on_unlocked() {
let key = ThreadKey::get().unwrap();
let collection = OwnedLockCollection::new((Mutex::new(0), Mutex::new(1)));
@@ -630,6 +632,74 @@ mod tests {
}
#[test]
+ fn try_read_succeeds_for_unlocked_collection() {
+ let key = ThreadKey::get().unwrap();
+ let mutexes = [RwLock::new(24), RwLock::new(42)];
+ let collection = OwnedLockCollection::new(mutexes);
+ let guard = collection.try_read(key).unwrap();
+ assert_eq!(*guard[0], 24);
+ assert_eq!(*guard[1], 42);
+ }
+
+ #[test]
+ fn try_read_fails_on_locked() {
+ let key = ThreadKey::get().unwrap();
+ let collection = OwnedLockCollection::new((RwLock::new(0), RwLock::new(1)));
+
+ std::thread::scope(|s| {
+ s.spawn(|| {
+ let key = ThreadKey::get().unwrap();
+ #[allow(unused)]
+ let guard = collection.lock(key);
+ std::mem::forget(guard);
+ });
+ });
+
+ assert!(collection.try_read(key).is_err());
+ }
+
+ #[test]
+ fn can_read_twice_on_different_threads() {
+ let key = ThreadKey::get().unwrap();
+ let mutexes = [RwLock::new(24), RwLock::new(42)];
+ let collection = OwnedLockCollection::new(mutexes);
+
+ std::thread::scope(|s| {
+ s.spawn(|| {
+ let key = ThreadKey::get().unwrap();
+ let guard = collection.read(key);
+ assert_eq!(*guard[0], 24);
+ assert_eq!(*guard[1], 42);
+ std::mem::forget(guard);
+ });
+ });
+
+ let guard = collection.try_read(key).unwrap();
+ assert_eq!(*guard[0], 24);
+ assert_eq!(*guard[1], 42);
+ }
+
+ #[test]
+ fn unlock_collection_works() {
+ let key = ThreadKey::get().unwrap();
+ let collection = OwnedLockCollection::new((Mutex::new("foo"), Mutex::new("bar")));
+ let guard = collection.lock(key);
+
+ let key = OwnedLockCollection::<(Mutex<_>, Mutex<_>)>::unlock(guard);
+ assert!(collection.try_lock(key).is_ok())
+ }
+
+ #[test]
+ fn read_unlock_collection_works() {
+ let key = ThreadKey::get().unwrap();
+ let collection = OwnedLockCollection::new((RwLock::new("foo"), RwLock::new("bar")));
+ let guard = collection.read(key);
+
+ let key = OwnedLockCollection::<(&RwLock<_>, &RwLock<_>)>::unlock_read(guard);
+ assert!(collection.try_lock(key).is_ok())
+ }
+
+ #[test]
fn default_works() {
type MyCollection = OwnedLockCollection<(Mutex<i32>, Mutex<Option<i32>>, Mutex<String>)>;
let collection = MyCollection::default();
@@ -649,4 +719,59 @@ mod tests {
assert_eq!(collection.data.len(), 3);
}
+
+ #[test]
+ fn works_in_collection() {
+ let key = ThreadKey::get().unwrap();
+ let collection =
+ OwnedLockCollection::new(OwnedLockCollection::new([RwLock::new(0), RwLock::new(1)]));
+
+ let mut guard = collection.lock(key);
+ assert_eq!(*guard[0], 0);
+ assert_eq!(*guard[1], 1);
+ *guard[1] = 2;
+
+ let key = OwnedLockCollection::<OwnedLockCollection<[RwLock<_>; 2]>>::unlock(guard);
+ let guard = collection.read(key);
+ assert_eq!(*guard[0], 0);
+ assert_eq!(*guard[1], 2);
+ }
+
+ #[test]
+ fn as_mut_works() {
+ let mut mutexes = [Mutex::new(0), Mutex::new(1)];
+ let mut collection = OwnedLockCollection::new(&mut mutexes);
+
+ collection.as_mut()[0] = Mutex::new(42);
+
+ assert_eq!(*collection.as_mut()[0].get_mut(), 42);
+ }
+
+ #[test]
+ fn child_mut_works() {
+ let mut mutexes = [Mutex::new(0), Mutex::new(1)];
+ let mut collection = OwnedLockCollection::new(&mut mutexes);
+
+ collection.child_mut()[0] = Mutex::new(42);
+
+ assert_eq!(*collection.child_mut()[0].get_mut(), 42);
+ }
+
+ #[test]
+ fn into_child_works() {
+ let mutexes = [Mutex::new(0), Mutex::new(1)];
+ let mut collection = OwnedLockCollection::new(mutexes);
+
+ collection.child_mut()[0] = Mutex::new(42);
+
+ assert_eq!(
+ *collection
+ .into_child()
+ .as_mut()
+ .get_mut(0)
+ .unwrap()
+ .get_mut(),
+ 42
+ );
+ }
}