use std::collections::HashSet;
use std::sync::Arc;
use crate::basic::clocks::{check_clock, me};
use shuttle::sync::mpsc::{channel, sync_channel, RecvError, TryRecvError, TrySendError};
use shuttle::{check_dfs, check_random, thread};
use test_log::test;
// The following tests (prefixed with mpsc_loom) are from the
// loom test suite; see https://github.com/tokio-rs/loom/blob/master/tests/mpsc.rs
#[test]
fn mpsc_loom_basic_sequential_usage() {
check_dfs(
move || {
let (s, r) = channel();
s.send(5).unwrap();
let val = r.recv().unwrap();
assert_eq!(val, 5);
},
None,
);
}
#[test]
fn mpsc_loom_basic_parallel_usage() {
check_dfs(
|| {
let (s, r) = channel();
thread::spawn(move || {
assert_eq!(me(), 1);
s.send(5).unwrap();
});
check_clock(|i, c| (c > 0) == (i == 0));
let val = r.recv().unwrap();
// After receiving a message from thread 1, we have a causal dependency on it
check_clock(|i, c| (c > 0) == (i == 0 || i == 1));
assert_eq!(val, 5);
},
None,
);
}
#[test]
fn mpsc_loom_commutative_senders() {
check_dfs(
|| {
let (s, r) = channel();
let s2 = s.clone();
thread::spawn(move || {
assert_eq!(me(), 1);
s.send(5).unwrap();
});
thread::spawn(move || {
assert_eq!(me(), 2);
s2.send(6).unwrap();
});
let mut val = r.recv().unwrap();
check_clock(|i, c| {
(c > 0)
== match val {
5 => i == 0 || i == 1, // thread 1 must have executed
6 => i == 0 || i == 2, // thread 2 must have executed
_ => unreachable!(),
}
});
val += r.recv().unwrap();
check_clock(|i, c| (c > 0) == (i == 0 || i == 1 || i == 2)); // both threads have executed
assert_eq!(val, 11);
},
None,
);
}
fn ignore_result(_: Result) {}
#[test]
#[should_panic(expected = "expected panic: sends can happen in any order")]
fn mpsc_loom_non_commutative_senders1() {
check_dfs(
|| {
let (s, r) = channel();
let s2 = s.clone();
thread::spawn(move || {
ignore_result(s.send(5));
});
thread::spawn(move || {
ignore_result(s2.send(6));
});
let val = r.recv().unwrap();
assert_eq!(val, 5, "expected panic: sends can happen in any order");
ignore_result(r.recv());
},
None,
);
}
#[test]
#[should_panic(expected = "expected panic: sends can happen in any order")]
fn mpsc_loom_non_commutative_senders2() {
check_dfs(
|| {
let (s, r) = channel();
let s2 = s.clone();
thread::spawn(move || {
ignore_result(s.send(5));
});
thread::spawn(move || {
ignore_result(s2.send(6));
});
let val = r.recv().unwrap();
assert_eq!(val, 6, "expected panic: sends can happen in any order");
ignore_result(r.recv());
},
None,
);
}
#[test]
fn mpsc_drop_sender_unbounded() {
check_dfs(
|| {
let (tx, rx) = channel::();
thread::spawn(move || {
drop(tx);
});
assert!(rx.recv().is_err());
// no message was sent, hence no causal dependency
check_clock(|i, c| (c > 0) == (i == 0));
},
None,
);
}
#[test]
fn mpsc_drop_receiver_unbounded() {
check_dfs(
|| {
let (tx, rx) = channel();
drop(rx);
assert!(tx.send(1).is_err());
},
None,
);
}
#[test]
fn mpsc_drop_sender_bounded() {
check_dfs(
|| {
let (tx, rx) = sync_channel::(10);
thread::spawn(move || {
assert!(rx.recv().is_err());
});
drop(tx);
check_clock(|i, c| (c > 0) == (i == 0));
},
None,
);
}
#[test]
fn mpsc_drop_receiver_bounded() {
check_dfs(
|| {
let (tx, rx) = sync_channel(10);
drop(rx);
assert!(tx.send(1).is_err());
},
None,
);
}
#[test]
fn mpsc_drop_sender_rendezvous() {
check_dfs(
|| {
let (tx, rx) = sync_channel::(0);
drop(tx);
assert!(rx.recv().is_err());
},
None,
);
}
#[test]
fn mpsc_drop_receiver_rendezvous() {
check_dfs(
|| {
let (tx, rx) = sync_channel(0);
drop(rx);
assert!(tx.send(1).is_err());
},
None,
);
}
// Example taken from the std::sync::mpsc documentation
// See "buffering behavior" example in
// https://doc.rust-lang.org/std/sync/mpsc/struct.Receiver.html
#[test]
fn mpsc_buffering_behavior() {
check_dfs(
|| {
let (send, recv) = channel();
let handle = thread::spawn(move || {
send.send(1u8).unwrap();
send.send(2).unwrap();
send.send(3).unwrap();
drop(send);
});
// wait for the thread to join so we ensure the sender is dropped
handle.join().unwrap();
// values sent before the sender disconnects are still available afterwards
assert_eq!(Ok(1), recv.recv());
assert_eq!(Ok(2), recv.recv());
assert_eq!(Ok(3), recv.recv());
// but after the values are exhausted, recv() returns an error
assert_eq!(Err(RecvError), recv.recv());
},
None,
);
}
#[test]
fn mpsc_bounded_sum() {
check_dfs(
|| {
let (tx, rx) = sync_channel::(5);
thread::spawn(move || {
assert_eq!(me(), 1);
for _ in 0..5 {
tx.send(1).unwrap();
}
});
let handle = thread::spawn(move || {
let mut sum = 0;
for _ in 0..5 {
let c1 = shuttle::current::clock().get(1); // save knowledge of sender's clock
sum += rx.recv().unwrap();
check_clock(|i, c| (i != 1) || (c > c1)); // sender's clock must have increased
}
sum
});
let r = handle.join().unwrap();
assert_eq!(r, 5);
},
None,
);
}
// Sending on a bounded channel doesn't block the sender if the channel isn't filled
#[test]
fn mpsc_bounded_sender_buffered() {
check_dfs(
|| {
let (tx, _rx) = sync_channel::(10);
let handle = thread::spawn(move || {
for _ in 0..10 {
tx.send(1).unwrap();
}
42
});
let r = handle.join().unwrap();
assert_eq!(r, 42);
},
None,
);
}
// Sending on a bounded channel blocks the sender when the channel becomes full
#[test]
#[should_panic(expected = "deadlock")]
fn mpsc_bounded_sender_blocked() {
check_dfs(
|| {
let (tx, _rx) = sync_channel::(10);
let handle = thread::spawn(move || {
for _ in 0..11 {
tx.send(1).unwrap();
}
42
});
let r = handle.join().unwrap();
assert_eq!(r, 42);
},
None,
);
}
// The following set of tests (prefixed `mpsc_rendezvous_`) check rendezvous channels.
#[test]
fn mpsc_rendezvous_channel() {
check_dfs(
|| {
let (tx, rx) = sync_channel::(0);
thread::spawn(move || {
// This will wait for the parent thread to start receiving
tx.send(53).unwrap();
});
let v = rx.recv().unwrap();
assert_eq!(v, 53);
},
None,
);
}
#[test]
#[should_panic(expected = "deadlock")]
fn mpsc_rendezvous_sender_block() {
check_dfs(
|| {
let (tx, rx) = sync_channel::(0);
tx.send(53).unwrap();
rx.recv().unwrap();
rx.recv().unwrap();
},
None,
);
}
#[test]
fn mpsc_rendezvous_two_threads() {
check_dfs(
|| {
let (tx1, rx) = sync_channel::(0);
let tx2 = tx1.clone();
thread::spawn(move || {
tx1.send(10).unwrap();
});
thread::spawn(move || {
tx2.send(20).unwrap();
});
let v1 = rx.recv().unwrap();
let v2 = rx.recv().unwrap();
assert_eq!(v1 + v2, 30);
},
None,
);
}
// An mpsc Receiver is not clone-able and is !Sync, so it can't be shared, but
// it is Send, so it can be transferred between threads. In this example, we
// rendezvous two separate threads with the main thread by passing the receiver
// to the rendezvous channel using a second, bounded channel.
#[test]
fn mpsc_rendezvous_transfer_receiver() {
check_dfs(
|| {
// First channel is used to send the receiver from one thread to another
let (tx1, rx1) = sync_channel(1);
// Second channel is a rendezvous channel used to synchronize with the main thread
let (tx2, rx2) = sync_channel::(0);
thread::spawn(move || {
let p = rx2.recv().unwrap();
assert_eq!(p, 10);
// Send the receiver to the 2nd thread
tx1.send(rx2).unwrap();
});
let handle = thread::spawn(move || {
let rx2 = rx1.recv().unwrap();
let q = rx2.recv().unwrap();
assert_eq!(q, 20);
});
tx2.send(10).unwrap();
tx2.send(20).unwrap();
// Wait for the 2nd thread to finish
handle.join().unwrap();
},
None,
);
}
// From libstd test suite
#[test]
fn mpsc_send_from_outside_runtime() {
check_dfs(
|| {
let (tx1, rx1) = channel::<()>();
let (tx2, rx2) = channel::();
let t1 = thread::spawn(move || {
tx1.send(()).unwrap();
for _ in 0..7 {
assert_eq!(rx2.recv().unwrap(), 1);
}
});
rx1.recv().unwrap();
let t2 = thread::spawn(move || {
for _ in 0..7 {
tx2.send(1).unwrap();
}
});
t1.join().expect("thread panicked");
t2.join().expect("thread panicked");
},
None,
);
}
// From libstd test suite
#[test]
fn mpsc_recv_from_outside_runtime() {
check_dfs(
|| {
let (tx, rx) = channel::();
let t = thread::spawn(move || {
for _ in 0..10 {
assert_eq!(rx.recv().unwrap(), 1);
}
});
for _ in 0..10 {
tx.send(1).unwrap();
}
t.join().expect("thread panicked");
},
None,
);
}
// From libstd test suite
// TODO This test checks that joining on a child thread that panicked returns Err, but Shuttle
// TODO aborts a test as soon as any thread panics (and propagates the panic), so we never get a
// TODO chance to check the join result. If this abort behavior ever changes, this test should start
// TODO failing because it no longer propagates the thread panic.
#[test]
#[should_panic(expected = "RecvError")]
fn mpsc_oneshot_single_thread_recv_chan_close() {
check_dfs(
|| {
// Receiving on a closed chan will panic and should propagate to the JoinHandle
let res = thread::spawn(move || {
let (tx, rx) = channel::();
drop(tx);
rx.recv().unwrap();
})
.join();
assert!(res.is_err());
},
None,
);
}
fn mpsc_senders_with_blocking_inner(num_senders: usize, channel_size: usize) {
assert!(num_senders >= channel_size);
let num_receives = num_senders - channel_size;
let (tx, rx) = sync_channel::(channel_size);
let senders = (0..num_senders)
.map(move |i| {
let tx = tx.clone();
thread::spawn(move || {
tx.send(i).unwrap();
})
})
.collect::>();
// Receive enough messages to ensure no sender will block
for _ in 0..num_receives {
rx.recv().unwrap();
}
for sender in senders {
sender.join().unwrap();
}
}
#[test]
fn mpsc_some_senders_with_blocking() {
check_dfs(|| mpsc_senders_with_blocking_inner(4, 2), None);
}
#[test]
fn mpsc_many_senders_with_blocking() {
check_random(|| mpsc_senders_with_blocking_inner(1000, 500), 10);
}
#[test]
fn mpsc_many_senders_drop_receiver() {
const NUM_SENDERS: usize = 4;
const CHANNEL_SIZE: usize = 2;
check_dfs(
|| {
let (tx, rx) = sync_channel::(CHANNEL_SIZE);
let senders = (0..NUM_SENDERS)
.map(move |i| {
let tx = tx.clone();
thread::spawn(move || {
let _ = tx.send(i);
})
})
.collect::>();
// Drop the receiver; this will unblock any waiting senders
drop(rx);
// Make sure all senders finish
for sender in senders {
sender.join().unwrap();
}
},
None,
);
}
#[test]
fn test_nested_recv_iter() {
check_dfs(
|| {
let (tx, rx) = sync_channel::(0);
let (total_tx, total_rx) = sync_channel::(0);
let _t = thread::spawn(move || {
let mut acc = 0;
for x in rx.iter() {
acc += x;
}
total_tx.send(acc).unwrap();
});
tx.send(3).unwrap();
tx.send(1).unwrap();
tx.send(2).unwrap();
drop(tx);
assert_eq!(total_rx.recv().unwrap(), 6);
},
None,
);
}
#[test]
fn mpsc_try_recv_iter_sync() {
let observed_values = Arc::new(std::sync::Mutex::new(HashSet::new()));
let observed_values_clone = Arc::clone(&observed_values);
check_dfs(
move || {
let (tx, rx) = sync_channel::(1);
let (total_tx, total_rx) = sync_channel::<(_, _)>(0);
let _t = thread::spawn(move || {
let mut acc = 0;
for x in rx.try_iter() {
acc += x;
thread::yield_now();
}
total_tx.send((rx, acc)).unwrap();
});
let mut sent_acc = 0;
for i in [3, 1, 2] {
if tx.try_send(i).is_ok() {
sent_acc += i;
thread::yield_now();
}
}
drop(tx);
let (rx, mut recv_acc) = total_rx.recv().unwrap();
observed_values_clone.lock().unwrap().insert(recv_acc);
recv_acc += rx.try_recv().unwrap_or(0);
assert_eq!(recv_acc, sent_acc);
},
None,
);
let observed_values = Arc::try_unwrap(observed_values).unwrap().into_inner().unwrap();
// Could fail to rendezvous at any step of the sequence
assert_eq!(observed_values, HashSet::from([0, 3, 4, 6]));
}
#[test]
fn mpsc_try_recv_iter_rendezvous() {
let observed_values = Arc::new(std::sync::Mutex::new(HashSet::new()));
let observed_values_clone = Arc::clone(&observed_values);
check_dfs(
move || {
let (tx, rx) = sync_channel::(0);
let (total_tx, total_rx) = sync_channel::(0);
let _t = thread::spawn(move || {
let mut acc = 0;
for x in rx.try_iter() {
acc += x;
thread::yield_now();
}
drop(rx);
total_tx.send(acc).unwrap();
});
for i in [3, 1, 2] {
let _ = tx.send(i);
thread::yield_now();
}
drop(tx);
let result = total_rx.recv().unwrap();
observed_values_clone.lock().unwrap().insert(result);
},
None,
);
let observed_values = Arc::try_unwrap(observed_values).unwrap().into_inner().unwrap();
// Could fail to rendezvous at any step of the sequence
assert_eq!(observed_values, HashSet::from([0, 3, 4, 6]));
}
#[test]
fn try_recv_send_rendezvous() {
let success = Arc::new(std::sync::Mutex::new(false));
let success_clone = Arc::clone(&success);
check_dfs(
move || {
let (tx, rx) = sync_channel::(0);
let success_clone = Arc::clone(&success_clone);
thread::spawn(move || {
let result = rx.try_recv();
if result.is_ok() {
*success_clone.lock().unwrap() = true;
}
});
let _ = tx.send(1);
},
None,
);
let success = Arc::try_unwrap(success).unwrap().into_inner().unwrap();
assert!(success);
}
#[test]
fn mpsc_oneshot_single_thread_peek_close() {
check_dfs(
|| {
let (tx, rx) = channel::();
drop(tx);
assert_eq!(rx.try_recv(), Err(TryRecvError::Disconnected));
assert_eq!(rx.try_recv(), Err(TryRecvError::Disconnected));
},
None,
);
}
#[test]
fn mpsc_try_recv() {
check_dfs(
|| {
let (tx, rx) = channel::();
assert_eq!(rx.try_recv(), Err(TryRecvError::Empty));
tx.send(1).unwrap();
assert_eq!(rx.try_recv(), Ok(1));
assert_eq!(rx.try_recv(), Err(TryRecvError::Empty));
},
None,
);
}
fn mpsc_try_recv_permutations(drop_sender: bool) {
let observed_values = Arc::new(std::sync::Mutex::new(vec![]));
let observed_values_clone = Arc::clone(&observed_values);
check_dfs(
move || {
let (tx, rx) = channel::();
let thd = thread::spawn(move || {
if !drop_sender {
tx.send(1).unwrap();
}
});
let result = rx.try_recv();
observed_values_clone.lock().unwrap().push(result);
// Should always fail the second time
assert!(rx.try_recv().is_err());
thd.join().unwrap();
},
None,
);
let observed_values = Arc::try_unwrap(observed_values).unwrap().into_inner().unwrap();
assert_eq!(observed_values.len(), 2);
assert!(observed_values.contains(&Err(TryRecvError::Empty)));
if drop_sender {
assert!(observed_values.contains(&Err(TryRecvError::Disconnected)));
} else {
assert!(observed_values.contains(&Ok(1)));
}
}
#[test]
fn mpsc_try_recv_permutations_no_drop() {
mpsc_try_recv_permutations(false);
}
#[test]
fn mpsc_try_recv_permutations_drop() {
mpsc_try_recv_permutations(true);
}
#[test]
fn mpsc_try_send_buffered() {
check_dfs(
|| {
let (tx, rx) = sync_channel::(1);
assert_eq!(tx.try_send(1), Ok(()));
assert_eq!(rx.recv(), Ok(1));
assert_eq!(tx.try_send(2), Ok(()));
assert_eq!(rx.recv(), Ok(2));
drop(rx);
assert_eq!(tx.try_send(3), Err(TrySendError::Disconnected(3)));
},
None,
);
}
#[test]
fn mpsc_try_send_rendezvous() {
check_dfs(
|| {
let (tx, _rx) = sync_channel::(0);
assert_eq!(tx.try_send(1), Err(TrySendError::Full(1)));
},
None,
);
}
fn mpsc_try_send_permutations(drop_receiver: bool, rendezvous: bool) {
let observed_values = Arc::new(std::sync::Mutex::new(vec![]));
let observed_values_clone = Arc::clone(&observed_values);
check_dfs(
move || {
let size = if rendezvous { 0 } else { 1 };
let (tx, rx) = sync_channel::(size);
let thd = thread::spawn(move || {
if !drop_receiver {
let _ = rx.recv();
}
});
let result = tx.try_send(1);
observed_values_clone.lock().unwrap().push(result);
drop(tx);
thd.join().unwrap();
},
None,
);
let observed_values = Arc::try_unwrap(observed_values).unwrap().into_inner().unwrap();
match (drop_receiver, rendezvous) {
(true, true) => assert_eq!(
observed_values,
vec![Err(TrySendError::Full(1)), Err(TrySendError::Disconnected(1))]
),
(true, false) => {
assert_eq!(observed_values.len(), 2);
assert!(observed_values.contains(&Err(TrySendError::Disconnected(1))));
assert!(observed_values.contains(&Ok(())));
}
(false, true) => {
assert_eq!(observed_values.len(), 2);
assert!(observed_values.contains(&Err(TrySendError::Full(1))));
assert!(observed_values.contains(&Ok(())));
}
(false, false) => assert_eq!(observed_values, vec![Ok(()), Ok(())]),
}
}
#[test]
fn mpsc_try_send_permutations_no_drop() {
mpsc_try_send_permutations(false, false);
}
#[test]
fn mpsc_try_send_permutations_drop() {
mpsc_try_send_permutations(true, false);
}
#[test]
fn mpsc_try_send_permutations_no_drop_rendezvous() {
mpsc_try_send_permutations(false, true);
}
#[test]
fn mpsc_try_send_permutations_drop_rendezvous() {
mpsc_try_send_permutations(true, true);
}