summaryrefslogtreecommitdiff
path: root/src/collection/retry_collection.rs
blob: 73f9e185173ee22085a5542aa4c1c682eeb1df31 (plain)
use std::marker::PhantomData;

use crate::{lockable::Lock, Keyable, Lockable, OwnedLockable};

use super::{LockGuard, RetryingLockCollection};

fn contains_duplicates<L: Lockable>(data: L) -> bool {
	let mut locks = Vec::new();
	data.get_ptrs(&mut locks);
	let mut locks: Vec<_> = locks.into_iter().map(|l| l as *const dyn Lock).collect();
	locks.sort_unstable();
	locks.windows(2).any(|w| std::ptr::addr_eq(w[0], w[1]))
}

impl<L: OwnedLockable> RetryingLockCollection<L> {
	#[must_use]
	pub const fn new(data: L) -> Self {
		Self { data }
	}
}

impl<'a, L: OwnedLockable> RetryingLockCollection<&'a L> {
	#[must_use]
	pub const fn new_ref(data: &'a L) -> Self {
		Self { data }
	}
}

impl<L: Lockable> RetryingLockCollection<L> {
	#[must_use]
	pub const unsafe fn new_unchecked(data: L) -> Self {
		Self { data }
	}

	pub fn try_new(data: L) -> Option<Self> {
		contains_duplicates(&data).then_some(Self { data })
	}

	pub fn lock<'a, 'key: 'a, Key: Keyable + 'key>(
		&'a self,
		key: Key,
	) -> LockGuard<'a, 'key, L, Key> {
		let mut first_index = 0;
		let mut locks = Vec::new();
		self.data.get_ptrs(&mut locks);

		if locks.is_empty() {
			return LockGuard {
				// safety: there's no data being returned
				guard: unsafe { self.data.guard() },
				key,
				_phantom: PhantomData,
			};
		}

		let guard = unsafe {
			'outer: loop {
				// safety: we have the thread key
				locks[first_index].lock();
				for (i, lock) in locks.iter().enumerate() {
					if i == first_index {
						continue;
					}

					// safety: we have the thread key
					if !lock.try_lock() {
						for lock in locks.iter().take(i) {
							// safety: we already locked all of these
							lock.unlock();
						}

						if first_index >= i {
							// safety: this is already locked and can't be unlocked
							//         by the previous loop
							locks[first_index].unlock();
						}

						first_index = i;
						continue 'outer;
					}
				}

				// safety: we locked all the data
				break self.data.guard();
			}
		};

		LockGuard {
			guard,
			key,
			_phantom: PhantomData,
		}
	}

	pub fn try_lock<'a, 'key: 'a, Key: Keyable + 'key>(
		&'a self,
		key: Key,
	) -> Option<LockGuard<'a, 'key, L, Key>> {
		let mut locks = Vec::new();
		self.data.get_ptrs(&mut locks);

		if locks.is_empty() {
			return Some(LockGuard {
				// safety: there's no data being returned
				guard: unsafe { self.data.guard() },
				key,
				_phantom: PhantomData,
			});
		}

		let guard = unsafe {
			for (i, lock) in locks.iter().enumerate() {
				// safety: we have the thread key
				if !lock.try_lock() {
					for lock in locks.iter().take(i) {
						// safety: we already locked all of these
						lock.unlock();
					}
					return None;
				}
			}

			// safety: we locked all the data
			self.data.guard()
		};

		Some(LockGuard {
			guard,
			key,
			_phantom: PhantomData,
		})
	}

	pub fn unlock<'key, Key: Keyable + 'key>(guard: LockGuard<'_, 'key, L, Key>) -> Key {
		drop(guard.guard);
		guard.key
	}
}