dispatcher.rs

  1use crate::{PlatformDispatcher, TaskLabel};
  2use async_task::Runnable;
  3use backtrace::Backtrace;
  4use collections::{HashMap, HashSet, VecDeque};
  5use parking::{Parker, Unparker};
  6use parking_lot::Mutex;
  7use rand::prelude::*;
  8use std::{
  9    future::Future,
 10    ops::RangeInclusive,
 11    pin::Pin,
 12    sync::Arc,
 13    task::{Context, Poll},
 14    time::{Duration, Instant},
 15};
 16use util::post_inc;
 17
 18#[derive(Copy, Clone, PartialEq, Eq, Hash)]
 19struct TestDispatcherId(usize);
 20
 21#[doc(hidden)]
 22pub struct TestDispatcher {
 23    id: TestDispatcherId,
 24    state: Arc<Mutex<TestDispatcherState>>,
 25    parker: Arc<Mutex<Parker>>,
 26    unparker: Unparker,
 27}
 28
 29struct TestDispatcherState {
 30    random: StdRng,
 31    foreground: HashMap<TestDispatcherId, VecDeque<Runnable>>,
 32    background: Vec<Runnable>,
 33    deprioritized_background: Vec<Runnable>,
 34    delayed: Vec<(Duration, Runnable)>,
 35    start_time: Instant,
 36    time: Duration,
 37    is_main_thread: bool,
 38    next_id: TestDispatcherId,
 39    allow_parking: bool,
 40    waiting_hint: Option<String>,
 41    waiting_backtrace: Option<Backtrace>,
 42    deprioritized_task_labels: HashSet<TaskLabel>,
 43    block_on_ticks: RangeInclusive<usize>,
 44}
 45
 46impl TestDispatcher {
 47    pub fn new(random: StdRng) -> Self {
 48        let (parker, unparker) = parking::pair();
 49        let state = TestDispatcherState {
 50            random,
 51            foreground: HashMap::default(),
 52            background: Vec::new(),
 53            deprioritized_background: Vec::new(),
 54            delayed: Vec::new(),
 55            time: Duration::ZERO,
 56            start_time: Instant::now(),
 57            is_main_thread: true,
 58            next_id: TestDispatcherId(1),
 59            allow_parking: false,
 60            waiting_hint: None,
 61            waiting_backtrace: None,
 62            deprioritized_task_labels: Default::default(),
 63            block_on_ticks: 0..=1000,
 64        };
 65
 66        TestDispatcher {
 67            id: TestDispatcherId(0),
 68            state: Arc::new(Mutex::new(state)),
 69            parker: Arc::new(Mutex::new(parker)),
 70            unparker,
 71        }
 72    }
 73
 74    pub fn advance_clock(&self, by: Duration) {
 75        let new_now = self.state.lock().time + by;
 76        loop {
 77            self.run_until_parked();
 78            let state = self.state.lock();
 79            let next_due_time = state.delayed.first().map(|(time, _)| *time);
 80            drop(state);
 81            if let Some(due_time) = next_due_time
 82                && due_time <= new_now
 83            {
 84                self.state.lock().time = due_time;
 85                continue;
 86            }
 87            break;
 88        }
 89        self.state.lock().time = new_now;
 90    }
 91
 92    pub fn advance_clock_to_next_delayed(&self) -> bool {
 93        let next_due_time = self.state.lock().delayed.first().map(|(time, _)| *time);
 94        if let Some(next_due_time) = next_due_time {
 95            self.state.lock().time = next_due_time;
 96            return true;
 97        }
 98        false
 99    }
100
101    pub fn simulate_random_delay(&self) -> impl 'static + Send + Future<Output = ()> + use<> {
102        struct YieldNow {
103            pub(crate) count: usize,
104        }
105
106        impl Future for YieldNow {
107            type Output = ();
108
109            fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
110                if self.count > 0 {
111                    self.count -= 1;
112                    cx.waker().wake_by_ref();
113                    Poll::Pending
114                } else {
115                    Poll::Ready(())
116                }
117            }
118        }
119
120        YieldNow {
121            count: self.state.lock().random.random_range(0..10),
122        }
123    }
124
125    pub fn tick(&self, background_only: bool) -> bool {
126        let mut state = self.state.lock();
127
128        while let Some((deadline, _)) = state.delayed.first() {
129            if *deadline > state.time {
130                break;
131            }
132            let (_, runnable) = state.delayed.remove(0);
133            state.background.push(runnable);
134        }
135
136        let foreground_len: usize = if background_only {
137            0
138        } else {
139            state
140                .foreground
141                .values()
142                .map(|runnables| runnables.len())
143                .sum()
144        };
145        let background_len = state.background.len();
146
147        let runnable;
148        let main_thread;
149        if foreground_len == 0 && background_len == 0 {
150            let deprioritized_background_len = state.deprioritized_background.len();
151            if deprioritized_background_len == 0 {
152                return false;
153            }
154            let ix = state.random.random_range(0..deprioritized_background_len);
155            main_thread = false;
156            runnable = state.deprioritized_background.swap_remove(ix);
157        } else {
158            main_thread = state.random.random_ratio(
159                foreground_len as u32,
160                (foreground_len + background_len) as u32,
161            );
162            if main_thread {
163                let state = &mut *state;
164                runnable = state
165                    .foreground
166                    .values_mut()
167                    .filter(|runnables| !runnables.is_empty())
168                    .choose(&mut state.random)
169                    .unwrap()
170                    .pop_front()
171                    .unwrap();
172            } else {
173                let ix = state.random.random_range(0..background_len);
174                runnable = state.background.swap_remove(ix);
175            };
176        };
177
178        let was_main_thread = state.is_main_thread;
179        state.is_main_thread = main_thread;
180        drop(state);
181        runnable.run();
182        self.state.lock().is_main_thread = was_main_thread;
183
184        true
185    }
186
187    pub fn deprioritize(&self, task_label: TaskLabel) {
188        self.state
189            .lock()
190            .deprioritized_task_labels
191            .insert(task_label);
192    }
193
194    pub fn run_until_parked(&self) {
195        while self.tick(false) {}
196    }
197
198    pub fn parking_allowed(&self) -> bool {
199        self.state.lock().allow_parking
200    }
201
202    pub fn allow_parking(&self) {
203        self.state.lock().allow_parking = true
204    }
205
206    pub fn forbid_parking(&self) {
207        self.state.lock().allow_parking = false
208    }
209
210    pub fn set_waiting_hint(&self, msg: Option<String>) {
211        self.state.lock().waiting_hint = msg
212    }
213
214    pub fn waiting_hint(&self) -> Option<String> {
215        self.state.lock().waiting_hint.clone()
216    }
217
218    pub fn start_waiting(&self) {
219        self.state.lock().waiting_backtrace = Some(Backtrace::new_unresolved());
220    }
221
222    pub fn finish_waiting(&self) {
223        self.state.lock().waiting_backtrace.take();
224    }
225
226    pub fn waiting_backtrace(&self) -> Option<Backtrace> {
227        self.state.lock().waiting_backtrace.take().map(|mut b| {
228            b.resolve();
229            b
230        })
231    }
232
233    pub fn rng(&self) -> StdRng {
234        self.state.lock().random.clone()
235    }
236
237    pub fn set_block_on_ticks(&self, range: std::ops::RangeInclusive<usize>) {
238        self.state.lock().block_on_ticks = range;
239    }
240
241    pub fn gen_block_on_ticks(&self) -> usize {
242        let mut lock = self.state.lock();
243        let block_on_ticks = lock.block_on_ticks.clone();
244        lock.random.random_range(block_on_ticks)
245    }
246}
247
248impl Clone for TestDispatcher {
249    fn clone(&self) -> Self {
250        let id = post_inc(&mut self.state.lock().next_id.0);
251        Self {
252            id: TestDispatcherId(id),
253            state: self.state.clone(),
254            parker: self.parker.clone(),
255            unparker: self.unparker.clone(),
256        }
257    }
258}
259
260impl PlatformDispatcher for TestDispatcher {
261    fn is_main_thread(&self) -> bool {
262        self.state.lock().is_main_thread
263    }
264
265    fn now(&self) -> Instant {
266        let state = self.state.lock();
267        state.start_time + state.time
268    }
269
270    fn dispatch(&self, runnable: Runnable, label: Option<TaskLabel>) {
271        {
272            let mut state = self.state.lock();
273            if label.is_some_and(|label| state.deprioritized_task_labels.contains(&label)) {
274                state.deprioritized_background.push(runnable);
275            } else {
276                state.background.push(runnable);
277            }
278        }
279        self.unparker.unpark();
280    }
281
282    fn dispatch_on_main_thread(&self, runnable: Runnable) {
283        self.state
284            .lock()
285            .foreground
286            .entry(self.id)
287            .or_default()
288            .push_back(runnable);
289        self.unparker.unpark();
290    }
291
292    fn dispatch_after(&self, duration: std::time::Duration, runnable: Runnable) {
293        let mut state = self.state.lock();
294        let next_time = state.time + duration;
295        let ix = match state.delayed.binary_search_by_key(&next_time, |e| e.0) {
296            Ok(ix) | Err(ix) => ix,
297        };
298        state.delayed.insert(ix, (next_time, runnable));
299    }
300    fn park(&self, _: Option<std::time::Duration>) -> bool {
301        self.parker.lock().park();
302        true
303    }
304
305    fn unparker(&self) -> Unparker {
306        self.unparker.clone()
307    }
308
309    fn as_test(&self) -> Option<&TestDispatcher> {
310        Some(self)
311    }
312}