use safina::threadpool::{NewThreadPoolError, ThreadPool, TryScheduleError, INTERNAL_MAX_THREADS}; use std::ops::Range; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; use std::time::{Duration, Instant}; static LOCK: std::sync::Mutex<()> = std::sync::Mutex::new(()); fn assert_elapsed(before: Instant, range_ms: Range) { assert!(!range_ms.is_empty(), "invalid range {:?}", range_ms); let elapsed = before.elapsed(); println!("elapsed = {} ms", elapsed.as_millis()); let duration_range = Duration::from_millis(range_ms.start)..Duration::from_millis(range_ms.end); assert!( duration_range.contains(&elapsed), "{:?} elapsed, out of range {:?}", elapsed, duration_range ); } fn set_max_threads(n: usize) { println!("set_max_threads({n})"); INTERNAL_MAX_THREADS.store(n, Ordering::Release); } fn sleep_ms(ms: u64) { std::thread::sleep(Duration::from_millis(ms)); } fn panic_threads(pool: &ThreadPool, num: usize) { let pause = Arc::new(AtomicBool::new(true)); for _ in 0..num { let pause_clone = pause.clone(); pool.try_schedule(move || { println!( "thread {:?} waiting", std::thread::current().name().unwrap_or("") ); while pause_clone.load(Ordering::Acquire) { sleep_ms(10); } println!( "panicking thread {:?}", std::thread::current().name().unwrap_or("") ); panic!("ignore this panic"); }) .unwrap(); } sleep_ms(100); pause.store(false, Ordering::Release); sleep_ms(100); } #[test] fn new_thread_pool_at_max() { let _guard = LOCK.lock().unwrap(); set_max_threads(2); ThreadPool::new("test", 2).unwrap(); } #[test] fn new_thread_pool_error_spawn() { let _guard = LOCK.lock().unwrap(); set_max_threads(2); let err = ThreadPool::new("test", 3).unwrap_err(); assert!(matches!(err, NewThreadPoolError::Spawn(_))); } #[test] fn schedule_retries_thread_start() { let _guard = LOCK.lock().unwrap(); set_max_threads(3); let pool = ThreadPool::new("test", 3).unwrap(); panic_threads(&pool, 3); set_max_threads(0); let before = Instant::now(); std::thread::spawn(|| { sleep_ms(100); set_max_threads(1); }); let (sender, receiver) = std::sync::mpsc::channel(); pool.schedule(move || { println!("sending ()"); sender.send(()).unwrap(); }); receiver.recv_timeout(Duration::from_millis(500)).unwrap(); assert_elapsed(before, 100..200); } // schedule_retries_when_queue_full is in tests/test.rs . // try_schedule_queue_full is in tests/test.rs . #[test] fn try_schedule_no_threads() { let _guard = LOCK.lock().unwrap(); set_max_threads(2); let pool = ThreadPool::new("test", 2).unwrap(); panic_threads(&pool, 2); set_max_threads(0); let result = pool.try_schedule(|| {}); assert!( matches!(result, Err(TryScheduleError::NoThreads(_))), "{:?}", result ); } #[test] fn try_schedule_respawn() { let _guard = LOCK.lock().unwrap(); set_max_threads(2); let pool = ThreadPool::new("test", 2).unwrap(); panic_threads(&pool, 1); set_max_threads(1); let result = pool.try_schedule(|| {}); assert!( matches!(result, Err(TryScheduleError::Respawn(_))), "{:?}", result ); } #[test] fn threads_stop_after_pool_drops() { let _guard = LOCK.lock().unwrap(); set_max_threads(2); let pool = ThreadPool::new("test", 2).unwrap(); let num_live_threads_fn = pool.num_live_threads_fn(); drop(pool); std::thread::sleep(Duration::from_millis(100)); assert_eq!(0, num_live_threads_fn()); }