dispatcher.rs

  1use crate::{PlatformDispatcher, TaskLabel};
  2use async_task::Runnable;
  3use backtrace::Backtrace;
  4use collections::{HashMap, HashSet, VecDeque};
  5use parking::{Parker, Unparker};
  6use parking_lot::Mutex;
  7use rand::prelude::*;
  8use std::{
  9    future::Future,
 10    ops::RangeInclusive,
 11    pin::Pin,
 12    sync::Arc,
 13    task::{Context, Poll},
 14    time::Duration,
 15};
 16use util::post_inc;
 17
 18#[derive(Copy, Clone, PartialEq, Eq, Hash)]
 19struct TestDispatcherId(usize);
 20
 21#[doc(hidden)]
 22pub struct TestDispatcher {
 23    id: TestDispatcherId,
 24    state: Arc<Mutex<TestDispatcherState>>,
 25    parker: Arc<Mutex<Parker>>,
 26    unparker: Unparker,
 27}
 28
 29struct TestDispatcherState {
 30    random: StdRng,
 31    foreground: HashMap<TestDispatcherId, VecDeque<Runnable>>,
 32    background: Vec<Runnable>,
 33    deprioritized_background: Vec<Runnable>,
 34    delayed: Vec<(Duration, Runnable)>,
 35    time: Duration,
 36    is_main_thread: bool,
 37    next_id: TestDispatcherId,
 38    allow_parking: bool,
 39    waiting_backtrace: Option<Backtrace>,
 40    deprioritized_task_labels: HashSet<TaskLabel>,
 41    block_on_ticks: RangeInclusive<usize>,
 42}
 43
 44impl TestDispatcher {
 45    pub fn new(random: StdRng) -> Self {
 46        let (parker, unparker) = parking::pair();
 47        let state = TestDispatcherState {
 48            random,
 49            foreground: HashMap::default(),
 50            background: Vec::new(),
 51            deprioritized_background: Vec::new(),
 52            delayed: Vec::new(),
 53            time: Duration::ZERO,
 54            is_main_thread: true,
 55            next_id: TestDispatcherId(1),
 56            allow_parking: false,
 57            waiting_backtrace: None,
 58            deprioritized_task_labels: Default::default(),
 59            block_on_ticks: 0..=1000,
 60        };
 61
 62        TestDispatcher {
 63            id: TestDispatcherId(0),
 64            state: Arc::new(Mutex::new(state)),
 65            parker: Arc::new(Mutex::new(parker)),
 66            unparker,
 67        }
 68    }
 69
 70    pub fn advance_clock(&self, by: Duration) {
 71        let new_now = self.state.lock().time + by;
 72        loop {
 73            self.run_until_parked();
 74            let state = self.state.lock();
 75            let next_due_time = state.delayed.first().map(|(time, _)| *time);
 76            drop(state);
 77            if let Some(due_time) = next_due_time {
 78                if due_time <= new_now {
 79                    self.state.lock().time = due_time;
 80                    continue;
 81                }
 82            }
 83            break;
 84        }
 85        self.state.lock().time = new_now;
 86    }
 87
 88    pub fn simulate_random_delay(&self) -> impl 'static + Send + Future<Output = ()> {
 89        struct YieldNow {
 90            pub(crate) count: usize,
 91        }
 92
 93        impl Future for YieldNow {
 94            type Output = ();
 95
 96            fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
 97                if self.count > 0 {
 98                    self.count -= 1;
 99                    cx.waker().wake_by_ref();
100                    Poll::Pending
101                } else {
102                    Poll::Ready(())
103                }
104            }
105        }
106
107        YieldNow {
108            count: self.state.lock().random.gen_range(0..10),
109        }
110    }
111
112    pub fn deprioritize(&self, task_label: TaskLabel) {
113        self.state
114            .lock()
115            .deprioritized_task_labels
116            .insert(task_label);
117    }
118
119    pub fn run_until_parked(&self) {
120        while self.tick(false) {}
121    }
122
123    pub fn parking_allowed(&self) -> bool {
124        self.state.lock().allow_parking
125    }
126
127    pub fn allow_parking(&self) {
128        self.state.lock().allow_parking = true
129    }
130
131    pub fn forbid_parking(&self) {
132        self.state.lock().allow_parking = false
133    }
134
135    pub fn start_waiting(&self) {
136        self.state.lock().waiting_backtrace = Some(Backtrace::new_unresolved());
137    }
138
139    pub fn finish_waiting(&self) {
140        self.state.lock().waiting_backtrace.take();
141    }
142
143    pub fn waiting_backtrace(&self) -> Option<Backtrace> {
144        self.state.lock().waiting_backtrace.take().map(|mut b| {
145            b.resolve();
146            b
147        })
148    }
149
150    pub fn rng(&self) -> StdRng {
151        self.state.lock().random.clone()
152    }
153
154    pub fn set_block_on_ticks(&self, range: std::ops::RangeInclusive<usize>) {
155        self.state.lock().block_on_ticks = range;
156    }
157
158    pub fn gen_block_on_ticks(&self) -> usize {
159        let mut lock = self.state.lock();
160        let block_on_ticks = lock.block_on_ticks.clone();
161        lock.random.gen_range(block_on_ticks)
162    }
163}
164
165impl Clone for TestDispatcher {
166    fn clone(&self) -> Self {
167        let id = post_inc(&mut self.state.lock().next_id.0);
168        Self {
169            id: TestDispatcherId(id),
170            state: self.state.clone(),
171            parker: self.parker.clone(),
172            unparker: self.unparker.clone(),
173        }
174    }
175}
176
177impl PlatformDispatcher for TestDispatcher {
178    fn is_main_thread(&self) -> bool {
179        self.state.lock().is_main_thread
180    }
181
182    fn dispatch(&self, runnable: Runnable, label: Option<TaskLabel>) {
183        {
184            let mut state = self.state.lock();
185            if label.map_or(false, |label| {
186                state.deprioritized_task_labels.contains(&label)
187            }) {
188                state.deprioritized_background.push(runnable);
189            } else {
190                state.background.push(runnable);
191            }
192        }
193        self.unparker.unpark();
194    }
195
196    fn dispatch_on_main_thread(&self, runnable: Runnable) {
197        self.state
198            .lock()
199            .foreground
200            .entry(self.id)
201            .or_default()
202            .push_back(runnable);
203        self.unparker.unpark();
204    }
205
206    fn dispatch_after(&self, duration: std::time::Duration, runnable: Runnable) {
207        let mut state = self.state.lock();
208        let next_time = state.time + duration;
209        let ix = match state.delayed.binary_search_by_key(&next_time, |e| e.0) {
210            Ok(ix) | Err(ix) => ix,
211        };
212        state.delayed.insert(ix, (next_time, runnable));
213    }
214
215    fn tick(&self, background_only: bool) -> bool {
216        let mut state = self.state.lock();
217
218        while let Some((deadline, _)) = state.delayed.first() {
219            if *deadline > state.time {
220                break;
221            }
222            let (_, runnable) = state.delayed.remove(0);
223            state.background.push(runnable);
224        }
225
226        let foreground_len: usize = if background_only {
227            0
228        } else {
229            state
230                .foreground
231                .values()
232                .map(|runnables| runnables.len())
233                .sum()
234        };
235        let background_len = state.background.len();
236
237        let runnable;
238        let main_thread;
239        if foreground_len == 0 && background_len == 0 {
240            let deprioritized_background_len = state.deprioritized_background.len();
241            if deprioritized_background_len == 0 {
242                return false;
243            }
244            let ix = state.random.gen_range(0..deprioritized_background_len);
245            main_thread = false;
246            runnable = state.deprioritized_background.swap_remove(ix);
247        } else {
248            main_thread = state.random.gen_ratio(
249                foreground_len as u32,
250                (foreground_len + background_len) as u32,
251            );
252            if main_thread {
253                let state = &mut *state;
254                runnable = state
255                    .foreground
256                    .values_mut()
257                    .filter(|runnables| !runnables.is_empty())
258                    .choose(&mut state.random)
259                    .unwrap()
260                    .pop_front()
261                    .unwrap();
262            } else {
263                let ix = state.random.gen_range(0..background_len);
264                runnable = state.background.swap_remove(ix);
265            };
266        };
267
268        let was_main_thread = state.is_main_thread;
269        state.is_main_thread = main_thread;
270        drop(state);
271        runnable.run();
272        self.state.lock().is_main_thread = was_main_thread;
273
274        true
275    }
276
277    fn park(&self) {
278        self.parker.lock().park();
279    }
280
281    fn unparker(&self) -> Unparker {
282        self.unparker.clone()
283    }
284
285    fn as_test(&self) -> Option<&TestDispatcher> {
286        Some(self)
287    }
288}