executor.rs

  1use anyhow::{anyhow, Result};
  2use async_task::Runnable;
  3pub use async_task::Task;
  4use parking_lot::Mutex;
  5use rand::prelude::*;
  6use smol::{channel, prelude::*, Executor};
  7use std::{
  8    marker::PhantomData,
  9    mem,
 10    pin::Pin,
 11    rc::Rc,
 12    sync::{mpsc::SyncSender, Arc},
 13    thread,
 14};
 15
 16use crate::platform;
 17
 18pub enum Foreground {
 19    Platform {
 20        dispatcher: Arc<dyn platform::Dispatcher>,
 21        _not_send_or_sync: PhantomData<Rc<()>>,
 22    },
 23    Test(smol::LocalExecutor<'static>),
 24    Deterministic(Arc<Deterministic>),
 25}
 26
 27pub enum Background {
 28    Deterministic(Arc<Deterministic>),
 29    Production {
 30        executor: Arc<smol::Executor<'static>>,
 31        threads: usize,
 32        _stop: channel::Sender<()>,
 33    },
 34}
 35
 36#[derive(Default)]
 37struct Runnables {
 38    scheduled: Vec<Runnable>,
 39    spawned_from_foreground: Vec<Runnable>,
 40    waker: Option<SyncSender<()>>,
 41}
 42
 43pub struct Deterministic {
 44    seed: u64,
 45    runnables: Arc<Mutex<Runnables>>,
 46}
 47
 48impl Deterministic {
 49    fn new(seed: u64) -> Self {
 50        Self {
 51            seed,
 52            runnables: Default::default(),
 53        }
 54    }
 55
 56    pub fn spawn_from_foreground<F, T>(&self, future: F) -> Task<T>
 57    where
 58        T: 'static,
 59        F: Future<Output = T> + 'static,
 60    {
 61        let scheduled_once = Mutex::new(false);
 62        let runnables = self.runnables.clone();
 63        let (runnable, task) = async_task::spawn_local(future, move |runnable| {
 64            let mut runnables = runnables.lock();
 65            if *scheduled_once.lock() {
 66                runnables.scheduled.push(runnable);
 67            } else {
 68                runnables.spawned_from_foreground.push(runnable);
 69                *scheduled_once.lock() = true;
 70            }
 71            if let Some(waker) = runnables.waker.as_ref() {
 72                waker.send(()).ok();
 73            }
 74        });
 75        runnable.schedule();
 76        task
 77    }
 78
 79    pub fn spawn<F, T>(&self, future: F) -> Task<T>
 80    where
 81        T: 'static + Send,
 82        F: 'static + Send + Future<Output = T>,
 83    {
 84        let runnables = self.runnables.clone();
 85        let (runnable, task) = async_task::spawn(future, move |runnable| {
 86            let mut runnables = runnables.lock();
 87            runnables.scheduled.push(runnable);
 88            if let Some(waker) = runnables.waker.as_ref() {
 89                waker.send(()).ok();
 90            }
 91        });
 92        runnable.schedule();
 93        task
 94    }
 95
 96    pub fn run<F, T>(&self, future: F) -> T
 97    where
 98        T: 'static,
 99        F: Future<Output = T> + 'static,
100    {
101        let (wake_tx, wake_rx) = std::sync::mpsc::sync_channel(32);
102        let runnables = self.runnables.clone();
103        runnables.lock().waker = Some(wake_tx);
104
105        let (output_tx, output_rx) = std::sync::mpsc::channel();
106        self.spawn_from_foreground(async move {
107            let output = future.await;
108            output_tx.send(output).unwrap();
109        })
110        .detach();
111
112        let mut rng = StdRng::seed_from_u64(self.seed);
113        loop {
114            if let Ok(value) = output_rx.try_recv() {
115                runnables.lock().waker = None;
116                return value;
117            }
118
119            wake_rx.recv().unwrap();
120            let runnable = {
121                let mut runnables = runnables.lock();
122                let ix = rng.gen_range(
123                    0..runnables.scheduled.len() + runnables.spawned_from_foreground.len(),
124                );
125                if ix < runnables.scheduled.len() {
126                    runnables.scheduled.remove(ix)
127                } else {
128                    runnables.spawned_from_foreground.remove(0)
129                }
130            };
131
132            runnable.run();
133        }
134    }
135}
136
137impl Foreground {
138    pub fn platform(dispatcher: Arc<dyn platform::Dispatcher>) -> Result<Self> {
139        if dispatcher.is_main_thread() {
140            Ok(Self::Platform {
141                dispatcher,
142                _not_send_or_sync: PhantomData,
143            })
144        } else {
145            Err(anyhow!("must be constructed on main thread"))
146        }
147    }
148
149    pub fn test() -> Self {
150        Self::Test(smol::LocalExecutor::new())
151    }
152
153    pub fn spawn<T: 'static>(&self, future: impl Future<Output = T> + 'static) -> Task<T> {
154        match self {
155            Self::Platform { dispatcher, .. } => {
156                let dispatcher = dispatcher.clone();
157                let schedule = move |runnable: Runnable| dispatcher.run_on_main_thread(runnable);
158                let (runnable, task) = async_task::spawn_local(future, schedule);
159                runnable.schedule();
160                task
161            }
162            Self::Test(executor) => executor.spawn(future),
163            Self::Deterministic(executor) => executor.spawn_from_foreground(future),
164        }
165    }
166
167    pub fn run<T: 'static>(&self, future: impl 'static + Future<Output = T>) -> T {
168        match self {
169            Self::Platform { .. } => panic!("you can't call run on a platform foreground executor"),
170            Self::Test(executor) => smol::block_on(executor.run(future)),
171            Self::Deterministic(executor) => executor.run(future),
172        }
173    }
174}
175
176impl Background {
177    pub fn new() -> Self {
178        let executor = Arc::new(Executor::new());
179        let stop = channel::unbounded::<()>();
180        let threads = num_cpus::get();
181
182        for i in 0..threads {
183            let executor = executor.clone();
184            let stop = stop.1.clone();
185            thread::Builder::new()
186                .name(format!("background-executor-{}", i))
187                .spawn(move || smol::block_on(executor.run(stop.recv())))
188                .unwrap();
189        }
190
191        Self::Production {
192            executor,
193            threads,
194            _stop: stop.0,
195        }
196    }
197
198    pub fn threads(&self) -> usize {
199        match self {
200            Self::Deterministic(_) => 1,
201            Self::Production { threads, .. } => *threads,
202        }
203    }
204
205    pub fn spawn<T, F>(&self, future: F) -> Task<T>
206    where
207        T: 'static + Send,
208        F: Send + Future<Output = T> + 'static,
209    {
210        match self {
211            Self::Production { executor, .. } => executor.spawn(future),
212            Self::Deterministic(executor) => executor.spawn(future),
213        }
214    }
215
216    pub async fn scoped<'scope, F>(&self, scheduler: F)
217    where
218        F: FnOnce(&mut Scope<'scope>),
219    {
220        let mut scope = Scope {
221            futures: Default::default(),
222            _phantom: PhantomData,
223        };
224        (scheduler)(&mut scope);
225        let spawned = scope
226            .futures
227            .into_iter()
228            .map(|f| self.spawn(f))
229            .collect::<Vec<_>>();
230        for task in spawned {
231            task.await;
232        }
233    }
234}
235
236pub struct Scope<'a> {
237    futures: Vec<Pin<Box<dyn Future<Output = ()> + Send + 'static>>>,
238    _phantom: PhantomData<&'a ()>,
239}
240
241impl<'a> Scope<'a> {
242    pub fn spawn<F>(&mut self, f: F)
243    where
244        F: Future<Output = ()> + Send + 'a,
245    {
246        let f = unsafe {
247            mem::transmute::<
248                Pin<Box<dyn Future<Output = ()> + Send + 'a>>,
249                Pin<Box<dyn Future<Output = ()> + Send + 'static>>,
250            >(Box::pin(f))
251        };
252        self.futures.push(f);
253    }
254}
255
256pub fn deterministic(seed: u64) -> (Rc<Foreground>, Arc<Background>) {
257    let executor = Arc::new(Deterministic::new(seed));
258    (
259        Rc::new(Foreground::Deterministic(executor.clone())),
260        Arc::new(Background::Deterministic(executor)),
261    )
262}