dispatcher.rs

  1use crate::PlatformDispatcher;
  2use async_task::Runnable;
  3use backtrace::Backtrace;
  4use collections::{HashMap, VecDeque};
  5use parking::{Parker, Unparker};
  6use parking_lot::Mutex;
  7use rand::prelude::*;
  8use std::{
  9    future::Future,
 10    pin::Pin,
 11    sync::Arc,
 12    task::{Context, Poll},
 13    time::Duration,
 14};
 15use util::post_inc;
 16
 17#[derive(Copy, Clone, PartialEq, Eq, Hash)]
 18struct TestDispatcherId(usize);
 19
 20pub struct TestDispatcher {
 21    id: TestDispatcherId,
 22    state: Arc<Mutex<TestDispatcherState>>,
 23    parker: Arc<Mutex<Parker>>,
 24    unparker: Unparker,
 25}
 26
 27struct TestDispatcherState {
 28    random: StdRng,
 29    foreground: HashMap<TestDispatcherId, VecDeque<Runnable>>,
 30    background: Vec<Runnable>,
 31    delayed: Vec<(Duration, Runnable)>,
 32    time: Duration,
 33    is_main_thread: bool,
 34    next_id: TestDispatcherId,
 35    allow_parking: bool,
 36    waiting_backtrace: Option<Backtrace>,
 37}
 38
 39impl TestDispatcher {
 40    pub fn new(random: StdRng) -> Self {
 41        let (parker, unparker) = parking::pair();
 42        let state = TestDispatcherState {
 43            random,
 44            foreground: HashMap::default(),
 45            background: Vec::new(),
 46            delayed: Vec::new(),
 47            time: Duration::ZERO,
 48            is_main_thread: true,
 49            next_id: TestDispatcherId(1),
 50            allow_parking: false,
 51            waiting_backtrace: None,
 52        };
 53
 54        TestDispatcher {
 55            id: TestDispatcherId(0),
 56            state: Arc::new(Mutex::new(state)),
 57            parker: Arc::new(Mutex::new(parker)),
 58            unparker,
 59        }
 60    }
 61
 62    pub fn advance_clock(&self, by: Duration) {
 63        let new_now = self.state.lock().time + by;
 64        loop {
 65            self.run_until_parked();
 66            let state = self.state.lock();
 67            let next_due_time = state.delayed.first().map(|(time, _)| *time);
 68            drop(state);
 69            if let Some(due_time) = next_due_time {
 70                if due_time <= new_now {
 71                    self.state.lock().time = due_time;
 72                    continue;
 73                }
 74            }
 75            break;
 76        }
 77        self.state.lock().time = new_now;
 78    }
 79
 80    pub fn simulate_random_delay(&self) -> impl 'static + Send + Future<Output = ()> {
 81        pub struct YieldNow {
 82            count: usize,
 83        }
 84
 85        impl Future for YieldNow {
 86            type Output = ();
 87
 88            fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
 89                if self.count > 0 {
 90                    self.count -= 1;
 91                    cx.waker().wake_by_ref();
 92                    Poll::Pending
 93                } else {
 94                    Poll::Ready(())
 95                }
 96            }
 97        }
 98
 99        YieldNow {
100            count: self.state.lock().random.gen_range(0..10),
101        }
102    }
103
104    pub fn run_until_parked(&self) {
105        while self.poll(false) {}
106    }
107
108    pub fn parking_allowed(&self) -> bool {
109        self.state.lock().allow_parking
110    }
111
112    pub fn allow_parking(&self) {
113        self.state.lock().allow_parking = true
114    }
115
116    pub fn start_waiting(&self) {
117        self.state.lock().waiting_backtrace = Some(Backtrace::new_unresolved());
118    }
119
120    pub fn finish_waiting(&self) {
121        self.state.lock().waiting_backtrace.take();
122    }
123
124    pub fn waiting_backtrace(&self) -> Option<Backtrace> {
125        self.state.lock().waiting_backtrace.take().map(|mut b| {
126            b.resolve();
127            b
128        })
129    }
130}
131
132impl Clone for TestDispatcher {
133    fn clone(&self) -> Self {
134        let id = post_inc(&mut self.state.lock().next_id.0);
135        Self {
136            id: TestDispatcherId(id),
137            state: self.state.clone(),
138            parker: self.parker.clone(),
139            unparker: self.unparker.clone(),
140        }
141    }
142}
143
144impl PlatformDispatcher for TestDispatcher {
145    fn is_main_thread(&self) -> bool {
146        self.state.lock().is_main_thread
147    }
148
149    fn dispatch(&self, runnable: Runnable) {
150        self.state.lock().background.push(runnable);
151        self.unparker.unpark();
152    }
153
154    fn dispatch_on_main_thread(&self, runnable: Runnable) {
155        self.state
156            .lock()
157            .foreground
158            .entry(self.id)
159            .or_default()
160            .push_back(runnable);
161        self.unparker.unpark();
162    }
163
164    fn dispatch_after(&self, duration: std::time::Duration, runnable: Runnable) {
165        let mut state = self.state.lock();
166        let next_time = state.time + duration;
167        let ix = match state.delayed.binary_search_by_key(&next_time, |e| e.0) {
168            Ok(ix) | Err(ix) => ix,
169        };
170        state.delayed.insert(ix, (next_time, runnable));
171    }
172
173    fn poll(&self, background_only: bool) -> bool {
174        let mut state = self.state.lock();
175
176        while let Some((deadline, _)) = state.delayed.first() {
177            if *deadline > state.time {
178                break;
179            }
180            let (_, runnable) = state.delayed.remove(0);
181            state.background.push(runnable);
182        }
183
184        let foreground_len: usize = if background_only {
185            0
186        } else {
187            state
188                .foreground
189                .values()
190                .map(|runnables| runnables.len())
191                .sum()
192        };
193        let background_len = state.background.len();
194
195        if foreground_len == 0 && background_len == 0 {
196            return false;
197        }
198
199        let main_thread = state.random.gen_ratio(
200            foreground_len as u32,
201            (foreground_len + background_len) as u32,
202        );
203        let was_main_thread = state.is_main_thread;
204        state.is_main_thread = main_thread;
205
206        let runnable = if main_thread {
207            let state = &mut *state;
208            let runnables = state
209                .foreground
210                .values_mut()
211                .filter(|runnables| !runnables.is_empty())
212                .choose(&mut state.random)
213                .unwrap();
214            runnables.pop_front().unwrap()
215        } else {
216            let ix = state.random.gen_range(0..background_len);
217            state.background.swap_remove(ix)
218        };
219
220        drop(state);
221        runnable.run();
222
223        self.state.lock().is_main_thread = was_main_thread;
224
225        true
226    }
227
228    fn park(&self) {
229        self.parker.lock().park();
230    }
231
232    fn unparker(&self) -> Unparker {
233        self.unparker.clone()
234    }
235
236    fn as_test(&self) -> Option<&TestDispatcher> {
237        Some(self)
238    }
239}