dispatcher.rs

  1use crate::{PlatformDispatcher, Priority, RunnableVariant, TaskLabel};
  2use backtrace::Backtrace;
  3use collections::{HashMap, HashSet, VecDeque};
  4use parking::Unparker;
  5use parking_lot::Mutex;
  6use rand::prelude::*;
  7use std::{
  8    future::Future,
  9    ops::RangeInclusive,
 10    pin::Pin,
 11    sync::Arc,
 12    task::{Context, Poll},
 13    time::{Duration, Instant},
 14};
 15use util::post_inc;
 16
 17#[derive(Copy, Clone, PartialEq, Eq, Hash)]
 18struct TestDispatcherId(usize);
 19
 20#[doc(hidden)]
 21pub struct TestDispatcher {
 22    id: TestDispatcherId,
 23    state: Arc<Mutex<TestDispatcherState>>,
 24}
 25
 26struct TestDispatcherState {
 27    random: StdRng,
 28    foreground: HashMap<TestDispatcherId, VecDeque<RunnableVariant>>,
 29    background: Vec<RunnableVariant>,
 30    deprioritized_background: Vec<RunnableVariant>,
 31    delayed: Vec<(Duration, RunnableVariant)>,
 32    start_time: Instant,
 33    time: Duration,
 34    is_main_thread: bool,
 35    next_id: TestDispatcherId,
 36    allow_parking: bool,
 37    waiting_hint: Option<String>,
 38    waiting_backtrace: Option<Backtrace>,
 39    deprioritized_task_labels: HashSet<TaskLabel>,
 40    block_on_ticks: RangeInclusive<usize>,
 41    unparkers: Vec<Unparker>,
 42}
 43
 44impl TestDispatcher {
 45    pub fn new(random: StdRng) -> Self {
 46        let state = TestDispatcherState {
 47            random,
 48            foreground: HashMap::default(),
 49            background: Vec::new(),
 50            deprioritized_background: Vec::new(),
 51            delayed: Vec::new(),
 52            time: Duration::ZERO,
 53            start_time: Instant::now(),
 54            is_main_thread: true,
 55            next_id: TestDispatcherId(1),
 56            allow_parking: false,
 57            waiting_hint: None,
 58            waiting_backtrace: None,
 59            deprioritized_task_labels: Default::default(),
 60            block_on_ticks: 0..=1000,
 61            unparkers: Default::default(),
 62        };
 63
 64        TestDispatcher {
 65            id: TestDispatcherId(0),
 66            state: Arc::new(Mutex::new(state)),
 67        }
 68    }
 69
 70    pub fn advance_clock(&self, by: Duration) {
 71        let new_now = self.state.lock().time + by;
 72        loop {
 73            self.run_until_parked();
 74            let state = self.state.lock();
 75            let next_due_time = state.delayed.first().map(|(time, _)| *time);
 76            drop(state);
 77            if let Some(due_time) = next_due_time
 78                && due_time <= new_now
 79            {
 80                self.state.lock().time = due_time;
 81                continue;
 82            }
 83            break;
 84        }
 85        self.state.lock().time = new_now;
 86    }
 87
 88    pub fn advance_clock_to_next_delayed(&self) -> bool {
 89        let next_due_time = self.state.lock().delayed.first().map(|(time, _)| *time);
 90        if let Some(next_due_time) = next_due_time {
 91            self.state.lock().time = next_due_time;
 92            return true;
 93        }
 94        false
 95    }
 96
 97    pub fn simulate_random_delay(&self) -> impl 'static + Send + Future<Output = ()> + use<> {
 98        struct YieldNow {
 99            pub(crate) count: usize,
100        }
101
102        impl Future for YieldNow {
103            type Output = ();
104
105            fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
106                if self.count > 0 {
107                    self.count -= 1;
108                    cx.waker().wake_by_ref();
109                    Poll::Pending
110                } else {
111                    Poll::Ready(())
112                }
113            }
114        }
115
116        YieldNow {
117            count: self.state.lock().random.random_range(0..10),
118        }
119    }
120
121    pub fn tick(&self, background_only: bool) -> bool {
122        let mut state = self.state.lock();
123
124        while let Some((deadline, _)) = state.delayed.first() {
125            if *deadline > state.time {
126                break;
127            }
128            let (_, runnable) = state.delayed.remove(0);
129            state.background.push(runnable);
130        }
131
132        let foreground_len: usize = if background_only {
133            0
134        } else {
135            state
136                .foreground
137                .values()
138                .map(|runnables| runnables.len())
139                .sum()
140        };
141        let background_len = state.background.len();
142
143        let runnable;
144        let main_thread;
145        if foreground_len == 0 && background_len == 0 {
146            let deprioritized_background_len = state.deprioritized_background.len();
147            if deprioritized_background_len == 0 {
148                return false;
149            }
150            let ix = state.random.random_range(0..deprioritized_background_len);
151            main_thread = false;
152            runnable = state.deprioritized_background.swap_remove(ix);
153        } else {
154            main_thread = state.random.random_ratio(
155                foreground_len as u32,
156                (foreground_len + background_len) as u32,
157            );
158            if main_thread {
159                let state = &mut *state;
160                runnable = state
161                    .foreground
162                    .values_mut()
163                    .filter(|runnables| !runnables.is_empty())
164                    .choose(&mut state.random)
165                    .unwrap()
166                    .pop_front()
167                    .unwrap();
168            } else {
169                let ix = state.random.random_range(0..background_len);
170                runnable = state.background.swap_remove(ix);
171            };
172        };
173
174        let was_main_thread = state.is_main_thread;
175        state.is_main_thread = main_thread;
176        drop(state);
177
178        // todo(localcc): add timings to tests
179        match runnable {
180            RunnableVariant::Meta(runnable) => {
181                if !runnable.metadata().is_app_alive() {
182                    drop(runnable);
183                    self.state.lock().is_main_thread = was_main_thread;
184                    return true;
185                }
186                runnable.run()
187            }
188            RunnableVariant::Compat(runnable) => runnable.run(),
189        };
190
191        self.state.lock().is_main_thread = was_main_thread;
192
193        true
194    }
195
196    pub fn deprioritize(&self, task_label: TaskLabel) {
197        self.state
198            .lock()
199            .deprioritized_task_labels
200            .insert(task_label);
201    }
202
203    pub fn run_until_parked(&self) {
204        while self.tick(false) {}
205    }
206
207    pub fn parking_allowed(&self) -> bool {
208        self.state.lock().allow_parking
209    }
210
211    pub fn allow_parking(&self) {
212        self.state.lock().allow_parking = true
213    }
214
215    pub fn forbid_parking(&self) {
216        self.state.lock().allow_parking = false
217    }
218
219    pub fn set_waiting_hint(&self, msg: Option<String>) {
220        self.state.lock().waiting_hint = msg
221    }
222
223    pub fn waiting_hint(&self) -> Option<String> {
224        self.state.lock().waiting_hint.clone()
225    }
226
227    pub fn start_waiting(&self) {
228        self.state.lock().waiting_backtrace = Some(Backtrace::new_unresolved());
229    }
230
231    pub fn finish_waiting(&self) {
232        self.state.lock().waiting_backtrace.take();
233    }
234
235    pub fn waiting_backtrace(&self) -> Option<Backtrace> {
236        self.state.lock().waiting_backtrace.take().map(|mut b| {
237            b.resolve();
238            b
239        })
240    }
241
242    pub fn rng(&self) -> StdRng {
243        self.state.lock().random.clone()
244    }
245
246    pub fn set_block_on_ticks(&self, range: std::ops::RangeInclusive<usize>) {
247        self.state.lock().block_on_ticks = range;
248    }
249
250    pub fn gen_block_on_ticks(&self) -> usize {
251        let mut lock = self.state.lock();
252        let block_on_ticks = lock.block_on_ticks.clone();
253        lock.random.random_range(block_on_ticks)
254    }
255
256    pub fn unpark_all(&self) {
257        self.state.lock().unparkers.retain(|parker| parker.unpark());
258    }
259
260    pub fn push_unparker(&self, unparker: Unparker) {
261        let mut state = self.state.lock();
262        state.unparkers.push(unparker);
263    }
264}
265
266impl Clone for TestDispatcher {
267    fn clone(&self) -> Self {
268        let id = post_inc(&mut self.state.lock().next_id.0);
269        Self {
270            id: TestDispatcherId(id),
271            state: self.state.clone(),
272        }
273    }
274}
275
276impl PlatformDispatcher for TestDispatcher {
277    fn get_all_timings(&self) -> Vec<crate::ThreadTaskTimings> {
278        Vec::new()
279    }
280
281    fn get_current_thread_timings(&self) -> Vec<crate::TaskTiming> {
282        Vec::new()
283    }
284
285    fn is_main_thread(&self) -> bool {
286        self.state.lock().is_main_thread
287    }
288
289    fn now(&self) -> Instant {
290        let state = self.state.lock();
291        state.start_time + state.time
292    }
293
294    fn dispatch(&self, runnable: RunnableVariant, label: Option<TaskLabel>, _priority: Priority) {
295        {
296            let mut state = self.state.lock();
297            if label.is_some_and(|label| state.deprioritized_task_labels.contains(&label)) {
298                state.deprioritized_background.push(runnable);
299            } else {
300                state.background.push(runnable);
301            }
302        }
303        self.unpark_all();
304    }
305
306    fn dispatch_on_main_thread(&self, runnable: RunnableVariant, _priority: Priority) {
307        self.state
308            .lock()
309            .foreground
310            .entry(self.id)
311            .or_default()
312            .push_back(runnable);
313        self.unpark_all();
314    }
315
316    fn dispatch_after(&self, duration: std::time::Duration, runnable: RunnableVariant) {
317        let mut state = self.state.lock();
318        let next_time = state.time + duration;
319        let ix = match state.delayed.binary_search_by_key(&next_time, |e| e.0) {
320            Ok(ix) | Err(ix) => ix,
321        };
322        state.delayed.insert(ix, (next_time, runnable));
323    }
324
325    fn as_test(&self) -> Option<&TestDispatcher> {
326        Some(self)
327    }
328
329    fn spawn_realtime(&self, _priority: crate::RealtimePriority, f: Box<dyn FnOnce() + Send>) {
330        std::thread::spawn(move || {
331            f();
332        });
333    }
334}