test_scheduler.rs

  1use crate::{
  2    BackgroundExecutor, Clock, ForegroundExecutor, Scheduler, SessionId, TestClock, Timer,
  3};
  4use async_task::Runnable;
  5use backtrace::{Backtrace, BacktraceFrame};
  6use futures::{FutureExt as _, channel::oneshot, future::LocalBoxFuture};
  7use parking_lot::Mutex;
  8use rand::prelude::*;
  9use std::{
 10    any::type_name_of_val,
 11    collections::{BTreeMap, VecDeque},
 12    env,
 13    fmt::Write,
 14    future::Future,
 15    mem,
 16    ops::RangeInclusive,
 17    panic::{self, AssertUnwindSafe},
 18    pin::Pin,
 19    sync::{
 20        Arc,
 21        atomic::{AtomicBool, Ordering::SeqCst},
 22    },
 23    task::{Context, Poll, RawWaker, RawWakerVTable, Waker},
 24    thread::{self, Thread},
 25    time::{Duration, Instant},
 26};
 27
 28const PENDING_TRACES_VAR_NAME: &str = "PENDING_TRACES";
 29
 30pub struct TestScheduler {
 31    clock: Arc<TestClock>,
 32    rng: Arc<Mutex<StdRng>>,
 33    state: Arc<Mutex<SchedulerState>>,
 34    thread: Thread,
 35}
 36
 37impl TestScheduler {
 38    /// Run a test once with default configuration (seed 0)
 39    pub fn once<R>(f: impl AsyncFnOnce(Arc<TestScheduler>) -> R) -> R {
 40        Self::with_seed(0, f)
 41    }
 42
 43    /// Run a test multiple times with sequential seeds (0, 1, 2, ...)
 44    pub fn many<R>(
 45        default_iterations: usize,
 46        mut f: impl AsyncFnMut(Arc<TestScheduler>) -> R,
 47    ) -> Vec<R> {
 48        let num_iterations = std::env::var("ITERATIONS")
 49            .map(|iterations| iterations.parse().unwrap())
 50            .unwrap_or(default_iterations);
 51
 52        let seed = std::env::var("SEED")
 53            .map(|seed| seed.parse().unwrap())
 54            .unwrap_or(0);
 55
 56        (seed..num_iterations as u64)
 57            .map(|seed| {
 58                let mut unwind_safe_f = AssertUnwindSafe(&mut f);
 59                eprintln!("Running seed: {seed}");
 60                match panic::catch_unwind(move || Self::with_seed(seed, &mut *unwind_safe_f)) {
 61                    Ok(result) => result,
 62                    Err(error) => {
 63                        eprintln!("\x1b[31mFailing Seed: {seed}\x1b[0m");
 64                        panic::resume_unwind(error);
 65                    }
 66                }
 67            })
 68            .collect()
 69    }
 70
 71    fn with_seed<R>(seed: u64, f: impl AsyncFnOnce(Arc<TestScheduler>) -> R) -> R {
 72        let scheduler = Arc::new(TestScheduler::new(TestSchedulerConfig::with_seed(seed)));
 73        let future = f(scheduler.clone());
 74        let result = scheduler.foreground().block_on(future);
 75        scheduler.run(); // Ensure spawned tasks finish up before returning in tests
 76        result
 77    }
 78
 79    pub fn new(config: TestSchedulerConfig) -> Self {
 80        Self {
 81            rng: Arc::new(Mutex::new(StdRng::seed_from_u64(config.seed))),
 82            state: Arc::new(Mutex::new(SchedulerState {
 83                runnables: VecDeque::new(),
 84                timers: Vec::new(),
 85                blocked_sessions: Vec::new(),
 86                randomize_order: config.randomize_order,
 87                allow_parking: config.allow_parking,
 88                timeout_ticks: config.timeout_ticks,
 89                next_session_id: SessionId(0),
 90                capture_pending_traces: config.capture_pending_traces,
 91                pending_traces: BTreeMap::new(),
 92                next_trace_id: TraceId(0),
 93            })),
 94            clock: Arc::new(TestClock::new()),
 95            thread: thread::current(),
 96        }
 97    }
 98
 99    pub fn clock(&self) -> Arc<TestClock> {
100        self.clock.clone()
101    }
102
103    pub fn rng(&self) -> Arc<Mutex<StdRng>> {
104        self.rng.clone()
105    }
106
107    pub fn set_timeout_ticks(&self, timeout_ticks: RangeInclusive<usize>) {
108        self.state.lock().timeout_ticks = timeout_ticks;
109    }
110
111    pub fn allow_parking(&self) {
112        self.state.lock().allow_parking = true;
113    }
114
115    pub fn forbid_parking(&self) {
116        self.state.lock().allow_parking = false;
117    }
118
119    /// Create a foreground executor for this scheduler
120    pub fn foreground(self: &Arc<Self>) -> ForegroundExecutor {
121        let session_id = {
122            let mut state = self.state.lock();
123            state.next_session_id.0 += 1;
124            state.next_session_id
125        };
126        ForegroundExecutor::new(session_id, self.clone())
127    }
128
129    /// Create a background executor for this scheduler
130    pub fn background(self: &Arc<Self>) -> BackgroundExecutor {
131        BackgroundExecutor::new(self.clone())
132    }
133
134    pub fn yield_random(&self) -> Yield {
135        let rng = &mut *self.rng.lock();
136        if rng.random_bool(0.1) {
137            Yield(rng.random_range(10..20))
138        } else {
139            Yield(rng.random_range(0..2))
140        }
141    }
142
143    pub fn run(&self) {
144        while self.step() {
145            // Continue until no work remains
146        }
147    }
148
149    pub fn run_with_clock_advancement(&self) {
150        while self.step() || self.advance_clock_to_next_timer() {
151            // Continue until no work remains
152        }
153    }
154
155    fn step(&self) -> bool {
156        let elapsed_timers = {
157            let mut state = self.state.lock();
158            let end_ix = state
159                .timers
160                .partition_point(|timer| timer.expiration <= self.clock.now());
161            state.timers.drain(..end_ix).collect::<Vec<_>>()
162        };
163
164        if !elapsed_timers.is_empty() {
165            return true;
166        }
167
168        let runnable = {
169            let state = &mut *self.state.lock();
170            let ix = state.runnables.iter().position(|runnable| {
171                runnable
172                    .session_id
173                    .is_none_or(|session_id| !state.blocked_sessions.contains(&session_id))
174            });
175            ix.and_then(|ix| state.runnables.remove(ix))
176        };
177
178        if let Some(runnable) = runnable {
179            runnable.run();
180            return true;
181        }
182
183        false
184    }
185
186    fn advance_clock_to_next_timer(&self) -> bool {
187        if let Some(timer) = self.state.lock().timers.first() {
188            self.clock.advance(timer.expiration - self.clock.now());
189            true
190        } else {
191            false
192        }
193    }
194
195    pub fn advance_clock(&self, duration: Duration) {
196        let next_now = self.clock.now() + duration;
197        loop {
198            self.run();
199            if let Some(timer) = self.state.lock().timers.first()
200                && timer.expiration <= next_now
201            {
202                self.clock.advance(timer.expiration - self.clock.now());
203            } else {
204                break;
205            }
206        }
207        self.clock.advance(next_now - self.clock.now());
208    }
209
210    fn park(&self, deadline: Option<Instant>) -> bool {
211        if self.state.lock().allow_parking {
212            if let Some(deadline) = deadline {
213                let now = Instant::now();
214                let timeout = deadline.saturating_duration_since(now);
215                thread::park_timeout(timeout);
216                now.elapsed() < timeout
217            } else {
218                thread::park();
219                true
220            }
221        } else if deadline.is_some() {
222            false
223        } else if self.state.lock().capture_pending_traces {
224            let mut pending_traces = String::new();
225            for (_, trace) in mem::take(&mut self.state.lock().pending_traces) {
226                writeln!(pending_traces, "{:?}", exclude_wakers_from_trace(trace)).unwrap();
227            }
228            panic!("Parking forbidden. Pending traces:\n{}", pending_traces);
229        } else {
230            panic!(
231                "Parking forbidden. Re-run with {PENDING_TRACES_VAR_NAME}=1 to show pending traces"
232            );
233        }
234    }
235}
236
237impl Scheduler for TestScheduler {
238    /// Block until the given future completes, with an optional timeout. If the
239    /// future is unable to make progress at any moment before the timeout and
240    /// no other tasks or timers remain, we panic unless parking is allowed. If
241    /// parking is allowed, we block up to the timeout or indefinitely if none
242    /// is provided. This is to allow testing a mix of deterministic and
243    /// non-deterministic async behavior, such as when interacting with I/O in
244    /// an otherwise deterministic test.
245    fn block(
246        &self,
247        session_id: Option<SessionId>,
248        mut future: LocalBoxFuture<()>,
249        timeout: Option<Duration>,
250    ) {
251        if let Some(session_id) = session_id {
252            self.state.lock().blocked_sessions.push(session_id);
253        }
254
255        let deadline = timeout.map(|timeout| Instant::now() + timeout);
256        let awoken = Arc::new(AtomicBool::new(false));
257        let waker = Box::new(TracingWaker {
258            id: None,
259            awoken: awoken.clone(),
260            thread: self.thread.clone(),
261            state: self.state.clone(),
262        });
263        let waker = unsafe { Waker::new(Box::into_raw(waker) as *const (), &WAKER_VTABLE) };
264        let max_ticks = if timeout.is_some() {
265            self.rng
266                .lock()
267                .random_range(self.state.lock().timeout_ticks.clone())
268        } else {
269            usize::MAX
270        };
271        let mut cx = Context::from_waker(&waker);
272
273        for _ in 0..max_ticks {
274            let Poll::Pending = future.poll_unpin(&mut cx) else {
275                break;
276            };
277
278            let mut stepped = None;
279            while self.rng.lock().random() {
280                let stepped = stepped.get_or_insert(false);
281                if self.step() {
282                    *stepped = true;
283                } else {
284                    break;
285                }
286            }
287
288            let stepped = stepped.unwrap_or(true);
289            let awoken = awoken.swap(false, SeqCst);
290            if !stepped && !awoken && !self.advance_clock_to_next_timer() {
291                if !self.park(deadline) {
292                    break;
293                }
294            }
295        }
296
297        if session_id.is_some() {
298            self.state.lock().blocked_sessions.pop();
299        }
300    }
301
302    fn schedule_foreground(&self, session_id: SessionId, runnable: Runnable) {
303        let mut state = self.state.lock();
304        let ix = if state.randomize_order {
305            let start_ix = state
306                .runnables
307                .iter()
308                .rposition(|task| task.session_id == Some(session_id))
309                .map_or(0, |ix| ix + 1);
310            self.rng
311                .lock()
312                .random_range(start_ix..=state.runnables.len())
313        } else {
314            state.runnables.len()
315        };
316        state.runnables.insert(
317            ix,
318            ScheduledRunnable {
319                session_id: Some(session_id),
320                runnable,
321            },
322        );
323        drop(state);
324        self.thread.unpark();
325    }
326
327    fn schedule_background(&self, runnable: Runnable) {
328        let mut state = self.state.lock();
329        let ix = if state.randomize_order {
330            self.rng.lock().random_range(0..=state.runnables.len())
331        } else {
332            state.runnables.len()
333        };
334        state.runnables.insert(
335            ix,
336            ScheduledRunnable {
337                session_id: None,
338                runnable,
339            },
340        );
341        drop(state);
342        self.thread.unpark();
343    }
344
345    fn timer(&self, duration: Duration) -> Timer {
346        let (tx, rx) = oneshot::channel();
347        let state = &mut *self.state.lock();
348        state.timers.push(ScheduledTimer {
349            expiration: self.clock.now() + duration,
350            _notify: tx,
351        });
352        state.timers.sort_by_key(|timer| timer.expiration);
353        Timer(rx)
354    }
355
356    fn clock(&self) -> Arc<dyn Clock> {
357        self.clock.clone()
358    }
359
360    fn as_test(&self) -> &TestScheduler {
361        self
362    }
363}
364
365#[derive(Clone, Debug)]
366pub struct TestSchedulerConfig {
367    pub seed: u64,
368    pub randomize_order: bool,
369    pub allow_parking: bool,
370    pub capture_pending_traces: bool,
371    pub timeout_ticks: RangeInclusive<usize>,
372}
373
374impl TestSchedulerConfig {
375    pub fn with_seed(seed: u64) -> Self {
376        Self {
377            seed,
378            ..Default::default()
379        }
380    }
381}
382
383impl Default for TestSchedulerConfig {
384    fn default() -> Self {
385        Self {
386            seed: 0,
387            randomize_order: true,
388            allow_parking: false,
389            capture_pending_traces: env::var(PENDING_TRACES_VAR_NAME)
390                .map_or(false, |var| var == "1" || var == "true"),
391            timeout_ticks: 0..=1000,
392        }
393    }
394}
395
396struct ScheduledRunnable {
397    session_id: Option<SessionId>,
398    runnable: Runnable,
399}
400
401impl ScheduledRunnable {
402    fn run(self) {
403        self.runnable.run();
404    }
405}
406
407struct ScheduledTimer {
408    expiration: Instant,
409    _notify: oneshot::Sender<()>,
410}
411
412struct SchedulerState {
413    runnables: VecDeque<ScheduledRunnable>,
414    timers: Vec<ScheduledTimer>,
415    blocked_sessions: Vec<SessionId>,
416    randomize_order: bool,
417    allow_parking: bool,
418    timeout_ticks: RangeInclusive<usize>,
419    next_session_id: SessionId,
420    capture_pending_traces: bool,
421    next_trace_id: TraceId,
422    pending_traces: BTreeMap<TraceId, Backtrace>,
423}
424
425const WAKER_VTABLE: RawWakerVTable = RawWakerVTable::new(
426    TracingWaker::clone_raw,
427    TracingWaker::wake_raw,
428    TracingWaker::wake_by_ref_raw,
429    TracingWaker::drop_raw,
430);
431
432#[derive(Copy, Clone, Eq, PartialEq, PartialOrd, Ord)]
433struct TraceId(usize);
434
435struct TracingWaker {
436    id: Option<TraceId>,
437    awoken: Arc<AtomicBool>,
438    thread: Thread,
439    state: Arc<Mutex<SchedulerState>>,
440}
441
442impl Clone for TracingWaker {
443    fn clone(&self) -> Self {
444        let mut state = self.state.lock();
445        let id = if state.capture_pending_traces {
446            let id = state.next_trace_id;
447            state.next_trace_id.0 += 1;
448            state.pending_traces.insert(id, Backtrace::new_unresolved());
449            Some(id)
450        } else {
451            None
452        };
453        Self {
454            id,
455            awoken: self.awoken.clone(),
456            thread: self.thread.clone(),
457            state: self.state.clone(),
458        }
459    }
460}
461
462impl Drop for TracingWaker {
463    fn drop(&mut self) {
464        if let Some(id) = self.id {
465            self.state.lock().pending_traces.remove(&id);
466        }
467    }
468}
469
470impl TracingWaker {
471    fn wake(self) {
472        self.wake_by_ref();
473    }
474
475    fn wake_by_ref(&self) {
476        if let Some(id) = self.id {
477            self.state.lock().pending_traces.remove(&id);
478        }
479        self.awoken.store(true, SeqCst);
480        self.thread.unpark();
481    }
482
483    fn clone_raw(waker: *const ()) -> RawWaker {
484        let waker = waker as *const TracingWaker;
485        let waker = unsafe { &*waker };
486        RawWaker::new(
487            Box::into_raw(Box::new(waker.clone())) as *const (),
488            &WAKER_VTABLE,
489        )
490    }
491
492    fn wake_raw(waker: *const ()) {
493        let waker = unsafe { Box::from_raw(waker as *mut TracingWaker) };
494        waker.wake();
495    }
496
497    fn wake_by_ref_raw(waker: *const ()) {
498        let waker = waker as *const TracingWaker;
499        let waker = unsafe { &*waker };
500        waker.wake_by_ref();
501    }
502
503    fn drop_raw(waker: *const ()) {
504        let waker = unsafe { Box::from_raw(waker as *mut TracingWaker) };
505        drop(waker);
506    }
507}
508
509pub struct Yield(usize);
510
511impl Future for Yield {
512    type Output = ();
513
514    fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
515        if self.0 == 0 {
516            Poll::Ready(())
517        } else {
518            self.0 -= 1;
519            cx.waker().wake_by_ref();
520            Poll::Pending
521        }
522    }
523}
524
525fn exclude_wakers_from_trace(mut trace: Backtrace) -> Backtrace {
526    trace.resolve();
527    let mut frames: Vec<BacktraceFrame> = trace.into();
528    let waker_clone_frame_ix = frames.iter().position(|frame| {
529        frame.symbols().iter().any(|symbol| {
530            symbol
531                .name()
532                .is_some_and(|name| format!("{name:#?}") == type_name_of_val(&Waker::clone))
533        })
534    });
535
536    if let Some(waker_clone_frame_ix) = waker_clone_frame_ix {
537        frames.drain(..waker_clone_frame_ix + 1);
538    }
539
540    Backtrace::from(frames)
541}