1use crate::{PlatformDispatcher, TaskLabel};
2use async_task::Runnable;
3use backtrace::Backtrace;
4use collections::{HashMap, HashSet, VecDeque};
5use parking::{Parker, Unparker};
6use parking_lot::Mutex;
7use rand::prelude::*;
8use std::{
9 future::Future,
10 ops::RangeInclusive,
11 pin::Pin,
12 sync::Arc,
13 task::{Context, Poll},
14 time::Duration,
15};
16use util::post_inc;
17
18#[derive(Copy, Clone, PartialEq, Eq, Hash)]
19struct TestDispatcherId(usize);
20
21#[doc(hidden)]
22pub struct TestDispatcher {
23 id: TestDispatcherId,
24 state: Arc<Mutex<TestDispatcherState>>,
25 parker: Arc<Mutex<Parker>>,
26 unparker: Unparker,
27}
28
29struct TestDispatcherState {
30 random: StdRng,
31 foreground: HashMap<TestDispatcherId, VecDeque<Runnable>>,
32 background: Vec<Runnable>,
33 deprioritized_background: Vec<Runnable>,
34 delayed: Vec<(Duration, Runnable)>,
35 time: Duration,
36 is_main_thread: bool,
37 next_id: TestDispatcherId,
38 allow_parking: bool,
39 waiting_hint: Option<String>,
40 waiting_backtrace: Option<Backtrace>,
41 deprioritized_task_labels: HashSet<TaskLabel>,
42 block_on_ticks: RangeInclusive<usize>,
43}
44
45impl TestDispatcher {
46 pub fn new(random: StdRng) -> Self {
47 let (parker, unparker) = parking::pair();
48 let state = TestDispatcherState {
49 random,
50 foreground: HashMap::default(),
51 background: Vec::new(),
52 deprioritized_background: Vec::new(),
53 delayed: Vec::new(),
54 time: Duration::ZERO,
55 is_main_thread: true,
56 next_id: TestDispatcherId(1),
57 allow_parking: false,
58 waiting_hint: None,
59 waiting_backtrace: None,
60 deprioritized_task_labels: Default::default(),
61 block_on_ticks: 0..=1000,
62 };
63
64 TestDispatcher {
65 id: TestDispatcherId(0),
66 state: Arc::new(Mutex::new(state)),
67 parker: Arc::new(Mutex::new(parker)),
68 unparker,
69 }
70 }
71
72 pub fn advance_clock(&self, by: Duration) {
73 let new_now = self.state.lock().time + by;
74 loop {
75 self.run_until_parked();
76 let state = self.state.lock();
77 let next_due_time = state.delayed.first().map(|(time, _)| *time);
78 drop(state);
79 if let Some(due_time) = next_due_time {
80 if due_time <= new_now {
81 self.state.lock().time = due_time;
82 continue;
83 }
84 }
85 break;
86 }
87 self.state.lock().time = new_now;
88 }
89
90 pub fn simulate_random_delay(&self) -> impl 'static + Send + Future<Output = ()> {
91 struct YieldNow {
92 pub(crate) count: usize,
93 }
94
95 impl Future for YieldNow {
96 type Output = ();
97
98 fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
99 if self.count > 0 {
100 self.count -= 1;
101 cx.waker().wake_by_ref();
102 Poll::Pending
103 } else {
104 Poll::Ready(())
105 }
106 }
107 }
108
109 YieldNow {
110 count: self.state.lock().random.gen_range(0..10),
111 }
112 }
113
114 pub fn tick(&self, background_only: bool) -> bool {
115 let mut state = self.state.lock();
116
117 while let Some((deadline, _)) = state.delayed.first() {
118 if *deadline > state.time {
119 break;
120 }
121 let (_, runnable) = state.delayed.remove(0);
122 state.background.push(runnable);
123 }
124
125 let foreground_len: usize = if background_only {
126 0
127 } else {
128 state
129 .foreground
130 .values()
131 .map(|runnables| runnables.len())
132 .sum()
133 };
134 let background_len = state.background.len();
135
136 let runnable;
137 let main_thread;
138 if foreground_len == 0 && background_len == 0 {
139 let deprioritized_background_len = state.deprioritized_background.len();
140 if deprioritized_background_len == 0 {
141 return false;
142 }
143 let ix = state.random.gen_range(0..deprioritized_background_len);
144 main_thread = false;
145 runnable = state.deprioritized_background.swap_remove(ix);
146 } else {
147 main_thread = state.random.gen_ratio(
148 foreground_len as u32,
149 (foreground_len + background_len) as u32,
150 );
151 if main_thread {
152 let state = &mut *state;
153 runnable = state
154 .foreground
155 .values_mut()
156 .filter(|runnables| !runnables.is_empty())
157 .choose(&mut state.random)
158 .unwrap()
159 .pop_front()
160 .unwrap();
161 } else {
162 let ix = state.random.gen_range(0..background_len);
163 runnable = state.background.swap_remove(ix);
164 };
165 };
166
167 let was_main_thread = state.is_main_thread;
168 state.is_main_thread = main_thread;
169 drop(state);
170 runnable.run();
171 self.state.lock().is_main_thread = was_main_thread;
172
173 true
174 }
175
176 pub fn deprioritize(&self, task_label: TaskLabel) {
177 self.state
178 .lock()
179 .deprioritized_task_labels
180 .insert(task_label);
181 }
182
183 pub fn run_until_parked(&self) {
184 while self.tick(false) {}
185 }
186
187 pub fn parking_allowed(&self) -> bool {
188 self.state.lock().allow_parking
189 }
190
191 pub fn allow_parking(&self) {
192 self.state.lock().allow_parking = true
193 }
194
195 pub fn forbid_parking(&self) {
196 self.state.lock().allow_parking = false
197 }
198
199 pub fn set_waiting_hint(&self, msg: Option<String>) {
200 self.state.lock().waiting_hint = msg
201 }
202
203 pub fn waiting_hint(&self) -> Option<String> {
204 self.state.lock().waiting_hint.clone()
205 }
206
207 pub fn start_waiting(&self) {
208 self.state.lock().waiting_backtrace = Some(Backtrace::new_unresolved());
209 }
210
211 pub fn finish_waiting(&self) {
212 self.state.lock().waiting_backtrace.take();
213 }
214
215 pub fn waiting_backtrace(&self) -> Option<Backtrace> {
216 self.state.lock().waiting_backtrace.take().map(|mut b| {
217 b.resolve();
218 b
219 })
220 }
221
222 pub fn rng(&self) -> StdRng {
223 self.state.lock().random.clone()
224 }
225
226 pub fn set_block_on_ticks(&self, range: std::ops::RangeInclusive<usize>) {
227 self.state.lock().block_on_ticks = range;
228 }
229
230 pub fn gen_block_on_ticks(&self) -> usize {
231 let mut lock = self.state.lock();
232 let block_on_ticks = lock.block_on_ticks.clone();
233 lock.random.gen_range(block_on_ticks)
234 }
235}
236
237impl Clone for TestDispatcher {
238 fn clone(&self) -> Self {
239 let id = post_inc(&mut self.state.lock().next_id.0);
240 Self {
241 id: TestDispatcherId(id),
242 state: self.state.clone(),
243 parker: self.parker.clone(),
244 unparker: self.unparker.clone(),
245 }
246 }
247}
248
249impl PlatformDispatcher for TestDispatcher {
250 fn is_main_thread(&self) -> bool {
251 self.state.lock().is_main_thread
252 }
253
254 fn dispatch(&self, runnable: Runnable, label: Option<TaskLabel>) {
255 {
256 let mut state = self.state.lock();
257 if label.map_or(false, |label| {
258 state.deprioritized_task_labels.contains(&label)
259 }) {
260 state.deprioritized_background.push(runnable);
261 } else {
262 state.background.push(runnable);
263 }
264 }
265 self.unparker.unpark();
266 }
267
268 fn dispatch_on_main_thread(&self, runnable: Runnable) {
269 self.state
270 .lock()
271 .foreground
272 .entry(self.id)
273 .or_default()
274 .push_back(runnable);
275 self.unparker.unpark();
276 }
277
278 fn dispatch_after(&self, duration: std::time::Duration, runnable: Runnable) {
279 let mut state = self.state.lock();
280 let next_time = state.time + duration;
281 let ix = match state.delayed.binary_search_by_key(&next_time, |e| e.0) {
282 Ok(ix) | Err(ix) => ix,
283 };
284 state.delayed.insert(ix, (next_time, runnable));
285 }
286 fn park(&self, _: Option<std::time::Duration>) -> bool {
287 self.parker.lock().park();
288 true
289 }
290
291 fn unparker(&self) -> Unparker {
292 self.unparker.clone()
293 }
294
295 fn as_test(&self) -> Option<&TestDispatcher> {
296 Some(self)
297 }
298}