dispatcher.rs

  1use crate::{PlatformDispatcher, TaskLabel};
  2use async_task::Runnable;
  3use backtrace::Backtrace;
  4use collections::{HashMap, HashSet, VecDeque};
  5use parking::{Parker, Unparker};
  6use parking_lot::Mutex;
  7use rand::prelude::*;
  8use std::{
  9    future::Future,
 10    ops::RangeInclusive,
 11    pin::Pin,
 12    sync::Arc,
 13    task::{Context, Poll},
 14    time::{Duration, Instant},
 15};
 16use util::post_inc;
 17
 18#[derive(Copy, Clone, PartialEq, Eq, Hash)]
 19struct TestDispatcherId(usize);
 20
 21#[doc(hidden)]
 22pub struct TestDispatcher {
 23    id: TestDispatcherId,
 24    state: Arc<Mutex<TestDispatcherState>>,
 25    parker: Arc<Mutex<Parker>>,
 26    unparker: Unparker,
 27}
 28
 29struct TestDispatcherState {
 30    random: StdRng,
 31    foreground: HashMap<TestDispatcherId, VecDeque<Runnable>>,
 32    background: Vec<Runnable>,
 33    deprioritized_background: Vec<Runnable>,
 34    delayed: Vec<(Duration, Runnable)>,
 35    start_time: Instant,
 36    time: Duration,
 37    is_main_thread: bool,
 38    next_id: TestDispatcherId,
 39    allow_parking: bool,
 40    waiting_hint: Option<String>,
 41    waiting_backtrace: Option<Backtrace>,
 42    deprioritized_task_labels: HashSet<TaskLabel>,
 43    block_on_ticks: RangeInclusive<usize>,
 44}
 45
 46impl TestDispatcher {
 47    pub fn new(random: StdRng) -> Self {
 48        let (parker, unparker) = parking::pair();
 49        let state = TestDispatcherState {
 50            random,
 51            foreground: HashMap::default(),
 52            background: Vec::new(),
 53            deprioritized_background: Vec::new(),
 54            delayed: Vec::new(),
 55            time: Duration::ZERO,
 56            start_time: Instant::now(),
 57            is_main_thread: true,
 58            next_id: TestDispatcherId(1),
 59            allow_parking: false,
 60            waiting_hint: None,
 61            waiting_backtrace: None,
 62            deprioritized_task_labels: Default::default(),
 63            block_on_ticks: 0..=1000,
 64        };
 65
 66        TestDispatcher {
 67            id: TestDispatcherId(0),
 68            state: Arc::new(Mutex::new(state)),
 69            parker: Arc::new(Mutex::new(parker)),
 70            unparker,
 71        }
 72    }
 73
 74    pub fn advance_clock(&self, by: Duration) {
 75        let new_now = self.state.lock().time + by;
 76        loop {
 77            self.run_until_parked();
 78            let state = self.state.lock();
 79            let next_due_time = state.delayed.first().map(|(time, _)| *time);
 80            drop(state);
 81            if let Some(due_time) = next_due_time {
 82                if due_time <= new_now {
 83                    self.state.lock().time = due_time;
 84                    continue;
 85                }
 86            }
 87            break;
 88        }
 89        self.state.lock().time = new_now;
 90    }
 91
 92    pub fn advance_clock_to_next_delayed(&self) -> bool {
 93        let next_due_time = self.state.lock().delayed.first().map(|(time, _)| *time);
 94        if let Some(next_due_time) = next_due_time {
 95            self.state.lock().time = next_due_time;
 96            return true;
 97        }
 98        false
 99    }
100
101    pub fn simulate_random_delay(&self) -> impl 'static + Send + Future<Output = ()> + use<> {
102        struct YieldNow {
103            pub(crate) count: usize,
104        }
105
106        impl Future for YieldNow {
107            type Output = ();
108
109            fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
110                if self.count > 0 {
111                    self.count -= 1;
112                    cx.waker().wake_by_ref();
113                    Poll::Pending
114                } else {
115                    Poll::Ready(())
116                }
117            }
118        }
119
120        YieldNow {
121            count: self.state.lock().random.gen_range(0..10),
122        }
123    }
124
125    pub fn deprioritize(&self, task_label: TaskLabel) {
126        self.state
127            .lock()
128            .deprioritized_task_labels
129            .insert(task_label);
130    }
131
132    pub fn run_until_parked(&self) {
133        while self.tick(false) {}
134    }
135
136    pub fn parking_allowed(&self) -> bool {
137        self.state.lock().allow_parking
138    }
139
140    pub fn allow_parking(&self) {
141        self.state.lock().allow_parking = true
142    }
143
144    pub fn forbid_parking(&self) {
145        self.state.lock().allow_parking = false
146    }
147
148    pub fn set_waiting_hint(&self, msg: Option<String>) {
149        self.state.lock().waiting_hint = msg
150    }
151
152    pub fn waiting_hint(&self) -> Option<String> {
153        self.state.lock().waiting_hint.clone()
154    }
155
156    pub fn start_waiting(&self) {
157        self.state.lock().waiting_backtrace = Some(Backtrace::new_unresolved());
158    }
159
160    pub fn finish_waiting(&self) {
161        self.state.lock().waiting_backtrace.take();
162    }
163
164    pub fn waiting_backtrace(&self) -> Option<Backtrace> {
165        self.state.lock().waiting_backtrace.take().map(|mut b| {
166            b.resolve();
167            b
168        })
169    }
170
171    pub fn rng(&self) -> StdRng {
172        self.state.lock().random.clone()
173    }
174
175    pub fn set_block_on_ticks(&self, range: std::ops::RangeInclusive<usize>) {
176        self.state.lock().block_on_ticks = range;
177    }
178
179    pub fn gen_block_on_ticks(&self) -> usize {
180        let mut lock = self.state.lock();
181        let block_on_ticks = lock.block_on_ticks.clone();
182        lock.random.gen_range(block_on_ticks)
183    }
184}
185
186impl Clone for TestDispatcher {
187    fn clone(&self) -> Self {
188        let id = post_inc(&mut self.state.lock().next_id.0);
189        Self {
190            id: TestDispatcherId(id),
191            state: self.state.clone(),
192            parker: self.parker.clone(),
193            unparker: self.unparker.clone(),
194        }
195    }
196}
197
198impl PlatformDispatcher for TestDispatcher {
199    fn is_main_thread(&self) -> bool {
200        self.state.lock().is_main_thread
201    }
202
203    fn now(&self) -> Instant {
204        let state = self.state.lock();
205        state.start_time + state.time
206    }
207
208    fn tick(&self, background_only: bool) -> bool {
209        let mut state = self.state.lock();
210
211        while let Some((deadline, _)) = state.delayed.first() {
212            if *deadline > state.time {
213                break;
214            }
215            let (_, runnable) = state.delayed.remove(0);
216            state.background.push(runnable);
217        }
218
219        let foreground_len: usize = if background_only {
220            0
221        } else {
222            state
223                .foreground
224                .values()
225                .map(|runnables| runnables.len())
226                .sum()
227        };
228        let background_len = state.background.len();
229
230        let runnable;
231        let main_thread;
232        if foreground_len == 0 && background_len == 0 {
233            let deprioritized_background_len = state.deprioritized_background.len();
234            if deprioritized_background_len == 0 {
235                return false;
236            }
237            let ix = state.random.gen_range(0..deprioritized_background_len);
238            main_thread = false;
239            runnable = state.deprioritized_background.swap_remove(ix);
240        } else {
241            main_thread = state.random.gen_ratio(
242                foreground_len as u32,
243                (foreground_len + background_len) as u32,
244            );
245            if main_thread {
246                let state = &mut *state;
247                runnable = state
248                    .foreground
249                    .values_mut()
250                    .filter(|runnables| !runnables.is_empty())
251                    .choose(&mut state.random)
252                    .unwrap()
253                    .pop_front()
254                    .unwrap();
255            } else {
256                let ix = state.random.gen_range(0..background_len);
257                runnable = state.background.swap_remove(ix);
258            };
259        };
260
261        let was_main_thread = state.is_main_thread;
262        state.is_main_thread = main_thread;
263        drop(state);
264        runnable.run();
265        self.state.lock().is_main_thread = was_main_thread;
266
267        true
268    }
269
270    fn dispatch(&self, runnable: Runnable, label: Option<TaskLabel>) {
271        {
272            let mut state = self.state.lock();
273            if label.map_or(false, |label| {
274                state.deprioritized_task_labels.contains(&label)
275            }) {
276                state.deprioritized_background.push(runnable);
277            } else {
278                state.background.push(runnable);
279            }
280        }
281        self.unparker.unpark();
282    }
283
284    fn dispatch_on_main_thread(&self, runnable: Runnable) {
285        self.state
286            .lock()
287            .foreground
288            .entry(self.id)
289            .or_default()
290            .push_back(runnable);
291        self.unparker.unpark();
292    }
293
294    fn dispatch_after(&self, duration: std::time::Duration, runnable: Runnable) {
295        let mut state = self.state.lock();
296        let next_time = state.time + duration;
297        let ix = match state.delayed.binary_search_by_key(&next_time, |e| e.0) {
298            Ok(ix) | Err(ix) => ix,
299        };
300        state.delayed.insert(ix, (next_time, runnable));
301    }
302    fn park(&self, _: Option<std::time::Duration>) -> bool {
303        self.parker.lock().park();
304        true
305    }
306
307    fn unparker(&self) -> Unparker {
308        self.unparker.clone()
309    }
310
311    fn as_test(&self) -> Option<&TestDispatcher> {
312        Some(self)
313    }
314}