1use crate::{
2 BackgroundExecutor, Clock as _, ForegroundExecutor, Scheduler, SessionId, TestClock, Timer,
3};
4use async_task::Runnable;
5use chrono::{DateTime, Duration as ChronoDuration, Utc};
6use futures::{FutureExt as _, channel::oneshot, future::LocalBoxFuture};
7use parking_lot::Mutex;
8use rand::prelude::*;
9use std::{
10 collections::VecDeque,
11 future::Future,
12 panic::{self, AssertUnwindSafe},
13 pin::Pin,
14 sync::{
15 Arc,
16 atomic::{AtomicBool, Ordering::SeqCst},
17 },
18 task::{Context, Poll, Wake, Waker},
19 thread,
20 time::{Duration, Instant},
21};
22
23pub struct TestScheduler {
24 clock: Arc<TestClock>,
25 rng: Arc<Mutex<StdRng>>,
26 state: Mutex<SchedulerState>,
27 pub thread_id: thread::ThreadId,
28 pub config: SchedulerConfig,
29}
30
31impl TestScheduler {
32 /// Run a test once with default configuration (seed 0)
33 pub fn once<R>(f: impl AsyncFnOnce(Arc<TestScheduler>) -> R) -> R {
34 Self::with_seed(0, f)
35 }
36
37 /// Run a test multiple times with sequential seeds (0, 1, 2, ...)
38 pub fn many<R>(iterations: usize, mut f: impl AsyncFnMut(Arc<TestScheduler>) -> R) -> Vec<R> {
39 (0..iterations as u64)
40 .map(|seed| {
41 let mut unwind_safe_f = AssertUnwindSafe(&mut f);
42 match panic::catch_unwind(move || Self::with_seed(seed, &mut *unwind_safe_f)) {
43 Ok(result) => result,
44 Err(error) => {
45 eprintln!("Failing Seed: {seed}");
46 panic::resume_unwind(error);
47 }
48 }
49 })
50 .collect()
51 }
52
53 /// Run a test once with a specific seed
54 pub fn with_seed<R>(seed: u64, f: impl AsyncFnOnce(Arc<TestScheduler>) -> R) -> R {
55 let scheduler = Arc::new(TestScheduler::new(SchedulerConfig::with_seed(seed)));
56 let future = f(scheduler.clone());
57 let result = scheduler.block_on(future);
58 scheduler.run();
59 result
60 }
61
62 pub fn new(config: SchedulerConfig) -> Self {
63 Self {
64 rng: Arc::new(Mutex::new(StdRng::seed_from_u64(config.seed))),
65 state: Mutex::new(SchedulerState {
66 runnables: VecDeque::new(),
67 timers: Vec::new(),
68 randomize_order: config.randomize_order,
69 allow_parking: config.allow_parking,
70 next_session_id: SessionId(0),
71 }),
72 thread_id: thread::current().id(),
73 clock: Arc::new(TestClock::new()),
74 config,
75 }
76 }
77
78 pub fn clock(&self) -> Arc<TestClock> {
79 self.clock.clone()
80 }
81
82 pub fn rng(&self) -> Arc<Mutex<StdRng>> {
83 self.rng.clone()
84 }
85
86 /// Create a foreground executor for this scheduler
87 pub fn foreground(self: &Arc<Self>) -> ForegroundExecutor {
88 let session_id = {
89 let mut state = self.state.lock();
90 state.next_session_id.0 += 1;
91 state.next_session_id
92 };
93 ForegroundExecutor::new(session_id, self.clone())
94 }
95
96 /// Create a background executor for this scheduler
97 pub fn background(self: &Arc<Self>) -> BackgroundExecutor {
98 BackgroundExecutor::new(self.clone())
99 }
100
101 pub fn block_on<Fut: Future>(&self, future: Fut) -> Fut::Output {
102 (self as &dyn Scheduler).block_on(future)
103 }
104
105 pub fn yield_random(&self) -> Yield {
106 Yield(self.rng.lock().random_range(0..20))
107 }
108
109 pub fn run(&self) {
110 while self.step() || self.advance_clock() {
111 // Continue until no work remains
112 }
113 }
114
115 fn step(&self) -> bool {
116 let elapsed_timers = {
117 let mut state = self.state.lock();
118 let end_ix = state
119 .timers
120 .partition_point(|timer| timer.expiration <= self.clock.now());
121 state.timers.drain(..end_ix).collect::<Vec<_>>()
122 };
123
124 if !elapsed_timers.is_empty() {
125 return true;
126 }
127
128 let runnable = self.state.lock().runnables.pop_front();
129 if let Some(runnable) = runnable {
130 runnable.run();
131 return true;
132 }
133
134 false
135 }
136
137 fn advance_clock(&self) -> bool {
138 if let Some(timer) = self.state.lock().timers.first() {
139 self.clock.set_now(timer.expiration);
140 true
141 } else {
142 false
143 }
144 }
145}
146
147impl Scheduler for TestScheduler {
148 fn is_main_thread(&self) -> bool {
149 thread::current().id() == self.thread_id
150 }
151
152 fn schedule_foreground(&self, session_id: SessionId, runnable: Runnable) {
153 let mut state = self.state.lock();
154 let ix = if state.randomize_order {
155 let start_ix = state
156 .runnables
157 .iter()
158 .rposition(|task| task.session_id == Some(session_id))
159 .map_or(0, |ix| ix + 1);
160 self.rng
161 .lock()
162 .random_range(start_ix..=state.runnables.len())
163 } else {
164 state.runnables.len()
165 };
166 state.runnables.insert(
167 ix,
168 ScheduledRunnable {
169 session_id: Some(session_id),
170 runnable,
171 },
172 );
173 }
174
175 fn schedule_background(&self, runnable: Runnable) {
176 let mut state = self.state.lock();
177 let ix = if state.randomize_order {
178 self.rng.lock().random_range(0..=state.runnables.len())
179 } else {
180 state.runnables.len()
181 };
182 state.runnables.insert(
183 ix,
184 ScheduledRunnable {
185 session_id: None,
186 runnable,
187 },
188 );
189 }
190
191 fn timer(&self, duration: Duration) -> Timer {
192 let (tx, rx) = oneshot::channel();
193 let expiration = self.clock.now() + ChronoDuration::from_std(duration).unwrap();
194 let state = &mut *self.state.lock();
195 state.timers.push(ScheduledTimer {
196 expiration,
197 _notify: tx,
198 });
199 state.timers.sort_by_key(|timer| timer.expiration);
200 Timer(rx)
201 }
202
203 /// Block until the given future completes, with an optional timeout. If the
204 /// future is unable to make progress at any moment before the timeout and
205 /// no other tasks or timers remain, we panic unless parking is allowed. If
206 /// parking is allowed, we block up to the timeout or indefinitely if none
207 /// is provided. This is to allow testing a mix of deterministic and
208 /// non-deterministic async behavior, such as when interacting with I/O in
209 /// an otherwise deterministic test.
210 fn block(&self, mut future: LocalBoxFuture<()>, timeout: Option<Duration>) {
211 let (parker, unparker) = parking::pair();
212 let deadline = timeout.map(|timeout| Instant::now() + timeout);
213 let awoken = Arc::new(AtomicBool::new(false));
214 let waker = Waker::from(Arc::new(WakerFn::new({
215 let awoken = awoken.clone();
216 move || {
217 awoken.store(true, SeqCst);
218 unparker.unpark();
219 }
220 })));
221 let max_ticks = if timeout.is_some() {
222 self.rng
223 .lock()
224 .random_range(0..=self.config.max_timeout_ticks)
225 } else {
226 usize::MAX
227 };
228 let mut cx = Context::from_waker(&waker);
229
230 for _ in 0..max_ticks {
231 let Poll::Pending = future.poll_unpin(&mut cx) else {
232 break;
233 };
234
235 let mut stepped = None;
236 while self.rng.lock().random() && stepped.unwrap_or(true) {
237 *stepped.get_or_insert(false) |= self.step();
238 }
239
240 let stepped = stepped.unwrap_or(true);
241 let awoken = awoken.swap(false, SeqCst);
242 if !stepped && !awoken && !self.advance_clock() {
243 if self.state.lock().allow_parking {
244 if !park(&parker, deadline) {
245 break;
246 }
247 } else if deadline.is_some() {
248 break;
249 } else {
250 panic!("Parking forbidden");
251 }
252 }
253 }
254 }
255}
256
257#[derive(Clone, Debug)]
258pub struct SchedulerConfig {
259 pub seed: u64,
260 pub randomize_order: bool,
261 pub allow_parking: bool,
262 pub max_timeout_ticks: usize,
263}
264
265impl SchedulerConfig {
266 pub fn with_seed(seed: u64) -> Self {
267 Self {
268 seed,
269 ..Default::default()
270 }
271 }
272}
273
274impl Default for SchedulerConfig {
275 fn default() -> Self {
276 Self {
277 seed: 0,
278 randomize_order: true,
279 allow_parking: false,
280 max_timeout_ticks: 1000,
281 }
282 }
283}
284
285struct ScheduledRunnable {
286 session_id: Option<SessionId>,
287 runnable: Runnable,
288}
289
290impl ScheduledRunnable {
291 fn run(self) {
292 self.runnable.run();
293 }
294}
295
296struct ScheduledTimer {
297 expiration: DateTime<Utc>,
298 _notify: oneshot::Sender<()>,
299}
300
301struct SchedulerState {
302 runnables: VecDeque<ScheduledRunnable>,
303 timers: Vec<ScheduledTimer>,
304 randomize_order: bool,
305 allow_parking: bool,
306 next_session_id: SessionId,
307}
308
309struct WakerFn<F> {
310 f: F,
311}
312
313impl<F: Fn()> WakerFn<F> {
314 fn new(f: F) -> Self {
315 Self { f }
316 }
317}
318
319impl<F: Fn()> Wake for WakerFn<F> {
320 fn wake(self: Arc<Self>) {
321 (self.f)();
322 }
323
324 fn wake_by_ref(self: &Arc<Self>) {
325 (self.f)();
326 }
327}
328
329pub struct Yield(usize);
330
331impl Future for Yield {
332 type Output = ();
333
334 fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
335 if self.0 == 0 {
336 Poll::Ready(())
337 } else {
338 self.0 -= 1;
339 cx.waker().wake_by_ref();
340 Poll::Pending
341 }
342 }
343}
344
345fn park(parker: &parking::Parker, deadline: Option<Instant>) -> bool {
346 if let Some(deadline) = deadline {
347 parker.park_deadline(deadline)
348 } else {
349 parker.park();
350 true
351 }
352}