diff options
| -rw-r--r-- | src/guard.rs | 47 |
1 files changed, 46 insertions, 1 deletions
diff --git a/src/guard.rs b/src/guard.rs index d9e8426..818e35a 100644 --- a/src/guard.rs +++ b/src/guard.rs @@ -1,4 +1,7 @@ -use std::ops::{Deref, DerefMut}; +use std::{ + mem::MaybeUninit, + ops::{Deref, DerefMut}, +}; use crate::{ mutex::{Mutex, MutexRef, RawMutex}, @@ -125,6 +128,48 @@ impl<'a, A: Lockable<'a>, B: Lockable<'a>> Lockable<'a> for (A, B) { } } +impl<'a, T: Lockable<'a>, const N: usize> Lockable<'a> for [T; N] { + type Output = [T::Output; N]; + + unsafe fn lock(&'a self) -> Self::Output { + loop { + if let Some(guard) = self.try_lock() { + return guard; + } + } + } + + unsafe fn try_lock(&'a self) -> Option<Self::Output> { + unsafe fn unlock_partial<'a, T: Lockable<'a>, const N: usize>( + guards: [MaybeUninit<T::Output>; N], + upto: usize, + ) { + for (i, guard) in guards.into_iter().enumerate() { + if i == upto { + break; + } + T::unlock(guard.assume_init()); + } + } + + let mut outputs = MaybeUninit::<[MaybeUninit<T::Output>; N]>::uninit().assume_init(); + for i in 0..N { + if let Some(guard) = self[i].try_lock() { + outputs[i].write(guard) + } else { + unlock_partial::<T, N>(outputs, i); + return None; + }; + } + + Some(outputs.map(|mu| mu.assume_init())) + } + + fn unlock(guard: Self::Output) { + guard.map(T::unlock); + } +} + pub struct LockGuard<'a, L: Lockable<'a>> { guard: L::Output, key: ThreadKey, |
