1use crate::PlatformDispatcher;
2use async_task::Runnable;
3use backtrace::Backtrace;
4use collections::{HashMap, VecDeque};
5use parking::{Parker, Unparker};
6use parking_lot::Mutex;
7use rand::prelude::*;
8use std::{
9 future::Future,
10 pin::Pin,
11 sync::Arc,
12 task::{Context, Poll},
13 time::Duration,
14};
15use util::post_inc;
16
17#[derive(Copy, Clone, PartialEq, Eq, Hash)]
18struct TestDispatcherId(usize);
19
20pub struct TestDispatcher {
21 id: TestDispatcherId,
22 state: Arc<Mutex<TestDispatcherState>>,
23 parker: Arc<Mutex<Parker>>,
24 unparker: Unparker,
25}
26
27struct TestDispatcherState {
28 random: StdRng,
29 foreground: HashMap<TestDispatcherId, VecDeque<Runnable>>,
30 background: Vec<Runnable>,
31 delayed: Vec<(Duration, Runnable)>,
32 time: Duration,
33 is_main_thread: bool,
34 next_id: TestDispatcherId,
35 allow_parking: bool,
36 waiting_backtrace: Option<Backtrace>,
37}
38
39impl TestDispatcher {
40 pub fn new(random: StdRng) -> Self {
41 let (parker, unparker) = parking::pair();
42 let state = TestDispatcherState {
43 random,
44 foreground: HashMap::default(),
45 background: Vec::new(),
46 delayed: Vec::new(),
47 time: Duration::ZERO,
48 is_main_thread: true,
49 next_id: TestDispatcherId(1),
50 allow_parking: false,
51 waiting_backtrace: None,
52 };
53
54 TestDispatcher {
55 id: TestDispatcherId(0),
56 state: Arc::new(Mutex::new(state)),
57 parker: Arc::new(Mutex::new(parker)),
58 unparker,
59 }
60 }
61
62 pub fn advance_clock(&self, by: Duration) {
63 let new_now = self.state.lock().time + by;
64 loop {
65 self.run_until_parked();
66 let state = self.state.lock();
67 let next_due_time = state.delayed.first().map(|(time, _)| *time);
68 drop(state);
69 if let Some(due_time) = next_due_time {
70 if due_time <= new_now {
71 self.state.lock().time = due_time;
72 continue;
73 }
74 }
75 break;
76 }
77 self.state.lock().time = new_now;
78 }
79
80 pub fn simulate_random_delay(&self) -> impl 'static + Send + Future<Output = ()> {
81 pub struct YieldNow {
82 count: usize,
83 }
84
85 impl Future for YieldNow {
86 type Output = ();
87
88 fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
89 if self.count > 0 {
90 self.count -= 1;
91 cx.waker().wake_by_ref();
92 Poll::Pending
93 } else {
94 Poll::Ready(())
95 }
96 }
97 }
98
99 YieldNow {
100 count: self.state.lock().random.gen_range(0..10),
101 }
102 }
103
104 pub fn run_until_parked(&self) {
105 while self.poll(false) {}
106 }
107
108 pub fn parking_allowed(&self) -> bool {
109 self.state.lock().allow_parking
110 }
111
112 pub fn allow_parking(&self) {
113 self.state.lock().allow_parking = true
114 }
115
116 pub fn start_waiting(&self) {
117 self.state.lock().waiting_backtrace = Some(Backtrace::new_unresolved());
118 }
119
120 pub fn finish_waiting(&self) {
121 self.state.lock().waiting_backtrace.take();
122 }
123
124 pub fn waiting_backtrace(&self) -> Option<Backtrace> {
125 self.state.lock().waiting_backtrace.take().map(|mut b| {
126 b.resolve();
127 b
128 })
129 }
130}
131
132impl Clone for TestDispatcher {
133 fn clone(&self) -> Self {
134 let id = post_inc(&mut self.state.lock().next_id.0);
135 Self {
136 id: TestDispatcherId(id),
137 state: self.state.clone(),
138 parker: self.parker.clone(),
139 unparker: self.unparker.clone(),
140 }
141 }
142}
143
144impl PlatformDispatcher for TestDispatcher {
145 fn is_main_thread(&self) -> bool {
146 self.state.lock().is_main_thread
147 }
148
149 fn dispatch(&self, runnable: Runnable) {
150 self.state.lock().background.push(runnable);
151 self.unparker.unpark();
152 }
153
154 fn dispatch_on_main_thread(&self, runnable: Runnable) {
155 self.state
156 .lock()
157 .foreground
158 .entry(self.id)
159 .or_default()
160 .push_back(runnable);
161 self.unparker.unpark();
162 }
163
164 fn dispatch_after(&self, duration: std::time::Duration, runnable: Runnable) {
165 let mut state = self.state.lock();
166 let next_time = state.time + duration;
167 let ix = match state.delayed.binary_search_by_key(&next_time, |e| e.0) {
168 Ok(ix) | Err(ix) => ix,
169 };
170 state.delayed.insert(ix, (next_time, runnable));
171 }
172
173 fn poll(&self, background_only: bool) -> bool {
174 let mut state = self.state.lock();
175
176 while let Some((deadline, _)) = state.delayed.first() {
177 if *deadline > state.time {
178 break;
179 }
180 let (_, runnable) = state.delayed.remove(0);
181 state.background.push(runnable);
182 }
183
184 let foreground_len: usize = if background_only {
185 0
186 } else {
187 state
188 .foreground
189 .values()
190 .map(|runnables| runnables.len())
191 .sum()
192 };
193 let background_len = state.background.len();
194
195 if foreground_len == 0 && background_len == 0 {
196 return false;
197 }
198
199 let main_thread = state.random.gen_ratio(
200 foreground_len as u32,
201 (foreground_len + background_len) as u32,
202 );
203 let was_main_thread = state.is_main_thread;
204 state.is_main_thread = main_thread;
205
206 let runnable = if main_thread {
207 let state = &mut *state;
208 let runnables = state
209 .foreground
210 .values_mut()
211 .filter(|runnables| !runnables.is_empty())
212 .choose(&mut state.random)
213 .unwrap();
214 runnables.pop_front().unwrap()
215 } else {
216 let ix = state.random.gen_range(0..background_len);
217 state.background.swap_remove(ix)
218 };
219
220 drop(state);
221 runnable.run();
222
223 self.state.lock().is_main_thread = was_main_thread;
224
225 true
226 }
227
228 fn park(&self) {
229 self.parker.lock().park();
230 }
231
232 fn unparker(&self) -> Unparker {
233 self.unparker.clone()
234 }
235
236 fn as_test(&self) -> Option<&TestDispatcher> {
237 Some(self)
238 }
239}