executor.rs

  1use anyhow::{anyhow, Result};
  2use async_task::Runnable;
  3use backtrace::{Backtrace, BacktraceFmt, BytesOrWideString};
  4use parking_lot::Mutex;
  5use postage::{barrier, prelude::Stream as _};
  6use rand::prelude::*;
  7use smol::{channel, prelude::*, Executor, Timer};
  8use std::{
  9    any::Any,
 10    fmt::{self, Debug},
 11    marker::PhantomData,
 12    mem,
 13    ops::RangeInclusive,
 14    pin::Pin,
 15    rc::Rc,
 16    sync::{
 17        atomic::{AtomicBool, Ordering::SeqCst},
 18        Arc,
 19    },
 20    task::{Context, Poll},
 21    thread,
 22    time::{Duration, Instant},
 23};
 24use waker_fn::waker_fn;
 25
 26use crate::{
 27    platform::{self, Dispatcher},
 28    util,
 29};
 30
 31pub enum Foreground {
 32    Platform {
 33        dispatcher: Arc<dyn platform::Dispatcher>,
 34        _not_send_or_sync: PhantomData<Rc<()>>,
 35    },
 36    Test(smol::LocalExecutor<'static>),
 37    Deterministic(Arc<Deterministic>),
 38}
 39
 40pub enum Background {
 41    Deterministic(Arc<Deterministic>),
 42    Production {
 43        executor: Arc<smol::Executor<'static>>,
 44        _stop: channel::Sender<()>,
 45    },
 46}
 47
 48type AnyLocalFuture = Pin<Box<dyn 'static + Future<Output = Box<dyn Any + 'static>>>>;
 49type AnyFuture = Pin<Box<dyn 'static + Send + Future<Output = Box<dyn Any + Send + 'static>>>>;
 50type AnyTask = async_task::Task<Box<dyn Any + Send + 'static>>;
 51type AnyLocalTask = async_task::Task<Box<dyn Any + 'static>>;
 52
 53pub enum Task<T> {
 54    Local {
 55        any_task: AnyLocalTask,
 56        result_type: PhantomData<T>,
 57    },
 58    Send {
 59        any_task: AnyTask,
 60        result_type: PhantomData<T>,
 61    },
 62}
 63
 64unsafe impl<T: Send> Send for Task<T> {}
 65
 66struct DeterministicState {
 67    rng: StdRng,
 68    seed: u64,
 69    scheduled_from_foreground: Vec<(Runnable, Backtrace)>,
 70    scheduled_from_background: Vec<(Runnable, Backtrace)>,
 71    spawned_from_foreground: Vec<(Runnable, Backtrace)>,
 72    forbid_parking: bool,
 73    block_on_ticks: RangeInclusive<usize>,
 74    now: Instant,
 75    pending_timers: Vec<(Instant, barrier::Sender)>,
 76}
 77
 78pub struct Deterministic {
 79    state: Arc<Mutex<DeterministicState>>,
 80    parker: Mutex<parking::Parker>,
 81}
 82
 83impl Deterministic {
 84    fn new(seed: u64) -> Self {
 85        Self {
 86            state: Arc::new(Mutex::new(DeterministicState {
 87                rng: StdRng::seed_from_u64(seed),
 88                seed,
 89                scheduled_from_foreground: Default::default(),
 90                scheduled_from_background: Default::default(),
 91                spawned_from_foreground: Default::default(),
 92                forbid_parking: false,
 93                block_on_ticks: 0..=1000,
 94                now: Instant::now(),
 95                pending_timers: Default::default(),
 96            })),
 97            parker: Default::default(),
 98        }
 99    }
100
101    fn spawn_from_foreground(&self, future: AnyLocalFuture) -> AnyLocalTask {
102        let backtrace = Backtrace::new_unresolved();
103        let scheduled_once = AtomicBool::new(false);
104        let state = self.state.clone();
105        let unparker = self.parker.lock().unparker();
106        let (runnable, task) = async_task::spawn_local(future, move |runnable| {
107            let mut state = state.lock();
108            let backtrace = backtrace.clone();
109            if scheduled_once.fetch_or(true, SeqCst) {
110                state.scheduled_from_foreground.push((runnable, backtrace));
111            } else {
112                state.spawned_from_foreground.push((runnable, backtrace));
113            }
114            unparker.unpark();
115        });
116        runnable.schedule();
117        task
118    }
119
120    fn spawn(&self, future: AnyFuture) -> AnyTask {
121        let backtrace = Backtrace::new_unresolved();
122        let state = self.state.clone();
123        let unparker = self.parker.lock().unparker();
124        let (runnable, task) = async_task::spawn(future, move |runnable| {
125            let mut state = state.lock();
126            state
127                .scheduled_from_background
128                .push((runnable, backtrace.clone()));
129            unparker.unpark();
130        });
131        runnable.schedule();
132        task
133    }
134
135    fn run(&self, mut future: AnyLocalFuture) -> Box<dyn Any> {
136        let woken = Arc::new(AtomicBool::new(false));
137        loop {
138            if let Some(result) = self.run_internal(woken.clone(), &mut future) {
139                return result;
140            }
141
142            if !woken.load(SeqCst) && self.state.lock().forbid_parking {
143                panic!("deterministic executor parked after a call to forbid_parking");
144            }
145
146            woken.store(false, SeqCst);
147            self.parker.lock().park();
148        }
149    }
150
151    fn run_until_parked(&self) {
152        let woken = Arc::new(AtomicBool::new(false));
153        let mut future = any_local_future(std::future::pending::<()>());
154        self.run_internal(woken, &mut future);
155    }
156
157    fn run_internal(
158        &self,
159        woken: Arc<AtomicBool>,
160        future: &mut AnyLocalFuture,
161    ) -> Option<Box<dyn Any>> {
162        let unparker = self.parker.lock().unparker();
163        let waker = waker_fn(move || {
164            woken.store(true, SeqCst);
165            unparker.unpark();
166        });
167
168        let mut cx = Context::from_waker(&waker);
169        let mut trace = Trace::default();
170        loop {
171            let mut state = self.state.lock();
172            let runnable_count = state.scheduled_from_foreground.len()
173                + state.scheduled_from_background.len()
174                + state.spawned_from_foreground.len();
175
176            let ix = state.rng.gen_range(0..=runnable_count);
177            if ix < state.scheduled_from_foreground.len() {
178                let (_, backtrace) = &state.scheduled_from_foreground[ix];
179                trace.record(&state, backtrace.clone());
180                let runnable = state.scheduled_from_foreground.remove(ix).0;
181                drop(state);
182                runnable.run();
183            } else if ix - state.scheduled_from_foreground.len()
184                < state.scheduled_from_background.len()
185            {
186                let ix = ix - state.scheduled_from_foreground.len();
187                let (_, backtrace) = &state.scheduled_from_background[ix];
188                trace.record(&state, backtrace.clone());
189                let runnable = state.scheduled_from_background.remove(ix).0;
190                drop(state);
191                runnable.run();
192            } else if ix < runnable_count {
193                let (_, backtrace) = &state.spawned_from_foreground[0];
194                trace.record(&state, backtrace.clone());
195                let runnable = state.spawned_from_foreground.remove(0).0;
196                drop(state);
197                runnable.run();
198            } else {
199                drop(state);
200                if let Poll::Ready(result) = future.poll(&mut cx) {
201                    return Some(result);
202                }
203
204                let state = self.state.lock();
205                if state.scheduled_from_foreground.is_empty()
206                    && state.scheduled_from_background.is_empty()
207                    && state.spawned_from_foreground.is_empty()
208                {
209                    return None;
210                }
211            }
212        }
213    }
214
215    fn block_on(&self, future: &mut AnyLocalFuture) -> Option<Box<dyn Any>> {
216        let unparker = self.parker.lock().unparker();
217        let waker = waker_fn(move || {
218            unparker.unpark();
219        });
220        let max_ticks = {
221            let mut state = self.state.lock();
222            let range = state.block_on_ticks.clone();
223            state.rng.gen_range(range)
224        };
225
226        let mut cx = Context::from_waker(&waker);
227        let mut trace = Trace::default();
228        for _ in 0..max_ticks {
229            let mut state = self.state.lock();
230            let runnable_count = state.scheduled_from_background.len();
231            let ix = state.rng.gen_range(0..=runnable_count);
232            if ix < state.scheduled_from_background.len() {
233                let (_, backtrace) = &state.scheduled_from_background[ix];
234                trace.record(&state, backtrace.clone());
235                let runnable = state.scheduled_from_background.remove(ix).0;
236                drop(state);
237                runnable.run();
238            } else {
239                drop(state);
240                if let Poll::Ready(result) = future.as_mut().poll(&mut cx) {
241                    return Some(result);
242                }
243                let state = self.state.lock();
244                if state.scheduled_from_background.is_empty() {
245                    if state.forbid_parking {
246                        panic!("deterministic executor parked after a call to forbid_parking");
247                    }
248                    drop(state);
249                    self.parker.lock().park();
250                }
251
252                continue;
253            }
254        }
255
256        None
257    }
258}
259
260#[derive(Default)]
261struct Trace {
262    executed: Vec<Backtrace>,
263    scheduled: Vec<Vec<Backtrace>>,
264    spawned_from_foreground: Vec<Vec<Backtrace>>,
265}
266
267impl Trace {
268    fn record(&mut self, state: &DeterministicState, executed: Backtrace) {
269        self.scheduled.push(
270            state
271                .scheduled_from_foreground
272                .iter()
273                .map(|(_, backtrace)| backtrace.clone())
274                .collect(),
275        );
276        self.spawned_from_foreground.push(
277            state
278                .spawned_from_foreground
279                .iter()
280                .map(|(_, backtrace)| backtrace.clone())
281                .collect(),
282        );
283        self.executed.push(executed);
284    }
285
286    fn resolve(&mut self) {
287        for backtrace in &mut self.executed {
288            backtrace.resolve();
289        }
290
291        for backtraces in &mut self.scheduled {
292            for backtrace in backtraces {
293                backtrace.resolve();
294            }
295        }
296
297        for backtraces in &mut self.spawned_from_foreground {
298            for backtrace in backtraces {
299                backtrace.resolve();
300            }
301        }
302    }
303}
304
305impl Debug for Trace {
306    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
307        struct FirstCwdFrameInBacktrace<'a>(&'a Backtrace);
308
309        impl<'a> Debug for FirstCwdFrameInBacktrace<'a> {
310            fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
311                let cwd = std::env::current_dir().unwrap();
312                let mut print_path = |fmt: &mut fmt::Formatter<'_>, path: BytesOrWideString<'_>| {
313                    fmt::Display::fmt(&path, fmt)
314                };
315                let mut fmt = BacktraceFmt::new(f, backtrace::PrintFmt::Full, &mut print_path);
316                for frame in self.0.frames() {
317                    let mut formatted_frame = fmt.frame();
318                    if frame
319                        .symbols()
320                        .iter()
321                        .any(|s| s.filename().map_or(false, |f| f.starts_with(&cwd)))
322                    {
323                        formatted_frame.backtrace_frame(frame)?;
324                        break;
325                    }
326                }
327                fmt.finish()
328            }
329        }
330
331        for ((backtrace, scheduled), spawned_from_foreground) in self
332            .executed
333            .iter()
334            .zip(&self.scheduled)
335            .zip(&self.spawned_from_foreground)
336        {
337            writeln!(f, "Scheduled")?;
338            for backtrace in scheduled {
339                writeln!(f, "- {:?}", FirstCwdFrameInBacktrace(backtrace))?;
340            }
341            if scheduled.is_empty() {
342                writeln!(f, "None")?;
343            }
344            writeln!(f, "==========")?;
345
346            writeln!(f, "Spawned from foreground")?;
347            for backtrace in spawned_from_foreground {
348                writeln!(f, "- {:?}", FirstCwdFrameInBacktrace(backtrace))?;
349            }
350            if spawned_from_foreground.is_empty() {
351                writeln!(f, "None")?;
352            }
353            writeln!(f, "==========")?;
354
355            writeln!(f, "Run: {:?}", FirstCwdFrameInBacktrace(backtrace))?;
356            writeln!(f, "+++++++++++++++++++")?;
357        }
358
359        Ok(())
360    }
361}
362
363impl Drop for Trace {
364    fn drop(&mut self) {
365        let trace_on_panic = if let Ok(trace_on_panic) = std::env::var("EXECUTOR_TRACE_ON_PANIC") {
366            trace_on_panic == "1" || trace_on_panic == "true"
367        } else {
368            false
369        };
370        let trace_always = if let Ok(trace_always) = std::env::var("EXECUTOR_TRACE_ALWAYS") {
371            trace_always == "1" || trace_always == "true"
372        } else {
373            false
374        };
375
376        if trace_always || (trace_on_panic && thread::panicking()) {
377            self.resolve();
378            dbg!(self);
379        }
380    }
381}
382
383impl Foreground {
384    pub fn platform(dispatcher: Arc<dyn platform::Dispatcher>) -> Result<Self> {
385        if dispatcher.is_main_thread() {
386            Ok(Self::Platform {
387                dispatcher,
388                _not_send_or_sync: PhantomData,
389            })
390        } else {
391            Err(anyhow!("must be constructed on main thread"))
392        }
393    }
394
395    pub fn test() -> Self {
396        Self::Test(smol::LocalExecutor::new())
397    }
398
399    pub fn spawn<T: 'static>(&self, future: impl Future<Output = T> + 'static) -> Task<T> {
400        let future = any_local_future(future);
401        let any_task = match self {
402            Self::Deterministic(executor) => executor.spawn_from_foreground(future),
403            Self::Platform { dispatcher, .. } => {
404                fn spawn_inner(
405                    future: AnyLocalFuture,
406                    dispatcher: &Arc<dyn Dispatcher>,
407                ) -> AnyLocalTask {
408                    let dispatcher = dispatcher.clone();
409                    let schedule =
410                        move |runnable: Runnable| dispatcher.run_on_main_thread(runnable);
411                    let (runnable, task) = async_task::spawn_local(future, schedule);
412                    runnable.schedule();
413                    task
414                }
415                spawn_inner(future, dispatcher)
416            }
417            Self::Test(executor) => executor.spawn(future),
418        };
419        Task::local(any_task)
420    }
421
422    pub fn run<T: 'static>(&self, future: impl 'static + Future<Output = T>) -> T {
423        let future = any_local_future(future);
424        let any_value = match self {
425            Self::Deterministic(executor) => executor.run(future),
426            Self::Platform { .. } => panic!("you can't call run on a platform foreground executor"),
427            Self::Test(executor) => smol::block_on(executor.run(future)),
428        };
429        *any_value.downcast().unwrap()
430    }
431
432    pub fn forbid_parking(&self) {
433        match self {
434            Self::Deterministic(executor) => {
435                let mut state = executor.state.lock();
436                state.forbid_parking = true;
437                state.rng = StdRng::seed_from_u64(state.seed);
438            }
439            _ => panic!("this method can only be called on a deterministic executor"),
440        }
441    }
442
443    pub async fn timer(&self, duration: Duration) {
444        match self {
445            Self::Deterministic(executor) => {
446                let (tx, mut rx) = barrier::channel();
447                {
448                    let mut state = executor.state.lock();
449                    let wakeup_at = state.now + duration;
450                    state.pending_timers.push((wakeup_at, tx));
451                }
452                rx.recv().await;
453            }
454            _ => {
455                Timer::after(duration).await;
456            }
457        }
458    }
459
460    pub fn advance_clock(&self, duration: Duration) {
461        match self {
462            Self::Deterministic(executor) => {
463                executor.run_until_parked();
464
465                let mut state = executor.state.lock();
466                state.now += duration;
467                let now = state.now;
468                let mut pending_timers = mem::take(&mut state.pending_timers);
469                drop(state);
470
471                pending_timers.retain(|(wakeup, _)| *wakeup > now);
472                executor.state.lock().pending_timers.extend(pending_timers);
473            }
474            _ => panic!("this method can only be called on a deterministic executor"),
475        }
476    }
477
478    pub fn set_block_on_ticks(&self, range: RangeInclusive<usize>) {
479        match self {
480            Self::Deterministic(executor) => executor.state.lock().block_on_ticks = range,
481            _ => panic!("this method can only be called on a deterministic executor"),
482        }
483    }
484}
485
486impl Background {
487    pub fn new() -> Self {
488        let executor = Arc::new(Executor::new());
489        let stop = channel::unbounded::<()>();
490
491        for i in 0..2 * num_cpus::get() {
492            let executor = executor.clone();
493            let stop = stop.1.clone();
494            thread::Builder::new()
495                .name(format!("background-executor-{}", i))
496                .spawn(move || smol::block_on(executor.run(stop.recv())))
497                .unwrap();
498        }
499
500        Self::Production {
501            executor,
502            _stop: stop.0,
503        }
504    }
505
506    pub fn num_cpus(&self) -> usize {
507        num_cpus::get()
508    }
509
510    pub fn spawn<T, F>(&self, future: F) -> Task<T>
511    where
512        T: 'static + Send,
513        F: Send + Future<Output = T> + 'static,
514    {
515        let future = any_future(future);
516        let any_task = match self {
517            Self::Production { executor, .. } => executor.spawn(future),
518            Self::Deterministic(executor) => executor.spawn(future),
519        };
520        Task::send(any_task)
521    }
522
523    pub fn block_with_timeout<F, T>(
524        &self,
525        timeout: Duration,
526        future: F,
527    ) -> Result<T, impl Future<Output = T>>
528    where
529        T: 'static,
530        F: 'static + Unpin + Future<Output = T>,
531    {
532        let mut future = any_local_future(future);
533        if !timeout.is_zero() {
534            let output = match self {
535                Self::Production { .. } => smol::block_on(util::timeout(timeout, &mut future)).ok(),
536                Self::Deterministic(executor) => executor.block_on(&mut future),
537            };
538            if let Some(output) = output {
539                return Ok(*output.downcast().unwrap());
540            }
541        }
542        Err(async { *future.await.downcast().unwrap() })
543    }
544
545    pub async fn scoped<'scope, F>(&self, scheduler: F)
546    where
547        F: FnOnce(&mut Scope<'scope>),
548    {
549        let mut scope = Scope {
550            futures: Default::default(),
551            _phantom: PhantomData,
552        };
553        (scheduler)(&mut scope);
554        let spawned = scope
555            .futures
556            .into_iter()
557            .map(|f| self.spawn(f))
558            .collect::<Vec<_>>();
559        for task in spawned {
560            task.await;
561        }
562    }
563}
564
565pub struct Scope<'a> {
566    futures: Vec<Pin<Box<dyn Future<Output = ()> + Send + 'static>>>,
567    _phantom: PhantomData<&'a ()>,
568}
569
570impl<'a> Scope<'a> {
571    pub fn spawn<F>(&mut self, f: F)
572    where
573        F: Future<Output = ()> + Send + 'a,
574    {
575        let f = unsafe {
576            mem::transmute::<
577                Pin<Box<dyn Future<Output = ()> + Send + 'a>>,
578                Pin<Box<dyn Future<Output = ()> + Send + 'static>>,
579            >(Box::pin(f))
580        };
581        self.futures.push(f);
582    }
583}
584
585pub fn deterministic(seed: u64) -> (Rc<Foreground>, Arc<Background>) {
586    let executor = Arc::new(Deterministic::new(seed));
587    (
588        Rc::new(Foreground::Deterministic(executor.clone())),
589        Arc::new(Background::Deterministic(executor)),
590    )
591}
592
593impl<T> Task<T> {
594    fn local(any_task: AnyLocalTask) -> Self {
595        Self::Local {
596            any_task,
597            result_type: PhantomData,
598        }
599    }
600
601    pub fn detach(self) {
602        match self {
603            Task::Local { any_task, .. } => any_task.detach(),
604            Task::Send { any_task, .. } => any_task.detach(),
605        }
606    }
607}
608
609impl<T: Send> Task<T> {
610    fn send(any_task: AnyTask) -> Self {
611        Self::Send {
612            any_task,
613            result_type: PhantomData,
614        }
615    }
616}
617
618impl<T: fmt::Debug> fmt::Debug for Task<T> {
619    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
620        match self {
621            Task::Local { any_task, .. } => any_task.fmt(f),
622            Task::Send { any_task, .. } => any_task.fmt(f),
623        }
624    }
625}
626
627impl<T: 'static> Future for Task<T> {
628    type Output = T;
629
630    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
631        match unsafe { self.get_unchecked_mut() } {
632            Task::Local { any_task, .. } => {
633                any_task.poll(cx).map(|value| *value.downcast().unwrap())
634            }
635            Task::Send { any_task, .. } => {
636                any_task.poll(cx).map(|value| *value.downcast().unwrap())
637            }
638        }
639    }
640}
641
642fn any_future<T, F>(future: F) -> AnyFuture
643where
644    T: 'static + Send,
645    F: Future<Output = T> + Send + 'static,
646{
647    async { Box::new(future.await) as Box<dyn Any + Send> }.boxed()
648}
649
650fn any_local_future<T, F>(future: F) -> AnyLocalFuture
651where
652    T: 'static,
653    F: Future<Output = T> + 'static,
654{
655    async { Box::new(future.await) as Box<dyn Any> }.boxed_local()
656}