summaryrefslogtreecommitdiff
path: root/src/collection/retry.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/collection/retry.rs')
-rw-r--r--src/collection/retry.rs268
1 files changed, 80 insertions, 188 deletions
diff --git a/src/collection/retry.rs b/src/collection/retry.rs
index 8a10fc3..05adc3e 100644
--- a/src/collection/retry.rs
+++ b/src/collection/retry.rs
@@ -8,11 +8,18 @@ use std::marker::PhantomData;
use super::{LockGuard, RetryingLockCollection};
+/// Get all raw locks in the collection
+fn get_locks<L: Lockable>(data: &L) -> Vec<&dyn RawLock> {
+ let mut locks = Vec::new();
+ data.get_ptrs(&mut locks);
+ locks
+}
+
/// Checks that a collection contains no duplicate references to a lock.
fn contains_duplicates<L: Lockable>(data: L) -> bool {
let mut locks = Vec::new();
data.get_ptrs(&mut locks);
- let locks = locks.into_iter().map(|l| l as *const dyn RawLock);
+ let locks = locks.into_iter().map(|l| &raw const *l);
let mut locks_set = HashSet::with_capacity(locks.len());
for lock in locks {
@@ -24,11 +31,17 @@ fn contains_duplicates<L: Lockable>(data: L) -> bool {
false
}
-unsafe impl<L: Lockable + Send + Sync> RawLock for RetryingLockCollection<L> {
- unsafe fn lock(&self) {
+unsafe impl<L: Lockable> RawLock for RetryingLockCollection<L> {
+ fn kill(&self) {
+ let locks = get_locks(&self.data);
+ for lock in locks {
+ lock.kill();
+ }
+ }
+
+ unsafe fn raw_lock(&self) {
let mut first_index = 0;
- let mut locks = Vec::new();
- self.data.get_ptrs(&mut locks);
+ let locks = get_locks(&self.data);
if locks.is_empty() {
return;
@@ -37,23 +50,27 @@ unsafe impl<L: Lockable + Send + Sync> RawLock for RetryingLockCollection<L> {
unsafe {
'outer: loop {
// safety: we have the thread key
- locks[first_index].lock();
+ locks[first_index].raw_lock();
for (i, lock) in locks.iter().enumerate() {
if i == first_index {
continue;
}
+ // If the lock has been killed, then this returns false
+ // instead of panicking. This sounds like a problem, but if
+ // it does return false, then the lock function is called
+ // immediately after, causing a panic
// safety: we have the thread key
- if !lock.try_lock() {
+ if !lock.raw_try_lock() {
for lock in locks.iter().take(i) {
// safety: we already locked all of these
- lock.unlock();
+ lock.raw_unlock();
}
if first_index >= i {
// safety: this is already locked and can't be unlocked
// by the previous loop
- locks[first_index].unlock();
+ locks[first_index].raw_unlock();
}
first_index = i;
@@ -67,9 +84,8 @@ unsafe impl<L: Lockable + Send + Sync> RawLock for RetryingLockCollection<L> {
};
}
- unsafe fn try_lock(&self) -> bool {
- let mut locks = Vec::new();
- self.data.get_ptrs(&mut locks);
+ unsafe fn raw_try_lock(&self) -> bool {
+ let locks = get_locks(&self.data);
if locks.is_empty() {
return true;
@@ -78,10 +94,10 @@ unsafe impl<L: Lockable + Send + Sync> RawLock for RetryingLockCollection<L> {
unsafe {
for (i, lock) in locks.iter().enumerate() {
// safety: we have the thread key
- if !lock.try_lock() {
+ if !lock.raw_try_lock() {
for lock in locks.iter().take(i) {
// safety: we already locked all of these
- lock.unlock();
+ lock.raw_unlock();
}
return false;
}
@@ -91,39 +107,37 @@ unsafe impl<L: Lockable + Send + Sync> RawLock for RetryingLockCollection<L> {
true
}
- unsafe fn unlock(&self) {
- let mut locks = Vec::new();
- self.get_ptrs(&mut locks);
+ unsafe fn raw_unlock(&self) {
+ let locks = get_locks(&self.data);
for lock in locks {
- lock.unlock();
+ lock.raw_unlock();
}
}
- unsafe fn read(&self) {
+ unsafe fn raw_read(&self) {
let mut first_index = 0;
- let mut locks = Vec::new();
- self.data.get_ptrs(&mut locks);
+ let locks = get_locks(&self.data);
'outer: loop {
// safety: we have the thread key
- locks[first_index].read();
+ locks[first_index].raw_read();
for (i, lock) in locks.iter().enumerate() {
if i == first_index {
continue;
}
// safety: we have the thread key
- if !lock.try_read() {
+ if !lock.raw_try_read() {
for lock in locks.iter().take(i) {
// safety: we already locked all of these
- lock.unlock_read();
+ lock.raw_unlock_read();
}
if first_index >= i {
// safety: this is already locked and can't be unlocked
// by the previous loop
- locks[first_index].unlock_read();
+ locks[first_index].raw_unlock_read();
}
first_index = i;
@@ -133,9 +147,8 @@ unsafe impl<L: Lockable + Send + Sync> RawLock for RetryingLockCollection<L> {
}
}
- unsafe fn try_read(&self) -> bool {
- let mut locks = Vec::new();
- self.data.get_ptrs(&mut locks);
+ unsafe fn raw_try_read(&self) -> bool {
+ let locks = get_locks(&self.data);
if locks.is_empty() {
return true;
@@ -144,10 +157,10 @@ unsafe impl<L: Lockable + Send + Sync> RawLock for RetryingLockCollection<L> {
unsafe {
for (i, lock) in locks.iter().enumerate() {
// safety: we have the thread key
- if !lock.try_read() {
+ if !lock.raw_try_read() {
for lock in locks.iter().take(i) {
// safety: we already locked all of these
- lock.unlock_read();
+ lock.raw_unlock_read();
}
return false;
}
@@ -157,20 +170,25 @@ unsafe impl<L: Lockable + Send + Sync> RawLock for RetryingLockCollection<L> {
true
}
- unsafe fn unlock_read(&self) {
- let mut locks = Vec::new();
- self.get_ptrs(&mut locks);
+ unsafe fn raw_unlock_read(&self) {
+ let locks = get_locks(&self.data);
for lock in locks {
- lock.unlock_read();
+ lock.raw_unlock_read();
}
}
}
unsafe impl<L: Lockable> Lockable for RetryingLockCollection<L> {
- type Guard<'g> = L::Guard<'g> where Self: 'g;
+ type Guard<'g>
+ = L::Guard<'g>
+ where
+ Self: 'g;
- type ReadGuard<'g> = L::ReadGuard<'g> where Self: 'g;
+ type ReadGuard<'g>
+ = L::ReadGuard<'g>
+ where
+ Self: 'g;
fn get_ptrs<'a>(&'a self, ptrs: &mut Vec<&'a dyn RawLock>) {
self.data.get_ptrs(ptrs)
@@ -186,7 +204,8 @@ unsafe impl<L: Lockable> Lockable for RetryingLockCollection<L> {
}
impl<L: LockableAsMut> LockableAsMut for RetryingLockCollection<L> {
- type Inner<'a> = L::Inner<'a>
+ type Inner<'a>
+ = L::Inner<'a>
where
Self: 'a;
@@ -419,55 +438,16 @@ impl<L: Lockable> RetryingLockCollection<L> {
&'g self,
key: Key,
) -> LockGuard<'key, L::Guard<'g>, Key> {
- let mut first_index = 0;
- let mut locks = Vec::new();
- self.data.get_ptrs(&mut locks);
+ unsafe {
+ // safety: we're taking the thread key
+ self.raw_lock();
- if locks.is_empty() {
- return LockGuard {
- // safety: there's no data being returned
- guard: unsafe { self.data.guard() },
+ LockGuard {
+ // safety: we just locked the collection
+ guard: self.guard(),
key,
_phantom: PhantomData,
- };
- }
-
- let guard = unsafe {
- 'outer: loop {
- // safety: we have the thread key
- locks[first_index].lock();
- for (i, lock) in locks.iter().enumerate() {
- if i == first_index {
- continue;
- }
-
- // safety: we have the thread key
- if !lock.try_lock() {
- for lock in locks.iter().take(i) {
- // safety: we already locked all of these
- lock.unlock();
- }
-
- if first_index >= i {
- // safety: this is already locked and can't be unlocked
- // by the previous loop
- locks[first_index].unlock();
- }
-
- first_index = i;
- continue 'outer;
- }
- }
-
- // safety: we locked all the data
- break self.data.guard();
}
- };
-
- LockGuard {
- guard,
- key,
- _phantom: PhantomData,
}
}
@@ -500,39 +480,15 @@ impl<L: Lockable> RetryingLockCollection<L> {
&'g self,
key: Key,
) -> Option<LockGuard<'key, L::Guard<'g>, Key>> {
- let mut locks = Vec::new();
- self.data.get_ptrs(&mut locks);
-
- if locks.is_empty() {
- return Some(LockGuard {
- // safety: there's no data being returned
- guard: unsafe { self.data.guard() },
+ unsafe {
+ // safety: we're taking the thread key
+ self.raw_try_lock().then(|| LockGuard {
+ // safety: we just succeeded in locking everything
+ guard: self.guard(),
key,
_phantom: PhantomData,
- });
+ })
}
-
- let guard = unsafe {
- for (i, lock) in locks.iter().enumerate() {
- // safety: we have the thread key
- if !lock.try_lock() {
- for lock in locks.iter().take(i) {
- // safety: we already locked all of these
- lock.unlock();
- }
- return None;
- }
- }
-
- // safety: we locked all the data
- self.data.guard()
- };
-
- Some(LockGuard {
- guard,
- key,
- _phantom: PhantomData,
- })
}
/// Unlocks the underlying lockable data type, returning the key that's
@@ -584,55 +540,16 @@ impl<L: Sharable> RetryingLockCollection<L> {
&'g self,
key: Key,
) -> LockGuard<'key, L::ReadGuard<'g>, Key> {
- let mut first_index = 0;
- let mut locks = Vec::new();
- self.data.get_ptrs(&mut locks);
+ unsafe {
+ // safety: we're taking the thread key
+ self.raw_read();
- if locks.is_empty() {
- return LockGuard {
- // safety: there's no data being returned
- guard: unsafe { self.data.read_guard() },
+ LockGuard {
+ // safety: we just locked the collection
+ guard: self.read_guard(),
key,
_phantom: PhantomData,
- };
- }
-
- let guard = unsafe {
- 'outer: loop {
- // safety: we have the thread key
- locks[first_index].read();
- for (i, lock) in locks.iter().enumerate() {
- if i == first_index {
- continue;
- }
-
- // safety: we have the thread key
- if !lock.try_read() {
- for lock in locks.iter().take(i) {
- // safety: we already locked all of these
- lock.unlock_read();
- }
-
- if first_index >= i {
- // safety: this is already locked and can't be unlocked
- // by the previous loop
- locks[first_index].unlock_read();
- }
-
- first_index = i;
- continue 'outer;
- }
- }
-
- // safety: we locked all the data
- break self.data.read_guard();
}
- };
-
- LockGuard {
- guard,
- key,
- _phantom: PhantomData,
}
}
@@ -666,39 +583,15 @@ impl<L: Sharable> RetryingLockCollection<L> {
&'g self,
key: Key,
) -> Option<LockGuard<'key, L::ReadGuard<'g>, Key>> {
- let mut locks = Vec::new();
- self.data.get_ptrs(&mut locks);
-
- if locks.is_empty() {
- return Some(LockGuard {
- // safety: there's no data being returned
- guard: unsafe { self.data.read_guard() },
+ unsafe {
+ // safety: we're taking the thread key
+ self.raw_try_lock().then(|| LockGuard {
+ // safety: we just succeeded in locking everything
+ guard: self.read_guard(),
key,
_phantom: PhantomData,
- });
+ })
}
-
- let guard = unsafe {
- for (i, lock) in locks.iter().enumerate() {
- // safety: we have the thread key
- if !lock.try_read() {
- for lock in locks.iter().take(i) {
- // safety: we already locked all of these
- lock.unlock_read();
- }
- return None;
- }
- }
-
- // safety: we locked all the data
- self.data.read_guard()
- };
-
- Some(LockGuard {
- guard,
- key,
- _phantom: PhantomData,
- })
}
/// Unlocks the underlying lockable data type, returning the key that's
@@ -786,7 +679,6 @@ mod tests {
use super::*;
use crate::collection::BoxedLockCollection;
use crate::{Mutex, RwLock, ThreadKey};
- use lock_api::{RawMutex, RawRwLock};
#[test]
fn nonduplicate_lock_references_are_allowed() {