test_scheduler.rs

  1use crate::{
  2    BackgroundExecutor, Clock, ForegroundExecutor, Priority, RunnableMeta, Scheduler, SessionId,
  3    TestClock, Timer,
  4};
  5use async_task::Runnable;
  6use backtrace::{Backtrace, BacktraceFrame};
  7use futures::channel::oneshot;
  8use parking_lot::{Mutex, MutexGuard};
  9use rand::{
 10    distr::{StandardUniform, uniform::SampleRange, uniform::SampleUniform},
 11    prelude::*,
 12};
 13use std::{
 14    any::type_name_of_val,
 15    collections::{BTreeMap, HashSet, VecDeque},
 16    env,
 17    fmt::Write,
 18    future::Future,
 19    mem,
 20    ops::RangeInclusive,
 21    panic::{self, AssertUnwindSafe},
 22    pin::Pin,
 23    sync::{
 24        Arc,
 25        atomic::{AtomicBool, Ordering::SeqCst},
 26    },
 27    task::{Context, Poll, RawWaker, RawWakerVTable, Waker},
 28    thread::{self, Thread},
 29    time::{Duration, Instant},
 30};
 31
 32const PENDING_TRACES_VAR_NAME: &str = "PENDING_TRACES";
 33
 34pub struct TestScheduler {
 35    clock: Arc<TestClock>,
 36    rng: Arc<Mutex<StdRng>>,
 37    state: Arc<Mutex<SchedulerState>>,
 38    thread: Thread,
 39}
 40
 41impl TestScheduler {
 42    /// Run a test once with default configuration (seed 0)
 43    pub fn once<R>(f: impl AsyncFnOnce(Arc<TestScheduler>) -> R) -> R {
 44        Self::with_seed(0, f)
 45    }
 46
 47    /// Run a test multiple times with sequential seeds (0, 1, 2, ...)
 48    pub fn many<R>(
 49        default_iterations: usize,
 50        mut f: impl AsyncFnMut(Arc<TestScheduler>) -> R,
 51    ) -> Vec<R> {
 52        let num_iterations = std::env::var("ITERATIONS")
 53            .map(|iterations| iterations.parse().unwrap())
 54            .unwrap_or(default_iterations);
 55
 56        let seed = std::env::var("SEED")
 57            .map(|seed| seed.parse().unwrap())
 58            .unwrap_or(0);
 59
 60        (seed..num_iterations as u64)
 61            .map(|seed| {
 62                let mut unwind_safe_f = AssertUnwindSafe(&mut f);
 63                eprintln!("Running seed: {seed}");
 64                match panic::catch_unwind(move || Self::with_seed(seed, &mut *unwind_safe_f)) {
 65                    Ok(result) => result,
 66                    Err(error) => {
 67                        eprintln!("\x1b[31mFailing Seed: {seed}\x1b[0m");
 68                        panic::resume_unwind(error);
 69                    }
 70                }
 71            })
 72            .collect()
 73    }
 74
 75    fn with_seed<R>(seed: u64, f: impl AsyncFnOnce(Arc<TestScheduler>) -> R) -> R {
 76        let scheduler = Arc::new(TestScheduler::new(TestSchedulerConfig::with_seed(seed)));
 77        let future = f(scheduler.clone());
 78        let result = scheduler.foreground().block_on(future);
 79        scheduler.run(); // Ensure spawned tasks finish up before returning in tests
 80        result
 81    }
 82
 83    pub fn new(config: TestSchedulerConfig) -> Self {
 84        Self {
 85            rng: Arc::new(Mutex::new(StdRng::seed_from_u64(config.seed))),
 86            state: Arc::new(Mutex::new(SchedulerState {
 87                runnables: VecDeque::new(),
 88                timers: Vec::new(),
 89                blocked_sessions: Vec::new(),
 90                randomize_order: config.randomize_order,
 91                allow_parking: config.allow_parking,
 92                timeout_ticks: config.timeout_ticks,
 93                next_session_id: SessionId(0),
 94                capture_pending_traces: config.capture_pending_traces,
 95                pending_traces: BTreeMap::new(),
 96                next_trace_id: TraceId(0),
 97                is_main_thread: true,
 98                non_determinism_error: None,
 99                finished: false,
100                parking_allowed_once: false,
101                unparked: false,
102            })),
103            clock: Arc::new(TestClock::new()),
104            thread: thread::current(),
105        }
106    }
107
108    pub fn end_test(&self) {
109        let mut state = self.state.lock();
110        if let Some((message, backtrace)) = &state.non_determinism_error {
111            panic!("{}\n{:?}", message, backtrace)
112        }
113        state.finished = true;
114    }
115
116    pub fn clock(&self) -> Arc<TestClock> {
117        self.clock.clone()
118    }
119
120    pub fn rng(&self) -> SharedRng {
121        SharedRng(self.rng.clone())
122    }
123
124    pub fn set_timeout_ticks(&self, timeout_ticks: RangeInclusive<usize>) {
125        self.state.lock().timeout_ticks = timeout_ticks;
126    }
127
128    pub fn allow_parking(&self) {
129        let mut state = self.state.lock();
130        state.allow_parking = true;
131        state.parking_allowed_once = true;
132    }
133
134    pub fn forbid_parking(&self) {
135        self.state.lock().allow_parking = false;
136    }
137
138    pub fn parking_allowed(&self) -> bool {
139        self.state.lock().allow_parking
140    }
141
142    pub fn is_main_thread(&self) -> bool {
143        self.state.lock().is_main_thread
144    }
145
146    /// Allocate a new session ID for foreground task scheduling.
147    /// This is used by GPUI's TestDispatcher to map dispatcher instances to sessions.
148    pub fn allocate_session_id(&self) -> SessionId {
149        let mut state = self.state.lock();
150        state.next_session_id.0 += 1;
151        state.next_session_id
152    }
153
154    /// Create a foreground executor for this scheduler
155    pub fn foreground(self: &Arc<Self>) -> ForegroundExecutor {
156        let session_id = self.allocate_session_id();
157        ForegroundExecutor::new(session_id, self.clone())
158    }
159
160    /// Create a background executor for this scheduler
161    pub fn background(self: &Arc<Self>) -> BackgroundExecutor {
162        BackgroundExecutor::new(self.clone())
163    }
164
165    pub fn yield_random(&self) -> Yield {
166        let rng = &mut *self.rng.lock();
167        if rng.random_bool(0.1) {
168            Yield(rng.random_range(10..20))
169        } else {
170            Yield(rng.random_range(0..2))
171        }
172    }
173
174    pub fn run(&self) {
175        while self.step() {
176            // Continue until no work remains
177        }
178    }
179
180    pub fn run_with_clock_advancement(&self) {
181        while self.step() || self.advance_clock_to_next_timer() {
182            // Continue until no work remains
183        }
184    }
185
186    /// Execute one tick of the scheduler, processing expired timers and running
187    /// at most one task. Returns true if any work was done.
188    ///
189    /// This is the public interface for GPUI's TestDispatcher to drive task execution.
190    pub fn tick(&self) -> bool {
191        self.step_filtered(false)
192    }
193
194    /// Execute one tick, but only run background tasks (no foreground/session tasks).
195    /// Returns true if any work was done.
196    pub fn tick_background_only(&self) -> bool {
197        self.step_filtered(true)
198    }
199
200    /// Check if there are any pending tasks or timers that could run.
201    pub fn has_pending_tasks(&self) -> bool {
202        let state = self.state.lock();
203        !state.runnables.is_empty() || !state.timers.is_empty()
204    }
205
206    /// Returns counts of (foreground_tasks, background_tasks) currently queued.
207    /// Foreground tasks are those with a session_id, background tasks have none.
208    pub fn pending_task_counts(&self) -> (usize, usize) {
209        let state = self.state.lock();
210        let foreground = state
211            .runnables
212            .iter()
213            .filter(|r| r.session_id.is_some())
214            .count();
215        let background = state
216            .runnables
217            .iter()
218            .filter(|r| r.session_id.is_none())
219            .count();
220        (foreground, background)
221    }
222
223    fn step(&self) -> bool {
224        self.step_filtered(false)
225    }
226
227    fn step_filtered(&self, background_only: bool) -> bool {
228        let (elapsed_count, runnables_before) = {
229            let mut state = self.state.lock();
230            let end_ix = state
231                .timers
232                .partition_point(|timer| timer.expiration <= self.clock.now());
233            let elapsed: Vec<_> = state.timers.drain(..end_ix).collect();
234            let count = elapsed.len();
235            let runnables = state.runnables.len();
236            drop(state);
237            // Dropping elapsed timers here wakes the waiting futures
238            drop(elapsed);
239            (count, runnables)
240        };
241
242        if elapsed_count > 0 {
243            let runnables_after = self.state.lock().runnables.len();
244            if std::env::var("DEBUG_SCHEDULER").is_ok() {
245                eprintln!(
246                    "[scheduler] Expired {} timers at {:?}, runnables: {} -> {}",
247                    elapsed_count,
248                    self.clock.now(),
249                    runnables_before,
250                    runnables_after
251                );
252            }
253            return true;
254        }
255
256        let runnable = {
257            let state = &mut *self.state.lock();
258
259            // Find candidate tasks:
260            // - For foreground tasks (with session_id), only the first task from each session
261            //   is a candidate (to preserve intra-session ordering)
262            // - For background tasks (no session_id), all are candidates
263            // - Tasks from blocked sessions are excluded
264            // - If background_only is true, skip foreground tasks entirely
265            let mut seen_sessions = HashSet::new();
266            let candidate_indices: Vec<usize> = state
267                .runnables
268                .iter()
269                .enumerate()
270                .filter(|(_, runnable)| {
271                    if let Some(session_id) = runnable.session_id {
272                        // Skip foreground tasks if background_only mode
273                        if background_only {
274                            return false;
275                        }
276                        // Exclude tasks from blocked sessions
277                        if state.blocked_sessions.contains(&session_id) {
278                            return false;
279                        }
280                        // Only include first task from each session (insert returns true if new)
281                        seen_sessions.insert(session_id)
282                    } else {
283                        // Background tasks are always candidates
284                        true
285                    }
286                })
287                .map(|(ix, _)| ix)
288                .collect();
289
290            if candidate_indices.is_empty() {
291                None
292            } else if state.randomize_order {
293                // Use priority-weighted random selection
294                let weights: Vec<u32> = candidate_indices
295                    .iter()
296                    .map(|&ix| state.runnables[ix].priority.weight())
297                    .collect();
298                let total_weight: u32 = weights.iter().sum();
299
300                if total_weight == 0 {
301                    // Fallback to uniform random if all weights are zero
302                    let choice = self.rng.lock().random_range(0..candidate_indices.len());
303                    state.runnables.remove(candidate_indices[choice])
304                } else {
305                    let mut target = self.rng.lock().random_range(0..total_weight);
306                    let mut selected_idx = 0;
307                    for (i, &weight) in weights.iter().enumerate() {
308                        if target < weight {
309                            selected_idx = i;
310                            break;
311                        }
312                        target -= weight;
313                    }
314                    state.runnables.remove(candidate_indices[selected_idx])
315                }
316            } else {
317                // Non-randomized: just take the first candidate task
318                state.runnables.remove(candidate_indices[0])
319            }
320        };
321
322        if let Some(runnable) = runnable {
323            // Check if the executor that spawned this task was closed
324            if runnable.runnable.metadata().is_closed() {
325                return true;
326            }
327            let is_foreground = runnable.session_id.is_some();
328            let was_main_thread = self.state.lock().is_main_thread;
329            self.state.lock().is_main_thread = is_foreground;
330            runnable.run();
331            self.state.lock().is_main_thread = was_main_thread;
332            return true;
333        }
334
335        false
336    }
337
338    pub fn advance_clock_to_next_timer(&self) -> bool {
339        if let Some(timer) = self.state.lock().timers.first() {
340            self.clock.advance(timer.expiration - self.clock.now());
341            true
342        } else {
343            false
344        }
345    }
346
347    pub fn advance_clock(&self, duration: Duration) {
348        let debug = std::env::var("DEBUG_SCHEDULER").is_ok();
349        let start = self.clock.now();
350        let next_now = start + duration;
351        if debug {
352            let timer_count = self.state.lock().timers.len();
353            eprintln!(
354                "[scheduler] advance_clock({:?}) from {:?}, {} pending timers",
355                duration, start, timer_count
356            );
357        }
358        loop {
359            self.run();
360            if let Some(timer) = self.state.lock().timers.first()
361                && timer.expiration <= next_now
362            {
363                let advance_to = timer.expiration;
364                if debug {
365                    eprintln!(
366                        "[scheduler] Advancing clock {:?} -> {:?} for timer",
367                        self.clock.now(),
368                        advance_to
369                    );
370                }
371                self.clock.advance(advance_to - self.clock.now());
372            } else {
373                break;
374            }
375        }
376        self.clock.advance(next_now - self.clock.now());
377        if debug {
378            eprintln!(
379                "[scheduler] advance_clock done, now at {:?}",
380                self.clock.now()
381            );
382        }
383    }
384
385    fn park(&self, deadline: Option<Instant>) -> bool {
386        if self.state.lock().allow_parking {
387            let start = Instant::now();
388            // Enforce a hard timeout to prevent tests from hanging indefinitely
389            let hard_deadline = start + Duration::from_secs(15);
390
391            // Use the earlier of the provided deadline or the hard timeout deadline
392            let effective_deadline = deadline
393                .map(|d| d.min(hard_deadline))
394                .unwrap_or(hard_deadline);
395
396            // Park in small intervals to allow checking both deadlines
397            const PARK_INTERVAL: Duration = Duration::from_millis(100);
398            loop {
399                let now = Instant::now();
400                if now >= effective_deadline {
401                    // Check if we hit the hard timeout
402                    if now >= hard_deadline {
403                        panic!(
404                            "Test timed out after 15 seconds while parking. \
405                            This may indicate a deadlock or missing waker.",
406                        );
407                    }
408                    // Hit the provided deadline
409                    return false;
410                }
411
412                let remaining = effective_deadline.saturating_duration_since(now);
413                let park_duration = remaining.min(PARK_INTERVAL);
414                let before_park = Instant::now();
415                thread::park_timeout(park_duration);
416                let elapsed = before_park.elapsed();
417
418                // Advance the test clock by the real elapsed time while parking
419                self.clock.advance(elapsed);
420
421                // Check if any timers have expired after advancing the clock.
422                // If so, return so the caller can process them.
423                if self
424                    .state
425                    .lock()
426                    .timers
427                    .first()
428                    .map_or(false, |t| t.expiration <= self.clock.now())
429                {
430                    return true;
431                }
432
433                // Check if we were woken up by a different thread.
434                // We use a flag because timing-based detection is unreliable:
435                // OS scheduling delays can cause elapsed >= park_duration even when
436                // we were woken early by unpark().
437                if std::mem::take(&mut self.state.lock().unparked) {
438                    return true;
439                }
440            }
441        } else if deadline.is_some() {
442            false
443        } else if self.state.lock().capture_pending_traces {
444            let mut pending_traces = String::new();
445            for (_, trace) in mem::take(&mut self.state.lock().pending_traces) {
446                writeln!(pending_traces, "{:?}", exclude_wakers_from_trace(trace)).unwrap();
447            }
448            panic!("Parking forbidden. Pending traces:\n{}", pending_traces);
449        } else {
450            panic!(
451                "Parking forbidden. Re-run with {PENDING_TRACES_VAR_NAME}=1 to show pending traces"
452            );
453        }
454    }
455}
456
457fn assert_correct_thread(expected: &Thread, state: &Arc<Mutex<SchedulerState>>) {
458    let current_thread = thread::current();
459    let mut state = state.lock();
460    if state.parking_allowed_once {
461        return;
462    }
463    if current_thread.id() == expected.id() {
464        return;
465    }
466
467    let message = format!(
468        "Detected activity on thread {:?} {:?}, but test scheduler is running on {:?} {:?}. Your test is not deterministic.",
469        current_thread.name(),
470        current_thread.id(),
471        expected.name(),
472        expected.id(),
473    );
474    let backtrace = Backtrace::new();
475    if state.finished {
476        panic!("{}", message);
477    } else {
478        state.non_determinism_error = Some((message, backtrace))
479    }
480}
481
482impl Scheduler for TestScheduler {
483    /// Block until the given future completes, with an optional timeout. If the
484    /// future is unable to make progress at any moment before the timeout and
485    /// no other tasks or timers remain, we panic unless parking is allowed. If
486    /// parking is allowed, we block up to the timeout or indefinitely if none
487    /// is provided. This is to allow testing a mix of deterministic and
488    /// non-deterministic async behavior, such as when interacting with I/O in
489    /// an otherwise deterministic test.
490    fn block(
491        &self,
492        session_id: Option<SessionId>,
493        mut future: Pin<&mut dyn Future<Output = ()>>,
494        timeout: Option<Duration>,
495    ) -> bool {
496        if let Some(session_id) = session_id {
497            self.state.lock().blocked_sessions.push(session_id);
498        }
499
500        let deadline = timeout.map(|timeout| Instant::now() + timeout);
501        let awoken = Arc::new(AtomicBool::new(false));
502        let waker = Box::new(TracingWaker {
503            id: None,
504            awoken: awoken.clone(),
505            thread: self.thread.clone(),
506            state: self.state.clone(),
507        });
508        let waker = unsafe { Waker::new(Box::into_raw(waker) as *const (), &WAKER_VTABLE) };
509        let max_ticks = if timeout.is_some() {
510            self.rng
511                .lock()
512                .random_range(self.state.lock().timeout_ticks.clone())
513        } else {
514            usize::MAX
515        };
516        let mut cx = Context::from_waker(&waker);
517
518        let mut completed = false;
519        for _ in 0..max_ticks {
520            match future.as_mut().poll(&mut cx) {
521                Poll::Ready(()) => {
522                    completed = true;
523                    break;
524                }
525                Poll::Pending => {}
526            }
527
528            let mut stepped = None;
529            while self.rng.lock().random() {
530                let stepped = stepped.get_or_insert(false);
531                if self.step() {
532                    *stepped = true;
533                } else {
534                    break;
535                }
536            }
537
538            let stepped = stepped.unwrap_or(true);
539            let awoken = awoken.swap(false, SeqCst);
540            if !stepped && !awoken {
541                let parking_allowed = self.state.lock().allow_parking;
542                // In deterministic mode (parking forbidden), instantly jump to the next timer.
543                // In non-deterministic mode (parking allowed), let real time pass instead.
544                let advanced_to_timer = !parking_allowed && self.advance_clock_to_next_timer();
545                if !advanced_to_timer && !self.park(deadline) {
546                    break;
547                }
548            }
549        }
550
551        if session_id.is_some() {
552            self.state.lock().blocked_sessions.pop();
553        }
554
555        completed
556    }
557
558    fn schedule_foreground(&self, session_id: SessionId, runnable: Runnable<RunnableMeta>) {
559        assert_correct_thread(&self.thread, &self.state);
560        let mut state = self.state.lock();
561        let ix = if state.randomize_order {
562            let start_ix = state
563                .runnables
564                .iter()
565                .rposition(|task| task.session_id == Some(session_id))
566                .map_or(0, |ix| ix + 1);
567            self.rng
568                .lock()
569                .random_range(start_ix..=state.runnables.len())
570        } else {
571            state.runnables.len()
572        };
573        state.runnables.insert(
574            ix,
575            ScheduledRunnable {
576                session_id: Some(session_id),
577                priority: Priority::default(),
578                runnable,
579            },
580        );
581        state.unparked = true;
582        drop(state);
583        self.thread.unpark();
584    }
585
586    fn schedule_background_with_priority(
587        &self,
588        runnable: Runnable<RunnableMeta>,
589        priority: Priority,
590    ) {
591        assert_correct_thread(&self.thread, &self.state);
592        let mut state = self.state.lock();
593        let ix = if state.randomize_order {
594            self.rng.lock().random_range(0..=state.runnables.len())
595        } else {
596            state.runnables.len()
597        };
598        state.runnables.insert(
599            ix,
600            ScheduledRunnable {
601                session_id: None,
602                priority,
603                runnable,
604            },
605        );
606        state.unparked = true;
607        drop(state);
608        self.thread.unpark();
609    }
610
611    fn spawn_realtime(&self, f: Box<dyn FnOnce() + Send>) {
612        std::thread::spawn(move || {
613            f();
614        });
615    }
616
617    fn timer(&self, duration: Duration) -> Timer {
618        let (tx, rx) = oneshot::channel();
619        let state = &mut *self.state.lock();
620        state.timers.push(ScheduledTimer {
621            expiration: self.clock.now() + duration,
622            _notify: tx,
623        });
624        state.timers.sort_by_key(|timer| timer.expiration);
625        Timer(rx)
626    }
627
628    fn clock(&self) -> Arc<dyn Clock> {
629        self.clock.clone()
630    }
631
632    fn as_test(&self) -> Option<&TestScheduler> {
633        Some(self)
634    }
635}
636
637#[derive(Clone, Debug)]
638pub struct TestSchedulerConfig {
639    pub seed: u64,
640    pub randomize_order: bool,
641    pub allow_parking: bool,
642    pub capture_pending_traces: bool,
643    pub timeout_ticks: RangeInclusive<usize>,
644}
645
646impl TestSchedulerConfig {
647    pub fn with_seed(seed: u64) -> Self {
648        Self {
649            seed,
650            ..Default::default()
651        }
652    }
653}
654
655impl Default for TestSchedulerConfig {
656    fn default() -> Self {
657        Self {
658            seed: 0,
659            randomize_order: true,
660            allow_parking: false,
661            capture_pending_traces: env::var(PENDING_TRACES_VAR_NAME)
662                .map_or(false, |var| var == "1" || var == "true"),
663            timeout_ticks: 1..=1000,
664        }
665    }
666}
667
668struct ScheduledRunnable {
669    session_id: Option<SessionId>,
670    priority: Priority,
671    runnable: Runnable<RunnableMeta>,
672}
673
674impl ScheduledRunnable {
675    fn run(self) {
676        self.runnable.run();
677    }
678}
679
680struct ScheduledTimer {
681    expiration: Instant,
682    _notify: oneshot::Sender<()>,
683}
684
685struct SchedulerState {
686    runnables: VecDeque<ScheduledRunnable>,
687    timers: Vec<ScheduledTimer>,
688    blocked_sessions: Vec<SessionId>,
689    randomize_order: bool,
690    allow_parking: bool,
691    timeout_ticks: RangeInclusive<usize>,
692    next_session_id: SessionId,
693    capture_pending_traces: bool,
694    next_trace_id: TraceId,
695    pending_traces: BTreeMap<TraceId, Backtrace>,
696    is_main_thread: bool,
697    non_determinism_error: Option<(String, Backtrace)>,
698    parking_allowed_once: bool,
699    finished: bool,
700    unparked: bool,
701}
702
703const WAKER_VTABLE: RawWakerVTable = RawWakerVTable::new(
704    TracingWaker::clone_raw,
705    TracingWaker::wake_raw,
706    TracingWaker::wake_by_ref_raw,
707    TracingWaker::drop_raw,
708);
709
710#[derive(Copy, Clone, Eq, PartialEq, PartialOrd, Ord)]
711struct TraceId(usize);
712
713struct TracingWaker {
714    id: Option<TraceId>,
715    awoken: Arc<AtomicBool>,
716    thread: Thread,
717    state: Arc<Mutex<SchedulerState>>,
718}
719
720impl Clone for TracingWaker {
721    fn clone(&self) -> Self {
722        let mut state = self.state.lock();
723        let id = if state.capture_pending_traces {
724            let id = state.next_trace_id;
725            state.next_trace_id.0 += 1;
726            state.pending_traces.insert(id, Backtrace::new_unresolved());
727            Some(id)
728        } else {
729            None
730        };
731        Self {
732            id,
733            awoken: self.awoken.clone(),
734            thread: self.thread.clone(),
735            state: self.state.clone(),
736        }
737    }
738}
739
740impl Drop for TracingWaker {
741    fn drop(&mut self) {
742        assert_correct_thread(&self.thread, &self.state);
743
744        if let Some(id) = self.id {
745            self.state.lock().pending_traces.remove(&id);
746        }
747    }
748}
749
750impl TracingWaker {
751    fn wake(self) {
752        self.wake_by_ref();
753    }
754
755    fn wake_by_ref(&self) {
756        assert_correct_thread(&self.thread, &self.state);
757
758        let mut state = self.state.lock();
759        if let Some(id) = self.id {
760            state.pending_traces.remove(&id);
761        }
762        state.unparked = true;
763        drop(state);
764        self.awoken.store(true, SeqCst);
765        self.thread.unpark();
766    }
767
768    fn clone_raw(waker: *const ()) -> RawWaker {
769        let waker = waker as *const TracingWaker;
770        let waker = unsafe { &*waker };
771        RawWaker::new(
772            Box::into_raw(Box::new(waker.clone())) as *const (),
773            &WAKER_VTABLE,
774        )
775    }
776
777    fn wake_raw(waker: *const ()) {
778        let waker = unsafe { Box::from_raw(waker as *mut TracingWaker) };
779        waker.wake();
780    }
781
782    fn wake_by_ref_raw(waker: *const ()) {
783        let waker = waker as *const TracingWaker;
784        let waker = unsafe { &*waker };
785        waker.wake_by_ref();
786    }
787
788    fn drop_raw(waker: *const ()) {
789        let waker = unsafe { Box::from_raw(waker as *mut TracingWaker) };
790        drop(waker);
791    }
792}
793
794pub struct Yield(usize);
795
796/// A wrapper around `Arc<Mutex<StdRng>>` that provides convenient methods
797/// for random number generation without requiring explicit locking.
798#[derive(Clone)]
799pub struct SharedRng(Arc<Mutex<StdRng>>);
800
801impl SharedRng {
802    /// Lock the inner RNG for direct access. Use this when you need multiple
803    /// random operations without re-locking between each one.
804    pub fn lock(&self) -> MutexGuard<'_, StdRng> {
805        self.0.lock()
806    }
807
808    /// Generate a random value in the given range.
809    pub fn random_range<T, R>(&self, range: R) -> T
810    where
811        T: SampleUniform,
812        R: SampleRange<T>,
813    {
814        self.0.lock().random_range(range)
815    }
816
817    /// Generate a random boolean with the given probability of being true.
818    pub fn random_bool(&self, p: f64) -> bool {
819        self.0.lock().random_bool(p)
820    }
821
822    /// Generate a random value of the given type.
823    pub fn random<T>(&self) -> T
824    where
825        StandardUniform: Distribution<T>,
826    {
827        self.0.lock().random()
828    }
829
830    /// Generate a random ratio - true with probability `numerator/denominator`.
831    pub fn random_ratio(&self, numerator: u32, denominator: u32) -> bool {
832        self.0.lock().random_ratio(numerator, denominator)
833    }
834}
835
836impl Future for Yield {
837    type Output = ();
838
839    fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
840        if self.0 == 0 {
841            Poll::Ready(())
842        } else {
843            self.0 -= 1;
844            cx.waker().wake_by_ref();
845            Poll::Pending
846        }
847    }
848}
849
850fn exclude_wakers_from_trace(mut trace: Backtrace) -> Backtrace {
851    trace.resolve();
852    let mut frames: Vec<BacktraceFrame> = trace.into();
853    let waker_clone_frame_ix = frames.iter().position(|frame| {
854        frame.symbols().iter().any(|symbol| {
855            symbol
856                .name()
857                .is_some_and(|name| format!("{name:#?}") == type_name_of_val(&Waker::clone))
858        })
859    });
860
861    if let Some(waker_clone_frame_ix) = waker_clone_frame_ix {
862        frames.drain(..waker_clone_frame_ix + 1);
863    }
864
865    Backtrace::from(frames)
866}