1use crate::{
2 BackgroundExecutor, Clock, ForegroundExecutor, Priority, RunnableMeta, Scheduler, SessionId,
3 TestClock, Timer,
4};
5use async_task::Runnable;
6use backtrace::{Backtrace, BacktraceFrame};
7use futures::channel::oneshot;
8use parking_lot::{Mutex, MutexGuard};
9use rand::{
10 distr::{StandardUniform, uniform::SampleRange, uniform::SampleUniform},
11 prelude::*,
12};
13use std::{
14 any::type_name_of_val,
15 collections::{BTreeMap, HashSet, VecDeque},
16 env,
17 fmt::Write,
18 future::Future,
19 mem,
20 ops::RangeInclusive,
21 panic::{self, AssertUnwindSafe},
22 pin::Pin,
23 sync::{
24 Arc,
25 atomic::{AtomicBool, Ordering::SeqCst},
26 },
27 task::{Context, Poll, RawWaker, RawWakerVTable, Waker},
28 thread::{self, Thread},
29 time::{Duration, Instant},
30};
31
32const PENDING_TRACES_VAR_NAME: &str = "PENDING_TRACES";
33
34pub struct TestScheduler {
35 clock: Arc<TestClock>,
36 rng: Arc<Mutex<StdRng>>,
37 state: Arc<Mutex<SchedulerState>>,
38 thread: Thread,
39}
40
41impl TestScheduler {
42 /// Run a test once with default configuration (seed 0)
43 pub fn once<R>(f: impl AsyncFnOnce(Arc<TestScheduler>) -> R) -> R {
44 Self::with_seed(0, f)
45 }
46
47 /// Run a test multiple times with sequential seeds (0, 1, 2, ...)
48 pub fn many<R>(
49 default_iterations: usize,
50 mut f: impl AsyncFnMut(Arc<TestScheduler>) -> R,
51 ) -> Vec<R> {
52 let num_iterations = std::env::var("ITERATIONS")
53 .map(|iterations| iterations.parse().unwrap())
54 .unwrap_or(default_iterations);
55
56 let seed = std::env::var("SEED")
57 .map(|seed| seed.parse().unwrap())
58 .unwrap_or(0);
59
60 (seed..num_iterations as u64)
61 .map(|seed| {
62 let mut unwind_safe_f = AssertUnwindSafe(&mut f);
63 eprintln!("Running seed: {seed}");
64 match panic::catch_unwind(move || Self::with_seed(seed, &mut *unwind_safe_f)) {
65 Ok(result) => result,
66 Err(error) => {
67 eprintln!("\x1b[31mFailing Seed: {seed}\x1b[0m");
68 panic::resume_unwind(error);
69 }
70 }
71 })
72 .collect()
73 }
74
75 fn with_seed<R>(seed: u64, f: impl AsyncFnOnce(Arc<TestScheduler>) -> R) -> R {
76 let scheduler = Arc::new(TestScheduler::new(TestSchedulerConfig::with_seed(seed)));
77 let future = f(scheduler.clone());
78 let result = scheduler.foreground().block_on(future);
79 scheduler.run(); // Ensure spawned tasks finish up before returning in tests
80 result
81 }
82
83 pub fn new(config: TestSchedulerConfig) -> Self {
84 Self {
85 rng: Arc::new(Mutex::new(StdRng::seed_from_u64(config.seed))),
86 state: Arc::new(Mutex::new(SchedulerState {
87 runnables: VecDeque::new(),
88 timers: Vec::new(),
89 blocked_sessions: Vec::new(),
90 randomize_order: config.randomize_order,
91 allow_parking: config.allow_parking,
92 timeout_ticks: config.timeout_ticks,
93 next_session_id: SessionId(0),
94 capture_pending_traces: config.capture_pending_traces,
95 pending_traces: BTreeMap::new(),
96 next_trace_id: TraceId(0),
97 is_main_thread: true,
98 non_determinism_error: None,
99 finished: false,
100 parking_allowed_once: false,
101 unparked: false,
102 })),
103 clock: Arc::new(TestClock::new()),
104 thread: thread::current(),
105 }
106 }
107
108 pub fn end_test(&self) {
109 let mut state = self.state.lock();
110 if let Some((message, backtrace)) = &state.non_determinism_error {
111 panic!("{}\n{:?}", message, backtrace)
112 }
113 state.finished = true;
114 }
115
116 pub fn clock(&self) -> Arc<TestClock> {
117 self.clock.clone()
118 }
119
120 pub fn rng(&self) -> SharedRng {
121 SharedRng(self.rng.clone())
122 }
123
124 pub fn set_timeout_ticks(&self, timeout_ticks: RangeInclusive<usize>) {
125 self.state.lock().timeout_ticks = timeout_ticks;
126 }
127
128 pub fn allow_parking(&self) {
129 let mut state = self.state.lock();
130 state.allow_parking = true;
131 state.parking_allowed_once = true;
132 }
133
134 pub fn forbid_parking(&self) {
135 self.state.lock().allow_parking = false;
136 }
137
138 pub fn parking_allowed(&self) -> bool {
139 self.state.lock().allow_parking
140 }
141
142 pub fn is_main_thread(&self) -> bool {
143 self.state.lock().is_main_thread
144 }
145
146 /// Allocate a new session ID for foreground task scheduling.
147 /// This is used by GPUI's TestDispatcher to map dispatcher instances to sessions.
148 pub fn allocate_session_id(&self) -> SessionId {
149 let mut state = self.state.lock();
150 state.next_session_id.0 += 1;
151 state.next_session_id
152 }
153
154 /// Create a foreground executor for this scheduler
155 pub fn foreground(self: &Arc<Self>) -> ForegroundExecutor {
156 let session_id = self.allocate_session_id();
157 ForegroundExecutor::new(session_id, self.clone())
158 }
159
160 /// Create a background executor for this scheduler
161 pub fn background(self: &Arc<Self>) -> BackgroundExecutor {
162 BackgroundExecutor::new(self.clone())
163 }
164
165 pub fn yield_random(&self) -> Yield {
166 let rng = &mut *self.rng.lock();
167 if rng.random_bool(0.1) {
168 Yield(rng.random_range(10..20))
169 } else {
170 Yield(rng.random_range(0..2))
171 }
172 }
173
174 pub fn run(&self) {
175 while self.step() {
176 // Continue until no work remains
177 }
178 }
179
180 pub fn run_with_clock_advancement(&self) {
181 while self.step() || self.advance_clock_to_next_timer() {
182 // Continue until no work remains
183 }
184 }
185
186 /// Execute one tick of the scheduler, processing expired timers and running
187 /// at most one task. Returns true if any work was done.
188 ///
189 /// This is the public interface for GPUI's TestDispatcher to drive task execution.
190 pub fn tick(&self) -> bool {
191 self.step_filtered(false)
192 }
193
194 /// Execute one tick, but only run background tasks (no foreground/session tasks).
195 /// Returns true if any work was done.
196 pub fn tick_background_only(&self) -> bool {
197 self.step_filtered(true)
198 }
199
200 /// Check if there are any pending tasks or timers that could run.
201 pub fn has_pending_tasks(&self) -> bool {
202 let state = self.state.lock();
203 !state.runnables.is_empty() || !state.timers.is_empty()
204 }
205
206 /// Returns counts of (foreground_tasks, background_tasks) currently queued.
207 /// Foreground tasks are those with a session_id, background tasks have none.
208 pub fn pending_task_counts(&self) -> (usize, usize) {
209 let state = self.state.lock();
210 let foreground = state
211 .runnables
212 .iter()
213 .filter(|r| r.session_id.is_some())
214 .count();
215 let background = state
216 .runnables
217 .iter()
218 .filter(|r| r.session_id.is_none())
219 .count();
220 (foreground, background)
221 }
222
223 fn step(&self) -> bool {
224 self.step_filtered(false)
225 }
226
227 fn step_filtered(&self, background_only: bool) -> bool {
228 let (elapsed_count, runnables_before) = {
229 let mut state = self.state.lock();
230 let end_ix = state
231 .timers
232 .partition_point(|timer| timer.expiration <= self.clock.now());
233 let elapsed: Vec<_> = state.timers.drain(..end_ix).collect();
234 let count = elapsed.len();
235 let runnables = state.runnables.len();
236 drop(state);
237 // Dropping elapsed timers here wakes the waiting futures
238 drop(elapsed);
239 (count, runnables)
240 };
241
242 if elapsed_count > 0 {
243 let runnables_after = self.state.lock().runnables.len();
244 if std::env::var("DEBUG_SCHEDULER").is_ok() {
245 eprintln!(
246 "[scheduler] Expired {} timers at {:?}, runnables: {} -> {}",
247 elapsed_count,
248 self.clock.now(),
249 runnables_before,
250 runnables_after
251 );
252 }
253 return true;
254 }
255
256 let runnable = {
257 let state = &mut *self.state.lock();
258
259 // Find candidate tasks:
260 // - For foreground tasks (with session_id), only the first task from each session
261 // is a candidate (to preserve intra-session ordering)
262 // - For background tasks (no session_id), all are candidates
263 // - Tasks from blocked sessions are excluded
264 // - If background_only is true, skip foreground tasks entirely
265 let mut seen_sessions = HashSet::new();
266 let candidate_indices: Vec<usize> = state
267 .runnables
268 .iter()
269 .enumerate()
270 .filter(|(_, runnable)| {
271 if let Some(session_id) = runnable.session_id {
272 // Skip foreground tasks if background_only mode
273 if background_only {
274 return false;
275 }
276 // Exclude tasks from blocked sessions
277 if state.blocked_sessions.contains(&session_id) {
278 return false;
279 }
280 // Only include first task from each session (insert returns true if new)
281 seen_sessions.insert(session_id)
282 } else {
283 // Background tasks are always candidates
284 true
285 }
286 })
287 .map(|(ix, _)| ix)
288 .collect();
289
290 if candidate_indices.is_empty() {
291 None
292 } else if state.randomize_order {
293 // Use priority-weighted random selection
294 let weights: Vec<u32> = candidate_indices
295 .iter()
296 .map(|&ix| state.runnables[ix].priority.weight())
297 .collect();
298 let total_weight: u32 = weights.iter().sum();
299
300 if total_weight == 0 {
301 // Fallback to uniform random if all weights are zero
302 let choice = self.rng.lock().random_range(0..candidate_indices.len());
303 state.runnables.remove(candidate_indices[choice])
304 } else {
305 let mut target = self.rng.lock().random_range(0..total_weight);
306 let mut selected_idx = 0;
307 for (i, &weight) in weights.iter().enumerate() {
308 if target < weight {
309 selected_idx = i;
310 break;
311 }
312 target -= weight;
313 }
314 state.runnables.remove(candidate_indices[selected_idx])
315 }
316 } else {
317 // Non-randomized: just take the first candidate task
318 state.runnables.remove(candidate_indices[0])
319 }
320 };
321
322 if let Some(runnable) = runnable {
323 // Check if the executor that spawned this task was closed
324 if runnable.runnable.metadata().is_closed() {
325 return true;
326 }
327 let is_foreground = runnable.session_id.is_some();
328 let was_main_thread = self.state.lock().is_main_thread;
329 self.state.lock().is_main_thread = is_foreground;
330 runnable.run();
331 self.state.lock().is_main_thread = was_main_thread;
332 return true;
333 }
334
335 false
336 }
337
338 pub fn advance_clock_to_next_timer(&self) -> bool {
339 if let Some(timer) = self.state.lock().timers.first() {
340 self.clock.advance(timer.expiration - self.clock.now());
341 true
342 } else {
343 false
344 }
345 }
346
347 pub fn advance_clock(&self, duration: Duration) {
348 let debug = std::env::var("DEBUG_SCHEDULER").is_ok();
349 let start = self.clock.now();
350 let next_now = start + duration;
351 if debug {
352 let timer_count = self.state.lock().timers.len();
353 eprintln!(
354 "[scheduler] advance_clock({:?}) from {:?}, {} pending timers",
355 duration, start, timer_count
356 );
357 }
358 loop {
359 self.run();
360 if let Some(timer) = self.state.lock().timers.first()
361 && timer.expiration <= next_now
362 {
363 let advance_to = timer.expiration;
364 if debug {
365 eprintln!(
366 "[scheduler] Advancing clock {:?} -> {:?} for timer",
367 self.clock.now(),
368 advance_to
369 );
370 }
371 self.clock.advance(advance_to - self.clock.now());
372 } else {
373 break;
374 }
375 }
376 self.clock.advance(next_now - self.clock.now());
377 if debug {
378 eprintln!(
379 "[scheduler] advance_clock done, now at {:?}",
380 self.clock.now()
381 );
382 }
383 }
384
385 fn park(&self, deadline: Option<Instant>) -> bool {
386 if self.state.lock().allow_parking {
387 let start = Instant::now();
388 // Enforce a hard timeout to prevent tests from hanging indefinitely
389 let hard_deadline = start + Duration::from_secs(15);
390
391 // Use the earlier of the provided deadline or the hard timeout deadline
392 let effective_deadline = deadline
393 .map(|d| d.min(hard_deadline))
394 .unwrap_or(hard_deadline);
395
396 // Park in small intervals to allow checking both deadlines
397 const PARK_INTERVAL: Duration = Duration::from_millis(100);
398 loop {
399 let now = Instant::now();
400 if now >= effective_deadline {
401 // Check if we hit the hard timeout
402 if now >= hard_deadline {
403 panic!(
404 "Test timed out after 15 seconds while parking. \
405 This may indicate a deadlock or missing waker.",
406 );
407 }
408 // Hit the provided deadline
409 return false;
410 }
411
412 let remaining = effective_deadline.saturating_duration_since(now);
413 let park_duration = remaining.min(PARK_INTERVAL);
414 let before_park = Instant::now();
415 thread::park_timeout(park_duration);
416 let elapsed = before_park.elapsed();
417
418 // Advance the test clock by the real elapsed time while parking
419 self.clock.advance(elapsed);
420
421 // Check if any timers have expired after advancing the clock.
422 // If so, return so the caller can process them.
423 if self
424 .state
425 .lock()
426 .timers
427 .first()
428 .map_or(false, |t| t.expiration <= self.clock.now())
429 {
430 return true;
431 }
432
433 // Check if we were woken up by a different thread.
434 // We use a flag because timing-based detection is unreliable:
435 // OS scheduling delays can cause elapsed >= park_duration even when
436 // we were woken early by unpark().
437 if std::mem::take(&mut self.state.lock().unparked) {
438 return true;
439 }
440 }
441 } else if deadline.is_some() {
442 false
443 } else if self.state.lock().capture_pending_traces {
444 let mut pending_traces = String::new();
445 for (_, trace) in mem::take(&mut self.state.lock().pending_traces) {
446 writeln!(pending_traces, "{:?}", exclude_wakers_from_trace(trace)).unwrap();
447 }
448 panic!("Parking forbidden. Pending traces:\n{}", pending_traces);
449 } else {
450 panic!(
451 "Parking forbidden. Re-run with {PENDING_TRACES_VAR_NAME}=1 to show pending traces"
452 );
453 }
454 }
455}
456
457fn assert_correct_thread(expected: &Thread, state: &Arc<Mutex<SchedulerState>>) {
458 let current_thread = thread::current();
459 let mut state = state.lock();
460 if state.parking_allowed_once {
461 return;
462 }
463 if current_thread.id() == expected.id() {
464 return;
465 }
466
467 let message = format!(
468 "Detected activity on thread {:?} {:?}, but test scheduler is running on {:?} {:?}. Your test is not deterministic.",
469 current_thread.name(),
470 current_thread.id(),
471 expected.name(),
472 expected.id(),
473 );
474 let backtrace = Backtrace::new();
475 if state.finished {
476 panic!("{}", message);
477 } else {
478 state.non_determinism_error = Some((message, backtrace))
479 }
480}
481
482impl Scheduler for TestScheduler {
483 /// Block until the given future completes, with an optional timeout. If the
484 /// future is unable to make progress at any moment before the timeout and
485 /// no other tasks or timers remain, we panic unless parking is allowed. If
486 /// parking is allowed, we block up to the timeout or indefinitely if none
487 /// is provided. This is to allow testing a mix of deterministic and
488 /// non-deterministic async behavior, such as when interacting with I/O in
489 /// an otherwise deterministic test.
490 fn block(
491 &self,
492 session_id: Option<SessionId>,
493 mut future: Pin<&mut dyn Future<Output = ()>>,
494 timeout: Option<Duration>,
495 ) -> bool {
496 if let Some(session_id) = session_id {
497 self.state.lock().blocked_sessions.push(session_id);
498 }
499
500 let deadline = timeout.map(|timeout| Instant::now() + timeout);
501 let awoken = Arc::new(AtomicBool::new(false));
502 let waker = Box::new(TracingWaker {
503 id: None,
504 awoken: awoken.clone(),
505 thread: self.thread.clone(),
506 state: self.state.clone(),
507 });
508 let waker = unsafe { Waker::new(Box::into_raw(waker) as *const (), &WAKER_VTABLE) };
509 let max_ticks = if timeout.is_some() {
510 self.rng
511 .lock()
512 .random_range(self.state.lock().timeout_ticks.clone())
513 } else {
514 usize::MAX
515 };
516 let mut cx = Context::from_waker(&waker);
517
518 let mut completed = false;
519 for _ in 0..max_ticks {
520 match future.as_mut().poll(&mut cx) {
521 Poll::Ready(()) => {
522 completed = true;
523 break;
524 }
525 Poll::Pending => {}
526 }
527
528 let mut stepped = None;
529 while self.rng.lock().random() {
530 let stepped = stepped.get_or_insert(false);
531 if self.step() {
532 *stepped = true;
533 } else {
534 break;
535 }
536 }
537
538 let stepped = stepped.unwrap_or(true);
539 let awoken = awoken.swap(false, SeqCst);
540 if !stepped && !awoken {
541 let parking_allowed = self.state.lock().allow_parking;
542 // In deterministic mode (parking forbidden), instantly jump to the next timer.
543 // In non-deterministic mode (parking allowed), let real time pass instead.
544 let advanced_to_timer = !parking_allowed && self.advance_clock_to_next_timer();
545 if !advanced_to_timer && !self.park(deadline) {
546 break;
547 }
548 }
549 }
550
551 if session_id.is_some() {
552 self.state.lock().blocked_sessions.pop();
553 }
554
555 completed
556 }
557
558 fn schedule_foreground(&self, session_id: SessionId, runnable: Runnable<RunnableMeta>) {
559 assert_correct_thread(&self.thread, &self.state);
560 let mut state = self.state.lock();
561 let ix = if state.randomize_order {
562 let start_ix = state
563 .runnables
564 .iter()
565 .rposition(|task| task.session_id == Some(session_id))
566 .map_or(0, |ix| ix + 1);
567 self.rng
568 .lock()
569 .random_range(start_ix..=state.runnables.len())
570 } else {
571 state.runnables.len()
572 };
573 state.runnables.insert(
574 ix,
575 ScheduledRunnable {
576 session_id: Some(session_id),
577 priority: Priority::default(),
578 runnable,
579 },
580 );
581 state.unparked = true;
582 drop(state);
583 self.thread.unpark();
584 }
585
586 fn schedule_background_with_priority(
587 &self,
588 runnable: Runnable<RunnableMeta>,
589 priority: Priority,
590 ) {
591 assert_correct_thread(&self.thread, &self.state);
592 let mut state = self.state.lock();
593 let ix = if state.randomize_order {
594 self.rng.lock().random_range(0..=state.runnables.len())
595 } else {
596 state.runnables.len()
597 };
598 state.runnables.insert(
599 ix,
600 ScheduledRunnable {
601 session_id: None,
602 priority,
603 runnable,
604 },
605 );
606 state.unparked = true;
607 drop(state);
608 self.thread.unpark();
609 }
610
611 fn spawn_realtime(&self, f: Box<dyn FnOnce() + Send>) {
612 std::thread::spawn(move || {
613 f();
614 });
615 }
616
617 fn timer(&self, duration: Duration) -> Timer {
618 let (tx, rx) = oneshot::channel();
619 let state = &mut *self.state.lock();
620 state.timers.push(ScheduledTimer {
621 expiration: self.clock.now() + duration,
622 _notify: tx,
623 });
624 state.timers.sort_by_key(|timer| timer.expiration);
625 Timer(rx)
626 }
627
628 fn clock(&self) -> Arc<dyn Clock> {
629 self.clock.clone()
630 }
631
632 fn as_test(&self) -> Option<&TestScheduler> {
633 Some(self)
634 }
635}
636
637#[derive(Clone, Debug)]
638pub struct TestSchedulerConfig {
639 pub seed: u64,
640 pub randomize_order: bool,
641 pub allow_parking: bool,
642 pub capture_pending_traces: bool,
643 pub timeout_ticks: RangeInclusive<usize>,
644}
645
646impl TestSchedulerConfig {
647 pub fn with_seed(seed: u64) -> Self {
648 Self {
649 seed,
650 ..Default::default()
651 }
652 }
653}
654
655impl Default for TestSchedulerConfig {
656 fn default() -> Self {
657 Self {
658 seed: 0,
659 randomize_order: true,
660 allow_parking: false,
661 capture_pending_traces: env::var(PENDING_TRACES_VAR_NAME)
662 .map_or(false, |var| var == "1" || var == "true"),
663 timeout_ticks: 1..=1000,
664 }
665 }
666}
667
668struct ScheduledRunnable {
669 session_id: Option<SessionId>,
670 priority: Priority,
671 runnable: Runnable<RunnableMeta>,
672}
673
674impl ScheduledRunnable {
675 fn run(self) {
676 self.runnable.run();
677 }
678}
679
680struct ScheduledTimer {
681 expiration: Instant,
682 _notify: oneshot::Sender<()>,
683}
684
685struct SchedulerState {
686 runnables: VecDeque<ScheduledRunnable>,
687 timers: Vec<ScheduledTimer>,
688 blocked_sessions: Vec<SessionId>,
689 randomize_order: bool,
690 allow_parking: bool,
691 timeout_ticks: RangeInclusive<usize>,
692 next_session_id: SessionId,
693 capture_pending_traces: bool,
694 next_trace_id: TraceId,
695 pending_traces: BTreeMap<TraceId, Backtrace>,
696 is_main_thread: bool,
697 non_determinism_error: Option<(String, Backtrace)>,
698 parking_allowed_once: bool,
699 finished: bool,
700 unparked: bool,
701}
702
703const WAKER_VTABLE: RawWakerVTable = RawWakerVTable::new(
704 TracingWaker::clone_raw,
705 TracingWaker::wake_raw,
706 TracingWaker::wake_by_ref_raw,
707 TracingWaker::drop_raw,
708);
709
710#[derive(Copy, Clone, Eq, PartialEq, PartialOrd, Ord)]
711struct TraceId(usize);
712
713struct TracingWaker {
714 id: Option<TraceId>,
715 awoken: Arc<AtomicBool>,
716 thread: Thread,
717 state: Arc<Mutex<SchedulerState>>,
718}
719
720impl Clone for TracingWaker {
721 fn clone(&self) -> Self {
722 let mut state = self.state.lock();
723 let id = if state.capture_pending_traces {
724 let id = state.next_trace_id;
725 state.next_trace_id.0 += 1;
726 state.pending_traces.insert(id, Backtrace::new_unresolved());
727 Some(id)
728 } else {
729 None
730 };
731 Self {
732 id,
733 awoken: self.awoken.clone(),
734 thread: self.thread.clone(),
735 state: self.state.clone(),
736 }
737 }
738}
739
740impl Drop for TracingWaker {
741 fn drop(&mut self) {
742 assert_correct_thread(&self.thread, &self.state);
743
744 if let Some(id) = self.id {
745 self.state.lock().pending_traces.remove(&id);
746 }
747 }
748}
749
750impl TracingWaker {
751 fn wake(self) {
752 self.wake_by_ref();
753 }
754
755 fn wake_by_ref(&self) {
756 assert_correct_thread(&self.thread, &self.state);
757
758 let mut state = self.state.lock();
759 if let Some(id) = self.id {
760 state.pending_traces.remove(&id);
761 }
762 state.unparked = true;
763 drop(state);
764 self.awoken.store(true, SeqCst);
765 self.thread.unpark();
766 }
767
768 fn clone_raw(waker: *const ()) -> RawWaker {
769 let waker = waker as *const TracingWaker;
770 let waker = unsafe { &*waker };
771 RawWaker::new(
772 Box::into_raw(Box::new(waker.clone())) as *const (),
773 &WAKER_VTABLE,
774 )
775 }
776
777 fn wake_raw(waker: *const ()) {
778 let waker = unsafe { Box::from_raw(waker as *mut TracingWaker) };
779 waker.wake();
780 }
781
782 fn wake_by_ref_raw(waker: *const ()) {
783 let waker = waker as *const TracingWaker;
784 let waker = unsafe { &*waker };
785 waker.wake_by_ref();
786 }
787
788 fn drop_raw(waker: *const ()) {
789 let waker = unsafe { Box::from_raw(waker as *mut TracingWaker) };
790 drop(waker);
791 }
792}
793
794pub struct Yield(usize);
795
796/// A wrapper around `Arc<Mutex<StdRng>>` that provides convenient methods
797/// for random number generation without requiring explicit locking.
798#[derive(Clone)]
799pub struct SharedRng(Arc<Mutex<StdRng>>);
800
801impl SharedRng {
802 /// Lock the inner RNG for direct access. Use this when you need multiple
803 /// random operations without re-locking between each one.
804 pub fn lock(&self) -> MutexGuard<'_, StdRng> {
805 self.0.lock()
806 }
807
808 /// Generate a random value in the given range.
809 pub fn random_range<T, R>(&self, range: R) -> T
810 where
811 T: SampleUniform,
812 R: SampleRange<T>,
813 {
814 self.0.lock().random_range(range)
815 }
816
817 /// Generate a random boolean with the given probability of being true.
818 pub fn random_bool(&self, p: f64) -> bool {
819 self.0.lock().random_bool(p)
820 }
821
822 /// Generate a random value of the given type.
823 pub fn random<T>(&self) -> T
824 where
825 StandardUniform: Distribution<T>,
826 {
827 self.0.lock().random()
828 }
829
830 /// Generate a random ratio - true with probability `numerator/denominator`.
831 pub fn random_ratio(&self, numerator: u32, denominator: u32) -> bool {
832 self.0.lock().random_ratio(numerator, denominator)
833 }
834}
835
836impl Future for Yield {
837 type Output = ();
838
839 fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
840 if self.0 == 0 {
841 Poll::Ready(())
842 } else {
843 self.0 -= 1;
844 cx.waker().wake_by_ref();
845 Poll::Pending
846 }
847 }
848}
849
850fn exclude_wakers_from_trace(mut trace: Backtrace) -> Backtrace {
851 trace.resolve();
852 let mut frames: Vec<BacktraceFrame> = trace.into();
853 let waker_clone_frame_ix = frames.iter().position(|frame| {
854 frame.symbols().iter().any(|symbol| {
855 symbol
856 .name()
857 .is_some_and(|name| format!("{name:#?}") == type_name_of_val(&Waker::clone))
858 })
859 });
860
861 if let Some(waker_clone_frame_ix) = waker_clone_frame_ix {
862 frames.drain(..waker_clone_frame_ix + 1);
863 }
864
865 Backtrace::from(frames)
866}