test_scheduler.rs

  1use crate::{
  2    BackgroundExecutor, Clock, ForegroundExecutor, Instant, Priority, RunnableMeta, Scheduler,
  3    SessionId, TestClock, Timer,
  4};
  5use async_task::Runnable;
  6use backtrace::{Backtrace, BacktraceFrame};
  7use futures::channel::oneshot;
  8use parking_lot::{Mutex, MutexGuard};
  9use rand::{
 10    distr::{StandardUniform, uniform::SampleRange, uniform::SampleUniform},
 11    prelude::*,
 12};
 13use std::{
 14    any::type_name_of_val,
 15    collections::{BTreeMap, HashSet, VecDeque},
 16    env,
 17    fmt::Write,
 18    future::Future,
 19    mem,
 20    ops::RangeInclusive,
 21    panic::{self, AssertUnwindSafe},
 22    pin::Pin,
 23    sync::{
 24        Arc,
 25        atomic::{AtomicBool, Ordering::SeqCst},
 26    },
 27    task::{Context, Poll, RawWaker, RawWakerVTable, Waker},
 28    thread::{self, Thread},
 29    time::Duration,
 30};
 31
 32const PENDING_TRACES_VAR_NAME: &str = "PENDING_TRACES";
 33
 34pub struct TestScheduler {
 35    clock: Arc<TestClock>,
 36    rng: Arc<Mutex<StdRng>>,
 37    state: Arc<Mutex<SchedulerState>>,
 38    thread: Thread,
 39}
 40
 41impl TestScheduler {
 42    /// Run a test once with default configuration (seed 0)
 43    pub fn once<R>(f: impl AsyncFnOnce(Arc<TestScheduler>) -> R) -> R {
 44        Self::with_seed(0, f)
 45    }
 46
 47    /// Run a test multiple times with sequential seeds (0, 1, 2, ...)
 48    pub fn many<R>(
 49        default_iterations: usize,
 50        mut f: impl AsyncFnMut(Arc<TestScheduler>) -> R,
 51    ) -> Vec<R> {
 52        let num_iterations = std::env::var("ITERATIONS")
 53            .map(|iterations| iterations.parse().unwrap())
 54            .unwrap_or(default_iterations);
 55
 56        let seed = std::env::var("SEED")
 57            .map(|seed| seed.parse().unwrap())
 58            .unwrap_or(0);
 59
 60        (seed..seed + num_iterations as u64)
 61            .map(|seed| {
 62                let mut unwind_safe_f = AssertUnwindSafe(&mut f);
 63                eprintln!("Running seed: {seed}");
 64                match panic::catch_unwind(move || Self::with_seed(seed, &mut *unwind_safe_f)) {
 65                    Ok(result) => result,
 66                    Err(error) => {
 67                        eprintln!("\x1b[31mFailing Seed: {seed}\x1b[0m");
 68                        panic::resume_unwind(error);
 69                    }
 70                }
 71            })
 72            .collect()
 73    }
 74
 75    fn with_seed<R>(seed: u64, f: impl AsyncFnOnce(Arc<TestScheduler>) -> R) -> R {
 76        let scheduler = Arc::new(TestScheduler::new(TestSchedulerConfig::with_seed(seed)));
 77        let future = f(scheduler.clone());
 78        let result = scheduler.foreground().block_on(future);
 79        scheduler.run(); // Ensure spawned tasks finish up before returning in tests
 80        result
 81    }
 82
 83    pub fn new(config: TestSchedulerConfig) -> Self {
 84        Self {
 85            rng: Arc::new(Mutex::new(StdRng::seed_from_u64(config.seed))),
 86            state: Arc::new(Mutex::new(SchedulerState {
 87                runnables: VecDeque::new(),
 88                timers: Vec::new(),
 89                blocked_sessions: Vec::new(),
 90                randomize_order: config.randomize_order,
 91                allow_parking: config.allow_parking,
 92                timeout_ticks: config.timeout_ticks,
 93                next_session_id: SessionId(0),
 94                capture_pending_traces: config.capture_pending_traces,
 95                pending_traces: BTreeMap::new(),
 96                next_trace_id: TraceId(0),
 97                is_main_thread: true,
 98                non_determinism_error: None,
 99                finished: false,
100                parking_allowed_once: false,
101                unparked: false,
102            })),
103            clock: Arc::new(TestClock::new()),
104            thread: thread::current(),
105        }
106    }
107
108    pub fn end_test(&self) {
109        let mut state = self.state.lock();
110        if let Some((message, backtrace)) = &state.non_determinism_error {
111            panic!("{}\n{:?}", message, backtrace)
112        }
113        state.finished = true;
114    }
115
116    pub fn clock(&self) -> Arc<TestClock> {
117        self.clock.clone()
118    }
119
120    pub fn rng(&self) -> SharedRng {
121        SharedRng(self.rng.clone())
122    }
123
124    pub fn set_timeout_ticks(&self, timeout_ticks: RangeInclusive<usize>) {
125        self.state.lock().timeout_ticks = timeout_ticks;
126    }
127
128    pub fn allow_parking(&self) {
129        let mut state = self.state.lock();
130        state.allow_parking = true;
131        state.parking_allowed_once = true;
132    }
133
134    pub fn forbid_parking(&self) {
135        self.state.lock().allow_parking = false;
136    }
137
138    pub fn parking_allowed(&self) -> bool {
139        self.state.lock().allow_parking
140    }
141
142    pub fn is_main_thread(&self) -> bool {
143        self.state.lock().is_main_thread
144    }
145
146    /// Allocate a new session ID for foreground task scheduling.
147    /// This is used by GPUI's TestDispatcher to map dispatcher instances to sessions.
148    pub fn allocate_session_id(&self) -> SessionId {
149        let mut state = self.state.lock();
150        state.next_session_id.0 += 1;
151        state.next_session_id
152    }
153
154    /// Create a foreground executor for this scheduler
155    pub fn foreground(self: &Arc<Self>) -> ForegroundExecutor {
156        let session_id = self.allocate_session_id();
157        ForegroundExecutor::new(session_id, self.clone())
158    }
159
160    /// Create a background executor for this scheduler
161    pub fn background(self: &Arc<Self>) -> BackgroundExecutor {
162        BackgroundExecutor::new(self.clone())
163    }
164
165    pub fn yield_random(&self) -> Yield {
166        let rng = &mut *self.rng.lock();
167        if rng.random_bool(0.1) {
168            Yield(rng.random_range(10..20))
169        } else {
170            Yield(rng.random_range(0..2))
171        }
172    }
173
174    pub fn run(&self) {
175        while self.step() {
176            // Continue until no work remains
177        }
178    }
179
180    pub fn run_with_clock_advancement(&self) {
181        while self.step() || self.advance_clock_to_next_timer() {
182            // Continue until no work remains
183        }
184    }
185
186    /// Execute one tick of the scheduler, processing expired timers and running
187    /// at most one task. Returns true if any work was done.
188    ///
189    /// This is the public interface for GPUI's TestDispatcher to drive task execution.
190    pub fn tick(&self) -> bool {
191        self.step_filtered(false)
192    }
193
194    /// Execute one tick, but only run background tasks (no foreground/session tasks).
195    /// Returns true if any work was done.
196    pub fn tick_background_only(&self) -> bool {
197        self.step_filtered(true)
198    }
199
200    /// Check if there are any pending tasks or timers that could run.
201    pub fn has_pending_tasks(&self) -> bool {
202        let state = self.state.lock();
203        !state.runnables.is_empty() || !state.timers.is_empty()
204    }
205
206    /// Returns counts of (foreground_tasks, background_tasks) currently queued.
207    /// Foreground tasks are those with a session_id, background tasks have none.
208    pub fn pending_task_counts(&self) -> (usize, usize) {
209        let state = self.state.lock();
210        let foreground = state
211            .runnables
212            .iter()
213            .filter(|r| r.session_id.is_some())
214            .count();
215        let background = state
216            .runnables
217            .iter()
218            .filter(|r| r.session_id.is_none())
219            .count();
220        (foreground, background)
221    }
222
223    fn step(&self) -> bool {
224        self.step_filtered(false)
225    }
226
227    fn step_filtered(&self, background_only: bool) -> bool {
228        let (elapsed_count, runnables_before) = {
229            let mut state = self.state.lock();
230            let end_ix = state
231                .timers
232                .partition_point(|timer| timer.expiration <= self.clock.now());
233            let elapsed: Vec<_> = state.timers.drain(..end_ix).collect();
234            let count = elapsed.len();
235            let runnables = state.runnables.len();
236            drop(state);
237            // Dropping elapsed timers here wakes the waiting futures
238            drop(elapsed);
239            (count, runnables)
240        };
241
242        if elapsed_count > 0 {
243            let runnables_after = self.state.lock().runnables.len();
244            if std::env::var("DEBUG_SCHEDULER").is_ok() {
245                eprintln!(
246                    "[scheduler] Expired {} timers at {:?}, runnables: {} -> {}",
247                    elapsed_count,
248                    self.clock.now(),
249                    runnables_before,
250                    runnables_after
251                );
252            }
253            return true;
254        }
255
256        let runnable = {
257            let state = &mut *self.state.lock();
258
259            // Find candidate tasks:
260            // - For foreground tasks (with session_id), only the first task from each session
261            //   is a candidate (to preserve intra-session ordering)
262            // - For background tasks (no session_id), all are candidates
263            // - Tasks from blocked sessions are excluded
264            // - If background_only is true, skip foreground tasks entirely
265            let mut seen_sessions = HashSet::new();
266            let candidate_indices: Vec<usize> = state
267                .runnables
268                .iter()
269                .enumerate()
270                .filter(|(_, runnable)| {
271                    if let Some(session_id) = runnable.session_id {
272                        // Skip foreground tasks if background_only mode
273                        if background_only {
274                            return false;
275                        }
276                        // Exclude tasks from blocked sessions
277                        if state.blocked_sessions.contains(&session_id) {
278                            return false;
279                        }
280                        // Only include first task from each session (insert returns true if new)
281                        seen_sessions.insert(session_id)
282                    } else {
283                        // Background tasks are always candidates
284                        true
285                    }
286                })
287                .map(|(ix, _)| ix)
288                .collect();
289
290            if candidate_indices.is_empty() {
291                None
292            } else if state.randomize_order {
293                // Use priority-weighted random selection
294                let weights: Vec<u32> = candidate_indices
295                    .iter()
296                    .map(|&ix| state.runnables[ix].priority.weight())
297                    .collect();
298                let total_weight: u32 = weights.iter().sum();
299
300                if total_weight == 0 {
301                    // Fallback to uniform random if all weights are zero
302                    let choice = self.rng.lock().random_range(0..candidate_indices.len());
303                    state.runnables.remove(candidate_indices[choice])
304                } else {
305                    let mut target = self.rng.lock().random_range(0..total_weight);
306                    let mut selected_idx = 0;
307                    for (i, &weight) in weights.iter().enumerate() {
308                        if target < weight {
309                            selected_idx = i;
310                            break;
311                        }
312                        target -= weight;
313                    }
314                    state.runnables.remove(candidate_indices[selected_idx])
315                }
316            } else {
317                // Non-randomized: just take the first candidate task
318                state.runnables.remove(candidate_indices[0])
319            }
320        };
321
322        if let Some(runnable) = runnable {
323            let is_foreground = runnable.session_id.is_some();
324            let was_main_thread = self.state.lock().is_main_thread;
325            self.state.lock().is_main_thread = is_foreground;
326            runnable.run();
327            self.state.lock().is_main_thread = was_main_thread;
328            return true;
329        }
330
331        false
332    }
333
334    /// Drops all runnable tasks from the scheduler.
335    ///
336    /// This is used by the leak detector to ensure that all tasks have been dropped as tasks may keep entities alive otherwise.
337    /// Why do we even have tasks left when tests finish you may ask. The reason for that is simple, the scheduler itself is the executor and it retains the scheduled runnables.
338    /// A lot of tasks, including every foreground task contain an executor handle that keeps the test scheduler alive, causing a reference cycle, thus the need for this function right now.
339    pub fn drain_tasks(&self) {
340        // dropping runnables may reschedule tasks
341        // due to drop impls with executors in them
342        // so drop until we reach a fixpoint
343        loop {
344            let mut state = self.state.lock();
345            if state.runnables.is_empty() && state.timers.is_empty() {
346                break;
347            }
348            let runnables = std::mem::take(&mut state.runnables);
349            let timers = std::mem::take(&mut state.timers);
350            drop(state);
351            drop(timers);
352            drop(runnables);
353        }
354    }
355
356    pub fn advance_clock_to_next_timer(&self) -> bool {
357        if let Some(timer) = self.state.lock().timers.first() {
358            self.clock.advance(timer.expiration - self.clock.now());
359            true
360        } else {
361            false
362        }
363    }
364
365    pub fn advance_clock(&self, duration: Duration) {
366        let debug = std::env::var("DEBUG_SCHEDULER").is_ok();
367        let start = self.clock.now();
368        let next_now = start + duration;
369        if debug {
370            let timer_count = self.state.lock().timers.len();
371            eprintln!(
372                "[scheduler] advance_clock({:?}) from {:?}, {} pending timers",
373                duration, start, timer_count
374            );
375        }
376        loop {
377            self.run();
378            if let Some(timer) = self.state.lock().timers.first()
379                && timer.expiration <= next_now
380            {
381                let advance_to = timer.expiration;
382                if debug {
383                    eprintln!(
384                        "[scheduler] Advancing clock {:?} -> {:?} for timer",
385                        self.clock.now(),
386                        advance_to
387                    );
388                }
389                self.clock.advance(advance_to - self.clock.now());
390            } else {
391                break;
392            }
393        }
394        self.clock.advance(next_now - self.clock.now());
395        if debug {
396            eprintln!(
397                "[scheduler] advance_clock done, now at {:?}",
398                self.clock.now()
399            );
400        }
401    }
402
403    fn park(&self, deadline: Option<Instant>) -> bool {
404        if self.state.lock().allow_parking {
405            let start = Instant::now();
406            // Enforce a hard timeout to prevent tests from hanging indefinitely
407            let hard_deadline = start + Duration::from_secs(15);
408
409            // Use the earlier of the provided deadline or the hard timeout deadline
410            let effective_deadline = deadline
411                .map(|d| d.min(hard_deadline))
412                .unwrap_or(hard_deadline);
413
414            // Park in small intervals to allow checking both deadlines
415            const PARK_INTERVAL: Duration = Duration::from_millis(100);
416            loop {
417                let now = Instant::now();
418                if now >= effective_deadline {
419                    // Check if we hit the hard timeout
420                    if now >= hard_deadline {
421                        panic!(
422                            "Test timed out after 15 seconds while parking. \
423                            This may indicate a deadlock or missing waker.",
424                        );
425                    }
426                    // Hit the provided deadline
427                    return false;
428                }
429
430                let remaining = effective_deadline.saturating_duration_since(now);
431                let park_duration = remaining.min(PARK_INTERVAL);
432                let before_park = Instant::now();
433                thread::park_timeout(park_duration);
434                let elapsed = before_park.elapsed();
435
436                // Advance the test clock by the real elapsed time while parking
437                self.clock.advance(elapsed);
438
439                // Check if any timers have expired after advancing the clock.
440                // If so, return so the caller can process them.
441                if self
442                    .state
443                    .lock()
444                    .timers
445                    .first()
446                    .map_or(false, |t| t.expiration <= self.clock.now())
447                {
448                    return true;
449                }
450
451                // Check if we were woken up by a different thread.
452                // We use a flag because timing-based detection is unreliable:
453                // OS scheduling delays can cause elapsed >= park_duration even when
454                // we were woken early by unpark().
455                if std::mem::take(&mut self.state.lock().unparked) {
456                    return true;
457                }
458            }
459        } else if deadline.is_some() {
460            false
461        } else if self.state.lock().capture_pending_traces {
462            let mut pending_traces = String::new();
463            for (_, trace) in mem::take(&mut self.state.lock().pending_traces) {
464                writeln!(pending_traces, "{:?}", exclude_wakers_from_trace(trace)).unwrap();
465            }
466            panic!("Parking forbidden. Pending traces:\n{}", pending_traces);
467        } else {
468            panic!(
469                "Parking forbidden. Re-run with {PENDING_TRACES_VAR_NAME}=1 to show pending traces"
470            );
471        }
472    }
473}
474
475fn assert_correct_thread(expected: &Thread, state: &Arc<Mutex<SchedulerState>>) {
476    let current_thread = thread::current();
477    let mut state = state.lock();
478    if state.parking_allowed_once {
479        return;
480    }
481    if current_thread.id() == expected.id() {
482        return;
483    }
484
485    let message = format!(
486        "Detected activity on thread {:?} {:?}, but test scheduler is running on {:?} {:?}. Your test is not deterministic.",
487        current_thread.name(),
488        current_thread.id(),
489        expected.name(),
490        expected.id(),
491    );
492    let backtrace = Backtrace::new();
493    if state.finished {
494        panic!("{}", message);
495    } else {
496        state.non_determinism_error = Some((message, backtrace))
497    }
498}
499
500impl Scheduler for TestScheduler {
501    /// Block until the given future completes, with an optional timeout. If the
502    /// future is unable to make progress at any moment before the timeout and
503    /// no other tasks or timers remain, we panic unless parking is allowed. If
504    /// parking is allowed, we block up to the timeout or indefinitely if none
505    /// is provided. This is to allow testing a mix of deterministic and
506    /// non-deterministic async behavior, such as when interacting with I/O in
507    /// an otherwise deterministic test.
508    fn block(
509        &self,
510        session_id: Option<SessionId>,
511        mut future: Pin<&mut dyn Future<Output = ()>>,
512        timeout: Option<Duration>,
513    ) -> bool {
514        if let Some(session_id) = session_id {
515            self.state.lock().blocked_sessions.push(session_id);
516        }
517
518        let deadline = timeout.map(|timeout| Instant::now() + timeout);
519        let awoken = Arc::new(AtomicBool::new(false));
520        let waker = Box::new(TracingWaker {
521            id: None,
522            awoken: awoken.clone(),
523            thread: self.thread.clone(),
524            state: self.state.clone(),
525        });
526        let waker = unsafe { Waker::new(Box::into_raw(waker) as *const (), &WAKER_VTABLE) };
527        let max_ticks = if timeout.is_some() {
528            self.rng
529                .lock()
530                .random_range(self.state.lock().timeout_ticks.clone())
531        } else {
532            usize::MAX
533        };
534        let mut cx = Context::from_waker(&waker);
535
536        let mut completed = false;
537        for _ in 0..max_ticks {
538            match future.as_mut().poll(&mut cx) {
539                Poll::Ready(()) => {
540                    completed = true;
541                    break;
542                }
543                Poll::Pending => {}
544            }
545
546            let mut stepped = None;
547            while self.rng.lock().random() {
548                let stepped = stepped.get_or_insert(false);
549                if self.step() {
550                    *stepped = true;
551                } else {
552                    break;
553                }
554            }
555
556            let stepped = stepped.unwrap_or(true);
557            let awoken = awoken.swap(false, SeqCst);
558            if !stepped && !awoken {
559                let parking_allowed = self.state.lock().allow_parking;
560                // In deterministic mode (parking forbidden), instantly jump to the next timer.
561                // In non-deterministic mode (parking allowed), let real time pass instead.
562                let advanced_to_timer = !parking_allowed && self.advance_clock_to_next_timer();
563                if !advanced_to_timer && !self.park(deadline) {
564                    break;
565                }
566            }
567        }
568
569        if session_id.is_some() {
570            self.state.lock().blocked_sessions.pop();
571        }
572
573        completed
574    }
575
576    fn schedule_foreground(&self, session_id: SessionId, runnable: Runnable<RunnableMeta>) {
577        assert_correct_thread(&self.thread, &self.state);
578        let mut state = self.state.lock();
579        let ix = if state.randomize_order {
580            let start_ix = state
581                .runnables
582                .iter()
583                .rposition(|task| task.session_id == Some(session_id))
584                .map_or(0, |ix| ix + 1);
585            self.rng
586                .lock()
587                .random_range(start_ix..=state.runnables.len())
588        } else {
589            state.runnables.len()
590        };
591        state.runnables.insert(
592            ix,
593            ScheduledRunnable {
594                session_id: Some(session_id),
595                priority: Priority::default(),
596                runnable,
597            },
598        );
599        state.unparked = true;
600        drop(state);
601        self.thread.unpark();
602    }
603
604    fn schedule_background_with_priority(
605        &self,
606        runnable: Runnable<RunnableMeta>,
607        priority: Priority,
608    ) {
609        assert_correct_thread(&self.thread, &self.state);
610        let mut state = self.state.lock();
611        let ix = if state.randomize_order {
612            self.rng.lock().random_range(0..=state.runnables.len())
613        } else {
614            state.runnables.len()
615        };
616        state.runnables.insert(
617            ix,
618            ScheduledRunnable {
619                session_id: None,
620                priority,
621                runnable,
622            },
623        );
624        state.unparked = true;
625        drop(state);
626        self.thread.unpark();
627    }
628
629    fn spawn_realtime(&self, f: Box<dyn FnOnce() + Send>) {
630        std::thread::spawn(move || {
631            f();
632        });
633    }
634
635    #[track_caller]
636    fn timer(&self, duration: Duration) -> Timer {
637        let (tx, rx) = oneshot::channel();
638        let state = &mut *self.state.lock();
639        state.timers.push(ScheduledTimer {
640            expiration: self.clock.now() + duration,
641            _notify: tx,
642        });
643        state.timers.sort_by_key(|timer| timer.expiration);
644        Timer(rx)
645    }
646
647    fn clock(&self) -> Arc<dyn Clock> {
648        self.clock.clone()
649    }
650
651    fn as_test(&self) -> Option<&TestScheduler> {
652        Some(self)
653    }
654}
655
656#[derive(Clone, Debug)]
657pub struct TestSchedulerConfig {
658    pub seed: u64,
659    pub randomize_order: bool,
660    pub allow_parking: bool,
661    pub capture_pending_traces: bool,
662    pub timeout_ticks: RangeInclusive<usize>,
663}
664
665impl TestSchedulerConfig {
666    pub fn with_seed(seed: u64) -> Self {
667        Self {
668            seed,
669            ..Default::default()
670        }
671    }
672}
673
674impl Default for TestSchedulerConfig {
675    fn default() -> Self {
676        Self {
677            seed: 0,
678            randomize_order: true,
679            allow_parking: false,
680            capture_pending_traces: env::var(PENDING_TRACES_VAR_NAME)
681                .map_or(false, |var| var == "1" || var == "true"),
682            timeout_ticks: 1..=1000,
683        }
684    }
685}
686
687struct ScheduledRunnable {
688    session_id: Option<SessionId>,
689    priority: Priority,
690    runnable: Runnable<RunnableMeta>,
691}
692
693impl ScheduledRunnable {
694    fn run(self) {
695        self.runnable.run();
696    }
697}
698
699struct ScheduledTimer {
700    expiration: Instant,
701    _notify: oneshot::Sender<()>,
702}
703
704struct SchedulerState {
705    runnables: VecDeque<ScheduledRunnable>,
706    timers: Vec<ScheduledTimer>,
707    blocked_sessions: Vec<SessionId>,
708    randomize_order: bool,
709    allow_parking: bool,
710    timeout_ticks: RangeInclusive<usize>,
711    next_session_id: SessionId,
712    capture_pending_traces: bool,
713    next_trace_id: TraceId,
714    pending_traces: BTreeMap<TraceId, Backtrace>,
715    is_main_thread: bool,
716    non_determinism_error: Option<(String, Backtrace)>,
717    parking_allowed_once: bool,
718    finished: bool,
719    unparked: bool,
720}
721
722const WAKER_VTABLE: RawWakerVTable = RawWakerVTable::new(
723    TracingWaker::clone_raw,
724    TracingWaker::wake_raw,
725    TracingWaker::wake_by_ref_raw,
726    TracingWaker::drop_raw,
727);
728
729#[derive(Copy, Clone, Eq, PartialEq, PartialOrd, Ord)]
730struct TraceId(usize);
731
732struct TracingWaker {
733    id: Option<TraceId>,
734    awoken: Arc<AtomicBool>,
735    thread: Thread,
736    state: Arc<Mutex<SchedulerState>>,
737}
738
739impl Clone for TracingWaker {
740    fn clone(&self) -> Self {
741        let mut state = self.state.lock();
742        let id = if state.capture_pending_traces {
743            let id = state.next_trace_id;
744            state.next_trace_id.0 += 1;
745            state.pending_traces.insert(id, Backtrace::new_unresolved());
746            Some(id)
747        } else {
748            None
749        };
750        Self {
751            id,
752            awoken: self.awoken.clone(),
753            thread: self.thread.clone(),
754            state: self.state.clone(),
755        }
756    }
757}
758
759impl Drop for TracingWaker {
760    fn drop(&mut self) {
761        assert_correct_thread(&self.thread, &self.state);
762
763        if let Some(id) = self.id {
764            self.state.lock().pending_traces.remove(&id);
765        }
766    }
767}
768
769impl TracingWaker {
770    fn wake(self) {
771        self.wake_by_ref();
772    }
773
774    fn wake_by_ref(&self) {
775        assert_correct_thread(&self.thread, &self.state);
776
777        let mut state = self.state.lock();
778        if let Some(id) = self.id {
779            state.pending_traces.remove(&id);
780        }
781        state.unparked = true;
782        drop(state);
783        self.awoken.store(true, SeqCst);
784        self.thread.unpark();
785    }
786
787    fn clone_raw(waker: *const ()) -> RawWaker {
788        let waker = waker as *const TracingWaker;
789        let waker = unsafe { &*waker };
790        RawWaker::new(
791            Box::into_raw(Box::new(waker.clone())) as *const (),
792            &WAKER_VTABLE,
793        )
794    }
795
796    fn wake_raw(waker: *const ()) {
797        let waker = unsafe { Box::from_raw(waker as *mut TracingWaker) };
798        waker.wake();
799    }
800
801    fn wake_by_ref_raw(waker: *const ()) {
802        let waker = waker as *const TracingWaker;
803        let waker = unsafe { &*waker };
804        waker.wake_by_ref();
805    }
806
807    fn drop_raw(waker: *const ()) {
808        let waker = unsafe { Box::from_raw(waker as *mut TracingWaker) };
809        drop(waker);
810    }
811}
812
813pub struct Yield(usize);
814
815/// A wrapper around `Arc<Mutex<StdRng>>` that provides convenient methods
816/// for random number generation without requiring explicit locking.
817#[derive(Clone)]
818pub struct SharedRng(Arc<Mutex<StdRng>>);
819
820impl SharedRng {
821    /// Lock the inner RNG for direct access. Use this when you need multiple
822    /// random operations without re-locking between each one.
823    pub fn lock(&self) -> MutexGuard<'_, StdRng> {
824        self.0.lock()
825    }
826
827    /// Generate a random value in the given range.
828    pub fn random_range<T, R>(&self, range: R) -> T
829    where
830        T: SampleUniform,
831        R: SampleRange<T>,
832    {
833        self.0.lock().random_range(range)
834    }
835
836    /// Generate a random boolean with the given probability of being true.
837    pub fn random_bool(&self, p: f64) -> bool {
838        self.0.lock().random_bool(p)
839    }
840
841    /// Generate a random value of the given type.
842    pub fn random<T>(&self) -> T
843    where
844        StandardUniform: Distribution<T>,
845    {
846        self.0.lock().random()
847    }
848
849    /// Generate a random ratio - true with probability `numerator/denominator`.
850    pub fn random_ratio(&self, numerator: u32, denominator: u32) -> bool {
851        self.0.lock().random_ratio(numerator, denominator)
852    }
853}
854
855impl Future for Yield {
856    type Output = ();
857
858    fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
859        if self.0 == 0 {
860            Poll::Ready(())
861        } else {
862            self.0 -= 1;
863            cx.waker().wake_by_ref();
864            Poll::Pending
865        }
866    }
867}
868
869fn exclude_wakers_from_trace(mut trace: Backtrace) -> Backtrace {
870    trace.resolve();
871    let mut frames: Vec<BacktraceFrame> = trace.into();
872    let waker_clone_frame_ix = frames.iter().position(|frame| {
873        frame.symbols().iter().any(|symbol| {
874            symbol
875                .name()
876                .is_some_and(|name| format!("{name:#?}") == type_name_of_val(&Waker::clone))
877        })
878    });
879
880    if let Some(waker_clone_frame_ix) = waker_clone_frame_ix {
881        frames.drain(..waker_clone_frame_ix + 1);
882    }
883
884    Backtrace::from(frames)
885}