dispatcher.rs

  1use crate::{PlatformDispatcher, Priority, RunnableVariant, TaskLabel};
  2use backtrace::Backtrace;
  3use collections::{HashMap, HashSet, VecDeque};
  4use parking::Unparker;
  5use parking_lot::Mutex;
  6use rand::prelude::*;
  7use std::{
  8    future::Future,
  9    ops::RangeInclusive,
 10    pin::Pin,
 11    sync::Arc,
 12    task::{Context, Poll},
 13    time::{Duration, Instant},
 14};
 15use util::post_inc;
 16
 17#[derive(Copy, Clone, PartialEq, Eq, Hash)]
 18struct TestDispatcherId(usize);
 19
 20#[doc(hidden)]
 21pub struct TestDispatcher {
 22    id: TestDispatcherId,
 23    state: Arc<Mutex<TestDispatcherState>>,
 24}
 25
 26struct TestDispatcherState {
 27    random: StdRng,
 28    foreground: HashMap<TestDispatcherId, VecDeque<RunnableVariant>>,
 29    background: Vec<RunnableVariant>,
 30    deprioritized_background: Vec<RunnableVariant>,
 31    delayed: Vec<(Duration, RunnableVariant)>,
 32    start_time: Instant,
 33    time: Duration,
 34    is_main_thread: bool,
 35    next_id: TestDispatcherId,
 36    allow_parking: bool,
 37    waiting_hint: Option<String>,
 38    waiting_backtrace: Option<Backtrace>,
 39    deprioritized_task_labels: HashSet<TaskLabel>,
 40    block_on_ticks: RangeInclusive<usize>,
 41    unparkers: Vec<Unparker>,
 42}
 43
 44impl TestDispatcher {
 45    pub fn new(random: StdRng) -> Self {
 46        let state = TestDispatcherState {
 47            random,
 48            foreground: HashMap::default(),
 49            background: Vec::new(),
 50            deprioritized_background: Vec::new(),
 51            delayed: Vec::new(),
 52            time: Duration::ZERO,
 53            start_time: Instant::now(),
 54            is_main_thread: true,
 55            next_id: TestDispatcherId(1),
 56            allow_parking: false,
 57            waiting_hint: None,
 58            waiting_backtrace: None,
 59            deprioritized_task_labels: Default::default(),
 60            block_on_ticks: 0..=1000,
 61            unparkers: Default::default(),
 62        };
 63
 64        TestDispatcher {
 65            id: TestDispatcherId(0),
 66            state: Arc::new(Mutex::new(state)),
 67        }
 68    }
 69
 70    pub fn advance_clock(&self, by: Duration) {
 71        let new_now = self.state.lock().time + by;
 72        loop {
 73            self.run_until_parked();
 74            let state = self.state.lock();
 75            let next_due_time = state.delayed.first().map(|(time, _)| *time);
 76            drop(state);
 77            if let Some(due_time) = next_due_time
 78                && due_time <= new_now
 79            {
 80                self.state.lock().time = due_time;
 81                continue;
 82            }
 83            break;
 84        }
 85        self.state.lock().time = new_now;
 86    }
 87
 88    pub fn advance_clock_to_next_delayed(&self) -> bool {
 89        let next_due_time = self.state.lock().delayed.first().map(|(time, _)| *time);
 90        if let Some(next_due_time) = next_due_time {
 91            self.state.lock().time = next_due_time;
 92            return true;
 93        }
 94        false
 95    }
 96
 97    pub fn simulate_random_delay(&self) -> impl 'static + Send + Future<Output = ()> + use<> {
 98        struct YieldNow {
 99            pub(crate) count: usize,
100        }
101
102        impl Future for YieldNow {
103            type Output = ();
104
105            fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
106                if self.count > 0 {
107                    self.count -= 1;
108                    cx.waker().wake_by_ref();
109                    Poll::Pending
110                } else {
111                    Poll::Ready(())
112                }
113            }
114        }
115
116        YieldNow {
117            count: self.state.lock().random.random_range(0..10),
118        }
119    }
120
121    pub fn tick(&self, background_only: bool) -> bool {
122        let mut state = self.state.lock();
123
124        while let Some((deadline, _)) = state.delayed.first() {
125            if *deadline > state.time {
126                break;
127            }
128            let (_, runnable) = state.delayed.remove(0);
129            state.background.push(runnable);
130        }
131
132        let foreground_len: usize = if background_only {
133            0
134        } else {
135            state
136                .foreground
137                .values()
138                .map(|runnables| runnables.len())
139                .sum()
140        };
141        let background_len = state.background.len();
142
143        let runnable;
144        let main_thread;
145        if foreground_len == 0 && background_len == 0 {
146            let deprioritized_background_len = state.deprioritized_background.len();
147            if deprioritized_background_len == 0 {
148                return false;
149            }
150            let ix = state.random.random_range(0..deprioritized_background_len);
151            main_thread = false;
152            runnable = state.deprioritized_background.swap_remove(ix);
153        } else {
154            main_thread = state.random.random_ratio(
155                foreground_len as u32,
156                (foreground_len + background_len) as u32,
157            );
158            if main_thread {
159                let state = &mut *state;
160                runnable = state
161                    .foreground
162                    .values_mut()
163                    .filter(|runnables| !runnables.is_empty())
164                    .choose(&mut state.random)
165                    .unwrap()
166                    .pop_front()
167                    .unwrap();
168            } else {
169                let ix = state.random.random_range(0..background_len);
170                runnable = state.background.swap_remove(ix);
171            };
172        };
173
174        let was_main_thread = state.is_main_thread;
175        state.is_main_thread = main_thread;
176        drop(state);
177
178        // todo(localcc): add timings to tests
179        match runnable {
180            RunnableVariant::Meta(runnable) => {
181                if !runnable.metadata().is_app_alive() {
182                    drop(runnable);
183                } else {
184                    runnable.run();
185                }
186            }
187            RunnableVariant::Compat(runnable) => {
188                runnable.run();
189            }
190        };
191
192        self.state.lock().is_main_thread = was_main_thread;
193
194        true
195    }
196
197    pub fn deprioritize(&self, task_label: TaskLabel) {
198        self.state
199            .lock()
200            .deprioritized_task_labels
201            .insert(task_label);
202    }
203
204    pub fn run_until_parked(&self) {
205        while self.tick(false) {}
206    }
207
208    pub fn parking_allowed(&self) -> bool {
209        self.state.lock().allow_parking
210    }
211
212    pub fn allow_parking(&self) {
213        self.state.lock().allow_parking = true
214    }
215
216    pub fn forbid_parking(&self) {
217        self.state.lock().allow_parking = false
218    }
219
220    pub fn set_waiting_hint(&self, msg: Option<String>) {
221        self.state.lock().waiting_hint = msg
222    }
223
224    pub fn waiting_hint(&self) -> Option<String> {
225        self.state.lock().waiting_hint.clone()
226    }
227
228    pub fn start_waiting(&self) {
229        self.state.lock().waiting_backtrace = Some(Backtrace::new_unresolved());
230    }
231
232    pub fn finish_waiting(&self) {
233        self.state.lock().waiting_backtrace.take();
234    }
235
236    pub fn waiting_backtrace(&self) -> Option<Backtrace> {
237        self.state.lock().waiting_backtrace.take().map(|mut b| {
238            b.resolve();
239            b
240        })
241    }
242
243    pub fn rng(&self) -> StdRng {
244        self.state.lock().random.clone()
245    }
246
247    pub fn set_block_on_ticks(&self, range: std::ops::RangeInclusive<usize>) {
248        self.state.lock().block_on_ticks = range;
249    }
250
251    pub fn gen_block_on_ticks(&self) -> usize {
252        let mut lock = self.state.lock();
253        let block_on_ticks = lock.block_on_ticks.clone();
254        lock.random.random_range(block_on_ticks)
255    }
256
257    pub fn unpark_all(&self) {
258        self.state.lock().unparkers.retain(|parker| parker.unpark());
259    }
260
261    pub fn push_unparker(&self, unparker: Unparker) {
262        let mut state = self.state.lock();
263        state.unparkers.push(unparker);
264    }
265}
266
267impl Clone for TestDispatcher {
268    fn clone(&self) -> Self {
269        let id = post_inc(&mut self.state.lock().next_id.0);
270        Self {
271            id: TestDispatcherId(id),
272            state: self.state.clone(),
273        }
274    }
275}
276
277impl PlatformDispatcher for TestDispatcher {
278    fn get_all_timings(&self) -> Vec<crate::ThreadTaskTimings> {
279        Vec::new()
280    }
281
282    fn get_current_thread_timings(&self) -> Vec<crate::TaskTiming> {
283        Vec::new()
284    }
285
286    fn is_main_thread(&self) -> bool {
287        self.state.lock().is_main_thread
288    }
289
290    fn now(&self) -> Instant {
291        let state = self.state.lock();
292        state.start_time + state.time
293    }
294
295    fn dispatch(&self, runnable: RunnableVariant, label: Option<TaskLabel>, _priority: Priority) {
296        {
297            let mut state = self.state.lock();
298            if label.is_some_and(|label| state.deprioritized_task_labels.contains(&label)) {
299                state.deprioritized_background.push(runnable);
300            } else {
301                state.background.push(runnable);
302            }
303        }
304        self.unpark_all();
305    }
306
307    fn dispatch_on_main_thread(&self, runnable: RunnableVariant, _priority: Priority) {
308        self.state
309            .lock()
310            .foreground
311            .entry(self.id)
312            .or_default()
313            .push_back(runnable);
314        self.unpark_all();
315    }
316
317    fn dispatch_after(&self, duration: std::time::Duration, runnable: RunnableVariant) {
318        let mut state = self.state.lock();
319        let next_time = state.time + duration;
320        let ix = match state.delayed.binary_search_by_key(&next_time, |e| e.0) {
321            Ok(ix) | Err(ix) => ix,
322        };
323        state.delayed.insert(ix, (next_time, runnable));
324    }
325
326    fn as_test(&self) -> Option<&TestDispatcher> {
327        Some(self)
328    }
329
330    fn spawn_realtime(&self, _priority: crate::RealtimePriority, f: Box<dyn FnOnce() + Send>) {
331        std::thread::spawn(move || {
332            f();
333        });
334    }
335}