1use crate::PlatformDispatcher;
2use async_task::Runnable;
3use collections::{HashMap, VecDeque};
4use parking_lot::Mutex;
5use rand::prelude::*;
6use std::{
7 future::Future,
8 pin::Pin,
9 sync::Arc,
10 task::{Context, Poll},
11 time::Duration,
12};
13use util::post_inc;
14
15#[derive(Copy, Clone, PartialEq, Eq, Hash)]
16struct TestDispatcherId(usize);
17
18pub struct TestDispatcher {
19 id: TestDispatcherId,
20 state: Arc<Mutex<TestDispatcherState>>,
21}
22
23struct TestDispatcherState {
24 random: StdRng,
25 foreground: HashMap<TestDispatcherId, VecDeque<Runnable>>,
26 background: Vec<Runnable>,
27 delayed: Vec<(Duration, Runnable)>,
28 time: Duration,
29 is_main_thread: bool,
30 next_id: TestDispatcherId,
31}
32
33impl TestDispatcher {
34 pub fn new(random: StdRng) -> Self {
35 let state = TestDispatcherState {
36 random,
37 foreground: HashMap::default(),
38 background: Vec::new(),
39 delayed: Vec::new(),
40 time: Duration::ZERO,
41 is_main_thread: true,
42 next_id: TestDispatcherId(1),
43 };
44
45 TestDispatcher {
46 id: TestDispatcherId(0),
47 state: Arc::new(Mutex::new(state)),
48 }
49 }
50
51 pub fn advance_clock(&self, by: Duration) {
52 let new_now = self.state.lock().time + by;
53 loop {
54 self.run_until_parked();
55 let state = self.state.lock();
56 let next_due_time = state.delayed.first().map(|(time, _)| *time);
57 drop(state);
58 if let Some(due_time) = next_due_time {
59 if due_time <= new_now {
60 self.state.lock().time = due_time;
61 continue;
62 }
63 }
64 break;
65 }
66 self.state.lock().time = new_now;
67 }
68
69 pub fn simulate_random_delay(&self) -> impl Future<Output = ()> {
70 pub struct YieldNow {
71 count: usize,
72 }
73
74 impl Future for YieldNow {
75 type Output = ();
76
77 fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
78 if self.count > 0 {
79 self.count -= 1;
80 cx.waker().wake_by_ref();
81 Poll::Pending
82 } else {
83 Poll::Ready(())
84 }
85 }
86 }
87
88 YieldNow {
89 count: self.state.lock().random.gen_range(0..10),
90 }
91 }
92
93 pub fn run_until_parked(&self) {
94 while self.poll() {}
95 }
96}
97
98impl Clone for TestDispatcher {
99 fn clone(&self) -> Self {
100 let id = post_inc(&mut self.state.lock().next_id.0);
101 Self {
102 id: TestDispatcherId(id),
103 state: self.state.clone(),
104 }
105 }
106}
107
108impl PlatformDispatcher for TestDispatcher {
109 fn is_main_thread(&self) -> bool {
110 self.state.lock().is_main_thread
111 }
112
113 fn dispatch(&self, runnable: Runnable) {
114 self.state.lock().background.push(runnable);
115 }
116
117 fn dispatch_on_main_thread(&self, runnable: Runnable) {
118 self.state
119 .lock()
120 .foreground
121 .entry(self.id)
122 .or_default()
123 .push_back(runnable);
124 }
125
126 fn dispatch_after(&self, duration: std::time::Duration, runnable: Runnable) {
127 let mut state = self.state.lock();
128 let next_time = state.time + duration;
129 let ix = match state.delayed.binary_search_by_key(&next_time, |e| e.0) {
130 Ok(ix) | Err(ix) => ix,
131 };
132 state.delayed.insert(ix, (next_time, runnable));
133 }
134
135 fn poll(&self) -> bool {
136 let mut state = self.state.lock();
137
138 while let Some((deadline, _)) = state.delayed.first() {
139 if *deadline > state.time {
140 break;
141 }
142 let (_, runnable) = state.delayed.remove(0);
143 state.background.push(runnable);
144 }
145
146 let foreground_len: usize = state
147 .foreground
148 .values()
149 .map(|runnables| runnables.len())
150 .sum();
151 let background_len = state.background.len();
152
153 if foreground_len == 0 && background_len == 0 {
154 return false;
155 }
156
157 let main_thread = state.random.gen_ratio(
158 foreground_len as u32,
159 (foreground_len + background_len) as u32,
160 );
161 let was_main_thread = state.is_main_thread;
162 state.is_main_thread = main_thread;
163
164 let runnable = if main_thread {
165 let state = &mut *state;
166 let runnables = state
167 .foreground
168 .values_mut()
169 .filter(|runnables| !runnables.is_empty())
170 .choose(&mut state.random)
171 .unwrap();
172 runnables.pop_front().unwrap()
173 } else {
174 let ix = state.random.gen_range(0..background_len);
175 state.background.swap_remove(ix)
176 };
177
178 drop(state);
179 runnable.run();
180
181 self.state.lock().is_main_thread = was_main_thread;
182
183 true
184 }
185
186 fn as_test(&self) -> Option<&TestDispatcher> {
187 Some(self)
188 }
189}
190
191#[cfg(test)]
192mod tests {
193 use super::*;
194 use crate::Executor;
195 use std::sync::Arc;
196
197 #[test]
198 fn test_dispatch() {
199 let dispatcher = TestDispatcher::new(StdRng::seed_from_u64(0));
200 let executor = Executor::new(Arc::new(dispatcher));
201
202 let result = executor.block(async { executor.run_on_main(|| 1).await });
203 assert_eq!(result, 1);
204
205 let result = executor.block({
206 let executor = executor.clone();
207 async move {
208 executor
209 .spawn_on_main({
210 let executor = executor.clone();
211 assert!(executor.is_main_thread());
212 || async move {
213 assert!(executor.is_main_thread());
214 let result = executor
215 .spawn({
216 let executor = executor.clone();
217 async move {
218 assert!(!executor.is_main_thread());
219
220 let result = executor
221 .spawn_on_main({
222 let executor = executor.clone();
223 || async move {
224 assert!(executor.is_main_thread());
225 2
226 }
227 })
228 .await;
229
230 assert!(!executor.is_main_thread());
231 result
232 }
233 })
234 .await;
235 assert!(executor.is_main_thread());
236 result
237 }
238 })
239 .await
240 }
241 });
242 assert_eq!(result, 2);
243 }
244}