executor.rs

  1use anyhow::{anyhow, Result};
  2use async_task::Runnable;
  3pub use async_task::Task;
  4use parking_lot::Mutex;
  5use rand::prelude::*;
  6use smol::prelude::*;
  7use smol::{channel, Executor};
  8use std::rc::Rc;
  9use std::sync::mpsc::SyncSender;
 10use std::sync::Arc;
 11use std::{marker::PhantomData, thread};
 12
 13use crate::platform;
 14
 15pub enum Foreground {
 16    Platform {
 17        dispatcher: Arc<dyn platform::Dispatcher>,
 18        _not_send_or_sync: PhantomData<Rc<()>>,
 19    },
 20    Test(smol::LocalExecutor<'static>),
 21    Deterministic(Arc<Deterministic>),
 22}
 23
 24pub enum Background {
 25    Deterministic(Arc<Deterministic>),
 26    Production {
 27        executor: Arc<smol::Executor<'static>>,
 28        _stop: channel::Sender<()>,
 29    },
 30}
 31
 32pub struct Deterministic {
 33    seed: u64,
 34    runnables: Arc<Mutex<(Vec<Runnable>, Option<SyncSender<()>>)>>,
 35}
 36
 37impl Deterministic {
 38    fn new(seed: u64) -> Self {
 39        Self {
 40            seed,
 41            runnables: Default::default(),
 42        }
 43    }
 44
 45    pub fn spawn_local<F, T>(&self, future: F) -> Task<T>
 46    where
 47        T: 'static,
 48        F: Future<Output = T> + 'static,
 49    {
 50        let runnables = self.runnables.clone();
 51        let (runnable, task) = async_task::spawn_local(future, move |runnable| {
 52            let mut runnables = runnables.lock();
 53            runnables.0.push(runnable);
 54            runnables.1.as_ref().unwrap().send(()).ok();
 55        });
 56        runnable.schedule();
 57        task
 58    }
 59
 60    pub fn spawn<F, T>(&self, future: F) -> Task<T>
 61    where
 62        T: 'static + Send,
 63        F: 'static + Send + Future<Output = T>,
 64    {
 65        let runnables = self.runnables.clone();
 66        let (runnable, task) = async_task::spawn(future, move |runnable| {
 67            let mut runnables = runnables.lock();
 68            runnables.0.push(runnable);
 69            runnables.1.as_ref().unwrap().send(()).ok();
 70        });
 71        runnable.schedule();
 72        task
 73    }
 74
 75    pub fn run<F, T>(&self, future: F) -> T
 76    where
 77        T: 'static,
 78        F: Future<Output = T> + 'static,
 79    {
 80        let (wake_tx, wake_rx) = std::sync::mpsc::sync_channel(32);
 81        let runnables = self.runnables.clone();
 82        runnables.lock().1 = Some(wake_tx);
 83
 84        let (output_tx, output_rx) = std::sync::mpsc::channel();
 85        self.spawn_local(async move {
 86            let output = future.await;
 87            output_tx.send(output).unwrap();
 88        })
 89        .detach();
 90
 91        let mut rng = StdRng::seed_from_u64(self.seed);
 92        loop {
 93            if let Ok(value) = output_rx.try_recv() {
 94                runnables.lock().1 = None;
 95                return value;
 96            }
 97
 98            wake_rx.recv().unwrap();
 99            let runnable = {
100                let mut runnables = runnables.lock();
101                let runnables = &mut runnables.0;
102                let index = rng.gen_range(0..runnables.len());
103                runnables.remove(index)
104            };
105
106            runnable.run();
107        }
108    }
109}
110
111impl Foreground {
112    pub fn platform(dispatcher: Arc<dyn platform::Dispatcher>) -> Result<Self> {
113        if dispatcher.is_main_thread() {
114            Ok(Self::Platform {
115                dispatcher,
116                _not_send_or_sync: PhantomData,
117            })
118        } else {
119            Err(anyhow!("must be constructed on main thread"))
120        }
121    }
122
123    pub fn test() -> Self {
124        Self::Test(smol::LocalExecutor::new())
125    }
126
127    pub fn spawn<T: 'static>(&self, future: impl Future<Output = T> + 'static) -> Task<T> {
128        match self {
129            Self::Platform { dispatcher, .. } => {
130                let dispatcher = dispatcher.clone();
131                let schedule = move |runnable: Runnable| dispatcher.run_on_main_thread(runnable);
132                let (runnable, task) = async_task::spawn_local(future, schedule);
133                runnable.schedule();
134                task
135            }
136            Self::Test(executor) => executor.spawn(future),
137            Self::Deterministic(executor) => executor.spawn_local(future),
138        }
139    }
140
141    pub fn run<T: 'static>(&self, future: impl 'static + Future<Output = T>) -> T {
142        match self {
143            Self::Platform { .. } => panic!("you can't call run on a platform foreground executor"),
144            Self::Test(executor) => smol::block_on(executor.run(future)),
145            Self::Deterministic(executor) => executor.run(future),
146        }
147    }
148}
149
150impl Background {
151    pub fn new() -> Self {
152        let executor = Arc::new(Executor::new());
153        let stop = channel::unbounded::<()>();
154
155        for i in 0..num_cpus::get() {
156            let executor = executor.clone();
157            let stop = stop.1.clone();
158            thread::Builder::new()
159                .name(format!("background-executor-{}", i))
160                .spawn(move || smol::block_on(executor.run(stop.recv())))
161                .unwrap();
162        }
163
164        Self::Production {
165            executor,
166            _stop: stop.0,
167        }
168    }
169
170    pub fn spawn<T, F>(&self, future: F) -> Task<T>
171    where
172        T: 'static + Send,
173        F: Send + Future<Output = T> + 'static,
174    {
175        match self {
176            Self::Production { executor, .. } => executor.spawn(future),
177            Self::Deterministic(executor) => executor.spawn(future),
178        }
179    }
180}
181
182pub fn deterministic(seed: u64) -> (Rc<Foreground>, Arc<Background>) {
183    let executor = Arc::new(Deterministic::new(seed));
184    (
185        Rc::new(Foreground::Deterministic(executor.clone())),
186        Arc::new(Background::Deterministic(executor)),
187    )
188}