1use crate::{PlatformDispatcher, TaskLabel};
2use async_task::Runnable;
3use backtrace::Backtrace;
4use collections::{HashMap, HashSet, VecDeque};
5use parking::{Parker, Unparker};
6use parking_lot::Mutex;
7use rand::prelude::*;
8use std::{
9 future::Future,
10 ops::RangeInclusive,
11 pin::Pin,
12 sync::Arc,
13 task::{Context, Poll},
14 time::Duration,
15};
16use util::post_inc;
17
18#[derive(Copy, Clone, PartialEq, Eq, Hash)]
19struct TestDispatcherId(usize);
20
21pub struct TestDispatcher {
22 id: TestDispatcherId,
23 state: Arc<Mutex<TestDispatcherState>>,
24 parker: Arc<Mutex<Parker>>,
25 unparker: Unparker,
26}
27
28struct TestDispatcherState {
29 random: StdRng,
30 foreground: HashMap<TestDispatcherId, VecDeque<Runnable>>,
31 background: Vec<Runnable>,
32 deprioritized_background: Vec<Runnable>,
33 delayed: Vec<(Duration, Runnable)>,
34 time: Duration,
35 is_main_thread: bool,
36 next_id: TestDispatcherId,
37 allow_parking: bool,
38 waiting_backtrace: Option<Backtrace>,
39 deprioritized_task_labels: HashSet<TaskLabel>,
40 block_on_ticks: RangeInclusive<usize>,
41}
42
43impl TestDispatcher {
44 pub fn new(random: StdRng) -> Self {
45 let (parker, unparker) = parking::pair();
46 let state = TestDispatcherState {
47 random,
48 foreground: HashMap::default(),
49 background: Vec::new(),
50 deprioritized_background: Vec::new(),
51 delayed: Vec::new(),
52 time: Duration::ZERO,
53 is_main_thread: true,
54 next_id: TestDispatcherId(1),
55 allow_parking: false,
56 waiting_backtrace: None,
57 deprioritized_task_labels: Default::default(),
58 block_on_ticks: 0..=1000,
59 };
60
61 TestDispatcher {
62 id: TestDispatcherId(0),
63 state: Arc::new(Mutex::new(state)),
64 parker: Arc::new(Mutex::new(parker)),
65 unparker,
66 }
67 }
68
69 pub fn advance_clock(&self, by: Duration) {
70 let new_now = self.state.lock().time + by;
71 loop {
72 self.run_until_parked();
73 let state = self.state.lock();
74 let next_due_time = state.delayed.first().map(|(time, _)| *time);
75 drop(state);
76 if let Some(due_time) = next_due_time {
77 if due_time <= new_now {
78 self.state.lock().time = due_time;
79 continue;
80 }
81 }
82 break;
83 }
84 self.state.lock().time = new_now;
85 }
86
87 pub fn simulate_random_delay(&self) -> impl 'static + Send + Future<Output = ()> {
88 struct YieldNow {
89 pub(crate) count: usize,
90 }
91
92 impl Future for YieldNow {
93 type Output = ();
94
95 fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
96 if self.count > 0 {
97 self.count -= 1;
98 cx.waker().wake_by_ref();
99 Poll::Pending
100 } else {
101 Poll::Ready(())
102 }
103 }
104 }
105
106 YieldNow {
107 count: self.state.lock().random.gen_range(0..10),
108 }
109 }
110
111 pub fn deprioritize(&self, task_label: TaskLabel) {
112 self.state
113 .lock()
114 .deprioritized_task_labels
115 .insert(task_label);
116 }
117
118 pub fn run_until_parked(&self) {
119 while self.tick(false) {}
120 }
121
122 pub fn parking_allowed(&self) -> bool {
123 self.state.lock().allow_parking
124 }
125
126 pub fn allow_parking(&self) {
127 self.state.lock().allow_parking = true
128 }
129
130 pub fn start_waiting(&self) {
131 self.state.lock().waiting_backtrace = Some(Backtrace::new_unresolved());
132 }
133
134 pub fn finish_waiting(&self) {
135 self.state.lock().waiting_backtrace.take();
136 }
137
138 pub fn waiting_backtrace(&self) -> Option<Backtrace> {
139 self.state.lock().waiting_backtrace.take().map(|mut b| {
140 b.resolve();
141 b
142 })
143 }
144
145 pub fn rng(&self) -> StdRng {
146 self.state.lock().random.clone()
147 }
148
149 pub fn set_block_on_ticks(&self, range: std::ops::RangeInclusive<usize>) {
150 self.state.lock().block_on_ticks = range;
151 }
152
153 pub fn gen_block_on_ticks(&self) -> usize {
154 let mut lock = self.state.lock();
155 let block_on_ticks = lock.block_on_ticks.clone();
156 lock.random.gen_range(block_on_ticks)
157 }
158}
159
160impl Clone for TestDispatcher {
161 fn clone(&self) -> Self {
162 let id = post_inc(&mut self.state.lock().next_id.0);
163 Self {
164 id: TestDispatcherId(id),
165 state: self.state.clone(),
166 parker: self.parker.clone(),
167 unparker: self.unparker.clone(),
168 }
169 }
170}
171
172impl PlatformDispatcher for TestDispatcher {
173 fn is_main_thread(&self) -> bool {
174 self.state.lock().is_main_thread
175 }
176
177 fn dispatch(&self, runnable: Runnable, label: Option<TaskLabel>) {
178 {
179 let mut state = self.state.lock();
180 if label.map_or(false, |label| {
181 state.deprioritized_task_labels.contains(&label)
182 }) {
183 state.deprioritized_background.push(runnable);
184 } else {
185 state.background.push(runnable);
186 }
187 }
188 self.unparker.unpark();
189 }
190
191 fn dispatch_on_main_thread(&self, runnable: Runnable) {
192 self.state
193 .lock()
194 .foreground
195 .entry(self.id)
196 .or_default()
197 .push_back(runnable);
198 self.unparker.unpark();
199 }
200
201 fn dispatch_after(&self, duration: std::time::Duration, runnable: Runnable) {
202 let mut state = self.state.lock();
203 let next_time = state.time + duration;
204 let ix = match state.delayed.binary_search_by_key(&next_time, |e| e.0) {
205 Ok(ix) | Err(ix) => ix,
206 };
207 state.delayed.insert(ix, (next_time, runnable));
208 }
209
210 fn tick(&self, background_only: bool) -> bool {
211 let mut state = self.state.lock();
212
213 while let Some((deadline, _)) = state.delayed.first() {
214 if *deadline > state.time {
215 break;
216 }
217 let (_, runnable) = state.delayed.remove(0);
218 state.background.push(runnable);
219 }
220
221 let foreground_len: usize = if background_only {
222 0
223 } else {
224 state
225 .foreground
226 .values()
227 .map(|runnables| runnables.len())
228 .sum()
229 };
230 let background_len = state.background.len();
231
232 let runnable;
233 let main_thread;
234 if foreground_len == 0 && background_len == 0 {
235 let deprioritized_background_len = state.deprioritized_background.len();
236 if deprioritized_background_len == 0 {
237 return false;
238 }
239 let ix = state.random.gen_range(0..deprioritized_background_len);
240 main_thread = false;
241 runnable = state.deprioritized_background.swap_remove(ix);
242 } else {
243 main_thread = state.random.gen_ratio(
244 foreground_len as u32,
245 (foreground_len + background_len) as u32,
246 );
247 if main_thread {
248 let state = &mut *state;
249 runnable = state
250 .foreground
251 .values_mut()
252 .filter(|runnables| !runnables.is_empty())
253 .choose(&mut state.random)
254 .unwrap()
255 .pop_front()
256 .unwrap();
257 } else {
258 let ix = state.random.gen_range(0..background_len);
259 runnable = state.background.swap_remove(ix);
260 };
261 };
262
263 let was_main_thread = state.is_main_thread;
264 state.is_main_thread = main_thread;
265 drop(state);
266 runnable.run();
267 self.state.lock().is_main_thread = was_main_thread;
268
269 true
270 }
271
272 fn park(&self) {
273 self.parker.lock().park();
274 }
275
276 fn unparker(&self) -> Unparker {
277 self.unparker.clone()
278 }
279
280 fn as_test(&self) -> Option<&TestDispatcher> {
281 Some(self)
282 }
283}