dispatcher.rs

  1use crate::{PlatformDispatcher, RunnableVariant, TaskLabel};
  2use backtrace::Backtrace;
  3use collections::{HashMap, HashSet, VecDeque};
  4use parking::Unparker;
  5use parking_lot::Mutex;
  6use rand::prelude::*;
  7use std::{
  8    future::Future,
  9    ops::RangeInclusive,
 10    pin::Pin,
 11    sync::Arc,
 12    task::{Context, Poll},
 13    time::{Duration, Instant},
 14};
 15use util::post_inc;
 16
 17#[derive(Copy, Clone, PartialEq, Eq, Hash)]
 18struct TestDispatcherId(usize);
 19
 20#[doc(hidden)]
 21pub struct TestDispatcher {
 22    id: TestDispatcherId,
 23    state: Arc<Mutex<TestDispatcherState>>,
 24}
 25
 26struct TestDispatcherState {
 27    random: StdRng,
 28    foreground: HashMap<TestDispatcherId, VecDeque<RunnableVariant>>,
 29    background: Vec<RunnableVariant>,
 30    deprioritized_background: Vec<RunnableVariant>,
 31    delayed: Vec<(Duration, RunnableVariant)>,
 32    start_time: Instant,
 33    time: Duration,
 34    is_main_thread: bool,
 35    next_id: TestDispatcherId,
 36    allow_parking: bool,
 37    waiting_hint: Option<String>,
 38    waiting_backtrace: Option<Backtrace>,
 39    deprioritized_task_labels: HashSet<TaskLabel>,
 40    block_on_ticks: RangeInclusive<usize>,
 41    last_parked: Option<Unparker>,
 42}
 43
 44impl TestDispatcher {
 45    pub fn new(random: StdRng) -> Self {
 46        let state = TestDispatcherState {
 47            random,
 48            foreground: HashMap::default(),
 49            background: Vec::new(),
 50            deprioritized_background: Vec::new(),
 51            delayed: Vec::new(),
 52            time: Duration::ZERO,
 53            start_time: Instant::now(),
 54            is_main_thread: true,
 55            next_id: TestDispatcherId(1),
 56            allow_parking: false,
 57            waiting_hint: None,
 58            waiting_backtrace: None,
 59            deprioritized_task_labels: Default::default(),
 60            block_on_ticks: 0..=1000,
 61            last_parked: None,
 62        };
 63
 64        TestDispatcher {
 65            id: TestDispatcherId(0),
 66            state: Arc::new(Mutex::new(state)),
 67        }
 68    }
 69
 70    pub fn advance_clock(&self, by: Duration) {
 71        let new_now = self.state.lock().time + by;
 72        loop {
 73            self.run_until_parked();
 74            let state = self.state.lock();
 75            let next_due_time = state.delayed.first().map(|(time, _)| *time);
 76            drop(state);
 77            if let Some(due_time) = next_due_time
 78                && due_time <= new_now
 79            {
 80                self.state.lock().time = due_time;
 81                continue;
 82            }
 83            break;
 84        }
 85        self.state.lock().time = new_now;
 86    }
 87
 88    pub fn advance_clock_to_next_delayed(&self) -> bool {
 89        let next_due_time = self.state.lock().delayed.first().map(|(time, _)| *time);
 90        if let Some(next_due_time) = next_due_time {
 91            self.state.lock().time = next_due_time;
 92            return true;
 93        }
 94        false
 95    }
 96
 97    pub fn simulate_random_delay(&self) -> impl 'static + Send + Future<Output = ()> + use<> {
 98        struct YieldNow {
 99            pub(crate) count: usize,
100        }
101
102        impl Future for YieldNow {
103            type Output = ();
104
105            fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
106                if self.count > 0 {
107                    self.count -= 1;
108                    cx.waker().wake_by_ref();
109                    Poll::Pending
110                } else {
111                    Poll::Ready(())
112                }
113            }
114        }
115
116        YieldNow {
117            count: self.state.lock().random.random_range(0..10),
118        }
119    }
120
121    pub fn tick(&self, background_only: bool) -> bool {
122        let mut state = self.state.lock();
123
124        while let Some((deadline, _)) = state.delayed.first() {
125            if *deadline > state.time {
126                break;
127            }
128            let (_, runnable) = state.delayed.remove(0);
129            state.background.push(runnable);
130        }
131
132        let foreground_len: usize = if background_only {
133            0
134        } else {
135            state
136                .foreground
137                .values()
138                .map(|runnables| runnables.len())
139                .sum()
140        };
141        let background_len = state.background.len();
142
143        let runnable;
144        let main_thread;
145        if foreground_len == 0 && background_len == 0 {
146            let deprioritized_background_len = state.deprioritized_background.len();
147            if deprioritized_background_len == 0 {
148                return false;
149            }
150            let ix = state.random.random_range(0..deprioritized_background_len);
151            main_thread = false;
152            runnable = state.deprioritized_background.swap_remove(ix);
153        } else {
154            main_thread = state.random.random_ratio(
155                foreground_len as u32,
156                (foreground_len + background_len) as u32,
157            );
158            if main_thread {
159                let state = &mut *state;
160                runnable = state
161                    .foreground
162                    .values_mut()
163                    .filter(|runnables| !runnables.is_empty())
164                    .choose(&mut state.random)
165                    .unwrap()
166                    .pop_front()
167                    .unwrap();
168            } else {
169                let ix = state.random.random_range(0..background_len);
170                runnable = state.background.swap_remove(ix);
171            };
172        };
173
174        let was_main_thread = state.is_main_thread;
175        state.is_main_thread = main_thread;
176        drop(state);
177
178        // todo(localcc): add timings to tests
179        match runnable {
180            RunnableVariant::Meta(runnable) => runnable.run(),
181            RunnableVariant::Compat(runnable) => runnable.run(),
182        };
183
184        self.state.lock().is_main_thread = was_main_thread;
185
186        true
187    }
188
189    pub fn deprioritize(&self, task_label: TaskLabel) {
190        self.state
191            .lock()
192            .deprioritized_task_labels
193            .insert(task_label);
194    }
195
196    pub fn run_until_parked(&self) {
197        while self.tick(false) {}
198    }
199
200    pub fn parking_allowed(&self) -> bool {
201        self.state.lock().allow_parking
202    }
203
204    pub fn allow_parking(&self) {
205        self.state.lock().allow_parking = true
206    }
207
208    pub fn forbid_parking(&self) {
209        self.state.lock().allow_parking = false
210    }
211
212    pub fn set_waiting_hint(&self, msg: Option<String>) {
213        self.state.lock().waiting_hint = msg
214    }
215
216    pub fn waiting_hint(&self) -> Option<String> {
217        self.state.lock().waiting_hint.clone()
218    }
219
220    pub fn start_waiting(&self) {
221        self.state.lock().waiting_backtrace = Some(Backtrace::new_unresolved());
222    }
223
224    pub fn finish_waiting(&self) {
225        self.state.lock().waiting_backtrace.take();
226    }
227
228    pub fn waiting_backtrace(&self) -> Option<Backtrace> {
229        self.state.lock().waiting_backtrace.take().map(|mut b| {
230            b.resolve();
231            b
232        })
233    }
234
235    pub fn rng(&self) -> StdRng {
236        self.state.lock().random.clone()
237    }
238
239    pub fn set_block_on_ticks(&self, range: std::ops::RangeInclusive<usize>) {
240        self.state.lock().block_on_ticks = range;
241    }
242
243    pub fn gen_block_on_ticks(&self) -> usize {
244        let mut lock = self.state.lock();
245        let block_on_ticks = lock.block_on_ticks.clone();
246        lock.random.random_range(block_on_ticks)
247    }
248    pub fn unpark_last(&self) {
249        self.state
250            .lock()
251            .last_parked
252            .take()
253            .as_ref()
254            .map(Unparker::unpark);
255    }
256
257    pub fn set_unparker(&self, unparker: Unparker) {
258        let last = { self.state.lock().last_parked.replace(unparker) };
259        if let Some(last) = last {
260            last.unpark();
261        }
262    }
263}
264
265impl Clone for TestDispatcher {
266    fn clone(&self) -> Self {
267        let id = post_inc(&mut self.state.lock().next_id.0);
268        Self {
269            id: TestDispatcherId(id),
270            state: self.state.clone(),
271        }
272    }
273}
274
275impl PlatformDispatcher for TestDispatcher {
276    fn get_all_timings(&self) -> Vec<crate::ThreadTaskTimings> {
277        Vec::new()
278    }
279
280    fn get_current_thread_timings(&self) -> Vec<crate::TaskTiming> {
281        Vec::new()
282    }
283
284    fn is_main_thread(&self) -> bool {
285        self.state.lock().is_main_thread
286    }
287
288    fn now(&self) -> Instant {
289        let state = self.state.lock();
290        state.start_time + state.time
291    }
292
293    fn dispatch(&self, runnable: RunnableVariant, label: Option<TaskLabel>) {
294        {
295            let mut state = self.state.lock();
296            if label.is_some_and(|label| state.deprioritized_task_labels.contains(&label)) {
297                state.deprioritized_background.push(runnable);
298            } else {
299                state.background.push(runnable);
300            }
301        }
302        self.unpark_last();
303    }
304
305    fn dispatch_on_main_thread(&self, runnable: RunnableVariant) {
306        self.state
307            .lock()
308            .foreground
309            .entry(self.id)
310            .or_default()
311            .push_back(runnable);
312        self.unpark_last();
313    }
314
315    fn dispatch_after(&self, duration: std::time::Duration, runnable: RunnableVariant) {
316        let mut state = self.state.lock();
317        let next_time = state.time + duration;
318        let ix = match state.delayed.binary_search_by_key(&next_time, |e| e.0) {
319            Ok(ix) | Err(ix) => ix,
320        };
321        state.delayed.insert(ix, (next_time, runnable));
322    }
323
324    fn as_test(&self) -> Option<&TestDispatcher> {
325        Some(self)
326    }
327}