dispatcher.rs

  1use crate::{PlatformDispatcher, TaskLabel};
  2use async_task::Runnable;
  3use backtrace::Backtrace;
  4use collections::{HashMap, HashSet, VecDeque};
  5use parking::{Parker, Unparker};
  6use parking_lot::Mutex;
  7use rand::prelude::*;
  8use std::{
  9    future::Future,
 10    pin::Pin,
 11    sync::Arc,
 12    task::{Context, Poll},
 13    time::Duration,
 14};
 15use util::post_inc;
 16
 17#[derive(Copy, Clone, PartialEq, Eq, Hash)]
 18struct TestDispatcherId(usize);
 19
 20pub struct TestDispatcher {
 21    id: TestDispatcherId,
 22    state: Arc<Mutex<TestDispatcherState>>,
 23    parker: Arc<Mutex<Parker>>,
 24    unparker: Unparker,
 25}
 26
 27struct TestDispatcherState {
 28    random: StdRng,
 29    foreground: HashMap<TestDispatcherId, VecDeque<Runnable>>,
 30    background: Vec<Runnable>,
 31    deprioritized_background: Vec<Runnable>,
 32    delayed: Vec<(Duration, Runnable)>,
 33    time: Duration,
 34    is_main_thread: bool,
 35    next_id: TestDispatcherId,
 36    allow_parking: bool,
 37    waiting_backtrace: Option<Backtrace>,
 38    deprioritized_task_labels: HashSet<TaskLabel>,
 39}
 40
 41impl TestDispatcher {
 42    pub fn new(random: StdRng) -> Self {
 43        let (parker, unparker) = parking::pair();
 44        let state = TestDispatcherState {
 45            random,
 46            foreground: HashMap::default(),
 47            background: Vec::new(),
 48            deprioritized_background: Vec::new(),
 49            delayed: Vec::new(),
 50            time: Duration::ZERO,
 51            is_main_thread: true,
 52            next_id: TestDispatcherId(1),
 53            allow_parking: false,
 54            waiting_backtrace: None,
 55            deprioritized_task_labels: Default::default(),
 56        };
 57
 58        TestDispatcher {
 59            id: TestDispatcherId(0),
 60            state: Arc::new(Mutex::new(state)),
 61            parker: Arc::new(Mutex::new(parker)),
 62            unparker,
 63        }
 64    }
 65
 66    pub fn advance_clock(&self, by: Duration) {
 67        let new_now = self.state.lock().time + by;
 68        loop {
 69            self.run_until_parked();
 70            let state = self.state.lock();
 71            let next_due_time = state.delayed.first().map(|(time, _)| *time);
 72            drop(state);
 73            if let Some(due_time) = next_due_time {
 74                if due_time <= new_now {
 75                    self.state.lock().time = due_time;
 76                    continue;
 77                }
 78            }
 79            break;
 80        }
 81        self.state.lock().time = new_now;
 82    }
 83
 84    pub fn simulate_random_delay(&self) -> impl 'static + Send + Future<Output = ()> {
 85        pub struct YieldNow {
 86            count: usize,
 87        }
 88
 89        impl Future for YieldNow {
 90            type Output = ();
 91
 92            fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
 93                if self.count > 0 {
 94                    self.count -= 1;
 95                    cx.waker().wake_by_ref();
 96                    Poll::Pending
 97                } else {
 98                    Poll::Ready(())
 99                }
100            }
101        }
102
103        YieldNow {
104            count: self.state.lock().random.gen_range(0..10),
105        }
106    }
107
108    pub fn deprioritize(&self, task_label: TaskLabel) {
109        self.state
110            .lock()
111            .deprioritized_task_labels
112            .insert(task_label);
113    }
114
115    pub fn run_until_parked(&self) {
116        while self.tick(false) {}
117    }
118
119    pub fn parking_allowed(&self) -> bool {
120        self.state.lock().allow_parking
121    }
122
123    pub fn allow_parking(&self) {
124        self.state.lock().allow_parking = true
125    }
126
127    pub fn start_waiting(&self) {
128        self.state.lock().waiting_backtrace = Some(Backtrace::new_unresolved());
129    }
130
131    pub fn finish_waiting(&self) {
132        self.state.lock().waiting_backtrace.take();
133    }
134
135    pub fn waiting_backtrace(&self) -> Option<Backtrace> {
136        self.state.lock().waiting_backtrace.take().map(|mut b| {
137            b.resolve();
138            b
139        })
140    }
141
142    pub fn rng(&self) -> StdRng {
143        self.state.lock().random.clone()
144    }
145}
146
147impl Clone for TestDispatcher {
148    fn clone(&self) -> Self {
149        let id = post_inc(&mut self.state.lock().next_id.0);
150        Self {
151            id: TestDispatcherId(id),
152            state: self.state.clone(),
153            parker: self.parker.clone(),
154            unparker: self.unparker.clone(),
155        }
156    }
157}
158
159impl PlatformDispatcher for TestDispatcher {
160    fn is_main_thread(&self) -> bool {
161        self.state.lock().is_main_thread
162    }
163
164    fn dispatch(&self, runnable: Runnable, label: Option<TaskLabel>) {
165        {
166            let mut state = self.state.lock();
167            if label.map_or(false, |label| {
168                state.deprioritized_task_labels.contains(&label)
169            }) {
170                state.deprioritized_background.push(runnable);
171            } else {
172                state.background.push(runnable);
173            }
174        }
175        self.unparker.unpark();
176    }
177
178    fn dispatch_on_main_thread(&self, runnable: Runnable) {
179        self.state
180            .lock()
181            .foreground
182            .entry(self.id)
183            .or_default()
184            .push_back(runnable);
185        self.unparker.unpark();
186    }
187
188    fn dispatch_after(&self, duration: std::time::Duration, runnable: Runnable) {
189        let mut state = self.state.lock();
190        let next_time = state.time + duration;
191        let ix = match state.delayed.binary_search_by_key(&next_time, |e| e.0) {
192            Ok(ix) | Err(ix) => ix,
193        };
194        state.delayed.insert(ix, (next_time, runnable));
195    }
196
197    fn tick(&self, background_only: bool) -> bool {
198        let mut state = self.state.lock();
199
200        while let Some((deadline, _)) = state.delayed.first() {
201            if *deadline > state.time {
202                break;
203            }
204            let (_, runnable) = state.delayed.remove(0);
205            state.background.push(runnable);
206        }
207
208        let foreground_len: usize = if background_only {
209            0
210        } else {
211            state
212                .foreground
213                .values()
214                .map(|runnables| runnables.len())
215                .sum()
216        };
217        let background_len = state.background.len();
218
219        let runnable;
220        let main_thread;
221        if foreground_len == 0 && background_len == 0 {
222            let deprioritized_background_len = state.deprioritized_background.len();
223            if deprioritized_background_len == 0 {
224                return false;
225            }
226            let ix = state.random.gen_range(0..deprioritized_background_len);
227            main_thread = false;
228            runnable = state.deprioritized_background.swap_remove(ix);
229        } else {
230            main_thread = state.random.gen_ratio(
231                foreground_len as u32,
232                (foreground_len + background_len) as u32,
233            );
234            if main_thread {
235                let state = &mut *state;
236                runnable = state
237                    .foreground
238                    .values_mut()
239                    .filter(|runnables| !runnables.is_empty())
240                    .choose(&mut state.random)
241                    .unwrap()
242                    .pop_front()
243                    .unwrap();
244            } else {
245                let ix = state.random.gen_range(0..background_len);
246                runnable = state.background.swap_remove(ix);
247            };
248        };
249
250        let was_main_thread = state.is_main_thread;
251        state.is_main_thread = main_thread;
252        drop(state);
253        runnable.run();
254        self.state.lock().is_main_thread = was_main_thread;
255
256        true
257    }
258
259    fn park(&self) {
260        self.parker.lock().park();
261    }
262
263    fn unparker(&self) -> Unparker {
264        self.unparker.clone()
265    }
266
267    fn as_test(&self) -> Option<&TestDispatcher> {
268        Some(self)
269    }
270}