dispatcher.rs

  1use crate::PlatformDispatcher;
  2use async_task::Runnable;
  3use backtrace::Backtrace;
  4use collections::{HashMap, VecDeque};
  5use parking::{Parker, Unparker};
  6use parking_lot::Mutex;
  7use rand::prelude::*;
  8use std::{
  9    future::Future,
 10    pin::Pin,
 11    sync::Arc,
 12    task::{Context, Poll},
 13    time::Duration,
 14};
 15use util::post_inc;
 16
 17#[derive(Copy, Clone, PartialEq, Eq, Hash)]
 18struct TestDispatcherId(usize);
 19
 20pub struct TestDispatcher {
 21    id: TestDispatcherId,
 22    state: Arc<Mutex<TestDispatcherState>>,
 23    parker: Arc<Mutex<Parker>>,
 24    unparker: Unparker,
 25}
 26
 27struct TestDispatcherState {
 28    random: StdRng,
 29    foreground: HashMap<TestDispatcherId, VecDeque<Runnable>>,
 30    background: Vec<Runnable>,
 31    delayed: Vec<(Duration, Runnable)>,
 32    time: Duration,
 33    is_main_thread: bool,
 34    next_id: TestDispatcherId,
 35    allow_parking: bool,
 36    waiting_backtrace: Option<Backtrace>,
 37}
 38
 39impl TestDispatcher {
 40    pub fn new(random: StdRng) -> Self {
 41        let (parker, unparker) = parking::pair();
 42        let state = TestDispatcherState {
 43            random,
 44            foreground: HashMap::default(),
 45            background: Vec::new(),
 46            delayed: Vec::new(),
 47            time: Duration::ZERO,
 48            is_main_thread: true,
 49            next_id: TestDispatcherId(1),
 50            allow_parking: false,
 51            waiting_backtrace: None,
 52        };
 53
 54        TestDispatcher {
 55            id: TestDispatcherId(0),
 56            state: Arc::new(Mutex::new(state)),
 57            parker: Arc::new(Mutex::new(parker)),
 58            unparker,
 59        }
 60    }
 61
 62    pub fn advance_clock(&self, by: Duration) {
 63        let new_now = self.state.lock().time + by;
 64        loop {
 65            self.run_until_parked();
 66            let state = self.state.lock();
 67            let next_due_time = state.delayed.first().map(|(time, _)| *time);
 68            drop(state);
 69            if let Some(due_time) = next_due_time {
 70                if due_time <= new_now {
 71                    self.state.lock().time = due_time;
 72                    continue;
 73                }
 74            }
 75            break;
 76        }
 77        self.state.lock().time = new_now;
 78    }
 79
 80    pub fn simulate_random_delay(&self) -> impl 'static + Send + Future<Output = ()> {
 81        pub struct YieldNow {
 82            count: usize,
 83        }
 84
 85        impl Future for YieldNow {
 86            type Output = ();
 87
 88            fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
 89                if self.count > 0 {
 90                    self.count -= 1;
 91                    cx.waker().wake_by_ref();
 92                    Poll::Pending
 93                } else {
 94                    Poll::Ready(())
 95                }
 96            }
 97        }
 98
 99        YieldNow {
100            count: self.state.lock().random.gen_range(0..10),
101        }
102    }
103
104    pub fn run_until_parked(&self) {
105        while self.poll(false) {}
106    }
107
108    pub fn parking_allowed(&self) -> bool {
109        self.state.lock().allow_parking
110    }
111
112    pub fn allow_parking(&self) {
113        self.state.lock().allow_parking = true
114    }
115
116    pub fn start_waiting(&self) {
117        self.state.lock().waiting_backtrace = Some(Backtrace::new_unresolved());
118    }
119
120    pub fn finish_waiting(&self) {
121        self.state.lock().waiting_backtrace.take();
122    }
123
124    pub fn waiting_backtrace(&self) -> Option<Backtrace> {
125        self.state.lock().waiting_backtrace.take().map(|mut b| {
126            b.resolve();
127            b
128        })
129    }
130
131    pub fn rng(&self) -> StdRng {
132        self.state.lock().random.clone()
133    }
134}
135
136impl Clone for TestDispatcher {
137    fn clone(&self) -> Self {
138        let id = post_inc(&mut self.state.lock().next_id.0);
139        Self {
140            id: TestDispatcherId(id),
141            state: self.state.clone(),
142            parker: self.parker.clone(),
143            unparker: self.unparker.clone(),
144        }
145    }
146}
147
148impl PlatformDispatcher for TestDispatcher {
149    fn is_main_thread(&self) -> bool {
150        self.state.lock().is_main_thread
151    }
152
153    fn dispatch(&self, runnable: Runnable) {
154        self.state.lock().background.push(runnable);
155        self.unparker.unpark();
156    }
157
158    fn dispatch_on_main_thread(&self, runnable: Runnable) {
159        self.state
160            .lock()
161            .foreground
162            .entry(self.id)
163            .or_default()
164            .push_back(runnable);
165        self.unparker.unpark();
166    }
167
168    fn dispatch_after(&self, duration: std::time::Duration, runnable: Runnable) {
169        let mut state = self.state.lock();
170        let next_time = state.time + duration;
171        let ix = match state.delayed.binary_search_by_key(&next_time, |e| e.0) {
172            Ok(ix) | Err(ix) => ix,
173        };
174        state.delayed.insert(ix, (next_time, runnable));
175    }
176
177    fn poll(&self, background_only: bool) -> bool {
178        let mut state = self.state.lock();
179
180        while let Some((deadline, _)) = state.delayed.first() {
181            if *deadline > state.time {
182                break;
183            }
184            let (_, runnable) = state.delayed.remove(0);
185            state.background.push(runnable);
186        }
187
188        let foreground_len: usize = if background_only {
189            0
190        } else {
191            state
192                .foreground
193                .values()
194                .map(|runnables| runnables.len())
195                .sum()
196        };
197        let background_len = state.background.len();
198
199        if foreground_len == 0 && background_len == 0 {
200            return false;
201        }
202
203        let main_thread = state.random.gen_ratio(
204            foreground_len as u32,
205            (foreground_len + background_len) as u32,
206        );
207        let was_main_thread = state.is_main_thread;
208        state.is_main_thread = main_thread;
209
210        let runnable = if main_thread {
211            let state = &mut *state;
212            let runnables = state
213                .foreground
214                .values_mut()
215                .filter(|runnables| !runnables.is_empty())
216                .choose(&mut state.random)
217                .unwrap();
218            runnables.pop_front().unwrap()
219        } else {
220            let ix = state.random.gen_range(0..background_len);
221            state.background.swap_remove(ix)
222        };
223
224        drop(state);
225        runnable.run();
226
227        self.state.lock().is_main_thread = was_main_thread;
228
229        true
230    }
231
232    fn park(&self) {
233        self.parker.lock().park();
234    }
235
236    fn unparker(&self) -> Unparker {
237        self.unparker.clone()
238    }
239
240    fn as_test(&self) -> Option<&TestDispatcher> {
241        Some(self)
242    }
243}