1use crate::{
2 BackgroundExecutor, Clock, ForegroundExecutor, Scheduler, SessionId, TestClock, Timer,
3};
4use async_task::Runnable;
5use backtrace::{Backtrace, BacktraceFrame};
6use futures::{FutureExt as _, channel::oneshot, future::LocalBoxFuture};
7use parking_lot::Mutex;
8use rand::prelude::*;
9use std::{
10 any::type_name_of_val,
11 collections::{BTreeMap, VecDeque},
12 env,
13 fmt::Write,
14 future::Future,
15 mem,
16 ops::RangeInclusive,
17 panic::{self, AssertUnwindSafe},
18 pin::Pin,
19 sync::{
20 Arc,
21 atomic::{AtomicBool, Ordering::SeqCst},
22 },
23 task::{Context, Poll, RawWaker, RawWakerVTable, Waker},
24 thread::{self, Thread},
25 time::{Duration, Instant},
26};
27
28const PENDING_TRACES_VAR_NAME: &str = "PENDING_TRACES";
29
30pub struct TestScheduler {
31 clock: Arc<TestClock>,
32 rng: Arc<Mutex<StdRng>>,
33 state: Arc<Mutex<SchedulerState>>,
34 thread: Thread,
35}
36
37impl TestScheduler {
38 /// Run a test once with default configuration (seed 0)
39 pub fn once<R>(f: impl AsyncFnOnce(Arc<TestScheduler>) -> R) -> R {
40 Self::with_seed(0, f)
41 }
42
43 /// Run a test multiple times with sequential seeds (0, 1, 2, ...)
44 pub fn many<R>(iterations: usize, mut f: impl AsyncFnMut(Arc<TestScheduler>) -> R) -> Vec<R> {
45 (0..iterations as u64)
46 .map(|seed| {
47 let mut unwind_safe_f = AssertUnwindSafe(&mut f);
48 match panic::catch_unwind(move || Self::with_seed(seed, &mut *unwind_safe_f)) {
49 Ok(result) => result,
50 Err(error) => {
51 eprintln!("Failing Seed: {seed}");
52 panic::resume_unwind(error);
53 }
54 }
55 })
56 .collect()
57 }
58
59 /// Run a test once with a specific seed
60 pub fn with_seed<R>(seed: u64, f: impl AsyncFnOnce(Arc<TestScheduler>) -> R) -> R {
61 let scheduler = Arc::new(TestScheduler::new(TestSchedulerConfig::with_seed(seed)));
62 let future = f(scheduler.clone());
63 let result = scheduler.foreground().block_on(future);
64 scheduler.run(); // Ensure spawned tasks finish up before returning in tests
65 result
66 }
67
68 pub fn new(config: TestSchedulerConfig) -> Self {
69 Self {
70 rng: Arc::new(Mutex::new(StdRng::seed_from_u64(config.seed))),
71 state: Arc::new(Mutex::new(SchedulerState {
72 runnables: VecDeque::new(),
73 timers: Vec::new(),
74 blocked_sessions: Vec::new(),
75 randomize_order: config.randomize_order,
76 allow_parking: config.allow_parking,
77 timeout_ticks: config.timeout_ticks,
78 next_session_id: SessionId(0),
79 capture_pending_traces: config.capture_pending_traces,
80 pending_traces: BTreeMap::new(),
81 next_trace_id: TraceId(0),
82 })),
83 clock: Arc::new(TestClock::new()),
84 thread: thread::current(),
85 }
86 }
87
88 pub fn clock(&self) -> Arc<TestClock> {
89 self.clock.clone()
90 }
91
92 pub fn rng(&self) -> Arc<Mutex<StdRng>> {
93 self.rng.clone()
94 }
95
96 pub fn set_timeout_ticks(&self, timeout_ticks: RangeInclusive<usize>) {
97 self.state.lock().timeout_ticks = timeout_ticks;
98 }
99
100 pub fn allow_parking(&self) {
101 self.state.lock().allow_parking = true;
102 }
103
104 pub fn forbid_parking(&self) {
105 self.state.lock().allow_parking = false;
106 }
107
108 /// Create a foreground executor for this scheduler
109 pub fn foreground(self: &Arc<Self>) -> ForegroundExecutor {
110 let session_id = {
111 let mut state = self.state.lock();
112 state.next_session_id.0 += 1;
113 state.next_session_id
114 };
115 ForegroundExecutor::new(session_id, self.clone())
116 }
117
118 /// Create a background executor for this scheduler
119 pub fn background(self: &Arc<Self>) -> BackgroundExecutor {
120 BackgroundExecutor::new(self.clone())
121 }
122
123 pub fn yield_random(&self) -> Yield {
124 let rng = &mut *self.rng.lock();
125 if rng.random_bool(0.1) {
126 Yield(rng.random_range(10..20))
127 } else {
128 Yield(rng.random_range(0..2))
129 }
130 }
131
132 pub fn run(&self) {
133 while self.step() {
134 // Continue until no work remains
135 }
136 }
137
138 fn step(&self) -> bool {
139 let elapsed_timers = {
140 let mut state = self.state.lock();
141 let end_ix = state
142 .timers
143 .partition_point(|timer| timer.expiration <= self.clock.now());
144 state.timers.drain(..end_ix).collect::<Vec<_>>()
145 };
146
147 if !elapsed_timers.is_empty() {
148 return true;
149 }
150
151 let runnable = {
152 let state = &mut *self.state.lock();
153 let ix = state.runnables.iter().position(|runnable| {
154 runnable
155 .session_id
156 .is_none_or(|session_id| !state.blocked_sessions.contains(&session_id))
157 });
158 ix.and_then(|ix| state.runnables.remove(ix))
159 };
160
161 if let Some(runnable) = runnable {
162 runnable.run();
163 return true;
164 }
165
166 false
167 }
168
169 fn advance_clock_to_next_timer(&self) -> bool {
170 if let Some(timer) = self.state.lock().timers.first() {
171 self.clock.advance(timer.expiration - self.clock.now());
172 true
173 } else {
174 false
175 }
176 }
177
178 pub fn advance_clock(&self, duration: Duration) {
179 let next_now = self.clock.now() + duration;
180 loop {
181 self.run();
182 if let Some(timer) = self.state.lock().timers.first()
183 && timer.expiration <= next_now
184 {
185 self.clock.advance(timer.expiration - self.clock.now());
186 } else {
187 break;
188 }
189 }
190 self.clock.advance(next_now - self.clock.now());
191 }
192
193 fn park(&self, deadline: Option<Instant>) -> bool {
194 if self.state.lock().allow_parking {
195 if let Some(deadline) = deadline {
196 let now = Instant::now();
197 let timeout = deadline.saturating_duration_since(now);
198 thread::park_timeout(timeout);
199 now.elapsed() < timeout
200 } else {
201 thread::park();
202 true
203 }
204 } else if deadline.is_some() {
205 false
206 } else if self.state.lock().capture_pending_traces {
207 let mut pending_traces = String::new();
208 for (_, trace) in mem::take(&mut self.state.lock().pending_traces) {
209 writeln!(pending_traces, "{:?}", exclude_wakers_from_trace(trace)).unwrap();
210 }
211 panic!("Parking forbidden. Pending traces:\n{}", pending_traces);
212 } else {
213 panic!(
214 "Parking forbidden. Re-run with {PENDING_TRACES_VAR_NAME}=1 to show pending traces"
215 );
216 }
217 }
218}
219
220impl Scheduler for TestScheduler {
221 /// Block until the given future completes, with an optional timeout. If the
222 /// future is unable to make progress at any moment before the timeout and
223 /// no other tasks or timers remain, we panic unless parking is allowed. If
224 /// parking is allowed, we block up to the timeout or indefinitely if none
225 /// is provided. This is to allow testing a mix of deterministic and
226 /// non-deterministic async behavior, such as when interacting with I/O in
227 /// an otherwise deterministic test.
228 fn block(
229 &self,
230 session_id: Option<SessionId>,
231 mut future: LocalBoxFuture<()>,
232 timeout: Option<Duration>,
233 ) {
234 if let Some(session_id) = session_id {
235 self.state.lock().blocked_sessions.push(session_id);
236 }
237
238 let deadline = timeout.map(|timeout| Instant::now() + timeout);
239 let awoken = Arc::new(AtomicBool::new(false));
240 let waker = Box::new(TracingWaker {
241 id: None,
242 awoken: awoken.clone(),
243 thread: self.thread.clone(),
244 state: self.state.clone(),
245 });
246 let waker = unsafe { Waker::new(Box::into_raw(waker) as *const (), &WAKER_VTABLE) };
247 let max_ticks = if timeout.is_some() {
248 self.rng
249 .lock()
250 .random_range(self.state.lock().timeout_ticks.clone())
251 } else {
252 usize::MAX
253 };
254 let mut cx = Context::from_waker(&waker);
255
256 for _ in 0..max_ticks {
257 let Poll::Pending = future.poll_unpin(&mut cx) else {
258 break;
259 };
260
261 let mut stepped = None;
262 while self.rng.lock().random() {
263 let stepped = stepped.get_or_insert(false);
264 if self.step() {
265 *stepped = true;
266 } else {
267 break;
268 }
269 }
270
271 let stepped = stepped.unwrap_or(true);
272 let awoken = awoken.swap(false, SeqCst);
273 if !stepped && !awoken && !self.advance_clock_to_next_timer() {
274 if !self.park(deadline) {
275 break;
276 }
277 }
278 }
279
280 if session_id.is_some() {
281 self.state.lock().blocked_sessions.pop();
282 }
283 }
284
285 fn schedule_foreground(&self, session_id: SessionId, runnable: Runnable) {
286 let mut state = self.state.lock();
287 let ix = if state.randomize_order {
288 let start_ix = state
289 .runnables
290 .iter()
291 .rposition(|task| task.session_id == Some(session_id))
292 .map_or(0, |ix| ix + 1);
293 self.rng
294 .lock()
295 .random_range(start_ix..=state.runnables.len())
296 } else {
297 state.runnables.len()
298 };
299 state.runnables.insert(
300 ix,
301 ScheduledRunnable {
302 session_id: Some(session_id),
303 runnable,
304 },
305 );
306 drop(state);
307 self.thread.unpark();
308 }
309
310 fn schedule_background(&self, runnable: Runnable) {
311 let mut state = self.state.lock();
312 let ix = if state.randomize_order {
313 self.rng.lock().random_range(0..=state.runnables.len())
314 } else {
315 state.runnables.len()
316 };
317 state.runnables.insert(
318 ix,
319 ScheduledRunnable {
320 session_id: None,
321 runnable,
322 },
323 );
324 drop(state);
325 self.thread.unpark();
326 }
327
328 fn timer(&self, duration: Duration) -> Timer {
329 let (tx, rx) = oneshot::channel();
330 let state = &mut *self.state.lock();
331 state.timers.push(ScheduledTimer {
332 expiration: self.clock.now() + duration,
333 _notify: tx,
334 });
335 state.timers.sort_by_key(|timer| timer.expiration);
336 Timer(rx)
337 }
338
339 fn clock(&self) -> Arc<dyn Clock> {
340 self.clock.clone()
341 }
342
343 fn as_test(&self) -> &TestScheduler {
344 self
345 }
346}
347
348#[derive(Clone, Debug)]
349pub struct TestSchedulerConfig {
350 pub seed: u64,
351 pub randomize_order: bool,
352 pub allow_parking: bool,
353 pub capture_pending_traces: bool,
354 pub timeout_ticks: RangeInclusive<usize>,
355}
356
357impl TestSchedulerConfig {
358 pub fn with_seed(seed: u64) -> Self {
359 Self {
360 seed,
361 ..Default::default()
362 }
363 }
364}
365
366impl Default for TestSchedulerConfig {
367 fn default() -> Self {
368 Self {
369 seed: 0,
370 randomize_order: true,
371 allow_parking: false,
372 capture_pending_traces: env::var(PENDING_TRACES_VAR_NAME)
373 .map_or(false, |var| var == "1" || var == "true"),
374 timeout_ticks: 0..=1000,
375 }
376 }
377}
378
379struct ScheduledRunnable {
380 session_id: Option<SessionId>,
381 runnable: Runnable,
382}
383
384impl ScheduledRunnable {
385 fn run(self) {
386 self.runnable.run();
387 }
388}
389
390struct ScheduledTimer {
391 expiration: Instant,
392 _notify: oneshot::Sender<()>,
393}
394
395struct SchedulerState {
396 runnables: VecDeque<ScheduledRunnable>,
397 timers: Vec<ScheduledTimer>,
398 blocked_sessions: Vec<SessionId>,
399 randomize_order: bool,
400 allow_parking: bool,
401 timeout_ticks: RangeInclusive<usize>,
402 next_session_id: SessionId,
403 capture_pending_traces: bool,
404 next_trace_id: TraceId,
405 pending_traces: BTreeMap<TraceId, Backtrace>,
406}
407
408const WAKER_VTABLE: RawWakerVTable = RawWakerVTable::new(
409 TracingWaker::clone_raw,
410 TracingWaker::wake_raw,
411 TracingWaker::wake_by_ref_raw,
412 TracingWaker::drop_raw,
413);
414
415#[derive(Copy, Clone, Eq, PartialEq, PartialOrd, Ord)]
416struct TraceId(usize);
417
418struct TracingWaker {
419 id: Option<TraceId>,
420 awoken: Arc<AtomicBool>,
421 thread: Thread,
422 state: Arc<Mutex<SchedulerState>>,
423}
424
425impl Clone for TracingWaker {
426 fn clone(&self) -> Self {
427 let mut state = self.state.lock();
428 let id = if state.capture_pending_traces {
429 let id = state.next_trace_id;
430 state.next_trace_id.0 += 1;
431 state.pending_traces.insert(id, Backtrace::new_unresolved());
432 Some(id)
433 } else {
434 None
435 };
436 Self {
437 id,
438 awoken: self.awoken.clone(),
439 thread: self.thread.clone(),
440 state: self.state.clone(),
441 }
442 }
443}
444
445impl Drop for TracingWaker {
446 fn drop(&mut self) {
447 if let Some(id) = self.id {
448 self.state.lock().pending_traces.remove(&id);
449 }
450 }
451}
452
453impl TracingWaker {
454 fn wake(self) {
455 self.wake_by_ref();
456 }
457
458 fn wake_by_ref(&self) {
459 if let Some(id) = self.id {
460 self.state.lock().pending_traces.remove(&id);
461 }
462 self.awoken.store(true, SeqCst);
463 self.thread.unpark();
464 }
465
466 fn clone_raw(waker: *const ()) -> RawWaker {
467 let waker = waker as *const TracingWaker;
468 let waker = unsafe { &*waker };
469 RawWaker::new(
470 Box::into_raw(Box::new(waker.clone())) as *const (),
471 &WAKER_VTABLE,
472 )
473 }
474
475 fn wake_raw(waker: *const ()) {
476 let waker = unsafe { Box::from_raw(waker as *mut TracingWaker) };
477 waker.wake();
478 }
479
480 fn wake_by_ref_raw(waker: *const ()) {
481 let waker = waker as *const TracingWaker;
482 let waker = unsafe { &*waker };
483 waker.wake_by_ref();
484 }
485
486 fn drop_raw(waker: *const ()) {
487 let waker = unsafe { Box::from_raw(waker as *mut TracingWaker) };
488 drop(waker);
489 }
490}
491
492pub struct Yield(usize);
493
494impl Future for Yield {
495 type Output = ();
496
497 fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
498 if self.0 == 0 {
499 Poll::Ready(())
500 } else {
501 self.0 -= 1;
502 cx.waker().wake_by_ref();
503 Poll::Pending
504 }
505 }
506}
507
508fn exclude_wakers_from_trace(mut trace: Backtrace) -> Backtrace {
509 trace.resolve();
510 let mut frames: Vec<BacktraceFrame> = trace.into();
511 let waker_clone_frame_ix = frames.iter().position(|frame| {
512 frame.symbols().iter().any(|symbol| {
513 symbol
514 .name()
515 .is_some_and(|name| format!("{name:#?}") == type_name_of_val(&Waker::clone))
516 })
517 });
518
519 if let Some(waker_clone_frame_ix) = waker_clone_frame_ix {
520 frames.drain(..waker_clone_frame_ix + 1);
521 }
522
523 Backtrace::from(frames)
524}