executor.rs

  1use anyhow::{anyhow, Result};
  2use async_task::Runnable;
  3pub use async_task::Task;
  4use parking_lot::Mutex;
  5use rand::prelude::*;
  6use smol::{channel, prelude::*, Executor};
  7use std::{
  8    marker::PhantomData,
  9    mem,
 10    pin::Pin,
 11    rc::Rc,
 12    sync::{mpsc::SyncSender, Arc},
 13    thread,
 14};
 15
 16use crate::platform;
 17
 18pub enum Foreground {
 19    Platform {
 20        dispatcher: Arc<dyn platform::Dispatcher>,
 21        _not_send_or_sync: PhantomData<Rc<()>>,
 22    },
 23    Test(smol::LocalExecutor<'static>),
 24    Deterministic(Arc<Deterministic>),
 25}
 26
 27pub enum Background {
 28    Deterministic(Arc<Deterministic>),
 29    Production {
 30        executor: Arc<smol::Executor<'static>>,
 31        threads: usize,
 32        _stop: channel::Sender<()>,
 33    },
 34}
 35
 36pub struct Deterministic {
 37    seed: u64,
 38    runnables: Arc<Mutex<(Vec<Runnable>, Option<SyncSender<()>>)>>,
 39}
 40
 41impl Deterministic {
 42    fn new(seed: u64) -> Self {
 43        Self {
 44            seed,
 45            runnables: Default::default(),
 46        }
 47    }
 48
 49    pub fn spawn_local<F, T>(&self, future: F) -> Task<T>
 50    where
 51        T: 'static,
 52        F: Future<Output = T> + 'static,
 53    {
 54        let runnables = self.runnables.clone();
 55        let (runnable, task) = async_task::spawn_local(future, move |runnable| {
 56            let mut runnables = runnables.lock();
 57            runnables.0.push(runnable);
 58            if let Some(wake_tx) = runnables.1.as_ref() {
 59                wake_tx.send(()).ok();
 60            }
 61        });
 62        runnable.schedule();
 63        task
 64    }
 65
 66    pub fn spawn<F, T>(&self, future: F) -> Task<T>
 67    where
 68        T: 'static + Send,
 69        F: 'static + Send + Future<Output = T>,
 70    {
 71        let runnables = self.runnables.clone();
 72        let (runnable, task) = async_task::spawn(future, move |runnable| {
 73            let mut runnables = runnables.lock();
 74            runnables.0.push(runnable);
 75            if let Some(wake_tx) = runnables.1.as_ref() {
 76                wake_tx.send(()).ok();
 77            }
 78        });
 79        runnable.schedule();
 80        task
 81    }
 82
 83    pub fn run<F, T>(&self, future: F) -> T
 84    where
 85        T: 'static,
 86        F: Future<Output = T> + 'static,
 87    {
 88        let (wake_tx, wake_rx) = std::sync::mpsc::sync_channel(32);
 89        let runnables = self.runnables.clone();
 90        runnables.lock().1 = Some(wake_tx);
 91
 92        let (output_tx, output_rx) = std::sync::mpsc::channel();
 93        self.spawn_local(async move {
 94            let output = future.await;
 95            output_tx.send(output).unwrap();
 96        })
 97        .detach();
 98
 99        let mut rng = StdRng::seed_from_u64(self.seed);
100        loop {
101            if let Ok(value) = output_rx.try_recv() {
102                runnables.lock().1 = None;
103                return value;
104            }
105
106            wake_rx.recv().unwrap();
107            let runnable = {
108                let mut runnables = runnables.lock();
109                let runnables = &mut runnables.0;
110                let index = rng.gen_range(0..runnables.len());
111                runnables.remove(index)
112            };
113
114            runnable.run();
115        }
116    }
117}
118
119impl Foreground {
120    pub fn platform(dispatcher: Arc<dyn platform::Dispatcher>) -> Result<Self> {
121        if dispatcher.is_main_thread() {
122            Ok(Self::Platform {
123                dispatcher,
124                _not_send_or_sync: PhantomData,
125            })
126        } else {
127            Err(anyhow!("must be constructed on main thread"))
128        }
129    }
130
131    pub fn test() -> Self {
132        Self::Test(smol::LocalExecutor::new())
133    }
134
135    pub fn spawn<T: 'static>(&self, future: impl Future<Output = T> + 'static) -> Task<T> {
136        match self {
137            Self::Platform { dispatcher, .. } => {
138                let dispatcher = dispatcher.clone();
139                let schedule = move |runnable: Runnable| dispatcher.run_on_main_thread(runnable);
140                let (runnable, task) = async_task::spawn_local(future, schedule);
141                runnable.schedule();
142                task
143            }
144            Self::Test(executor) => executor.spawn(future),
145            Self::Deterministic(executor) => executor.spawn_local(future),
146        }
147    }
148
149    pub fn run<T: 'static>(&self, future: impl 'static + Future<Output = T>) -> T {
150        match self {
151            Self::Platform { .. } => panic!("you can't call run on a platform foreground executor"),
152            Self::Test(executor) => smol::block_on(executor.run(future)),
153            Self::Deterministic(executor) => executor.run(future),
154        }
155    }
156}
157
158impl Background {
159    pub fn new() -> Self {
160        let executor = Arc::new(Executor::new());
161        let stop = channel::unbounded::<()>();
162        let threads = num_cpus::get();
163
164        for i in 0..threads {
165            let executor = executor.clone();
166            let stop = stop.1.clone();
167            thread::Builder::new()
168                .name(format!("background-executor-{}", i))
169                .spawn(move || smol::block_on(executor.run(stop.recv())))
170                .unwrap();
171        }
172
173        Self::Production {
174            executor,
175            threads,
176            _stop: stop.0,
177        }
178    }
179
180    pub fn threads(&self) -> usize {
181        match self {
182            Self::Deterministic(_) => 1,
183            Self::Production { threads, .. } => *threads,
184        }
185    }
186
187    pub fn spawn<T, F>(&self, future: F) -> Task<T>
188    where
189        T: 'static + Send,
190        F: Send + Future<Output = T> + 'static,
191    {
192        match self {
193            Self::Production { executor, .. } => executor.spawn(future),
194            Self::Deterministic(executor) => executor.spawn(future),
195        }
196    }
197
198    pub async fn scoped<'scope, F>(&self, scheduler: F)
199    where
200        F: FnOnce(&mut Scope<'scope>),
201    {
202        let mut scope = Scope {
203            futures: Default::default(),
204            _phantom: PhantomData,
205        };
206        (scheduler)(&mut scope);
207        let spawned = scope
208            .futures
209            .into_iter()
210            .map(|f| self.spawn(f))
211            .collect::<Vec<_>>();
212        for task in spawned {
213            task.await;
214        }
215    }
216}
217
218pub struct Scope<'a> {
219    futures: Vec<Pin<Box<dyn Future<Output = ()> + Send + 'static>>>,
220    _phantom: PhantomData<&'a ()>,
221}
222
223impl<'a> Scope<'a> {
224    pub fn spawn<F>(&mut self, f: F)
225    where
226        F: Future<Output = ()> + Send + 'a,
227    {
228        let f = unsafe {
229            mem::transmute::<
230                Pin<Box<dyn Future<Output = ()> + Send + 'a>>,
231                Pin<Box<dyn Future<Output = ()> + Send + 'static>>,
232            >(Box::pin(f))
233        };
234        self.futures.push(f);
235    }
236}
237
238pub fn deterministic(seed: u64) -> (Rc<Foreground>, Arc<Background>) {
239    let executor = Arc::new(Deterministic::new(seed));
240    (
241        Rc::new(Foreground::Deterministic(executor.clone())),
242        Arc::new(Background::Deterministic(executor)),
243    )
244}