executor.rs

  1use anyhow::{anyhow, Result};
  2use async_task::Runnable;
  3use backtrace::{Backtrace, BacktraceFmt, BytesOrWideString};
  4use parking_lot::Mutex;
  5use postage::{barrier, prelude::Stream as _};
  6use rand::prelude::*;
  7use smol::{channel, prelude::*, Executor, Timer};
  8use std::{
  9    any::Any,
 10    fmt::{self, Debug, Display},
 11    marker::PhantomData,
 12    mem,
 13    ops::RangeInclusive,
 14    pin::Pin,
 15    rc::Rc,
 16    sync::{
 17        atomic::{AtomicBool, Ordering::SeqCst},
 18        Arc,
 19    },
 20    task::{Context, Poll},
 21    thread,
 22    time::{Duration, Instant},
 23};
 24use waker_fn::waker_fn;
 25
 26use crate::{
 27    platform::{self, Dispatcher},
 28    util, MutableAppContext,
 29};
 30
 31pub enum Foreground {
 32    Platform {
 33        dispatcher: Arc<dyn platform::Dispatcher>,
 34        _not_send_or_sync: PhantomData<Rc<()>>,
 35    },
 36    Deterministic(Arc<Deterministic>),
 37}
 38
 39pub enum Background {
 40    Deterministic {
 41        executor: Arc<Deterministic>,
 42    },
 43    Production {
 44        executor: Arc<smol::Executor<'static>>,
 45        _stop: channel::Sender<()>,
 46    },
 47}
 48
 49type AnyLocalFuture = Pin<Box<dyn 'static + Future<Output = Box<dyn Any + 'static>>>>;
 50type AnyFuture = Pin<Box<dyn 'static + Send + Future<Output = Box<dyn Any + Send + 'static>>>>;
 51type AnyTask = async_task::Task<Box<dyn Any + Send + 'static>>;
 52type AnyLocalTask = async_task::Task<Box<dyn Any + 'static>>;
 53
 54#[must_use]
 55pub enum Task<T> {
 56    Ready(Option<T>),
 57    Local {
 58        any_task: AnyLocalTask,
 59        result_type: PhantomData<T>,
 60    },
 61    Send {
 62        any_task: AnyTask,
 63        result_type: PhantomData<T>,
 64    },
 65}
 66
 67unsafe impl<T: Send> Send for Task<T> {}
 68
 69struct DeterministicState {
 70    rng: StdRng,
 71    seed: u64,
 72    scheduled_from_foreground: Vec<(Runnable, Backtrace)>,
 73    scheduled_from_background: Vec<(Runnable, Backtrace)>,
 74    spawned_from_foreground: Vec<(Runnable, Backtrace)>,
 75    forbid_parking: bool,
 76    block_on_ticks: RangeInclusive<usize>,
 77    now: Instant,
 78    pending_timers: Vec<(Instant, barrier::Sender)>,
 79    waiting_backtrace: Option<Backtrace>,
 80}
 81
 82pub struct Deterministic {
 83    state: Arc<Mutex<DeterministicState>>,
 84    parker: Mutex<parking::Parker>,
 85}
 86
 87impl Deterministic {
 88    fn new(seed: u64) -> Self {
 89        Self {
 90            state: Arc::new(Mutex::new(DeterministicState {
 91                rng: StdRng::seed_from_u64(seed),
 92                seed,
 93                scheduled_from_foreground: Default::default(),
 94                scheduled_from_background: Default::default(),
 95                spawned_from_foreground: Default::default(),
 96                forbid_parking: false,
 97                block_on_ticks: 0..=1000,
 98                now: Instant::now(),
 99                pending_timers: Default::default(),
100                waiting_backtrace: None,
101            })),
102            parker: Default::default(),
103        }
104    }
105
106    fn spawn_from_foreground(&self, future: AnyLocalFuture) -> AnyLocalTask {
107        let backtrace = Backtrace::new_unresolved();
108        let scheduled_once = AtomicBool::new(false);
109        let state = self.state.clone();
110        let unparker = self.parker.lock().unparker();
111        let (runnable, task) = async_task::spawn_local(future, move |runnable| {
112            let mut state = state.lock();
113            let backtrace = backtrace.clone();
114            if scheduled_once.fetch_or(true, SeqCst) {
115                state.scheduled_from_foreground.push((runnable, backtrace));
116            } else {
117                state.spawned_from_foreground.push((runnable, backtrace));
118            }
119            unparker.unpark();
120        });
121        runnable.schedule();
122        task
123    }
124
125    fn spawn(&self, future: AnyFuture) -> AnyTask {
126        let backtrace = Backtrace::new_unresolved();
127        let state = self.state.clone();
128        let unparker = self.parker.lock().unparker();
129        let (runnable, task) = async_task::spawn(future, move |runnable| {
130            let mut state = state.lock();
131            state
132                .scheduled_from_background
133                .push((runnable, backtrace.clone()));
134            unparker.unpark();
135        });
136        runnable.schedule();
137        task
138    }
139
140    fn run(&self, mut future: AnyLocalFuture) -> Box<dyn Any> {
141        let woken = Arc::new(AtomicBool::new(false));
142        loop {
143            if let Some(result) = self.run_internal(woken.clone(), &mut future) {
144                return result;
145            }
146
147            if !woken.load(SeqCst) {
148                self.state.lock().will_park();
149            }
150
151            woken.store(false, SeqCst);
152            self.parker.lock().park();
153        }
154    }
155
156    fn run_until_parked(&self) {
157        let woken = Arc::new(AtomicBool::new(false));
158        let mut future = any_local_future(std::future::pending::<()>());
159        self.run_internal(woken, &mut future);
160    }
161
162    fn run_internal(
163        &self,
164        woken: Arc<AtomicBool>,
165        future: &mut AnyLocalFuture,
166    ) -> Option<Box<dyn Any>> {
167        let unparker = self.parker.lock().unparker();
168        let waker = waker_fn(move || {
169            woken.store(true, SeqCst);
170            unparker.unpark();
171        });
172
173        let mut cx = Context::from_waker(&waker);
174        let mut trace = Trace::default();
175        loop {
176            let mut state = self.state.lock();
177            let runnable_count = state.scheduled_from_foreground.len()
178                + state.scheduled_from_background.len()
179                + state.spawned_from_foreground.len();
180
181            let ix = state.rng.gen_range(0..=runnable_count);
182            if ix < state.scheduled_from_foreground.len() {
183                let (_, backtrace) = &state.scheduled_from_foreground[ix];
184                trace.record(&state, backtrace.clone());
185                let runnable = state.scheduled_from_foreground.remove(ix).0;
186                drop(state);
187                runnable.run();
188            } else if ix - state.scheduled_from_foreground.len()
189                < state.scheduled_from_background.len()
190            {
191                let ix = ix - state.scheduled_from_foreground.len();
192                let (_, backtrace) = &state.scheduled_from_background[ix];
193                trace.record(&state, backtrace.clone());
194                let runnable = state.scheduled_from_background.remove(ix).0;
195                drop(state);
196                runnable.run();
197            } else if ix < runnable_count {
198                let (_, backtrace) = &state.spawned_from_foreground[0];
199                trace.record(&state, backtrace.clone());
200                let runnable = state.spawned_from_foreground.remove(0).0;
201                drop(state);
202                runnable.run();
203            } else {
204                drop(state);
205                if let Poll::Ready(result) = future.poll(&mut cx) {
206                    return Some(result);
207                }
208
209                let state = self.state.lock();
210
211                if state.scheduled_from_foreground.is_empty()
212                    && state.scheduled_from_background.is_empty()
213                    && state.spawned_from_foreground.is_empty()
214                {
215                    return None;
216                }
217            }
218        }
219    }
220
221    fn block_on(&self, future: &mut AnyLocalFuture) -> Option<Box<dyn Any>> {
222        let unparker = self.parker.lock().unparker();
223        let waker = waker_fn(move || {
224            unparker.unpark();
225        });
226        let max_ticks = {
227            let mut state = self.state.lock();
228            let range = state.block_on_ticks.clone();
229            state.rng.gen_range(range)
230        };
231
232        let mut cx = Context::from_waker(&waker);
233        let mut trace = Trace::default();
234        for _ in 0..max_ticks {
235            let mut state = self.state.lock();
236            let runnable_count = state.scheduled_from_background.len();
237            let ix = state.rng.gen_range(0..=runnable_count);
238            if ix < state.scheduled_from_background.len() {
239                let (_, backtrace) = &state.scheduled_from_background[ix];
240                trace.record(&state, backtrace.clone());
241                let runnable = state.scheduled_from_background.remove(ix).0;
242                drop(state);
243                runnable.run();
244            } else {
245                drop(state);
246                if let Poll::Ready(result) = future.as_mut().poll(&mut cx) {
247                    return Some(result);
248                }
249                let mut state = self.state.lock();
250                if state.scheduled_from_background.is_empty() {
251                    state.will_park();
252                    drop(state);
253                    self.parker.lock().park();
254                }
255
256                continue;
257            }
258        }
259
260        None
261    }
262}
263
264impl DeterministicState {
265    fn will_park(&mut self) {
266        if self.forbid_parking {
267            let mut backtrace_message = String::new();
268            if let Some(backtrace) = self.waiting_backtrace.as_mut() {
269                backtrace.resolve();
270                backtrace_message = format!(
271                    "\nbacktrace of waiting future:\n{:?}",
272                    CwdBacktrace::new(backtrace)
273                );
274            }
275
276            panic!(
277                "deterministic executor parked after a call to forbid_parking{}",
278                backtrace_message
279            );
280        }
281    }
282}
283
284#[derive(Default)]
285struct Trace {
286    executed: Vec<Backtrace>,
287    scheduled: Vec<Vec<Backtrace>>,
288    spawned_from_foreground: Vec<Vec<Backtrace>>,
289}
290
291impl Trace {
292    fn record(&mut self, state: &DeterministicState, executed: Backtrace) {
293        self.scheduled.push(
294            state
295                .scheduled_from_foreground
296                .iter()
297                .map(|(_, backtrace)| backtrace.clone())
298                .collect(),
299        );
300        self.spawned_from_foreground.push(
301            state
302                .spawned_from_foreground
303                .iter()
304                .map(|(_, backtrace)| backtrace.clone())
305                .collect(),
306        );
307        self.executed.push(executed);
308    }
309
310    fn resolve(&mut self) {
311        for backtrace in &mut self.executed {
312            backtrace.resolve();
313        }
314
315        for backtraces in &mut self.scheduled {
316            for backtrace in backtraces {
317                backtrace.resolve();
318            }
319        }
320
321        for backtraces in &mut self.spawned_from_foreground {
322            for backtrace in backtraces {
323                backtrace.resolve();
324            }
325        }
326    }
327}
328
329struct CwdBacktrace<'a> {
330    backtrace: &'a Backtrace,
331    first_frame_only: bool,
332}
333
334impl<'a> CwdBacktrace<'a> {
335    fn new(backtrace: &'a Backtrace) -> Self {
336        Self {
337            backtrace,
338            first_frame_only: false,
339        }
340    }
341
342    fn first_frame(backtrace: &'a Backtrace) -> Self {
343        Self {
344            backtrace,
345            first_frame_only: true,
346        }
347    }
348}
349
350impl<'a> Debug for CwdBacktrace<'a> {
351    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
352        let cwd = std::env::current_dir().unwrap();
353        let mut print_path = |fmt: &mut fmt::Formatter<'_>, path: BytesOrWideString<'_>| {
354            fmt::Display::fmt(&path, fmt)
355        };
356        let mut fmt = BacktraceFmt::new(f, backtrace::PrintFmt::Full, &mut print_path);
357        for frame in self.backtrace.frames() {
358            let mut formatted_frame = fmt.frame();
359            if frame
360                .symbols()
361                .iter()
362                .any(|s| s.filename().map_or(false, |f| f.starts_with(&cwd)))
363            {
364                formatted_frame.backtrace_frame(frame)?;
365                if self.first_frame_only {
366                    break;
367                }
368            }
369        }
370        fmt.finish()
371    }
372}
373
374impl Debug for Trace {
375    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
376        for ((backtrace, scheduled), spawned_from_foreground) in self
377            .executed
378            .iter()
379            .zip(&self.scheduled)
380            .zip(&self.spawned_from_foreground)
381        {
382            writeln!(f, "Scheduled")?;
383            for backtrace in scheduled {
384                writeln!(f, "- {:?}", CwdBacktrace::first_frame(backtrace))?;
385            }
386            if scheduled.is_empty() {
387                writeln!(f, "None")?;
388            }
389            writeln!(f, "==========")?;
390
391            writeln!(f, "Spawned from foreground")?;
392            for backtrace in spawned_from_foreground {
393                writeln!(f, "- {:?}", CwdBacktrace::first_frame(backtrace))?;
394            }
395            if spawned_from_foreground.is_empty() {
396                writeln!(f, "None")?;
397            }
398            writeln!(f, "==========")?;
399
400            writeln!(f, "Run: {:?}", CwdBacktrace::first_frame(backtrace))?;
401            writeln!(f, "+++++++++++++++++++")?;
402        }
403
404        Ok(())
405    }
406}
407
408impl Drop for Trace {
409    fn drop(&mut self) {
410        let trace_on_panic = if let Ok(trace_on_panic) = std::env::var("EXECUTOR_TRACE_ON_PANIC") {
411            trace_on_panic == "1" || trace_on_panic == "true"
412        } else {
413            false
414        };
415        let trace_always = if let Ok(trace_always) = std::env::var("EXECUTOR_TRACE_ALWAYS") {
416            trace_always == "1" || trace_always == "true"
417        } else {
418            false
419        };
420
421        if trace_always || (trace_on_panic && thread::panicking()) {
422            self.resolve();
423            dbg!(self);
424        }
425    }
426}
427
428impl Foreground {
429    pub fn platform(dispatcher: Arc<dyn platform::Dispatcher>) -> Result<Self> {
430        if dispatcher.is_main_thread() {
431            Ok(Self::Platform {
432                dispatcher,
433                _not_send_or_sync: PhantomData,
434            })
435        } else {
436            Err(anyhow!("must be constructed on main thread"))
437        }
438    }
439
440    pub fn spawn<T: 'static>(&self, future: impl Future<Output = T> + 'static) -> Task<T> {
441        let future = any_local_future(future);
442        let any_task = match self {
443            Self::Deterministic(executor) => executor.spawn_from_foreground(future),
444            Self::Platform { dispatcher, .. } => {
445                fn spawn_inner(
446                    future: AnyLocalFuture,
447                    dispatcher: &Arc<dyn Dispatcher>,
448                ) -> AnyLocalTask {
449                    let dispatcher = dispatcher.clone();
450                    let schedule =
451                        move |runnable: Runnable| dispatcher.run_on_main_thread(runnable);
452                    let (runnable, task) = async_task::spawn_local(future, schedule);
453                    runnable.schedule();
454                    task
455                }
456                spawn_inner(future, dispatcher)
457            }
458        };
459        Task::local(any_task)
460    }
461
462    pub fn run<T: 'static>(&self, future: impl 'static + Future<Output = T>) -> T {
463        let future = any_local_future(future);
464        let any_value = match self {
465            Self::Deterministic(executor) => executor.run(future),
466            Self::Platform { .. } => panic!("you can't call run on a platform foreground executor"),
467        };
468        *any_value.downcast().unwrap()
469    }
470
471    pub fn parking_forbidden(&self) -> bool {
472        match self {
473            Self::Deterministic(executor) => executor.state.lock().forbid_parking,
474            _ => panic!("this method can only be called on a deterministic executor"),
475        }
476    }
477
478    pub fn start_waiting(&self) {
479        match self {
480            Self::Deterministic(executor) => {
481                executor.state.lock().waiting_backtrace = Some(Backtrace::new_unresolved());
482            }
483            _ => panic!("this method can only be called on a deterministic executor"),
484        }
485    }
486
487    pub fn finish_waiting(&self) {
488        match self {
489            Self::Deterministic(executor) => {
490                executor.state.lock().waiting_backtrace.take();
491            }
492            _ => panic!("this method can only be called on a deterministic executor"),
493        }
494    }
495
496    pub fn forbid_parking(&self) {
497        match self {
498            Self::Deterministic(executor) => {
499                let mut state = executor.state.lock();
500                state.forbid_parking = true;
501                state.rng = StdRng::seed_from_u64(state.seed);
502            }
503            _ => panic!("this method can only be called on a deterministic executor"),
504        }
505    }
506
507    pub async fn timer(&self, duration: Duration) {
508        match self {
509            Self::Deterministic(executor) => {
510                let (tx, mut rx) = barrier::channel();
511                {
512                    let mut state = executor.state.lock();
513                    let wakeup_at = state.now + duration;
514                    state.pending_timers.push((wakeup_at, tx));
515                }
516                rx.recv().await;
517            }
518            _ => {
519                Timer::after(duration).await;
520            }
521        }
522    }
523
524    pub fn advance_clock(&self, duration: Duration) {
525        match self {
526            Self::Deterministic(executor) => {
527                executor.run_until_parked();
528
529                let mut state = executor.state.lock();
530                state.now += duration;
531                let now = state.now;
532                let mut pending_timers = mem::take(&mut state.pending_timers);
533                drop(state);
534
535                pending_timers.retain(|(wakeup, _)| *wakeup > now);
536                executor.state.lock().pending_timers.extend(pending_timers);
537            }
538            _ => panic!("this method can only be called on a deterministic executor"),
539        }
540    }
541
542    pub fn set_block_on_ticks(&self, range: RangeInclusive<usize>) {
543        match self {
544            Self::Deterministic(executor) => executor.state.lock().block_on_ticks = range,
545            _ => panic!("this method can only be called on a deterministic executor"),
546        }
547    }
548}
549
550impl Background {
551    pub fn new() -> Self {
552        let executor = Arc::new(Executor::new());
553        let stop = channel::unbounded::<()>();
554
555        for i in 0..2 * num_cpus::get() {
556            let executor = executor.clone();
557            let stop = stop.1.clone();
558            thread::Builder::new()
559                .name(format!("background-executor-{}", i))
560                .spawn(move || smol::block_on(executor.run(stop.recv())))
561                .unwrap();
562        }
563
564        Self::Production {
565            executor,
566            _stop: stop.0,
567        }
568    }
569
570    pub fn num_cpus(&self) -> usize {
571        num_cpus::get()
572    }
573
574    pub fn spawn<T, F>(&self, future: F) -> Task<T>
575    where
576        T: 'static + Send,
577        F: Send + Future<Output = T> + 'static,
578    {
579        let future = any_future(future);
580        let any_task = match self {
581            Self::Production { executor, .. } => executor.spawn(future),
582            Self::Deterministic { executor, .. } => executor.spawn(future),
583        };
584        Task::send(any_task)
585    }
586
587    pub fn block_with_timeout<F, T>(
588        &self,
589        timeout: Duration,
590        future: F,
591    ) -> Result<T, impl Future<Output = T>>
592    where
593        T: 'static,
594        F: 'static + Unpin + Future<Output = T>,
595    {
596        let mut future = any_local_future(future);
597        if !timeout.is_zero() {
598            let output = match self {
599                Self::Production { .. } => smol::block_on(util::timeout(timeout, &mut future)).ok(),
600                Self::Deterministic { executor, .. } => executor.block_on(&mut future),
601            };
602            if let Some(output) = output {
603                return Ok(*output.downcast().unwrap());
604            }
605        }
606        Err(async { *future.await.downcast().unwrap() })
607    }
608
609    pub async fn scoped<'scope, F>(&self, scheduler: F)
610    where
611        F: FnOnce(&mut Scope<'scope>),
612    {
613        let mut scope = Scope {
614            futures: Default::default(),
615            _phantom: PhantomData,
616        };
617        (scheduler)(&mut scope);
618        let spawned = scope
619            .futures
620            .into_iter()
621            .map(|f| self.spawn(f))
622            .collect::<Vec<_>>();
623        for task in spawned {
624            task.await;
625        }
626    }
627}
628
629pub struct Scope<'a> {
630    futures: Vec<Pin<Box<dyn Future<Output = ()> + Send + 'static>>>,
631    _phantom: PhantomData<&'a ()>,
632}
633
634impl<'a> Scope<'a> {
635    pub fn spawn<F>(&mut self, f: F)
636    where
637        F: Future<Output = ()> + Send + 'a,
638    {
639        let f = unsafe {
640            mem::transmute::<
641                Pin<Box<dyn Future<Output = ()> + Send + 'a>>,
642                Pin<Box<dyn Future<Output = ()> + Send + 'static>>,
643            >(Box::pin(f))
644        };
645        self.futures.push(f);
646    }
647}
648
649pub fn deterministic(seed: u64) -> (Rc<Foreground>, Arc<Background>) {
650    let executor = Arc::new(Deterministic::new(seed));
651    (
652        Rc::new(Foreground::Deterministic(executor.clone())),
653        Arc::new(Background::Deterministic { executor }),
654    )
655}
656
657impl<T> Task<T> {
658    pub fn ready(value: T) -> Self {
659        Self::Ready(Some(value))
660    }
661
662    fn local(any_task: AnyLocalTask) -> Self {
663        Self::Local {
664            any_task,
665            result_type: PhantomData,
666        }
667    }
668
669    pub fn detach(self) {
670        match self {
671            Task::Ready(_) => {}
672            Task::Local { any_task, .. } => any_task.detach(),
673            Task::Send { any_task, .. } => any_task.detach(),
674        }
675    }
676}
677
678impl<T: 'static, E: 'static + Display> Task<Result<T, E>> {
679    pub fn detach_and_log_err(self, cx: &mut MutableAppContext) {
680        cx.spawn(|_| async move {
681            if let Err(err) = self.await {
682                log::error!("{}", err);
683            }
684        })
685        .detach();
686    }
687}
688
689impl<T: Send> Task<T> {
690    fn send(any_task: AnyTask) -> Self {
691        Self::Send {
692            any_task,
693            result_type: PhantomData,
694        }
695    }
696}
697
698impl<T: fmt::Debug> fmt::Debug for Task<T> {
699    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
700        match self {
701            Task::Ready(value) => value.fmt(f),
702            Task::Local { any_task, .. } => any_task.fmt(f),
703            Task::Send { any_task, .. } => any_task.fmt(f),
704        }
705    }
706}
707
708impl<T: 'static> Future for Task<T> {
709    type Output = T;
710
711    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
712        match unsafe { self.get_unchecked_mut() } {
713            Task::Ready(value) => Poll::Ready(value.take().unwrap()),
714            Task::Local { any_task, .. } => {
715                any_task.poll(cx).map(|value| *value.downcast().unwrap())
716            }
717            Task::Send { any_task, .. } => {
718                any_task.poll(cx).map(|value| *value.downcast().unwrap())
719            }
720        }
721    }
722}
723
724fn any_future<T, F>(future: F) -> AnyFuture
725where
726    T: 'static + Send,
727    F: Future<Output = T> + Send + 'static,
728{
729    async { Box::new(future.await) as Box<dyn Any + Send> }.boxed()
730}
731
732fn any_local_future<T, F>(future: F) -> AnyLocalFuture
733where
734    T: 'static,
735    F: Future<Output = T> + 'static,
736{
737    async { Box::new(future.await) as Box<dyn Any> }.boxed_local()
738}