dispatcher.rs

  1use crate::PlatformDispatcher;
  2use async_task::Runnable;
  3use collections::{HashMap, VecDeque};
  4use parking_lot::Mutex;
  5use rand::prelude::*;
  6use std::{
  7    future::Future,
  8    pin::Pin,
  9    sync::Arc,
 10    task::{Context, Poll},
 11    time::Duration,
 12};
 13use util::post_inc;
 14
 15#[derive(Copy, Clone, PartialEq, Eq, Hash)]
 16struct TestDispatcherId(usize);
 17
 18pub struct TestDispatcher {
 19    id: TestDispatcherId,
 20    state: Arc<Mutex<TestDispatcherState>>,
 21}
 22
 23struct TestDispatcherState {
 24    random: StdRng,
 25    foreground: HashMap<TestDispatcherId, VecDeque<Runnable>>,
 26    background: Vec<Runnable>,
 27    delayed: Vec<(Duration, Runnable)>,
 28    time: Duration,
 29    is_main_thread: bool,
 30    next_id: TestDispatcherId,
 31}
 32
 33impl TestDispatcher {
 34    pub fn new(random: StdRng) -> Self {
 35        let state = TestDispatcherState {
 36            random,
 37            foreground: HashMap::default(),
 38            background: Vec::new(),
 39            delayed: Vec::new(),
 40            time: Duration::ZERO,
 41            is_main_thread: true,
 42            next_id: TestDispatcherId(1),
 43        };
 44
 45        TestDispatcher {
 46            id: TestDispatcherId(0),
 47            state: Arc::new(Mutex::new(state)),
 48        }
 49    }
 50
 51    pub fn advance_clock(&self, by: Duration) {
 52        let new_now = self.state.lock().time + by;
 53        loop {
 54            self.run_until_parked();
 55            let state = self.state.lock();
 56            let next_due_time = state.delayed.first().map(|(time, _)| *time);
 57            drop(state);
 58            if let Some(due_time) = next_due_time {
 59                if due_time <= new_now {
 60                    self.state.lock().time = due_time;
 61                    continue;
 62                }
 63            }
 64            break;
 65        }
 66        self.state.lock().time = new_now;
 67    }
 68
 69    pub fn simulate_random_delay(&self) -> impl Future<Output = ()> {
 70        pub struct YieldNow {
 71            count: usize,
 72        }
 73
 74        impl Future for YieldNow {
 75            type Output = ();
 76
 77            fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
 78                if self.count > 0 {
 79                    self.count -= 1;
 80                    cx.waker().wake_by_ref();
 81                    Poll::Pending
 82                } else {
 83                    Poll::Ready(())
 84                }
 85            }
 86        }
 87
 88        YieldNow {
 89            count: self.state.lock().random.gen_range(0..10),
 90        }
 91    }
 92
 93    pub fn run_until_parked(&self) {
 94        while self.poll() {}
 95    }
 96}
 97
 98impl Clone for TestDispatcher {
 99    fn clone(&self) -> Self {
100        let id = post_inc(&mut self.state.lock().next_id.0);
101        Self {
102            id: TestDispatcherId(id),
103            state: self.state.clone(),
104        }
105    }
106}
107
108impl PlatformDispatcher for TestDispatcher {
109    fn is_main_thread(&self) -> bool {
110        self.state.lock().is_main_thread
111    }
112
113    fn dispatch(&self, runnable: Runnable) {
114        self.state.lock().background.push(runnable);
115    }
116
117    fn dispatch_on_main_thread(&self, runnable: Runnable) {
118        self.state
119            .lock()
120            .foreground
121            .entry(self.id)
122            .or_default()
123            .push_back(runnable);
124    }
125
126    fn dispatch_after(&self, duration: std::time::Duration, runnable: Runnable) {
127        let mut state = self.state.lock();
128        let next_time = state.time + duration;
129        let ix = match state.delayed.binary_search_by_key(&next_time, |e| e.0) {
130            Ok(ix) | Err(ix) => ix,
131        };
132        state.delayed.insert(ix, (next_time, runnable));
133    }
134
135    fn poll(&self) -> bool {
136        let mut state = self.state.lock();
137
138        while let Some((deadline, _)) = state.delayed.first() {
139            if *deadline > state.time {
140                break;
141            }
142            let (_, runnable) = state.delayed.remove(0);
143            state.background.push(runnable);
144        }
145
146        let foreground_len: usize = state
147            .foreground
148            .values()
149            .map(|runnables| runnables.len())
150            .sum();
151        let background_len = state.background.len();
152
153        if foreground_len == 0 && background_len == 0 {
154            return false;
155        }
156
157        let main_thread = state.random.gen_ratio(
158            foreground_len as u32,
159            (foreground_len + background_len) as u32,
160        );
161        let was_main_thread = state.is_main_thread;
162        state.is_main_thread = main_thread;
163
164        let runnable = if main_thread {
165            let state = &mut *state;
166            let runnables = state
167                .foreground
168                .values_mut()
169                .filter(|runnables| !runnables.is_empty())
170                .choose(&mut state.random)
171                .unwrap();
172            runnables.pop_front().unwrap()
173        } else {
174            let ix = state.random.gen_range(0..background_len);
175            state.background.swap_remove(ix)
176        };
177
178        drop(state);
179        runnable.run();
180
181        self.state.lock().is_main_thread = was_main_thread;
182
183        true
184    }
185
186    fn as_test(&self) -> Option<&TestDispatcher> {
187        Some(self)
188    }
189}
190
191#[cfg(test)]
192mod tests {
193    use super::*;
194    use crate::Executor;
195    use std::sync::Arc;
196
197    #[test]
198    fn test_dispatch() {
199        let dispatcher = TestDispatcher::new(StdRng::seed_from_u64(0));
200        let executor = Executor::new(Arc::new(dispatcher));
201
202        let result = executor.block(async { executor.run_on_main(|| 1).await });
203        assert_eq!(result, 1);
204
205        let result = executor.block({
206            let executor = executor.clone();
207            async move {
208                executor
209                    .spawn_on_main({
210                        let executor = executor.clone();
211                        assert!(executor.is_main_thread());
212                        || async move {
213                            assert!(executor.is_main_thread());
214                            let result = executor
215                                .spawn({
216                                    let executor = executor.clone();
217                                    async move {
218                                        assert!(!executor.is_main_thread());
219
220                                        let result = executor
221                                            .spawn_on_main({
222                                                let executor = executor.clone();
223                                                || async move {
224                                                    assert!(executor.is_main_thread());
225                                                    2
226                                                }
227                                            })
228                                            .await;
229
230                                        assert!(!executor.is_main_thread());
231                                        result
232                                    }
233                                })
234                                .await;
235                            assert!(executor.is_main_thread());
236                            result
237                        }
238                    })
239                    .await
240            }
241        });
242        assert_eq!(result, 2);
243    }
244}