use shuttle::scheduler::RandomScheduler; use shuttle::{check_random, thread, Runner}; use std::panic::{catch_unwind, AssertUnwindSafe}; use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; use tracing::field::{Field, Visit}; use tracing::span::{Attributes, Record}; use tracing::{Event, Id, Metadata, Subscriber}; // Simple `Subscriber` that just remembers the last value of the `iterations` field it has seen from // a `MetricsScheduler`-generated event #[derive(Clone)] struct MetricsSubscriber { iterations: Arc, } impl MetricsSubscriber { fn new() -> Self { Self { iterations: Arc::new(AtomicUsize::new(0)), } } } impl Subscriber for MetricsSubscriber { fn enabled(&self, _metadata: &Metadata<'_>) -> bool { true } fn new_span(&self, _span: &Attributes<'_>) -> Id { // We don't care about span equality so just use the same identity for everything Id::from_u64(1) } fn record(&self, _span: &Id, _values: &Record<'_>) {} fn record_follows_from(&self, _span: &Id, _follows: &Id) {} fn event(&self, event: &Event<'_>) { // If it's an event from the `MetricsScheduler` with an `iterations` counter, record it let metadata = event.metadata(); if metadata.target() == "shuttle::scheduler::metrics" { struct FindIterationsVisitor(Option); impl Visit for FindIterationsVisitor { fn record_debug(&mut self, _field: &Field, _value: &dyn std::fmt::Debug) {} fn record_u64(&mut self, field: &Field, value: u64) { if field.name() == "iterations" { self.0 = Some(value); } } } let mut visitor = FindIterationsVisitor(None); event.record(&mut visitor); if let Some(iterations) = visitor.0 { self.iterations.store(iterations as usize, Ordering::SeqCst); } } } fn enter(&self, _span: &Id) {} fn exit(&self, _span: &Id) {} } // Note: `panic_iteration` is 1-indexed because "iterations" is a count fn iterations_test(run_iterations: usize, panic_iteration: usize) { let metrics = MetricsSubscriber::new(); let _guard = tracing::subscriber::set_default(metrics.clone()); let iterations = Arc::new(AtomicUsize::new(0)); let result = catch_unwind(AssertUnwindSafe(|| { check_random( move || { iterations.fetch_add(1, Ordering::SeqCst); if iterations.load(Ordering::SeqCst) >= panic_iteration { panic!("expected panic"); } thread::spawn(move || { thread::yield_now(); }); }, run_iterations, ); })); assert_eq!(result.is_err(), panic_iteration <= run_iterations); assert_eq!( metrics.iterations.load(Ordering::SeqCst), run_iterations.min(panic_iteration) ); } #[test] fn iterations_test_basic() { iterations_test(10, 20); } #[test] fn iterations_test_panic() { iterations_test(10, 1); iterations_test(10, 5); iterations_test(10, 10); } #[test] fn iterations_without_running() { let metrics = MetricsSubscriber::new(); { let _guard = tracing::subscriber::set_default(metrics.clone()); let scheduler = RandomScheduler::new(10); let _runner = Runner::new(scheduler, Default::default()); } assert_eq!(metrics.iterations.load(Ordering::SeqCst), 0); }