executor.rs

  1use anyhow::{anyhow, Result};
  2use async_task::Runnable;
  3pub use async_task::Task;
  4use parking_lot::Mutex;
  5use rand::prelude::*;
  6use smol::{channel, prelude::*, Executor};
  7use std::{
  8    marker::PhantomData,
  9    mem,
 10    pin::Pin,
 11    rc::Rc,
 12    sync::{
 13        atomic::{AtomicBool, Ordering::SeqCst},
 14        mpsc::SyncSender,
 15        Arc,
 16    },
 17    thread,
 18};
 19
 20use crate::platform;
 21
 22pub enum Foreground {
 23    Platform {
 24        dispatcher: Arc<dyn platform::Dispatcher>,
 25        _not_send_or_sync: PhantomData<Rc<()>>,
 26    },
 27    Test(smol::LocalExecutor<'static>),
 28    Deterministic(Arc<Deterministic>),
 29}
 30
 31pub enum Background {
 32    Deterministic(Arc<Deterministic>),
 33    Production {
 34        executor: Arc<smol::Executor<'static>>,
 35        threads: usize,
 36        _stop: channel::Sender<()>,
 37    },
 38}
 39
 40struct DeterministicState {
 41    rng: StdRng,
 42    seed: u64,
 43    scheduled: Vec<Runnable>,
 44    spawned_from_foreground: Vec<Runnable>,
 45    waker: Option<SyncSender<()>>,
 46}
 47
 48pub struct Deterministic(Arc<Mutex<DeterministicState>>);
 49
 50impl Deterministic {
 51    fn new(seed: u64) -> Self {
 52        Self(Arc::new(Mutex::new(DeterministicState {
 53            rng: StdRng::seed_from_u64(seed),
 54            seed,
 55            scheduled: Default::default(),
 56            spawned_from_foreground: Default::default(),
 57            waker: None,
 58        })))
 59    }
 60
 61    pub fn spawn_from_foreground<F, T>(&self, future: F) -> Task<T>
 62    where
 63        T: 'static,
 64        F: Future<Output = T> + 'static,
 65    {
 66        let scheduled_once = AtomicBool::new(false);
 67        let state = self.0.clone();
 68        let (runnable, task) = async_task::spawn_local(future, move |runnable| {
 69            let mut state = state.lock();
 70            if scheduled_once.fetch_or(true, SeqCst) {
 71                state.scheduled.push(runnable);
 72            } else {
 73                state.spawned_from_foreground.push(runnable);
 74            }
 75            if let Some(waker) = state.waker.as_ref() {
 76                waker.send(()).ok();
 77            }
 78        });
 79        runnable.schedule();
 80        task
 81    }
 82
 83    pub fn spawn<F, T>(&self, future: F) -> Task<T>
 84    where
 85        T: 'static + Send,
 86        F: 'static + Send + Future<Output = T>,
 87    {
 88        let state = self.0.clone();
 89        let (runnable, task) = async_task::spawn(future, move |runnable| {
 90            let mut state = state.lock();
 91            state.scheduled.push(runnable);
 92            if let Some(waker) = state.waker.as_ref() {
 93                waker.send(()).ok();
 94            }
 95        });
 96        runnable.schedule();
 97        task
 98    }
 99
100    pub fn run<F, T>(&self, future: F) -> T
101    where
102        T: 'static,
103        F: Future<Output = T> + 'static,
104    {
105        let (wake_tx, wake_rx) = std::sync::mpsc::sync_channel(32);
106        let state = self.0.clone();
107        state.lock().waker = Some(wake_tx);
108
109        let (output_tx, output_rx) = std::sync::mpsc::channel();
110        self.spawn_from_foreground(async move {
111            let output = future.await;
112            output_tx.send(output).unwrap();
113        })
114        .detach();
115
116        loop {
117            if let Ok(value) = output_rx.try_recv() {
118                state.lock().waker = None;
119                return value;
120            }
121
122            wake_rx.recv().unwrap();
123            let runnable = {
124                let state = &mut *state.lock();
125                let ix = state
126                    .rng
127                    .gen_range(0..state.scheduled.len() + state.spawned_from_foreground.len());
128                if ix < state.scheduled.len() {
129                    state.scheduled.remove(ix)
130                } else {
131                    state.spawned_from_foreground.remove(0)
132                }
133            };
134
135            runnable.run();
136        }
137    }
138}
139
140impl Foreground {
141    pub fn platform(dispatcher: Arc<dyn platform::Dispatcher>) -> Result<Self> {
142        if dispatcher.is_main_thread() {
143            Ok(Self::Platform {
144                dispatcher,
145                _not_send_or_sync: PhantomData,
146            })
147        } else {
148            Err(anyhow!("must be constructed on main thread"))
149        }
150    }
151
152    pub fn test() -> Self {
153        Self::Test(smol::LocalExecutor::new())
154    }
155
156    pub fn spawn<T: 'static>(&self, future: impl Future<Output = T> + 'static) -> Task<T> {
157        match self {
158            Self::Platform { dispatcher, .. } => {
159                let dispatcher = dispatcher.clone();
160                let schedule = move |runnable: Runnable| dispatcher.run_on_main_thread(runnable);
161                let (runnable, task) = async_task::spawn_local(future, schedule);
162                runnable.schedule();
163                task
164            }
165            Self::Test(executor) => executor.spawn(future),
166            Self::Deterministic(executor) => executor.spawn_from_foreground(future),
167        }
168    }
169
170    pub fn run<T: 'static>(&self, future: impl 'static + Future<Output = T>) -> T {
171        match self {
172            Self::Platform { .. } => panic!("you can't call run on a platform foreground executor"),
173            Self::Test(executor) => smol::block_on(executor.run(future)),
174            Self::Deterministic(executor) => executor.run(future),
175        }
176    }
177
178    pub fn reset(&self) {
179        match self {
180            Self::Platform { .. } => panic!("can't call this method on a platform executor"),
181            Self::Test(_) => panic!("can't call this method on a test executor"),
182            Self::Deterministic(executor) => {
183                let state = &mut *executor.0.lock();
184                state.rng = StdRng::seed_from_u64(state.seed);
185            }
186        }
187    }
188}
189
190impl Background {
191    pub fn new() -> Self {
192        let executor = Arc::new(Executor::new());
193        let stop = channel::unbounded::<()>();
194        let threads = num_cpus::get();
195
196        for i in 0..threads {
197            let executor = executor.clone();
198            let stop = stop.1.clone();
199            thread::Builder::new()
200                .name(format!("background-executor-{}", i))
201                .spawn(move || smol::block_on(executor.run(stop.recv())))
202                .unwrap();
203        }
204
205        Self::Production {
206            executor,
207            threads,
208            _stop: stop.0,
209        }
210    }
211
212    pub fn threads(&self) -> usize {
213        match self {
214            Self::Deterministic(_) => 1,
215            Self::Production { threads, .. } => *threads,
216        }
217    }
218
219    pub fn spawn<T, F>(&self, future: F) -> Task<T>
220    where
221        T: 'static + Send,
222        F: Send + Future<Output = T> + 'static,
223    {
224        match self {
225            Self::Production { executor, .. } => executor.spawn(future),
226            Self::Deterministic(executor) => executor.spawn(future),
227        }
228    }
229
230    pub async fn scoped<'scope, F>(&self, scheduler: F)
231    where
232        F: FnOnce(&mut Scope<'scope>),
233    {
234        let mut scope = Scope {
235            futures: Default::default(),
236            _phantom: PhantomData,
237        };
238        (scheduler)(&mut scope);
239        let spawned = scope
240            .futures
241            .into_iter()
242            .map(|f| self.spawn(f))
243            .collect::<Vec<_>>();
244        for task in spawned {
245            task.await;
246        }
247    }
248}
249
250pub struct Scope<'a> {
251    futures: Vec<Pin<Box<dyn Future<Output = ()> + Send + 'static>>>,
252    _phantom: PhantomData<&'a ()>,
253}
254
255impl<'a> Scope<'a> {
256    pub fn spawn<F>(&mut self, f: F)
257    where
258        F: Future<Output = ()> + Send + 'a,
259    {
260        let f = unsafe {
261            mem::transmute::<
262                Pin<Box<dyn Future<Output = ()> + Send + 'a>>,
263                Pin<Box<dyn Future<Output = ()> + Send + 'static>>,
264            >(Box::pin(f))
265        };
266        self.futures.push(f);
267    }
268}
269
270pub fn deterministic(seed: u64) -> (Rc<Foreground>, Arc<Background>) {
271    let executor = Arc::new(Deterministic::new(seed));
272    (
273        Rc::new(Foreground::Deterministic(executor.clone())),
274        Arc::new(Background::Deterministic(executor)),
275    )
276}