dispatcher.rs

  1use crate::{PlatformDispatcher, TaskLabel};
  2use async_task::Runnable;
  3use backtrace::Backtrace;
  4use collections::{HashMap, HashSet, VecDeque};
  5use parking::{Parker, Unparker};
  6use parking_lot::Mutex;
  7use rand::prelude::*;
  8use std::{
  9    future::Future,
 10    ops::RangeInclusive,
 11    pin::Pin,
 12    sync::Arc,
 13    task::{Context, Poll},
 14    time::{Duration, Instant},
 15};
 16use util::post_inc;
 17
 18#[derive(Copy, Clone, PartialEq, Eq, Hash)]
 19struct TestDispatcherId(usize);
 20
 21#[doc(hidden)]
 22pub struct TestDispatcher {
 23    id: TestDispatcherId,
 24    state: Arc<Mutex<TestDispatcherState>>,
 25    parker: Arc<Mutex<Parker>>,
 26    unparker: Unparker,
 27}
 28
 29struct TestDispatcherState {
 30    random: StdRng,
 31    foreground: HashMap<TestDispatcherId, VecDeque<Runnable>>,
 32    background: Vec<Runnable>,
 33    deprioritized_background: Vec<Runnable>,
 34    delayed: Vec<(Duration, Runnable)>,
 35    start_time: Instant,
 36    time: Duration,
 37    is_main_thread: bool,
 38    next_id: TestDispatcherId,
 39    allow_parking: bool,
 40    waiting_hint: Option<String>,
 41    waiting_backtrace: Option<Backtrace>,
 42    deprioritized_task_labels: HashSet<TaskLabel>,
 43    block_on_ticks: RangeInclusive<usize>,
 44}
 45
 46impl TestDispatcher {
 47    pub fn new(random: StdRng) -> Self {
 48        let (parker, unparker) = parking::pair();
 49        let state = TestDispatcherState {
 50            random,
 51            foreground: HashMap::default(),
 52            background: Vec::new(),
 53            deprioritized_background: Vec::new(),
 54            delayed: Vec::new(),
 55            time: Duration::ZERO,
 56            start_time: Instant::now(),
 57            is_main_thread: true,
 58            next_id: TestDispatcherId(1),
 59            allow_parking: false,
 60            waiting_hint: None,
 61            waiting_backtrace: None,
 62            deprioritized_task_labels: Default::default(),
 63            block_on_ticks: 0..=1000,
 64        };
 65
 66        TestDispatcher {
 67            id: TestDispatcherId(0),
 68            state: Arc::new(Mutex::new(state)),
 69            parker: Arc::new(Mutex::new(parker)),
 70            unparker,
 71        }
 72    }
 73
 74    pub fn advance_clock(&self, by: Duration) {
 75        let new_now = self.state.lock().time + by;
 76        loop {
 77            self.run_until_parked();
 78            let state = self.state.lock();
 79            let next_due_time = state.delayed.first().map(|(time, _)| *time);
 80            drop(state);
 81            if let Some(due_time) = next_due_time {
 82                if due_time <= new_now {
 83                    self.state.lock().time = due_time;
 84                    continue;
 85                }
 86            }
 87            break;
 88        }
 89        self.state.lock().time = new_now;
 90    }
 91
 92    pub fn simulate_random_delay(&self) -> impl 'static + Send + Future<Output = ()> + use<> {
 93        struct YieldNow {
 94            pub(crate) count: usize,
 95        }
 96
 97        impl Future for YieldNow {
 98            type Output = ();
 99
100            fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
101                if self.count > 0 {
102                    self.count -= 1;
103                    cx.waker().wake_by_ref();
104                    Poll::Pending
105                } else {
106                    Poll::Ready(())
107                }
108            }
109        }
110
111        YieldNow {
112            count: self.state.lock().random.gen_range(0..10),
113        }
114    }
115
116    pub fn tick(&self, background_only: bool) -> bool {
117        let mut state = self.state.lock();
118
119        while let Some((deadline, _)) = state.delayed.first() {
120            if *deadline > state.time {
121                break;
122            }
123            let (_, runnable) = state.delayed.remove(0);
124            state.background.push(runnable);
125        }
126
127        let foreground_len: usize = if background_only {
128            0
129        } else {
130            state
131                .foreground
132                .values()
133                .map(|runnables| runnables.len())
134                .sum()
135        };
136        let background_len = state.background.len();
137
138        let runnable;
139        let main_thread;
140        if foreground_len == 0 && background_len == 0 {
141            let deprioritized_background_len = state.deprioritized_background.len();
142            if deprioritized_background_len == 0 {
143                return false;
144            }
145            let ix = state.random.gen_range(0..deprioritized_background_len);
146            main_thread = false;
147            runnable = state.deprioritized_background.swap_remove(ix);
148        } else {
149            main_thread = state.random.gen_ratio(
150                foreground_len as u32,
151                (foreground_len + background_len) as u32,
152            );
153            if main_thread {
154                let state = &mut *state;
155                runnable = state
156                    .foreground
157                    .values_mut()
158                    .filter(|runnables| !runnables.is_empty())
159                    .choose(&mut state.random)
160                    .unwrap()
161                    .pop_front()
162                    .unwrap();
163            } else {
164                let ix = state.random.gen_range(0..background_len);
165                runnable = state.background.swap_remove(ix);
166            };
167        };
168
169        let was_main_thread = state.is_main_thread;
170        state.is_main_thread = main_thread;
171        drop(state);
172        runnable.run();
173        self.state.lock().is_main_thread = was_main_thread;
174
175        true
176    }
177
178    pub fn deprioritize(&self, task_label: TaskLabel) {
179        self.state
180            .lock()
181            .deprioritized_task_labels
182            .insert(task_label);
183    }
184
185    pub fn run_until_parked(&self) {
186        while self.tick(false) {}
187    }
188
189    pub fn parking_allowed(&self) -> bool {
190        self.state.lock().allow_parking
191    }
192
193    pub fn allow_parking(&self) {
194        self.state.lock().allow_parking = true
195    }
196
197    pub fn forbid_parking(&self) {
198        self.state.lock().allow_parking = false
199    }
200
201    pub fn set_waiting_hint(&self, msg: Option<String>) {
202        self.state.lock().waiting_hint = msg
203    }
204
205    pub fn waiting_hint(&self) -> Option<String> {
206        self.state.lock().waiting_hint.clone()
207    }
208
209    pub fn start_waiting(&self) {
210        self.state.lock().waiting_backtrace = Some(Backtrace::new_unresolved());
211    }
212
213    pub fn finish_waiting(&self) {
214        self.state.lock().waiting_backtrace.take();
215    }
216
217    pub fn waiting_backtrace(&self) -> Option<Backtrace> {
218        self.state.lock().waiting_backtrace.take().map(|mut b| {
219            b.resolve();
220            b
221        })
222    }
223
224    pub fn rng(&self) -> StdRng {
225        self.state.lock().random.clone()
226    }
227
228    pub fn set_block_on_ticks(&self, range: std::ops::RangeInclusive<usize>) {
229        self.state.lock().block_on_ticks = range;
230    }
231
232    pub fn gen_block_on_ticks(&self) -> usize {
233        let mut lock = self.state.lock();
234        let block_on_ticks = lock.block_on_ticks.clone();
235        lock.random.gen_range(block_on_ticks)
236    }
237}
238
239impl Clone for TestDispatcher {
240    fn clone(&self) -> Self {
241        let id = post_inc(&mut self.state.lock().next_id.0);
242        Self {
243            id: TestDispatcherId(id),
244            state: self.state.clone(),
245            parker: self.parker.clone(),
246            unparker: self.unparker.clone(),
247        }
248    }
249}
250
251impl PlatformDispatcher for TestDispatcher {
252    fn is_main_thread(&self) -> bool {
253        self.state.lock().is_main_thread
254    }
255
256    fn now(&self) -> Instant {
257        let state = self.state.lock();
258        state.start_time + state.time
259    }
260
261    fn dispatch(&self, runnable: Runnable, label: Option<TaskLabel>) {
262        {
263            let mut state = self.state.lock();
264            if label.map_or(false, |label| {
265                state.deprioritized_task_labels.contains(&label)
266            }) {
267                state.deprioritized_background.push(runnable);
268            } else {
269                state.background.push(runnable);
270            }
271        }
272        self.unparker.unpark();
273    }
274
275    fn dispatch_on_main_thread(&self, runnable: Runnable) {
276        self.state
277            .lock()
278            .foreground
279            .entry(self.id)
280            .or_default()
281            .push_back(runnable);
282        self.unparker.unpark();
283    }
284
285    fn dispatch_after(&self, duration: std::time::Duration, runnable: Runnable) {
286        let mut state = self.state.lock();
287        let next_time = state.time + duration;
288        let ix = match state.delayed.binary_search_by_key(&next_time, |e| e.0) {
289            Ok(ix) | Err(ix) => ix,
290        };
291        state.delayed.insert(ix, (next_time, runnable));
292    }
293    fn park(&self, _: Option<std::time::Duration>) -> bool {
294        self.parker.lock().park();
295        true
296    }
297
298    fn unparker(&self) -> Unparker {
299        self.unparker.clone()
300    }
301
302    fn as_test(&self) -> Option<&TestDispatcher> {
303        Some(self)
304    }
305}