1use crate::{PlatformDispatcher, TaskLabel};
2use async_task::Runnable;
3use backtrace::Backtrace;
4use collections::{HashMap, HashSet, VecDeque};
5use parking::{Parker, Unparker};
6use parking_lot::Mutex;
7use rand::prelude::*;
8use std::{
9 future::Future,
10 ops::RangeInclusive,
11 pin::Pin,
12 sync::Arc,
13 task::{Context, Poll},
14 time::{Duration, Instant},
15};
16use util::post_inc;
17
18#[derive(Copy, Clone, PartialEq, Eq, Hash)]
19struct TestDispatcherId(usize);
20
21#[doc(hidden)]
22pub struct TestDispatcher {
23 id: TestDispatcherId,
24 state: Arc<Mutex<TestDispatcherState>>,
25 parker: Arc<Mutex<Parker>>,
26 unparker: Unparker,
27}
28
29struct TestDispatcherState {
30 random: StdRng,
31 foreground: HashMap<TestDispatcherId, VecDeque<Runnable>>,
32 background: Vec<Runnable>,
33 deprioritized_background: Vec<Runnable>,
34 delayed: Vec<(Duration, Runnable)>,
35 start_time: Instant,
36 time: Duration,
37 is_main_thread: bool,
38 next_id: TestDispatcherId,
39 allow_parking: bool,
40 waiting_hint: Option<String>,
41 waiting_backtrace: Option<Backtrace>,
42 deprioritized_task_labels: HashSet<TaskLabel>,
43 block_on_ticks: RangeInclusive<usize>,
44}
45
46impl TestDispatcher {
47 pub fn new(random: StdRng) -> Self {
48 let (parker, unparker) = parking::pair();
49 let state = TestDispatcherState {
50 random,
51 foreground: HashMap::default(),
52 background: Vec::new(),
53 deprioritized_background: Vec::new(),
54 delayed: Vec::new(),
55 time: Duration::ZERO,
56 start_time: Instant::now(),
57 is_main_thread: true,
58 next_id: TestDispatcherId(1),
59 allow_parking: false,
60 waiting_hint: None,
61 waiting_backtrace: None,
62 deprioritized_task_labels: Default::default(),
63 block_on_ticks: 0..=1000,
64 };
65
66 TestDispatcher {
67 id: TestDispatcherId(0),
68 state: Arc::new(Mutex::new(state)),
69 parker: Arc::new(Mutex::new(parker)),
70 unparker,
71 }
72 }
73
74 pub fn advance_clock(&self, by: Duration) {
75 let new_now = self.state.lock().time + by;
76 loop {
77 self.run_until_parked();
78 let state = self.state.lock();
79 let next_due_time = state.delayed.first().map(|(time, _)| *time);
80 drop(state);
81 if let Some(due_time) = next_due_time
82 && due_time <= new_now {
83 self.state.lock().time = due_time;
84 continue;
85 }
86 break;
87 }
88 self.state.lock().time = new_now;
89 }
90
91 pub fn advance_clock_to_next_delayed(&self) -> bool {
92 let next_due_time = self.state.lock().delayed.first().map(|(time, _)| *time);
93 if let Some(next_due_time) = next_due_time {
94 self.state.lock().time = next_due_time;
95 return true;
96 }
97 false
98 }
99
100 pub fn simulate_random_delay(&self) -> impl 'static + Send + Future<Output = ()> + use<> {
101 struct YieldNow {
102 pub(crate) count: usize,
103 }
104
105 impl Future for YieldNow {
106 type Output = ();
107
108 fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
109 if self.count > 0 {
110 self.count -= 1;
111 cx.waker().wake_by_ref();
112 Poll::Pending
113 } else {
114 Poll::Ready(())
115 }
116 }
117 }
118
119 YieldNow {
120 count: self.state.lock().random.gen_range(0..10),
121 }
122 }
123
124 pub fn tick(&self, background_only: bool) -> bool {
125 let mut state = self.state.lock();
126
127 while let Some((deadline, _)) = state.delayed.first() {
128 if *deadline > state.time {
129 break;
130 }
131 let (_, runnable) = state.delayed.remove(0);
132 state.background.push(runnable);
133 }
134
135 let foreground_len: usize = if background_only {
136 0
137 } else {
138 state
139 .foreground
140 .values()
141 .map(|runnables| runnables.len())
142 .sum()
143 };
144 let background_len = state.background.len();
145
146 let runnable;
147 let main_thread;
148 if foreground_len == 0 && background_len == 0 {
149 let deprioritized_background_len = state.deprioritized_background.len();
150 if deprioritized_background_len == 0 {
151 return false;
152 }
153 let ix = state.random.gen_range(0..deprioritized_background_len);
154 main_thread = false;
155 runnable = state.deprioritized_background.swap_remove(ix);
156 } else {
157 main_thread = state.random.gen_ratio(
158 foreground_len as u32,
159 (foreground_len + background_len) as u32,
160 );
161 if main_thread {
162 let state = &mut *state;
163 runnable = state
164 .foreground
165 .values_mut()
166 .filter(|runnables| !runnables.is_empty())
167 .choose(&mut state.random)
168 .unwrap()
169 .pop_front()
170 .unwrap();
171 } else {
172 let ix = state.random.gen_range(0..background_len);
173 runnable = state.background.swap_remove(ix);
174 };
175 };
176
177 let was_main_thread = state.is_main_thread;
178 state.is_main_thread = main_thread;
179 drop(state);
180 runnable.run();
181 self.state.lock().is_main_thread = was_main_thread;
182
183 true
184 }
185
186 pub fn deprioritize(&self, task_label: TaskLabel) {
187 self.state
188 .lock()
189 .deprioritized_task_labels
190 .insert(task_label);
191 }
192
193 pub fn run_until_parked(&self) {
194 while self.tick(false) {}
195 }
196
197 pub fn parking_allowed(&self) -> bool {
198 self.state.lock().allow_parking
199 }
200
201 pub fn allow_parking(&self) {
202 self.state.lock().allow_parking = true
203 }
204
205 pub fn forbid_parking(&self) {
206 self.state.lock().allow_parking = false
207 }
208
209 pub fn set_waiting_hint(&self, msg: Option<String>) {
210 self.state.lock().waiting_hint = msg
211 }
212
213 pub fn waiting_hint(&self) -> Option<String> {
214 self.state.lock().waiting_hint.clone()
215 }
216
217 pub fn start_waiting(&self) {
218 self.state.lock().waiting_backtrace = Some(Backtrace::new_unresolved());
219 }
220
221 pub fn finish_waiting(&self) {
222 self.state.lock().waiting_backtrace.take();
223 }
224
225 pub fn waiting_backtrace(&self) -> Option<Backtrace> {
226 self.state.lock().waiting_backtrace.take().map(|mut b| {
227 b.resolve();
228 b
229 })
230 }
231
232 pub fn rng(&self) -> StdRng {
233 self.state.lock().random.clone()
234 }
235
236 pub fn set_block_on_ticks(&self, range: std::ops::RangeInclusive<usize>) {
237 self.state.lock().block_on_ticks = range;
238 }
239
240 pub fn gen_block_on_ticks(&self) -> usize {
241 let mut lock = self.state.lock();
242 let block_on_ticks = lock.block_on_ticks.clone();
243 lock.random.gen_range(block_on_ticks)
244 }
245}
246
247impl Clone for TestDispatcher {
248 fn clone(&self) -> Self {
249 let id = post_inc(&mut self.state.lock().next_id.0);
250 Self {
251 id: TestDispatcherId(id),
252 state: self.state.clone(),
253 parker: self.parker.clone(),
254 unparker: self.unparker.clone(),
255 }
256 }
257}
258
259impl PlatformDispatcher for TestDispatcher {
260 fn is_main_thread(&self) -> bool {
261 self.state.lock().is_main_thread
262 }
263
264 fn now(&self) -> Instant {
265 let state = self.state.lock();
266 state.start_time + state.time
267 }
268
269 fn dispatch(&self, runnable: Runnable, label: Option<TaskLabel>) {
270 {
271 let mut state = self.state.lock();
272 if label.map_or(false, |label| {
273 state.deprioritized_task_labels.contains(&label)
274 }) {
275 state.deprioritized_background.push(runnable);
276 } else {
277 state.background.push(runnable);
278 }
279 }
280 self.unparker.unpark();
281 }
282
283 fn dispatch_on_main_thread(&self, runnable: Runnable) {
284 self.state
285 .lock()
286 .foreground
287 .entry(self.id)
288 .or_default()
289 .push_back(runnable);
290 self.unparker.unpark();
291 }
292
293 fn dispatch_after(&self, duration: std::time::Duration, runnable: Runnable) {
294 let mut state = self.state.lock();
295 let next_time = state.time + duration;
296 let ix = match state.delayed.binary_search_by_key(&next_time, |e| e.0) {
297 Ok(ix) | Err(ix) => ix,
298 };
299 state.delayed.insert(ix, (next_time, runnable));
300 }
301 fn park(&self, _: Option<std::time::Duration>) -> bool {
302 self.parker.lock().park();
303 true
304 }
305
306 fn unparker(&self) -> Unparker {
307 self.unparker.clone()
308 }
309
310 fn as_test(&self) -> Option<&TestDispatcher> {
311 Some(self)
312 }
313}