dispatcher.rs

  1use crate::{PlatformDispatcher, TaskLabel};
  2use async_task::Runnable;
  3use backtrace::Backtrace;
  4use collections::{HashMap, HashSet, VecDeque};
  5use parking::Unparker;
  6use parking_lot::Mutex;
  7use rand::prelude::*;
  8use std::{
  9    future::Future,
 10    ops::RangeInclusive,
 11    pin::Pin,
 12    sync::Arc,
 13    task::{Context, Poll},
 14    time::{Duration, Instant},
 15};
 16use util::post_inc;
 17
 18#[derive(Copy, Clone, PartialEq, Eq, Hash)]
 19struct TestDispatcherId(usize);
 20
 21#[doc(hidden)]
 22pub struct TestDispatcher {
 23    id: TestDispatcherId,
 24    state: Arc<Mutex<TestDispatcherState>>,
 25}
 26
 27struct TestDispatcherState {
 28    random: StdRng,
 29    foreground: HashMap<TestDispatcherId, VecDeque<Runnable>>,
 30    background: Vec<Runnable>,
 31    deprioritized_background: Vec<Runnable>,
 32    delayed: Vec<(Duration, Runnable)>,
 33    start_time: Instant,
 34    time: Duration,
 35    is_main_thread: bool,
 36    next_id: TestDispatcherId,
 37    allow_parking: bool,
 38    waiting_hint: Option<String>,
 39    waiting_backtrace: Option<Backtrace>,
 40    deprioritized_task_labels: HashSet<TaskLabel>,
 41    block_on_ticks: RangeInclusive<usize>,
 42    last_parked: Option<Unparker>,
 43}
 44
 45impl TestDispatcher {
 46    pub fn new(random: StdRng) -> Self {
 47        let state = TestDispatcherState {
 48            random,
 49            foreground: HashMap::default(),
 50            background: Vec::new(),
 51            deprioritized_background: Vec::new(),
 52            delayed: Vec::new(),
 53            time: Duration::ZERO,
 54            start_time: Instant::now(),
 55            is_main_thread: true,
 56            next_id: TestDispatcherId(1),
 57            allow_parking: false,
 58            waiting_hint: None,
 59            waiting_backtrace: None,
 60            deprioritized_task_labels: Default::default(),
 61            block_on_ticks: 0..=1000,
 62            last_parked: None,
 63        };
 64
 65        TestDispatcher {
 66            id: TestDispatcherId(0),
 67            state: Arc::new(Mutex::new(state)),
 68        }
 69    }
 70
 71    pub fn advance_clock(&self, by: Duration) {
 72        let new_now = self.state.lock().time + by;
 73        loop {
 74            self.run_until_parked();
 75            let state = self.state.lock();
 76            let next_due_time = state.delayed.first().map(|(time, _)| *time);
 77            drop(state);
 78            if let Some(due_time) = next_due_time
 79                && due_time <= new_now
 80            {
 81                self.state.lock().time = due_time;
 82                continue;
 83            }
 84            break;
 85        }
 86        self.state.lock().time = new_now;
 87    }
 88
 89    pub fn advance_clock_to_next_delayed(&self) -> bool {
 90        let next_due_time = self.state.lock().delayed.first().map(|(time, _)| *time);
 91        if let Some(next_due_time) = next_due_time {
 92            self.state.lock().time = next_due_time;
 93            return true;
 94        }
 95        false
 96    }
 97
 98    pub fn simulate_random_delay(&self) -> impl 'static + Send + Future<Output = ()> + use<> {
 99        struct YieldNow {
100            pub(crate) count: usize,
101        }
102
103        impl Future for YieldNow {
104            type Output = ();
105
106            fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
107                if self.count > 0 {
108                    self.count -= 1;
109                    cx.waker().wake_by_ref();
110                    Poll::Pending
111                } else {
112                    Poll::Ready(())
113                }
114            }
115        }
116
117        YieldNow {
118            count: self.state.lock().random.random_range(0..10),
119        }
120    }
121
122    pub fn tick(&self, background_only: bool) -> bool {
123        let mut state = self.state.lock();
124
125        while let Some((deadline, _)) = state.delayed.first() {
126            if *deadline > state.time {
127                break;
128            }
129            let (_, runnable) = state.delayed.remove(0);
130            state.background.push(runnable);
131        }
132
133        let foreground_len: usize = if background_only {
134            0
135        } else {
136            state
137                .foreground
138                .values()
139                .map(|runnables| runnables.len())
140                .sum()
141        };
142        let background_len = state.background.len();
143
144        let runnable;
145        let main_thread;
146        if foreground_len == 0 && background_len == 0 {
147            let deprioritized_background_len = state.deprioritized_background.len();
148            if deprioritized_background_len == 0 {
149                return false;
150            }
151            let ix = state.random.random_range(0..deprioritized_background_len);
152            main_thread = false;
153            runnable = state.deprioritized_background.swap_remove(ix);
154        } else {
155            main_thread = state.random.random_ratio(
156                foreground_len as u32,
157                (foreground_len + background_len) as u32,
158            );
159            if main_thread {
160                let state = &mut *state;
161                runnable = state
162                    .foreground
163                    .values_mut()
164                    .filter(|runnables| !runnables.is_empty())
165                    .choose(&mut state.random)
166                    .unwrap()
167                    .pop_front()
168                    .unwrap();
169            } else {
170                let ix = state.random.random_range(0..background_len);
171                runnable = state.background.swap_remove(ix);
172            };
173        };
174
175        let was_main_thread = state.is_main_thread;
176        state.is_main_thread = main_thread;
177        drop(state);
178        runnable.run();
179        self.state.lock().is_main_thread = was_main_thread;
180
181        true
182    }
183
184    pub fn deprioritize(&self, task_label: TaskLabel) {
185        self.state
186            .lock()
187            .deprioritized_task_labels
188            .insert(task_label);
189    }
190
191    pub fn run_until_parked(&self) {
192        while self.tick(false) {}
193    }
194
195    pub fn parking_allowed(&self) -> bool {
196        self.state.lock().allow_parking
197    }
198
199    pub fn allow_parking(&self) {
200        self.state.lock().allow_parking = true
201    }
202
203    pub fn forbid_parking(&self) {
204        self.state.lock().allow_parking = false
205    }
206
207    pub fn set_waiting_hint(&self, msg: Option<String>) {
208        self.state.lock().waiting_hint = msg
209    }
210
211    pub fn waiting_hint(&self) -> Option<String> {
212        self.state.lock().waiting_hint.clone()
213    }
214
215    pub fn start_waiting(&self) {
216        self.state.lock().waiting_backtrace = Some(Backtrace::new_unresolved());
217    }
218
219    pub fn finish_waiting(&self) {
220        self.state.lock().waiting_backtrace.take();
221    }
222
223    pub fn waiting_backtrace(&self) -> Option<Backtrace> {
224        self.state.lock().waiting_backtrace.take().map(|mut b| {
225            b.resolve();
226            b
227        })
228    }
229
230    pub fn rng(&self) -> StdRng {
231        self.state.lock().random.clone()
232    }
233
234    pub fn set_block_on_ticks(&self, range: std::ops::RangeInclusive<usize>) {
235        self.state.lock().block_on_ticks = range;
236    }
237
238    pub fn gen_block_on_ticks(&self) -> usize {
239        let mut lock = self.state.lock();
240        let block_on_ticks = lock.block_on_ticks.clone();
241        lock.random.random_range(block_on_ticks)
242    }
243    pub fn unpark_last(&self) {
244        self.state
245            .lock()
246            .last_parked
247            .take()
248            .as_ref()
249            .map(Unparker::unpark);
250    }
251
252    pub fn set_unparker(&self, unparker: Unparker) {
253        let last = { self.state.lock().last_parked.replace(unparker) };
254        if let Some(last) = last {
255            last.unpark();
256        }
257    }
258}
259
260impl Clone for TestDispatcher {
261    fn clone(&self) -> Self {
262        let id = post_inc(&mut self.state.lock().next_id.0);
263        Self {
264            id: TestDispatcherId(id),
265            state: self.state.clone(),
266        }
267    }
268}
269
270impl PlatformDispatcher for TestDispatcher {
271    fn is_main_thread(&self) -> bool {
272        self.state.lock().is_main_thread
273    }
274
275    fn now(&self) -> Instant {
276        let state = self.state.lock();
277        state.start_time + state.time
278    }
279
280    fn dispatch(&self, runnable: Runnable, label: Option<TaskLabel>) {
281        {
282            let mut state = self.state.lock();
283            if label.is_some_and(|label| state.deprioritized_task_labels.contains(&label)) {
284                state.deprioritized_background.push(runnable);
285            } else {
286                state.background.push(runnable);
287            }
288        }
289        self.unpark_last();
290    }
291
292    fn dispatch_on_main_thread(&self, runnable: Runnable) {
293        self.state
294            .lock()
295            .foreground
296            .entry(self.id)
297            .or_default()
298            .push_back(runnable);
299        self.unpark_last();
300    }
301
302    fn dispatch_after(&self, duration: std::time::Duration, runnable: Runnable) {
303        let mut state = self.state.lock();
304        let next_time = state.time + duration;
305        let ix = match state.delayed.binary_search_by_key(&next_time, |e| e.0) {
306            Ok(ix) | Err(ix) => ix,
307        };
308        state.delayed.insert(ix, (next_time, runnable));
309    }
310
311    fn as_test(&self) -> Option<&TestDispatcher> {
312        Some(self)
313    }
314}