test_scheduler.rs

  1use crate::{
  2    BackgroundExecutor, Clock, ForegroundExecutor, Instant, Priority, RunnableMeta, Scheduler,
  3    SessionId, TestClock, Timer,
  4};
  5use async_task::Runnable;
  6use backtrace::{Backtrace, BacktraceFrame};
  7use futures::channel::oneshot;
  8use parking_lot::{Mutex, MutexGuard};
  9use rand::{
 10    distr::{StandardUniform, uniform::SampleRange, uniform::SampleUniform},
 11    prelude::*,
 12};
 13use std::{
 14    any::type_name_of_val,
 15    collections::{BTreeMap, HashSet, VecDeque},
 16    env,
 17    fmt::Write,
 18    future::Future,
 19    mem,
 20    ops::RangeInclusive,
 21    panic::{self, AssertUnwindSafe},
 22    pin::Pin,
 23    sync::{
 24        Arc,
 25        atomic::{AtomicBool, Ordering::SeqCst},
 26    },
 27    task::{Context, Poll, RawWaker, RawWakerVTable, Waker},
 28    thread::{self, Thread},
 29    time::Duration,
 30};
 31
 32const PENDING_TRACES_VAR_NAME: &str = "PENDING_TRACES";
 33
 34pub struct TestScheduler {
 35    clock: Arc<TestClock>,
 36    rng: Arc<Mutex<StdRng>>,
 37    state: Arc<Mutex<SchedulerState>>,
 38    thread: Thread,
 39}
 40
 41impl TestScheduler {
 42    /// Run a test once with default configuration (seed 0)
 43    pub fn once<R>(f: impl AsyncFnOnce(Arc<TestScheduler>) -> R) -> R {
 44        Self::with_seed(0, f)
 45    }
 46
 47    /// Run a test multiple times with sequential seeds (0, 1, 2, ...)
 48    pub fn many<R>(
 49        default_iterations: usize,
 50        mut f: impl AsyncFnMut(Arc<TestScheduler>) -> R,
 51    ) -> Vec<R> {
 52        let num_iterations = std::env::var("ITERATIONS")
 53            .map(|iterations| iterations.parse().unwrap())
 54            .unwrap_or(default_iterations);
 55
 56        let seed = std::env::var("SEED")
 57            .map(|seed| seed.parse().unwrap())
 58            .unwrap_or(0);
 59
 60        (seed..num_iterations as u64)
 61            .map(|seed| {
 62                let mut unwind_safe_f = AssertUnwindSafe(&mut f);
 63                eprintln!("Running seed: {seed}");
 64                match panic::catch_unwind(move || Self::with_seed(seed, &mut *unwind_safe_f)) {
 65                    Ok(result) => result,
 66                    Err(error) => {
 67                        eprintln!("\x1b[31mFailing Seed: {seed}\x1b[0m");
 68                        panic::resume_unwind(error);
 69                    }
 70                }
 71            })
 72            .collect()
 73    }
 74
 75    fn with_seed<R>(seed: u64, f: impl AsyncFnOnce(Arc<TestScheduler>) -> R) -> R {
 76        let scheduler = Arc::new(TestScheduler::new(TestSchedulerConfig::with_seed(seed)));
 77        let future = f(scheduler.clone());
 78        let result = scheduler.foreground().block_on(future);
 79        scheduler.run(); // Ensure spawned tasks finish up before returning in tests
 80        result
 81    }
 82
 83    pub fn new(config: TestSchedulerConfig) -> Self {
 84        Self {
 85            rng: Arc::new(Mutex::new(StdRng::seed_from_u64(config.seed))),
 86            state: Arc::new(Mutex::new(SchedulerState {
 87                runnables: VecDeque::new(),
 88                timers: Vec::new(),
 89                blocked_sessions: Vec::new(),
 90                randomize_order: config.randomize_order,
 91                allow_parking: config.allow_parking,
 92                timeout_ticks: config.timeout_ticks,
 93                next_session_id: SessionId(0),
 94                capture_pending_traces: config.capture_pending_traces,
 95                pending_traces: BTreeMap::new(),
 96                next_trace_id: TraceId(0),
 97                is_main_thread: true,
 98                non_determinism_error: None,
 99                finished: false,
100                parking_allowed_once: false,
101                unparked: false,
102            })),
103            clock: Arc::new(TestClock::new()),
104            thread: thread::current(),
105        }
106    }
107
108    pub fn end_test(&self) {
109        let mut state = self.state.lock();
110        if let Some((message, backtrace)) = &state.non_determinism_error {
111            panic!("{}\n{:?}", message, backtrace)
112        }
113        state.finished = true;
114    }
115
116    pub fn clock(&self) -> Arc<TestClock> {
117        self.clock.clone()
118    }
119
120    pub fn rng(&self) -> SharedRng {
121        SharedRng(self.rng.clone())
122    }
123
124    pub fn set_timeout_ticks(&self, timeout_ticks: RangeInclusive<usize>) {
125        self.state.lock().timeout_ticks = timeout_ticks;
126    }
127
128    pub fn allow_parking(&self) {
129        let mut state = self.state.lock();
130        state.allow_parking = true;
131        state.parking_allowed_once = true;
132    }
133
134    pub fn forbid_parking(&self) {
135        self.state.lock().allow_parking = false;
136    }
137
138    pub fn parking_allowed(&self) -> bool {
139        self.state.lock().allow_parking
140    }
141
142    pub fn is_main_thread(&self) -> bool {
143        self.state.lock().is_main_thread
144    }
145
146    /// Allocate a new session ID for foreground task scheduling.
147    /// This is used by GPUI's TestDispatcher to map dispatcher instances to sessions.
148    pub fn allocate_session_id(&self) -> SessionId {
149        let mut state = self.state.lock();
150        state.next_session_id.0 += 1;
151        state.next_session_id
152    }
153
154    /// Create a foreground executor for this scheduler
155    pub fn foreground(self: &Arc<Self>) -> ForegroundExecutor {
156        let session_id = self.allocate_session_id();
157        ForegroundExecutor::new(session_id, self.clone())
158    }
159
160    /// Create a background executor for this scheduler
161    pub fn background(self: &Arc<Self>) -> BackgroundExecutor {
162        BackgroundExecutor::new(self.clone())
163    }
164
165    pub fn yield_random(&self) -> Yield {
166        let rng = &mut *self.rng.lock();
167        if rng.random_bool(0.1) {
168            Yield(rng.random_range(10..20))
169        } else {
170            Yield(rng.random_range(0..2))
171        }
172    }
173
174    pub fn run(&self) {
175        while self.step() {
176            // Continue until no work remains
177        }
178    }
179
180    pub fn run_with_clock_advancement(&self) {
181        while self.step() || self.advance_clock_to_next_timer() {
182            // Continue until no work remains
183        }
184    }
185
186    /// Execute one tick of the scheduler, processing expired timers and running
187    /// at most one task. Returns true if any work was done.
188    ///
189    /// This is the public interface for GPUI's TestDispatcher to drive task execution.
190    pub fn tick(&self) -> bool {
191        self.step_filtered(false)
192    }
193
194    /// Execute one tick, but only run background tasks (no foreground/session tasks).
195    /// Returns true if any work was done.
196    pub fn tick_background_only(&self) -> bool {
197        self.step_filtered(true)
198    }
199
200    /// Check if there are any pending tasks or timers that could run.
201    pub fn has_pending_tasks(&self) -> bool {
202        let state = self.state.lock();
203        !state.runnables.is_empty() || !state.timers.is_empty()
204    }
205
206    /// Returns counts of (foreground_tasks, background_tasks) currently queued.
207    /// Foreground tasks are those with a session_id, background tasks have none.
208    pub fn pending_task_counts(&self) -> (usize, usize) {
209        let state = self.state.lock();
210        let foreground = state
211            .runnables
212            .iter()
213            .filter(|r| r.session_id.is_some())
214            .count();
215        let background = state
216            .runnables
217            .iter()
218            .filter(|r| r.session_id.is_none())
219            .count();
220        (foreground, background)
221    }
222
223    fn step(&self) -> bool {
224        self.step_filtered(false)
225    }
226
227    fn step_filtered(&self, background_only: bool) -> bool {
228        let (elapsed_count, runnables_before) = {
229            let mut state = self.state.lock();
230            let end_ix = state
231                .timers
232                .partition_point(|timer| timer.expiration <= self.clock.now());
233            let elapsed: Vec<_> = state.timers.drain(..end_ix).collect();
234            let count = elapsed.len();
235            let runnables = state.runnables.len();
236            drop(state);
237            // Dropping elapsed timers here wakes the waiting futures
238            drop(elapsed);
239            (count, runnables)
240        };
241
242        if elapsed_count > 0 {
243            let runnables_after = self.state.lock().runnables.len();
244            if std::env::var("DEBUG_SCHEDULER").is_ok() {
245                eprintln!(
246                    "[scheduler] Expired {} timers at {:?}, runnables: {} -> {}",
247                    elapsed_count,
248                    self.clock.now(),
249                    runnables_before,
250                    runnables_after
251                );
252            }
253            return true;
254        }
255
256        let runnable = {
257            let state = &mut *self.state.lock();
258
259            // Find candidate tasks:
260            // - For foreground tasks (with session_id), only the first task from each session
261            //   is a candidate (to preserve intra-session ordering)
262            // - For background tasks (no session_id), all are candidates
263            // - Tasks from blocked sessions are excluded
264            // - If background_only is true, skip foreground tasks entirely
265            let mut seen_sessions = HashSet::new();
266            let candidate_indices: Vec<usize> = state
267                .runnables
268                .iter()
269                .enumerate()
270                .filter(|(_, runnable)| {
271                    if let Some(session_id) = runnable.session_id {
272                        // Skip foreground tasks if background_only mode
273                        if background_only {
274                            return false;
275                        }
276                        // Exclude tasks from blocked sessions
277                        if state.blocked_sessions.contains(&session_id) {
278                            return false;
279                        }
280                        // Only include first task from each session (insert returns true if new)
281                        seen_sessions.insert(session_id)
282                    } else {
283                        // Background tasks are always candidates
284                        true
285                    }
286                })
287                .map(|(ix, _)| ix)
288                .collect();
289
290            if candidate_indices.is_empty() {
291                None
292            } else if state.randomize_order {
293                // Use priority-weighted random selection
294                let weights: Vec<u32> = candidate_indices
295                    .iter()
296                    .map(|&ix| state.runnables[ix].priority.weight())
297                    .collect();
298                let total_weight: u32 = weights.iter().sum();
299
300                if total_weight == 0 {
301                    // Fallback to uniform random if all weights are zero
302                    let choice = self.rng.lock().random_range(0..candidate_indices.len());
303                    state.runnables.remove(candidate_indices[choice])
304                } else {
305                    let mut target = self.rng.lock().random_range(0..total_weight);
306                    let mut selected_idx = 0;
307                    for (i, &weight) in weights.iter().enumerate() {
308                        if target < weight {
309                            selected_idx = i;
310                            break;
311                        }
312                        target -= weight;
313                    }
314                    state.runnables.remove(candidate_indices[selected_idx])
315                }
316            } else {
317                // Non-randomized: just take the first candidate task
318                state.runnables.remove(candidate_indices[0])
319            }
320        };
321
322        if let Some(runnable) = runnable {
323            // Check if the executor that spawned this task was closed
324            if runnable.runnable.metadata().is_closed() {
325                return true;
326            }
327            let is_foreground = runnable.session_id.is_some();
328            let was_main_thread = self.state.lock().is_main_thread;
329            self.state.lock().is_main_thread = is_foreground;
330            runnable.run();
331            self.state.lock().is_main_thread = was_main_thread;
332            return true;
333        }
334
335        false
336    }
337
338    pub fn advance_clock_to_next_timer(&self) -> bool {
339        if let Some(timer) = self.state.lock().timers.first() {
340            self.clock.advance(timer.expiration - self.clock.now());
341            true
342        } else {
343            false
344        }
345    }
346
347    pub fn advance_clock(&self, duration: Duration) {
348        let debug = std::env::var("DEBUG_SCHEDULER").is_ok();
349        let start = self.clock.now();
350        let next_now = start + duration;
351        if debug {
352            let timer_count = self.state.lock().timers.len();
353            eprintln!(
354                "[scheduler] advance_clock({:?}) from {:?}, {} pending timers",
355                duration, start, timer_count
356            );
357        }
358        loop {
359            self.run();
360            if let Some(timer) = self.state.lock().timers.first()
361                && timer.expiration <= next_now
362            {
363                let advance_to = timer.expiration;
364                if debug {
365                    eprintln!(
366                        "[scheduler] Advancing clock {:?} -> {:?} for timer",
367                        self.clock.now(),
368                        advance_to
369                    );
370                }
371                self.clock.advance(advance_to - self.clock.now());
372            } else {
373                break;
374            }
375        }
376        self.clock.advance(next_now - self.clock.now());
377        if debug {
378            eprintln!(
379                "[scheduler] advance_clock done, now at {:?}",
380                self.clock.now()
381            );
382        }
383    }
384
385    fn park(&self, deadline: Option<Instant>) -> bool {
386        if self.state.lock().allow_parking {
387            let start = Instant::now();
388            // Enforce a hard timeout to prevent tests from hanging indefinitely
389            let hard_deadline = start + Duration::from_secs(15);
390
391            // Use the earlier of the provided deadline or the hard timeout deadline
392            let effective_deadline = deadline
393                .map(|d| d.min(hard_deadline))
394                .unwrap_or(hard_deadline);
395
396            // Park in small intervals to allow checking both deadlines
397            const PARK_INTERVAL: Duration = Duration::from_millis(100);
398            loop {
399                let now = Instant::now();
400                if now >= effective_deadline {
401                    // Check if we hit the hard timeout
402                    if now >= hard_deadline {
403                        panic!(
404                            "Test timed out after 15 seconds while parking. \
405                            This may indicate a deadlock or missing waker.",
406                        );
407                    }
408                    // Hit the provided deadline
409                    return false;
410                }
411
412                let remaining = effective_deadline.saturating_duration_since(now);
413                let park_duration = remaining.min(PARK_INTERVAL);
414                let before_park = Instant::now();
415                thread::park_timeout(park_duration);
416                let elapsed = before_park.elapsed();
417
418                // Advance the test clock by the real elapsed time while parking
419                self.clock.advance(elapsed);
420
421                // Check if any timers have expired after advancing the clock.
422                // If so, return so the caller can process them.
423                if self
424                    .state
425                    .lock()
426                    .timers
427                    .first()
428                    .map_or(false, |t| t.expiration <= self.clock.now())
429                {
430                    return true;
431                }
432
433                // Check if we were woken up by a different thread.
434                // We use a flag because timing-based detection is unreliable:
435                // OS scheduling delays can cause elapsed >= park_duration even when
436                // we were woken early by unpark().
437                if std::mem::take(&mut self.state.lock().unparked) {
438                    return true;
439                }
440            }
441        } else if deadline.is_some() {
442            false
443        } else if self.state.lock().capture_pending_traces {
444            let mut pending_traces = String::new();
445            for (_, trace) in mem::take(&mut self.state.lock().pending_traces) {
446                writeln!(pending_traces, "{:?}", exclude_wakers_from_trace(trace)).unwrap();
447            }
448            panic!("Parking forbidden. Pending traces:\n{}", pending_traces);
449        } else {
450            panic!(
451                "Parking forbidden. Re-run with {PENDING_TRACES_VAR_NAME}=1 to show pending traces"
452            );
453        }
454    }
455}
456
457fn assert_correct_thread(expected: &Thread, state: &Arc<Mutex<SchedulerState>>) {
458    let current_thread = thread::current();
459    let mut state = state.lock();
460    if state.parking_allowed_once {
461        return;
462    }
463    if current_thread.id() == expected.id() {
464        return;
465    }
466
467    let message = format!(
468        "Detected activity on thread {:?} {:?}, but test scheduler is running on {:?} {:?}. Your test is not deterministic.",
469        current_thread.name(),
470        current_thread.id(),
471        expected.name(),
472        expected.id(),
473    );
474    let backtrace = Backtrace::new();
475    if state.finished {
476        panic!("{}", message);
477    } else {
478        state.non_determinism_error = Some((message, backtrace))
479    }
480}
481
482impl Scheduler for TestScheduler {
483    /// Block until the given future completes, with an optional timeout. If the
484    /// future is unable to make progress at any moment before the timeout and
485    /// no other tasks or timers remain, we panic unless parking is allowed. If
486    /// parking is allowed, we block up to the timeout or indefinitely if none
487    /// is provided. This is to allow testing a mix of deterministic and
488    /// non-deterministic async behavior, such as when interacting with I/O in
489    /// an otherwise deterministic test.
490    fn block(
491        &self,
492        session_id: Option<SessionId>,
493        mut future: Pin<&mut dyn Future<Output = ()>>,
494        timeout: Option<Duration>,
495    ) -> bool {
496        if let Some(session_id) = session_id {
497            self.state.lock().blocked_sessions.push(session_id);
498        }
499
500        let deadline = timeout.map(|timeout| Instant::now() + timeout);
501        let awoken = Arc::new(AtomicBool::new(false));
502        let waker = Box::new(TracingWaker {
503            id: None,
504            awoken: awoken.clone(),
505            thread: self.thread.clone(),
506            state: self.state.clone(),
507        });
508        let waker = unsafe { Waker::new(Box::into_raw(waker) as *const (), &WAKER_VTABLE) };
509        let max_ticks = if timeout.is_some() {
510            self.rng
511                .lock()
512                .random_range(self.state.lock().timeout_ticks.clone())
513        } else {
514            usize::MAX
515        };
516        let mut cx = Context::from_waker(&waker);
517
518        let mut completed = false;
519        for _ in 0..max_ticks {
520            match future.as_mut().poll(&mut cx) {
521                Poll::Ready(()) => {
522                    completed = true;
523                    break;
524                }
525                Poll::Pending => {}
526            }
527
528            let mut stepped = None;
529            while self.rng.lock().random() {
530                let stepped = stepped.get_or_insert(false);
531                if self.step() {
532                    *stepped = true;
533                } else {
534                    break;
535                }
536            }
537
538            let stepped = stepped.unwrap_or(true);
539            let awoken = awoken.swap(false, SeqCst);
540            if !stepped && !awoken {
541                let parking_allowed = self.state.lock().allow_parking;
542                // In deterministic mode (parking forbidden), instantly jump to the next timer.
543                // In non-deterministic mode (parking allowed), let real time pass instead.
544                let advanced_to_timer = !parking_allowed && self.advance_clock_to_next_timer();
545                if !advanced_to_timer && !self.park(deadline) {
546                    break;
547                }
548            }
549        }
550
551        if session_id.is_some() {
552            self.state.lock().blocked_sessions.pop();
553        }
554
555        completed
556    }
557
558    fn schedule_foreground(&self, session_id: SessionId, runnable: Runnable<RunnableMeta>) {
559        assert_correct_thread(&self.thread, &self.state);
560        let mut state = self.state.lock();
561        let ix = if state.randomize_order {
562            let start_ix = state
563                .runnables
564                .iter()
565                .rposition(|task| task.session_id == Some(session_id))
566                .map_or(0, |ix| ix + 1);
567            self.rng
568                .lock()
569                .random_range(start_ix..=state.runnables.len())
570        } else {
571            state.runnables.len()
572        };
573        state.runnables.insert(
574            ix,
575            ScheduledRunnable {
576                session_id: Some(session_id),
577                priority: Priority::default(),
578                runnable,
579            },
580        );
581        state.unparked = true;
582        drop(state);
583        self.thread.unpark();
584    }
585
586    fn schedule_background_with_priority(
587        &self,
588        runnable: Runnable<RunnableMeta>,
589        priority: Priority,
590    ) {
591        assert_correct_thread(&self.thread, &self.state);
592        let mut state = self.state.lock();
593        let ix = if state.randomize_order {
594            self.rng.lock().random_range(0..=state.runnables.len())
595        } else {
596            state.runnables.len()
597        };
598        state.runnables.insert(
599            ix,
600            ScheduledRunnable {
601                session_id: None,
602                priority,
603                runnable,
604            },
605        );
606        state.unparked = true;
607        drop(state);
608        self.thread.unpark();
609    }
610
611    fn spawn_realtime(&self, f: Box<dyn FnOnce() + Send>) {
612        std::thread::spawn(move || {
613            f();
614        });
615    }
616
617    #[track_caller]
618    fn timer(&self, duration: Duration) -> Timer {
619        let (tx, rx) = oneshot::channel();
620        let state = &mut *self.state.lock();
621        state.timers.push(ScheduledTimer {
622            expiration: self.clock.now() + duration,
623            _notify: tx,
624        });
625        state.timers.sort_by_key(|timer| timer.expiration);
626        Timer(rx)
627    }
628
629    fn clock(&self) -> Arc<dyn Clock> {
630        self.clock.clone()
631    }
632
633    fn as_test(&self) -> Option<&TestScheduler> {
634        Some(self)
635    }
636}
637
638#[derive(Clone, Debug)]
639pub struct TestSchedulerConfig {
640    pub seed: u64,
641    pub randomize_order: bool,
642    pub allow_parking: bool,
643    pub capture_pending_traces: bool,
644    pub timeout_ticks: RangeInclusive<usize>,
645}
646
647impl TestSchedulerConfig {
648    pub fn with_seed(seed: u64) -> Self {
649        Self {
650            seed,
651            ..Default::default()
652        }
653    }
654}
655
656impl Default for TestSchedulerConfig {
657    fn default() -> Self {
658        Self {
659            seed: 0,
660            randomize_order: true,
661            allow_parking: false,
662            capture_pending_traces: env::var(PENDING_TRACES_VAR_NAME)
663                .map_or(false, |var| var == "1" || var == "true"),
664            timeout_ticks: 1..=1000,
665        }
666    }
667}
668
669struct ScheduledRunnable {
670    session_id: Option<SessionId>,
671    priority: Priority,
672    runnable: Runnable<RunnableMeta>,
673}
674
675impl ScheduledRunnable {
676    fn run(self) {
677        self.runnable.run();
678    }
679}
680
681struct ScheduledTimer {
682    expiration: Instant,
683    _notify: oneshot::Sender<()>,
684}
685
686struct SchedulerState {
687    runnables: VecDeque<ScheduledRunnable>,
688    timers: Vec<ScheduledTimer>,
689    blocked_sessions: Vec<SessionId>,
690    randomize_order: bool,
691    allow_parking: bool,
692    timeout_ticks: RangeInclusive<usize>,
693    next_session_id: SessionId,
694    capture_pending_traces: bool,
695    next_trace_id: TraceId,
696    pending_traces: BTreeMap<TraceId, Backtrace>,
697    is_main_thread: bool,
698    non_determinism_error: Option<(String, Backtrace)>,
699    parking_allowed_once: bool,
700    finished: bool,
701    unparked: bool,
702}
703
704const WAKER_VTABLE: RawWakerVTable = RawWakerVTable::new(
705    TracingWaker::clone_raw,
706    TracingWaker::wake_raw,
707    TracingWaker::wake_by_ref_raw,
708    TracingWaker::drop_raw,
709);
710
711#[derive(Copy, Clone, Eq, PartialEq, PartialOrd, Ord)]
712struct TraceId(usize);
713
714struct TracingWaker {
715    id: Option<TraceId>,
716    awoken: Arc<AtomicBool>,
717    thread: Thread,
718    state: Arc<Mutex<SchedulerState>>,
719}
720
721impl Clone for TracingWaker {
722    fn clone(&self) -> Self {
723        let mut state = self.state.lock();
724        let id = if state.capture_pending_traces {
725            let id = state.next_trace_id;
726            state.next_trace_id.0 += 1;
727            state.pending_traces.insert(id, Backtrace::new_unresolved());
728            Some(id)
729        } else {
730            None
731        };
732        Self {
733            id,
734            awoken: self.awoken.clone(),
735            thread: self.thread.clone(),
736            state: self.state.clone(),
737        }
738    }
739}
740
741impl Drop for TracingWaker {
742    fn drop(&mut self) {
743        assert_correct_thread(&self.thread, &self.state);
744
745        if let Some(id) = self.id {
746            self.state.lock().pending_traces.remove(&id);
747        }
748    }
749}
750
751impl TracingWaker {
752    fn wake(self) {
753        self.wake_by_ref();
754    }
755
756    fn wake_by_ref(&self) {
757        assert_correct_thread(&self.thread, &self.state);
758
759        let mut state = self.state.lock();
760        if let Some(id) = self.id {
761            state.pending_traces.remove(&id);
762        }
763        state.unparked = true;
764        drop(state);
765        self.awoken.store(true, SeqCst);
766        self.thread.unpark();
767    }
768
769    fn clone_raw(waker: *const ()) -> RawWaker {
770        let waker = waker as *const TracingWaker;
771        let waker = unsafe { &*waker };
772        RawWaker::new(
773            Box::into_raw(Box::new(waker.clone())) as *const (),
774            &WAKER_VTABLE,
775        )
776    }
777
778    fn wake_raw(waker: *const ()) {
779        let waker = unsafe { Box::from_raw(waker as *mut TracingWaker) };
780        waker.wake();
781    }
782
783    fn wake_by_ref_raw(waker: *const ()) {
784        let waker = waker as *const TracingWaker;
785        let waker = unsafe { &*waker };
786        waker.wake_by_ref();
787    }
788
789    fn drop_raw(waker: *const ()) {
790        let waker = unsafe { Box::from_raw(waker as *mut TracingWaker) };
791        drop(waker);
792    }
793}
794
795pub struct Yield(usize);
796
797/// A wrapper around `Arc<Mutex<StdRng>>` that provides convenient methods
798/// for random number generation without requiring explicit locking.
799#[derive(Clone)]
800pub struct SharedRng(Arc<Mutex<StdRng>>);
801
802impl SharedRng {
803    /// Lock the inner RNG for direct access. Use this when you need multiple
804    /// random operations without re-locking between each one.
805    pub fn lock(&self) -> MutexGuard<'_, StdRng> {
806        self.0.lock()
807    }
808
809    /// Generate a random value in the given range.
810    pub fn random_range<T, R>(&self, range: R) -> T
811    where
812        T: SampleUniform,
813        R: SampleRange<T>,
814    {
815        self.0.lock().random_range(range)
816    }
817
818    /// Generate a random boolean with the given probability of being true.
819    pub fn random_bool(&self, p: f64) -> bool {
820        self.0.lock().random_bool(p)
821    }
822
823    /// Generate a random value of the given type.
824    pub fn random<T>(&self) -> T
825    where
826        StandardUniform: Distribution<T>,
827    {
828        self.0.lock().random()
829    }
830
831    /// Generate a random ratio - true with probability `numerator/denominator`.
832    pub fn random_ratio(&self, numerator: u32, denominator: u32) -> bool {
833        self.0.lock().random_ratio(numerator, denominator)
834    }
835}
836
837impl Future for Yield {
838    type Output = ();
839
840    fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
841        if self.0 == 0 {
842            Poll::Ready(())
843        } else {
844            self.0 -= 1;
845            cx.waker().wake_by_ref();
846            Poll::Pending
847        }
848    }
849}
850
851fn exclude_wakers_from_trace(mut trace: Backtrace) -> Backtrace {
852    trace.resolve();
853    let mut frames: Vec<BacktraceFrame> = trace.into();
854    let waker_clone_frame_ix = frames.iter().position(|frame| {
855        frame.symbols().iter().any(|symbol| {
856            symbol
857                .name()
858                .is_some_and(|name| format!("{name:#?}") == type_name_of_val(&Waker::clone))
859        })
860    });
861
862    if let Some(waker_clone_frame_ix) = waker_clone_frame_ix {
863        frames.drain(..waker_clone_frame_ix + 1);
864    }
865
866    Backtrace::from(frames)
867}