1use crate::{PlatformDispatcher, TaskLabel};
2use async_task::Runnable;
3use backtrace::Backtrace;
4use collections::{HashMap, HashSet, VecDeque};
5use parking::{Parker, Unparker};
6use parking_lot::Mutex;
7use rand::prelude::*;
8use std::{
9 future::Future,
10 ops::RangeInclusive,
11 pin::Pin,
12 sync::Arc,
13 task::{Context, Poll},
14 time::Duration,
15};
16use util::post_inc;
17
18#[derive(Copy, Clone, PartialEq, Eq, Hash)]
19struct TestDispatcherId(usize);
20
21#[doc(hidden)]
22pub struct TestDispatcher {
23 id: TestDispatcherId,
24 state: Arc<Mutex<TestDispatcherState>>,
25 parker: Arc<Mutex<Parker>>,
26 unparker: Unparker,
27}
28
29struct TestDispatcherState {
30 random: StdRng,
31 foreground: HashMap<TestDispatcherId, VecDeque<Runnable>>,
32 background: Vec<Runnable>,
33 deprioritized_background: Vec<Runnable>,
34 delayed: Vec<(Duration, Runnable)>,
35 time: Duration,
36 is_main_thread: bool,
37 next_id: TestDispatcherId,
38 allow_parking: bool,
39 waiting_backtrace: Option<Backtrace>,
40 deprioritized_task_labels: HashSet<TaskLabel>,
41 block_on_ticks: RangeInclusive<usize>,
42}
43
44impl TestDispatcher {
45 pub fn new(random: StdRng) -> Self {
46 let (parker, unparker) = parking::pair();
47 let state = TestDispatcherState {
48 random,
49 foreground: HashMap::default(),
50 background: Vec::new(),
51 deprioritized_background: Vec::new(),
52 delayed: Vec::new(),
53 time: Duration::ZERO,
54 is_main_thread: true,
55 next_id: TestDispatcherId(1),
56 allow_parking: false,
57 waiting_backtrace: None,
58 deprioritized_task_labels: Default::default(),
59 block_on_ticks: 0..=1000,
60 };
61
62 TestDispatcher {
63 id: TestDispatcherId(0),
64 state: Arc::new(Mutex::new(state)),
65 parker: Arc::new(Mutex::new(parker)),
66 unparker,
67 }
68 }
69
70 pub fn advance_clock(&self, by: Duration) {
71 let new_now = self.state.lock().time + by;
72 loop {
73 self.run_until_parked();
74 let state = self.state.lock();
75 let next_due_time = state.delayed.first().map(|(time, _)| *time);
76 drop(state);
77 if let Some(due_time) = next_due_time {
78 if due_time <= new_now {
79 self.state.lock().time = due_time;
80 continue;
81 }
82 }
83 break;
84 }
85 self.state.lock().time = new_now;
86 }
87
88 pub fn simulate_random_delay(&self) -> impl 'static + Send + Future<Output = ()> {
89 struct YieldNow {
90 pub(crate) count: usize,
91 }
92
93 impl Future for YieldNow {
94 type Output = ();
95
96 fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
97 if self.count > 0 {
98 self.count -= 1;
99 cx.waker().wake_by_ref();
100 Poll::Pending
101 } else {
102 Poll::Ready(())
103 }
104 }
105 }
106
107 YieldNow {
108 count: self.state.lock().random.gen_range(0..10),
109 }
110 }
111
112 pub fn deprioritize(&self, task_label: TaskLabel) {
113 self.state
114 .lock()
115 .deprioritized_task_labels
116 .insert(task_label);
117 }
118
119 pub fn run_until_parked(&self) {
120 while self.tick(false) {}
121 }
122
123 pub fn parking_allowed(&self) -> bool {
124 self.state.lock().allow_parking
125 }
126
127 pub fn allow_parking(&self) {
128 self.state.lock().allow_parking = true
129 }
130
131 pub fn forbid_parking(&self) {
132 self.state.lock().allow_parking = false
133 }
134
135 pub fn start_waiting(&self) {
136 self.state.lock().waiting_backtrace = Some(Backtrace::new_unresolved());
137 }
138
139 pub fn finish_waiting(&self) {
140 self.state.lock().waiting_backtrace.take();
141 }
142
143 pub fn waiting_backtrace(&self) -> Option<Backtrace> {
144 self.state.lock().waiting_backtrace.take().map(|mut b| {
145 b.resolve();
146 b
147 })
148 }
149
150 pub fn rng(&self) -> StdRng {
151 self.state.lock().random.clone()
152 }
153
154 pub fn set_block_on_ticks(&self, range: std::ops::RangeInclusive<usize>) {
155 self.state.lock().block_on_ticks = range;
156 }
157
158 pub fn gen_block_on_ticks(&self) -> usize {
159 let mut lock = self.state.lock();
160 let block_on_ticks = lock.block_on_ticks.clone();
161 lock.random.gen_range(block_on_ticks)
162 }
163}
164
165impl Clone for TestDispatcher {
166 fn clone(&self) -> Self {
167 let id = post_inc(&mut self.state.lock().next_id.0);
168 Self {
169 id: TestDispatcherId(id),
170 state: self.state.clone(),
171 parker: self.parker.clone(),
172 unparker: self.unparker.clone(),
173 }
174 }
175}
176
177impl PlatformDispatcher for TestDispatcher {
178 fn is_main_thread(&self) -> bool {
179 self.state.lock().is_main_thread
180 }
181
182 fn dispatch(&self, runnable: Runnable, label: Option<TaskLabel>) {
183 {
184 let mut state = self.state.lock();
185 if label.map_or(false, |label| {
186 state.deprioritized_task_labels.contains(&label)
187 }) {
188 state.deprioritized_background.push(runnable);
189 } else {
190 state.background.push(runnable);
191 }
192 }
193 self.unparker.unpark();
194 }
195
196 fn dispatch_on_main_thread(&self, runnable: Runnable) {
197 self.state
198 .lock()
199 .foreground
200 .entry(self.id)
201 .or_default()
202 .push_back(runnable);
203 self.unparker.unpark();
204 }
205
206 fn dispatch_after(&self, duration: std::time::Duration, runnable: Runnable) {
207 let mut state = self.state.lock();
208 let next_time = state.time + duration;
209 let ix = match state.delayed.binary_search_by_key(&next_time, |e| e.0) {
210 Ok(ix) | Err(ix) => ix,
211 };
212 state.delayed.insert(ix, (next_time, runnable));
213 }
214
215 fn tick(&self, background_only: bool) -> bool {
216 let mut state = self.state.lock();
217
218 while let Some((deadline, _)) = state.delayed.first() {
219 if *deadline > state.time {
220 break;
221 }
222 let (_, runnable) = state.delayed.remove(0);
223 state.background.push(runnable);
224 }
225
226 let foreground_len: usize = if background_only {
227 0
228 } else {
229 state
230 .foreground
231 .values()
232 .map(|runnables| runnables.len())
233 .sum()
234 };
235 let background_len = state.background.len();
236
237 let runnable;
238 let main_thread;
239 if foreground_len == 0 && background_len == 0 {
240 let deprioritized_background_len = state.deprioritized_background.len();
241 if deprioritized_background_len == 0 {
242 return false;
243 }
244 let ix = state.random.gen_range(0..deprioritized_background_len);
245 main_thread = false;
246 runnable = state.deprioritized_background.swap_remove(ix);
247 } else {
248 main_thread = state.random.gen_ratio(
249 foreground_len as u32,
250 (foreground_len + background_len) as u32,
251 );
252 if main_thread {
253 let state = &mut *state;
254 runnable = state
255 .foreground
256 .values_mut()
257 .filter(|runnables| !runnables.is_empty())
258 .choose(&mut state.random)
259 .unwrap()
260 .pop_front()
261 .unwrap();
262 } else {
263 let ix = state.random.gen_range(0..background_len);
264 runnable = state.background.swap_remove(ix);
265 };
266 };
267
268 let was_main_thread = state.is_main_thread;
269 state.is_main_thread = main_thread;
270 drop(state);
271 runnable.run();
272 self.state.lock().is_main_thread = was_main_thread;
273
274 true
275 }
276
277 fn park(&self) {
278 self.parker.lock().park();
279 }
280
281 fn unparker(&self) -> Unparker {
282 self.unparker.clone()
283 }
284
285 fn as_test(&self) -> Option<&TestDispatcher> {
286 Some(self)
287 }
288}