1use crate::{
2 BackgroundExecutor, Clock, ForegroundExecutor, Instant, Priority, RunnableMeta, Scheduler,
3 SessionId, TestClock, Timer,
4};
5use async_task::Runnable;
6use backtrace::{Backtrace, BacktraceFrame};
7use futures::channel::oneshot;
8use parking_lot::{Mutex, MutexGuard};
9use rand::{
10 distr::{StandardUniform, uniform::SampleRange, uniform::SampleUniform},
11 prelude::*,
12};
13use std::{
14 any::type_name_of_val,
15 collections::{BTreeMap, HashSet, VecDeque},
16 env,
17 fmt::Write,
18 future::Future,
19 mem,
20 ops::RangeInclusive,
21 panic::{self, AssertUnwindSafe},
22 pin::Pin,
23 sync::{
24 Arc,
25 atomic::{AtomicBool, Ordering::SeqCst},
26 },
27 task::{Context, Poll, RawWaker, RawWakerVTable, Waker},
28 thread::{self, Thread},
29 time::Duration,
30};
31
32const PENDING_TRACES_VAR_NAME: &str = "PENDING_TRACES";
33
34pub struct TestScheduler {
35 clock: Arc<TestClock>,
36 rng: Arc<Mutex<StdRng>>,
37 state: Arc<Mutex<SchedulerState>>,
38 thread: Thread,
39}
40
41impl TestScheduler {
42 /// Run a test once with default configuration (seed 0)
43 pub fn once<R>(f: impl AsyncFnOnce(Arc<TestScheduler>) -> R) -> R {
44 Self::with_seed(0, f)
45 }
46
47 /// Run a test multiple times with sequential seeds (0, 1, 2, ...)
48 pub fn many<R>(
49 default_iterations: usize,
50 mut f: impl AsyncFnMut(Arc<TestScheduler>) -> R,
51 ) -> Vec<R> {
52 let num_iterations = std::env::var("ITERATIONS")
53 .map(|iterations| iterations.parse().unwrap())
54 .unwrap_or(default_iterations);
55
56 let seed = std::env::var("SEED")
57 .map(|seed| seed.parse().unwrap())
58 .unwrap_or(0);
59
60 (seed..seed + num_iterations as u64)
61 .map(|seed| {
62 let mut unwind_safe_f = AssertUnwindSafe(&mut f);
63 eprintln!("Running seed: {seed}");
64 match panic::catch_unwind(move || Self::with_seed(seed, &mut *unwind_safe_f)) {
65 Ok(result) => result,
66 Err(error) => {
67 eprintln!("\x1b[31mFailing Seed: {seed}\x1b[0m");
68 panic::resume_unwind(error);
69 }
70 }
71 })
72 .collect()
73 }
74
75 fn with_seed<R>(seed: u64, f: impl AsyncFnOnce(Arc<TestScheduler>) -> R) -> R {
76 let scheduler = Arc::new(TestScheduler::new(TestSchedulerConfig::with_seed(seed)));
77 let future = f(scheduler.clone());
78 let result = scheduler.foreground().block_on(future);
79 scheduler.run(); // Ensure spawned tasks finish up before returning in tests
80 result
81 }
82
83 pub fn new(config: TestSchedulerConfig) -> Self {
84 Self {
85 rng: Arc::new(Mutex::new(StdRng::seed_from_u64(config.seed))),
86 state: Arc::new(Mutex::new(SchedulerState {
87 runnables: VecDeque::new(),
88 timers: Vec::new(),
89 blocked_sessions: Vec::new(),
90 randomize_order: config.randomize_order,
91 allow_parking: config.allow_parking,
92 timeout_ticks: config.timeout_ticks,
93 next_session_id: SessionId(0),
94 capture_pending_traces: config.capture_pending_traces,
95 pending_traces: BTreeMap::new(),
96 next_trace_id: TraceId(0),
97 is_main_thread: true,
98 non_determinism_error: None,
99 finished: false,
100 parking_allowed_once: false,
101 unparked: false,
102 })),
103 clock: Arc::new(TestClock::new()),
104 thread: thread::current(),
105 }
106 }
107
108 pub fn end_test(&self) {
109 let mut state = self.state.lock();
110 if let Some((message, backtrace)) = &state.non_determinism_error {
111 panic!("{}\n{:?}", message, backtrace)
112 }
113 state.finished = true;
114 }
115
116 pub fn clock(&self) -> Arc<TestClock> {
117 self.clock.clone()
118 }
119
120 pub fn rng(&self) -> SharedRng {
121 SharedRng(self.rng.clone())
122 }
123
124 pub fn set_timeout_ticks(&self, timeout_ticks: RangeInclusive<usize>) {
125 self.state.lock().timeout_ticks = timeout_ticks;
126 }
127
128 pub fn allow_parking(&self) {
129 let mut state = self.state.lock();
130 state.allow_parking = true;
131 state.parking_allowed_once = true;
132 }
133
134 pub fn forbid_parking(&self) {
135 self.state.lock().allow_parking = false;
136 }
137
138 pub fn parking_allowed(&self) -> bool {
139 self.state.lock().allow_parking
140 }
141
142 pub fn is_main_thread(&self) -> bool {
143 self.state.lock().is_main_thread
144 }
145
146 /// Allocate a new session ID for foreground task scheduling.
147 /// This is used by GPUI's TestDispatcher to map dispatcher instances to sessions.
148 pub fn allocate_session_id(&self) -> SessionId {
149 let mut state = self.state.lock();
150 state.next_session_id.0 += 1;
151 state.next_session_id
152 }
153
154 /// Create a foreground executor for this scheduler
155 pub fn foreground(self: &Arc<Self>) -> ForegroundExecutor {
156 let session_id = self.allocate_session_id();
157 ForegroundExecutor::new(session_id, self.clone())
158 }
159
160 /// Create a background executor for this scheduler
161 pub fn background(self: &Arc<Self>) -> BackgroundExecutor {
162 BackgroundExecutor::new(self.clone())
163 }
164
165 pub fn yield_random(&self) -> Yield {
166 let rng = &mut *self.rng.lock();
167 if rng.random_bool(0.1) {
168 Yield(rng.random_range(10..20))
169 } else {
170 Yield(rng.random_range(0..2))
171 }
172 }
173
174 pub fn run(&self) {
175 while self.step() {
176 // Continue until no work remains
177 }
178 }
179
180 pub fn run_with_clock_advancement(&self) {
181 while self.step() || self.advance_clock_to_next_timer() {
182 // Continue until no work remains
183 }
184 }
185
186 /// Execute one tick of the scheduler, processing expired timers and running
187 /// at most one task. Returns true if any work was done.
188 ///
189 /// This is the public interface for GPUI's TestDispatcher to drive task execution.
190 pub fn tick(&self) -> bool {
191 self.step_filtered(false)
192 }
193
194 /// Execute one tick, but only run background tasks (no foreground/session tasks).
195 /// Returns true if any work was done.
196 pub fn tick_background_only(&self) -> bool {
197 self.step_filtered(true)
198 }
199
200 /// Check if there are any pending tasks or timers that could run.
201 pub fn has_pending_tasks(&self) -> bool {
202 let state = self.state.lock();
203 !state.runnables.is_empty() || !state.timers.is_empty()
204 }
205
206 /// Returns counts of (foreground_tasks, background_tasks) currently queued.
207 /// Foreground tasks are those with a session_id, background tasks have none.
208 pub fn pending_task_counts(&self) -> (usize, usize) {
209 let state = self.state.lock();
210 let foreground = state
211 .runnables
212 .iter()
213 .filter(|r| r.session_id.is_some())
214 .count();
215 let background = state
216 .runnables
217 .iter()
218 .filter(|r| r.session_id.is_none())
219 .count();
220 (foreground, background)
221 }
222
223 fn step(&self) -> bool {
224 self.step_filtered(false)
225 }
226
227 fn step_filtered(&self, background_only: bool) -> bool {
228 let (elapsed_count, runnables_before) = {
229 let mut state = self.state.lock();
230 let end_ix = state
231 .timers
232 .partition_point(|timer| timer.expiration <= self.clock.now());
233 let elapsed: Vec<_> = state.timers.drain(..end_ix).collect();
234 let count = elapsed.len();
235 let runnables = state.runnables.len();
236 drop(state);
237 // Dropping elapsed timers here wakes the waiting futures
238 drop(elapsed);
239 (count, runnables)
240 };
241
242 if elapsed_count > 0 {
243 let runnables_after = self.state.lock().runnables.len();
244 if std::env::var("DEBUG_SCHEDULER").is_ok() {
245 eprintln!(
246 "[scheduler] Expired {} timers at {:?}, runnables: {} -> {}",
247 elapsed_count,
248 self.clock.now(),
249 runnables_before,
250 runnables_after
251 );
252 }
253 return true;
254 }
255
256 let runnable = {
257 let state = &mut *self.state.lock();
258
259 // Find candidate tasks:
260 // - For foreground tasks (with session_id), only the first task from each session
261 // is a candidate (to preserve intra-session ordering)
262 // - For background tasks (no session_id), all are candidates
263 // - Tasks from blocked sessions are excluded
264 // - If background_only is true, skip foreground tasks entirely
265 let mut seen_sessions = HashSet::new();
266 let candidate_indices: Vec<usize> = state
267 .runnables
268 .iter()
269 .enumerate()
270 .filter(|(_, runnable)| {
271 if let Some(session_id) = runnable.session_id {
272 // Skip foreground tasks if background_only mode
273 if background_only {
274 return false;
275 }
276 // Exclude tasks from blocked sessions
277 if state.blocked_sessions.contains(&session_id) {
278 return false;
279 }
280 // Only include first task from each session (insert returns true if new)
281 seen_sessions.insert(session_id)
282 } else {
283 // Background tasks are always candidates
284 true
285 }
286 })
287 .map(|(ix, _)| ix)
288 .collect();
289
290 if candidate_indices.is_empty() {
291 None
292 } else if state.randomize_order {
293 // Use priority-weighted random selection
294 let weights: Vec<u32> = candidate_indices
295 .iter()
296 .map(|&ix| state.runnables[ix].priority.weight())
297 .collect();
298 let total_weight: u32 = weights.iter().sum();
299
300 if total_weight == 0 {
301 // Fallback to uniform random if all weights are zero
302 let choice = self.rng.lock().random_range(0..candidate_indices.len());
303 state.runnables.remove(candidate_indices[choice])
304 } else {
305 let mut target = self.rng.lock().random_range(0..total_weight);
306 let mut selected_idx = 0;
307 for (i, &weight) in weights.iter().enumerate() {
308 if target < weight {
309 selected_idx = i;
310 break;
311 }
312 target -= weight;
313 }
314 state.runnables.remove(candidate_indices[selected_idx])
315 }
316 } else {
317 // Non-randomized: just take the first candidate task
318 state.runnables.remove(candidate_indices[0])
319 }
320 };
321
322 if let Some(runnable) = runnable {
323 let is_foreground = runnable.session_id.is_some();
324 let was_main_thread = self.state.lock().is_main_thread;
325 self.state.lock().is_main_thread = is_foreground;
326 runnable.run();
327 self.state.lock().is_main_thread = was_main_thread;
328 return true;
329 }
330
331 false
332 }
333
334 /// Drops all runnable tasks from the scheduler.
335 ///
336 /// This is used by the leak detector to ensure that all tasks have been dropped as tasks may keep entities alive otherwise.
337 /// Why do we even have tasks left when tests finish you may ask. The reason for that is simple, the scheduler itself is the executor and it retains the scheduled runnables.
338 /// A lot of tasks, including every foreground task contain an executor handle that keeps the test scheduler alive, causing a reference cycle, thus the need for this function right now.
339 pub fn drain_tasks(&self) {
340 // dropping runnables may reschedule tasks
341 // due to drop impls with executors in them
342 // so drop until we reach a fixpoint
343 loop {
344 let mut state = self.state.lock();
345 if state.runnables.is_empty() && state.timers.is_empty() {
346 break;
347 }
348 let runnables = std::mem::take(&mut state.runnables);
349 let timers = std::mem::take(&mut state.timers);
350 drop(state);
351 drop(timers);
352 drop(runnables);
353 }
354 }
355
356 pub fn advance_clock_to_next_timer(&self) -> bool {
357 if let Some(timer) = self.state.lock().timers.first() {
358 self.clock.advance(timer.expiration - self.clock.now());
359 true
360 } else {
361 false
362 }
363 }
364
365 pub fn advance_clock(&self, duration: Duration) {
366 let debug = std::env::var("DEBUG_SCHEDULER").is_ok();
367 let start = self.clock.now();
368 let next_now = start + duration;
369 if debug {
370 let timer_count = self.state.lock().timers.len();
371 eprintln!(
372 "[scheduler] advance_clock({:?}) from {:?}, {} pending timers",
373 duration, start, timer_count
374 );
375 }
376 loop {
377 self.run();
378 if let Some(timer) = self.state.lock().timers.first()
379 && timer.expiration <= next_now
380 {
381 let advance_to = timer.expiration;
382 if debug {
383 eprintln!(
384 "[scheduler] Advancing clock {:?} -> {:?} for timer",
385 self.clock.now(),
386 advance_to
387 );
388 }
389 self.clock.advance(advance_to - self.clock.now());
390 } else {
391 break;
392 }
393 }
394 self.clock.advance(next_now - self.clock.now());
395 if debug {
396 eprintln!(
397 "[scheduler] advance_clock done, now at {:?}",
398 self.clock.now()
399 );
400 }
401 }
402
403 fn park(&self, deadline: Option<Instant>) -> bool {
404 if self.state.lock().allow_parking {
405 let start = Instant::now();
406 // Enforce a hard timeout to prevent tests from hanging indefinitely
407 let hard_deadline = start + Duration::from_secs(15);
408
409 // Use the earlier of the provided deadline or the hard timeout deadline
410 let effective_deadline = deadline
411 .map(|d| d.min(hard_deadline))
412 .unwrap_or(hard_deadline);
413
414 // Park in small intervals to allow checking both deadlines
415 const PARK_INTERVAL: Duration = Duration::from_millis(100);
416 loop {
417 let now = Instant::now();
418 if now >= effective_deadline {
419 // Check if we hit the hard timeout
420 if now >= hard_deadline {
421 panic!(
422 "Test timed out after 15 seconds while parking. \
423 This may indicate a deadlock or missing waker.",
424 );
425 }
426 // Hit the provided deadline
427 return false;
428 }
429
430 let remaining = effective_deadline.saturating_duration_since(now);
431 let park_duration = remaining.min(PARK_INTERVAL);
432 let before_park = Instant::now();
433 thread::park_timeout(park_duration);
434 let elapsed = before_park.elapsed();
435
436 // Advance the test clock by the real elapsed time while parking
437 self.clock.advance(elapsed);
438
439 // Check if any timers have expired after advancing the clock.
440 // If so, return so the caller can process them.
441 if self
442 .state
443 .lock()
444 .timers
445 .first()
446 .map_or(false, |t| t.expiration <= self.clock.now())
447 {
448 return true;
449 }
450
451 // Check if we were woken up by a different thread.
452 // We use a flag because timing-based detection is unreliable:
453 // OS scheduling delays can cause elapsed >= park_duration even when
454 // we were woken early by unpark().
455 if std::mem::take(&mut self.state.lock().unparked) {
456 return true;
457 }
458 }
459 } else if deadline.is_some() {
460 false
461 } else if self.state.lock().capture_pending_traces {
462 let mut pending_traces = String::new();
463 for (_, trace) in mem::take(&mut self.state.lock().pending_traces) {
464 writeln!(pending_traces, "{:?}", exclude_wakers_from_trace(trace)).unwrap();
465 }
466 panic!("Parking forbidden. Pending traces:\n{}", pending_traces);
467 } else {
468 panic!(
469 "Parking forbidden. Re-run with {PENDING_TRACES_VAR_NAME}=1 to show pending traces"
470 );
471 }
472 }
473}
474
475fn assert_correct_thread(expected: &Thread, state: &Arc<Mutex<SchedulerState>>) {
476 let current_thread = thread::current();
477 let mut state = state.lock();
478 if state.parking_allowed_once {
479 return;
480 }
481 if current_thread.id() == expected.id() {
482 return;
483 }
484
485 let message = format!(
486 "Detected activity on thread {:?} {:?}, but test scheduler is running on {:?} {:?}. Your test is not deterministic.",
487 current_thread.name(),
488 current_thread.id(),
489 expected.name(),
490 expected.id(),
491 );
492 let backtrace = Backtrace::new();
493 if state.finished {
494 panic!("{}", message);
495 } else {
496 state.non_determinism_error = Some((message, backtrace))
497 }
498}
499
500impl Scheduler for TestScheduler {
501 /// Block until the given future completes, with an optional timeout. If the
502 /// future is unable to make progress at any moment before the timeout and
503 /// no other tasks or timers remain, we panic unless parking is allowed. If
504 /// parking is allowed, we block up to the timeout or indefinitely if none
505 /// is provided. This is to allow testing a mix of deterministic and
506 /// non-deterministic async behavior, such as when interacting with I/O in
507 /// an otherwise deterministic test.
508 fn block(
509 &self,
510 session_id: Option<SessionId>,
511 mut future: Pin<&mut dyn Future<Output = ()>>,
512 timeout: Option<Duration>,
513 ) -> bool {
514 if let Some(session_id) = session_id {
515 self.state.lock().blocked_sessions.push(session_id);
516 }
517
518 let deadline = timeout.map(|timeout| Instant::now() + timeout);
519 let awoken = Arc::new(AtomicBool::new(false));
520 let waker = Box::new(TracingWaker {
521 id: None,
522 awoken: awoken.clone(),
523 thread: self.thread.clone(),
524 state: self.state.clone(),
525 });
526 let waker = unsafe { Waker::new(Box::into_raw(waker) as *const (), &WAKER_VTABLE) };
527 let max_ticks = if timeout.is_some() {
528 self.rng
529 .lock()
530 .random_range(self.state.lock().timeout_ticks.clone())
531 } else {
532 usize::MAX
533 };
534 let mut cx = Context::from_waker(&waker);
535
536 let mut completed = false;
537 for _ in 0..max_ticks {
538 match future.as_mut().poll(&mut cx) {
539 Poll::Ready(()) => {
540 completed = true;
541 break;
542 }
543 Poll::Pending => {}
544 }
545
546 let mut stepped = None;
547 while self.rng.lock().random() {
548 let stepped = stepped.get_or_insert(false);
549 if self.step() {
550 *stepped = true;
551 } else {
552 break;
553 }
554 }
555
556 let stepped = stepped.unwrap_or(true);
557 let awoken = awoken.swap(false, SeqCst);
558 if !stepped && !awoken {
559 let parking_allowed = self.state.lock().allow_parking;
560 // In deterministic mode (parking forbidden), instantly jump to the next timer.
561 // In non-deterministic mode (parking allowed), let real time pass instead.
562 let advanced_to_timer = !parking_allowed && self.advance_clock_to_next_timer();
563 if !advanced_to_timer && !self.park(deadline) {
564 break;
565 }
566 }
567 }
568
569 if session_id.is_some() {
570 self.state.lock().blocked_sessions.pop();
571 }
572
573 completed
574 }
575
576 fn schedule_foreground(&self, session_id: SessionId, runnable: Runnable<RunnableMeta>) {
577 assert_correct_thread(&self.thread, &self.state);
578 let mut state = self.state.lock();
579 let ix = if state.randomize_order {
580 let start_ix = state
581 .runnables
582 .iter()
583 .rposition(|task| task.session_id == Some(session_id))
584 .map_or(0, |ix| ix + 1);
585 self.rng
586 .lock()
587 .random_range(start_ix..=state.runnables.len())
588 } else {
589 state.runnables.len()
590 };
591 state.runnables.insert(
592 ix,
593 ScheduledRunnable {
594 session_id: Some(session_id),
595 priority: Priority::default(),
596 runnable,
597 },
598 );
599 state.unparked = true;
600 drop(state);
601 self.thread.unpark();
602 }
603
604 fn schedule_background_with_priority(
605 &self,
606 runnable: Runnable<RunnableMeta>,
607 priority: Priority,
608 ) {
609 assert_correct_thread(&self.thread, &self.state);
610 let mut state = self.state.lock();
611 let ix = if state.randomize_order {
612 self.rng.lock().random_range(0..=state.runnables.len())
613 } else {
614 state.runnables.len()
615 };
616 state.runnables.insert(
617 ix,
618 ScheduledRunnable {
619 session_id: None,
620 priority,
621 runnable,
622 },
623 );
624 state.unparked = true;
625 drop(state);
626 self.thread.unpark();
627 }
628
629 fn spawn_realtime(&self, f: Box<dyn FnOnce() + Send>) {
630 std::thread::spawn(move || {
631 f();
632 });
633 }
634
635 #[track_caller]
636 fn timer(&self, duration: Duration) -> Timer {
637 let (tx, rx) = oneshot::channel();
638 let state = &mut *self.state.lock();
639 state.timers.push(ScheduledTimer {
640 expiration: self.clock.now() + duration,
641 _notify: tx,
642 });
643 state.timers.sort_by_key(|timer| timer.expiration);
644 Timer(rx)
645 }
646
647 fn clock(&self) -> Arc<dyn Clock> {
648 self.clock.clone()
649 }
650
651 fn as_test(&self) -> Option<&TestScheduler> {
652 Some(self)
653 }
654}
655
656#[derive(Clone, Debug)]
657pub struct TestSchedulerConfig {
658 pub seed: u64,
659 pub randomize_order: bool,
660 pub allow_parking: bool,
661 pub capture_pending_traces: bool,
662 pub timeout_ticks: RangeInclusive<usize>,
663}
664
665impl TestSchedulerConfig {
666 pub fn with_seed(seed: u64) -> Self {
667 Self {
668 seed,
669 ..Default::default()
670 }
671 }
672}
673
674impl Default for TestSchedulerConfig {
675 fn default() -> Self {
676 Self {
677 seed: 0,
678 randomize_order: true,
679 allow_parking: false,
680 capture_pending_traces: env::var(PENDING_TRACES_VAR_NAME)
681 .map_or(false, |var| var == "1" || var == "true"),
682 timeout_ticks: 1..=1000,
683 }
684 }
685}
686
687struct ScheduledRunnable {
688 session_id: Option<SessionId>,
689 priority: Priority,
690 runnable: Runnable<RunnableMeta>,
691}
692
693impl ScheduledRunnable {
694 fn run(self) {
695 self.runnable.run();
696 }
697}
698
699struct ScheduledTimer {
700 expiration: Instant,
701 _notify: oneshot::Sender<()>,
702}
703
704struct SchedulerState {
705 runnables: VecDeque<ScheduledRunnable>,
706 timers: Vec<ScheduledTimer>,
707 blocked_sessions: Vec<SessionId>,
708 randomize_order: bool,
709 allow_parking: bool,
710 timeout_ticks: RangeInclusive<usize>,
711 next_session_id: SessionId,
712 capture_pending_traces: bool,
713 next_trace_id: TraceId,
714 pending_traces: BTreeMap<TraceId, Backtrace>,
715 is_main_thread: bool,
716 non_determinism_error: Option<(String, Backtrace)>,
717 parking_allowed_once: bool,
718 finished: bool,
719 unparked: bool,
720}
721
722const WAKER_VTABLE: RawWakerVTable = RawWakerVTable::new(
723 TracingWaker::clone_raw,
724 TracingWaker::wake_raw,
725 TracingWaker::wake_by_ref_raw,
726 TracingWaker::drop_raw,
727);
728
729#[derive(Copy, Clone, Eq, PartialEq, PartialOrd, Ord)]
730struct TraceId(usize);
731
732struct TracingWaker {
733 id: Option<TraceId>,
734 awoken: Arc<AtomicBool>,
735 thread: Thread,
736 state: Arc<Mutex<SchedulerState>>,
737}
738
739impl Clone for TracingWaker {
740 fn clone(&self) -> Self {
741 let mut state = self.state.lock();
742 let id = if state.capture_pending_traces {
743 let id = state.next_trace_id;
744 state.next_trace_id.0 += 1;
745 state.pending_traces.insert(id, Backtrace::new_unresolved());
746 Some(id)
747 } else {
748 None
749 };
750 Self {
751 id,
752 awoken: self.awoken.clone(),
753 thread: self.thread.clone(),
754 state: self.state.clone(),
755 }
756 }
757}
758
759impl Drop for TracingWaker {
760 fn drop(&mut self) {
761 assert_correct_thread(&self.thread, &self.state);
762
763 if let Some(id) = self.id {
764 self.state.lock().pending_traces.remove(&id);
765 }
766 }
767}
768
769impl TracingWaker {
770 fn wake(self) {
771 self.wake_by_ref();
772 }
773
774 fn wake_by_ref(&self) {
775 assert_correct_thread(&self.thread, &self.state);
776
777 let mut state = self.state.lock();
778 if let Some(id) = self.id {
779 state.pending_traces.remove(&id);
780 }
781 state.unparked = true;
782 drop(state);
783 self.awoken.store(true, SeqCst);
784 self.thread.unpark();
785 }
786
787 fn clone_raw(waker: *const ()) -> RawWaker {
788 let waker = waker as *const TracingWaker;
789 let waker = unsafe { &*waker };
790 RawWaker::new(
791 Box::into_raw(Box::new(waker.clone())) as *const (),
792 &WAKER_VTABLE,
793 )
794 }
795
796 fn wake_raw(waker: *const ()) {
797 let waker = unsafe { Box::from_raw(waker as *mut TracingWaker) };
798 waker.wake();
799 }
800
801 fn wake_by_ref_raw(waker: *const ()) {
802 let waker = waker as *const TracingWaker;
803 let waker = unsafe { &*waker };
804 waker.wake_by_ref();
805 }
806
807 fn drop_raw(waker: *const ()) {
808 let waker = unsafe { Box::from_raw(waker as *mut TracingWaker) };
809 drop(waker);
810 }
811}
812
813pub struct Yield(usize);
814
815/// A wrapper around `Arc<Mutex<StdRng>>` that provides convenient methods
816/// for random number generation without requiring explicit locking.
817#[derive(Clone)]
818pub struct SharedRng(Arc<Mutex<StdRng>>);
819
820impl SharedRng {
821 /// Lock the inner RNG for direct access. Use this when you need multiple
822 /// random operations without re-locking between each one.
823 pub fn lock(&self) -> MutexGuard<'_, StdRng> {
824 self.0.lock()
825 }
826
827 /// Generate a random value in the given range.
828 pub fn random_range<T, R>(&self, range: R) -> T
829 where
830 T: SampleUniform,
831 R: SampleRange<T>,
832 {
833 self.0.lock().random_range(range)
834 }
835
836 /// Generate a random boolean with the given probability of being true.
837 pub fn random_bool(&self, p: f64) -> bool {
838 self.0.lock().random_bool(p)
839 }
840
841 /// Generate a random value of the given type.
842 pub fn random<T>(&self) -> T
843 where
844 StandardUniform: Distribution<T>,
845 {
846 self.0.lock().random()
847 }
848
849 /// Generate a random ratio - true with probability `numerator/denominator`.
850 pub fn random_ratio(&self, numerator: u32, denominator: u32) -> bool {
851 self.0.lock().random_ratio(numerator, denominator)
852 }
853}
854
855impl Future for Yield {
856 type Output = ();
857
858 fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
859 if self.0 == 0 {
860 Poll::Ready(())
861 } else {
862 self.0 -= 1;
863 cx.waker().wake_by_ref();
864 Poll::Pending
865 }
866 }
867}
868
869fn exclude_wakers_from_trace(mut trace: Backtrace) -> Backtrace {
870 trace.resolve();
871 let mut frames: Vec<BacktraceFrame> = trace.into();
872 let waker_clone_frame_ix = frames.iter().position(|frame| {
873 frame.symbols().iter().any(|symbol| {
874 symbol
875 .name()
876 .is_some_and(|name| format!("{name:#?}") == type_name_of_val(&Waker::clone))
877 })
878 });
879
880 if let Some(waker_clone_frame_ix) = waker_clone_frame_ix {
881 frames.drain(..waker_clone_frame_ix + 1);
882 }
883
884 Backtrace::from(frames)
885}