1use crate::{
2 BackgroundExecutor, Clock, ForegroundExecutor, Scheduler, SessionId, TestClock, Timer,
3};
4use async_task::Runnable;
5use backtrace::{Backtrace, BacktraceFrame};
6use futures::{FutureExt as _, channel::oneshot, future::LocalBoxFuture};
7use parking_lot::Mutex;
8use rand::prelude::*;
9use std::{
10 any::type_name_of_val,
11 collections::{BTreeMap, VecDeque},
12 env,
13 fmt::Write,
14 future::Future,
15 mem,
16 ops::RangeInclusive,
17 panic::{self, AssertUnwindSafe},
18 pin::Pin,
19 sync::{
20 Arc,
21 atomic::{AtomicBool, Ordering::SeqCst},
22 },
23 task::{Context, Poll, RawWaker, RawWakerVTable, Waker},
24 thread::{self, Thread},
25 time::{Duration, Instant},
26};
27
28const PENDING_TRACES_VAR_NAME: &str = "PENDING_TRACES";
29
30pub struct TestScheduler {
31 clock: Arc<TestClock>,
32 rng: Arc<Mutex<StdRng>>,
33 state: Arc<Mutex<SchedulerState>>,
34 thread: Thread,
35}
36
37impl TestScheduler {
38 /// Run a test once with default configuration (seed 0)
39 pub fn once<R>(f: impl AsyncFnOnce(Arc<TestScheduler>) -> R) -> R {
40 Self::with_seed(0, f)
41 }
42
43 /// Run a test multiple times with sequential seeds (0, 1, 2, ...)
44 pub fn many<R>(
45 default_iterations: usize,
46 mut f: impl AsyncFnMut(Arc<TestScheduler>) -> R,
47 ) -> Vec<R> {
48 let num_iterations = std::env::var("ITERATIONS")
49 .map(|iterations| iterations.parse().unwrap())
50 .unwrap_or(default_iterations);
51
52 let seed = std::env::var("SEED")
53 .map(|seed| seed.parse().unwrap())
54 .unwrap_or(0);
55
56 (seed..num_iterations as u64)
57 .map(|seed| {
58 let mut unwind_safe_f = AssertUnwindSafe(&mut f);
59 eprintln!("Running seed: {seed}");
60 match panic::catch_unwind(move || Self::with_seed(seed, &mut *unwind_safe_f)) {
61 Ok(result) => result,
62 Err(error) => {
63 eprintln!("\x1b[31mFailing Seed: {seed}\x1b[0m");
64 panic::resume_unwind(error);
65 }
66 }
67 })
68 .collect()
69 }
70
71 fn with_seed<R>(seed: u64, f: impl AsyncFnOnce(Arc<TestScheduler>) -> R) -> R {
72 let scheduler = Arc::new(TestScheduler::new(TestSchedulerConfig::with_seed(seed)));
73 let future = f(scheduler.clone());
74 let result = scheduler.foreground().block_on(future);
75 scheduler.run(); // Ensure spawned tasks finish up before returning in tests
76 result
77 }
78
79 pub fn new(config: TestSchedulerConfig) -> Self {
80 Self {
81 rng: Arc::new(Mutex::new(StdRng::seed_from_u64(config.seed))),
82 state: Arc::new(Mutex::new(SchedulerState {
83 runnables: VecDeque::new(),
84 timers: Vec::new(),
85 blocked_sessions: Vec::new(),
86 randomize_order: config.randomize_order,
87 allow_parking: config.allow_parking,
88 timeout_ticks: config.timeout_ticks,
89 next_session_id: SessionId(0),
90 capture_pending_traces: config.capture_pending_traces,
91 pending_traces: BTreeMap::new(),
92 next_trace_id: TraceId(0),
93 })),
94 clock: Arc::new(TestClock::new()),
95 thread: thread::current(),
96 }
97 }
98
99 pub fn clock(&self) -> Arc<TestClock> {
100 self.clock.clone()
101 }
102
103 pub fn rng(&self) -> Arc<Mutex<StdRng>> {
104 self.rng.clone()
105 }
106
107 pub fn set_timeout_ticks(&self, timeout_ticks: RangeInclusive<usize>) {
108 self.state.lock().timeout_ticks = timeout_ticks;
109 }
110
111 pub fn allow_parking(&self) {
112 self.state.lock().allow_parking = true;
113 }
114
115 pub fn forbid_parking(&self) {
116 self.state.lock().allow_parking = false;
117 }
118
119 /// Create a foreground executor for this scheduler
120 pub fn foreground(self: &Arc<Self>) -> ForegroundExecutor {
121 let session_id = {
122 let mut state = self.state.lock();
123 state.next_session_id.0 += 1;
124 state.next_session_id
125 };
126 ForegroundExecutor::new(session_id, self.clone())
127 }
128
129 /// Create a background executor for this scheduler
130 pub fn background(self: &Arc<Self>) -> BackgroundExecutor {
131 BackgroundExecutor::new(self.clone())
132 }
133
134 pub fn yield_random(&self) -> Yield {
135 let rng = &mut *self.rng.lock();
136 if rng.random_bool(0.1) {
137 Yield(rng.random_range(10..20))
138 } else {
139 Yield(rng.random_range(0..2))
140 }
141 }
142
143 pub fn run(&self) {
144 while self.step() {
145 // Continue until no work remains
146 }
147 }
148
149 pub fn run_with_clock_advancement(&self) {
150 while self.step() || self.advance_clock_to_next_timer() {
151 // Continue until no work remains
152 }
153 }
154
155 fn step(&self) -> bool {
156 let elapsed_timers = {
157 let mut state = self.state.lock();
158 let end_ix = state
159 .timers
160 .partition_point(|timer| timer.expiration <= self.clock.now());
161 state.timers.drain(..end_ix).collect::<Vec<_>>()
162 };
163
164 if !elapsed_timers.is_empty() {
165 return true;
166 }
167
168 let runnable = {
169 let state = &mut *self.state.lock();
170 let ix = state.runnables.iter().position(|runnable| {
171 runnable
172 .session_id
173 .is_none_or(|session_id| !state.blocked_sessions.contains(&session_id))
174 });
175 ix.and_then(|ix| state.runnables.remove(ix))
176 };
177
178 if let Some(runnable) = runnable {
179 runnable.run();
180 return true;
181 }
182
183 false
184 }
185
186 fn advance_clock_to_next_timer(&self) -> bool {
187 if let Some(timer) = self.state.lock().timers.first() {
188 self.clock.advance(timer.expiration - self.clock.now());
189 true
190 } else {
191 false
192 }
193 }
194
195 pub fn advance_clock(&self, duration: Duration) {
196 let next_now = self.clock.now() + duration;
197 loop {
198 self.run();
199 if let Some(timer) = self.state.lock().timers.first()
200 && timer.expiration <= next_now
201 {
202 self.clock.advance(timer.expiration - self.clock.now());
203 } else {
204 break;
205 }
206 }
207 self.clock.advance(next_now - self.clock.now());
208 }
209
210 fn park(&self, deadline: Option<Instant>) -> bool {
211 if self.state.lock().allow_parking {
212 if let Some(deadline) = deadline {
213 let now = Instant::now();
214 let timeout = deadline.saturating_duration_since(now);
215 thread::park_timeout(timeout);
216 now.elapsed() < timeout
217 } else {
218 thread::park();
219 true
220 }
221 } else if deadline.is_some() {
222 false
223 } else if self.state.lock().capture_pending_traces {
224 let mut pending_traces = String::new();
225 for (_, trace) in mem::take(&mut self.state.lock().pending_traces) {
226 writeln!(pending_traces, "{:?}", exclude_wakers_from_trace(trace)).unwrap();
227 }
228 panic!("Parking forbidden. Pending traces:\n{}", pending_traces);
229 } else {
230 panic!(
231 "Parking forbidden. Re-run with {PENDING_TRACES_VAR_NAME}=1 to show pending traces"
232 );
233 }
234 }
235}
236
237impl Scheduler for TestScheduler {
238 /// Block until the given future completes, with an optional timeout. If the
239 /// future is unable to make progress at any moment before the timeout and
240 /// no other tasks or timers remain, we panic unless parking is allowed. If
241 /// parking is allowed, we block up to the timeout or indefinitely if none
242 /// is provided. This is to allow testing a mix of deterministic and
243 /// non-deterministic async behavior, such as when interacting with I/O in
244 /// an otherwise deterministic test.
245 fn block(
246 &self,
247 session_id: Option<SessionId>,
248 mut future: LocalBoxFuture<()>,
249 timeout: Option<Duration>,
250 ) {
251 if let Some(session_id) = session_id {
252 self.state.lock().blocked_sessions.push(session_id);
253 }
254
255 let deadline = timeout.map(|timeout| Instant::now() + timeout);
256 let awoken = Arc::new(AtomicBool::new(false));
257 let waker = Box::new(TracingWaker {
258 id: None,
259 awoken: awoken.clone(),
260 thread: self.thread.clone(),
261 state: self.state.clone(),
262 });
263 let waker = unsafe { Waker::new(Box::into_raw(waker) as *const (), &WAKER_VTABLE) };
264 let max_ticks = if timeout.is_some() {
265 self.rng
266 .lock()
267 .random_range(self.state.lock().timeout_ticks.clone())
268 } else {
269 usize::MAX
270 };
271 let mut cx = Context::from_waker(&waker);
272
273 for _ in 0..max_ticks {
274 let Poll::Pending = future.poll_unpin(&mut cx) else {
275 break;
276 };
277
278 let mut stepped = None;
279 while self.rng.lock().random() {
280 let stepped = stepped.get_or_insert(false);
281 if self.step() {
282 *stepped = true;
283 } else {
284 break;
285 }
286 }
287
288 let stepped = stepped.unwrap_or(true);
289 let awoken = awoken.swap(false, SeqCst);
290 if !stepped && !awoken && !self.advance_clock_to_next_timer() {
291 if !self.park(deadline) {
292 break;
293 }
294 }
295 }
296
297 if session_id.is_some() {
298 self.state.lock().blocked_sessions.pop();
299 }
300 }
301
302 fn schedule_foreground(&self, session_id: SessionId, runnable: Runnable) {
303 let mut state = self.state.lock();
304 let ix = if state.randomize_order {
305 let start_ix = state
306 .runnables
307 .iter()
308 .rposition(|task| task.session_id == Some(session_id))
309 .map_or(0, |ix| ix + 1);
310 self.rng
311 .lock()
312 .random_range(start_ix..=state.runnables.len())
313 } else {
314 state.runnables.len()
315 };
316 state.runnables.insert(
317 ix,
318 ScheduledRunnable {
319 session_id: Some(session_id),
320 runnable,
321 },
322 );
323 drop(state);
324 self.thread.unpark();
325 }
326
327 fn schedule_background(&self, runnable: Runnable) {
328 let mut state = self.state.lock();
329 let ix = if state.randomize_order {
330 self.rng.lock().random_range(0..=state.runnables.len())
331 } else {
332 state.runnables.len()
333 };
334 state.runnables.insert(
335 ix,
336 ScheduledRunnable {
337 session_id: None,
338 runnable,
339 },
340 );
341 drop(state);
342 self.thread.unpark();
343 }
344
345 fn timer(&self, duration: Duration) -> Timer {
346 let (tx, rx) = oneshot::channel();
347 let state = &mut *self.state.lock();
348 state.timers.push(ScheduledTimer {
349 expiration: self.clock.now() + duration,
350 _notify: tx,
351 });
352 state.timers.sort_by_key(|timer| timer.expiration);
353 Timer(rx)
354 }
355
356 fn clock(&self) -> Arc<dyn Clock> {
357 self.clock.clone()
358 }
359
360 fn as_test(&self) -> &TestScheduler {
361 self
362 }
363}
364
365#[derive(Clone, Debug)]
366pub struct TestSchedulerConfig {
367 pub seed: u64,
368 pub randomize_order: bool,
369 pub allow_parking: bool,
370 pub capture_pending_traces: bool,
371 pub timeout_ticks: RangeInclusive<usize>,
372}
373
374impl TestSchedulerConfig {
375 pub fn with_seed(seed: u64) -> Self {
376 Self {
377 seed,
378 ..Default::default()
379 }
380 }
381}
382
383impl Default for TestSchedulerConfig {
384 fn default() -> Self {
385 Self {
386 seed: 0,
387 randomize_order: true,
388 allow_parking: false,
389 capture_pending_traces: env::var(PENDING_TRACES_VAR_NAME)
390 .map_or(false, |var| var == "1" || var == "true"),
391 timeout_ticks: 0..=1000,
392 }
393 }
394}
395
396struct ScheduledRunnable {
397 session_id: Option<SessionId>,
398 runnable: Runnable,
399}
400
401impl ScheduledRunnable {
402 fn run(self) {
403 self.runnable.run();
404 }
405}
406
407struct ScheduledTimer {
408 expiration: Instant,
409 _notify: oneshot::Sender<()>,
410}
411
412struct SchedulerState {
413 runnables: VecDeque<ScheduledRunnable>,
414 timers: Vec<ScheduledTimer>,
415 blocked_sessions: Vec<SessionId>,
416 randomize_order: bool,
417 allow_parking: bool,
418 timeout_ticks: RangeInclusive<usize>,
419 next_session_id: SessionId,
420 capture_pending_traces: bool,
421 next_trace_id: TraceId,
422 pending_traces: BTreeMap<TraceId, Backtrace>,
423}
424
425const WAKER_VTABLE: RawWakerVTable = RawWakerVTable::new(
426 TracingWaker::clone_raw,
427 TracingWaker::wake_raw,
428 TracingWaker::wake_by_ref_raw,
429 TracingWaker::drop_raw,
430);
431
432#[derive(Copy, Clone, Eq, PartialEq, PartialOrd, Ord)]
433struct TraceId(usize);
434
435struct TracingWaker {
436 id: Option<TraceId>,
437 awoken: Arc<AtomicBool>,
438 thread: Thread,
439 state: Arc<Mutex<SchedulerState>>,
440}
441
442impl Clone for TracingWaker {
443 fn clone(&self) -> Self {
444 let mut state = self.state.lock();
445 let id = if state.capture_pending_traces {
446 let id = state.next_trace_id;
447 state.next_trace_id.0 += 1;
448 state.pending_traces.insert(id, Backtrace::new_unresolved());
449 Some(id)
450 } else {
451 None
452 };
453 Self {
454 id,
455 awoken: self.awoken.clone(),
456 thread: self.thread.clone(),
457 state: self.state.clone(),
458 }
459 }
460}
461
462impl Drop for TracingWaker {
463 fn drop(&mut self) {
464 if let Some(id) = self.id {
465 self.state.lock().pending_traces.remove(&id);
466 }
467 }
468}
469
470impl TracingWaker {
471 fn wake(self) {
472 self.wake_by_ref();
473 }
474
475 fn wake_by_ref(&self) {
476 if let Some(id) = self.id {
477 self.state.lock().pending_traces.remove(&id);
478 }
479 self.awoken.store(true, SeqCst);
480 self.thread.unpark();
481 }
482
483 fn clone_raw(waker: *const ()) -> RawWaker {
484 let waker = waker as *const TracingWaker;
485 let waker = unsafe { &*waker };
486 RawWaker::new(
487 Box::into_raw(Box::new(waker.clone())) as *const (),
488 &WAKER_VTABLE,
489 )
490 }
491
492 fn wake_raw(waker: *const ()) {
493 let waker = unsafe { Box::from_raw(waker as *mut TracingWaker) };
494 waker.wake();
495 }
496
497 fn wake_by_ref_raw(waker: *const ()) {
498 let waker = waker as *const TracingWaker;
499 let waker = unsafe { &*waker };
500 waker.wake_by_ref();
501 }
502
503 fn drop_raw(waker: *const ()) {
504 let waker = unsafe { Box::from_raw(waker as *mut TracingWaker) };
505 drop(waker);
506 }
507}
508
509pub struct Yield(usize);
510
511impl Future for Yield {
512 type Output = ();
513
514 fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
515 if self.0 == 0 {
516 Poll::Ready(())
517 } else {
518 self.0 -= 1;
519 cx.waker().wake_by_ref();
520 Poll::Pending
521 }
522 }
523}
524
525fn exclude_wakers_from_trace(mut trace: Backtrace) -> Backtrace {
526 trace.resolve();
527 let mut frames: Vec<BacktraceFrame> = trace.into();
528 let waker_clone_frame_ix = frames.iter().position(|frame| {
529 frame.symbols().iter().any(|symbol| {
530 symbol
531 .name()
532 .is_some_and(|name| format!("{name:#?}") == type_name_of_val(&Waker::clone))
533 })
534 });
535
536 if let Some(waker_clone_frame_ix) = waker_clone_frame_ix {
537 frames.drain(..waker_clone_frame_ix + 1);
538 }
539
540 Backtrace::from(frames)
541}