1use crate::{
2 BackgroundExecutor, Clock, ForegroundExecutor, Scheduler, SessionId, TestClock, Timer,
3};
4use async_task::Runnable;
5use backtrace::{Backtrace, BacktraceFrame};
6use futures::{FutureExt as _, channel::oneshot, future::LocalBoxFuture};
7use parking_lot::Mutex;
8use rand::prelude::*;
9use std::{
10 any::type_name_of_val,
11 collections::{BTreeMap, VecDeque},
12 env,
13 fmt::Write,
14 future::Future,
15 mem,
16 ops::RangeInclusive,
17 panic::{self, AssertUnwindSafe},
18 pin::Pin,
19 sync::{
20 Arc,
21 atomic::{AtomicBool, Ordering::SeqCst},
22 },
23 task::{Context, Poll, RawWaker, RawWakerVTable, Waker},
24 thread::{self, Thread},
25 time::{Duration, Instant},
26};
27
28const PENDING_TRACES_VAR_NAME: &str = "PENDING_TRACES";
29
30pub struct TestScheduler {
31 clock: Arc<TestClock>,
32 rng: Arc<Mutex<StdRng>>,
33 state: Arc<Mutex<SchedulerState>>,
34 thread: Thread,
35}
36
37impl TestScheduler {
38 /// Run a test once with default configuration (seed 0)
39 pub fn once<R>(f: impl AsyncFnOnce(Arc<TestScheduler>) -> R) -> R {
40 Self::with_seed(0, f)
41 }
42
43 /// Run a test multiple times with sequential seeds (0, 1, 2, ...)
44 pub fn many<R>(iterations: usize, mut f: impl AsyncFnMut(Arc<TestScheduler>) -> R) -> Vec<R> {
45 (0..iterations as u64)
46 .map(|seed| {
47 let mut unwind_safe_f = AssertUnwindSafe(&mut f);
48 match panic::catch_unwind(move || Self::with_seed(seed, &mut *unwind_safe_f)) {
49 Ok(result) => result,
50 Err(error) => {
51 eprintln!("Failing Seed: {seed}");
52 panic::resume_unwind(error);
53 }
54 }
55 })
56 .collect()
57 }
58
59 /// Run a test once with a specific seed
60 pub fn with_seed<R>(seed: u64, f: impl AsyncFnOnce(Arc<TestScheduler>) -> R) -> R {
61 let scheduler = Arc::new(TestScheduler::new(TestSchedulerConfig::with_seed(seed)));
62 let future = f(scheduler.clone());
63 let result = scheduler.foreground().block_on(future);
64 scheduler.run(); // Ensure spawned tasks finish up before returning in tests
65 result
66 }
67
68 pub fn new(config: TestSchedulerConfig) -> Self {
69 Self {
70 rng: Arc::new(Mutex::new(StdRng::seed_from_u64(config.seed))),
71 state: Arc::new(Mutex::new(SchedulerState {
72 runnables: VecDeque::new(),
73 timers: Vec::new(),
74 blocked_sessions: Vec::new(),
75 randomize_order: config.randomize_order,
76 allow_parking: config.allow_parking,
77 timeout_ticks: config.timeout_ticks,
78 next_session_id: SessionId(0),
79 capture_pending_traces: config.capture_pending_traces,
80 pending_traces: BTreeMap::new(),
81 next_trace_id: TraceId(0),
82 })),
83 clock: Arc::new(TestClock::new()),
84 thread: thread::current(),
85 }
86 }
87
88 pub fn clock(&self) -> Arc<TestClock> {
89 self.clock.clone()
90 }
91
92 pub fn rng(&self) -> Arc<Mutex<StdRng>> {
93 self.rng.clone()
94 }
95
96 pub fn set_timeout_ticks(&self, timeout_ticks: RangeInclusive<usize>) {
97 self.state.lock().timeout_ticks = timeout_ticks;
98 }
99
100 pub fn allow_parking(&self) {
101 self.state.lock().allow_parking = true;
102 }
103
104 pub fn forbid_parking(&self) {
105 self.state.lock().allow_parking = false;
106 }
107
108 /// Create a foreground executor for this scheduler
109 pub fn foreground(self: &Arc<Self>) -> ForegroundExecutor {
110 let session_id = {
111 let mut state = self.state.lock();
112 state.next_session_id.0 += 1;
113 state.next_session_id
114 };
115 ForegroundExecutor::new(session_id, self.clone())
116 }
117
118 /// Create a background executor for this scheduler
119 pub fn background(self: &Arc<Self>) -> BackgroundExecutor {
120 BackgroundExecutor::new(self.clone())
121 }
122
123 pub fn yield_random(&self) -> Yield {
124 let rng = &mut *self.rng.lock();
125 if rng.random_bool(0.1) {
126 Yield(rng.random_range(10..20))
127 } else {
128 Yield(rng.random_range(0..2))
129 }
130 }
131
132 pub fn run(&self) {
133 while self.step() {
134 // Continue until no work remains
135 }
136 }
137
138 pub fn run_with_clock_advancement(&self) {
139 while self.step() || self.advance_clock_to_next_timer() {
140 // Continue until no work remains
141 }
142 }
143
144 fn step(&self) -> bool {
145 let elapsed_timers = {
146 let mut state = self.state.lock();
147 let end_ix = state
148 .timers
149 .partition_point(|timer| timer.expiration <= self.clock.now());
150 state.timers.drain(..end_ix).collect::<Vec<_>>()
151 };
152
153 if !elapsed_timers.is_empty() {
154 return true;
155 }
156
157 let runnable = {
158 let state = &mut *self.state.lock();
159 let ix = state.runnables.iter().position(|runnable| {
160 runnable
161 .session_id
162 .is_none_or(|session_id| !state.blocked_sessions.contains(&session_id))
163 });
164 ix.and_then(|ix| state.runnables.remove(ix))
165 };
166
167 if let Some(runnable) = runnable {
168 runnable.run();
169 return true;
170 }
171
172 false
173 }
174
175 fn advance_clock_to_next_timer(&self) -> bool {
176 if let Some(timer) = self.state.lock().timers.first() {
177 self.clock.advance(timer.expiration - self.clock.now());
178 true
179 } else {
180 false
181 }
182 }
183
184 pub fn advance_clock(&self, duration: Duration) {
185 let next_now = self.clock.now() + duration;
186 loop {
187 self.run();
188 if let Some(timer) = self.state.lock().timers.first()
189 && timer.expiration <= next_now
190 {
191 self.clock.advance(timer.expiration - self.clock.now());
192 } else {
193 break;
194 }
195 }
196 self.clock.advance(next_now - self.clock.now());
197 }
198
199 fn park(&self, deadline: Option<Instant>) -> bool {
200 if self.state.lock().allow_parking {
201 if let Some(deadline) = deadline {
202 let now = Instant::now();
203 let timeout = deadline.saturating_duration_since(now);
204 thread::park_timeout(timeout);
205 now.elapsed() < timeout
206 } else {
207 thread::park();
208 true
209 }
210 } else if deadline.is_some() {
211 false
212 } else if self.state.lock().capture_pending_traces {
213 let mut pending_traces = String::new();
214 for (_, trace) in mem::take(&mut self.state.lock().pending_traces) {
215 writeln!(pending_traces, "{:?}", exclude_wakers_from_trace(trace)).unwrap();
216 }
217 panic!("Parking forbidden. Pending traces:\n{}", pending_traces);
218 } else {
219 panic!(
220 "Parking forbidden. Re-run with {PENDING_TRACES_VAR_NAME}=1 to show pending traces"
221 );
222 }
223 }
224}
225
226impl Scheduler for TestScheduler {
227 /// Block until the given future completes, with an optional timeout. If the
228 /// future is unable to make progress at any moment before the timeout and
229 /// no other tasks or timers remain, we panic unless parking is allowed. If
230 /// parking is allowed, we block up to the timeout or indefinitely if none
231 /// is provided. This is to allow testing a mix of deterministic and
232 /// non-deterministic async behavior, such as when interacting with I/O in
233 /// an otherwise deterministic test.
234 fn block(
235 &self,
236 session_id: Option<SessionId>,
237 mut future: LocalBoxFuture<()>,
238 timeout: Option<Duration>,
239 ) {
240 if let Some(session_id) = session_id {
241 self.state.lock().blocked_sessions.push(session_id);
242 }
243
244 let deadline = timeout.map(|timeout| Instant::now() + timeout);
245 let awoken = Arc::new(AtomicBool::new(false));
246 let waker = Box::new(TracingWaker {
247 id: None,
248 awoken: awoken.clone(),
249 thread: self.thread.clone(),
250 state: self.state.clone(),
251 });
252 let waker = unsafe { Waker::new(Box::into_raw(waker) as *const (), &WAKER_VTABLE) };
253 let max_ticks = if timeout.is_some() {
254 self.rng
255 .lock()
256 .random_range(self.state.lock().timeout_ticks.clone())
257 } else {
258 usize::MAX
259 };
260 let mut cx = Context::from_waker(&waker);
261
262 for _ in 0..max_ticks {
263 let Poll::Pending = future.poll_unpin(&mut cx) else {
264 break;
265 };
266
267 let mut stepped = None;
268 while self.rng.lock().random() {
269 let stepped = stepped.get_or_insert(false);
270 if self.step() {
271 *stepped = true;
272 } else {
273 break;
274 }
275 }
276
277 let stepped = stepped.unwrap_or(true);
278 let awoken = awoken.swap(false, SeqCst);
279 if !stepped && !awoken && !self.advance_clock_to_next_timer() {
280 if !self.park(deadline) {
281 break;
282 }
283 }
284 }
285
286 if session_id.is_some() {
287 self.state.lock().blocked_sessions.pop();
288 }
289 }
290
291 fn schedule_foreground(&self, session_id: SessionId, runnable: Runnable) {
292 let mut state = self.state.lock();
293 let ix = if state.randomize_order {
294 let start_ix = state
295 .runnables
296 .iter()
297 .rposition(|task| task.session_id == Some(session_id))
298 .map_or(0, |ix| ix + 1);
299 self.rng
300 .lock()
301 .random_range(start_ix..=state.runnables.len())
302 } else {
303 state.runnables.len()
304 };
305 state.runnables.insert(
306 ix,
307 ScheduledRunnable {
308 session_id: Some(session_id),
309 runnable,
310 },
311 );
312 drop(state);
313 self.thread.unpark();
314 }
315
316 fn schedule_background(&self, runnable: Runnable) {
317 let mut state = self.state.lock();
318 let ix = if state.randomize_order {
319 self.rng.lock().random_range(0..=state.runnables.len())
320 } else {
321 state.runnables.len()
322 };
323 state.runnables.insert(
324 ix,
325 ScheduledRunnable {
326 session_id: None,
327 runnable,
328 },
329 );
330 drop(state);
331 self.thread.unpark();
332 }
333
334 fn timer(&self, duration: Duration) -> Timer {
335 let (tx, rx) = oneshot::channel();
336 let state = &mut *self.state.lock();
337 state.timers.push(ScheduledTimer {
338 expiration: self.clock.now() + duration,
339 _notify: tx,
340 });
341 state.timers.sort_by_key(|timer| timer.expiration);
342 Timer(rx)
343 }
344
345 fn clock(&self) -> Arc<dyn Clock> {
346 self.clock.clone()
347 }
348
349 fn as_test(&self) -> &TestScheduler {
350 self
351 }
352}
353
354#[derive(Clone, Debug)]
355pub struct TestSchedulerConfig {
356 pub seed: u64,
357 pub randomize_order: bool,
358 pub allow_parking: bool,
359 pub capture_pending_traces: bool,
360 pub timeout_ticks: RangeInclusive<usize>,
361}
362
363impl TestSchedulerConfig {
364 pub fn with_seed(seed: u64) -> Self {
365 Self {
366 seed,
367 ..Default::default()
368 }
369 }
370}
371
372impl Default for TestSchedulerConfig {
373 fn default() -> Self {
374 Self {
375 seed: 0,
376 randomize_order: true,
377 allow_parking: false,
378 capture_pending_traces: env::var(PENDING_TRACES_VAR_NAME)
379 .map_or(false, |var| var == "1" || var == "true"),
380 timeout_ticks: 0..=1000,
381 }
382 }
383}
384
385struct ScheduledRunnable {
386 session_id: Option<SessionId>,
387 runnable: Runnable,
388}
389
390impl ScheduledRunnable {
391 fn run(self) {
392 self.runnable.run();
393 }
394}
395
396struct ScheduledTimer {
397 expiration: Instant,
398 _notify: oneshot::Sender<()>,
399}
400
401struct SchedulerState {
402 runnables: VecDeque<ScheduledRunnable>,
403 timers: Vec<ScheduledTimer>,
404 blocked_sessions: Vec<SessionId>,
405 randomize_order: bool,
406 allow_parking: bool,
407 timeout_ticks: RangeInclusive<usize>,
408 next_session_id: SessionId,
409 capture_pending_traces: bool,
410 next_trace_id: TraceId,
411 pending_traces: BTreeMap<TraceId, Backtrace>,
412}
413
414const WAKER_VTABLE: RawWakerVTable = RawWakerVTable::new(
415 TracingWaker::clone_raw,
416 TracingWaker::wake_raw,
417 TracingWaker::wake_by_ref_raw,
418 TracingWaker::drop_raw,
419);
420
421#[derive(Copy, Clone, Eq, PartialEq, PartialOrd, Ord)]
422struct TraceId(usize);
423
424struct TracingWaker {
425 id: Option<TraceId>,
426 awoken: Arc<AtomicBool>,
427 thread: Thread,
428 state: Arc<Mutex<SchedulerState>>,
429}
430
431impl Clone for TracingWaker {
432 fn clone(&self) -> Self {
433 let mut state = self.state.lock();
434 let id = if state.capture_pending_traces {
435 let id = state.next_trace_id;
436 state.next_trace_id.0 += 1;
437 state.pending_traces.insert(id, Backtrace::new_unresolved());
438 Some(id)
439 } else {
440 None
441 };
442 Self {
443 id,
444 awoken: self.awoken.clone(),
445 thread: self.thread.clone(),
446 state: self.state.clone(),
447 }
448 }
449}
450
451impl Drop for TracingWaker {
452 fn drop(&mut self) {
453 if let Some(id) = self.id {
454 self.state.lock().pending_traces.remove(&id);
455 }
456 }
457}
458
459impl TracingWaker {
460 fn wake(self) {
461 self.wake_by_ref();
462 }
463
464 fn wake_by_ref(&self) {
465 if let Some(id) = self.id {
466 self.state.lock().pending_traces.remove(&id);
467 }
468 self.awoken.store(true, SeqCst);
469 self.thread.unpark();
470 }
471
472 fn clone_raw(waker: *const ()) -> RawWaker {
473 let waker = waker as *const TracingWaker;
474 let waker = unsafe { &*waker };
475 RawWaker::new(
476 Box::into_raw(Box::new(waker.clone())) as *const (),
477 &WAKER_VTABLE,
478 )
479 }
480
481 fn wake_raw(waker: *const ()) {
482 let waker = unsafe { Box::from_raw(waker as *mut TracingWaker) };
483 waker.wake();
484 }
485
486 fn wake_by_ref_raw(waker: *const ()) {
487 let waker = waker as *const TracingWaker;
488 let waker = unsafe { &*waker };
489 waker.wake_by_ref();
490 }
491
492 fn drop_raw(waker: *const ()) {
493 let waker = unsafe { Box::from_raw(waker as *mut TracingWaker) };
494 drop(waker);
495 }
496}
497
498pub struct Yield(usize);
499
500impl Future for Yield {
501 type Output = ();
502
503 fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
504 if self.0 == 0 {
505 Poll::Ready(())
506 } else {
507 self.0 -= 1;
508 cx.waker().wake_by_ref();
509 Poll::Pending
510 }
511 }
512}
513
514fn exclude_wakers_from_trace(mut trace: Backtrace) -> Backtrace {
515 trace.resolve();
516 let mut frames: Vec<BacktraceFrame> = trace.into();
517 let waker_clone_frame_ix = frames.iter().position(|frame| {
518 frame.symbols().iter().any(|symbol| {
519 symbol
520 .name()
521 .is_some_and(|name| format!("{name:#?}") == type_name_of_val(&Waker::clone))
522 })
523 });
524
525 if let Some(waker_clone_frame_ix) = waker_clone_frame_ix {
526 frames.drain(..waker_clone_frame_ix + 1);
527 }
528
529 Backtrace::from(frames)
530}