1use crate::{PlatformDispatcher, TaskLabel};
2use async_task::Runnable;
3use backtrace::Backtrace;
4use collections::{HashMap, HashSet, VecDeque};
5use parking::{Parker, Unparker};
6use parking_lot::Mutex;
7use rand::prelude::*;
8use std::{
9 future::Future,
10 ops::RangeInclusive,
11 pin::Pin,
12 sync::Arc,
13 task::{Context, Poll},
14 time::{Duration, Instant},
15};
16use util::post_inc;
17
18#[derive(Copy, Clone, PartialEq, Eq, Hash)]
19struct TestDispatcherId(usize);
20
21#[doc(hidden)]
22pub struct TestDispatcher {
23 id: TestDispatcherId,
24 state: Arc<Mutex<TestDispatcherState>>,
25 parker: Arc<Mutex<Parker>>,
26 unparker: Unparker,
27}
28
29struct TestDispatcherState {
30 random: StdRng,
31 foreground: HashMap<TestDispatcherId, VecDeque<Runnable>>,
32 background: Vec<Runnable>,
33 deprioritized_background: Vec<Runnable>,
34 delayed: Vec<(Duration, Runnable)>,
35 start_time: Instant,
36 time: Duration,
37 is_main_thread: bool,
38 next_id: TestDispatcherId,
39 allow_parking: bool,
40 waiting_hint: Option<String>,
41 waiting_backtrace: Option<Backtrace>,
42 deprioritized_task_labels: HashSet<TaskLabel>,
43 block_on_ticks: RangeInclusive<usize>,
44}
45
46impl TestDispatcher {
47 pub fn new(random: StdRng) -> Self {
48 let (parker, unparker) = parking::pair();
49 let state = TestDispatcherState {
50 random,
51 foreground: HashMap::default(),
52 background: Vec::new(),
53 deprioritized_background: Vec::new(),
54 delayed: Vec::new(),
55 time: Duration::ZERO,
56 start_time: Instant::now(),
57 is_main_thread: true,
58 next_id: TestDispatcherId(1),
59 allow_parking: false,
60 waiting_hint: None,
61 waiting_backtrace: None,
62 deprioritized_task_labels: Default::default(),
63 block_on_ticks: 0..=1000,
64 };
65
66 TestDispatcher {
67 id: TestDispatcherId(0),
68 state: Arc::new(Mutex::new(state)),
69 parker: Arc::new(Mutex::new(parker)),
70 unparker,
71 }
72 }
73
74 pub fn advance_clock(&self, by: Duration) {
75 let new_now = self.state.lock().time + by;
76 loop {
77 self.run_until_parked();
78 let state = self.state.lock();
79 let next_due_time = state.delayed.first().map(|(time, _)| *time);
80 drop(state);
81 if let Some(due_time) = next_due_time {
82 if due_time <= new_now {
83 self.state.lock().time = due_time;
84 continue;
85 }
86 }
87 break;
88 }
89 self.state.lock().time = new_now;
90 }
91
92 pub fn simulate_random_delay(&self) -> impl 'static + Send + Future<Output = ()> + use<> {
93 struct YieldNow {
94 pub(crate) count: usize,
95 }
96
97 impl Future for YieldNow {
98 type Output = ();
99
100 fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
101 if self.count > 0 {
102 self.count -= 1;
103 cx.waker().wake_by_ref();
104 Poll::Pending
105 } else {
106 Poll::Ready(())
107 }
108 }
109 }
110
111 YieldNow {
112 count: self.state.lock().random.gen_range(0..10),
113 }
114 }
115
116 pub fn tick(&self, background_only: bool) -> bool {
117 let mut state = self.state.lock();
118
119 while let Some((deadline, _)) = state.delayed.first() {
120 if *deadline > state.time {
121 break;
122 }
123 let (_, runnable) = state.delayed.remove(0);
124 state.background.push(runnable);
125 }
126
127 let foreground_len: usize = if background_only {
128 0
129 } else {
130 state
131 .foreground
132 .values()
133 .map(|runnables| runnables.len())
134 .sum()
135 };
136 let background_len = state.background.len();
137
138 let runnable;
139 let main_thread;
140 if foreground_len == 0 && background_len == 0 {
141 let deprioritized_background_len = state.deprioritized_background.len();
142 if deprioritized_background_len == 0 {
143 return false;
144 }
145 let ix = state.random.gen_range(0..deprioritized_background_len);
146 main_thread = false;
147 runnable = state.deprioritized_background.swap_remove(ix);
148 } else {
149 main_thread = state.random.gen_ratio(
150 foreground_len as u32,
151 (foreground_len + background_len) as u32,
152 );
153 if main_thread {
154 let state = &mut *state;
155 runnable = state
156 .foreground
157 .values_mut()
158 .filter(|runnables| !runnables.is_empty())
159 .choose(&mut state.random)
160 .unwrap()
161 .pop_front()
162 .unwrap();
163 } else {
164 let ix = state.random.gen_range(0..background_len);
165 runnable = state.background.swap_remove(ix);
166 };
167 };
168
169 let was_main_thread = state.is_main_thread;
170 state.is_main_thread = main_thread;
171 drop(state);
172 runnable.run();
173 self.state.lock().is_main_thread = was_main_thread;
174
175 true
176 }
177
178 pub fn deprioritize(&self, task_label: TaskLabel) {
179 self.state
180 .lock()
181 .deprioritized_task_labels
182 .insert(task_label);
183 }
184
185 pub fn run_until_parked(&self) {
186 while self.tick(false) {}
187 }
188
189 pub fn parking_allowed(&self) -> bool {
190 self.state.lock().allow_parking
191 }
192
193 pub fn allow_parking(&self) {
194 self.state.lock().allow_parking = true
195 }
196
197 pub fn forbid_parking(&self) {
198 self.state.lock().allow_parking = false
199 }
200
201 pub fn set_waiting_hint(&self, msg: Option<String>) {
202 self.state.lock().waiting_hint = msg
203 }
204
205 pub fn waiting_hint(&self) -> Option<String> {
206 self.state.lock().waiting_hint.clone()
207 }
208
209 pub fn start_waiting(&self) {
210 self.state.lock().waiting_backtrace = Some(Backtrace::new_unresolved());
211 }
212
213 pub fn finish_waiting(&self) {
214 self.state.lock().waiting_backtrace.take();
215 }
216
217 pub fn waiting_backtrace(&self) -> Option<Backtrace> {
218 self.state.lock().waiting_backtrace.take().map(|mut b| {
219 b.resolve();
220 b
221 })
222 }
223
224 pub fn rng(&self) -> StdRng {
225 self.state.lock().random.clone()
226 }
227
228 pub fn set_block_on_ticks(&self, range: std::ops::RangeInclusive<usize>) {
229 self.state.lock().block_on_ticks = range;
230 }
231
232 pub fn gen_block_on_ticks(&self) -> usize {
233 let mut lock = self.state.lock();
234 let block_on_ticks = lock.block_on_ticks.clone();
235 lock.random.gen_range(block_on_ticks)
236 }
237}
238
239impl Clone for TestDispatcher {
240 fn clone(&self) -> Self {
241 let id = post_inc(&mut self.state.lock().next_id.0);
242 Self {
243 id: TestDispatcherId(id),
244 state: self.state.clone(),
245 parker: self.parker.clone(),
246 unparker: self.unparker.clone(),
247 }
248 }
249}
250
251impl PlatformDispatcher for TestDispatcher {
252 fn is_main_thread(&self) -> bool {
253 self.state.lock().is_main_thread
254 }
255
256 fn now(&self) -> Instant {
257 let state = self.state.lock();
258 state.start_time + state.time
259 }
260
261 fn dispatch(&self, runnable: Runnable, label: Option<TaskLabel>) {
262 {
263 let mut state = self.state.lock();
264 if label.map_or(false, |label| {
265 state.deprioritized_task_labels.contains(&label)
266 }) {
267 state.deprioritized_background.push(runnable);
268 } else {
269 state.background.push(runnable);
270 }
271 }
272 self.unparker.unpark();
273 }
274
275 fn dispatch_on_main_thread(&self, runnable: Runnable) {
276 self.state
277 .lock()
278 .foreground
279 .entry(self.id)
280 .or_default()
281 .push_back(runnable);
282 self.unparker.unpark();
283 }
284
285 fn dispatch_after(&self, duration: std::time::Duration, runnable: Runnable) {
286 let mut state = self.state.lock();
287 let next_time = state.time + duration;
288 let ix = match state.delayed.binary_search_by_key(&next_time, |e| e.0) {
289 Ok(ix) | Err(ix) => ix,
290 };
291 state.delayed.insert(ix, (next_time, runnable));
292 }
293 fn park(&self, _: Option<std::time::Duration>) -> bool {
294 self.parker.lock().park();
295 true
296 }
297
298 fn unparker(&self) -> Unparker {
299 self.unparker.clone()
300 }
301
302 fn as_test(&self) -> Option<&TestDispatcher> {
303 Some(self)
304 }
305}