1use crate::PlatformDispatcher;
2use async_task::Runnable;
3use backtrace::Backtrace;
4use collections::{HashMap, VecDeque};
5use parking::{Parker, Unparker};
6use parking_lot::Mutex;
7use rand::prelude::*;
8use std::{
9 future::Future,
10 pin::Pin,
11 sync::Arc,
12 task::{Context, Poll},
13 time::Duration,
14};
15use util::post_inc;
16
17#[derive(Copy, Clone, PartialEq, Eq, Hash)]
18struct TestDispatcherId(usize);
19
20pub struct TestDispatcher {
21 id: TestDispatcherId,
22 state: Arc<Mutex<TestDispatcherState>>,
23 parker: Arc<Mutex<Parker>>,
24 unparker: Unparker,
25}
26
27struct TestDispatcherState {
28 random: StdRng,
29 foreground: HashMap<TestDispatcherId, VecDeque<Runnable>>,
30 background: Vec<Runnable>,
31 delayed: Vec<(Duration, Runnable)>,
32 time: Duration,
33 is_main_thread: bool,
34 next_id: TestDispatcherId,
35 allow_parking: bool,
36 waiting_backtrace: Option<Backtrace>,
37}
38
39impl TestDispatcher {
40 pub fn new(random: StdRng) -> Self {
41 let (parker, unparker) = parking::pair();
42 let state = TestDispatcherState {
43 random,
44 foreground: HashMap::default(),
45 background: Vec::new(),
46 delayed: Vec::new(),
47 time: Duration::ZERO,
48 is_main_thread: true,
49 next_id: TestDispatcherId(1),
50 allow_parking: false,
51 waiting_backtrace: None,
52 };
53
54 TestDispatcher {
55 id: TestDispatcherId(0),
56 state: Arc::new(Mutex::new(state)),
57 parker: Arc::new(Mutex::new(parker)),
58 unparker,
59 }
60 }
61
62 pub fn advance_clock(&self, by: Duration) {
63 let new_now = self.state.lock().time + by;
64 loop {
65 self.run_until_parked();
66 let state = self.state.lock();
67 let next_due_time = state.delayed.first().map(|(time, _)| *time);
68 drop(state);
69 if let Some(due_time) = next_due_time {
70 if due_time <= new_now {
71 self.state.lock().time = due_time;
72 continue;
73 }
74 }
75 break;
76 }
77 self.state.lock().time = new_now;
78 }
79
80 pub fn simulate_random_delay(&self) -> impl 'static + Send + Future<Output = ()> {
81 pub struct YieldNow {
82 count: usize,
83 }
84
85 impl Future for YieldNow {
86 type Output = ();
87
88 fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
89 if self.count > 0 {
90 self.count -= 1;
91 cx.waker().wake_by_ref();
92 Poll::Pending
93 } else {
94 Poll::Ready(())
95 }
96 }
97 }
98
99 YieldNow {
100 count: self.state.lock().random.gen_range(0..10),
101 }
102 }
103
104 pub fn run_until_parked(&self) {
105 while self.poll(false) {}
106 }
107
108 pub fn parking_allowed(&self) -> bool {
109 self.state.lock().allow_parking
110 }
111
112 pub fn allow_parking(&self) {
113 self.state.lock().allow_parking = true
114 }
115
116 pub fn start_waiting(&self) {
117 self.state.lock().waiting_backtrace = Some(Backtrace::new_unresolved());
118 }
119
120 pub fn finish_waiting(&self) {
121 self.state.lock().waiting_backtrace.take();
122 }
123
124 pub fn waiting_backtrace(&self) -> Option<Backtrace> {
125 self.state.lock().waiting_backtrace.take().map(|mut b| {
126 b.resolve();
127 b
128 })
129 }
130
131 pub fn rng(&self) -> StdRng {
132 self.state.lock().random.clone()
133 }
134}
135
136impl Clone for TestDispatcher {
137 fn clone(&self) -> Self {
138 let id = post_inc(&mut self.state.lock().next_id.0);
139 Self {
140 id: TestDispatcherId(id),
141 state: self.state.clone(),
142 parker: self.parker.clone(),
143 unparker: self.unparker.clone(),
144 }
145 }
146}
147
148impl PlatformDispatcher for TestDispatcher {
149 fn is_main_thread(&self) -> bool {
150 self.state.lock().is_main_thread
151 }
152
153 fn dispatch(&self, runnable: Runnable) {
154 self.state.lock().background.push(runnable);
155 self.unparker.unpark();
156 }
157
158 fn dispatch_on_main_thread(&self, runnable: Runnable) {
159 self.state
160 .lock()
161 .foreground
162 .entry(self.id)
163 .or_default()
164 .push_back(runnable);
165 self.unparker.unpark();
166 }
167
168 fn dispatch_after(&self, duration: std::time::Duration, runnable: Runnable) {
169 let mut state = self.state.lock();
170 let next_time = state.time + duration;
171 let ix = match state.delayed.binary_search_by_key(&next_time, |e| e.0) {
172 Ok(ix) | Err(ix) => ix,
173 };
174 state.delayed.insert(ix, (next_time, runnable));
175 }
176
177 fn poll(&self, background_only: bool) -> bool {
178 let mut state = self.state.lock();
179
180 while let Some((deadline, _)) = state.delayed.first() {
181 if *deadline > state.time {
182 break;
183 }
184 let (_, runnable) = state.delayed.remove(0);
185 state.background.push(runnable);
186 }
187
188 let foreground_len: usize = if background_only {
189 0
190 } else {
191 state
192 .foreground
193 .values()
194 .map(|runnables| runnables.len())
195 .sum()
196 };
197 let background_len = state.background.len();
198
199 if foreground_len == 0 && background_len == 0 {
200 return false;
201 }
202
203 let main_thread = state.random.gen_ratio(
204 foreground_len as u32,
205 (foreground_len + background_len) as u32,
206 );
207 let was_main_thread = state.is_main_thread;
208 state.is_main_thread = main_thread;
209
210 let runnable = if main_thread {
211 let state = &mut *state;
212 let runnables = state
213 .foreground
214 .values_mut()
215 .filter(|runnables| !runnables.is_empty())
216 .choose(&mut state.random)
217 .unwrap();
218 runnables.pop_front().unwrap()
219 } else {
220 let ix = state.random.gen_range(0..background_len);
221 state.background.swap_remove(ix)
222 };
223
224 drop(state);
225 runnable.run();
226
227 self.state.lock().is_main_thread = was_main_thread;
228
229 true
230 }
231
232 fn park(&self) {
233 self.parker.lock().park();
234 }
235
236 fn unparker(&self) -> Unparker {
237 self.unparker.clone()
238 }
239
240 fn as_test(&self) -> Option<&TestDispatcher> {
241 Some(self)
242 }
243}