1use crate::{
2 BackgroundExecutor, Clock, ForegroundExecutor, Priority, RunnableMeta, Scheduler, SessionId,
3 TestClock, Timer,
4};
5use async_task::Runnable;
6use backtrace::{Backtrace, BacktraceFrame};
7use futures::channel::oneshot;
8use parking_lot::{Mutex, MutexGuard};
9use rand::{
10 distr::{StandardUniform, uniform::SampleRange, uniform::SampleUniform},
11 prelude::*,
12};
13use std::{
14 any::type_name_of_val,
15 collections::{BTreeMap, HashSet, VecDeque},
16 env,
17 fmt::Write,
18 future::Future,
19 mem,
20 ops::RangeInclusive,
21 panic::{self, AssertUnwindSafe},
22 pin::Pin,
23 sync::{
24 Arc,
25 atomic::{AtomicBool, Ordering::SeqCst},
26 },
27 task::{Context, Poll, RawWaker, RawWakerVTable, Waker},
28 thread::{self, Thread},
29 time::{Duration, Instant},
30};
31
32const PENDING_TRACES_VAR_NAME: &str = "PENDING_TRACES";
33
34pub struct TestScheduler {
35 clock: Arc<TestClock>,
36 rng: Arc<Mutex<StdRng>>,
37 state: Arc<Mutex<SchedulerState>>,
38 thread: Thread,
39}
40
41impl TestScheduler {
42 /// Run a test once with default configuration (seed 0)
43 pub fn once<R>(f: impl AsyncFnOnce(Arc<TestScheduler>) -> R) -> R {
44 Self::with_seed(0, f)
45 }
46
47 /// Run a test multiple times with sequential seeds (0, 1, 2, ...)
48 pub fn many<R>(
49 default_iterations: usize,
50 mut f: impl AsyncFnMut(Arc<TestScheduler>) -> R,
51 ) -> Vec<R> {
52 let num_iterations = std::env::var("ITERATIONS")
53 .map(|iterations| iterations.parse().unwrap())
54 .unwrap_or(default_iterations);
55
56 let seed = std::env::var("SEED")
57 .map(|seed| seed.parse().unwrap())
58 .unwrap_or(0);
59
60 (seed..num_iterations as u64)
61 .map(|seed| {
62 let mut unwind_safe_f = AssertUnwindSafe(&mut f);
63 eprintln!("Running seed: {seed}");
64 match panic::catch_unwind(move || Self::with_seed(seed, &mut *unwind_safe_f)) {
65 Ok(result) => result,
66 Err(error) => {
67 eprintln!("\x1b[31mFailing Seed: {seed}\x1b[0m");
68 panic::resume_unwind(error);
69 }
70 }
71 })
72 .collect()
73 }
74
75 fn with_seed<R>(seed: u64, f: impl AsyncFnOnce(Arc<TestScheduler>) -> R) -> R {
76 let scheduler = Arc::new(TestScheduler::new(TestSchedulerConfig::with_seed(seed)));
77 let future = f(scheduler.clone());
78 let result = scheduler.foreground().block_on(future);
79 scheduler.run(); // Ensure spawned tasks finish up before returning in tests
80 result
81 }
82
83 pub fn new(config: TestSchedulerConfig) -> Self {
84 Self {
85 rng: Arc::new(Mutex::new(StdRng::seed_from_u64(config.seed))),
86 state: Arc::new(Mutex::new(SchedulerState {
87 runnables: VecDeque::new(),
88 timers: Vec::new(),
89 blocked_sessions: Vec::new(),
90 randomize_order: config.randomize_order,
91 allow_parking: config.allow_parking,
92 timeout_ticks: config.timeout_ticks,
93 next_session_id: SessionId(0),
94 capture_pending_traces: config.capture_pending_traces,
95 pending_traces: BTreeMap::new(),
96 next_trace_id: TraceId(0),
97 is_main_thread: true,
98 })),
99 clock: Arc::new(TestClock::new()),
100 thread: thread::current(),
101 }
102 }
103
104 pub fn clock(&self) -> Arc<TestClock> {
105 self.clock.clone()
106 }
107
108 pub fn rng(&self) -> SharedRng {
109 SharedRng(self.rng.clone())
110 }
111
112 pub fn set_timeout_ticks(&self, timeout_ticks: RangeInclusive<usize>) {
113 self.state.lock().timeout_ticks = timeout_ticks;
114 }
115
116 pub fn allow_parking(&self) {
117 self.state.lock().allow_parking = true;
118 }
119
120 pub fn forbid_parking(&self) {
121 self.state.lock().allow_parking = false;
122 }
123
124 pub fn parking_allowed(&self) -> bool {
125 self.state.lock().allow_parking
126 }
127
128 pub fn is_main_thread(&self) -> bool {
129 self.state.lock().is_main_thread
130 }
131
132 /// Allocate a new session ID for foreground task scheduling.
133 /// This is used by GPUI's TestDispatcher to map dispatcher instances to sessions.
134 pub fn allocate_session_id(&self) -> SessionId {
135 let mut state = self.state.lock();
136 state.next_session_id.0 += 1;
137 state.next_session_id
138 }
139
140 /// Create a foreground executor for this scheduler
141 pub fn foreground(self: &Arc<Self>) -> ForegroundExecutor {
142 let session_id = self.allocate_session_id();
143 ForegroundExecutor::new(session_id, self.clone())
144 }
145
146 /// Create a background executor for this scheduler
147 pub fn background(self: &Arc<Self>) -> BackgroundExecutor {
148 BackgroundExecutor::new(self.clone())
149 }
150
151 pub fn yield_random(&self) -> Yield {
152 let rng = &mut *self.rng.lock();
153 if rng.random_bool(0.1) {
154 Yield(rng.random_range(10..20))
155 } else {
156 Yield(rng.random_range(0..2))
157 }
158 }
159
160 pub fn run(&self) {
161 while self.step() {
162 // Continue until no work remains
163 }
164 }
165
166 pub fn run_with_clock_advancement(&self) {
167 while self.step() || self.advance_clock_to_next_timer() {
168 // Continue until no work remains
169 }
170 }
171
172 /// Execute one tick of the scheduler, processing expired timers and running
173 /// at most one task. Returns true if any work was done.
174 ///
175 /// This is the public interface for GPUI's TestDispatcher to drive task execution.
176 pub fn tick(&self) -> bool {
177 self.step_filtered(false)
178 }
179
180 /// Execute one tick, but only run background tasks (no foreground/session tasks).
181 /// Returns true if any work was done.
182 pub fn tick_background_only(&self) -> bool {
183 self.step_filtered(true)
184 }
185
186 /// Check if there are any pending tasks or timers that could run.
187 pub fn has_pending_tasks(&self) -> bool {
188 let state = self.state.lock();
189 !state.runnables.is_empty() || !state.timers.is_empty()
190 }
191
192 /// Returns counts of (foreground_tasks, background_tasks) currently queued.
193 /// Foreground tasks are those with a session_id, background tasks have none.
194 pub fn pending_task_counts(&self) -> (usize, usize) {
195 let state = self.state.lock();
196 let foreground = state
197 .runnables
198 .iter()
199 .filter(|r| r.session_id.is_some())
200 .count();
201 let background = state
202 .runnables
203 .iter()
204 .filter(|r| r.session_id.is_none())
205 .count();
206 (foreground, background)
207 }
208
209 fn step(&self) -> bool {
210 self.step_filtered(false)
211 }
212
213 fn step_filtered(&self, background_only: bool) -> bool {
214 let (elapsed_count, runnables_before) = {
215 let mut state = self.state.lock();
216 let end_ix = state
217 .timers
218 .partition_point(|timer| timer.expiration <= self.clock.now());
219 let elapsed: Vec<_> = state.timers.drain(..end_ix).collect();
220 let count = elapsed.len();
221 let runnables = state.runnables.len();
222 drop(state);
223 // Dropping elapsed timers here wakes the waiting futures
224 drop(elapsed);
225 (count, runnables)
226 };
227
228 if elapsed_count > 0 {
229 let runnables_after = self.state.lock().runnables.len();
230 if std::env::var("DEBUG_SCHEDULER").is_ok() {
231 eprintln!(
232 "[scheduler] Expired {} timers at {:?}, runnables: {} -> {}",
233 elapsed_count,
234 self.clock.now(),
235 runnables_before,
236 runnables_after
237 );
238 }
239 return true;
240 }
241
242 let runnable = {
243 let state = &mut *self.state.lock();
244
245 // Find candidate tasks:
246 // - For foreground tasks (with session_id), only the first task from each session
247 // is a candidate (to preserve intra-session ordering)
248 // - For background tasks (no session_id), all are candidates
249 // - Tasks from blocked sessions are excluded
250 // - If background_only is true, skip foreground tasks entirely
251 let mut seen_sessions = HashSet::new();
252 let candidate_indices: Vec<usize> = state
253 .runnables
254 .iter()
255 .enumerate()
256 .filter(|(_, runnable)| {
257 if let Some(session_id) = runnable.session_id {
258 // Skip foreground tasks if background_only mode
259 if background_only {
260 return false;
261 }
262 // Exclude tasks from blocked sessions
263 if state.blocked_sessions.contains(&session_id) {
264 return false;
265 }
266 // Only include first task from each session (insert returns true if new)
267 seen_sessions.insert(session_id)
268 } else {
269 // Background tasks are always candidates
270 true
271 }
272 })
273 .map(|(ix, _)| ix)
274 .collect();
275
276 if candidate_indices.is_empty() {
277 None
278 } else if state.randomize_order {
279 // Use priority-weighted random selection
280 let weights: Vec<u32> = candidate_indices
281 .iter()
282 .map(|&ix| state.runnables[ix].priority.weight())
283 .collect();
284 let total_weight: u32 = weights.iter().sum();
285
286 if total_weight == 0 {
287 // Fallback to uniform random if all weights are zero
288 let choice = self.rng.lock().random_range(0..candidate_indices.len());
289 state.runnables.remove(candidate_indices[choice])
290 } else {
291 let mut target = self.rng.lock().random_range(0..total_weight);
292 let mut selected_idx = 0;
293 for (i, &weight) in weights.iter().enumerate() {
294 if target < weight {
295 selected_idx = i;
296 break;
297 }
298 target -= weight;
299 }
300 state.runnables.remove(candidate_indices[selected_idx])
301 }
302 } else {
303 // Non-randomized: just take the first candidate task
304 state.runnables.remove(candidate_indices[0])
305 }
306 };
307
308 if let Some(runnable) = runnable {
309 // Check if the executor that spawned this task was closed
310 if runnable.runnable.metadata().is_closed() {
311 return true;
312 }
313 let is_foreground = runnable.session_id.is_some();
314 let was_main_thread = self.state.lock().is_main_thread;
315 self.state.lock().is_main_thread = is_foreground;
316 runnable.run();
317 self.state.lock().is_main_thread = was_main_thread;
318 return true;
319 }
320
321 false
322 }
323
324 pub fn advance_clock_to_next_timer(&self) -> bool {
325 if let Some(timer) = self.state.lock().timers.first() {
326 self.clock.advance(timer.expiration - self.clock.now());
327 true
328 } else {
329 false
330 }
331 }
332
333 pub fn advance_clock(&self, duration: Duration) {
334 let debug = std::env::var("DEBUG_SCHEDULER").is_ok();
335 let start = self.clock.now();
336 let next_now = start + duration;
337 if debug {
338 let timer_count = self.state.lock().timers.len();
339 eprintln!(
340 "[scheduler] advance_clock({:?}) from {:?}, {} pending timers",
341 duration, start, timer_count
342 );
343 }
344 loop {
345 self.run();
346 if let Some(timer) = self.state.lock().timers.first()
347 && timer.expiration <= next_now
348 {
349 let advance_to = timer.expiration;
350 if debug {
351 eprintln!(
352 "[scheduler] Advancing clock {:?} -> {:?} for timer",
353 self.clock.now(),
354 advance_to
355 );
356 }
357 self.clock.advance(advance_to - self.clock.now());
358 } else {
359 break;
360 }
361 }
362 self.clock.advance(next_now - self.clock.now());
363 if debug {
364 eprintln!(
365 "[scheduler] advance_clock done, now at {:?}",
366 self.clock.now()
367 );
368 }
369 }
370
371 fn park(&self, deadline: Option<Instant>) -> bool {
372 if self.state.lock().allow_parking {
373 if let Some(deadline) = deadline {
374 let now = Instant::now();
375 let timeout = deadline.saturating_duration_since(now);
376 thread::park_timeout(timeout);
377 now.elapsed() < timeout
378 } else {
379 thread::park();
380 true
381 }
382 } else if deadline.is_some() {
383 false
384 } else if self.state.lock().capture_pending_traces {
385 let mut pending_traces = String::new();
386 for (_, trace) in mem::take(&mut self.state.lock().pending_traces) {
387 writeln!(pending_traces, "{:?}", exclude_wakers_from_trace(trace)).unwrap();
388 }
389 panic!("Parking forbidden. Pending traces:\n{}", pending_traces);
390 } else {
391 panic!(
392 "Parking forbidden. Re-run with {PENDING_TRACES_VAR_NAME}=1 to show pending traces"
393 );
394 }
395 }
396}
397
398impl Scheduler for TestScheduler {
399 /// Block until the given future completes, with an optional timeout. If the
400 /// future is unable to make progress at any moment before the timeout and
401 /// no other tasks or timers remain, we panic unless parking is allowed. If
402 /// parking is allowed, we block up to the timeout or indefinitely if none
403 /// is provided. This is to allow testing a mix of deterministic and
404 /// non-deterministic async behavior, such as when interacting with I/O in
405 /// an otherwise deterministic test.
406 fn block(
407 &self,
408 session_id: Option<SessionId>,
409 mut future: Pin<&mut dyn Future<Output = ()>>,
410 timeout: Option<Duration>,
411 ) -> bool {
412 if let Some(session_id) = session_id {
413 self.state.lock().blocked_sessions.push(session_id);
414 }
415
416 let deadline = timeout.map(|timeout| Instant::now() + timeout);
417 let awoken = Arc::new(AtomicBool::new(false));
418 let waker = Box::new(TracingWaker {
419 id: None,
420 awoken: awoken.clone(),
421 thread: self.thread.clone(),
422 state: self.state.clone(),
423 });
424 let waker = unsafe { Waker::new(Box::into_raw(waker) as *const (), &WAKER_VTABLE) };
425 let max_ticks = if timeout.is_some() {
426 self.rng
427 .lock()
428 .random_range(self.state.lock().timeout_ticks.clone())
429 } else {
430 usize::MAX
431 };
432 let mut cx = Context::from_waker(&waker);
433
434 let mut completed = false;
435 for _ in 0..max_ticks {
436 match future.as_mut().poll(&mut cx) {
437 Poll::Ready(()) => {
438 completed = true;
439 break;
440 }
441 Poll::Pending => {}
442 }
443
444 let mut stepped = None;
445 while self.rng.lock().random() {
446 let stepped = stepped.get_or_insert(false);
447 if self.step() {
448 *stepped = true;
449 } else {
450 break;
451 }
452 }
453
454 let stepped = stepped.unwrap_or(true);
455 let awoken = awoken.swap(false, SeqCst);
456 if !stepped && !awoken && !self.advance_clock_to_next_timer() {
457 if !self.park(deadline) {
458 break;
459 }
460 }
461 }
462
463 if session_id.is_some() {
464 self.state.lock().blocked_sessions.pop();
465 }
466
467 completed
468 }
469
470 fn schedule_foreground(&self, session_id: SessionId, runnable: Runnable<RunnableMeta>) {
471 let mut state = self.state.lock();
472 let ix = if state.randomize_order {
473 let start_ix = state
474 .runnables
475 .iter()
476 .rposition(|task| task.session_id == Some(session_id))
477 .map_or(0, |ix| ix + 1);
478 self.rng
479 .lock()
480 .random_range(start_ix..=state.runnables.len())
481 } else {
482 state.runnables.len()
483 };
484 state.runnables.insert(
485 ix,
486 ScheduledRunnable {
487 session_id: Some(session_id),
488 priority: Priority::default(),
489 runnable,
490 },
491 );
492 drop(state);
493 self.thread.unpark();
494 }
495
496 fn schedule_background_with_priority(
497 &self,
498 runnable: Runnable<RunnableMeta>,
499 priority: Priority,
500 ) {
501 let mut state = self.state.lock();
502 let ix = if state.randomize_order {
503 self.rng.lock().random_range(0..=state.runnables.len())
504 } else {
505 state.runnables.len()
506 };
507 state.runnables.insert(
508 ix,
509 ScheduledRunnable {
510 session_id: None,
511 priority,
512 runnable,
513 },
514 );
515 drop(state);
516 self.thread.unpark();
517 }
518
519 fn timer(&self, duration: Duration) -> Timer {
520 let (tx, rx) = oneshot::channel();
521 let state = &mut *self.state.lock();
522 state.timers.push(ScheduledTimer {
523 expiration: self.clock.now() + duration,
524 _notify: tx,
525 });
526 state.timers.sort_by_key(|timer| timer.expiration);
527 Timer(rx)
528 }
529
530 fn clock(&self) -> Arc<dyn Clock> {
531 self.clock.clone()
532 }
533
534 fn as_test(&self) -> Option<&TestScheduler> {
535 Some(self)
536 }
537}
538
539#[derive(Clone, Debug)]
540pub struct TestSchedulerConfig {
541 pub seed: u64,
542 pub randomize_order: bool,
543 pub allow_parking: bool,
544 pub capture_pending_traces: bool,
545 pub timeout_ticks: RangeInclusive<usize>,
546}
547
548impl TestSchedulerConfig {
549 pub fn with_seed(seed: u64) -> Self {
550 Self {
551 seed,
552 ..Default::default()
553 }
554 }
555}
556
557impl Default for TestSchedulerConfig {
558 fn default() -> Self {
559 Self {
560 seed: 0,
561 randomize_order: true,
562 allow_parking: false,
563 capture_pending_traces: env::var(PENDING_TRACES_VAR_NAME)
564 .map_or(false, |var| var == "1" || var == "true"),
565 timeout_ticks: 0..=1000,
566 }
567 }
568}
569
570struct ScheduledRunnable {
571 session_id: Option<SessionId>,
572 priority: Priority,
573 runnable: Runnable<RunnableMeta>,
574}
575
576impl ScheduledRunnable {
577 fn run(self) {
578 self.runnable.run();
579 }
580}
581
582struct ScheduledTimer {
583 expiration: Instant,
584 _notify: oneshot::Sender<()>,
585}
586
587struct SchedulerState {
588 runnables: VecDeque<ScheduledRunnable>,
589 timers: Vec<ScheduledTimer>,
590 blocked_sessions: Vec<SessionId>,
591 randomize_order: bool,
592 allow_parking: bool,
593 timeout_ticks: RangeInclusive<usize>,
594 next_session_id: SessionId,
595 capture_pending_traces: bool,
596 next_trace_id: TraceId,
597 pending_traces: BTreeMap<TraceId, Backtrace>,
598 is_main_thread: bool,
599}
600
601const WAKER_VTABLE: RawWakerVTable = RawWakerVTable::new(
602 TracingWaker::clone_raw,
603 TracingWaker::wake_raw,
604 TracingWaker::wake_by_ref_raw,
605 TracingWaker::drop_raw,
606);
607
608#[derive(Copy, Clone, Eq, PartialEq, PartialOrd, Ord)]
609struct TraceId(usize);
610
611struct TracingWaker {
612 id: Option<TraceId>,
613 awoken: Arc<AtomicBool>,
614 thread: Thread,
615 state: Arc<Mutex<SchedulerState>>,
616}
617
618impl Clone for TracingWaker {
619 fn clone(&self) -> Self {
620 let mut state = self.state.lock();
621 let id = if state.capture_pending_traces {
622 let id = state.next_trace_id;
623 state.next_trace_id.0 += 1;
624 state.pending_traces.insert(id, Backtrace::new_unresolved());
625 Some(id)
626 } else {
627 None
628 };
629 Self {
630 id,
631 awoken: self.awoken.clone(),
632 thread: self.thread.clone(),
633 state: self.state.clone(),
634 }
635 }
636}
637
638impl Drop for TracingWaker {
639 fn drop(&mut self) {
640 if let Some(id) = self.id {
641 self.state.lock().pending_traces.remove(&id);
642 }
643 }
644}
645
646impl TracingWaker {
647 fn wake(self) {
648 self.wake_by_ref();
649 }
650
651 fn wake_by_ref(&self) {
652 if let Some(id) = self.id {
653 self.state.lock().pending_traces.remove(&id);
654 }
655 self.awoken.store(true, SeqCst);
656 self.thread.unpark();
657 }
658
659 fn clone_raw(waker: *const ()) -> RawWaker {
660 let waker = waker as *const TracingWaker;
661 let waker = unsafe { &*waker };
662 RawWaker::new(
663 Box::into_raw(Box::new(waker.clone())) as *const (),
664 &WAKER_VTABLE,
665 )
666 }
667
668 fn wake_raw(waker: *const ()) {
669 let waker = unsafe { Box::from_raw(waker as *mut TracingWaker) };
670 waker.wake();
671 }
672
673 fn wake_by_ref_raw(waker: *const ()) {
674 let waker = waker as *const TracingWaker;
675 let waker = unsafe { &*waker };
676 waker.wake_by_ref();
677 }
678
679 fn drop_raw(waker: *const ()) {
680 let waker = unsafe { Box::from_raw(waker as *mut TracingWaker) };
681 drop(waker);
682 }
683}
684
685pub struct Yield(usize);
686
687/// A wrapper around `Arc<Mutex<StdRng>>` that provides convenient methods
688/// for random number generation without requiring explicit locking.
689#[derive(Clone)]
690pub struct SharedRng(Arc<Mutex<StdRng>>);
691
692impl SharedRng {
693 /// Lock the inner RNG for direct access. Use this when you need multiple
694 /// random operations without re-locking between each one.
695 pub fn lock(&self) -> MutexGuard<'_, StdRng> {
696 self.0.lock()
697 }
698
699 /// Generate a random value in the given range.
700 pub fn random_range<T, R>(&self, range: R) -> T
701 where
702 T: SampleUniform,
703 R: SampleRange<T>,
704 {
705 self.0.lock().random_range(range)
706 }
707
708 /// Generate a random boolean with the given probability of being true.
709 pub fn random_bool(&self, p: f64) -> bool {
710 self.0.lock().random_bool(p)
711 }
712
713 /// Generate a random value of the given type.
714 pub fn random<T>(&self) -> T
715 where
716 StandardUniform: Distribution<T>,
717 {
718 self.0.lock().random()
719 }
720
721 /// Generate a random ratio - true with probability `numerator/denominator`.
722 pub fn random_ratio(&self, numerator: u32, denominator: u32) -> bool {
723 self.0.lock().random_ratio(numerator, denominator)
724 }
725}
726
727impl Future for Yield {
728 type Output = ();
729
730 fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
731 if self.0 == 0 {
732 Poll::Ready(())
733 } else {
734 self.0 -= 1;
735 cx.waker().wake_by_ref();
736 Poll::Pending
737 }
738 }
739}
740
741fn exclude_wakers_from_trace(mut trace: Backtrace) -> Backtrace {
742 trace.resolve();
743 let mut frames: Vec<BacktraceFrame> = trace.into();
744 let waker_clone_frame_ix = frames.iter().position(|frame| {
745 frame.symbols().iter().any(|symbol| {
746 symbol
747 .name()
748 .is_some_and(|name| format!("{name:#?}") == type_name_of_val(&Waker::clone))
749 })
750 });
751
752 if let Some(waker_clone_frame_ix) = waker_clone_frame_ix {
753 frames.drain(..waker_clone_frame_ix + 1);
754 }
755
756 Backtrace::from(frames)
757}