executor.rs

  1use anyhow::{anyhow, Result};
  2use async_task::Runnable;
  3use backtrace::{Backtrace, BacktraceFmt, BytesOrWideString};
  4use parking_lot::Mutex;
  5use postage::{barrier, prelude::Stream as _};
  6use rand::prelude::*;
  7use smol::{channel, prelude::*, Executor, Timer};
  8use std::{
  9    any::Any,
 10    fmt::{self, Debug},
 11    marker::PhantomData,
 12    mem,
 13    ops::RangeInclusive,
 14    pin::Pin,
 15    rc::Rc,
 16    sync::{
 17        atomic::{AtomicBool, Ordering::SeqCst},
 18        Arc,
 19    },
 20    task::{Context, Poll},
 21    thread,
 22    time::{Duration, Instant},
 23};
 24use waker_fn::waker_fn;
 25
 26use crate::{
 27    platform::{self, Dispatcher},
 28    util,
 29};
 30
 31pub enum Foreground {
 32    Platform {
 33        dispatcher: Arc<dyn platform::Dispatcher>,
 34        _not_send_or_sync: PhantomData<Rc<()>>,
 35    },
 36    Test(smol::LocalExecutor<'static>),
 37    Deterministic(Arc<Deterministic>),
 38}
 39
 40pub enum Background {
 41    Deterministic(Arc<Deterministic>),
 42    Production {
 43        executor: Arc<smol::Executor<'static>>,
 44        _stop: channel::Sender<()>,
 45    },
 46}
 47
 48type AnyLocalFuture = Pin<Box<dyn 'static + Future<Output = Box<dyn Any + 'static>>>>;
 49type AnyFuture = Pin<Box<dyn 'static + Send + Future<Output = Box<dyn Any + Send + 'static>>>>;
 50type AnyTask = async_task::Task<Box<dyn Any + Send + 'static>>;
 51type AnyLocalTask = async_task::Task<Box<dyn Any + 'static>>;
 52
 53#[must_use]
 54pub enum Task<T> {
 55    Local {
 56        any_task: AnyLocalTask,
 57        result_type: PhantomData<T>,
 58    },
 59    Send {
 60        any_task: AnyTask,
 61        result_type: PhantomData<T>,
 62    },
 63}
 64
 65unsafe impl<T: Send> Send for Task<T> {}
 66
 67struct DeterministicState {
 68    rng: StdRng,
 69    seed: u64,
 70    scheduled_from_foreground: Vec<(Runnable, Backtrace)>,
 71    scheduled_from_background: Vec<(Runnable, Backtrace)>,
 72    spawned_from_foreground: Vec<(Runnable, Backtrace)>,
 73    forbid_parking: bool,
 74    block_on_ticks: RangeInclusive<usize>,
 75    now: Instant,
 76    pending_timers: Vec<(Instant, barrier::Sender)>,
 77}
 78
 79pub struct Deterministic {
 80    state: Arc<Mutex<DeterministicState>>,
 81    parker: Mutex<parking::Parker>,
 82}
 83
 84impl Deterministic {
 85    fn new(seed: u64) -> Self {
 86        Self {
 87            state: Arc::new(Mutex::new(DeterministicState {
 88                rng: StdRng::seed_from_u64(seed),
 89                seed,
 90                scheduled_from_foreground: Default::default(),
 91                scheduled_from_background: Default::default(),
 92                spawned_from_foreground: Default::default(),
 93                forbid_parking: false,
 94                block_on_ticks: 0..=1000,
 95                now: Instant::now(),
 96                pending_timers: Default::default(),
 97            })),
 98            parker: Default::default(),
 99        }
100    }
101
102    fn spawn_from_foreground(&self, future: AnyLocalFuture) -> AnyLocalTask {
103        let backtrace = Backtrace::new_unresolved();
104        let scheduled_once = AtomicBool::new(false);
105        let state = self.state.clone();
106        let unparker = self.parker.lock().unparker();
107        let (runnable, task) = async_task::spawn_local(future, move |runnable| {
108            let mut state = state.lock();
109            let backtrace = backtrace.clone();
110            if scheduled_once.fetch_or(true, SeqCst) {
111                state.scheduled_from_foreground.push((runnable, backtrace));
112            } else {
113                state.spawned_from_foreground.push((runnable, backtrace));
114            }
115            unparker.unpark();
116        });
117        runnable.schedule();
118        task
119    }
120
121    fn spawn(&self, future: AnyFuture) -> AnyTask {
122        let backtrace = Backtrace::new_unresolved();
123        let state = self.state.clone();
124        let unparker = self.parker.lock().unparker();
125        let (runnable, task) = async_task::spawn(future, move |runnable| {
126            let mut state = state.lock();
127            state
128                .scheduled_from_background
129                .push((runnable, backtrace.clone()));
130            unparker.unpark();
131        });
132        runnable.schedule();
133        task
134    }
135
136    fn run(&self, mut future: AnyLocalFuture) -> Box<dyn Any> {
137        let woken = Arc::new(AtomicBool::new(false));
138        loop {
139            if let Some(result) = self.run_internal(woken.clone(), &mut future) {
140                return result;
141            }
142
143            if !woken.load(SeqCst) && self.state.lock().forbid_parking {
144                panic!("deterministic executor parked after a call to forbid_parking");
145            }
146
147            woken.store(false, SeqCst);
148            self.parker.lock().park();
149        }
150    }
151
152    fn run_until_parked(&self) {
153        let woken = Arc::new(AtomicBool::new(false));
154        let mut future = any_local_future(std::future::pending::<()>());
155        self.run_internal(woken, &mut future);
156    }
157
158    fn run_internal(
159        &self,
160        woken: Arc<AtomicBool>,
161        future: &mut AnyLocalFuture,
162    ) -> Option<Box<dyn Any>> {
163        let unparker = self.parker.lock().unparker();
164        let waker = waker_fn(move || {
165            woken.store(true, SeqCst);
166            unparker.unpark();
167        });
168
169        let mut cx = Context::from_waker(&waker);
170        let mut trace = Trace::default();
171        loop {
172            let mut state = self.state.lock();
173            let runnable_count = state.scheduled_from_foreground.len()
174                + state.scheduled_from_background.len()
175                + state.spawned_from_foreground.len();
176
177            let ix = state.rng.gen_range(0..=runnable_count);
178            if ix < state.scheduled_from_foreground.len() {
179                let (_, backtrace) = &state.scheduled_from_foreground[ix];
180                trace.record(&state, backtrace.clone());
181                let runnable = state.scheduled_from_foreground.remove(ix).0;
182                drop(state);
183                runnable.run();
184            } else if ix - state.scheduled_from_foreground.len()
185                < state.scheduled_from_background.len()
186            {
187                let ix = ix - state.scheduled_from_foreground.len();
188                let (_, backtrace) = &state.scheduled_from_background[ix];
189                trace.record(&state, backtrace.clone());
190                let runnable = state.scheduled_from_background.remove(ix).0;
191                drop(state);
192                runnable.run();
193            } else if ix < runnable_count {
194                let (_, backtrace) = &state.spawned_from_foreground[0];
195                trace.record(&state, backtrace.clone());
196                let runnable = state.spawned_from_foreground.remove(0).0;
197                drop(state);
198                runnable.run();
199            } else {
200                drop(state);
201                if let Poll::Ready(result) = future.poll(&mut cx) {
202                    return Some(result);
203                }
204
205                let state = self.state.lock();
206                if state.scheduled_from_foreground.is_empty()
207                    && state.scheduled_from_background.is_empty()
208                    && state.spawned_from_foreground.is_empty()
209                {
210                    return None;
211                }
212            }
213        }
214    }
215
216    fn block_on(&self, future: &mut AnyLocalFuture) -> Option<Box<dyn Any>> {
217        let unparker = self.parker.lock().unparker();
218        let waker = waker_fn(move || {
219            unparker.unpark();
220        });
221        let max_ticks = {
222            let mut state = self.state.lock();
223            let range = state.block_on_ticks.clone();
224            state.rng.gen_range(range)
225        };
226
227        let mut cx = Context::from_waker(&waker);
228        let mut trace = Trace::default();
229        for _ in 0..max_ticks {
230            let mut state = self.state.lock();
231            let runnable_count = state.scheduled_from_background.len();
232            let ix = state.rng.gen_range(0..=runnable_count);
233            if ix < state.scheduled_from_background.len() {
234                let (_, backtrace) = &state.scheduled_from_background[ix];
235                trace.record(&state, backtrace.clone());
236                let runnable = state.scheduled_from_background.remove(ix).0;
237                drop(state);
238                runnable.run();
239            } else {
240                drop(state);
241                if let Poll::Ready(result) = future.as_mut().poll(&mut cx) {
242                    return Some(result);
243                }
244                let state = self.state.lock();
245                if state.scheduled_from_background.is_empty() {
246                    if state.forbid_parking {
247                        panic!("deterministic executor parked after a call to forbid_parking");
248                    }
249                    drop(state);
250                    self.parker.lock().park();
251                }
252
253                continue;
254            }
255        }
256
257        None
258    }
259}
260
261#[derive(Default)]
262struct Trace {
263    executed: Vec<Backtrace>,
264    scheduled: Vec<Vec<Backtrace>>,
265    spawned_from_foreground: Vec<Vec<Backtrace>>,
266}
267
268impl Trace {
269    fn record(&mut self, state: &DeterministicState, executed: Backtrace) {
270        self.scheduled.push(
271            state
272                .scheduled_from_foreground
273                .iter()
274                .map(|(_, backtrace)| backtrace.clone())
275                .collect(),
276        );
277        self.spawned_from_foreground.push(
278            state
279                .spawned_from_foreground
280                .iter()
281                .map(|(_, backtrace)| backtrace.clone())
282                .collect(),
283        );
284        self.executed.push(executed);
285    }
286
287    fn resolve(&mut self) {
288        for backtrace in &mut self.executed {
289            backtrace.resolve();
290        }
291
292        for backtraces in &mut self.scheduled {
293            for backtrace in backtraces {
294                backtrace.resolve();
295            }
296        }
297
298        for backtraces in &mut self.spawned_from_foreground {
299            for backtrace in backtraces {
300                backtrace.resolve();
301            }
302        }
303    }
304}
305
306impl Debug for Trace {
307    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
308        struct FirstCwdFrameInBacktrace<'a>(&'a Backtrace);
309
310        impl<'a> Debug for FirstCwdFrameInBacktrace<'a> {
311            fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
312                let cwd = std::env::current_dir().unwrap();
313                let mut print_path = |fmt: &mut fmt::Formatter<'_>, path: BytesOrWideString<'_>| {
314                    fmt::Display::fmt(&path, fmt)
315                };
316                let mut fmt = BacktraceFmt::new(f, backtrace::PrintFmt::Full, &mut print_path);
317                for frame in self.0.frames() {
318                    let mut formatted_frame = fmt.frame();
319                    if frame
320                        .symbols()
321                        .iter()
322                        .any(|s| s.filename().map_or(false, |f| f.starts_with(&cwd)))
323                    {
324                        formatted_frame.backtrace_frame(frame)?;
325                        break;
326                    }
327                }
328                fmt.finish()
329            }
330        }
331
332        for ((backtrace, scheduled), spawned_from_foreground) in self
333            .executed
334            .iter()
335            .zip(&self.scheduled)
336            .zip(&self.spawned_from_foreground)
337        {
338            writeln!(f, "Scheduled")?;
339            for backtrace in scheduled {
340                writeln!(f, "- {:?}", FirstCwdFrameInBacktrace(backtrace))?;
341            }
342            if scheduled.is_empty() {
343                writeln!(f, "None")?;
344            }
345            writeln!(f, "==========")?;
346
347            writeln!(f, "Spawned from foreground")?;
348            for backtrace in spawned_from_foreground {
349                writeln!(f, "- {:?}", FirstCwdFrameInBacktrace(backtrace))?;
350            }
351            if spawned_from_foreground.is_empty() {
352                writeln!(f, "None")?;
353            }
354            writeln!(f, "==========")?;
355
356            writeln!(f, "Run: {:?}", FirstCwdFrameInBacktrace(backtrace))?;
357            writeln!(f, "+++++++++++++++++++")?;
358        }
359
360        Ok(())
361    }
362}
363
364impl Drop for Trace {
365    fn drop(&mut self) {
366        let trace_on_panic = if let Ok(trace_on_panic) = std::env::var("EXECUTOR_TRACE_ON_PANIC") {
367            trace_on_panic == "1" || trace_on_panic == "true"
368        } else {
369            false
370        };
371        let trace_always = if let Ok(trace_always) = std::env::var("EXECUTOR_TRACE_ALWAYS") {
372            trace_always == "1" || trace_always == "true"
373        } else {
374            false
375        };
376
377        if trace_always || (trace_on_panic && thread::panicking()) {
378            self.resolve();
379            dbg!(self);
380        }
381    }
382}
383
384impl Foreground {
385    pub fn platform(dispatcher: Arc<dyn platform::Dispatcher>) -> Result<Self> {
386        if dispatcher.is_main_thread() {
387            Ok(Self::Platform {
388                dispatcher,
389                _not_send_or_sync: PhantomData,
390            })
391        } else {
392            Err(anyhow!("must be constructed on main thread"))
393        }
394    }
395
396    pub fn test() -> Self {
397        Self::Test(smol::LocalExecutor::new())
398    }
399
400    pub fn spawn<T: 'static>(&self, future: impl Future<Output = T> + 'static) -> Task<T> {
401        let future = any_local_future(future);
402        let any_task = match self {
403            Self::Deterministic(executor) => executor.spawn_from_foreground(future),
404            Self::Platform { dispatcher, .. } => {
405                fn spawn_inner(
406                    future: AnyLocalFuture,
407                    dispatcher: &Arc<dyn Dispatcher>,
408                ) -> AnyLocalTask {
409                    let dispatcher = dispatcher.clone();
410                    let schedule =
411                        move |runnable: Runnable| dispatcher.run_on_main_thread(runnable);
412                    let (runnable, task) = async_task::spawn_local(future, schedule);
413                    runnable.schedule();
414                    task
415                }
416                spawn_inner(future, dispatcher)
417            }
418            Self::Test(executor) => executor.spawn(future),
419        };
420        Task::local(any_task)
421    }
422
423    pub fn run<T: 'static>(&self, future: impl 'static + Future<Output = T>) -> T {
424        let future = any_local_future(future);
425        let any_value = match self {
426            Self::Deterministic(executor) => executor.run(future),
427            Self::Platform { .. } => panic!("you can't call run on a platform foreground executor"),
428            Self::Test(executor) => smol::block_on(executor.run(future)),
429        };
430        *any_value.downcast().unwrap()
431    }
432
433    pub fn forbid_parking(&self) {
434        match self {
435            Self::Deterministic(executor) => {
436                let mut state = executor.state.lock();
437                state.forbid_parking = true;
438                state.rng = StdRng::seed_from_u64(state.seed);
439            }
440            _ => panic!("this method can only be called on a deterministic executor"),
441        }
442    }
443
444    pub async fn timer(&self, duration: Duration) {
445        match self {
446            Self::Deterministic(executor) => {
447                let (tx, mut rx) = barrier::channel();
448                {
449                    let mut state = executor.state.lock();
450                    let wakeup_at = state.now + duration;
451                    state.pending_timers.push((wakeup_at, tx));
452                }
453                rx.recv().await;
454            }
455            _ => {
456                Timer::after(duration).await;
457            }
458        }
459    }
460
461    pub fn advance_clock(&self, duration: Duration) {
462        match self {
463            Self::Deterministic(executor) => {
464                executor.run_until_parked();
465
466                let mut state = executor.state.lock();
467                state.now += duration;
468                let now = state.now;
469                let mut pending_timers = mem::take(&mut state.pending_timers);
470                drop(state);
471
472                pending_timers.retain(|(wakeup, _)| *wakeup > now);
473                executor.state.lock().pending_timers.extend(pending_timers);
474            }
475            _ => panic!("this method can only be called on a deterministic executor"),
476        }
477    }
478
479    pub fn set_block_on_ticks(&self, range: RangeInclusive<usize>) {
480        match self {
481            Self::Deterministic(executor) => executor.state.lock().block_on_ticks = range,
482            _ => panic!("this method can only be called on a deterministic executor"),
483        }
484    }
485}
486
487impl Background {
488    pub fn new() -> Self {
489        let executor = Arc::new(Executor::new());
490        let stop = channel::unbounded::<()>();
491
492        for i in 0..2 * num_cpus::get() {
493            let executor = executor.clone();
494            let stop = stop.1.clone();
495            thread::Builder::new()
496                .name(format!("background-executor-{}", i))
497                .spawn(move || smol::block_on(executor.run(stop.recv())))
498                .unwrap();
499        }
500
501        Self::Production {
502            executor,
503            _stop: stop.0,
504        }
505    }
506
507    pub fn num_cpus(&self) -> usize {
508        num_cpus::get()
509    }
510
511    pub fn spawn<T, F>(&self, future: F) -> Task<T>
512    where
513        T: 'static + Send,
514        F: Send + Future<Output = T> + 'static,
515    {
516        let future = any_future(future);
517        let any_task = match self {
518            Self::Production { executor, .. } => executor.spawn(future),
519            Self::Deterministic(executor) => executor.spawn(future),
520        };
521        Task::send(any_task)
522    }
523
524    pub fn block_with_timeout<F, T>(
525        &self,
526        timeout: Duration,
527        future: F,
528    ) -> Result<T, impl Future<Output = T>>
529    where
530        T: 'static,
531        F: 'static + Unpin + Future<Output = T>,
532    {
533        let mut future = any_local_future(future);
534        if !timeout.is_zero() {
535            let output = match self {
536                Self::Production { .. } => smol::block_on(util::timeout(timeout, &mut future)).ok(),
537                Self::Deterministic(executor) => executor.block_on(&mut future),
538            };
539            if let Some(output) = output {
540                return Ok(*output.downcast().unwrap());
541            }
542        }
543        Err(async { *future.await.downcast().unwrap() })
544    }
545
546    pub async fn scoped<'scope, F>(&self, scheduler: F)
547    where
548        F: FnOnce(&mut Scope<'scope>),
549    {
550        let mut scope = Scope {
551            futures: Default::default(),
552            _phantom: PhantomData,
553        };
554        (scheduler)(&mut scope);
555        let spawned = scope
556            .futures
557            .into_iter()
558            .map(|f| self.spawn(f))
559            .collect::<Vec<_>>();
560        for task in spawned {
561            task.await;
562        }
563    }
564}
565
566pub struct Scope<'a> {
567    futures: Vec<Pin<Box<dyn Future<Output = ()> + Send + 'static>>>,
568    _phantom: PhantomData<&'a ()>,
569}
570
571impl<'a> Scope<'a> {
572    pub fn spawn<F>(&mut self, f: F)
573    where
574        F: Future<Output = ()> + Send + 'a,
575    {
576        let f = unsafe {
577            mem::transmute::<
578                Pin<Box<dyn Future<Output = ()> + Send + 'a>>,
579                Pin<Box<dyn Future<Output = ()> + Send + 'static>>,
580            >(Box::pin(f))
581        };
582        self.futures.push(f);
583    }
584}
585
586pub fn deterministic(seed: u64) -> (Rc<Foreground>, Arc<Background>) {
587    let executor = Arc::new(Deterministic::new(seed));
588    (
589        Rc::new(Foreground::Deterministic(executor.clone())),
590        Arc::new(Background::Deterministic(executor)),
591    )
592}
593
594impl<T> Task<T> {
595    fn local(any_task: AnyLocalTask) -> Self {
596        Self::Local {
597            any_task,
598            result_type: PhantomData,
599        }
600    }
601
602    pub fn detach(self) {
603        match self {
604            Task::Local { any_task, .. } => any_task.detach(),
605            Task::Send { any_task, .. } => any_task.detach(),
606        }
607    }
608}
609
610impl<T: Send> Task<T> {
611    fn send(any_task: AnyTask) -> Self {
612        Self::Send {
613            any_task,
614            result_type: PhantomData,
615        }
616    }
617}
618
619impl<T: fmt::Debug> fmt::Debug for Task<T> {
620    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
621        match self {
622            Task::Local { any_task, .. } => any_task.fmt(f),
623            Task::Send { any_task, .. } => any_task.fmt(f),
624        }
625    }
626}
627
628impl<T: 'static> Future for Task<T> {
629    type Output = T;
630
631    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
632        match unsafe { self.get_unchecked_mut() } {
633            Task::Local { any_task, .. } => {
634                any_task.poll(cx).map(|value| *value.downcast().unwrap())
635            }
636            Task::Send { any_task, .. } => {
637                any_task.poll(cx).map(|value| *value.downcast().unwrap())
638            }
639        }
640    }
641}
642
643fn any_future<T, F>(future: F) -> AnyFuture
644where
645    T: 'static + Send,
646    F: Future<Output = T> + Send + 'static,
647{
648    async { Box::new(future.await) as Box<dyn Any + Send> }.boxed()
649}
650
651fn any_local_future<T, F>(future: F) -> AnyLocalFuture
652where
653    T: 'static,
654    F: Future<Output = T> + 'static,
655{
656    async { Box::new(future.await) as Box<dyn Any> }.boxed_local()
657}