1use crate::{PlatformDispatcher, TaskLabel};
2use async_task::Runnable;
3use backtrace::Backtrace;
4use collections::{HashMap, HashSet, VecDeque};
5use parking::{Parker, Unparker};
6use parking_lot::Mutex;
7use rand::prelude::*;
8use std::{
9 future::Future,
10 ops::RangeInclusive,
11 pin::Pin,
12 sync::Arc,
13 task::{Context, Poll},
14 time::Duration,
15};
16use util::post_inc;
17
18#[derive(Copy, Clone, PartialEq, Eq, Hash)]
19struct TestDispatcherId(usize);
20
21#[doc(hidden)]
22pub struct TestDispatcher {
23 id: TestDispatcherId,
24 state: Arc<Mutex<TestDispatcherState>>,
25 parker: Arc<Mutex<Parker>>,
26 unparker: Unparker,
27}
28
29struct TestDispatcherState {
30 random: StdRng,
31 foreground: HashMap<TestDispatcherId, VecDeque<Runnable>>,
32 background: Vec<Runnable>,
33 deprioritized_background: Vec<Runnable>,
34 delayed: Vec<(Duration, Runnable)>,
35 time: Duration,
36 is_main_thread: bool,
37 next_id: TestDispatcherId,
38 allow_parking: bool,
39 waiting_hint: Option<String>,
40 waiting_backtrace: Option<Backtrace>,
41 deprioritized_task_labels: HashSet<TaskLabel>,
42 block_on_ticks: RangeInclusive<usize>,
43}
44
45impl TestDispatcher {
46 pub fn new(random: StdRng) -> Self {
47 let (parker, unparker) = parking::pair();
48 let state = TestDispatcherState {
49 random,
50 foreground: HashMap::default(),
51 background: Vec::new(),
52 deprioritized_background: Vec::new(),
53 delayed: Vec::new(),
54 time: Duration::ZERO,
55 is_main_thread: true,
56 next_id: TestDispatcherId(1),
57 allow_parking: false,
58 waiting_hint: None,
59 waiting_backtrace: None,
60 deprioritized_task_labels: Default::default(),
61 block_on_ticks: 0..=1000,
62 };
63
64 TestDispatcher {
65 id: TestDispatcherId(0),
66 state: Arc::new(Mutex::new(state)),
67 parker: Arc::new(Mutex::new(parker)),
68 unparker,
69 }
70 }
71
72 pub fn advance_clock(&self, by: Duration) {
73 let new_now = self.state.lock().time + by;
74 loop {
75 self.run_until_parked();
76 let state = self.state.lock();
77 let next_due_time = state.delayed.first().map(|(time, _)| *time);
78 drop(state);
79 if let Some(due_time) = next_due_time {
80 if due_time <= new_now {
81 self.state.lock().time = due_time;
82 continue;
83 }
84 }
85 break;
86 }
87 self.state.lock().time = new_now;
88 }
89
90 pub fn simulate_random_delay(&self) -> impl 'static + Send + Future<Output = ()> {
91 struct YieldNow {
92 pub(crate) count: usize,
93 }
94
95 impl Future for YieldNow {
96 type Output = ();
97
98 fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
99 if self.count > 0 {
100 self.count -= 1;
101 cx.waker().wake_by_ref();
102 Poll::Pending
103 } else {
104 Poll::Ready(())
105 }
106 }
107 }
108
109 YieldNow {
110 count: self.state.lock().random.gen_range(0..10),
111 }
112 }
113
114 pub fn deprioritize(&self, task_label: TaskLabel) {
115 self.state
116 .lock()
117 .deprioritized_task_labels
118 .insert(task_label);
119 }
120
121 pub fn run_until_parked(&self) {
122 while self.tick(false) {}
123 }
124
125 pub fn parking_allowed(&self) -> bool {
126 self.state.lock().allow_parking
127 }
128
129 pub fn allow_parking(&self) {
130 self.state.lock().allow_parking = true
131 }
132
133 pub fn forbid_parking(&self) {
134 self.state.lock().allow_parking = false
135 }
136
137 pub fn set_waiting_hint(&self, msg: Option<String>) {
138 self.state.lock().waiting_hint = msg
139 }
140
141 pub fn waiting_hint(&self) -> Option<String> {
142 self.state.lock().waiting_hint.clone()
143 }
144
145 pub fn start_waiting(&self) {
146 self.state.lock().waiting_backtrace = Some(Backtrace::new_unresolved());
147 }
148
149 pub fn finish_waiting(&self) {
150 self.state.lock().waiting_backtrace.take();
151 }
152
153 pub fn waiting_backtrace(&self) -> Option<Backtrace> {
154 self.state.lock().waiting_backtrace.take().map(|mut b| {
155 b.resolve();
156 b
157 })
158 }
159
160 pub fn rng(&self) -> StdRng {
161 self.state.lock().random.clone()
162 }
163
164 pub fn set_block_on_ticks(&self, range: std::ops::RangeInclusive<usize>) {
165 self.state.lock().block_on_ticks = range;
166 }
167
168 pub fn gen_block_on_ticks(&self) -> usize {
169 let mut lock = self.state.lock();
170 let block_on_ticks = lock.block_on_ticks.clone();
171 lock.random.gen_range(block_on_ticks)
172 }
173}
174
175impl Clone for TestDispatcher {
176 fn clone(&self) -> Self {
177 let id = post_inc(&mut self.state.lock().next_id.0);
178 Self {
179 id: TestDispatcherId(id),
180 state: self.state.clone(),
181 parker: self.parker.clone(),
182 unparker: self.unparker.clone(),
183 }
184 }
185}
186
187impl PlatformDispatcher for TestDispatcher {
188 fn is_main_thread(&self) -> bool {
189 self.state.lock().is_main_thread
190 }
191
192 fn dispatch(&self, runnable: Runnable, label: Option<TaskLabel>) {
193 {
194 let mut state = self.state.lock();
195 if label.map_or(false, |label| {
196 state.deprioritized_task_labels.contains(&label)
197 }) {
198 state.deprioritized_background.push(runnable);
199 } else {
200 state.background.push(runnable);
201 }
202 }
203 self.unparker.unpark();
204 }
205
206 fn dispatch_on_main_thread(&self, runnable: Runnable) {
207 self.state
208 .lock()
209 .foreground
210 .entry(self.id)
211 .or_default()
212 .push_back(runnable);
213 self.unparker.unpark();
214 }
215
216 fn dispatch_after(&self, duration: std::time::Duration, runnable: Runnable) {
217 let mut state = self.state.lock();
218 let next_time = state.time + duration;
219 let ix = match state.delayed.binary_search_by_key(&next_time, |e| e.0) {
220 Ok(ix) | Err(ix) => ix,
221 };
222 state.delayed.insert(ix, (next_time, runnable));
223 }
224
225 fn tick(&self, background_only: bool) -> bool {
226 let mut state = self.state.lock();
227
228 while let Some((deadline, _)) = state.delayed.first() {
229 if *deadline > state.time {
230 break;
231 }
232 let (_, runnable) = state.delayed.remove(0);
233 state.background.push(runnable);
234 }
235
236 let foreground_len: usize = if background_only {
237 0
238 } else {
239 state
240 .foreground
241 .values()
242 .map(|runnables| runnables.len())
243 .sum()
244 };
245 let background_len = state.background.len();
246
247 let runnable;
248 let main_thread;
249 if foreground_len == 0 && background_len == 0 {
250 let deprioritized_background_len = state.deprioritized_background.len();
251 if deprioritized_background_len == 0 {
252 return false;
253 }
254 let ix = state.random.gen_range(0..deprioritized_background_len);
255 main_thread = false;
256 runnable = state.deprioritized_background.swap_remove(ix);
257 } else {
258 main_thread = state.random.gen_ratio(
259 foreground_len as u32,
260 (foreground_len + background_len) as u32,
261 );
262 if main_thread {
263 let state = &mut *state;
264 runnable = state
265 .foreground
266 .values_mut()
267 .filter(|runnables| !runnables.is_empty())
268 .choose(&mut state.random)
269 .unwrap()
270 .pop_front()
271 .unwrap();
272 } else {
273 let ix = state.random.gen_range(0..background_len);
274 runnable = state.background.swap_remove(ix);
275 };
276 };
277
278 let was_main_thread = state.is_main_thread;
279 state.is_main_thread = main_thread;
280 drop(state);
281 runnable.run();
282 self.state.lock().is_main_thread = was_main_thread;
283
284 true
285 }
286
287 fn park(&self) {
288 self.parker.lock().park();
289 }
290
291 fn unparker(&self) -> Unparker {
292 self.unparker.clone()
293 }
294
295 fn as_test(&self) -> Option<&TestDispatcher> {
296 Some(self)
297 }
298}