#![feature(allocator_api)] use core::{ cell::{Cell, UnsafeCell}, marker::PhantomPinned, ops::{Deref, DerefMut}, pin::Pin, sync::atomic::{AtomicBool, Ordering}, }; use std::{ sync::Arc, thread::{self, park, sleep, Builder, Thread}, time::Duration, }; use pinned_init::*; #[allow(unused_attributes)] #[path = "./linked_list.rs"] pub mod linked_list; pub use linked_list::Error; use linked_list::*; pub struct SpinLock { inner: AtomicBool, } impl SpinLock { #[inline] pub fn acquire(&self) -> SpinLockGuard<'_> { while self .inner .compare_exchange(false, true, Ordering::Acquire, Ordering::Relaxed) .is_err() { while self.inner.load(Ordering::Relaxed) { thread::yield_now(); } } SpinLockGuard(self) } #[inline] pub const fn new() -> Self { Self { inner: AtomicBool::new(false), } } } pub struct SpinLockGuard<'a>(&'a SpinLock); impl Drop for SpinLockGuard<'_> { #[inline] fn drop(&mut self) { self.0.inner.store(false, Ordering::Release); } } #[pin_data] pub struct CMutex { #[pin] wait_list: ListHead, spin_lock: SpinLock, locked: Cell, #[pin] data: UnsafeCell, } impl CMutex { #[inline] pub fn new(val: impl PinInit) -> impl PinInit { pin_init!(CMutex { wait_list <- ListHead::new(), spin_lock: SpinLock::new(), locked: Cell::new(false), data <- unsafe { pin_init_from_closure(|slot: *mut UnsafeCell| val.__pinned_init(slot.cast::())) }, }) } #[inline] pub fn lock(&self) -> Pin> { let mut sguard = self.spin_lock.acquire(); if self.locked.get() { stack_pin_init!(let wait_entry = WaitEntry::insert_new(&self.wait_list)); // println!("wait list length: {}", self.wait_list.size()); while self.locked.get() { drop(sguard); park(); sguard = self.spin_lock.acquire(); } drop(wait_entry); } self.locked.set(true); unsafe { Pin::new_unchecked(CMutexGuard { mtx: self, _pin: PhantomPinned, }) } } pub fn get_data_mut(self: Pin<&mut Self>) -> &mut T { // SAFETY: we have an exclusive reference and thus nobody has access to data. unsafe { &mut *self.data.get() } } } unsafe impl Send for CMutex {} unsafe impl Sync for CMutex {} pub struct CMutexGuard<'a, T> { mtx: &'a CMutex, _pin: PhantomPinned, } impl<'a, T> Drop for CMutexGuard<'a, T> { #[inline] fn drop(&mut self) { let sguard = self.mtx.spin_lock.acquire(); self.mtx.locked.set(false); if let Some(list_field) = self.mtx.wait_list.next() { let wait_entry = list_field.as_ptr().cast::(); unsafe { (*wait_entry).thread.unpark() }; } drop(sguard); } } impl<'a, T> Deref for CMutexGuard<'a, T> { type Target = T; #[inline] fn deref(&self) -> &Self::Target { unsafe { &*self.mtx.data.get() } } } impl<'a, T> DerefMut for CMutexGuard<'a, T> { #[inline] fn deref_mut(&mut self) -> &mut Self::Target { unsafe { &mut *self.mtx.data.get() } } } #[pin_data] #[repr(C)] struct WaitEntry { #[pin] wait_list: ListHead, thread: Thread, } impl WaitEntry { #[inline] fn insert_new(list: &ListHead) -> impl PinInit + '_ { pin_init!(Self { thread: thread::current(), wait_list <- ListHead::insert_prev(list), }) } } #[allow(dead_code)] #[cfg_attr(test, test)] fn main() { let mtx: Pin>> = Arc::pin_init(CMutex::new(0)).unwrap(); let mut handles = vec![]; let thread_count = 20; let workload = if cfg!(miri) { 100 } else { 1_000 }; for i in 0..thread_count { let mtx = mtx.clone(); handles.push( Builder::new() .name(format!("worker #{i}")) .spawn(move || { for _ in 0..workload { *mtx.lock() += 1; } println!("{i} halfway"); sleep(Duration::from_millis((i as u64) * 10)); for _ in 0..workload { *mtx.lock() += 1; } println!("{i} finished"); }) .expect("should not fail"), ); } for h in handles { h.join().expect("thread paniced"); } println!("{:?}", &*mtx.lock()); assert_eq!(*mtx.lock(), workload * thread_count * 2); }