test_scheduler.rs

  1use crate::{
  2    BackgroundExecutor, Clock, ForegroundExecutor, Priority, RunnableMeta, Scheduler, SessionId,
  3    TestClock, Timer,
  4};
  5use async_task::Runnable;
  6use backtrace::{Backtrace, BacktraceFrame};
  7use futures::channel::oneshot;
  8use parking_lot::{Mutex, MutexGuard};
  9use rand::{
 10    distr::{StandardUniform, uniform::SampleRange, uniform::SampleUniform},
 11    prelude::*,
 12};
 13use std::{
 14    any::type_name_of_val,
 15    collections::{BTreeMap, HashSet, VecDeque},
 16    env,
 17    fmt::Write,
 18    future::Future,
 19    mem,
 20    ops::RangeInclusive,
 21    panic::{self, AssertUnwindSafe},
 22    pin::Pin,
 23    sync::{
 24        Arc,
 25        atomic::{AtomicBool, Ordering::SeqCst},
 26    },
 27    task::{Context, Poll, RawWaker, RawWakerVTable, Waker},
 28    thread::{self, Thread},
 29    time::{Duration, Instant},
 30};
 31
 32const PENDING_TRACES_VAR_NAME: &str = "PENDING_TRACES";
 33
 34pub struct TestScheduler {
 35    clock: Arc<TestClock>,
 36    rng: Arc<Mutex<StdRng>>,
 37    state: Arc<Mutex<SchedulerState>>,
 38    thread: Thread,
 39}
 40
 41impl TestScheduler {
 42    /// Run a test once with default configuration (seed 0)
 43    pub fn once<R>(f: impl AsyncFnOnce(Arc<TestScheduler>) -> R) -> R {
 44        Self::with_seed(0, f)
 45    }
 46
 47    /// Run a test multiple times with sequential seeds (0, 1, 2, ...)
 48    pub fn many<R>(
 49        default_iterations: usize,
 50        mut f: impl AsyncFnMut(Arc<TestScheduler>) -> R,
 51    ) -> Vec<R> {
 52        let num_iterations = std::env::var("ITERATIONS")
 53            .map(|iterations| iterations.parse().unwrap())
 54            .unwrap_or(default_iterations);
 55
 56        let seed = std::env::var("SEED")
 57            .map(|seed| seed.parse().unwrap())
 58            .unwrap_or(0);
 59
 60        (seed..num_iterations as u64)
 61            .map(|seed| {
 62                let mut unwind_safe_f = AssertUnwindSafe(&mut f);
 63                eprintln!("Running seed: {seed}");
 64                match panic::catch_unwind(move || Self::with_seed(seed, &mut *unwind_safe_f)) {
 65                    Ok(result) => result,
 66                    Err(error) => {
 67                        eprintln!("\x1b[31mFailing Seed: {seed}\x1b[0m");
 68                        panic::resume_unwind(error);
 69                    }
 70                }
 71            })
 72            .collect()
 73    }
 74
 75    fn with_seed<R>(seed: u64, f: impl AsyncFnOnce(Arc<TestScheduler>) -> R) -> R {
 76        let scheduler = Arc::new(TestScheduler::new(TestSchedulerConfig::with_seed(seed)));
 77        let future = f(scheduler.clone());
 78        let result = scheduler.foreground().block_on(future);
 79        scheduler.run(); // Ensure spawned tasks finish up before returning in tests
 80        result
 81    }
 82
 83    pub fn new(config: TestSchedulerConfig) -> Self {
 84        Self {
 85            rng: Arc::new(Mutex::new(StdRng::seed_from_u64(config.seed))),
 86            state: Arc::new(Mutex::new(SchedulerState {
 87                runnables: VecDeque::new(),
 88                timers: Vec::new(),
 89                blocked_sessions: Vec::new(),
 90                randomize_order: config.randomize_order,
 91                allow_parking: config.allow_parking,
 92                timeout_ticks: config.timeout_ticks,
 93                next_session_id: SessionId(0),
 94                capture_pending_traces: config.capture_pending_traces,
 95                pending_traces: BTreeMap::new(),
 96                next_trace_id: TraceId(0),
 97                is_main_thread: true,
 98            })),
 99            clock: Arc::new(TestClock::new()),
100            thread: thread::current(),
101        }
102    }
103
104    pub fn clock(&self) -> Arc<TestClock> {
105        self.clock.clone()
106    }
107
108    pub fn rng(&self) -> SharedRng {
109        SharedRng(self.rng.clone())
110    }
111
112    pub fn set_timeout_ticks(&self, timeout_ticks: RangeInclusive<usize>) {
113        self.state.lock().timeout_ticks = timeout_ticks;
114    }
115
116    pub fn allow_parking(&self) {
117        self.state.lock().allow_parking = true;
118    }
119
120    pub fn forbid_parking(&self) {
121        self.state.lock().allow_parking = false;
122    }
123
124    pub fn parking_allowed(&self) -> bool {
125        self.state.lock().allow_parking
126    }
127
128    pub fn is_main_thread(&self) -> bool {
129        self.state.lock().is_main_thread
130    }
131
132    /// Allocate a new session ID for foreground task scheduling.
133    /// This is used by GPUI's TestDispatcher to map dispatcher instances to sessions.
134    pub fn allocate_session_id(&self) -> SessionId {
135        let mut state = self.state.lock();
136        state.next_session_id.0 += 1;
137        state.next_session_id
138    }
139
140    /// Create a foreground executor for this scheduler
141    pub fn foreground(self: &Arc<Self>) -> ForegroundExecutor {
142        let session_id = self.allocate_session_id();
143        ForegroundExecutor::new(session_id, self.clone())
144    }
145
146    /// Create a background executor for this scheduler
147    pub fn background(self: &Arc<Self>) -> BackgroundExecutor {
148        BackgroundExecutor::new(self.clone())
149    }
150
151    pub fn yield_random(&self) -> Yield {
152        let rng = &mut *self.rng.lock();
153        if rng.random_bool(0.1) {
154            Yield(rng.random_range(10..20))
155        } else {
156            Yield(rng.random_range(0..2))
157        }
158    }
159
160    pub fn run(&self) {
161        while self.step() {
162            // Continue until no work remains
163        }
164    }
165
166    pub fn run_with_clock_advancement(&self) {
167        while self.step() || self.advance_clock_to_next_timer() {
168            // Continue until no work remains
169        }
170    }
171
172    /// Execute one tick of the scheduler, processing expired timers and running
173    /// at most one task. Returns true if any work was done.
174    ///
175    /// This is the public interface for GPUI's TestDispatcher to drive task execution.
176    pub fn tick(&self) -> bool {
177        self.step_filtered(false)
178    }
179
180    /// Execute one tick, but only run background tasks (no foreground/session tasks).
181    /// Returns true if any work was done.
182    pub fn tick_background_only(&self) -> bool {
183        self.step_filtered(true)
184    }
185
186    /// Check if there are any pending tasks or timers that could run.
187    pub fn has_pending_tasks(&self) -> bool {
188        let state = self.state.lock();
189        !state.runnables.is_empty() || !state.timers.is_empty()
190    }
191
192    /// Returns counts of (foreground_tasks, background_tasks) currently queued.
193    /// Foreground tasks are those with a session_id, background tasks have none.
194    pub fn pending_task_counts(&self) -> (usize, usize) {
195        let state = self.state.lock();
196        let foreground = state
197            .runnables
198            .iter()
199            .filter(|r| r.session_id.is_some())
200            .count();
201        let background = state
202            .runnables
203            .iter()
204            .filter(|r| r.session_id.is_none())
205            .count();
206        (foreground, background)
207    }
208
209    fn step(&self) -> bool {
210        self.step_filtered(false)
211    }
212
213    fn step_filtered(&self, background_only: bool) -> bool {
214        let (elapsed_count, runnables_before) = {
215            let mut state = self.state.lock();
216            let end_ix = state
217                .timers
218                .partition_point(|timer| timer.expiration <= self.clock.now());
219            let elapsed: Vec<_> = state.timers.drain(..end_ix).collect();
220            let count = elapsed.len();
221            let runnables = state.runnables.len();
222            drop(state);
223            // Dropping elapsed timers here wakes the waiting futures
224            drop(elapsed);
225            (count, runnables)
226        };
227
228        if elapsed_count > 0 {
229            let runnables_after = self.state.lock().runnables.len();
230            if std::env::var("DEBUG_SCHEDULER").is_ok() {
231                eprintln!(
232                    "[scheduler] Expired {} timers at {:?}, runnables: {} -> {}",
233                    elapsed_count,
234                    self.clock.now(),
235                    runnables_before,
236                    runnables_after
237                );
238            }
239            return true;
240        }
241
242        let runnable = {
243            let state = &mut *self.state.lock();
244
245            // Find candidate tasks:
246            // - For foreground tasks (with session_id), only the first task from each session
247            //   is a candidate (to preserve intra-session ordering)
248            // - For background tasks (no session_id), all are candidates
249            // - Tasks from blocked sessions are excluded
250            // - If background_only is true, skip foreground tasks entirely
251            let mut seen_sessions = HashSet::new();
252            let candidate_indices: Vec<usize> = state
253                .runnables
254                .iter()
255                .enumerate()
256                .filter(|(_, runnable)| {
257                    if let Some(session_id) = runnable.session_id {
258                        // Skip foreground tasks if background_only mode
259                        if background_only {
260                            return false;
261                        }
262                        // Exclude tasks from blocked sessions
263                        if state.blocked_sessions.contains(&session_id) {
264                            return false;
265                        }
266                        // Only include first task from each session (insert returns true if new)
267                        seen_sessions.insert(session_id)
268                    } else {
269                        // Background tasks are always candidates
270                        true
271                    }
272                })
273                .map(|(ix, _)| ix)
274                .collect();
275
276            if candidate_indices.is_empty() {
277                None
278            } else if state.randomize_order {
279                // Use priority-weighted random selection
280                let weights: Vec<u32> = candidate_indices
281                    .iter()
282                    .map(|&ix| state.runnables[ix].priority.weight())
283                    .collect();
284                let total_weight: u32 = weights.iter().sum();
285
286                if total_weight == 0 {
287                    // Fallback to uniform random if all weights are zero
288                    let choice = self.rng.lock().random_range(0..candidate_indices.len());
289                    state.runnables.remove(candidate_indices[choice])
290                } else {
291                    let mut target = self.rng.lock().random_range(0..total_weight);
292                    let mut selected_idx = 0;
293                    for (i, &weight) in weights.iter().enumerate() {
294                        if target < weight {
295                            selected_idx = i;
296                            break;
297                        }
298                        target -= weight;
299                    }
300                    state.runnables.remove(candidate_indices[selected_idx])
301                }
302            } else {
303                // Non-randomized: just take the first candidate task
304                state.runnables.remove(candidate_indices[0])
305            }
306        };
307
308        if let Some(runnable) = runnable {
309            // Check if the executor that spawned this task was closed
310            if runnable.runnable.metadata().is_closed() {
311                return true;
312            }
313            let is_foreground = runnable.session_id.is_some();
314            let was_main_thread = self.state.lock().is_main_thread;
315            self.state.lock().is_main_thread = is_foreground;
316            runnable.run();
317            self.state.lock().is_main_thread = was_main_thread;
318            return true;
319        }
320
321        false
322    }
323
324    pub fn advance_clock_to_next_timer(&self) -> bool {
325        if let Some(timer) = self.state.lock().timers.first() {
326            self.clock.advance(timer.expiration - self.clock.now());
327            true
328        } else {
329            false
330        }
331    }
332
333    pub fn advance_clock(&self, duration: Duration) {
334        let debug = std::env::var("DEBUG_SCHEDULER").is_ok();
335        let start = self.clock.now();
336        let next_now = start + duration;
337        if debug {
338            let timer_count = self.state.lock().timers.len();
339            eprintln!(
340                "[scheduler] advance_clock({:?}) from {:?}, {} pending timers",
341                duration, start, timer_count
342            );
343        }
344        loop {
345            self.run();
346            if let Some(timer) = self.state.lock().timers.first()
347                && timer.expiration <= next_now
348            {
349                let advance_to = timer.expiration;
350                if debug {
351                    eprintln!(
352                        "[scheduler] Advancing clock {:?} -> {:?} for timer",
353                        self.clock.now(),
354                        advance_to
355                    );
356                }
357                self.clock.advance(advance_to - self.clock.now());
358            } else {
359                break;
360            }
361        }
362        self.clock.advance(next_now - self.clock.now());
363        if debug {
364            eprintln!(
365                "[scheduler] advance_clock done, now at {:?}",
366                self.clock.now()
367            );
368        }
369    }
370
371    fn park(&self, deadline: Option<Instant>) -> bool {
372        if self.state.lock().allow_parking {
373            if let Some(deadline) = deadline {
374                let now = Instant::now();
375                let timeout = deadline.saturating_duration_since(now);
376                thread::park_timeout(timeout);
377                now.elapsed() < timeout
378            } else {
379                thread::park();
380                true
381            }
382        } else if deadline.is_some() {
383            false
384        } else if self.state.lock().capture_pending_traces {
385            let mut pending_traces = String::new();
386            for (_, trace) in mem::take(&mut self.state.lock().pending_traces) {
387                writeln!(pending_traces, "{:?}", exclude_wakers_from_trace(trace)).unwrap();
388            }
389            panic!("Parking forbidden. Pending traces:\n{}", pending_traces);
390        } else {
391            panic!(
392                "Parking forbidden. Re-run with {PENDING_TRACES_VAR_NAME}=1 to show pending traces"
393            );
394        }
395    }
396}
397
398impl Scheduler for TestScheduler {
399    /// Block until the given future completes, with an optional timeout. If the
400    /// future is unable to make progress at any moment before the timeout and
401    /// no other tasks or timers remain, we panic unless parking is allowed. If
402    /// parking is allowed, we block up to the timeout or indefinitely if none
403    /// is provided. This is to allow testing a mix of deterministic and
404    /// non-deterministic async behavior, such as when interacting with I/O in
405    /// an otherwise deterministic test.
406    fn block(
407        &self,
408        session_id: Option<SessionId>,
409        mut future: Pin<&mut dyn Future<Output = ()>>,
410        timeout: Option<Duration>,
411    ) -> bool {
412        if let Some(session_id) = session_id {
413            self.state.lock().blocked_sessions.push(session_id);
414        }
415
416        let deadline = timeout.map(|timeout| Instant::now() + timeout);
417        let awoken = Arc::new(AtomicBool::new(false));
418        let waker = Box::new(TracingWaker {
419            id: None,
420            awoken: awoken.clone(),
421            thread: self.thread.clone(),
422            state: self.state.clone(),
423        });
424        let waker = unsafe { Waker::new(Box::into_raw(waker) as *const (), &WAKER_VTABLE) };
425        let max_ticks = if timeout.is_some() {
426            self.rng
427                .lock()
428                .random_range(self.state.lock().timeout_ticks.clone())
429        } else {
430            usize::MAX
431        };
432        let mut cx = Context::from_waker(&waker);
433
434        let mut completed = false;
435        for _ in 0..max_ticks {
436            match future.as_mut().poll(&mut cx) {
437                Poll::Ready(()) => {
438                    completed = true;
439                    break;
440                }
441                Poll::Pending => {}
442            }
443
444            let mut stepped = None;
445            while self.rng.lock().random() {
446                let stepped = stepped.get_or_insert(false);
447                if self.step() {
448                    *stepped = true;
449                } else {
450                    break;
451                }
452            }
453
454            let stepped = stepped.unwrap_or(true);
455            let awoken = awoken.swap(false, SeqCst);
456            if !stepped && !awoken && !self.advance_clock_to_next_timer() {
457                if !self.park(deadline) {
458                    break;
459                }
460            }
461        }
462
463        if session_id.is_some() {
464            self.state.lock().blocked_sessions.pop();
465        }
466
467        completed
468    }
469
470    fn schedule_foreground(&self, session_id: SessionId, runnable: Runnable<RunnableMeta>) {
471        let mut state = self.state.lock();
472        let ix = if state.randomize_order {
473            let start_ix = state
474                .runnables
475                .iter()
476                .rposition(|task| task.session_id == Some(session_id))
477                .map_or(0, |ix| ix + 1);
478            self.rng
479                .lock()
480                .random_range(start_ix..=state.runnables.len())
481        } else {
482            state.runnables.len()
483        };
484        state.runnables.insert(
485            ix,
486            ScheduledRunnable {
487                session_id: Some(session_id),
488                priority: Priority::default(),
489                runnable,
490            },
491        );
492        drop(state);
493        self.thread.unpark();
494    }
495
496    fn schedule_background_with_priority(
497        &self,
498        runnable: Runnable<RunnableMeta>,
499        priority: Priority,
500    ) {
501        let mut state = self.state.lock();
502        let ix = if state.randomize_order {
503            self.rng.lock().random_range(0..=state.runnables.len())
504        } else {
505            state.runnables.len()
506        };
507        state.runnables.insert(
508            ix,
509            ScheduledRunnable {
510                session_id: None,
511                priority,
512                runnable,
513            },
514        );
515        drop(state);
516        self.thread.unpark();
517    }
518
519    fn timer(&self, duration: Duration) -> Timer {
520        let (tx, rx) = oneshot::channel();
521        let state = &mut *self.state.lock();
522        state.timers.push(ScheduledTimer {
523            expiration: self.clock.now() + duration,
524            _notify: tx,
525        });
526        state.timers.sort_by_key(|timer| timer.expiration);
527        Timer(rx)
528    }
529
530    fn clock(&self) -> Arc<dyn Clock> {
531        self.clock.clone()
532    }
533
534    fn as_test(&self) -> Option<&TestScheduler> {
535        Some(self)
536    }
537}
538
539#[derive(Clone, Debug)]
540pub struct TestSchedulerConfig {
541    pub seed: u64,
542    pub randomize_order: bool,
543    pub allow_parking: bool,
544    pub capture_pending_traces: bool,
545    pub timeout_ticks: RangeInclusive<usize>,
546}
547
548impl TestSchedulerConfig {
549    pub fn with_seed(seed: u64) -> Self {
550        Self {
551            seed,
552            ..Default::default()
553        }
554    }
555}
556
557impl Default for TestSchedulerConfig {
558    fn default() -> Self {
559        Self {
560            seed: 0,
561            randomize_order: true,
562            allow_parking: false,
563            capture_pending_traces: env::var(PENDING_TRACES_VAR_NAME)
564                .map_or(false, |var| var == "1" || var == "true"),
565            timeout_ticks: 0..=1000,
566        }
567    }
568}
569
570struct ScheduledRunnable {
571    session_id: Option<SessionId>,
572    priority: Priority,
573    runnable: Runnable<RunnableMeta>,
574}
575
576impl ScheduledRunnable {
577    fn run(self) {
578        self.runnable.run();
579    }
580}
581
582struct ScheduledTimer {
583    expiration: Instant,
584    _notify: oneshot::Sender<()>,
585}
586
587struct SchedulerState {
588    runnables: VecDeque<ScheduledRunnable>,
589    timers: Vec<ScheduledTimer>,
590    blocked_sessions: Vec<SessionId>,
591    randomize_order: bool,
592    allow_parking: bool,
593    timeout_ticks: RangeInclusive<usize>,
594    next_session_id: SessionId,
595    capture_pending_traces: bool,
596    next_trace_id: TraceId,
597    pending_traces: BTreeMap<TraceId, Backtrace>,
598    is_main_thread: bool,
599}
600
601const WAKER_VTABLE: RawWakerVTable = RawWakerVTable::new(
602    TracingWaker::clone_raw,
603    TracingWaker::wake_raw,
604    TracingWaker::wake_by_ref_raw,
605    TracingWaker::drop_raw,
606);
607
608#[derive(Copy, Clone, Eq, PartialEq, PartialOrd, Ord)]
609struct TraceId(usize);
610
611struct TracingWaker {
612    id: Option<TraceId>,
613    awoken: Arc<AtomicBool>,
614    thread: Thread,
615    state: Arc<Mutex<SchedulerState>>,
616}
617
618impl Clone for TracingWaker {
619    fn clone(&self) -> Self {
620        let mut state = self.state.lock();
621        let id = if state.capture_pending_traces {
622            let id = state.next_trace_id;
623            state.next_trace_id.0 += 1;
624            state.pending_traces.insert(id, Backtrace::new_unresolved());
625            Some(id)
626        } else {
627            None
628        };
629        Self {
630            id,
631            awoken: self.awoken.clone(),
632            thread: self.thread.clone(),
633            state: self.state.clone(),
634        }
635    }
636}
637
638impl Drop for TracingWaker {
639    fn drop(&mut self) {
640        if let Some(id) = self.id {
641            self.state.lock().pending_traces.remove(&id);
642        }
643    }
644}
645
646impl TracingWaker {
647    fn wake(self) {
648        self.wake_by_ref();
649    }
650
651    fn wake_by_ref(&self) {
652        if let Some(id) = self.id {
653            self.state.lock().pending_traces.remove(&id);
654        }
655        self.awoken.store(true, SeqCst);
656        self.thread.unpark();
657    }
658
659    fn clone_raw(waker: *const ()) -> RawWaker {
660        let waker = waker as *const TracingWaker;
661        let waker = unsafe { &*waker };
662        RawWaker::new(
663            Box::into_raw(Box::new(waker.clone())) as *const (),
664            &WAKER_VTABLE,
665        )
666    }
667
668    fn wake_raw(waker: *const ()) {
669        let waker = unsafe { Box::from_raw(waker as *mut TracingWaker) };
670        waker.wake();
671    }
672
673    fn wake_by_ref_raw(waker: *const ()) {
674        let waker = waker as *const TracingWaker;
675        let waker = unsafe { &*waker };
676        waker.wake_by_ref();
677    }
678
679    fn drop_raw(waker: *const ()) {
680        let waker = unsafe { Box::from_raw(waker as *mut TracingWaker) };
681        drop(waker);
682    }
683}
684
685pub struct Yield(usize);
686
687/// A wrapper around `Arc<Mutex<StdRng>>` that provides convenient methods
688/// for random number generation without requiring explicit locking.
689#[derive(Clone)]
690pub struct SharedRng(Arc<Mutex<StdRng>>);
691
692impl SharedRng {
693    /// Lock the inner RNG for direct access. Use this when you need multiple
694    /// random operations without re-locking between each one.
695    pub fn lock(&self) -> MutexGuard<'_, StdRng> {
696        self.0.lock()
697    }
698
699    /// Generate a random value in the given range.
700    pub fn random_range<T, R>(&self, range: R) -> T
701    where
702        T: SampleUniform,
703        R: SampleRange<T>,
704    {
705        self.0.lock().random_range(range)
706    }
707
708    /// Generate a random boolean with the given probability of being true.
709    pub fn random_bool(&self, p: f64) -> bool {
710        self.0.lock().random_bool(p)
711    }
712
713    /// Generate a random value of the given type.
714    pub fn random<T>(&self) -> T
715    where
716        StandardUniform: Distribution<T>,
717    {
718        self.0.lock().random()
719    }
720
721    /// Generate a random ratio - true with probability `numerator/denominator`.
722    pub fn random_ratio(&self, numerator: u32, denominator: u32) -> bool {
723        self.0.lock().random_ratio(numerator, denominator)
724    }
725}
726
727impl Future for Yield {
728    type Output = ();
729
730    fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
731        if self.0 == 0 {
732            Poll::Ready(())
733        } else {
734            self.0 -= 1;
735            cx.waker().wake_by_ref();
736            Poll::Pending
737        }
738    }
739}
740
741fn exclude_wakers_from_trace(mut trace: Backtrace) -> Backtrace {
742    trace.resolve();
743    let mut frames: Vec<BacktraceFrame> = trace.into();
744    let waker_clone_frame_ix = frames.iter().position(|frame| {
745        frame.symbols().iter().any(|symbol| {
746            symbol
747                .name()
748                .is_some_and(|name| format!("{name:#?}") == type_name_of_val(&Waker::clone))
749        })
750    });
751
752    if let Some(waker_clone_frame_ix) = waker_clone_frame_ix {
753        frames.drain(..waker_clone_frame_ix + 1);
754    }
755
756    Backtrace::from(frames)
757}