1use crate::{PlatformDispatcher, TaskLabel};
2use async_task::Runnable;
3use backtrace::Backtrace;
4use collections::{HashMap, HashSet, VecDeque};
5use parking::{Parker, Unparker};
6use parking_lot::Mutex;
7use rand::prelude::*;
8use std::{
9 future::Future,
10 ops::RangeInclusive,
11 pin::Pin,
12 sync::Arc,
13 task::{Context, Poll},
14 time::{Duration, Instant},
15};
16use util::post_inc;
17
18#[derive(Copy, Clone, PartialEq, Eq, Hash)]
19struct TestDispatcherId(usize);
20
21#[doc(hidden)]
22pub struct TestDispatcher {
23 id: TestDispatcherId,
24 state: Arc<Mutex<TestDispatcherState>>,
25 parker: Arc<Mutex<Parker>>,
26 unparker: Unparker,
27}
28
29struct TestDispatcherState {
30 random: StdRng,
31 foreground: HashMap<TestDispatcherId, VecDeque<Runnable>>,
32 background: Vec<Runnable>,
33 deprioritized_background: Vec<Runnable>,
34 delayed: Vec<(Duration, Runnable)>,
35 start_time: Instant,
36 time: Duration,
37 is_main_thread: bool,
38 next_id: TestDispatcherId,
39 allow_parking: bool,
40 waiting_hint: Option<String>,
41 waiting_backtrace: Option<Backtrace>,
42 deprioritized_task_labels: HashSet<TaskLabel>,
43 block_on_ticks: RangeInclusive<usize>,
44}
45
46impl TestDispatcher {
47 pub fn new(random: StdRng) -> Self {
48 let (parker, unparker) = parking::pair();
49 let state = TestDispatcherState {
50 random,
51 foreground: HashMap::default(),
52 background: Vec::new(),
53 deprioritized_background: Vec::new(),
54 delayed: Vec::new(),
55 time: Duration::ZERO,
56 start_time: Instant::now(),
57 is_main_thread: true,
58 next_id: TestDispatcherId(1),
59 allow_parking: false,
60 waiting_hint: None,
61 waiting_backtrace: None,
62 deprioritized_task_labels: Default::default(),
63 block_on_ticks: 0..=1000,
64 };
65
66 TestDispatcher {
67 id: TestDispatcherId(0),
68 state: Arc::new(Mutex::new(state)),
69 parker: Arc::new(Mutex::new(parker)),
70 unparker,
71 }
72 }
73
74 pub fn advance_clock(&self, by: Duration) {
75 let new_now = self.state.lock().time + by;
76 loop {
77 self.run_until_parked();
78 let state = self.state.lock();
79 let next_due_time = state.delayed.first().map(|(time, _)| *time);
80 drop(state);
81 if let Some(due_time) = next_due_time
82 && due_time <= new_now
83 {
84 self.state.lock().time = due_time;
85 continue;
86 }
87 break;
88 }
89 self.state.lock().time = new_now;
90 }
91
92 pub fn advance_clock_to_next_delayed(&self) -> bool {
93 let next_due_time = self.state.lock().delayed.first().map(|(time, _)| *time);
94 if let Some(next_due_time) = next_due_time {
95 self.state.lock().time = next_due_time;
96 return true;
97 }
98 false
99 }
100
101 pub fn simulate_random_delay(&self) -> impl 'static + Send + Future<Output = ()> + use<> {
102 struct YieldNow {
103 pub(crate) count: usize,
104 }
105
106 impl Future for YieldNow {
107 type Output = ();
108
109 fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
110 if self.count > 0 {
111 self.count -= 1;
112 cx.waker().wake_by_ref();
113 Poll::Pending
114 } else {
115 Poll::Ready(())
116 }
117 }
118 }
119
120 YieldNow {
121 count: self.state.lock().random.random_range(0..10),
122 }
123 }
124
125 pub fn tick(&self, background_only: bool) -> bool {
126 let mut state = self.state.lock();
127
128 while let Some((deadline, _)) = state.delayed.first() {
129 if *deadline > state.time {
130 break;
131 }
132 let (_, runnable) = state.delayed.remove(0);
133 state.background.push(runnable);
134 }
135
136 let foreground_len: usize = if background_only {
137 0
138 } else {
139 state
140 .foreground
141 .values()
142 .map(|runnables| runnables.len())
143 .sum()
144 };
145 let background_len = state.background.len();
146
147 let runnable;
148 let main_thread;
149 if foreground_len == 0 && background_len == 0 {
150 let deprioritized_background_len = state.deprioritized_background.len();
151 if deprioritized_background_len == 0 {
152 return false;
153 }
154 let ix = state.random.random_range(0..deprioritized_background_len);
155 main_thread = false;
156 runnable = state.deprioritized_background.swap_remove(ix);
157 } else {
158 main_thread = state.random.random_ratio(
159 foreground_len as u32,
160 (foreground_len + background_len) as u32,
161 );
162 if main_thread {
163 let state = &mut *state;
164 runnable = state
165 .foreground
166 .values_mut()
167 .filter(|runnables| !runnables.is_empty())
168 .choose(&mut state.random)
169 .unwrap()
170 .pop_front()
171 .unwrap();
172 } else {
173 let ix = state.random.random_range(0..background_len);
174 runnable = state.background.swap_remove(ix);
175 };
176 };
177
178 let was_main_thread = state.is_main_thread;
179 state.is_main_thread = main_thread;
180 drop(state);
181 runnable.run();
182 self.state.lock().is_main_thread = was_main_thread;
183
184 true
185 }
186
187 pub fn deprioritize(&self, task_label: TaskLabel) {
188 self.state
189 .lock()
190 .deprioritized_task_labels
191 .insert(task_label);
192 }
193
194 pub fn run_until_parked(&self) {
195 while self.tick(false) {}
196 }
197
198 pub fn parking_allowed(&self) -> bool {
199 self.state.lock().allow_parking
200 }
201
202 pub fn allow_parking(&self) {
203 self.state.lock().allow_parking = true
204 }
205
206 pub fn forbid_parking(&self) {
207 self.state.lock().allow_parking = false
208 }
209
210 pub fn set_waiting_hint(&self, msg: Option<String>) {
211 self.state.lock().waiting_hint = msg
212 }
213
214 pub fn waiting_hint(&self) -> Option<String> {
215 self.state.lock().waiting_hint.clone()
216 }
217
218 pub fn start_waiting(&self) {
219 self.state.lock().waiting_backtrace = Some(Backtrace::new_unresolved());
220 }
221
222 pub fn finish_waiting(&self) {
223 self.state.lock().waiting_backtrace.take();
224 }
225
226 pub fn waiting_backtrace(&self) -> Option<Backtrace> {
227 self.state.lock().waiting_backtrace.take().map(|mut b| {
228 b.resolve();
229 b
230 })
231 }
232
233 pub fn rng(&self) -> StdRng {
234 self.state.lock().random.clone()
235 }
236
237 pub fn set_block_on_ticks(&self, range: std::ops::RangeInclusive<usize>) {
238 self.state.lock().block_on_ticks = range;
239 }
240
241 pub fn gen_block_on_ticks(&self) -> usize {
242 let mut lock = self.state.lock();
243 let block_on_ticks = lock.block_on_ticks.clone();
244 lock.random.random_range(block_on_ticks)
245 }
246}
247
248impl Clone for TestDispatcher {
249 fn clone(&self) -> Self {
250 let id = post_inc(&mut self.state.lock().next_id.0);
251 Self {
252 id: TestDispatcherId(id),
253 state: self.state.clone(),
254 parker: self.parker.clone(),
255 unparker: self.unparker.clone(),
256 }
257 }
258}
259
260impl PlatformDispatcher for TestDispatcher {
261 fn is_main_thread(&self) -> bool {
262 self.state.lock().is_main_thread
263 }
264
265 fn now(&self) -> Instant {
266 let state = self.state.lock();
267 state.start_time + state.time
268 }
269
270 fn dispatch(&self, runnable: Runnable, label: Option<TaskLabel>) {
271 {
272 let mut state = self.state.lock();
273 if label.is_some_and(|label| state.deprioritized_task_labels.contains(&label)) {
274 state.deprioritized_background.push(runnable);
275 } else {
276 state.background.push(runnable);
277 }
278 }
279 self.unparker.unpark();
280 }
281
282 fn dispatch_on_main_thread(&self, runnable: Runnable) {
283 self.state
284 .lock()
285 .foreground
286 .entry(self.id)
287 .or_default()
288 .push_back(runnable);
289 self.unparker.unpark();
290 }
291
292 fn dispatch_after(&self, duration: std::time::Duration, runnable: Runnable) {
293 let mut state = self.state.lock();
294 let next_time = state.time + duration;
295 let ix = match state.delayed.binary_search_by_key(&next_time, |e| e.0) {
296 Ok(ix) | Err(ix) => ix,
297 };
298 state.delayed.insert(ix, (next_time, runnable));
299 }
300 fn park(&self, _: Option<std::time::Duration>) -> bool {
301 self.parker.lock().park();
302 true
303 }
304
305 fn unparker(&self) -> Unparker {
306 self.unparker.clone()
307 }
308
309 fn as_test(&self) -> Option<&TestDispatcher> {
310 Some(self)
311 }
312}