1use crate::{PlatformDispatcher, TaskLabel};
2use async_task::Runnable;
3use backtrace::Backtrace;
4use collections::{HashMap, HashSet, VecDeque};
5use parking::{Parker, Unparker};
6use parking_lot::Mutex;
7use rand::prelude::*;
8use std::{
9 future::Future,
10 pin::Pin,
11 sync::Arc,
12 task::{Context, Poll},
13 time::Duration,
14};
15use util::post_inc;
16
17#[derive(Copy, Clone, PartialEq, Eq, Hash)]
18struct TestDispatcherId(usize);
19
20pub struct TestDispatcher {
21 id: TestDispatcherId,
22 state: Arc<Mutex<TestDispatcherState>>,
23 parker: Arc<Mutex<Parker>>,
24 unparker: Unparker,
25}
26
27struct TestDispatcherState {
28 random: StdRng,
29 foreground: HashMap<TestDispatcherId, VecDeque<Runnable>>,
30 background: Vec<Runnable>,
31 deprioritized_background: Vec<Runnable>,
32 delayed: Vec<(Duration, Runnable)>,
33 time: Duration,
34 is_main_thread: bool,
35 next_id: TestDispatcherId,
36 allow_parking: bool,
37 waiting_backtrace: Option<Backtrace>,
38 deprioritized_task_labels: HashSet<TaskLabel>,
39}
40
41impl TestDispatcher {
42 pub fn new(random: StdRng) -> Self {
43 let (parker, unparker) = parking::pair();
44 let state = TestDispatcherState {
45 random,
46 foreground: HashMap::default(),
47 background: Vec::new(),
48 deprioritized_background: Vec::new(),
49 delayed: Vec::new(),
50 time: Duration::ZERO,
51 is_main_thread: true,
52 next_id: TestDispatcherId(1),
53 allow_parking: false,
54 waiting_backtrace: None,
55 deprioritized_task_labels: Default::default(),
56 };
57
58 TestDispatcher {
59 id: TestDispatcherId(0),
60 state: Arc::new(Mutex::new(state)),
61 parker: Arc::new(Mutex::new(parker)),
62 unparker,
63 }
64 }
65
66 pub fn advance_clock(&self, by: Duration) {
67 let new_now = self.state.lock().time + by;
68 loop {
69 self.run_until_parked();
70 let state = self.state.lock();
71 let next_due_time = state.delayed.first().map(|(time, _)| *time);
72 drop(state);
73 if let Some(due_time) = next_due_time {
74 if due_time <= new_now {
75 self.state.lock().time = due_time;
76 continue;
77 }
78 }
79 break;
80 }
81 self.state.lock().time = new_now;
82 }
83
84 pub fn simulate_random_delay(&self) -> impl 'static + Send + Future<Output = ()> {
85 pub struct YieldNow {
86 count: usize,
87 }
88
89 impl Future for YieldNow {
90 type Output = ();
91
92 fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
93 if self.count > 0 {
94 self.count -= 1;
95 cx.waker().wake_by_ref();
96 Poll::Pending
97 } else {
98 Poll::Ready(())
99 }
100 }
101 }
102
103 YieldNow {
104 count: self.state.lock().random.gen_range(0..10),
105 }
106 }
107
108 pub fn deprioritize(&self, task_label: TaskLabel) {
109 self.state
110 .lock()
111 .deprioritized_task_labels
112 .insert(task_label);
113 }
114
115 pub fn run_until_parked(&self) {
116 while self.tick(false) {}
117 }
118
119 pub fn parking_allowed(&self) -> bool {
120 self.state.lock().allow_parking
121 }
122
123 pub fn allow_parking(&self) {
124 self.state.lock().allow_parking = true
125 }
126
127 pub fn start_waiting(&self) {
128 self.state.lock().waiting_backtrace = Some(Backtrace::new_unresolved());
129 }
130
131 pub fn finish_waiting(&self) {
132 self.state.lock().waiting_backtrace.take();
133 }
134
135 pub fn waiting_backtrace(&self) -> Option<Backtrace> {
136 self.state.lock().waiting_backtrace.take().map(|mut b| {
137 b.resolve();
138 b
139 })
140 }
141
142 pub fn rng(&self) -> StdRng {
143 self.state.lock().random.clone()
144 }
145}
146
147impl Clone for TestDispatcher {
148 fn clone(&self) -> Self {
149 let id = post_inc(&mut self.state.lock().next_id.0);
150 Self {
151 id: TestDispatcherId(id),
152 state: self.state.clone(),
153 parker: self.parker.clone(),
154 unparker: self.unparker.clone(),
155 }
156 }
157}
158
159impl PlatformDispatcher for TestDispatcher {
160 fn is_main_thread(&self) -> bool {
161 self.state.lock().is_main_thread
162 }
163
164 fn dispatch(&self, runnable: Runnable, label: Option<TaskLabel>) {
165 {
166 let mut state = self.state.lock();
167 if label.map_or(false, |label| {
168 state.deprioritized_task_labels.contains(&label)
169 }) {
170 state.deprioritized_background.push(runnable);
171 } else {
172 state.background.push(runnable);
173 }
174 }
175 self.unparker.unpark();
176 }
177
178 fn dispatch_on_main_thread(&self, runnable: Runnable) {
179 self.state
180 .lock()
181 .foreground
182 .entry(self.id)
183 .or_default()
184 .push_back(runnable);
185 self.unparker.unpark();
186 }
187
188 fn dispatch_after(&self, duration: std::time::Duration, runnable: Runnable) {
189 let mut state = self.state.lock();
190 let next_time = state.time + duration;
191 let ix = match state.delayed.binary_search_by_key(&next_time, |e| e.0) {
192 Ok(ix) | Err(ix) => ix,
193 };
194 state.delayed.insert(ix, (next_time, runnable));
195 }
196
197 fn tick(&self, background_only: bool) -> bool {
198 let mut state = self.state.lock();
199
200 while let Some((deadline, _)) = state.delayed.first() {
201 if *deadline > state.time {
202 break;
203 }
204 let (_, runnable) = state.delayed.remove(0);
205 state.background.push(runnable);
206 }
207
208 let foreground_len: usize = if background_only {
209 0
210 } else {
211 state
212 .foreground
213 .values()
214 .map(|runnables| runnables.len())
215 .sum()
216 };
217 let background_len = state.background.len();
218
219 let runnable;
220 let main_thread;
221 if foreground_len == 0 && background_len == 0 {
222 let deprioritized_background_len = state.deprioritized_background.len();
223 if deprioritized_background_len == 0 {
224 return false;
225 }
226 let ix = state.random.gen_range(0..deprioritized_background_len);
227 main_thread = false;
228 runnable = state.deprioritized_background.swap_remove(ix);
229 } else {
230 main_thread = state.random.gen_ratio(
231 foreground_len as u32,
232 (foreground_len + background_len) as u32,
233 );
234 if main_thread {
235 let state = &mut *state;
236 runnable = state
237 .foreground
238 .values_mut()
239 .filter(|runnables| !runnables.is_empty())
240 .choose(&mut state.random)
241 .unwrap()
242 .pop_front()
243 .unwrap();
244 } else {
245 let ix = state.random.gen_range(0..background_len);
246 runnable = state.background.swap_remove(ix);
247 };
248 };
249
250 let was_main_thread = state.is_main_thread;
251 state.is_main_thread = main_thread;
252 drop(state);
253 runnable.run();
254 self.state.lock().is_main_thread = was_main_thread;
255
256 true
257 }
258
259 fn park(&self) {
260 self.parker.lock().park();
261 }
262
263 fn unparker(&self) -> Unparker {
264 self.unparker.clone()
265 }
266
267 fn as_test(&self) -> Option<&TestDispatcher> {
268 Some(self)
269 }
270}