dispatcher.rs

  1use crate::{PlatformDispatcher, TaskLabel};
  2use async_task::Runnable;
  3use backtrace::Backtrace;
  4use collections::{HashMap, HashSet, VecDeque};
  5use parking::{Parker, Unparker};
  6use parking_lot::Mutex;
  7use rand::prelude::*;
  8use std::{
  9    future::Future,
 10    ops::RangeInclusive,
 11    pin::Pin,
 12    sync::Arc,
 13    task::{Context, Poll},
 14    time::Duration,
 15};
 16use util::post_inc;
 17
 18#[derive(Copy, Clone, PartialEq, Eq, Hash)]
 19struct TestDispatcherId(usize);
 20
 21#[doc(hidden)]
 22pub struct TestDispatcher {
 23    id: TestDispatcherId,
 24    state: Arc<Mutex<TestDispatcherState>>,
 25    parker: Arc<Mutex<Parker>>,
 26    unparker: Unparker,
 27}
 28
 29struct TestDispatcherState {
 30    random: StdRng,
 31    foreground: HashMap<TestDispatcherId, VecDeque<Runnable>>,
 32    background: Vec<Runnable>,
 33    deprioritized_background: Vec<Runnable>,
 34    delayed: Vec<(Duration, Runnable)>,
 35    time: Duration,
 36    is_main_thread: bool,
 37    next_id: TestDispatcherId,
 38    allow_parking: bool,
 39    waiting_backtrace: Option<Backtrace>,
 40    deprioritized_task_labels: HashSet<TaskLabel>,
 41    block_on_ticks: RangeInclusive<usize>,
 42}
 43
 44impl TestDispatcher {
 45    pub fn new(random: StdRng) -> Self {
 46        let (parker, unparker) = parking::pair();
 47        let state = TestDispatcherState {
 48            random,
 49            foreground: HashMap::default(),
 50            background: Vec::new(),
 51            deprioritized_background: Vec::new(),
 52            delayed: Vec::new(),
 53            time: Duration::ZERO,
 54            is_main_thread: true,
 55            next_id: TestDispatcherId(1),
 56            allow_parking: false,
 57            waiting_backtrace: None,
 58            deprioritized_task_labels: Default::default(),
 59            block_on_ticks: 0..=1000,
 60        };
 61
 62        TestDispatcher {
 63            id: TestDispatcherId(0),
 64            state: Arc::new(Mutex::new(state)),
 65            parker: Arc::new(Mutex::new(parker)),
 66            unparker,
 67        }
 68    }
 69
 70    pub fn advance_clock(&self, by: Duration) {
 71        let new_now = self.state.lock().time + by;
 72        loop {
 73            self.run_until_parked();
 74            let state = self.state.lock();
 75            let next_due_time = state.delayed.first().map(|(time, _)| *time);
 76            drop(state);
 77            if let Some(due_time) = next_due_time {
 78                if due_time <= new_now {
 79                    self.state.lock().time = due_time;
 80                    continue;
 81                }
 82            }
 83            break;
 84        }
 85        self.state.lock().time = new_now;
 86    }
 87
 88    pub fn simulate_random_delay(&self) -> impl 'static + Send + Future<Output = ()> {
 89        struct YieldNow {
 90            pub(crate) count: usize,
 91        }
 92
 93        impl Future for YieldNow {
 94            type Output = ();
 95
 96            fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
 97                if self.count > 0 {
 98                    self.count -= 1;
 99                    cx.waker().wake_by_ref();
100                    Poll::Pending
101                } else {
102                    Poll::Ready(())
103                }
104            }
105        }
106
107        YieldNow {
108            count: self.state.lock().random.gen_range(0..10),
109        }
110    }
111
112    pub fn deprioritize(&self, task_label: TaskLabel) {
113        self.state
114            .lock()
115            .deprioritized_task_labels
116            .insert(task_label);
117    }
118
119    pub fn run_until_parked(&self) {
120        while self.tick(false) {}
121    }
122
123    pub fn parking_allowed(&self) -> bool {
124        self.state.lock().allow_parking
125    }
126
127    pub fn allow_parking(&self) {
128        self.state.lock().allow_parking = true
129    }
130
131    pub fn start_waiting(&self) {
132        self.state.lock().waiting_backtrace = Some(Backtrace::new_unresolved());
133    }
134
135    pub fn finish_waiting(&self) {
136        self.state.lock().waiting_backtrace.take();
137    }
138
139    pub fn waiting_backtrace(&self) -> Option<Backtrace> {
140        self.state.lock().waiting_backtrace.take().map(|mut b| {
141            b.resolve();
142            b
143        })
144    }
145
146    pub fn rng(&self) -> StdRng {
147        self.state.lock().random.clone()
148    }
149
150    pub fn set_block_on_ticks(&self, range: std::ops::RangeInclusive<usize>) {
151        self.state.lock().block_on_ticks = range;
152    }
153
154    pub fn gen_block_on_ticks(&self) -> usize {
155        let mut lock = self.state.lock();
156        let block_on_ticks = lock.block_on_ticks.clone();
157        lock.random.gen_range(block_on_ticks)
158    }
159}
160
161impl Clone for TestDispatcher {
162    fn clone(&self) -> Self {
163        let id = post_inc(&mut self.state.lock().next_id.0);
164        Self {
165            id: TestDispatcherId(id),
166            state: self.state.clone(),
167            parker: self.parker.clone(),
168            unparker: self.unparker.clone(),
169        }
170    }
171}
172
173impl PlatformDispatcher for TestDispatcher {
174    fn is_main_thread(&self) -> bool {
175        self.state.lock().is_main_thread
176    }
177
178    fn dispatch(&self, runnable: Runnable, label: Option<TaskLabel>) {
179        {
180            let mut state = self.state.lock();
181            if label.map_or(false, |label| {
182                state.deprioritized_task_labels.contains(&label)
183            }) {
184                state.deprioritized_background.push(runnable);
185            } else {
186                state.background.push(runnable);
187            }
188        }
189        self.unparker.unpark();
190    }
191
192    fn dispatch_on_main_thread(&self, runnable: Runnable) {
193        self.state
194            .lock()
195            .foreground
196            .entry(self.id)
197            .or_default()
198            .push_back(runnable);
199        self.unparker.unpark();
200    }
201
202    fn dispatch_after(&self, duration: std::time::Duration, runnable: Runnable) {
203        let mut state = self.state.lock();
204        let next_time = state.time + duration;
205        let ix = match state.delayed.binary_search_by_key(&next_time, |e| e.0) {
206            Ok(ix) | Err(ix) => ix,
207        };
208        state.delayed.insert(ix, (next_time, runnable));
209    }
210
211    fn tick(&self, background_only: bool) -> bool {
212        let mut state = self.state.lock();
213
214        while let Some((deadline, _)) = state.delayed.first() {
215            if *deadline > state.time {
216                break;
217            }
218            let (_, runnable) = state.delayed.remove(0);
219            state.background.push(runnable);
220        }
221
222        let foreground_len: usize = if background_only {
223            0
224        } else {
225            state
226                .foreground
227                .values()
228                .map(|runnables| runnables.len())
229                .sum()
230        };
231        let background_len = state.background.len();
232
233        let runnable;
234        let main_thread;
235        if foreground_len == 0 && background_len == 0 {
236            let deprioritized_background_len = state.deprioritized_background.len();
237            if deprioritized_background_len == 0 {
238                return false;
239            }
240            let ix = state.random.gen_range(0..deprioritized_background_len);
241            main_thread = false;
242            runnable = state.deprioritized_background.swap_remove(ix);
243        } else {
244            main_thread = state.random.gen_ratio(
245                foreground_len as u32,
246                (foreground_len + background_len) as u32,
247            );
248            if main_thread {
249                let state = &mut *state;
250                runnable = state
251                    .foreground
252                    .values_mut()
253                    .filter(|runnables| !runnables.is_empty())
254                    .choose(&mut state.random)
255                    .unwrap()
256                    .pop_front()
257                    .unwrap();
258            } else {
259                let ix = state.random.gen_range(0..background_len);
260                runnable = state.background.swap_remove(ix);
261            };
262        };
263
264        let was_main_thread = state.is_main_thread;
265        state.is_main_thread = main_thread;
266        drop(state);
267        runnable.run();
268        self.state.lock().is_main_thread = was_main_thread;
269
270        true
271    }
272
273    fn park(&self) {
274        self.parker.lock().park();
275    }
276
277    fn unparker(&self) -> Unparker {
278        self.unparker.clone()
279    }
280
281    fn as_test(&self) -> Option<&TestDispatcher> {
282        Some(self)
283    }
284}