test_scheduler.rs

  1use crate::{
  2    BackgroundExecutor, Clock, ForegroundExecutor, Priority, RunnableMeta, Scheduler, SessionId,
  3    TestClock, Timer,
  4};
  5use async_task::Runnable;
  6use backtrace::{Backtrace, BacktraceFrame};
  7use futures::channel::oneshot;
  8use parking_lot::{Mutex, MutexGuard};
  9use rand::{
 10    distr::{StandardUniform, uniform::SampleRange, uniform::SampleUniform},
 11    prelude::*,
 12};
 13use std::{
 14    any::type_name_of_val,
 15    collections::{BTreeMap, HashSet, VecDeque},
 16    env,
 17    fmt::Write,
 18    future::Future,
 19    mem,
 20    ops::RangeInclusive,
 21    panic::{self, AssertUnwindSafe},
 22    pin::Pin,
 23    sync::{
 24        Arc,
 25        atomic::{AtomicBool, Ordering::SeqCst},
 26    },
 27    task::{Context, Poll, RawWaker, RawWakerVTable, Waker},
 28    thread::{self, Thread},
 29    time::{Duration, Instant},
 30};
 31
 32const PENDING_TRACES_VAR_NAME: &str = "PENDING_TRACES";
 33
 34pub struct TestScheduler {
 35    clock: Arc<TestClock>,
 36    rng: Arc<Mutex<StdRng>>,
 37    state: Arc<Mutex<SchedulerState>>,
 38    thread: Thread,
 39}
 40
 41impl TestScheduler {
 42    /// Run a test once with default configuration (seed 0)
 43    pub fn once<R>(f: impl AsyncFnOnce(Arc<TestScheduler>) -> R) -> R {
 44        Self::with_seed(0, f)
 45    }
 46
 47    /// Run a test multiple times with sequential seeds (0, 1, 2, ...)
 48    pub fn many<R>(
 49        default_iterations: usize,
 50        mut f: impl AsyncFnMut(Arc<TestScheduler>) -> R,
 51    ) -> Vec<R> {
 52        let num_iterations = std::env::var("ITERATIONS")
 53            .map(|iterations| iterations.parse().unwrap())
 54            .unwrap_or(default_iterations);
 55
 56        let seed = std::env::var("SEED")
 57            .map(|seed| seed.parse().unwrap())
 58            .unwrap_or(0);
 59
 60        (seed..num_iterations as u64)
 61            .map(|seed| {
 62                let mut unwind_safe_f = AssertUnwindSafe(&mut f);
 63                eprintln!("Running seed: {seed}");
 64                match panic::catch_unwind(move || Self::with_seed(seed, &mut *unwind_safe_f)) {
 65                    Ok(result) => result,
 66                    Err(error) => {
 67                        eprintln!("\x1b[31mFailing Seed: {seed}\x1b[0m");
 68                        panic::resume_unwind(error);
 69                    }
 70                }
 71            })
 72            .collect()
 73    }
 74
 75    fn with_seed<R>(seed: u64, f: impl AsyncFnOnce(Arc<TestScheduler>) -> R) -> R {
 76        let scheduler = Arc::new(TestScheduler::new(TestSchedulerConfig::with_seed(seed)));
 77        let future = f(scheduler.clone());
 78        let result = scheduler.foreground().block_on(future);
 79        scheduler.run(); // Ensure spawned tasks finish up before returning in tests
 80        result
 81    }
 82
 83    pub fn new(config: TestSchedulerConfig) -> Self {
 84        Self {
 85            rng: Arc::new(Mutex::new(StdRng::seed_from_u64(config.seed))),
 86            state: Arc::new(Mutex::new(SchedulerState {
 87                runnables: VecDeque::new(),
 88                timers: Vec::new(),
 89                blocked_sessions: Vec::new(),
 90                randomize_order: config.randomize_order,
 91                allow_parking: config.allow_parking,
 92                timeout_ticks: config.timeout_ticks,
 93                next_session_id: SessionId(0),
 94                capture_pending_traces: config.capture_pending_traces,
 95                pending_traces: BTreeMap::new(),
 96                next_trace_id: TraceId(0),
 97                is_main_thread: true,
 98                non_determinism_error: None,
 99                finished: false,
100                parking_allowed_once: false,
101                unparked: false,
102            })),
103            clock: Arc::new(TestClock::new()),
104            thread: thread::current(),
105        }
106    }
107
108    pub fn end_test(&self) {
109        let mut state = self.state.lock();
110        if let Some((message, backtrace)) = &state.non_determinism_error {
111            panic!("{}\n{:?}", message, backtrace)
112        }
113        state.finished = true;
114    }
115
116    pub fn clock(&self) -> Arc<TestClock> {
117        self.clock.clone()
118    }
119
120    pub fn rng(&self) -> SharedRng {
121        SharedRng(self.rng.clone())
122    }
123
124    pub fn set_timeout_ticks(&self, timeout_ticks: RangeInclusive<usize>) {
125        self.state.lock().timeout_ticks = timeout_ticks;
126    }
127
128    pub fn allow_parking(&self) {
129        let mut state = self.state.lock();
130        state.allow_parking = true;
131        state.parking_allowed_once = true;
132    }
133
134    pub fn forbid_parking(&self) {
135        self.state.lock().allow_parking = false;
136    }
137
138    pub fn parking_allowed(&self) -> bool {
139        self.state.lock().allow_parking
140    }
141
142    pub fn is_main_thread(&self) -> bool {
143        self.state.lock().is_main_thread
144    }
145
146    /// Allocate a new session ID for foreground task scheduling.
147    /// This is used by GPUI's TestDispatcher to map dispatcher instances to sessions.
148    pub fn allocate_session_id(&self) -> SessionId {
149        let mut state = self.state.lock();
150        state.next_session_id.0 += 1;
151        state.next_session_id
152    }
153
154    /// Create a foreground executor for this scheduler
155    pub fn foreground(self: &Arc<Self>) -> ForegroundExecutor {
156        let session_id = self.allocate_session_id();
157        ForegroundExecutor::new(session_id, self.clone())
158    }
159
160    /// Create a background executor for this scheduler
161    pub fn background(self: &Arc<Self>) -> BackgroundExecutor {
162        BackgroundExecutor::new(self.clone())
163    }
164
165    pub fn yield_random(&self) -> Yield {
166        let rng = &mut *self.rng.lock();
167        if rng.random_bool(0.1) {
168            Yield(rng.random_range(10..20))
169        } else {
170            Yield(rng.random_range(0..2))
171        }
172    }
173
174    pub fn run(&self) {
175        while self.step() {
176            // Continue until no work remains
177        }
178    }
179
180    /// Execute one tick of the scheduler, processing expired timers and running
181    /// at most one task. Returns true if any work was done.
182    ///
183    /// This is the public interface for GPUI's TestDispatcher to drive task execution.
184    pub fn tick(&self) -> bool {
185        self.step_filtered(false)
186    }
187
188    /// Execute one tick, but only run background tasks (no foreground/session tasks).
189    /// Returns true if any work was done.
190    pub fn tick_background_only(&self) -> bool {
191        self.step_filtered(true)
192    }
193
194    /// Check if there are any pending tasks or timers that could run.
195    pub fn has_pending_tasks(&self) -> bool {
196        let state = self.state.lock();
197        !state.runnables.is_empty() || !state.timers.is_empty()
198    }
199
200    /// Returns counts of (foreground_tasks, background_tasks) currently queued.
201    /// Foreground tasks are those with a session_id, background tasks have none.
202    pub fn pending_task_counts(&self) -> (usize, usize) {
203        let state = self.state.lock();
204        let foreground = state
205            .runnables
206            .iter()
207            .filter(|r| r.session_id.is_some())
208            .count();
209        let background = state
210            .runnables
211            .iter()
212            .filter(|r| r.session_id.is_none())
213            .count();
214        (foreground, background)
215    }
216
217    fn step(&self) -> bool {
218        self.step_filtered(false)
219    }
220
221    fn step_filtered(&self, background_only: bool) -> bool {
222        let (elapsed_count, runnables_before) = {
223            let mut state = self.state.lock();
224            let end_ix = state
225                .timers
226                .partition_point(|timer| timer.expiration <= self.clock.now());
227            let elapsed: Vec<_> = state.timers.drain(..end_ix).collect();
228            let count = elapsed.len();
229            let runnables = state.runnables.len();
230            drop(state);
231            // Dropping elapsed timers here wakes the waiting futures
232            drop(elapsed);
233            (count, runnables)
234        };
235
236        if elapsed_count > 0 {
237            let runnables_after = self.state.lock().runnables.len();
238            if std::env::var("DEBUG_SCHEDULER").is_ok() {
239                eprintln!(
240                    "[scheduler] Expired {} timers at {:?}, runnables: {} -> {}",
241                    elapsed_count,
242                    self.clock.now(),
243                    runnables_before,
244                    runnables_after
245                );
246            }
247            return true;
248        }
249
250        let runnable = {
251            let state = &mut *self.state.lock();
252
253            // Find candidate tasks:
254            // - For foreground tasks (with session_id), only the first task from each session
255            //   is a candidate (to preserve intra-session ordering)
256            // - For background tasks (no session_id), all are candidates
257            // - Tasks from blocked sessions are excluded
258            // - If background_only is true, skip foreground tasks entirely
259            let mut seen_sessions = HashSet::new();
260            let candidate_indices: Vec<usize> = state
261                .runnables
262                .iter()
263                .enumerate()
264                .filter(|(_, runnable)| {
265                    if let Some(session_id) = runnable.session_id {
266                        // Skip foreground tasks if background_only mode
267                        if background_only {
268                            return false;
269                        }
270                        // Exclude tasks from blocked sessions
271                        if state.blocked_sessions.contains(&session_id) {
272                            return false;
273                        }
274                        // Only include first task from each session (insert returns true if new)
275                        seen_sessions.insert(session_id)
276                    } else {
277                        // Background tasks are always candidates
278                        true
279                    }
280                })
281                .map(|(ix, _)| ix)
282                .collect();
283
284            if candidate_indices.is_empty() {
285                None
286            } else if state.randomize_order {
287                // Use priority-weighted random selection
288                let weights: Vec<u32> = candidate_indices
289                    .iter()
290                    .map(|&ix| state.runnables[ix].priority.weight())
291                    .collect();
292                let total_weight: u32 = weights.iter().sum();
293
294                if total_weight == 0 {
295                    // Fallback to uniform random if all weights are zero
296                    let choice = self.rng.lock().random_range(0..candidate_indices.len());
297                    state.runnables.remove(candidate_indices[choice])
298                } else {
299                    let mut target = self.rng.lock().random_range(0..total_weight);
300                    let mut selected_idx = 0;
301                    for (i, &weight) in weights.iter().enumerate() {
302                        if target < weight {
303                            selected_idx = i;
304                            break;
305                        }
306                        target -= weight;
307                    }
308                    state.runnables.remove(candidate_indices[selected_idx])
309                }
310            } else {
311                // Non-randomized: just take the first candidate task
312                state.runnables.remove(candidate_indices[0])
313            }
314        };
315
316        if let Some(runnable) = runnable {
317            // Check if the executor that spawned this task was closed
318            if runnable.runnable.metadata().is_closed() {
319                return true;
320            }
321            let is_foreground = runnable.session_id.is_some();
322            let was_main_thread = self.state.lock().is_main_thread;
323            self.state.lock().is_main_thread = is_foreground;
324            runnable.run();
325            self.state.lock().is_main_thread = was_main_thread;
326            return true;
327        }
328
329        false
330    }
331
332    pub fn advance_clock_to_next_timer(&self) -> bool {
333        if let Some(timer) = self.state.lock().timers.first() {
334            self.clock.advance(timer.expiration - self.clock.now());
335            true
336        } else {
337            false
338        }
339    }
340
341    pub fn advance_clock(&self, duration: Duration) {
342        let debug = std::env::var("DEBUG_SCHEDULER").is_ok();
343        let start = self.clock.now();
344        let next_now = start + duration;
345        if debug {
346            let timer_count = self.state.lock().timers.len();
347            eprintln!(
348                "[scheduler] advance_clock({:?}) from {:?}, {} pending timers",
349                duration, start, timer_count
350            );
351        }
352        loop {
353            self.run();
354            if let Some(timer) = self.state.lock().timers.first()
355                && timer.expiration <= next_now
356            {
357                let advance_to = timer.expiration;
358                if debug {
359                    eprintln!(
360                        "[scheduler] Advancing clock {:?} -> {:?} for timer",
361                        self.clock.now(),
362                        advance_to
363                    );
364                }
365                self.clock.advance(advance_to - self.clock.now());
366            } else {
367                break;
368            }
369        }
370        self.clock.advance(next_now - self.clock.now());
371        if debug {
372            eprintln!(
373                "[scheduler] advance_clock done, now at {:?}",
374                self.clock.now()
375            );
376        }
377    }
378
379    fn park(&self, deadline: Option<Instant>) -> bool {
380        if self.state.lock().allow_parking {
381            let start = Instant::now();
382            // Enforce a hard timeout to prevent tests from hanging indefinitely
383            let hard_deadline = start + Duration::from_secs(15);
384
385            // Use the earlier of the provided deadline or the hard timeout deadline
386            let effective_deadline = deadline
387                .map(|d| d.min(hard_deadline))
388                .unwrap_or(hard_deadline);
389
390            // Park in small intervals to allow checking both deadlines
391            const PARK_INTERVAL: Duration = Duration::from_millis(100);
392            loop {
393                let now = Instant::now();
394                if now >= effective_deadline {
395                    // Check if we hit the hard timeout
396                    if now >= hard_deadline {
397                        panic!(
398                            "Test timed out after 15 seconds while parking. \
399                            This may indicate a deadlock or missing waker.",
400                        );
401                    }
402                    // Hit the provided deadline
403                    return false;
404                }
405
406                let remaining = effective_deadline.saturating_duration_since(now);
407                let park_duration = remaining.min(PARK_INTERVAL);
408                let before_park = Instant::now();
409                thread::park_timeout(park_duration);
410                let elapsed = before_park.elapsed();
411
412                // Advance the test clock by the real elapsed time while parking
413                self.clock.advance(elapsed);
414
415                // Check if any timers have expired after advancing the clock.
416                // If so, return so the caller can process them.
417                if self
418                    .state
419                    .lock()
420                    .timers
421                    .first()
422                    .map_or(false, |t| t.expiration <= self.clock.now())
423                {
424                    return true;
425                }
426
427                // Check if we were woken up by a different thread.
428                // We use a flag because timing-based detection is unreliable:
429                // OS scheduling delays can cause elapsed >= park_duration even when
430                // we were woken early by unpark().
431                if std::mem::take(&mut self.state.lock().unparked) {
432                    return true;
433                }
434            }
435        } else if deadline.is_some() {
436            false
437        } else if self.state.lock().capture_pending_traces {
438            let mut pending_traces = String::new();
439            for (_, trace) in mem::take(&mut self.state.lock().pending_traces) {
440                writeln!(pending_traces, "{:?}", exclude_wakers_from_trace(trace)).unwrap();
441            }
442            panic!("Parking forbidden. Pending traces:\n{}", pending_traces);
443        } else {
444            panic!(
445                "Parking forbidden. Re-run with {PENDING_TRACES_VAR_NAME}=1 to show pending traces"
446            );
447        }
448    }
449}
450
451fn assert_correct_thread(expected: &Thread, state: &Arc<Mutex<SchedulerState>>) {
452    let current_thread = thread::current();
453    let mut state = state.lock();
454    if state.parking_allowed_once {
455        return;
456    }
457    if current_thread.id() == expected.id() {
458        return;
459    }
460
461    let message = format!(
462        "Detected activity on thread {:?} {:?}, but test scheduler is running on {:?} {:?}. Your test is not deterministic.",
463        current_thread.name(),
464        current_thread.id(),
465        expected.name(),
466        expected.id(),
467    );
468    let backtrace = Backtrace::new();
469    if state.finished {
470        panic!("{}", message);
471    } else {
472        state.non_determinism_error = Some((message, backtrace))
473    }
474}
475
476impl Scheduler for TestScheduler {
477    /// Block until the given future completes, with an optional timeout. If the
478    /// future is unable to make progress at any moment before the timeout and
479    /// no other tasks or timers remain, we panic unless parking is allowed. If
480    /// parking is allowed, we block up to the timeout or indefinitely if none
481    /// is provided. This is to allow testing a mix of deterministic and
482    /// non-deterministic async behavior, such as when interacting with I/O in
483    /// an otherwise deterministic test.
484    fn block(
485        &self,
486        session_id: Option<SessionId>,
487        mut future: Pin<&mut dyn Future<Output = ()>>,
488        timeout: Option<Duration>,
489    ) -> bool {
490        if let Some(session_id) = session_id {
491            self.state.lock().blocked_sessions.push(session_id);
492        }
493
494        let deadline = timeout.map(|timeout| Instant::now() + timeout);
495        let awoken = Arc::new(AtomicBool::new(false));
496        let waker = Box::new(TracingWaker {
497            id: None,
498            awoken: awoken.clone(),
499            thread: self.thread.clone(),
500            state: self.state.clone(),
501        });
502        let waker = unsafe { Waker::new(Box::into_raw(waker) as *const (), &WAKER_VTABLE) };
503        let max_ticks = if timeout.is_some() {
504            self.rng
505                .lock()
506                .random_range(self.state.lock().timeout_ticks.clone())
507        } else {
508            usize::MAX
509        };
510        let mut cx = Context::from_waker(&waker);
511
512        let mut completed = false;
513        for _ in 0..max_ticks {
514            match future.as_mut().poll(&mut cx) {
515                Poll::Ready(()) => {
516                    completed = true;
517                    break;
518                }
519                Poll::Pending => {}
520            }
521
522            let mut stepped = None;
523            while self.rng.lock().random() {
524                let stepped = stepped.get_or_insert(false);
525                if self.step() {
526                    *stepped = true;
527                } else {
528                    break;
529                }
530            }
531
532            let stepped = stepped.unwrap_or(true);
533            let awoken = awoken.swap(false, SeqCst);
534            if !stepped && !awoken {
535                let parking_allowed = self.state.lock().allow_parking;
536                // In deterministic mode (parking forbidden), instantly jump to the next timer.
537                // In non-deterministic mode (parking allowed), let real time pass instead.
538                let advanced_to_timer = !parking_allowed && self.advance_clock_to_next_timer();
539                if !advanced_to_timer && !self.park(deadline) {
540                    break;
541                }
542            }
543        }
544
545        if session_id.is_some() {
546            self.state.lock().blocked_sessions.pop();
547        }
548
549        completed
550    }
551
552    fn schedule_foreground(&self, session_id: SessionId, runnable: Runnable<RunnableMeta>) {
553        assert_correct_thread(&self.thread, &self.state);
554        let mut state = self.state.lock();
555        let ix = if state.randomize_order {
556            let start_ix = state
557                .runnables
558                .iter()
559                .rposition(|task| task.session_id == Some(session_id))
560                .map_or(0, |ix| ix + 1);
561            self.rng
562                .lock()
563                .random_range(start_ix..=state.runnables.len())
564        } else {
565            state.runnables.len()
566        };
567        state.runnables.insert(
568            ix,
569            ScheduledRunnable {
570                session_id: Some(session_id),
571                priority: Priority::default(),
572                runnable,
573            },
574        );
575        state.unparked = true;
576        drop(state);
577        self.thread.unpark();
578    }
579
580    fn schedule_background_with_priority(
581        &self,
582        runnable: Runnable<RunnableMeta>,
583        priority: Priority,
584    ) {
585        assert_correct_thread(&self.thread, &self.state);
586        let mut state = self.state.lock();
587        let ix = if state.randomize_order {
588            self.rng.lock().random_range(0..=state.runnables.len())
589        } else {
590            state.runnables.len()
591        };
592        state.runnables.insert(
593            ix,
594            ScheduledRunnable {
595                session_id: None,
596                priority,
597                runnable,
598            },
599        );
600        state.unparked = true;
601        drop(state);
602        self.thread.unpark();
603    }
604
605    fn spawn_realtime(&self, f: Box<dyn FnOnce() + Send>) {
606        std::thread::spawn(move || {
607            f();
608        });
609    }
610
611    fn timer(&self, duration: Duration) -> Timer {
612        let (tx, rx) = oneshot::channel();
613        let state = &mut *self.state.lock();
614        state.timers.push(ScheduledTimer {
615            expiration: self.clock.now() + duration,
616            _notify: tx,
617        });
618        state.timers.sort_by_key(|timer| timer.expiration);
619        Timer(rx)
620    }
621
622    fn clock(&self) -> Arc<dyn Clock> {
623        self.clock.clone()
624    }
625
626    fn as_test(&self) -> Option<&TestScheduler> {
627        Some(self)
628    }
629}
630
631#[derive(Clone, Debug)]
632pub struct TestSchedulerConfig {
633    pub seed: u64,
634    pub randomize_order: bool,
635    pub allow_parking: bool,
636    pub capture_pending_traces: bool,
637    pub timeout_ticks: RangeInclusive<usize>,
638}
639
640impl TestSchedulerConfig {
641    pub fn with_seed(seed: u64) -> Self {
642        Self {
643            seed,
644            ..Default::default()
645        }
646    }
647}
648
649impl Default for TestSchedulerConfig {
650    fn default() -> Self {
651        Self {
652            seed: 0,
653            randomize_order: true,
654            allow_parking: false,
655            capture_pending_traces: env::var(PENDING_TRACES_VAR_NAME)
656                .map_or(false, |var| var == "1" || var == "true"),
657            timeout_ticks: 1..=1000,
658        }
659    }
660}
661
662struct ScheduledRunnable {
663    session_id: Option<SessionId>,
664    priority: Priority,
665    runnable: Runnable<RunnableMeta>,
666}
667
668impl ScheduledRunnable {
669    fn run(self) {
670        self.runnable.run();
671    }
672}
673
674struct ScheduledTimer {
675    expiration: Instant,
676    _notify: oneshot::Sender<()>,
677}
678
679struct SchedulerState {
680    runnables: VecDeque<ScheduledRunnable>,
681    timers: Vec<ScheduledTimer>,
682    blocked_sessions: Vec<SessionId>,
683    randomize_order: bool,
684    allow_parking: bool,
685    timeout_ticks: RangeInclusive<usize>,
686    next_session_id: SessionId,
687    capture_pending_traces: bool,
688    next_trace_id: TraceId,
689    pending_traces: BTreeMap<TraceId, Backtrace>,
690    is_main_thread: bool,
691    non_determinism_error: Option<(String, Backtrace)>,
692    parking_allowed_once: bool,
693    finished: bool,
694    unparked: bool,
695}
696
697const WAKER_VTABLE: RawWakerVTable = RawWakerVTable::new(
698    TracingWaker::clone_raw,
699    TracingWaker::wake_raw,
700    TracingWaker::wake_by_ref_raw,
701    TracingWaker::drop_raw,
702);
703
704#[derive(Copy, Clone, Eq, PartialEq, PartialOrd, Ord)]
705struct TraceId(usize);
706
707struct TracingWaker {
708    id: Option<TraceId>,
709    awoken: Arc<AtomicBool>,
710    thread: Thread,
711    state: Arc<Mutex<SchedulerState>>,
712}
713
714impl Clone for TracingWaker {
715    fn clone(&self) -> Self {
716        let mut state = self.state.lock();
717        let id = if state.capture_pending_traces {
718            let id = state.next_trace_id;
719            state.next_trace_id.0 += 1;
720            state.pending_traces.insert(id, Backtrace::new_unresolved());
721            Some(id)
722        } else {
723            None
724        };
725        Self {
726            id,
727            awoken: self.awoken.clone(),
728            thread: self.thread.clone(),
729            state: self.state.clone(),
730        }
731    }
732}
733
734impl Drop for TracingWaker {
735    fn drop(&mut self) {
736        assert_correct_thread(&self.thread, &self.state);
737
738        if let Some(id) = self.id {
739            self.state.lock().pending_traces.remove(&id);
740        }
741    }
742}
743
744impl TracingWaker {
745    fn wake(self) {
746        self.wake_by_ref();
747    }
748
749    fn wake_by_ref(&self) {
750        assert_correct_thread(&self.thread, &self.state);
751
752        let mut state = self.state.lock();
753        if let Some(id) = self.id {
754            state.pending_traces.remove(&id);
755        }
756        state.unparked = true;
757        drop(state);
758        self.awoken.store(true, SeqCst);
759        self.thread.unpark();
760    }
761
762    fn clone_raw(waker: *const ()) -> RawWaker {
763        let waker = waker as *const TracingWaker;
764        let waker = unsafe { &*waker };
765        RawWaker::new(
766            Box::into_raw(Box::new(waker.clone())) as *const (),
767            &WAKER_VTABLE,
768        )
769    }
770
771    fn wake_raw(waker: *const ()) {
772        let waker = unsafe { Box::from_raw(waker as *mut TracingWaker) };
773        waker.wake();
774    }
775
776    fn wake_by_ref_raw(waker: *const ()) {
777        let waker = waker as *const TracingWaker;
778        let waker = unsafe { &*waker };
779        waker.wake_by_ref();
780    }
781
782    fn drop_raw(waker: *const ()) {
783        let waker = unsafe { Box::from_raw(waker as *mut TracingWaker) };
784        drop(waker);
785    }
786}
787
788pub struct Yield(usize);
789
790/// A wrapper around `Arc<Mutex<StdRng>>` that provides convenient methods
791/// for random number generation without requiring explicit locking.
792#[derive(Clone)]
793pub struct SharedRng(Arc<Mutex<StdRng>>);
794
795impl SharedRng {
796    /// Lock the inner RNG for direct access. Use this when you need multiple
797    /// random operations without re-locking between each one.
798    pub fn lock(&self) -> MutexGuard<'_, StdRng> {
799        self.0.lock()
800    }
801
802    /// Generate a random value in the given range.
803    pub fn random_range<T, R>(&self, range: R) -> T
804    where
805        T: SampleUniform,
806        R: SampleRange<T>,
807    {
808        self.0.lock().random_range(range)
809    }
810
811    /// Generate a random boolean with the given probability of being true.
812    pub fn random_bool(&self, p: f64) -> bool {
813        self.0.lock().random_bool(p)
814    }
815
816    /// Generate a random value of the given type.
817    pub fn random<T>(&self) -> T
818    where
819        StandardUniform: Distribution<T>,
820    {
821        self.0.lock().random()
822    }
823
824    /// Generate a random ratio - true with probability `numerator/denominator`.
825    pub fn random_ratio(&self, numerator: u32, denominator: u32) -> bool {
826        self.0.lock().random_ratio(numerator, denominator)
827    }
828}
829
830impl Future for Yield {
831    type Output = ();
832
833    fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
834        if self.0 == 0 {
835            Poll::Ready(())
836        } else {
837            self.0 -= 1;
838            cx.waker().wake_by_ref();
839            Poll::Pending
840        }
841    }
842}
843
844fn exclude_wakers_from_trace(mut trace: Backtrace) -> Backtrace {
845    trace.resolve();
846    let mut frames: Vec<BacktraceFrame> = trace.into();
847    let waker_clone_frame_ix = frames.iter().position(|frame| {
848        frame.symbols().iter().any(|symbol| {
849            symbol
850                .name()
851                .is_some_and(|name| format!("{name:#?}") == type_name_of_val(&Waker::clone))
852        })
853    });
854
855    if let Some(waker_clone_frame_ix) = waker_clone_frame_ix {
856        frames.drain(..waker_clone_frame_ix + 1);
857    }
858
859    Backtrace::from(frames)
860}