dispatcher.rs

  1use crate::{PlatformDispatcher, TaskLabel};
  2use async_task::Runnable;
  3use backtrace::Backtrace;
  4use collections::{HashMap, HashSet, VecDeque};
  5use parking::{Parker, Unparker};
  6use parking_lot::Mutex;
  7use rand::prelude::*;
  8use std::{
  9    future::Future,
 10    ops::RangeInclusive,
 11    pin::Pin,
 12    sync::Arc,
 13    task::{Context, Poll},
 14    time::Duration,
 15};
 16use util::post_inc;
 17
 18#[derive(Copy, Clone, PartialEq, Eq, Hash)]
 19struct TestDispatcherId(usize);
 20
 21#[doc(hidden)]
 22pub struct TestDispatcher {
 23    id: TestDispatcherId,
 24    state: Arc<Mutex<TestDispatcherState>>,
 25    parker: Arc<Mutex<Parker>>,
 26    unparker: Unparker,
 27}
 28
 29struct TestDispatcherState {
 30    random: StdRng,
 31    foreground: HashMap<TestDispatcherId, VecDeque<Runnable>>,
 32    background: Vec<Runnable>,
 33    deprioritized_background: Vec<Runnable>,
 34    delayed: Vec<(Duration, Runnable)>,
 35    time: Duration,
 36    is_main_thread: bool,
 37    next_id: TestDispatcherId,
 38    allow_parking: bool,
 39    waiting_hint: Option<String>,
 40    waiting_backtrace: Option<Backtrace>,
 41    deprioritized_task_labels: HashSet<TaskLabel>,
 42    block_on_ticks: RangeInclusive<usize>,
 43}
 44
 45impl TestDispatcher {
 46    pub fn new(random: StdRng) -> Self {
 47        let (parker, unparker) = parking::pair();
 48        let state = TestDispatcherState {
 49            random,
 50            foreground: HashMap::default(),
 51            background: Vec::new(),
 52            deprioritized_background: Vec::new(),
 53            delayed: Vec::new(),
 54            time: Duration::ZERO,
 55            is_main_thread: true,
 56            next_id: TestDispatcherId(1),
 57            allow_parking: false,
 58            waiting_hint: None,
 59            waiting_backtrace: None,
 60            deprioritized_task_labels: Default::default(),
 61            block_on_ticks: 0..=1000,
 62        };
 63
 64        TestDispatcher {
 65            id: TestDispatcherId(0),
 66            state: Arc::new(Mutex::new(state)),
 67            parker: Arc::new(Mutex::new(parker)),
 68            unparker,
 69        }
 70    }
 71
 72    pub fn advance_clock(&self, by: Duration) {
 73        let new_now = self.state.lock().time + by;
 74        loop {
 75            self.run_until_parked();
 76            let state = self.state.lock();
 77            let next_due_time = state.delayed.first().map(|(time, _)| *time);
 78            drop(state);
 79            if let Some(due_time) = next_due_time {
 80                if due_time <= new_now {
 81                    self.state.lock().time = due_time;
 82                    continue;
 83                }
 84            }
 85            break;
 86        }
 87        self.state.lock().time = new_now;
 88    }
 89
 90    pub fn simulate_random_delay(&self) -> impl 'static + Send + Future<Output = ()> {
 91        struct YieldNow {
 92            pub(crate) count: usize,
 93        }
 94
 95        impl Future for YieldNow {
 96            type Output = ();
 97
 98            fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
 99                if self.count > 0 {
100                    self.count -= 1;
101                    cx.waker().wake_by_ref();
102                    Poll::Pending
103                } else {
104                    Poll::Ready(())
105                }
106            }
107        }
108
109        YieldNow {
110            count: self.state.lock().random.gen_range(0..10),
111        }
112    }
113
114    pub fn deprioritize(&self, task_label: TaskLabel) {
115        self.state
116            .lock()
117            .deprioritized_task_labels
118            .insert(task_label);
119    }
120
121    pub fn run_until_parked(&self) {
122        while self.tick(false) {}
123    }
124
125    pub fn parking_allowed(&self) -> bool {
126        self.state.lock().allow_parking
127    }
128
129    pub fn allow_parking(&self) {
130        self.state.lock().allow_parking = true
131    }
132
133    pub fn forbid_parking(&self) {
134        self.state.lock().allow_parking = false
135    }
136
137    pub fn set_waiting_hint(&self, msg: Option<String>) {
138        self.state.lock().waiting_hint = msg
139    }
140
141    pub fn waiting_hint(&self) -> Option<String> {
142        self.state.lock().waiting_hint.clone()
143    }
144
145    pub fn start_waiting(&self) {
146        self.state.lock().waiting_backtrace = Some(Backtrace::new_unresolved());
147    }
148
149    pub fn finish_waiting(&self) {
150        self.state.lock().waiting_backtrace.take();
151    }
152
153    pub fn waiting_backtrace(&self) -> Option<Backtrace> {
154        self.state.lock().waiting_backtrace.take().map(|mut b| {
155            b.resolve();
156            b
157        })
158    }
159
160    pub fn rng(&self) -> StdRng {
161        self.state.lock().random.clone()
162    }
163
164    pub fn set_block_on_ticks(&self, range: std::ops::RangeInclusive<usize>) {
165        self.state.lock().block_on_ticks = range;
166    }
167
168    pub fn gen_block_on_ticks(&self) -> usize {
169        let mut lock = self.state.lock();
170        let block_on_ticks = lock.block_on_ticks.clone();
171        lock.random.gen_range(block_on_ticks)
172    }
173}
174
175impl Clone for TestDispatcher {
176    fn clone(&self) -> Self {
177        let id = post_inc(&mut self.state.lock().next_id.0);
178        Self {
179            id: TestDispatcherId(id),
180            state: self.state.clone(),
181            parker: self.parker.clone(),
182            unparker: self.unparker.clone(),
183        }
184    }
185}
186
187impl PlatformDispatcher for TestDispatcher {
188    fn is_main_thread(&self) -> bool {
189        self.state.lock().is_main_thread
190    }
191
192    fn dispatch(&self, runnable: Runnable, label: Option<TaskLabel>) {
193        {
194            let mut state = self.state.lock();
195            if label.map_or(false, |label| {
196                state.deprioritized_task_labels.contains(&label)
197            }) {
198                state.deprioritized_background.push(runnable);
199            } else {
200                state.background.push(runnable);
201            }
202        }
203        self.unparker.unpark();
204    }
205
206    fn dispatch_on_main_thread(&self, runnable: Runnable) {
207        self.state
208            .lock()
209            .foreground
210            .entry(self.id)
211            .or_default()
212            .push_back(runnable);
213        self.unparker.unpark();
214    }
215
216    fn dispatch_after(&self, duration: std::time::Duration, runnable: Runnable) {
217        let mut state = self.state.lock();
218        let next_time = state.time + duration;
219        let ix = match state.delayed.binary_search_by_key(&next_time, |e| e.0) {
220            Ok(ix) | Err(ix) => ix,
221        };
222        state.delayed.insert(ix, (next_time, runnable));
223    }
224
225    fn tick(&self, background_only: bool) -> bool {
226        let mut state = self.state.lock();
227
228        while let Some((deadline, _)) = state.delayed.first() {
229            if *deadline > state.time {
230                break;
231            }
232            let (_, runnable) = state.delayed.remove(0);
233            state.background.push(runnable);
234        }
235
236        let foreground_len: usize = if background_only {
237            0
238        } else {
239            state
240                .foreground
241                .values()
242                .map(|runnables| runnables.len())
243                .sum()
244        };
245        let background_len = state.background.len();
246
247        let runnable;
248        let main_thread;
249        if foreground_len == 0 && background_len == 0 {
250            let deprioritized_background_len = state.deprioritized_background.len();
251            if deprioritized_background_len == 0 {
252                return false;
253            }
254            let ix = state.random.gen_range(0..deprioritized_background_len);
255            main_thread = false;
256            runnable = state.deprioritized_background.swap_remove(ix);
257        } else {
258            main_thread = state.random.gen_ratio(
259                foreground_len as u32,
260                (foreground_len + background_len) as u32,
261            );
262            if main_thread {
263                let state = &mut *state;
264                runnable = state
265                    .foreground
266                    .values_mut()
267                    .filter(|runnables| !runnables.is_empty())
268                    .choose(&mut state.random)
269                    .unwrap()
270                    .pop_front()
271                    .unwrap();
272            } else {
273                let ix = state.random.gen_range(0..background_len);
274                runnable = state.background.swap_remove(ix);
275            };
276        };
277
278        let was_main_thread = state.is_main_thread;
279        state.is_main_thread = main_thread;
280        drop(state);
281        runnable.run();
282        self.state.lock().is_main_thread = was_main_thread;
283
284        true
285    }
286
287    fn park(&self) {
288        self.parker.lock().park();
289    }
290
291    fn unparker(&self) -> Unparker {
292        self.unparker.clone()
293    }
294
295    fn as_test(&self) -> Option<&TestDispatcher> {
296        Some(self)
297    }
298}