dispatcher.rs

  1use crate::{PlatformDispatcher, TaskLabel};
  2use async_task::Runnable;
  3use backtrace::Backtrace;
  4use collections::{HashMap, HashSet, VecDeque};
  5use parking::{Parker, Unparker};
  6use parking_lot::Mutex;
  7use rand::prelude::*;
  8use std::{
  9    future::Future,
 10    ops::RangeInclusive,
 11    pin::Pin,
 12    sync::Arc,
 13    task::{Context, Poll},
 14    time::Duration,
 15};
 16use util::post_inc;
 17
 18#[derive(Copy, Clone, PartialEq, Eq, Hash)]
 19struct TestDispatcherId(usize);
 20
 21pub struct TestDispatcher {
 22    id: TestDispatcherId,
 23    state: Arc<Mutex<TestDispatcherState>>,
 24    parker: Arc<Mutex<Parker>>,
 25    unparker: Unparker,
 26}
 27
 28struct TestDispatcherState {
 29    random: StdRng,
 30    foreground: HashMap<TestDispatcherId, VecDeque<Runnable>>,
 31    background: Vec<Runnable>,
 32    deprioritized_background: Vec<Runnable>,
 33    delayed: Vec<(Duration, Runnable)>,
 34    time: Duration,
 35    is_main_thread: bool,
 36    next_id: TestDispatcherId,
 37    allow_parking: bool,
 38    waiting_backtrace: Option<Backtrace>,
 39    deprioritized_task_labels: HashSet<TaskLabel>,
 40    block_on_ticks: RangeInclusive<usize>,
 41}
 42
 43impl TestDispatcher {
 44    pub fn new(random: StdRng) -> Self {
 45        let (parker, unparker) = parking::pair();
 46        let state = TestDispatcherState {
 47            random,
 48            foreground: HashMap::default(),
 49            background: Vec::new(),
 50            deprioritized_background: Vec::new(),
 51            delayed: Vec::new(),
 52            time: Duration::ZERO,
 53            is_main_thread: true,
 54            next_id: TestDispatcherId(1),
 55            allow_parking: false,
 56            waiting_backtrace: None,
 57            deprioritized_task_labels: Default::default(),
 58            block_on_ticks: 0..=1000,
 59        };
 60
 61        TestDispatcher {
 62            id: TestDispatcherId(0),
 63            state: Arc::new(Mutex::new(state)),
 64            parker: Arc::new(Mutex::new(parker)),
 65            unparker,
 66        }
 67    }
 68
 69    pub fn advance_clock(&self, by: Duration) {
 70        let new_now = self.state.lock().time + by;
 71        loop {
 72            self.run_until_parked();
 73            let state = self.state.lock();
 74            let next_due_time = state.delayed.first().map(|(time, _)| *time);
 75            drop(state);
 76            if let Some(due_time) = next_due_time {
 77                if due_time <= new_now {
 78                    self.state.lock().time = due_time;
 79                    continue;
 80                }
 81            }
 82            break;
 83        }
 84        self.state.lock().time = new_now;
 85    }
 86
 87    pub fn simulate_random_delay(&self) -> impl 'static + Send + Future<Output = ()> {
 88        struct YieldNow {
 89            pub(crate) count: usize,
 90        }
 91
 92        impl Future for YieldNow {
 93            type Output = ();
 94
 95            fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
 96                if self.count > 0 {
 97                    self.count -= 1;
 98                    cx.waker().wake_by_ref();
 99                    Poll::Pending
100                } else {
101                    Poll::Ready(())
102                }
103            }
104        }
105
106        YieldNow {
107            count: self.state.lock().random.gen_range(0..10),
108        }
109    }
110
111    pub fn deprioritize(&self, task_label: TaskLabel) {
112        self.state
113            .lock()
114            .deprioritized_task_labels
115            .insert(task_label);
116    }
117
118    pub fn run_until_parked(&self) {
119        while self.tick(false) {}
120    }
121
122    pub fn parking_allowed(&self) -> bool {
123        self.state.lock().allow_parking
124    }
125
126    pub fn allow_parking(&self) {
127        self.state.lock().allow_parking = true
128    }
129
130    pub fn start_waiting(&self) {
131        self.state.lock().waiting_backtrace = Some(Backtrace::new_unresolved());
132    }
133
134    pub fn finish_waiting(&self) {
135        self.state.lock().waiting_backtrace.take();
136    }
137
138    pub fn waiting_backtrace(&self) -> Option<Backtrace> {
139        self.state.lock().waiting_backtrace.take().map(|mut b| {
140            b.resolve();
141            b
142        })
143    }
144
145    pub fn rng(&self) -> StdRng {
146        self.state.lock().random.clone()
147    }
148
149    pub fn set_block_on_ticks(&self, range: std::ops::RangeInclusive<usize>) {
150        self.state.lock().block_on_ticks = range;
151    }
152
153    pub fn gen_block_on_ticks(&self) -> usize {
154        let mut lock = self.state.lock();
155        let block_on_ticks = lock.block_on_ticks.clone();
156        lock.random.gen_range(block_on_ticks)
157    }
158}
159
160impl Clone for TestDispatcher {
161    fn clone(&self) -> Self {
162        let id = post_inc(&mut self.state.lock().next_id.0);
163        Self {
164            id: TestDispatcherId(id),
165            state: self.state.clone(),
166            parker: self.parker.clone(),
167            unparker: self.unparker.clone(),
168        }
169    }
170}
171
172impl PlatformDispatcher for TestDispatcher {
173    fn is_main_thread(&self) -> bool {
174        self.state.lock().is_main_thread
175    }
176
177    fn dispatch(&self, runnable: Runnable, label: Option<TaskLabel>) {
178        {
179            let mut state = self.state.lock();
180            if label.map_or(false, |label| {
181                state.deprioritized_task_labels.contains(&label)
182            }) {
183                state.deprioritized_background.push(runnable);
184            } else {
185                state.background.push(runnable);
186            }
187        }
188        self.unparker.unpark();
189    }
190
191    fn dispatch_on_main_thread(&self, runnable: Runnable) {
192        self.state
193            .lock()
194            .foreground
195            .entry(self.id)
196            .or_default()
197            .push_back(runnable);
198        self.unparker.unpark();
199    }
200
201    fn dispatch_after(&self, duration: std::time::Duration, runnable: Runnable) {
202        let mut state = self.state.lock();
203        let next_time = state.time + duration;
204        let ix = match state.delayed.binary_search_by_key(&next_time, |e| e.0) {
205            Ok(ix) | Err(ix) => ix,
206        };
207        state.delayed.insert(ix, (next_time, runnable));
208    }
209
210    fn tick(&self, background_only: bool) -> bool {
211        let mut state = self.state.lock();
212
213        while let Some((deadline, _)) = state.delayed.first() {
214            if *deadline > state.time {
215                break;
216            }
217            let (_, runnable) = state.delayed.remove(0);
218            state.background.push(runnable);
219        }
220
221        let foreground_len: usize = if background_only {
222            0
223        } else {
224            state
225                .foreground
226                .values()
227                .map(|runnables| runnables.len())
228                .sum()
229        };
230        let background_len = state.background.len();
231
232        let runnable;
233        let main_thread;
234        if foreground_len == 0 && background_len == 0 {
235            let deprioritized_background_len = state.deprioritized_background.len();
236            if deprioritized_background_len == 0 {
237                return false;
238            }
239            let ix = state.random.gen_range(0..deprioritized_background_len);
240            main_thread = false;
241            runnable = state.deprioritized_background.swap_remove(ix);
242        } else {
243            main_thread = state.random.gen_ratio(
244                foreground_len as u32,
245                (foreground_len + background_len) as u32,
246            );
247            if main_thread {
248                let state = &mut *state;
249                runnable = state
250                    .foreground
251                    .values_mut()
252                    .filter(|runnables| !runnables.is_empty())
253                    .choose(&mut state.random)
254                    .unwrap()
255                    .pop_front()
256                    .unwrap();
257            } else {
258                let ix = state.random.gen_range(0..background_len);
259                runnable = state.background.swap_remove(ix);
260            };
261        };
262
263        let was_main_thread = state.is_main_thread;
264        state.is_main_thread = main_thread;
265        drop(state);
266        runnable.run();
267        self.state.lock().is_main_thread = was_main_thread;
268
269        true
270    }
271
272    fn park(&self) {
273        self.parker.lock().park();
274    }
275
276    fn unparker(&self) -> Unparker {
277        self.unparker.clone()
278    }
279
280    fn as_test(&self) -> Option<&TestDispatcher> {
281        Some(self)
282    }
283}