1use crate::{PlatformDispatcher, TaskLabel};
2use async_task::Runnable;
3use backtrace::Backtrace;
4use collections::{HashMap, HashSet, VecDeque};
5use parking::{Parker, Unparker};
6use parking_lot::Mutex;
7use rand::prelude::*;
8use std::{
9 future::Future,
10 ops::RangeInclusive,
11 pin::Pin,
12 sync::Arc,
13 task::{Context, Poll},
14 time::{Duration, Instant},
15};
16use util::post_inc;
17
18#[derive(Copy, Clone, PartialEq, Eq, Hash)]
19struct TestDispatcherId(usize);
20
21#[doc(hidden)]
22pub struct TestDispatcher {
23 id: TestDispatcherId,
24 state: Arc<Mutex<TestDispatcherState>>,
25 parker: Arc<Mutex<Parker>>,
26 unparker: Unparker,
27}
28
29struct TestDispatcherState {
30 random: StdRng,
31 foreground: HashMap<TestDispatcherId, VecDeque<Runnable>>,
32 background: Vec<Runnable>,
33 deprioritized_background: Vec<Runnable>,
34 delayed: Vec<(Duration, Runnable)>,
35 start_time: Instant,
36 time: Duration,
37 is_main_thread: bool,
38 next_id: TestDispatcherId,
39 allow_parking: bool,
40 waiting_hint: Option<String>,
41 waiting_backtrace: Option<Backtrace>,
42 deprioritized_task_labels: HashSet<TaskLabel>,
43 block_on_ticks: RangeInclusive<usize>,
44}
45
46impl TestDispatcher {
47 pub fn new(random: StdRng) -> Self {
48 let (parker, unparker) = parking::pair();
49 let state = TestDispatcherState {
50 random,
51 foreground: HashMap::default(),
52 background: Vec::new(),
53 deprioritized_background: Vec::new(),
54 delayed: Vec::new(),
55 time: Duration::ZERO,
56 start_time: Instant::now(),
57 is_main_thread: true,
58 next_id: TestDispatcherId(1),
59 allow_parking: false,
60 waiting_hint: None,
61 waiting_backtrace: None,
62 deprioritized_task_labels: Default::default(),
63 block_on_ticks: 0..=1000,
64 };
65
66 TestDispatcher {
67 id: TestDispatcherId(0),
68 state: Arc::new(Mutex::new(state)),
69 parker: Arc::new(Mutex::new(parker)),
70 unparker,
71 }
72 }
73
74 pub fn advance_clock(&self, by: Duration) {
75 let new_now = self.state.lock().time + by;
76 loop {
77 self.run_until_parked();
78 let state = self.state.lock();
79 let next_due_time = state.delayed.first().map(|(time, _)| *time);
80 drop(state);
81 if let Some(due_time) = next_due_time {
82 if due_time <= new_now {
83 self.state.lock().time = due_time;
84 continue;
85 }
86 }
87 break;
88 }
89 self.state.lock().time = new_now;
90 }
91
92 pub fn advance_clock_to_next_delayed(&self) -> bool {
93 let next_due_time = self.state.lock().delayed.first().map(|(time, _)| *time);
94 if let Some(next_due_time) = next_due_time {
95 self.state.lock().time = next_due_time;
96 return true;
97 }
98 false
99 }
100
101 pub fn simulate_random_delay(&self) -> impl 'static + Send + Future<Output = ()> + use<> {
102 struct YieldNow {
103 pub(crate) count: usize,
104 }
105
106 impl Future for YieldNow {
107 type Output = ();
108
109 fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
110 if self.count > 0 {
111 self.count -= 1;
112 cx.waker().wake_by_ref();
113 Poll::Pending
114 } else {
115 Poll::Ready(())
116 }
117 }
118 }
119
120 YieldNow {
121 count: self.state.lock().random.gen_range(0..10),
122 }
123 }
124
125 pub fn deprioritize(&self, task_label: TaskLabel) {
126 self.state
127 .lock()
128 .deprioritized_task_labels
129 .insert(task_label);
130 }
131
132 pub fn run_until_parked(&self) {
133 while self.tick(false) {}
134 }
135
136 pub fn parking_allowed(&self) -> bool {
137 self.state.lock().allow_parking
138 }
139
140 pub fn allow_parking(&self) {
141 self.state.lock().allow_parking = true
142 }
143
144 pub fn forbid_parking(&self) {
145 self.state.lock().allow_parking = false
146 }
147
148 pub fn set_waiting_hint(&self, msg: Option<String>) {
149 self.state.lock().waiting_hint = msg
150 }
151
152 pub fn waiting_hint(&self) -> Option<String> {
153 self.state.lock().waiting_hint.clone()
154 }
155
156 pub fn start_waiting(&self) {
157 self.state.lock().waiting_backtrace = Some(Backtrace::new_unresolved());
158 }
159
160 pub fn finish_waiting(&self) {
161 self.state.lock().waiting_backtrace.take();
162 }
163
164 pub fn waiting_backtrace(&self) -> Option<Backtrace> {
165 self.state.lock().waiting_backtrace.take().map(|mut b| {
166 b.resolve();
167 b
168 })
169 }
170
171 pub fn rng(&self) -> StdRng {
172 self.state.lock().random.clone()
173 }
174
175 pub fn set_block_on_ticks(&self, range: std::ops::RangeInclusive<usize>) {
176 self.state.lock().block_on_ticks = range;
177 }
178
179 pub fn gen_block_on_ticks(&self) -> usize {
180 let mut lock = self.state.lock();
181 let block_on_ticks = lock.block_on_ticks.clone();
182 lock.random.gen_range(block_on_ticks)
183 }
184}
185
186impl Clone for TestDispatcher {
187 fn clone(&self) -> Self {
188 let id = post_inc(&mut self.state.lock().next_id.0);
189 Self {
190 id: TestDispatcherId(id),
191 state: self.state.clone(),
192 parker: self.parker.clone(),
193 unparker: self.unparker.clone(),
194 }
195 }
196}
197
198impl PlatformDispatcher for TestDispatcher {
199 fn is_main_thread(&self) -> bool {
200 self.state.lock().is_main_thread
201 }
202
203 fn now(&self) -> Instant {
204 let state = self.state.lock();
205 state.start_time + state.time
206 }
207
208 fn tick(&self, background_only: bool) -> bool {
209 let mut state = self.state.lock();
210
211 while let Some((deadline, _)) = state.delayed.first() {
212 if *deadline > state.time {
213 break;
214 }
215 let (_, runnable) = state.delayed.remove(0);
216 state.background.push(runnable);
217 }
218
219 let foreground_len: usize = if background_only {
220 0
221 } else {
222 state
223 .foreground
224 .values()
225 .map(|runnables| runnables.len())
226 .sum()
227 };
228 let background_len = state.background.len();
229
230 let runnable;
231 let main_thread;
232 if foreground_len == 0 && background_len == 0 {
233 let deprioritized_background_len = state.deprioritized_background.len();
234 if deprioritized_background_len == 0 {
235 return false;
236 }
237 let ix = state.random.gen_range(0..deprioritized_background_len);
238 main_thread = false;
239 runnable = state.deprioritized_background.swap_remove(ix);
240 } else {
241 main_thread = state.random.gen_ratio(
242 foreground_len as u32,
243 (foreground_len + background_len) as u32,
244 );
245 if main_thread {
246 let state = &mut *state;
247 runnable = state
248 .foreground
249 .values_mut()
250 .filter(|runnables| !runnables.is_empty())
251 .choose(&mut state.random)
252 .unwrap()
253 .pop_front()
254 .unwrap();
255 } else {
256 let ix = state.random.gen_range(0..background_len);
257 runnable = state.background.swap_remove(ix);
258 };
259 };
260
261 let was_main_thread = state.is_main_thread;
262 state.is_main_thread = main_thread;
263 drop(state);
264 runnable.run();
265 self.state.lock().is_main_thread = was_main_thread;
266
267 true
268 }
269
270 fn dispatch(&self, runnable: Runnable, label: Option<TaskLabel>) {
271 {
272 let mut state = self.state.lock();
273 if label.map_or(false, |label| {
274 state.deprioritized_task_labels.contains(&label)
275 }) {
276 state.deprioritized_background.push(runnable);
277 } else {
278 state.background.push(runnable);
279 }
280 }
281 self.unparker.unpark();
282 }
283
284 fn dispatch_on_main_thread(&self, runnable: Runnable) {
285 self.state
286 .lock()
287 .foreground
288 .entry(self.id)
289 .or_default()
290 .push_back(runnable);
291 self.unparker.unpark();
292 }
293
294 fn dispatch_after(&self, duration: std::time::Duration, runnable: Runnable) {
295 let mut state = self.state.lock();
296 let next_time = state.time + duration;
297 let ix = match state.delayed.binary_search_by_key(&next_time, |e| e.0) {
298 Ok(ix) | Err(ix) => ix,
299 };
300 state.delayed.insert(ix, (next_time, runnable));
301 }
302 fn park(&self, _: Option<std::time::Duration>) -> bool {
303 self.parker.lock().park();
304 true
305 }
306
307 fn unparker(&self) -> Unparker {
308 self.unparker.clone()
309 }
310
311 fn as_test(&self) -> Option<&TestDispatcher> {
312 Some(self)
313 }
314}