test_scheduler.rs

  1use crate::{
  2    BackgroundExecutor, Clock as _, ForegroundExecutor, Scheduler, SessionId, TestClock, Timer,
  3};
  4use async_task::Runnable;
  5use chrono::{DateTime, Duration as ChronoDuration, Utc};
  6use futures::{FutureExt as _, channel::oneshot, future::LocalBoxFuture};
  7use parking_lot::Mutex;
  8use rand::prelude::*;
  9use std::{
 10    collections::VecDeque,
 11    future::Future,
 12    panic::{self, AssertUnwindSafe},
 13    pin::Pin,
 14    sync::{
 15        Arc,
 16        atomic::{AtomicBool, Ordering::SeqCst},
 17    },
 18    task::{Context, Poll, Wake, Waker},
 19    thread,
 20    time::{Duration, Instant},
 21};
 22
 23pub struct TestScheduler {
 24    clock: Arc<TestClock>,
 25    rng: Arc<Mutex<StdRng>>,
 26    state: Mutex<SchedulerState>,
 27    pub thread_id: thread::ThreadId,
 28    pub config: SchedulerConfig,
 29}
 30
 31impl TestScheduler {
 32    /// Run a test once with default configuration (seed 0)
 33    pub fn once<R>(f: impl AsyncFnOnce(Arc<TestScheduler>) -> R) -> R {
 34        Self::with_seed(0, f)
 35    }
 36
 37    /// Run a test multiple times with sequential seeds (0, 1, 2, ...)
 38    pub fn many<R>(iterations: usize, mut f: impl AsyncFnMut(Arc<TestScheduler>) -> R) -> Vec<R> {
 39        (0..iterations as u64)
 40            .map(|seed| {
 41                let mut unwind_safe_f = AssertUnwindSafe(&mut f);
 42                match panic::catch_unwind(move || Self::with_seed(seed, &mut *unwind_safe_f)) {
 43                    Ok(result) => result,
 44                    Err(error) => {
 45                        eprintln!("Failing Seed: {seed}");
 46                        panic::resume_unwind(error);
 47                    }
 48                }
 49            })
 50            .collect()
 51    }
 52
 53    /// Run a test once with a specific seed
 54    pub fn with_seed<R>(seed: u64, f: impl AsyncFnOnce(Arc<TestScheduler>) -> R) -> R {
 55        let scheduler = Arc::new(TestScheduler::new(SchedulerConfig::with_seed(seed)));
 56        let future = f(scheduler.clone());
 57        let result = scheduler.block_on(future);
 58        scheduler.run();
 59        result
 60    }
 61
 62    pub fn new(config: SchedulerConfig) -> Self {
 63        Self {
 64            rng: Arc::new(Mutex::new(StdRng::seed_from_u64(config.seed))),
 65            state: Mutex::new(SchedulerState {
 66                runnables: VecDeque::new(),
 67                timers: Vec::new(),
 68                randomize_order: config.randomize_order,
 69                allow_parking: config.allow_parking,
 70                next_session_id: SessionId(0),
 71            }),
 72            thread_id: thread::current().id(),
 73            clock: Arc::new(TestClock::new()),
 74            config,
 75        }
 76    }
 77
 78    pub fn clock(&self) -> Arc<TestClock> {
 79        self.clock.clone()
 80    }
 81
 82    pub fn rng(&self) -> Arc<Mutex<StdRng>> {
 83        self.rng.clone()
 84    }
 85
 86    /// Create a foreground executor for this scheduler
 87    pub fn foreground(self: &Arc<Self>) -> ForegroundExecutor {
 88        let session_id = {
 89            let mut state = self.state.lock();
 90            state.next_session_id.0 += 1;
 91            state.next_session_id
 92        };
 93        ForegroundExecutor::new(session_id, self.clone())
 94    }
 95
 96    /// Create a background executor for this scheduler
 97    pub fn background(self: &Arc<Self>) -> BackgroundExecutor {
 98        BackgroundExecutor::new(self.clone())
 99    }
100
101    pub fn block_on<Fut: Future>(&self, future: Fut) -> Fut::Output {
102        (self as &dyn Scheduler).block_on(future)
103    }
104
105    pub fn yield_random(&self) -> Yield {
106        Yield(self.rng.lock().random_range(0..20))
107    }
108
109    pub fn run(&self) {
110        while self.step() || self.advance_clock() {
111            // Continue until no work remains
112        }
113    }
114
115    fn step(&self) -> bool {
116        let elapsed_timers = {
117            let mut state = self.state.lock();
118            let end_ix = state
119                .timers
120                .partition_point(|timer| timer.expiration <= self.clock.now());
121            state.timers.drain(..end_ix).collect::<Vec<_>>()
122        };
123
124        if !elapsed_timers.is_empty() {
125            return true;
126        }
127
128        let runnable = self.state.lock().runnables.pop_front();
129        if let Some(runnable) = runnable {
130            runnable.run();
131            return true;
132        }
133
134        false
135    }
136
137    fn advance_clock(&self) -> bool {
138        if let Some(timer) = self.state.lock().timers.first() {
139            self.clock.set_now(timer.expiration);
140            true
141        } else {
142            false
143        }
144    }
145}
146
147impl Scheduler for TestScheduler {
148    fn is_main_thread(&self) -> bool {
149        thread::current().id() == self.thread_id
150    }
151
152    fn schedule_foreground(&self, session_id: SessionId, runnable: Runnable) {
153        let mut state = self.state.lock();
154        let ix = if state.randomize_order {
155            let start_ix = state
156                .runnables
157                .iter()
158                .rposition(|task| task.session_id == Some(session_id))
159                .map_or(0, |ix| ix + 1);
160            self.rng
161                .lock()
162                .random_range(start_ix..=state.runnables.len())
163        } else {
164            state.runnables.len()
165        };
166        state.runnables.insert(
167            ix,
168            ScheduledRunnable {
169                session_id: Some(session_id),
170                runnable,
171            },
172        );
173    }
174
175    fn schedule_background(&self, runnable: Runnable) {
176        let mut state = self.state.lock();
177        let ix = if state.randomize_order {
178            self.rng.lock().random_range(0..=state.runnables.len())
179        } else {
180            state.runnables.len()
181        };
182        state.runnables.insert(
183            ix,
184            ScheduledRunnable {
185                session_id: None,
186                runnable,
187            },
188        );
189    }
190
191    fn timer(&self, duration: Duration) -> Timer {
192        let (tx, rx) = oneshot::channel();
193        let expiration = self.clock.now() + ChronoDuration::from_std(duration).unwrap();
194        let state = &mut *self.state.lock();
195        state.timers.push(ScheduledTimer {
196            expiration,
197            _notify: tx,
198        });
199        state.timers.sort_by_key(|timer| timer.expiration);
200        Timer(rx)
201    }
202
203    /// Block until the given future completes, with an optional timeout. If the
204    /// future is unable to make progress at any moment before the timeout and
205    /// no other tasks or timers remain, we panic unless parking is allowed. If
206    /// parking is allowed, we block up to the timeout or indefinitely if none
207    /// is provided. This is to allow testing a mix of deterministic and
208    /// non-deterministic async behavior, such as when interacting with I/O in
209    /// an otherwise deterministic test.
210    fn block(&self, mut future: LocalBoxFuture<()>, timeout: Option<Duration>) {
211        let (parker, unparker) = parking::pair();
212        let deadline = timeout.map(|timeout| Instant::now() + timeout);
213        let awoken = Arc::new(AtomicBool::new(false));
214        let waker = Waker::from(Arc::new(WakerFn::new({
215            let awoken = awoken.clone();
216            move || {
217                awoken.store(true, SeqCst);
218                unparker.unpark();
219            }
220        })));
221        let max_ticks = if timeout.is_some() {
222            self.rng
223                .lock()
224                .random_range(0..=self.config.max_timeout_ticks)
225        } else {
226            usize::MAX
227        };
228        let mut cx = Context::from_waker(&waker);
229
230        for _ in 0..max_ticks {
231            let Poll::Pending = future.poll_unpin(&mut cx) else {
232                break;
233            };
234
235            let mut stepped = None;
236            while self.rng.lock().random() && stepped.unwrap_or(true) {
237                *stepped.get_or_insert(false) |= self.step();
238            }
239
240            let stepped = stepped.unwrap_or(true);
241            let awoken = awoken.swap(false, SeqCst);
242            if !stepped && !awoken && !self.advance_clock() {
243                if self.state.lock().allow_parking {
244                    if !park(&parker, deadline) {
245                        break;
246                    }
247                } else if deadline.is_some() {
248                    break;
249                } else {
250                    panic!("Parking forbidden");
251                }
252            }
253        }
254    }
255}
256
257#[derive(Clone, Debug)]
258pub struct SchedulerConfig {
259    pub seed: u64,
260    pub randomize_order: bool,
261    pub allow_parking: bool,
262    pub max_timeout_ticks: usize,
263}
264
265impl SchedulerConfig {
266    pub fn with_seed(seed: u64) -> Self {
267        Self {
268            seed,
269            ..Default::default()
270        }
271    }
272}
273
274impl Default for SchedulerConfig {
275    fn default() -> Self {
276        Self {
277            seed: 0,
278            randomize_order: true,
279            allow_parking: false,
280            max_timeout_ticks: 1000,
281        }
282    }
283}
284
285struct ScheduledRunnable {
286    session_id: Option<SessionId>,
287    runnable: Runnable,
288}
289
290impl ScheduledRunnable {
291    fn run(self) {
292        self.runnable.run();
293    }
294}
295
296struct ScheduledTimer {
297    expiration: DateTime<Utc>,
298    _notify: oneshot::Sender<()>,
299}
300
301struct SchedulerState {
302    runnables: VecDeque<ScheduledRunnable>,
303    timers: Vec<ScheduledTimer>,
304    randomize_order: bool,
305    allow_parking: bool,
306    next_session_id: SessionId,
307}
308
309struct WakerFn<F> {
310    f: F,
311}
312
313impl<F: Fn()> WakerFn<F> {
314    fn new(f: F) -> Self {
315        Self { f }
316    }
317}
318
319impl<F: Fn()> Wake for WakerFn<F> {
320    fn wake(self: Arc<Self>) {
321        (self.f)();
322    }
323
324    fn wake_by_ref(self: &Arc<Self>) {
325        (self.f)();
326    }
327}
328
329pub struct Yield(usize);
330
331impl Future for Yield {
332    type Output = ();
333
334    fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
335        if self.0 == 0 {
336            Poll::Ready(())
337        } else {
338            self.0 -= 1;
339            cx.waker().wake_by_ref();
340            Poll::Pending
341        }
342    }
343}
344
345fn park(parker: &parking::Parker, deadline: Option<Instant>) -> bool {
346    if let Some(deadline) = deadline {
347        parker.park_deadline(deadline)
348    } else {
349        parker.park();
350        true
351    }
352}