test_scheduler.rs

  1use crate::{
  2    BackgroundExecutor, Clock, ForegroundExecutor, Scheduler, SessionId, TestClock, Timer,
  3};
  4use async_task::Runnable;
  5use backtrace::{Backtrace, BacktraceFrame};
  6use futures::{FutureExt as _, channel::oneshot, future::LocalBoxFuture};
  7use parking_lot::Mutex;
  8use rand::prelude::*;
  9use std::{
 10    any::type_name_of_val,
 11    collections::{BTreeMap, VecDeque},
 12    env,
 13    fmt::Write,
 14    future::Future,
 15    mem,
 16    ops::RangeInclusive,
 17    panic::{self, AssertUnwindSafe},
 18    pin::Pin,
 19    sync::{
 20        Arc,
 21        atomic::{AtomicBool, Ordering::SeqCst},
 22    },
 23    task::{Context, Poll, RawWaker, RawWakerVTable, Waker},
 24    thread::{self, Thread},
 25    time::{Duration, Instant},
 26};
 27
 28const PENDING_TRACES_VAR_NAME: &str = "PENDING_TRACES";
 29
 30pub struct TestScheduler {
 31    clock: Arc<TestClock>,
 32    rng: Arc<Mutex<StdRng>>,
 33    state: Arc<Mutex<SchedulerState>>,
 34    thread: Thread,
 35}
 36
 37impl TestScheduler {
 38    /// Run a test once with default configuration (seed 0)
 39    pub fn once<R>(f: impl AsyncFnOnce(Arc<TestScheduler>) -> R) -> R {
 40        Self::with_seed(0, f)
 41    }
 42
 43    /// Run a test multiple times with sequential seeds (0, 1, 2, ...)
 44    pub fn many<R>(iterations: usize, mut f: impl AsyncFnMut(Arc<TestScheduler>) -> R) -> Vec<R> {
 45        (0..iterations as u64)
 46            .map(|seed| {
 47                let mut unwind_safe_f = AssertUnwindSafe(&mut f);
 48                match panic::catch_unwind(move || Self::with_seed(seed, &mut *unwind_safe_f)) {
 49                    Ok(result) => result,
 50                    Err(error) => {
 51                        eprintln!("Failing Seed: {seed}");
 52                        panic::resume_unwind(error);
 53                    }
 54                }
 55            })
 56            .collect()
 57    }
 58
 59    /// Run a test once with a specific seed
 60    pub fn with_seed<R>(seed: u64, f: impl AsyncFnOnce(Arc<TestScheduler>) -> R) -> R {
 61        let scheduler = Arc::new(TestScheduler::new(TestSchedulerConfig::with_seed(seed)));
 62        let future = f(scheduler.clone());
 63        let result = scheduler.foreground().block_on(future);
 64        scheduler.run(); // Ensure spawned tasks finish up before returning in tests
 65        result
 66    }
 67
 68    pub fn new(config: TestSchedulerConfig) -> Self {
 69        Self {
 70            rng: Arc::new(Mutex::new(StdRng::seed_from_u64(config.seed))),
 71            state: Arc::new(Mutex::new(SchedulerState {
 72                runnables: VecDeque::new(),
 73                timers: Vec::new(),
 74                blocked_sessions: Vec::new(),
 75                randomize_order: config.randomize_order,
 76                allow_parking: config.allow_parking,
 77                timeout_ticks: config.timeout_ticks,
 78                next_session_id: SessionId(0),
 79                capture_pending_traces: config.capture_pending_traces,
 80                pending_traces: BTreeMap::new(),
 81                next_trace_id: TraceId(0),
 82            })),
 83            clock: Arc::new(TestClock::new()),
 84            thread: thread::current(),
 85        }
 86    }
 87
 88    pub fn clock(&self) -> Arc<TestClock> {
 89        self.clock.clone()
 90    }
 91
 92    pub fn rng(&self) -> Arc<Mutex<StdRng>> {
 93        self.rng.clone()
 94    }
 95
 96    pub fn set_timeout_ticks(&self, timeout_ticks: RangeInclusive<usize>) {
 97        self.state.lock().timeout_ticks = timeout_ticks;
 98    }
 99
100    pub fn allow_parking(&self) {
101        self.state.lock().allow_parking = true;
102    }
103
104    pub fn forbid_parking(&self) {
105        self.state.lock().allow_parking = false;
106    }
107
108    /// Create a foreground executor for this scheduler
109    pub fn foreground(self: &Arc<Self>) -> ForegroundExecutor {
110        let session_id = {
111            let mut state = self.state.lock();
112            state.next_session_id.0 += 1;
113            state.next_session_id
114        };
115        ForegroundExecutor::new(session_id, self.clone())
116    }
117
118    /// Create a background executor for this scheduler
119    pub fn background(self: &Arc<Self>) -> BackgroundExecutor {
120        BackgroundExecutor::new(self.clone())
121    }
122
123    pub fn yield_random(&self) -> Yield {
124        let rng = &mut *self.rng.lock();
125        if rng.random_bool(0.1) {
126            Yield(rng.random_range(10..20))
127        } else {
128            Yield(rng.random_range(0..2))
129        }
130    }
131
132    pub fn run(&self) {
133        while self.step() {
134            // Continue until no work remains
135        }
136    }
137
138    pub fn run_with_clock_advancement(&self) {
139        while self.step() || self.advance_clock_to_next_timer() {
140            // Continue until no work remains
141        }
142    }
143
144    fn step(&self) -> bool {
145        let elapsed_timers = {
146            let mut state = self.state.lock();
147            let end_ix = state
148                .timers
149                .partition_point(|timer| timer.expiration <= self.clock.now());
150            state.timers.drain(..end_ix).collect::<Vec<_>>()
151        };
152
153        if !elapsed_timers.is_empty() {
154            return true;
155        }
156
157        let runnable = {
158            let state = &mut *self.state.lock();
159            let ix = state.runnables.iter().position(|runnable| {
160                runnable
161                    .session_id
162                    .is_none_or(|session_id| !state.blocked_sessions.contains(&session_id))
163            });
164            ix.and_then(|ix| state.runnables.remove(ix))
165        };
166
167        if let Some(runnable) = runnable {
168            runnable.run();
169            return true;
170        }
171
172        false
173    }
174
175    fn advance_clock_to_next_timer(&self) -> bool {
176        if let Some(timer) = self.state.lock().timers.first() {
177            self.clock.advance(timer.expiration - self.clock.now());
178            true
179        } else {
180            false
181        }
182    }
183
184    pub fn advance_clock(&self, duration: Duration) {
185        let next_now = self.clock.now() + duration;
186        loop {
187            self.run();
188            if let Some(timer) = self.state.lock().timers.first()
189                && timer.expiration <= next_now
190            {
191                self.clock.advance(timer.expiration - self.clock.now());
192            } else {
193                break;
194            }
195        }
196        self.clock.advance(next_now - self.clock.now());
197    }
198
199    fn park(&self, deadline: Option<Instant>) -> bool {
200        if self.state.lock().allow_parking {
201            if let Some(deadline) = deadline {
202                let now = Instant::now();
203                let timeout = deadline.saturating_duration_since(now);
204                thread::park_timeout(timeout);
205                now.elapsed() < timeout
206            } else {
207                thread::park();
208                true
209            }
210        } else if deadline.is_some() {
211            false
212        } else if self.state.lock().capture_pending_traces {
213            let mut pending_traces = String::new();
214            for (_, trace) in mem::take(&mut self.state.lock().pending_traces) {
215                writeln!(pending_traces, "{:?}", exclude_wakers_from_trace(trace)).unwrap();
216            }
217            panic!("Parking forbidden. Pending traces:\n{}", pending_traces);
218        } else {
219            panic!(
220                "Parking forbidden. Re-run with {PENDING_TRACES_VAR_NAME}=1 to show pending traces"
221            );
222        }
223    }
224}
225
226impl Scheduler for TestScheduler {
227    /// Block until the given future completes, with an optional timeout. If the
228    /// future is unable to make progress at any moment before the timeout and
229    /// no other tasks or timers remain, we panic unless parking is allowed. If
230    /// parking is allowed, we block up to the timeout or indefinitely if none
231    /// is provided. This is to allow testing a mix of deterministic and
232    /// non-deterministic async behavior, such as when interacting with I/O in
233    /// an otherwise deterministic test.
234    fn block(
235        &self,
236        session_id: Option<SessionId>,
237        mut future: LocalBoxFuture<()>,
238        timeout: Option<Duration>,
239    ) {
240        if let Some(session_id) = session_id {
241            self.state.lock().blocked_sessions.push(session_id);
242        }
243
244        let deadline = timeout.map(|timeout| Instant::now() + timeout);
245        let awoken = Arc::new(AtomicBool::new(false));
246        let waker = Box::new(TracingWaker {
247            id: None,
248            awoken: awoken.clone(),
249            thread: self.thread.clone(),
250            state: self.state.clone(),
251        });
252        let waker = unsafe { Waker::new(Box::into_raw(waker) as *const (), &WAKER_VTABLE) };
253        let max_ticks = if timeout.is_some() {
254            self.rng
255                .lock()
256                .random_range(self.state.lock().timeout_ticks.clone())
257        } else {
258            usize::MAX
259        };
260        let mut cx = Context::from_waker(&waker);
261
262        for _ in 0..max_ticks {
263            let Poll::Pending = future.poll_unpin(&mut cx) else {
264                break;
265            };
266
267            let mut stepped = None;
268            while self.rng.lock().random() {
269                let stepped = stepped.get_or_insert(false);
270                if self.step() {
271                    *stepped = true;
272                } else {
273                    break;
274                }
275            }
276
277            let stepped = stepped.unwrap_or(true);
278            let awoken = awoken.swap(false, SeqCst);
279            if !stepped && !awoken && !self.advance_clock_to_next_timer() {
280                if !self.park(deadline) {
281                    break;
282                }
283            }
284        }
285
286        if session_id.is_some() {
287            self.state.lock().blocked_sessions.pop();
288        }
289    }
290
291    fn schedule_foreground(&self, session_id: SessionId, runnable: Runnable) {
292        let mut state = self.state.lock();
293        let ix = if state.randomize_order {
294            let start_ix = state
295                .runnables
296                .iter()
297                .rposition(|task| task.session_id == Some(session_id))
298                .map_or(0, |ix| ix + 1);
299            self.rng
300                .lock()
301                .random_range(start_ix..=state.runnables.len())
302        } else {
303            state.runnables.len()
304        };
305        state.runnables.insert(
306            ix,
307            ScheduledRunnable {
308                session_id: Some(session_id),
309                runnable,
310            },
311        );
312        drop(state);
313        self.thread.unpark();
314    }
315
316    fn schedule_background(&self, runnable: Runnable) {
317        let mut state = self.state.lock();
318        let ix = if state.randomize_order {
319            self.rng.lock().random_range(0..=state.runnables.len())
320        } else {
321            state.runnables.len()
322        };
323        state.runnables.insert(
324            ix,
325            ScheduledRunnable {
326                session_id: None,
327                runnable,
328            },
329        );
330        drop(state);
331        self.thread.unpark();
332    }
333
334    fn timer(&self, duration: Duration) -> Timer {
335        let (tx, rx) = oneshot::channel();
336        let state = &mut *self.state.lock();
337        state.timers.push(ScheduledTimer {
338            expiration: self.clock.now() + duration,
339            _notify: tx,
340        });
341        state.timers.sort_by_key(|timer| timer.expiration);
342        Timer(rx)
343    }
344
345    fn clock(&self) -> Arc<dyn Clock> {
346        self.clock.clone()
347    }
348
349    fn as_test(&self) -> &TestScheduler {
350        self
351    }
352}
353
354#[derive(Clone, Debug)]
355pub struct TestSchedulerConfig {
356    pub seed: u64,
357    pub randomize_order: bool,
358    pub allow_parking: bool,
359    pub capture_pending_traces: bool,
360    pub timeout_ticks: RangeInclusive<usize>,
361}
362
363impl TestSchedulerConfig {
364    pub fn with_seed(seed: u64) -> Self {
365        Self {
366            seed,
367            ..Default::default()
368        }
369    }
370}
371
372impl Default for TestSchedulerConfig {
373    fn default() -> Self {
374        Self {
375            seed: 0,
376            randomize_order: true,
377            allow_parking: false,
378            capture_pending_traces: env::var(PENDING_TRACES_VAR_NAME)
379                .map_or(false, |var| var == "1" || var == "true"),
380            timeout_ticks: 0..=1000,
381        }
382    }
383}
384
385struct ScheduledRunnable {
386    session_id: Option<SessionId>,
387    runnable: Runnable,
388}
389
390impl ScheduledRunnable {
391    fn run(self) {
392        self.runnable.run();
393    }
394}
395
396struct ScheduledTimer {
397    expiration: Instant,
398    _notify: oneshot::Sender<()>,
399}
400
401struct SchedulerState {
402    runnables: VecDeque<ScheduledRunnable>,
403    timers: Vec<ScheduledTimer>,
404    blocked_sessions: Vec<SessionId>,
405    randomize_order: bool,
406    allow_parking: bool,
407    timeout_ticks: RangeInclusive<usize>,
408    next_session_id: SessionId,
409    capture_pending_traces: bool,
410    next_trace_id: TraceId,
411    pending_traces: BTreeMap<TraceId, Backtrace>,
412}
413
414const WAKER_VTABLE: RawWakerVTable = RawWakerVTable::new(
415    TracingWaker::clone_raw,
416    TracingWaker::wake_raw,
417    TracingWaker::wake_by_ref_raw,
418    TracingWaker::drop_raw,
419);
420
421#[derive(Copy, Clone, Eq, PartialEq, PartialOrd, Ord)]
422struct TraceId(usize);
423
424struct TracingWaker {
425    id: Option<TraceId>,
426    awoken: Arc<AtomicBool>,
427    thread: Thread,
428    state: Arc<Mutex<SchedulerState>>,
429}
430
431impl Clone for TracingWaker {
432    fn clone(&self) -> Self {
433        let mut state = self.state.lock();
434        let id = if state.capture_pending_traces {
435            let id = state.next_trace_id;
436            state.next_trace_id.0 += 1;
437            state.pending_traces.insert(id, Backtrace::new_unresolved());
438            Some(id)
439        } else {
440            None
441        };
442        Self {
443            id,
444            awoken: self.awoken.clone(),
445            thread: self.thread.clone(),
446            state: self.state.clone(),
447        }
448    }
449}
450
451impl Drop for TracingWaker {
452    fn drop(&mut self) {
453        if let Some(id) = self.id {
454            self.state.lock().pending_traces.remove(&id);
455        }
456    }
457}
458
459impl TracingWaker {
460    fn wake(self) {
461        self.wake_by_ref();
462    }
463
464    fn wake_by_ref(&self) {
465        if let Some(id) = self.id {
466            self.state.lock().pending_traces.remove(&id);
467        }
468        self.awoken.store(true, SeqCst);
469        self.thread.unpark();
470    }
471
472    fn clone_raw(waker: *const ()) -> RawWaker {
473        let waker = waker as *const TracingWaker;
474        let waker = unsafe { &*waker };
475        RawWaker::new(
476            Box::into_raw(Box::new(waker.clone())) as *const (),
477            &WAKER_VTABLE,
478        )
479    }
480
481    fn wake_raw(waker: *const ()) {
482        let waker = unsafe { Box::from_raw(waker as *mut TracingWaker) };
483        waker.wake();
484    }
485
486    fn wake_by_ref_raw(waker: *const ()) {
487        let waker = waker as *const TracingWaker;
488        let waker = unsafe { &*waker };
489        waker.wake_by_ref();
490    }
491
492    fn drop_raw(waker: *const ()) {
493        let waker = unsafe { Box::from_raw(waker as *mut TracingWaker) };
494        drop(waker);
495    }
496}
497
498pub struct Yield(usize);
499
500impl Future for Yield {
501    type Output = ();
502
503    fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
504        if self.0 == 0 {
505            Poll::Ready(())
506        } else {
507            self.0 -= 1;
508            cx.waker().wake_by_ref();
509            Poll::Pending
510        }
511    }
512}
513
514fn exclude_wakers_from_trace(mut trace: Backtrace) -> Backtrace {
515    trace.resolve();
516    let mut frames: Vec<BacktraceFrame> = trace.into();
517    let waker_clone_frame_ix = frames.iter().position(|frame| {
518        frame.symbols().iter().any(|symbol| {
519            symbol
520                .name()
521                .is_some_and(|name| format!("{name:#?}") == type_name_of_val(&Waker::clone))
522        })
523    });
524
525    if let Some(waker_clone_frame_ix) = waker_clone_frame_ix {
526        frames.drain(..waker_clone_frame_ix + 1);
527    }
528
529    Backtrace::from(frames)
530}