1use crate::{PlatformDispatcher, TaskLabel};
2use async_task::Runnable;
3use backtrace::Backtrace;
4use collections::{HashMap, HashSet, VecDeque};
5use parking::Unparker;
6use parking_lot::Mutex;
7use rand::prelude::*;
8use std::{
9 future::Future,
10 ops::RangeInclusive,
11 pin::Pin,
12 sync::Arc,
13 task::{Context, Poll},
14 time::{Duration, Instant},
15};
16use util::post_inc;
17
18#[derive(Copy, Clone, PartialEq, Eq, Hash)]
19struct TestDispatcherId(usize);
20
21#[doc(hidden)]
22pub struct TestDispatcher {
23 id: TestDispatcherId,
24 state: Arc<Mutex<TestDispatcherState>>,
25}
26
27struct TestDispatcherState {
28 random: StdRng,
29 foreground: HashMap<TestDispatcherId, VecDeque<Runnable>>,
30 background: Vec<Runnable>,
31 deprioritized_background: Vec<Runnable>,
32 delayed: Vec<(Duration, Runnable)>,
33 start_time: Instant,
34 time: Duration,
35 is_main_thread: bool,
36 next_id: TestDispatcherId,
37 allow_parking: bool,
38 waiting_hint: Option<String>,
39 waiting_backtrace: Option<Backtrace>,
40 deprioritized_task_labels: HashSet<TaskLabel>,
41 block_on_ticks: RangeInclusive<usize>,
42 last_parked: Option<Unparker>,
43}
44
45impl TestDispatcher {
46 pub fn new(random: StdRng) -> Self {
47 let state = TestDispatcherState {
48 random,
49 foreground: HashMap::default(),
50 background: Vec::new(),
51 deprioritized_background: Vec::new(),
52 delayed: Vec::new(),
53 time: Duration::ZERO,
54 start_time: Instant::now(),
55 is_main_thread: true,
56 next_id: TestDispatcherId(1),
57 allow_parking: false,
58 waiting_hint: None,
59 waiting_backtrace: None,
60 deprioritized_task_labels: Default::default(),
61 block_on_ticks: 0..=1000,
62 last_parked: None,
63 };
64
65 TestDispatcher {
66 id: TestDispatcherId(0),
67 state: Arc::new(Mutex::new(state)),
68 }
69 }
70
71 pub fn advance_clock(&self, by: Duration) {
72 let new_now = self.state.lock().time + by;
73 loop {
74 self.run_until_parked();
75 let state = self.state.lock();
76 let next_due_time = state.delayed.first().map(|(time, _)| *time);
77 drop(state);
78 if let Some(due_time) = next_due_time
79 && due_time <= new_now
80 {
81 self.state.lock().time = due_time;
82 continue;
83 }
84 break;
85 }
86 self.state.lock().time = new_now;
87 }
88
89 pub fn advance_clock_to_next_delayed(&self) -> bool {
90 let next_due_time = self.state.lock().delayed.first().map(|(time, _)| *time);
91 if let Some(next_due_time) = next_due_time {
92 self.state.lock().time = next_due_time;
93 return true;
94 }
95 false
96 }
97
98 pub fn simulate_random_delay(&self) -> impl 'static + Send + Future<Output = ()> + use<> {
99 struct YieldNow {
100 pub(crate) count: usize,
101 }
102
103 impl Future for YieldNow {
104 type Output = ();
105
106 fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
107 if self.count > 0 {
108 self.count -= 1;
109 cx.waker().wake_by_ref();
110 Poll::Pending
111 } else {
112 Poll::Ready(())
113 }
114 }
115 }
116
117 YieldNow {
118 count: self.state.lock().random.random_range(0..10),
119 }
120 }
121
122 pub fn tick(&self, background_only: bool) -> bool {
123 let mut state = self.state.lock();
124
125 while let Some((deadline, _)) = state.delayed.first() {
126 if *deadline > state.time {
127 break;
128 }
129 let (_, runnable) = state.delayed.remove(0);
130 state.background.push(runnable);
131 }
132
133 let foreground_len: usize = if background_only {
134 0
135 } else {
136 state
137 .foreground
138 .values()
139 .map(|runnables| runnables.len())
140 .sum()
141 };
142 let background_len = state.background.len();
143
144 let runnable;
145 let main_thread;
146 if foreground_len == 0 && background_len == 0 {
147 let deprioritized_background_len = state.deprioritized_background.len();
148 if deprioritized_background_len == 0 {
149 return false;
150 }
151 let ix = state.random.random_range(0..deprioritized_background_len);
152 main_thread = false;
153 runnable = state.deprioritized_background.swap_remove(ix);
154 } else {
155 main_thread = state.random.random_ratio(
156 foreground_len as u32,
157 (foreground_len + background_len) as u32,
158 );
159 if main_thread {
160 let state = &mut *state;
161 runnable = state
162 .foreground
163 .values_mut()
164 .filter(|runnables| !runnables.is_empty())
165 .choose(&mut state.random)
166 .unwrap()
167 .pop_front()
168 .unwrap();
169 } else {
170 let ix = state.random.random_range(0..background_len);
171 runnable = state.background.swap_remove(ix);
172 };
173 };
174
175 let was_main_thread = state.is_main_thread;
176 state.is_main_thread = main_thread;
177 drop(state);
178 runnable.run();
179 self.state.lock().is_main_thread = was_main_thread;
180
181 true
182 }
183
184 pub fn deprioritize(&self, task_label: TaskLabel) {
185 self.state
186 .lock()
187 .deprioritized_task_labels
188 .insert(task_label);
189 }
190
191 pub fn run_until_parked(&self) {
192 while self.tick(false) {}
193 }
194
195 pub fn parking_allowed(&self) -> bool {
196 self.state.lock().allow_parking
197 }
198
199 pub fn allow_parking(&self) {
200 self.state.lock().allow_parking = true
201 }
202
203 pub fn forbid_parking(&self) {
204 self.state.lock().allow_parking = false
205 }
206
207 pub fn set_waiting_hint(&self, msg: Option<String>) {
208 self.state.lock().waiting_hint = msg
209 }
210
211 pub fn waiting_hint(&self) -> Option<String> {
212 self.state.lock().waiting_hint.clone()
213 }
214
215 pub fn start_waiting(&self) {
216 self.state.lock().waiting_backtrace = Some(Backtrace::new_unresolved());
217 }
218
219 pub fn finish_waiting(&self) {
220 self.state.lock().waiting_backtrace.take();
221 }
222
223 pub fn waiting_backtrace(&self) -> Option<Backtrace> {
224 self.state.lock().waiting_backtrace.take().map(|mut b| {
225 b.resolve();
226 b
227 })
228 }
229
230 pub fn rng(&self) -> StdRng {
231 self.state.lock().random.clone()
232 }
233
234 pub fn set_block_on_ticks(&self, range: std::ops::RangeInclusive<usize>) {
235 self.state.lock().block_on_ticks = range;
236 }
237
238 pub fn gen_block_on_ticks(&self) -> usize {
239 let mut lock = self.state.lock();
240 let block_on_ticks = lock.block_on_ticks.clone();
241 lock.random.random_range(block_on_ticks)
242 }
243 pub fn unpark_last(&self) {
244 self.state
245 .lock()
246 .last_parked
247 .take()
248 .as_ref()
249 .map(Unparker::unpark);
250 }
251
252 pub fn set_unparker(&self, unparker: Unparker) {
253 let last = { self.state.lock().last_parked.replace(unparker) };
254 if let Some(last) = last {
255 last.unpark();
256 }
257 }
258}
259
260impl Clone for TestDispatcher {
261 fn clone(&self) -> Self {
262 let id = post_inc(&mut self.state.lock().next_id.0);
263 Self {
264 id: TestDispatcherId(id),
265 state: self.state.clone(),
266 }
267 }
268}
269
270impl PlatformDispatcher for TestDispatcher {
271 fn is_main_thread(&self) -> bool {
272 self.state.lock().is_main_thread
273 }
274
275 fn now(&self) -> Instant {
276 let state = self.state.lock();
277 state.start_time + state.time
278 }
279
280 fn dispatch(&self, runnable: Runnable, label: Option<TaskLabel>) {
281 {
282 let mut state = self.state.lock();
283 if label.is_some_and(|label| state.deprioritized_task_labels.contains(&label)) {
284 state.deprioritized_background.push(runnable);
285 } else {
286 state.background.push(runnable);
287 }
288 }
289 self.unpark_last();
290 }
291
292 fn dispatch_on_main_thread(&self, runnable: Runnable) {
293 self.state
294 .lock()
295 .foreground
296 .entry(self.id)
297 .or_default()
298 .push_back(runnable);
299 self.unpark_last();
300 }
301
302 fn dispatch_after(&self, duration: std::time::Duration, runnable: Runnable) {
303 let mut state = self.state.lock();
304 let next_time = state.time + duration;
305 let ix = match state.delayed.binary_search_by_key(&next_time, |e| e.0) {
306 Ok(ix) | Err(ix) => ix,
307 };
308 state.delayed.insert(ix, (next_time, runnable));
309 }
310
311 fn as_test(&self) -> Option<&TestDispatcher> {
312 Some(self)
313 }
314}