1use crate::{
2 BackgroundExecutor, Clock, ForegroundExecutor, Priority, RunnableMeta, Scheduler, SessionId,
3 TestClock, Timer,
4};
5use async_task::Runnable;
6use backtrace::{Backtrace, BacktraceFrame};
7use futures::channel::oneshot;
8use parking_lot::{Mutex, MutexGuard};
9use rand::{
10 distr::{StandardUniform, uniform::SampleRange, uniform::SampleUniform},
11 prelude::*,
12};
13use std::{
14 any::type_name_of_val,
15 collections::{BTreeMap, HashSet, VecDeque},
16 env,
17 fmt::Write,
18 future::Future,
19 mem,
20 ops::RangeInclusive,
21 panic::{self, AssertUnwindSafe},
22 pin::Pin,
23 sync::{
24 Arc,
25 atomic::{AtomicBool, Ordering::SeqCst},
26 },
27 task::{Context, Poll, RawWaker, RawWakerVTable, Waker},
28 thread::{self, Thread},
29 time::{Duration, Instant},
30};
31
32const PENDING_TRACES_VAR_NAME: &str = "PENDING_TRACES";
33
34pub struct TestScheduler {
35 clock: Arc<TestClock>,
36 rng: Arc<Mutex<StdRng>>,
37 state: Arc<Mutex<SchedulerState>>,
38 thread: Thread,
39}
40
41impl TestScheduler {
42 /// Run a test once with default configuration (seed 0)
43 pub fn once<R>(f: impl AsyncFnOnce(Arc<TestScheduler>) -> R) -> R {
44 Self::with_seed(0, f)
45 }
46
47 /// Run a test multiple times with sequential seeds (0, 1, 2, ...)
48 pub fn many<R>(
49 default_iterations: usize,
50 mut f: impl AsyncFnMut(Arc<TestScheduler>) -> R,
51 ) -> Vec<R> {
52 let num_iterations = std::env::var("ITERATIONS")
53 .map(|iterations| iterations.parse().unwrap())
54 .unwrap_or(default_iterations);
55
56 let seed = std::env::var("SEED")
57 .map(|seed| seed.parse().unwrap())
58 .unwrap_or(0);
59
60 (seed..num_iterations as u64)
61 .map(|seed| {
62 let mut unwind_safe_f = AssertUnwindSafe(&mut f);
63 eprintln!("Running seed: {seed}");
64 match panic::catch_unwind(move || Self::with_seed(seed, &mut *unwind_safe_f)) {
65 Ok(result) => result,
66 Err(error) => {
67 eprintln!("\x1b[31mFailing Seed: {seed}\x1b[0m");
68 panic::resume_unwind(error);
69 }
70 }
71 })
72 .collect()
73 }
74
75 fn with_seed<R>(seed: u64, f: impl AsyncFnOnce(Arc<TestScheduler>) -> R) -> R {
76 let scheduler = Arc::new(TestScheduler::new(TestSchedulerConfig::with_seed(seed)));
77 let future = f(scheduler.clone());
78 let result = scheduler.foreground().block_on(future);
79 scheduler.run(); // Ensure spawned tasks finish up before returning in tests
80 result
81 }
82
83 pub fn new(config: TestSchedulerConfig) -> Self {
84 Self {
85 rng: Arc::new(Mutex::new(StdRng::seed_from_u64(config.seed))),
86 state: Arc::new(Mutex::new(SchedulerState {
87 runnables: VecDeque::new(),
88 timers: Vec::new(),
89 blocked_sessions: Vec::new(),
90 randomize_order: config.randomize_order,
91 allow_parking: config.allow_parking,
92 timeout_ticks: config.timeout_ticks,
93 next_session_id: SessionId(0),
94 capture_pending_traces: config.capture_pending_traces,
95 pending_traces: BTreeMap::new(),
96 next_trace_id: TraceId(0),
97 is_main_thread: true,
98 non_determinism_error: None,
99 finished: false,
100 parking_allowed_once: false,
101 unparked: false,
102 })),
103 clock: Arc::new(TestClock::new()),
104 thread: thread::current(),
105 }
106 }
107
108 pub fn end_test(&self) {
109 let mut state = self.state.lock();
110 if let Some((message, backtrace)) = &state.non_determinism_error {
111 panic!("{}\n{:?}", message, backtrace)
112 }
113 state.finished = true;
114 }
115
116 pub fn clock(&self) -> Arc<TestClock> {
117 self.clock.clone()
118 }
119
120 pub fn rng(&self) -> SharedRng {
121 SharedRng(self.rng.clone())
122 }
123
124 pub fn set_timeout_ticks(&self, timeout_ticks: RangeInclusive<usize>) {
125 self.state.lock().timeout_ticks = timeout_ticks;
126 }
127
128 pub fn allow_parking(&self) {
129 let mut state = self.state.lock();
130 state.allow_parking = true;
131 state.parking_allowed_once = true;
132 }
133
134 pub fn forbid_parking(&self) {
135 self.state.lock().allow_parking = false;
136 }
137
138 pub fn parking_allowed(&self) -> bool {
139 self.state.lock().allow_parking
140 }
141
142 pub fn is_main_thread(&self) -> bool {
143 self.state.lock().is_main_thread
144 }
145
146 /// Allocate a new session ID for foreground task scheduling.
147 /// This is used by GPUI's TestDispatcher to map dispatcher instances to sessions.
148 pub fn allocate_session_id(&self) -> SessionId {
149 let mut state = self.state.lock();
150 state.next_session_id.0 += 1;
151 state.next_session_id
152 }
153
154 /// Create a foreground executor for this scheduler
155 pub fn foreground(self: &Arc<Self>) -> ForegroundExecutor {
156 let session_id = self.allocate_session_id();
157 ForegroundExecutor::new(session_id, self.clone())
158 }
159
160 /// Create a background executor for this scheduler
161 pub fn background(self: &Arc<Self>) -> BackgroundExecutor {
162 BackgroundExecutor::new(self.clone())
163 }
164
165 pub fn yield_random(&self) -> Yield {
166 let rng = &mut *self.rng.lock();
167 if rng.random_bool(0.1) {
168 Yield(rng.random_range(10..20))
169 } else {
170 Yield(rng.random_range(0..2))
171 }
172 }
173
174 pub fn run(&self) {
175 while self.step() {
176 // Continue until no work remains
177 }
178 }
179
180 /// Execute one tick of the scheduler, processing expired timers and running
181 /// at most one task. Returns true if any work was done.
182 ///
183 /// This is the public interface for GPUI's TestDispatcher to drive task execution.
184 pub fn tick(&self) -> bool {
185 self.step_filtered(false)
186 }
187
188 /// Execute one tick, but only run background tasks (no foreground/session tasks).
189 /// Returns true if any work was done.
190 pub fn tick_background_only(&self) -> bool {
191 self.step_filtered(true)
192 }
193
194 /// Check if there are any pending tasks or timers that could run.
195 pub fn has_pending_tasks(&self) -> bool {
196 let state = self.state.lock();
197 !state.runnables.is_empty() || !state.timers.is_empty()
198 }
199
200 /// Returns counts of (foreground_tasks, background_tasks) currently queued.
201 /// Foreground tasks are those with a session_id, background tasks have none.
202 pub fn pending_task_counts(&self) -> (usize, usize) {
203 let state = self.state.lock();
204 let foreground = state
205 .runnables
206 .iter()
207 .filter(|r| r.session_id.is_some())
208 .count();
209 let background = state
210 .runnables
211 .iter()
212 .filter(|r| r.session_id.is_none())
213 .count();
214 (foreground, background)
215 }
216
217 fn step(&self) -> bool {
218 self.step_filtered(false)
219 }
220
221 fn step_filtered(&self, background_only: bool) -> bool {
222 let (elapsed_count, runnables_before) = {
223 let mut state = self.state.lock();
224 let end_ix = state
225 .timers
226 .partition_point(|timer| timer.expiration <= self.clock.now());
227 let elapsed: Vec<_> = state.timers.drain(..end_ix).collect();
228 let count = elapsed.len();
229 let runnables = state.runnables.len();
230 drop(state);
231 // Dropping elapsed timers here wakes the waiting futures
232 drop(elapsed);
233 (count, runnables)
234 };
235
236 if elapsed_count > 0 {
237 let runnables_after = self.state.lock().runnables.len();
238 if std::env::var("DEBUG_SCHEDULER").is_ok() {
239 eprintln!(
240 "[scheduler] Expired {} timers at {:?}, runnables: {} -> {}",
241 elapsed_count,
242 self.clock.now(),
243 runnables_before,
244 runnables_after
245 );
246 }
247 return true;
248 }
249
250 let runnable = {
251 let state = &mut *self.state.lock();
252
253 // Find candidate tasks:
254 // - For foreground tasks (with session_id), only the first task from each session
255 // is a candidate (to preserve intra-session ordering)
256 // - For background tasks (no session_id), all are candidates
257 // - Tasks from blocked sessions are excluded
258 // - If background_only is true, skip foreground tasks entirely
259 let mut seen_sessions = HashSet::new();
260 let candidate_indices: Vec<usize> = state
261 .runnables
262 .iter()
263 .enumerate()
264 .filter(|(_, runnable)| {
265 if let Some(session_id) = runnable.session_id {
266 // Skip foreground tasks if background_only mode
267 if background_only {
268 return false;
269 }
270 // Exclude tasks from blocked sessions
271 if state.blocked_sessions.contains(&session_id) {
272 return false;
273 }
274 // Only include first task from each session (insert returns true if new)
275 seen_sessions.insert(session_id)
276 } else {
277 // Background tasks are always candidates
278 true
279 }
280 })
281 .map(|(ix, _)| ix)
282 .collect();
283
284 if candidate_indices.is_empty() {
285 None
286 } else if state.randomize_order {
287 // Use priority-weighted random selection
288 let weights: Vec<u32> = candidate_indices
289 .iter()
290 .map(|&ix| state.runnables[ix].priority.weight())
291 .collect();
292 let total_weight: u32 = weights.iter().sum();
293
294 if total_weight == 0 {
295 // Fallback to uniform random if all weights are zero
296 let choice = self.rng.lock().random_range(0..candidate_indices.len());
297 state.runnables.remove(candidate_indices[choice])
298 } else {
299 let mut target = self.rng.lock().random_range(0..total_weight);
300 let mut selected_idx = 0;
301 for (i, &weight) in weights.iter().enumerate() {
302 if target < weight {
303 selected_idx = i;
304 break;
305 }
306 target -= weight;
307 }
308 state.runnables.remove(candidate_indices[selected_idx])
309 }
310 } else {
311 // Non-randomized: just take the first candidate task
312 state.runnables.remove(candidate_indices[0])
313 }
314 };
315
316 if let Some(runnable) = runnable {
317 // Check if the executor that spawned this task was closed
318 if runnable.runnable.metadata().is_closed() {
319 return true;
320 }
321 let is_foreground = runnable.session_id.is_some();
322 let was_main_thread = self.state.lock().is_main_thread;
323 self.state.lock().is_main_thread = is_foreground;
324 runnable.run();
325 self.state.lock().is_main_thread = was_main_thread;
326 return true;
327 }
328
329 false
330 }
331
332 pub fn advance_clock_to_next_timer(&self) -> bool {
333 if let Some(timer) = self.state.lock().timers.first() {
334 self.clock.advance(timer.expiration - self.clock.now());
335 true
336 } else {
337 false
338 }
339 }
340
341 pub fn advance_clock(&self, duration: Duration) {
342 let debug = std::env::var("DEBUG_SCHEDULER").is_ok();
343 let start = self.clock.now();
344 let next_now = start + duration;
345 if debug {
346 let timer_count = self.state.lock().timers.len();
347 eprintln!(
348 "[scheduler] advance_clock({:?}) from {:?}, {} pending timers",
349 duration, start, timer_count
350 );
351 }
352 loop {
353 self.run();
354 if let Some(timer) = self.state.lock().timers.first()
355 && timer.expiration <= next_now
356 {
357 let advance_to = timer.expiration;
358 if debug {
359 eprintln!(
360 "[scheduler] Advancing clock {:?} -> {:?} for timer",
361 self.clock.now(),
362 advance_to
363 );
364 }
365 self.clock.advance(advance_to - self.clock.now());
366 } else {
367 break;
368 }
369 }
370 self.clock.advance(next_now - self.clock.now());
371 if debug {
372 eprintln!(
373 "[scheduler] advance_clock done, now at {:?}",
374 self.clock.now()
375 );
376 }
377 }
378
379 fn park(&self, deadline: Option<Instant>) -> bool {
380 if self.state.lock().allow_parking {
381 let start = Instant::now();
382 // Enforce a hard timeout to prevent tests from hanging indefinitely
383 let hard_deadline = start + Duration::from_secs(15);
384
385 // Use the earlier of the provided deadline or the hard timeout deadline
386 let effective_deadline = deadline
387 .map(|d| d.min(hard_deadline))
388 .unwrap_or(hard_deadline);
389
390 // Park in small intervals to allow checking both deadlines
391 const PARK_INTERVAL: Duration = Duration::from_millis(100);
392 loop {
393 let now = Instant::now();
394 if now >= effective_deadline {
395 // Check if we hit the hard timeout
396 if now >= hard_deadline {
397 panic!(
398 "Test timed out after 15 seconds while parking. \
399 This may indicate a deadlock or missing waker.",
400 );
401 }
402 // Hit the provided deadline
403 return false;
404 }
405
406 let remaining = effective_deadline.saturating_duration_since(now);
407 let park_duration = remaining.min(PARK_INTERVAL);
408 let before_park = Instant::now();
409 thread::park_timeout(park_duration);
410 let elapsed = before_park.elapsed();
411
412 // Advance the test clock by the real elapsed time while parking
413 self.clock.advance(elapsed);
414
415 // Check if any timers have expired after advancing the clock.
416 // If so, return so the caller can process them.
417 if self
418 .state
419 .lock()
420 .timers
421 .first()
422 .map_or(false, |t| t.expiration <= self.clock.now())
423 {
424 return true;
425 }
426
427 // Check if we were woken up by a different thread.
428 // We use a flag because timing-based detection is unreliable:
429 // OS scheduling delays can cause elapsed >= park_duration even when
430 // we were woken early by unpark().
431 if std::mem::take(&mut self.state.lock().unparked) {
432 return true;
433 }
434 }
435 } else if deadline.is_some() {
436 false
437 } else if self.state.lock().capture_pending_traces {
438 let mut pending_traces = String::new();
439 for (_, trace) in mem::take(&mut self.state.lock().pending_traces) {
440 writeln!(pending_traces, "{:?}", exclude_wakers_from_trace(trace)).unwrap();
441 }
442 panic!("Parking forbidden. Pending traces:\n{}", pending_traces);
443 } else {
444 panic!(
445 "Parking forbidden. Re-run with {PENDING_TRACES_VAR_NAME}=1 to show pending traces"
446 );
447 }
448 }
449}
450
451fn assert_correct_thread(expected: &Thread, state: &Arc<Mutex<SchedulerState>>) {
452 let current_thread = thread::current();
453 let mut state = state.lock();
454 if state.parking_allowed_once {
455 return;
456 }
457 if current_thread.id() == expected.id() {
458 return;
459 }
460
461 let message = format!(
462 "Detected activity on thread {:?} {:?}, but test scheduler is running on {:?} {:?}. Your test is not deterministic.",
463 current_thread.name(),
464 current_thread.id(),
465 expected.name(),
466 expected.id(),
467 );
468 let backtrace = Backtrace::new();
469 if state.finished {
470 panic!("{}", message);
471 } else {
472 state.non_determinism_error = Some((message, backtrace))
473 }
474}
475
476impl Scheduler for TestScheduler {
477 /// Block until the given future completes, with an optional timeout. If the
478 /// future is unable to make progress at any moment before the timeout and
479 /// no other tasks or timers remain, we panic unless parking is allowed. If
480 /// parking is allowed, we block up to the timeout or indefinitely if none
481 /// is provided. This is to allow testing a mix of deterministic and
482 /// non-deterministic async behavior, such as when interacting with I/O in
483 /// an otherwise deterministic test.
484 fn block(
485 &self,
486 session_id: Option<SessionId>,
487 mut future: Pin<&mut dyn Future<Output = ()>>,
488 timeout: Option<Duration>,
489 ) -> bool {
490 if let Some(session_id) = session_id {
491 self.state.lock().blocked_sessions.push(session_id);
492 }
493
494 let deadline = timeout.map(|timeout| Instant::now() + timeout);
495 let awoken = Arc::new(AtomicBool::new(false));
496 let waker = Box::new(TracingWaker {
497 id: None,
498 awoken: awoken.clone(),
499 thread: self.thread.clone(),
500 state: self.state.clone(),
501 });
502 let waker = unsafe { Waker::new(Box::into_raw(waker) as *const (), &WAKER_VTABLE) };
503 let max_ticks = if timeout.is_some() {
504 self.rng
505 .lock()
506 .random_range(self.state.lock().timeout_ticks.clone())
507 } else {
508 usize::MAX
509 };
510 let mut cx = Context::from_waker(&waker);
511
512 let mut completed = false;
513 for _ in 0..max_ticks {
514 match future.as_mut().poll(&mut cx) {
515 Poll::Ready(()) => {
516 completed = true;
517 break;
518 }
519 Poll::Pending => {}
520 }
521
522 let mut stepped = None;
523 while self.rng.lock().random() {
524 let stepped = stepped.get_or_insert(false);
525 if self.step() {
526 *stepped = true;
527 } else {
528 break;
529 }
530 }
531
532 let stepped = stepped.unwrap_or(true);
533 let awoken = awoken.swap(false, SeqCst);
534 if !stepped && !awoken {
535 let parking_allowed = self.state.lock().allow_parking;
536 // In deterministic mode (parking forbidden), instantly jump to the next timer.
537 // In non-deterministic mode (parking allowed), let real time pass instead.
538 let advanced_to_timer = !parking_allowed && self.advance_clock_to_next_timer();
539 if !advanced_to_timer && !self.park(deadline) {
540 break;
541 }
542 }
543 }
544
545 if session_id.is_some() {
546 self.state.lock().blocked_sessions.pop();
547 }
548
549 completed
550 }
551
552 fn schedule_foreground(&self, session_id: SessionId, runnable: Runnable<RunnableMeta>) {
553 assert_correct_thread(&self.thread, &self.state);
554 let mut state = self.state.lock();
555 let ix = if state.randomize_order {
556 let start_ix = state
557 .runnables
558 .iter()
559 .rposition(|task| task.session_id == Some(session_id))
560 .map_or(0, |ix| ix + 1);
561 self.rng
562 .lock()
563 .random_range(start_ix..=state.runnables.len())
564 } else {
565 state.runnables.len()
566 };
567 state.runnables.insert(
568 ix,
569 ScheduledRunnable {
570 session_id: Some(session_id),
571 priority: Priority::default(),
572 runnable,
573 },
574 );
575 state.unparked = true;
576 drop(state);
577 self.thread.unpark();
578 }
579
580 fn schedule_background_with_priority(
581 &self,
582 runnable: Runnable<RunnableMeta>,
583 priority: Priority,
584 ) {
585 assert_correct_thread(&self.thread, &self.state);
586 let mut state = self.state.lock();
587 let ix = if state.randomize_order {
588 self.rng.lock().random_range(0..=state.runnables.len())
589 } else {
590 state.runnables.len()
591 };
592 state.runnables.insert(
593 ix,
594 ScheduledRunnable {
595 session_id: None,
596 priority,
597 runnable,
598 },
599 );
600 state.unparked = true;
601 drop(state);
602 self.thread.unpark();
603 }
604
605 fn spawn_realtime(&self, f: Box<dyn FnOnce() + Send>) {
606 std::thread::spawn(move || {
607 f();
608 });
609 }
610
611 fn timer(&self, duration: Duration) -> Timer {
612 let (tx, rx) = oneshot::channel();
613 let state = &mut *self.state.lock();
614 state.timers.push(ScheduledTimer {
615 expiration: self.clock.now() + duration,
616 _notify: tx,
617 });
618 state.timers.sort_by_key(|timer| timer.expiration);
619 Timer(rx)
620 }
621
622 fn clock(&self) -> Arc<dyn Clock> {
623 self.clock.clone()
624 }
625
626 fn as_test(&self) -> Option<&TestScheduler> {
627 Some(self)
628 }
629}
630
631#[derive(Clone, Debug)]
632pub struct TestSchedulerConfig {
633 pub seed: u64,
634 pub randomize_order: bool,
635 pub allow_parking: bool,
636 pub capture_pending_traces: bool,
637 pub timeout_ticks: RangeInclusive<usize>,
638}
639
640impl TestSchedulerConfig {
641 pub fn with_seed(seed: u64) -> Self {
642 Self {
643 seed,
644 ..Default::default()
645 }
646 }
647}
648
649impl Default for TestSchedulerConfig {
650 fn default() -> Self {
651 Self {
652 seed: 0,
653 randomize_order: true,
654 allow_parking: false,
655 capture_pending_traces: env::var(PENDING_TRACES_VAR_NAME)
656 .map_or(false, |var| var == "1" || var == "true"),
657 timeout_ticks: 1..=1000,
658 }
659 }
660}
661
662struct ScheduledRunnable {
663 session_id: Option<SessionId>,
664 priority: Priority,
665 runnable: Runnable<RunnableMeta>,
666}
667
668impl ScheduledRunnable {
669 fn run(self) {
670 self.runnable.run();
671 }
672}
673
674struct ScheduledTimer {
675 expiration: Instant,
676 _notify: oneshot::Sender<()>,
677}
678
679struct SchedulerState {
680 runnables: VecDeque<ScheduledRunnable>,
681 timers: Vec<ScheduledTimer>,
682 blocked_sessions: Vec<SessionId>,
683 randomize_order: bool,
684 allow_parking: bool,
685 timeout_ticks: RangeInclusive<usize>,
686 next_session_id: SessionId,
687 capture_pending_traces: bool,
688 next_trace_id: TraceId,
689 pending_traces: BTreeMap<TraceId, Backtrace>,
690 is_main_thread: bool,
691 non_determinism_error: Option<(String, Backtrace)>,
692 parking_allowed_once: bool,
693 finished: bool,
694 unparked: bool,
695}
696
697const WAKER_VTABLE: RawWakerVTable = RawWakerVTable::new(
698 TracingWaker::clone_raw,
699 TracingWaker::wake_raw,
700 TracingWaker::wake_by_ref_raw,
701 TracingWaker::drop_raw,
702);
703
704#[derive(Copy, Clone, Eq, PartialEq, PartialOrd, Ord)]
705struct TraceId(usize);
706
707struct TracingWaker {
708 id: Option<TraceId>,
709 awoken: Arc<AtomicBool>,
710 thread: Thread,
711 state: Arc<Mutex<SchedulerState>>,
712}
713
714impl Clone for TracingWaker {
715 fn clone(&self) -> Self {
716 let mut state = self.state.lock();
717 let id = if state.capture_pending_traces {
718 let id = state.next_trace_id;
719 state.next_trace_id.0 += 1;
720 state.pending_traces.insert(id, Backtrace::new_unresolved());
721 Some(id)
722 } else {
723 None
724 };
725 Self {
726 id,
727 awoken: self.awoken.clone(),
728 thread: self.thread.clone(),
729 state: self.state.clone(),
730 }
731 }
732}
733
734impl Drop for TracingWaker {
735 fn drop(&mut self) {
736 assert_correct_thread(&self.thread, &self.state);
737
738 if let Some(id) = self.id {
739 self.state.lock().pending_traces.remove(&id);
740 }
741 }
742}
743
744impl TracingWaker {
745 fn wake(self) {
746 self.wake_by_ref();
747 }
748
749 fn wake_by_ref(&self) {
750 assert_correct_thread(&self.thread, &self.state);
751
752 let mut state = self.state.lock();
753 if let Some(id) = self.id {
754 state.pending_traces.remove(&id);
755 }
756 state.unparked = true;
757 drop(state);
758 self.awoken.store(true, SeqCst);
759 self.thread.unpark();
760 }
761
762 fn clone_raw(waker: *const ()) -> RawWaker {
763 let waker = waker as *const TracingWaker;
764 let waker = unsafe { &*waker };
765 RawWaker::new(
766 Box::into_raw(Box::new(waker.clone())) as *const (),
767 &WAKER_VTABLE,
768 )
769 }
770
771 fn wake_raw(waker: *const ()) {
772 let waker = unsafe { Box::from_raw(waker as *mut TracingWaker) };
773 waker.wake();
774 }
775
776 fn wake_by_ref_raw(waker: *const ()) {
777 let waker = waker as *const TracingWaker;
778 let waker = unsafe { &*waker };
779 waker.wake_by_ref();
780 }
781
782 fn drop_raw(waker: *const ()) {
783 let waker = unsafe { Box::from_raw(waker as *mut TracingWaker) };
784 drop(waker);
785 }
786}
787
788pub struct Yield(usize);
789
790/// A wrapper around `Arc<Mutex<StdRng>>` that provides convenient methods
791/// for random number generation without requiring explicit locking.
792#[derive(Clone)]
793pub struct SharedRng(Arc<Mutex<StdRng>>);
794
795impl SharedRng {
796 /// Lock the inner RNG for direct access. Use this when you need multiple
797 /// random operations without re-locking between each one.
798 pub fn lock(&self) -> MutexGuard<'_, StdRng> {
799 self.0.lock()
800 }
801
802 /// Generate a random value in the given range.
803 pub fn random_range<T, R>(&self, range: R) -> T
804 where
805 T: SampleUniform,
806 R: SampleRange<T>,
807 {
808 self.0.lock().random_range(range)
809 }
810
811 /// Generate a random boolean with the given probability of being true.
812 pub fn random_bool(&self, p: f64) -> bool {
813 self.0.lock().random_bool(p)
814 }
815
816 /// Generate a random value of the given type.
817 pub fn random<T>(&self) -> T
818 where
819 StandardUniform: Distribution<T>,
820 {
821 self.0.lock().random()
822 }
823
824 /// Generate a random ratio - true with probability `numerator/denominator`.
825 pub fn random_ratio(&self, numerator: u32, denominator: u32) -> bool {
826 self.0.lock().random_ratio(numerator, denominator)
827 }
828}
829
830impl Future for Yield {
831 type Output = ();
832
833 fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
834 if self.0 == 0 {
835 Poll::Ready(())
836 } else {
837 self.0 -= 1;
838 cx.waker().wake_by_ref();
839 Poll::Pending
840 }
841 }
842}
843
844fn exclude_wakers_from_trace(mut trace: Backtrace) -> Backtrace {
845 trace.resolve();
846 let mut frames: Vec<BacktraceFrame> = trace.into();
847 let waker_clone_frame_ix = frames.iter().position(|frame| {
848 frame.symbols().iter().any(|symbol| {
849 symbol
850 .name()
851 .is_some_and(|name| format!("{name:#?}") == type_name_of_val(&Waker::clone))
852 })
853 });
854
855 if let Some(waker_clone_frame_ix) = waker_clone_frame_ix {
856 frames.drain(..waker_clone_frame_ix + 1);
857 }
858
859 Backtrace::from(frames)
860}