test_scheduler.rs

  1use crate::{
  2    BackgroundExecutor, Clock, ForegroundExecutor, Scheduler, SessionId, TestClock, Timer,
  3};
  4use async_task::Runnable;
  5use backtrace::{Backtrace, BacktraceFrame};
  6use futures::{FutureExt as _, channel::oneshot, future::LocalBoxFuture};
  7use parking_lot::Mutex;
  8use rand::prelude::*;
  9use std::{
 10    any::type_name_of_val,
 11    collections::{BTreeMap, VecDeque},
 12    env,
 13    fmt::Write,
 14    future::Future,
 15    mem,
 16    ops::RangeInclusive,
 17    panic::{self, AssertUnwindSafe},
 18    pin::Pin,
 19    sync::{
 20        Arc,
 21        atomic::{AtomicBool, Ordering::SeqCst},
 22    },
 23    task::{Context, Poll, RawWaker, RawWakerVTable, Waker},
 24    thread::{self, Thread},
 25    time::{Duration, Instant},
 26};
 27
 28const PENDING_TRACES_VAR_NAME: &str = "PENDING_TRACES";
 29
 30pub struct TestScheduler {
 31    clock: Arc<TestClock>,
 32    rng: Arc<Mutex<StdRng>>,
 33    state: Arc<Mutex<SchedulerState>>,
 34    thread: Thread,
 35}
 36
 37impl TestScheduler {
 38    /// Run a test once with default configuration (seed 0)
 39    pub fn once<R>(f: impl AsyncFnOnce(Arc<TestScheduler>) -> R) -> R {
 40        Self::with_seed(0, f)
 41    }
 42
 43    /// Run a test multiple times with sequential seeds (0, 1, 2, ...)
 44    pub fn many<R>(iterations: usize, mut f: impl AsyncFnMut(Arc<TestScheduler>) -> R) -> Vec<R> {
 45        (0..iterations as u64)
 46            .map(|seed| {
 47                let mut unwind_safe_f = AssertUnwindSafe(&mut f);
 48                match panic::catch_unwind(move || Self::with_seed(seed, &mut *unwind_safe_f)) {
 49                    Ok(result) => result,
 50                    Err(error) => {
 51                        eprintln!("Failing Seed: {seed}");
 52                        panic::resume_unwind(error);
 53                    }
 54                }
 55            })
 56            .collect()
 57    }
 58
 59    /// Run a test once with a specific seed
 60    pub fn with_seed<R>(seed: u64, f: impl AsyncFnOnce(Arc<TestScheduler>) -> R) -> R {
 61        let scheduler = Arc::new(TestScheduler::new(TestSchedulerConfig::with_seed(seed)));
 62        let future = f(scheduler.clone());
 63        let result = scheduler.foreground().block_on(future);
 64        scheduler.run(); // Ensure spawned tasks finish up before returning in tests
 65        result
 66    }
 67
 68    pub fn new(config: TestSchedulerConfig) -> Self {
 69        Self {
 70            rng: Arc::new(Mutex::new(StdRng::seed_from_u64(config.seed))),
 71            state: Arc::new(Mutex::new(SchedulerState {
 72                runnables: VecDeque::new(),
 73                timers: Vec::new(),
 74                blocked_sessions: Vec::new(),
 75                randomize_order: config.randomize_order,
 76                allow_parking: config.allow_parking,
 77                timeout_ticks: config.timeout_ticks,
 78                next_session_id: SessionId(0),
 79                capture_pending_traces: config.capture_pending_traces,
 80                pending_traces: BTreeMap::new(),
 81                next_trace_id: TraceId(0),
 82            })),
 83            clock: Arc::new(TestClock::new()),
 84            thread: thread::current(),
 85        }
 86    }
 87
 88    pub fn clock(&self) -> Arc<TestClock> {
 89        self.clock.clone()
 90    }
 91
 92    pub fn rng(&self) -> Arc<Mutex<StdRng>> {
 93        self.rng.clone()
 94    }
 95
 96    pub fn set_timeout_ticks(&self, timeout_ticks: RangeInclusive<usize>) {
 97        self.state.lock().timeout_ticks = timeout_ticks;
 98    }
 99
100    pub fn allow_parking(&self) {
101        self.state.lock().allow_parking = true;
102    }
103
104    pub fn forbid_parking(&self) {
105        self.state.lock().allow_parking = false;
106    }
107
108    /// Create a foreground executor for this scheduler
109    pub fn foreground(self: &Arc<Self>) -> ForegroundExecutor {
110        let session_id = {
111            let mut state = self.state.lock();
112            state.next_session_id.0 += 1;
113            state.next_session_id
114        };
115        ForegroundExecutor::new(session_id, self.clone())
116    }
117
118    /// Create a background executor for this scheduler
119    pub fn background(self: &Arc<Self>) -> BackgroundExecutor {
120        BackgroundExecutor::new(self.clone())
121    }
122
123    pub fn yield_random(&self) -> Yield {
124        let rng = &mut *self.rng.lock();
125        if rng.random_bool(0.1) {
126            Yield(rng.random_range(10..20))
127        } else {
128            Yield(rng.random_range(0..2))
129        }
130    }
131
132    pub fn run(&self) {
133        while self.step() {
134            // Continue until no work remains
135        }
136    }
137
138    fn step(&self) -> bool {
139        let elapsed_timers = {
140            let mut state = self.state.lock();
141            let end_ix = state
142                .timers
143                .partition_point(|timer| timer.expiration <= self.clock.now());
144            state.timers.drain(..end_ix).collect::<Vec<_>>()
145        };
146
147        if !elapsed_timers.is_empty() {
148            return true;
149        }
150
151        let runnable = {
152            let state = &mut *self.state.lock();
153            let ix = state.runnables.iter().position(|runnable| {
154                runnable
155                    .session_id
156                    .is_none_or(|session_id| !state.blocked_sessions.contains(&session_id))
157            });
158            ix.and_then(|ix| state.runnables.remove(ix))
159        };
160
161        if let Some(runnable) = runnable {
162            runnable.run();
163            return true;
164        }
165
166        false
167    }
168
169    fn advance_clock_to_next_timer(&self) -> bool {
170        if let Some(timer) = self.state.lock().timers.first() {
171            self.clock.advance(timer.expiration - self.clock.now());
172            true
173        } else {
174            false
175        }
176    }
177
178    pub fn advance_clock(&self, duration: Duration) {
179        let next_now = self.clock.now() + duration;
180        loop {
181            self.run();
182            if let Some(timer) = self.state.lock().timers.first()
183                && timer.expiration <= next_now
184            {
185                self.clock.advance(timer.expiration - self.clock.now());
186            } else {
187                break;
188            }
189        }
190        self.clock.advance(next_now - self.clock.now());
191    }
192
193    fn park(&self, deadline: Option<Instant>) -> bool {
194        if self.state.lock().allow_parking {
195            if let Some(deadline) = deadline {
196                let now = Instant::now();
197                let timeout = deadline.saturating_duration_since(now);
198                thread::park_timeout(timeout);
199                now.elapsed() < timeout
200            } else {
201                thread::park();
202                true
203            }
204        } else if deadline.is_some() {
205            false
206        } else if self.state.lock().capture_pending_traces {
207            let mut pending_traces = String::new();
208            for (_, trace) in mem::take(&mut self.state.lock().pending_traces) {
209                writeln!(pending_traces, "{:?}", exclude_wakers_from_trace(trace)).unwrap();
210            }
211            panic!("Parking forbidden. Pending traces:\n{}", pending_traces);
212        } else {
213            panic!(
214                "Parking forbidden. Re-run with {PENDING_TRACES_VAR_NAME}=1 to show pending traces"
215            );
216        }
217    }
218}
219
220impl Scheduler for TestScheduler {
221    /// Block until the given future completes, with an optional timeout. If the
222    /// future is unable to make progress at any moment before the timeout and
223    /// no other tasks or timers remain, we panic unless parking is allowed. If
224    /// parking is allowed, we block up to the timeout or indefinitely if none
225    /// is provided. This is to allow testing a mix of deterministic and
226    /// non-deterministic async behavior, such as when interacting with I/O in
227    /// an otherwise deterministic test.
228    fn block(
229        &self,
230        session_id: Option<SessionId>,
231        mut future: LocalBoxFuture<()>,
232        timeout: Option<Duration>,
233    ) {
234        if let Some(session_id) = session_id {
235            self.state.lock().blocked_sessions.push(session_id);
236        }
237
238        let deadline = timeout.map(|timeout| Instant::now() + timeout);
239        let awoken = Arc::new(AtomicBool::new(false));
240        let waker = Box::new(TracingWaker {
241            id: None,
242            awoken: awoken.clone(),
243            thread: self.thread.clone(),
244            state: self.state.clone(),
245        });
246        let waker = unsafe { Waker::new(Box::into_raw(waker) as *const (), &WAKER_VTABLE) };
247        let max_ticks = if timeout.is_some() {
248            self.rng
249                .lock()
250                .random_range(self.state.lock().timeout_ticks.clone())
251        } else {
252            usize::MAX
253        };
254        let mut cx = Context::from_waker(&waker);
255
256        for _ in 0..max_ticks {
257            let Poll::Pending = future.poll_unpin(&mut cx) else {
258                break;
259            };
260
261            let mut stepped = None;
262            while self.rng.lock().random() {
263                let stepped = stepped.get_or_insert(false);
264                if self.step() {
265                    *stepped = true;
266                } else {
267                    break;
268                }
269            }
270
271            let stepped = stepped.unwrap_or(true);
272            let awoken = awoken.swap(false, SeqCst);
273            if !stepped && !awoken && !self.advance_clock_to_next_timer() {
274                if !self.park(deadline) {
275                    break;
276                }
277            }
278        }
279
280        if session_id.is_some() {
281            self.state.lock().blocked_sessions.pop();
282        }
283    }
284
285    fn schedule_foreground(&self, session_id: SessionId, runnable: Runnable) {
286        let mut state = self.state.lock();
287        let ix = if state.randomize_order {
288            let start_ix = state
289                .runnables
290                .iter()
291                .rposition(|task| task.session_id == Some(session_id))
292                .map_or(0, |ix| ix + 1);
293            self.rng
294                .lock()
295                .random_range(start_ix..=state.runnables.len())
296        } else {
297            state.runnables.len()
298        };
299        state.runnables.insert(
300            ix,
301            ScheduledRunnable {
302                session_id: Some(session_id),
303                runnable,
304            },
305        );
306        drop(state);
307        self.thread.unpark();
308    }
309
310    fn schedule_background(&self, runnable: Runnable) {
311        let mut state = self.state.lock();
312        let ix = if state.randomize_order {
313            self.rng.lock().random_range(0..=state.runnables.len())
314        } else {
315            state.runnables.len()
316        };
317        state.runnables.insert(
318            ix,
319            ScheduledRunnable {
320                session_id: None,
321                runnable,
322            },
323        );
324        drop(state);
325        self.thread.unpark();
326    }
327
328    fn timer(&self, duration: Duration) -> Timer {
329        let (tx, rx) = oneshot::channel();
330        let state = &mut *self.state.lock();
331        state.timers.push(ScheduledTimer {
332            expiration: self.clock.now() + duration,
333            _notify: tx,
334        });
335        state.timers.sort_by_key(|timer| timer.expiration);
336        Timer(rx)
337    }
338
339    fn clock(&self) -> Arc<dyn Clock> {
340        self.clock.clone()
341    }
342
343    fn as_test(&self) -> &TestScheduler {
344        self
345    }
346}
347
348#[derive(Clone, Debug)]
349pub struct TestSchedulerConfig {
350    pub seed: u64,
351    pub randomize_order: bool,
352    pub allow_parking: bool,
353    pub capture_pending_traces: bool,
354    pub timeout_ticks: RangeInclusive<usize>,
355}
356
357impl TestSchedulerConfig {
358    pub fn with_seed(seed: u64) -> Self {
359        Self {
360            seed,
361            ..Default::default()
362        }
363    }
364}
365
366impl Default for TestSchedulerConfig {
367    fn default() -> Self {
368        Self {
369            seed: 0,
370            randomize_order: true,
371            allow_parking: false,
372            capture_pending_traces: env::var(PENDING_TRACES_VAR_NAME)
373                .map_or(false, |var| var == "1" || var == "true"),
374            timeout_ticks: 0..=1000,
375        }
376    }
377}
378
379struct ScheduledRunnable {
380    session_id: Option<SessionId>,
381    runnable: Runnable,
382}
383
384impl ScheduledRunnable {
385    fn run(self) {
386        self.runnable.run();
387    }
388}
389
390struct ScheduledTimer {
391    expiration: Instant,
392    _notify: oneshot::Sender<()>,
393}
394
395struct SchedulerState {
396    runnables: VecDeque<ScheduledRunnable>,
397    timers: Vec<ScheduledTimer>,
398    blocked_sessions: Vec<SessionId>,
399    randomize_order: bool,
400    allow_parking: bool,
401    timeout_ticks: RangeInclusive<usize>,
402    next_session_id: SessionId,
403    capture_pending_traces: bool,
404    next_trace_id: TraceId,
405    pending_traces: BTreeMap<TraceId, Backtrace>,
406}
407
408const WAKER_VTABLE: RawWakerVTable = RawWakerVTable::new(
409    TracingWaker::clone_raw,
410    TracingWaker::wake_raw,
411    TracingWaker::wake_by_ref_raw,
412    TracingWaker::drop_raw,
413);
414
415#[derive(Copy, Clone, Eq, PartialEq, PartialOrd, Ord)]
416struct TraceId(usize);
417
418struct TracingWaker {
419    id: Option<TraceId>,
420    awoken: Arc<AtomicBool>,
421    thread: Thread,
422    state: Arc<Mutex<SchedulerState>>,
423}
424
425impl Clone for TracingWaker {
426    fn clone(&self) -> Self {
427        let mut state = self.state.lock();
428        let id = if state.capture_pending_traces {
429            let id = state.next_trace_id;
430            state.next_trace_id.0 += 1;
431            state.pending_traces.insert(id, Backtrace::new_unresolved());
432            Some(id)
433        } else {
434            None
435        };
436        Self {
437            id,
438            awoken: self.awoken.clone(),
439            thread: self.thread.clone(),
440            state: self.state.clone(),
441        }
442    }
443}
444
445impl Drop for TracingWaker {
446    fn drop(&mut self) {
447        if let Some(id) = self.id {
448            self.state.lock().pending_traces.remove(&id);
449        }
450    }
451}
452
453impl TracingWaker {
454    fn wake(self) {
455        self.wake_by_ref();
456    }
457
458    fn wake_by_ref(&self) {
459        if let Some(id) = self.id {
460            self.state.lock().pending_traces.remove(&id);
461        }
462        self.awoken.store(true, SeqCst);
463        self.thread.unpark();
464    }
465
466    fn clone_raw(waker: *const ()) -> RawWaker {
467        let waker = waker as *const TracingWaker;
468        let waker = unsafe { &*waker };
469        RawWaker::new(
470            Box::into_raw(Box::new(waker.clone())) as *const (),
471            &WAKER_VTABLE,
472        )
473    }
474
475    fn wake_raw(waker: *const ()) {
476        let waker = unsafe { Box::from_raw(waker as *mut TracingWaker) };
477        waker.wake();
478    }
479
480    fn wake_by_ref_raw(waker: *const ()) {
481        let waker = waker as *const TracingWaker;
482        let waker = unsafe { &*waker };
483        waker.wake_by_ref();
484    }
485
486    fn drop_raw(waker: *const ()) {
487        let waker = unsafe { Box::from_raw(waker as *mut TracingWaker) };
488        drop(waker);
489    }
490}
491
492pub struct Yield(usize);
493
494impl Future for Yield {
495    type Output = ();
496
497    fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
498        if self.0 == 0 {
499            Poll::Ready(())
500        } else {
501            self.0 -= 1;
502            cx.waker().wake_by_ref();
503            Poll::Pending
504        }
505    }
506}
507
508fn exclude_wakers_from_trace(mut trace: Backtrace) -> Backtrace {
509    trace.resolve();
510    let mut frames: Vec<BacktraceFrame> = trace.into();
511    let waker_clone_frame_ix = frames.iter().position(|frame| {
512        frame.symbols().iter().any(|symbol| {
513            symbol
514                .name()
515                .is_some_and(|name| format!("{name:#?}") == type_name_of_val(&Waker::clone))
516        })
517    });
518
519    if let Some(waker_clone_frame_ix) = waker_clone_frame_ix {
520        frames.drain(..waker_clone_frame_ix + 1);
521    }
522
523    Backtrace::from(frames)
524}