1use crate::{
2 BackgroundExecutor, Clock, ForegroundExecutor, Instant, Priority, RunnableMeta, Scheduler,
3 SessionId, TestClock, Timer,
4};
5use async_task::Runnable;
6use backtrace::{Backtrace, BacktraceFrame};
7use futures::channel::oneshot;
8use parking_lot::{Mutex, MutexGuard};
9use rand::{
10 distr::{StandardUniform, uniform::SampleRange, uniform::SampleUniform},
11 prelude::*,
12};
13use std::{
14 any::type_name_of_val,
15 collections::{BTreeMap, HashSet, VecDeque},
16 env,
17 fmt::Write,
18 future::Future,
19 mem,
20 ops::RangeInclusive,
21 panic::{self, AssertUnwindSafe},
22 pin::Pin,
23 sync::{
24 Arc,
25 atomic::{AtomicBool, Ordering::SeqCst},
26 },
27 task::{Context, Poll, RawWaker, RawWakerVTable, Waker},
28 thread::{self, Thread},
29 time::Duration,
30};
31
32const PENDING_TRACES_VAR_NAME: &str = "PENDING_TRACES";
33
34pub struct TestScheduler {
35 clock: Arc<TestClock>,
36 rng: Arc<Mutex<StdRng>>,
37 state: Arc<Mutex<SchedulerState>>,
38 thread: Thread,
39}
40
41impl TestScheduler {
42 /// Run a test once with default configuration (seed 0)
43 pub fn once<R>(f: impl AsyncFnOnce(Arc<TestScheduler>) -> R) -> R {
44 Self::with_seed(0, f)
45 }
46
47 /// Run a test multiple times with sequential seeds (0, 1, 2, ...)
48 pub fn many<R>(
49 default_iterations: usize,
50 mut f: impl AsyncFnMut(Arc<TestScheduler>) -> R,
51 ) -> Vec<R> {
52 let num_iterations = std::env::var("ITERATIONS")
53 .map(|iterations| iterations.parse().unwrap())
54 .unwrap_or(default_iterations);
55
56 let seed = std::env::var("SEED")
57 .map(|seed| seed.parse().unwrap())
58 .unwrap_or(0);
59
60 (seed..seed + num_iterations as u64)
61 .map(|seed| {
62 let mut unwind_safe_f = AssertUnwindSafe(&mut f);
63 eprintln!("Running seed: {seed}");
64 match panic::catch_unwind(move || Self::with_seed(seed, &mut *unwind_safe_f)) {
65 Ok(result) => result,
66 Err(error) => {
67 eprintln!("\x1b[31mFailing Seed: {seed}\x1b[0m");
68 panic::resume_unwind(error);
69 }
70 }
71 })
72 .collect()
73 }
74
75 fn with_seed<R>(seed: u64, f: impl AsyncFnOnce(Arc<TestScheduler>) -> R) -> R {
76 let scheduler = Arc::new(TestScheduler::new(TestSchedulerConfig::with_seed(seed)));
77 let future = f(scheduler.clone());
78 let result = scheduler.foreground().block_on(future);
79 scheduler.run(); // Ensure spawned tasks finish up before returning in tests
80 result
81 }
82
83 pub fn new(config: TestSchedulerConfig) -> Self {
84 Self {
85 rng: Arc::new(Mutex::new(StdRng::seed_from_u64(config.seed))),
86 state: Arc::new(Mutex::new(SchedulerState {
87 runnables: VecDeque::new(),
88 timers: Vec::new(),
89 blocked_sessions: Vec::new(),
90 randomize_order: config.randomize_order,
91 allow_parking: config.allow_parking,
92 timeout_ticks: config.timeout_ticks,
93 next_session_id: SessionId(0),
94 capture_pending_traces: config.capture_pending_traces,
95 pending_traces: BTreeMap::new(),
96 next_trace_id: TraceId(0),
97 is_main_thread: true,
98 non_determinism_error: None,
99 finished: false,
100 parking_allowed_once: false,
101 unparked: false,
102 })),
103 clock: Arc::new(TestClock::new()),
104 thread: thread::current(),
105 }
106 }
107
108 pub fn end_test(&self) {
109 let mut state = self.state.lock();
110 if let Some((message, backtrace)) = &state.non_determinism_error {
111 panic!("{}\n{:?}", message, backtrace)
112 }
113 state.finished = true;
114 }
115
116 pub fn clock(&self) -> Arc<TestClock> {
117 self.clock.clone()
118 }
119
120 pub fn rng(&self) -> SharedRng {
121 SharedRng(self.rng.clone())
122 }
123
124 pub fn set_timeout_ticks(&self, timeout_ticks: RangeInclusive<usize>) {
125 self.state.lock().timeout_ticks = timeout_ticks;
126 }
127
128 pub fn allow_parking(&self) {
129 let mut state = self.state.lock();
130 state.allow_parking = true;
131 state.parking_allowed_once = true;
132 }
133
134 pub fn forbid_parking(&self) {
135 self.state.lock().allow_parking = false;
136 }
137
138 pub fn parking_allowed(&self) -> bool {
139 self.state.lock().allow_parking
140 }
141
142 pub fn is_main_thread(&self) -> bool {
143 self.state.lock().is_main_thread
144 }
145
146 /// Allocate a new session ID for foreground task scheduling.
147 /// This is used by GPUI's TestDispatcher to map dispatcher instances to sessions.
148 pub fn allocate_session_id(&self) -> SessionId {
149 let mut state = self.state.lock();
150 state.next_session_id.0 += 1;
151 state.next_session_id
152 }
153
154 /// Create a foreground executor for this scheduler
155 pub fn foreground(self: &Arc<Self>) -> ForegroundExecutor {
156 let session_id = self.allocate_session_id();
157 ForegroundExecutor::new(session_id, self.clone())
158 }
159
160 /// Create a background executor for this scheduler
161 pub fn background(self: &Arc<Self>) -> BackgroundExecutor {
162 BackgroundExecutor::new(self.clone())
163 }
164
165 pub fn yield_random(&self) -> Yield {
166 let rng = &mut *self.rng.lock();
167 if rng.random_bool(0.1) {
168 Yield(rng.random_range(10..20))
169 } else {
170 Yield(rng.random_range(0..2))
171 }
172 }
173
174 pub fn run(&self) {
175 while self.step() {
176 // Continue until no work remains
177 }
178 }
179
180 pub fn run_with_clock_advancement(&self) {
181 while self.step() || self.advance_clock_to_next_timer() {
182 // Continue until no work remains
183 }
184 }
185
186 /// Execute one tick of the scheduler, processing expired timers and running
187 /// at most one task. Returns true if any work was done.
188 ///
189 /// This is the public interface for GPUI's TestDispatcher to drive task execution.
190 pub fn tick(&self) -> bool {
191 self.step_filtered(false)
192 }
193
194 /// Execute one tick, but only run background tasks (no foreground/session tasks).
195 /// Returns true if any work was done.
196 pub fn tick_background_only(&self) -> bool {
197 self.step_filtered(true)
198 }
199
200 /// Check if there are any pending tasks or timers that could run.
201 pub fn has_pending_tasks(&self) -> bool {
202 let state = self.state.lock();
203 !state.runnables.is_empty() || !state.timers.is_empty()
204 }
205
206 /// Returns counts of (foreground_tasks, background_tasks) currently queued.
207 /// Foreground tasks are those with a session_id, background tasks have none.
208 pub fn pending_task_counts(&self) -> (usize, usize) {
209 let state = self.state.lock();
210 let foreground = state
211 .runnables
212 .iter()
213 .filter(|r| r.session_id.is_some())
214 .count();
215 let background = state
216 .runnables
217 .iter()
218 .filter(|r| r.session_id.is_none())
219 .count();
220 (foreground, background)
221 }
222
223 fn step(&self) -> bool {
224 self.step_filtered(false)
225 }
226
227 fn step_filtered(&self, background_only: bool) -> bool {
228 let (elapsed_count, runnables_before) = {
229 let mut state = self.state.lock();
230 let end_ix = state
231 .timers
232 .partition_point(|timer| timer.expiration <= self.clock.now());
233 let elapsed: Vec<_> = state.timers.drain(..end_ix).collect();
234 let count = elapsed.len();
235 let runnables = state.runnables.len();
236 drop(state);
237 // Dropping elapsed timers here wakes the waiting futures
238 drop(elapsed);
239 (count, runnables)
240 };
241
242 if elapsed_count > 0 {
243 let runnables_after = self.state.lock().runnables.len();
244 if std::env::var("DEBUG_SCHEDULER").is_ok() {
245 eprintln!(
246 "[scheduler] Expired {} timers at {:?}, runnables: {} -> {}",
247 elapsed_count,
248 self.clock.now(),
249 runnables_before,
250 runnables_after
251 );
252 }
253 return true;
254 }
255
256 let runnable = {
257 let state = &mut *self.state.lock();
258
259 // Find candidate tasks:
260 // - For foreground tasks (with session_id), only the first task from each session
261 // is a candidate (to preserve intra-session ordering)
262 // - For background tasks (no session_id), all are candidates
263 // - Tasks from blocked sessions are excluded
264 // - If background_only is true, skip foreground tasks entirely
265 let mut seen_sessions = HashSet::new();
266 let candidate_indices: Vec<usize> = state
267 .runnables
268 .iter()
269 .enumerate()
270 .filter(|(_, runnable)| {
271 if let Some(session_id) = runnable.session_id {
272 // Skip foreground tasks if background_only mode
273 if background_only {
274 return false;
275 }
276 // Exclude tasks from blocked sessions
277 if state.blocked_sessions.contains(&session_id) {
278 return false;
279 }
280 // Only include first task from each session (insert returns true if new)
281 seen_sessions.insert(session_id)
282 } else {
283 // Background tasks are always candidates
284 true
285 }
286 })
287 .map(|(ix, _)| ix)
288 .collect();
289
290 if candidate_indices.is_empty() {
291 None
292 } else if state.randomize_order {
293 // Use priority-weighted random selection
294 let weights: Vec<u32> = candidate_indices
295 .iter()
296 .map(|&ix| state.runnables[ix].priority.weight())
297 .collect();
298 let total_weight: u32 = weights.iter().sum();
299
300 if total_weight == 0 {
301 // Fallback to uniform random if all weights are zero
302 let choice = self.rng.lock().random_range(0..candidate_indices.len());
303 state.runnables.remove(candidate_indices[choice])
304 } else {
305 let mut target = self.rng.lock().random_range(0..total_weight);
306 let mut selected_idx = 0;
307 for (i, &weight) in weights.iter().enumerate() {
308 if target < weight {
309 selected_idx = i;
310 break;
311 }
312 target -= weight;
313 }
314 state.runnables.remove(candidate_indices[selected_idx])
315 }
316 } else {
317 // Non-randomized: just take the first candidate task
318 state.runnables.remove(candidate_indices[0])
319 }
320 };
321
322 if let Some(runnable) = runnable {
323 // Check if the executor that spawned this task was closed
324 if runnable.runnable.metadata().is_closed() {
325 return true;
326 }
327 let is_foreground = runnable.session_id.is_some();
328 let was_main_thread = self.state.lock().is_main_thread;
329 self.state.lock().is_main_thread = is_foreground;
330 runnable.run();
331 self.state.lock().is_main_thread = was_main_thread;
332 return true;
333 }
334
335 false
336 }
337
338 /// Drops all runnable tasks from the scheduler.
339 ///
340 /// This is used by the leak detector to ensure that all tasks have been dropped as tasks may keep entities alive otherwise.
341 /// Why do we even have tasks left when tests finish you may ask. The reason for that is simple, the scheduler itself is the executor and it retains the scheduled runnables.
342 /// A lot of tasks, including every foreground task contain an executor handle that keeps the test scheduler alive, causing a reference cycle, thus the need for this function right now.
343 pub fn drain_tasks(&self) {
344 // dropping runnables may reschedule tasks
345 // due to drop impls with executors in them
346 // so drop until we reach a fixpoint
347 loop {
348 let mut state = self.state.lock();
349 if state.runnables.is_empty() && state.timers.is_empty() {
350 break;
351 }
352 let runnables = std::mem::take(&mut state.runnables);
353 let timers = std::mem::take(&mut state.timers);
354 drop(state);
355 drop(timers);
356 drop(runnables);
357 }
358 }
359
360 pub fn advance_clock_to_next_timer(&self) -> bool {
361 if let Some(timer) = self.state.lock().timers.first() {
362 self.clock.advance(timer.expiration - self.clock.now());
363 true
364 } else {
365 false
366 }
367 }
368
369 pub fn advance_clock(&self, duration: Duration) {
370 let debug = std::env::var("DEBUG_SCHEDULER").is_ok();
371 let start = self.clock.now();
372 let next_now = start + duration;
373 if debug {
374 let timer_count = self.state.lock().timers.len();
375 eprintln!(
376 "[scheduler] advance_clock({:?}) from {:?}, {} pending timers",
377 duration, start, timer_count
378 );
379 }
380 loop {
381 self.run();
382 if let Some(timer) = self.state.lock().timers.first()
383 && timer.expiration <= next_now
384 {
385 let advance_to = timer.expiration;
386 if debug {
387 eprintln!(
388 "[scheduler] Advancing clock {:?} -> {:?} for timer",
389 self.clock.now(),
390 advance_to
391 );
392 }
393 self.clock.advance(advance_to - self.clock.now());
394 } else {
395 break;
396 }
397 }
398 self.clock.advance(next_now - self.clock.now());
399 if debug {
400 eprintln!(
401 "[scheduler] advance_clock done, now at {:?}",
402 self.clock.now()
403 );
404 }
405 }
406
407 fn park(&self, deadline: Option<Instant>) -> bool {
408 if self.state.lock().allow_parking {
409 let start = Instant::now();
410 // Enforce a hard timeout to prevent tests from hanging indefinitely
411 let hard_deadline = start + Duration::from_secs(15);
412
413 // Use the earlier of the provided deadline or the hard timeout deadline
414 let effective_deadline = deadline
415 .map(|d| d.min(hard_deadline))
416 .unwrap_or(hard_deadline);
417
418 // Park in small intervals to allow checking both deadlines
419 const PARK_INTERVAL: Duration = Duration::from_millis(100);
420 loop {
421 let now = Instant::now();
422 if now >= effective_deadline {
423 // Check if we hit the hard timeout
424 if now >= hard_deadline {
425 panic!(
426 "Test timed out after 15 seconds while parking. \
427 This may indicate a deadlock or missing waker.",
428 );
429 }
430 // Hit the provided deadline
431 return false;
432 }
433
434 let remaining = effective_deadline.saturating_duration_since(now);
435 let park_duration = remaining.min(PARK_INTERVAL);
436 let before_park = Instant::now();
437 thread::park_timeout(park_duration);
438 let elapsed = before_park.elapsed();
439
440 // Advance the test clock by the real elapsed time while parking
441 self.clock.advance(elapsed);
442
443 // Check if any timers have expired after advancing the clock.
444 // If so, return so the caller can process them.
445 if self
446 .state
447 .lock()
448 .timers
449 .first()
450 .map_or(false, |t| t.expiration <= self.clock.now())
451 {
452 return true;
453 }
454
455 // Check if we were woken up by a different thread.
456 // We use a flag because timing-based detection is unreliable:
457 // OS scheduling delays can cause elapsed >= park_duration even when
458 // we were woken early by unpark().
459 if std::mem::take(&mut self.state.lock().unparked) {
460 return true;
461 }
462 }
463 } else if deadline.is_some() {
464 false
465 } else if self.state.lock().capture_pending_traces {
466 let mut pending_traces = String::new();
467 for (_, trace) in mem::take(&mut self.state.lock().pending_traces) {
468 writeln!(pending_traces, "{:?}", exclude_wakers_from_trace(trace)).unwrap();
469 }
470 panic!("Parking forbidden. Pending traces:\n{}", pending_traces);
471 } else {
472 panic!(
473 "Parking forbidden. Re-run with {PENDING_TRACES_VAR_NAME}=1 to show pending traces"
474 );
475 }
476 }
477}
478
479fn assert_correct_thread(expected: &Thread, state: &Arc<Mutex<SchedulerState>>) {
480 let current_thread = thread::current();
481 let mut state = state.lock();
482 if state.parking_allowed_once {
483 return;
484 }
485 if current_thread.id() == expected.id() {
486 return;
487 }
488
489 let message = format!(
490 "Detected activity on thread {:?} {:?}, but test scheduler is running on {:?} {:?}. Your test is not deterministic.",
491 current_thread.name(),
492 current_thread.id(),
493 expected.name(),
494 expected.id(),
495 );
496 let backtrace = Backtrace::new();
497 if state.finished {
498 panic!("{}", message);
499 } else {
500 state.non_determinism_error = Some((message, backtrace))
501 }
502}
503
504impl Scheduler for TestScheduler {
505 /// Block until the given future completes, with an optional timeout. If the
506 /// future is unable to make progress at any moment before the timeout and
507 /// no other tasks or timers remain, we panic unless parking is allowed. If
508 /// parking is allowed, we block up to the timeout or indefinitely if none
509 /// is provided. This is to allow testing a mix of deterministic and
510 /// non-deterministic async behavior, such as when interacting with I/O in
511 /// an otherwise deterministic test.
512 fn block(
513 &self,
514 session_id: Option<SessionId>,
515 mut future: Pin<&mut dyn Future<Output = ()>>,
516 timeout: Option<Duration>,
517 ) -> bool {
518 if let Some(session_id) = session_id {
519 self.state.lock().blocked_sessions.push(session_id);
520 }
521
522 let deadline = timeout.map(|timeout| Instant::now() + timeout);
523 let awoken = Arc::new(AtomicBool::new(false));
524 let waker = Box::new(TracingWaker {
525 id: None,
526 awoken: awoken.clone(),
527 thread: self.thread.clone(),
528 state: self.state.clone(),
529 });
530 let waker = unsafe { Waker::new(Box::into_raw(waker) as *const (), &WAKER_VTABLE) };
531 let max_ticks = if timeout.is_some() {
532 self.rng
533 .lock()
534 .random_range(self.state.lock().timeout_ticks.clone())
535 } else {
536 usize::MAX
537 };
538 let mut cx = Context::from_waker(&waker);
539
540 let mut completed = false;
541 for _ in 0..max_ticks {
542 match future.as_mut().poll(&mut cx) {
543 Poll::Ready(()) => {
544 completed = true;
545 break;
546 }
547 Poll::Pending => {}
548 }
549
550 let mut stepped = None;
551 while self.rng.lock().random() {
552 let stepped = stepped.get_or_insert(false);
553 if self.step() {
554 *stepped = true;
555 } else {
556 break;
557 }
558 }
559
560 let stepped = stepped.unwrap_or(true);
561 let awoken = awoken.swap(false, SeqCst);
562 if !stepped && !awoken {
563 let parking_allowed = self.state.lock().allow_parking;
564 // In deterministic mode (parking forbidden), instantly jump to the next timer.
565 // In non-deterministic mode (parking allowed), let real time pass instead.
566 let advanced_to_timer = !parking_allowed && self.advance_clock_to_next_timer();
567 if !advanced_to_timer && !self.park(deadline) {
568 break;
569 }
570 }
571 }
572
573 if session_id.is_some() {
574 self.state.lock().blocked_sessions.pop();
575 }
576
577 completed
578 }
579
580 fn schedule_foreground(&self, session_id: SessionId, runnable: Runnable<RunnableMeta>) {
581 assert_correct_thread(&self.thread, &self.state);
582 let mut state = self.state.lock();
583 let ix = if state.randomize_order {
584 let start_ix = state
585 .runnables
586 .iter()
587 .rposition(|task| task.session_id == Some(session_id))
588 .map_or(0, |ix| ix + 1);
589 self.rng
590 .lock()
591 .random_range(start_ix..=state.runnables.len())
592 } else {
593 state.runnables.len()
594 };
595 state.runnables.insert(
596 ix,
597 ScheduledRunnable {
598 session_id: Some(session_id),
599 priority: Priority::default(),
600 runnable,
601 },
602 );
603 state.unparked = true;
604 drop(state);
605 self.thread.unpark();
606 }
607
608 fn schedule_background_with_priority(
609 &self,
610 runnable: Runnable<RunnableMeta>,
611 priority: Priority,
612 ) {
613 assert_correct_thread(&self.thread, &self.state);
614 let mut state = self.state.lock();
615 let ix = if state.randomize_order {
616 self.rng.lock().random_range(0..=state.runnables.len())
617 } else {
618 state.runnables.len()
619 };
620 state.runnables.insert(
621 ix,
622 ScheduledRunnable {
623 session_id: None,
624 priority,
625 runnable,
626 },
627 );
628 state.unparked = true;
629 drop(state);
630 self.thread.unpark();
631 }
632
633 fn spawn_realtime(&self, f: Box<dyn FnOnce() + Send>) {
634 std::thread::spawn(move || {
635 f();
636 });
637 }
638
639 #[track_caller]
640 fn timer(&self, duration: Duration) -> Timer {
641 let (tx, rx) = oneshot::channel();
642 let state = &mut *self.state.lock();
643 state.timers.push(ScheduledTimer {
644 expiration: self.clock.now() + duration,
645 _notify: tx,
646 });
647 state.timers.sort_by_key(|timer| timer.expiration);
648 Timer(rx)
649 }
650
651 fn clock(&self) -> Arc<dyn Clock> {
652 self.clock.clone()
653 }
654
655 fn as_test(&self) -> Option<&TestScheduler> {
656 Some(self)
657 }
658}
659
660#[derive(Clone, Debug)]
661pub struct TestSchedulerConfig {
662 pub seed: u64,
663 pub randomize_order: bool,
664 pub allow_parking: bool,
665 pub capture_pending_traces: bool,
666 pub timeout_ticks: RangeInclusive<usize>,
667}
668
669impl TestSchedulerConfig {
670 pub fn with_seed(seed: u64) -> Self {
671 Self {
672 seed,
673 ..Default::default()
674 }
675 }
676}
677
678impl Default for TestSchedulerConfig {
679 fn default() -> Self {
680 Self {
681 seed: 0,
682 randomize_order: true,
683 allow_parking: false,
684 capture_pending_traces: env::var(PENDING_TRACES_VAR_NAME)
685 .map_or(false, |var| var == "1" || var == "true"),
686 timeout_ticks: 1..=1000,
687 }
688 }
689}
690
691struct ScheduledRunnable {
692 session_id: Option<SessionId>,
693 priority: Priority,
694 runnable: Runnable<RunnableMeta>,
695}
696
697impl ScheduledRunnable {
698 fn run(self) {
699 self.runnable.run();
700 }
701}
702
703struct ScheduledTimer {
704 expiration: Instant,
705 _notify: oneshot::Sender<()>,
706}
707
708struct SchedulerState {
709 runnables: VecDeque<ScheduledRunnable>,
710 timers: Vec<ScheduledTimer>,
711 blocked_sessions: Vec<SessionId>,
712 randomize_order: bool,
713 allow_parking: bool,
714 timeout_ticks: RangeInclusive<usize>,
715 next_session_id: SessionId,
716 capture_pending_traces: bool,
717 next_trace_id: TraceId,
718 pending_traces: BTreeMap<TraceId, Backtrace>,
719 is_main_thread: bool,
720 non_determinism_error: Option<(String, Backtrace)>,
721 parking_allowed_once: bool,
722 finished: bool,
723 unparked: bool,
724}
725
726const WAKER_VTABLE: RawWakerVTable = RawWakerVTable::new(
727 TracingWaker::clone_raw,
728 TracingWaker::wake_raw,
729 TracingWaker::wake_by_ref_raw,
730 TracingWaker::drop_raw,
731);
732
733#[derive(Copy, Clone, Eq, PartialEq, PartialOrd, Ord)]
734struct TraceId(usize);
735
736struct TracingWaker {
737 id: Option<TraceId>,
738 awoken: Arc<AtomicBool>,
739 thread: Thread,
740 state: Arc<Mutex<SchedulerState>>,
741}
742
743impl Clone for TracingWaker {
744 fn clone(&self) -> Self {
745 let mut state = self.state.lock();
746 let id = if state.capture_pending_traces {
747 let id = state.next_trace_id;
748 state.next_trace_id.0 += 1;
749 state.pending_traces.insert(id, Backtrace::new_unresolved());
750 Some(id)
751 } else {
752 None
753 };
754 Self {
755 id,
756 awoken: self.awoken.clone(),
757 thread: self.thread.clone(),
758 state: self.state.clone(),
759 }
760 }
761}
762
763impl Drop for TracingWaker {
764 fn drop(&mut self) {
765 assert_correct_thread(&self.thread, &self.state);
766
767 if let Some(id) = self.id {
768 self.state.lock().pending_traces.remove(&id);
769 }
770 }
771}
772
773impl TracingWaker {
774 fn wake(self) {
775 self.wake_by_ref();
776 }
777
778 fn wake_by_ref(&self) {
779 assert_correct_thread(&self.thread, &self.state);
780
781 let mut state = self.state.lock();
782 if let Some(id) = self.id {
783 state.pending_traces.remove(&id);
784 }
785 state.unparked = true;
786 drop(state);
787 self.awoken.store(true, SeqCst);
788 self.thread.unpark();
789 }
790
791 fn clone_raw(waker: *const ()) -> RawWaker {
792 let waker = waker as *const TracingWaker;
793 let waker = unsafe { &*waker };
794 RawWaker::new(
795 Box::into_raw(Box::new(waker.clone())) as *const (),
796 &WAKER_VTABLE,
797 )
798 }
799
800 fn wake_raw(waker: *const ()) {
801 let waker = unsafe { Box::from_raw(waker as *mut TracingWaker) };
802 waker.wake();
803 }
804
805 fn wake_by_ref_raw(waker: *const ()) {
806 let waker = waker as *const TracingWaker;
807 let waker = unsafe { &*waker };
808 waker.wake_by_ref();
809 }
810
811 fn drop_raw(waker: *const ()) {
812 let waker = unsafe { Box::from_raw(waker as *mut TracingWaker) };
813 drop(waker);
814 }
815}
816
817pub struct Yield(usize);
818
819/// A wrapper around `Arc<Mutex<StdRng>>` that provides convenient methods
820/// for random number generation without requiring explicit locking.
821#[derive(Clone)]
822pub struct SharedRng(Arc<Mutex<StdRng>>);
823
824impl SharedRng {
825 /// Lock the inner RNG for direct access. Use this when you need multiple
826 /// random operations without re-locking between each one.
827 pub fn lock(&self) -> MutexGuard<'_, StdRng> {
828 self.0.lock()
829 }
830
831 /// Generate a random value in the given range.
832 pub fn random_range<T, R>(&self, range: R) -> T
833 where
834 T: SampleUniform,
835 R: SampleRange<T>,
836 {
837 self.0.lock().random_range(range)
838 }
839
840 /// Generate a random boolean with the given probability of being true.
841 pub fn random_bool(&self, p: f64) -> bool {
842 self.0.lock().random_bool(p)
843 }
844
845 /// Generate a random value of the given type.
846 pub fn random<T>(&self) -> T
847 where
848 StandardUniform: Distribution<T>,
849 {
850 self.0.lock().random()
851 }
852
853 /// Generate a random ratio - true with probability `numerator/denominator`.
854 pub fn random_ratio(&self, numerator: u32, denominator: u32) -> bool {
855 self.0.lock().random_ratio(numerator, denominator)
856 }
857}
858
859impl Future for Yield {
860 type Output = ();
861
862 fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
863 if self.0 == 0 {
864 Poll::Ready(())
865 } else {
866 self.0 -= 1;
867 cx.waker().wake_by_ref();
868 Poll::Pending
869 }
870 }
871}
872
873fn exclude_wakers_from_trace(mut trace: Backtrace) -> Backtrace {
874 trace.resolve();
875 let mut frames: Vec<BacktraceFrame> = trace.into();
876 let waker_clone_frame_ix = frames.iter().position(|frame| {
877 frame.symbols().iter().any(|symbol| {
878 symbol
879 .name()
880 .is_some_and(|name| format!("{name:#?}") == type_name_of_val(&Waker::clone))
881 })
882 });
883
884 if let Some(waker_clone_frame_ix) = waker_clone_frame_ix {
885 frames.drain(..waker_clone_frame_ix + 1);
886 }
887
888 Backtrace::from(frames)
889}