1use crate::{PlatformDispatcher, TaskLabel};
2use async_task::Runnable;
3use backtrace::Backtrace;
4use collections::{HashMap, HashSet, VecDeque};
5use parking::{Parker, Unparker};
6use parking_lot::Mutex;
7use rand::prelude::*;
8use std::{
9 future::Future,
10 ops::RangeInclusive,
11 pin::Pin,
12 sync::Arc,
13 task::{Context, Poll},
14 time::Duration,
15};
16use util::post_inc;
17
18#[derive(Copy, Clone, PartialEq, Eq, Hash)]
19struct TestDispatcherId(usize);
20
21#[doc(hidden)]
22pub struct TestDispatcher {
23 id: TestDispatcherId,
24 state: Arc<Mutex<TestDispatcherState>>,
25 parker: Arc<Mutex<Parker>>,
26 unparker: Unparker,
27}
28
29struct TestDispatcherState {
30 random: StdRng,
31 foreground: HashMap<TestDispatcherId, VecDeque<Runnable>>,
32 background: Vec<Runnable>,
33 deprioritized_background: Vec<Runnable>,
34 delayed: Vec<(Duration, Runnable)>,
35 time: Duration,
36 is_main_thread: bool,
37 next_id: TestDispatcherId,
38 allow_parking: bool,
39 waiting_backtrace: Option<Backtrace>,
40 deprioritized_task_labels: HashSet<TaskLabel>,
41 block_on_ticks: RangeInclusive<usize>,
42}
43
44impl TestDispatcher {
45 pub fn new(random: StdRng) -> Self {
46 let (parker, unparker) = parking::pair();
47 let state = TestDispatcherState {
48 random,
49 foreground: HashMap::default(),
50 background: Vec::new(),
51 deprioritized_background: Vec::new(),
52 delayed: Vec::new(),
53 time: Duration::ZERO,
54 is_main_thread: true,
55 next_id: TestDispatcherId(1),
56 allow_parking: false,
57 waiting_backtrace: None,
58 deprioritized_task_labels: Default::default(),
59 block_on_ticks: 0..=1000,
60 };
61
62 TestDispatcher {
63 id: TestDispatcherId(0),
64 state: Arc::new(Mutex::new(state)),
65 parker: Arc::new(Mutex::new(parker)),
66 unparker,
67 }
68 }
69
70 pub fn advance_clock(&self, by: Duration) {
71 let new_now = self.state.lock().time + by;
72 loop {
73 self.run_until_parked();
74 let state = self.state.lock();
75 let next_due_time = state.delayed.first().map(|(time, _)| *time);
76 drop(state);
77 if let Some(due_time) = next_due_time {
78 if due_time <= new_now {
79 self.state.lock().time = due_time;
80 continue;
81 }
82 }
83 break;
84 }
85 self.state.lock().time = new_now;
86 }
87
88 pub fn simulate_random_delay(&self) -> impl 'static + Send + Future<Output = ()> {
89 struct YieldNow {
90 pub(crate) count: usize,
91 }
92
93 impl Future for YieldNow {
94 type Output = ();
95
96 fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
97 if self.count > 0 {
98 self.count -= 1;
99 cx.waker().wake_by_ref();
100 Poll::Pending
101 } else {
102 Poll::Ready(())
103 }
104 }
105 }
106
107 YieldNow {
108 count: self.state.lock().random.gen_range(0..10),
109 }
110 }
111
112 pub fn deprioritize(&self, task_label: TaskLabel) {
113 self.state
114 .lock()
115 .deprioritized_task_labels
116 .insert(task_label);
117 }
118
119 pub fn run_until_parked(&self) {
120 while self.tick(false) {}
121 }
122
123 pub fn parking_allowed(&self) -> bool {
124 self.state.lock().allow_parking
125 }
126
127 pub fn allow_parking(&self) {
128 self.state.lock().allow_parking = true
129 }
130
131 pub fn start_waiting(&self) {
132 self.state.lock().waiting_backtrace = Some(Backtrace::new_unresolved());
133 }
134
135 pub fn finish_waiting(&self) {
136 self.state.lock().waiting_backtrace.take();
137 }
138
139 pub fn waiting_backtrace(&self) -> Option<Backtrace> {
140 self.state.lock().waiting_backtrace.take().map(|mut b| {
141 b.resolve();
142 b
143 })
144 }
145
146 pub fn rng(&self) -> StdRng {
147 self.state.lock().random.clone()
148 }
149
150 pub fn set_block_on_ticks(&self, range: std::ops::RangeInclusive<usize>) {
151 self.state.lock().block_on_ticks = range;
152 }
153
154 pub fn gen_block_on_ticks(&self) -> usize {
155 let mut lock = self.state.lock();
156 let block_on_ticks = lock.block_on_ticks.clone();
157 lock.random.gen_range(block_on_ticks)
158 }
159}
160
161impl Clone for TestDispatcher {
162 fn clone(&self) -> Self {
163 let id = post_inc(&mut self.state.lock().next_id.0);
164 Self {
165 id: TestDispatcherId(id),
166 state: self.state.clone(),
167 parker: self.parker.clone(),
168 unparker: self.unparker.clone(),
169 }
170 }
171}
172
173impl PlatformDispatcher for TestDispatcher {
174 fn is_main_thread(&self) -> bool {
175 self.state.lock().is_main_thread
176 }
177
178 fn dispatch(&self, runnable: Runnable, label: Option<TaskLabel>) {
179 {
180 let mut state = self.state.lock();
181 if label.map_or(false, |label| {
182 state.deprioritized_task_labels.contains(&label)
183 }) {
184 state.deprioritized_background.push(runnable);
185 } else {
186 state.background.push(runnable);
187 }
188 }
189 self.unparker.unpark();
190 }
191
192 fn dispatch_on_main_thread(&self, runnable: Runnable) {
193 self.state
194 .lock()
195 .foreground
196 .entry(self.id)
197 .or_default()
198 .push_back(runnable);
199 self.unparker.unpark();
200 }
201
202 fn dispatch_after(&self, duration: std::time::Duration, runnable: Runnable) {
203 let mut state = self.state.lock();
204 let next_time = state.time + duration;
205 let ix = match state.delayed.binary_search_by_key(&next_time, |e| e.0) {
206 Ok(ix) | Err(ix) => ix,
207 };
208 state.delayed.insert(ix, (next_time, runnable));
209 }
210
211 fn tick(&self, background_only: bool) -> bool {
212 let mut state = self.state.lock();
213
214 while let Some((deadline, _)) = state.delayed.first() {
215 if *deadline > state.time {
216 break;
217 }
218 let (_, runnable) = state.delayed.remove(0);
219 state.background.push(runnable);
220 }
221
222 let foreground_len: usize = if background_only {
223 0
224 } else {
225 state
226 .foreground
227 .values()
228 .map(|runnables| runnables.len())
229 .sum()
230 };
231 let background_len = state.background.len();
232
233 let runnable;
234 let main_thread;
235 if foreground_len == 0 && background_len == 0 {
236 let deprioritized_background_len = state.deprioritized_background.len();
237 if deprioritized_background_len == 0 {
238 return false;
239 }
240 let ix = state.random.gen_range(0..deprioritized_background_len);
241 main_thread = false;
242 runnable = state.deprioritized_background.swap_remove(ix);
243 } else {
244 main_thread = state.random.gen_ratio(
245 foreground_len as u32,
246 (foreground_len + background_len) as u32,
247 );
248 if main_thread {
249 let state = &mut *state;
250 runnable = state
251 .foreground
252 .values_mut()
253 .filter(|runnables| !runnables.is_empty())
254 .choose(&mut state.random)
255 .unwrap()
256 .pop_front()
257 .unwrap();
258 } else {
259 let ix = state.random.gen_range(0..background_len);
260 runnable = state.background.swap_remove(ix);
261 };
262 };
263
264 let was_main_thread = state.is_main_thread;
265 state.is_main_thread = main_thread;
266 drop(state);
267 runnable.run();
268 self.state.lock().is_main_thread = was_main_thread;
269
270 true
271 }
272
273 fn park(&self) {
274 self.parker.lock().park();
275 }
276
277 fn unparker(&self) -> Unparker {
278 self.unparker.clone()
279 }
280
281 fn as_test(&self) -> Option<&TestDispatcher> {
282 Some(self)
283 }
284}