summaryrefslogtreecommitdiff
path: root/src/collection
diff options
context:
space:
mode:
Diffstat (limited to 'src/collection')
-rw-r--r--src/collection/boxed.rs30
-rw-r--r--src/collection/owned.rs28
-rw-r--r--src/collection/ref.rs28
-rw-r--r--src/collection/utils.rs44
4 files changed, 61 insertions, 69 deletions
diff --git a/src/collection/boxed.rs b/src/collection/boxed.rs
index 224eedb..5ced6d1 100644
--- a/src/collection/boxed.rs
+++ b/src/collection/boxed.rs
@@ -4,7 +4,7 @@ use std::marker::PhantomData;
use crate::lockable::{Lockable, OwnedLockable, RawLock, Sharable};
use crate::Keyable;
-use super::{BoxedLockCollection, LockGuard};
+use super::{utils, BoxedLockCollection, LockGuard};
/// returns `true` if the sorted list contains a duplicate
#[must_use]
@@ -185,6 +185,8 @@ impl<L: Lockable> BoxedLockCollection<L> {
let data = Box::new(data);
let mut locks = Vec::new();
data.get_ptrs(&mut locks);
+
+ // cast to *const () because fat pointers can't be converted to usize
locks.sort_by_key(|lock| std::ptr::from_ref(*lock).cast::<()>() as usize);
// safety: the box will be dropped after the lock references, so it's
@@ -310,17 +312,8 @@ impl<L: Lockable> BoxedLockCollection<L> {
key: Key,
) -> Option<LockGuard<'key, L::Guard<'g>, Key>> {
let guard = unsafe {
- for (i, lock) in self.locks.iter().enumerate() {
- // safety: we have the thread key
- let success = lock.try_lock();
-
- if !success {
- for lock in &self.locks[0..i] {
- // safety: this lock was already acquired
- lock.unlock();
- }
- return None;
- }
+ if !utils::ordered_try_lock(&self.locks) {
+ return None;
}
// safety: we've acquired the locks
@@ -424,17 +417,8 @@ impl<L: Sharable> BoxedLockCollection<L> {
key: Key,
) -> Option<LockGuard<'key, L::ReadGuard<'g>, Key>> {
let guard = unsafe {
- for (i, lock) in self.locks.iter().enumerate() {
- // safety: we have the thread key
- let success = lock.try_read();
-
- if !success {
- for lock in &self.locks[0..i] {
- // safety: this lock was already acquired
- lock.unlock_read();
- }
- return None;
- }
+ if !utils::ordered_try_read(&self.locks) {
+ return None;
}
// safety: we've acquired the locks
diff --git a/src/collection/owned.rs b/src/collection/owned.rs
index e1549b2..919c403 100644
--- a/src/collection/owned.rs
+++ b/src/collection/owned.rs
@@ -3,7 +3,7 @@ use std::marker::PhantomData;
use crate::lockable::{Lockable, OwnedLockable, RawLock, Sharable};
use crate::Keyable;
-use super::{LockGuard, OwnedLockCollection};
+use super::{utils, LockGuard, OwnedLockCollection};
fn get_locks<L: Lockable>(data: &L) -> Vec<&dyn RawLock> {
let mut locks = Vec::new();
@@ -191,17 +191,8 @@ impl<L: OwnedLockable> OwnedLockCollection<L> {
) -> Option<LockGuard<'key, L::Guard<'g>, Key>> {
let locks = get_locks(&self.data);
let guard = unsafe {
- for (i, lock) in locks.iter().enumerate() {
- // safety: we have the thread key
- let success = lock.try_lock();
-
- if !success {
- for lock in &locks[0..i] {
- // safety: this lock was already acquired
- lock.unlock();
- }
- return None;
- }
+ if !utils::ordered_try_lock(&locks) {
+ return None;
}
// safety: we've acquired the locks
@@ -315,17 +306,8 @@ impl<L: Sharable> OwnedLockCollection<L> {
) -> Option<LockGuard<'key, L::ReadGuard<'g>, Key>> {
let locks = get_locks(&self.data);
let guard = unsafe {
- for (i, lock) in locks.iter().enumerate() {
- // safety: we have the thread key
- let success = lock.try_read();
-
- if !success {
- for lock in &locks[0..i] {
- // safety: this lock was already acquired
- lock.unlock();
- }
- return None;
- }
+ if !utils::ordered_try_read(&locks) {
+ return None;
}
// safety: we've acquired the locks
diff --git a/src/collection/ref.rs b/src/collection/ref.rs
index e5c548f..d8c7f2e 100644
--- a/src/collection/ref.rs
+++ b/src/collection/ref.rs
@@ -4,7 +4,7 @@ use std::marker::PhantomData;
use crate::lockable::{Lockable, OwnedLockable, RawLock, Sharable};
use crate::Keyable;
-use super::{LockGuard, RefLockCollection};
+use super::{utils, LockGuard, RefLockCollection};
#[must_use]
pub fn get_locks<L: Lockable>(data: &L) -> Vec<&dyn RawLock> {
@@ -221,17 +221,8 @@ impl<'a, L: Lockable> RefLockCollection<'a, L> {
key: Key,
) -> Option<LockGuard<'key, L::Guard<'a>, Key>> {
let guard = unsafe {
- for (i, lock) in self.locks.iter().enumerate() {
- // safety: we have the thread key
- let success = lock.try_lock();
-
- if !success {
- for lock in &self.locks[0..i] {
- // safety: this lock was already acquired
- lock.unlock();
- }
- return None;
- }
+ if !utils::ordered_try_lock(&self.locks) {
+ return None;
}
// safety: we've acquired the locks
@@ -339,17 +330,8 @@ impl<'a, L: Sharable> RefLockCollection<'a, L> {
key: Key,
) -> Option<LockGuard<'key, L::ReadGuard<'a>, Key>> {
let guard = unsafe {
- for (i, lock) in self.locks.iter().enumerate() {
- // safety: we have the thread key
- let success = lock.try_read();
-
- if !success {
- for lock in &self.locks[0..i] {
- // safety: this lock was already acquired
- lock.unlock_read();
- }
- return None;
- }
+ if !utils::ordered_try_read(&self.locks) {
+ return None;
}
// safety: we've acquired the locks
diff --git a/src/collection/utils.rs b/src/collection/utils.rs
new file mode 100644
index 0000000..dc58399
--- /dev/null
+++ b/src/collection/utils.rs
@@ -0,0 +1,44 @@
+use crate::lockable::RawLock;
+
+/// Locks the locks in the order they are given. This causes deadlock if the
+/// locks contain duplicates, or if this is called by multiple threads with the
+/// locks in different orders.
+pub unsafe fn ordered_try_lock(locks: &[&dyn RawLock]) -> bool {
+ unsafe {
+ for (i, lock) in locks.iter().enumerate() {
+ // safety: we have the thread key
+ let success = lock.try_lock();
+
+ if !success {
+ for lock in &locks[0..i] {
+ // safety: this lock was already acquired
+ lock.unlock();
+ }
+ return false;
+ }
+ }
+
+ true
+ }
+}
+
+/// Locks the locks in the order they are given. This causes deadlock f this is
+/// called by multiple threads with the locks in different orders.
+pub unsafe fn ordered_try_read(locks: &[&dyn RawLock]) -> bool {
+ unsafe {
+ for (i, lock) in locks.iter().enumerate() {
+ // safety: we have the thread key
+ let success = lock.try_read();
+
+ if !success {
+ for lock in &locks[0..i] {
+ // safety: this lock was already acquired
+ lock.unlock_read();
+ }
+ return false;
+ }
+ }
+
+ true
+ }
+}