executor.rs

  1use anyhow::{anyhow, Result};
  2use async_task::Runnable;
  3use backtrace::{Backtrace, BacktraceFmt, BytesOrWideString};
  4use collections::HashMap;
  5use parking_lot::Mutex;
  6use postage::{barrier, prelude::Stream as _};
  7use rand::prelude::*;
  8use smol::{channel, future::yield_now, prelude::*, Executor, Timer};
  9use std::{
 10    any::Any,
 11    fmt::{self, Debug, Display},
 12    marker::PhantomData,
 13    mem,
 14    ops::RangeInclusive,
 15    pin::Pin,
 16    rc::Rc,
 17    sync::{
 18        atomic::{AtomicBool, Ordering::SeqCst},
 19        Arc,
 20    },
 21    task::{Context, Poll},
 22    thread,
 23    time::{Duration, Instant},
 24};
 25use waker_fn::waker_fn;
 26
 27use crate::{
 28    platform::{self, Dispatcher},
 29    util, MutableAppContext,
 30};
 31
 32pub enum Foreground {
 33    Platform {
 34        dispatcher: Arc<dyn platform::Dispatcher>,
 35        _not_send_or_sync: PhantomData<Rc<()>>,
 36    },
 37    Deterministic {
 38        cx_id: usize,
 39        executor: Arc<Deterministic>,
 40    },
 41}
 42
 43pub enum Background {
 44    Deterministic {
 45        executor: Arc<Deterministic>,
 46    },
 47    Production {
 48        executor: Arc<smol::Executor<'static>>,
 49        _stop: channel::Sender<()>,
 50    },
 51}
 52
 53type AnyLocalFuture = Pin<Box<dyn 'static + Future<Output = Box<dyn Any + 'static>>>>;
 54type AnyFuture = Pin<Box<dyn 'static + Send + Future<Output = Box<dyn Any + Send + 'static>>>>;
 55type AnyTask = async_task::Task<Box<dyn Any + Send + 'static>>;
 56type AnyLocalTask = async_task::Task<Box<dyn Any + 'static>>;
 57
 58#[must_use]
 59pub enum Task<T> {
 60    Ready(Option<T>),
 61    Local {
 62        any_task: AnyLocalTask,
 63        result_type: PhantomData<T>,
 64    },
 65    Send {
 66        any_task: AnyTask,
 67        result_type: PhantomData<T>,
 68    },
 69}
 70
 71unsafe impl<T: Send> Send for Task<T> {}
 72
 73struct DeterministicState {
 74    rng: StdRng,
 75    seed: u64,
 76    scheduled_from_foreground: HashMap<usize, Vec<ForegroundRunnable>>,
 77    scheduled_from_background: Vec<Runnable>,
 78    forbid_parking: bool,
 79    block_on_ticks: RangeInclusive<usize>,
 80    now: Instant,
 81    pending_timers: Vec<(Instant, barrier::Sender)>,
 82    waiting_backtrace: Option<Backtrace>,
 83}
 84
 85struct ForegroundRunnable {
 86    runnable: Runnable,
 87    main: bool,
 88}
 89
 90pub struct Deterministic {
 91    state: Arc<Mutex<DeterministicState>>,
 92    parker: Mutex<parking::Parker>,
 93}
 94
 95impl Deterministic {
 96    pub fn new(seed: u64) -> Arc<Self> {
 97        Arc::new(Self {
 98            state: Arc::new(Mutex::new(DeterministicState {
 99                rng: StdRng::seed_from_u64(seed),
100                seed,
101                scheduled_from_foreground: Default::default(),
102                scheduled_from_background: Default::default(),
103                forbid_parking: false,
104                block_on_ticks: 0..=1000,
105                now: Instant::now(),
106                pending_timers: Default::default(),
107                waiting_backtrace: None,
108            })),
109            parker: Default::default(),
110        })
111    }
112
113    pub fn build_background(self: &Arc<Self>) -> Arc<Background> {
114        Arc::new(Background::Deterministic {
115            executor: self.clone(),
116        })
117    }
118
119    pub fn build_foreground(self: &Arc<Self>, id: usize) -> Rc<Foreground> {
120        Rc::new(Foreground::Deterministic {
121            cx_id: id,
122            executor: self.clone(),
123        })
124    }
125
126    fn spawn_from_foreground(
127        &self,
128        cx_id: usize,
129        future: AnyLocalFuture,
130        main: bool,
131    ) -> AnyLocalTask {
132        let state = self.state.clone();
133        let unparker = self.parker.lock().unparker();
134        let (runnable, task) = async_task::spawn_local(future, move |runnable| {
135            let mut state = state.lock();
136            state
137                .scheduled_from_foreground
138                .entry(cx_id)
139                .or_default()
140                .push(ForegroundRunnable { runnable, main });
141            unparker.unpark();
142        });
143        runnable.schedule();
144        task
145    }
146
147    fn spawn(&self, future: AnyFuture) -> AnyTask {
148        let state = self.state.clone();
149        let unparker = self.parker.lock().unparker();
150        let (runnable, task) = async_task::spawn(future, move |runnable| {
151            let mut state = state.lock();
152            state.scheduled_from_background.push(runnable);
153            unparker.unpark();
154        });
155        runnable.schedule();
156        task
157    }
158
159    fn run(&self, cx_id: usize, main_future: AnyLocalFuture) -> Box<dyn Any> {
160        let woken = Arc::new(AtomicBool::new(false));
161        let mut main_task = self.spawn_from_foreground(cx_id, main_future, true);
162
163        loop {
164            if let Some(result) = self.run_internal(woken.clone(), Some(&mut main_task)) {
165                return result;
166            }
167
168            if !woken.load(SeqCst) {
169                self.state.lock().will_park();
170            }
171
172            woken.store(false, SeqCst);
173            self.parker.lock().park();
174        }
175    }
176
177    fn run_until_parked(&self) {
178        let woken = Arc::new(AtomicBool::new(false));
179        self.run_internal(woken, None);
180    }
181
182    fn run_internal(
183        &self,
184        woken: Arc<AtomicBool>,
185        mut main_task: Option<&mut AnyLocalTask>,
186    ) -> Option<Box<dyn Any>> {
187        let unparker = self.parker.lock().unparker();
188        let waker = waker_fn(move || {
189            woken.store(true, SeqCst);
190            unparker.unpark();
191        });
192
193        let mut cx = Context::from_waker(&waker);
194        loop {
195            let mut state = self.state.lock();
196
197            if state.scheduled_from_foreground.is_empty()
198                && state.scheduled_from_background.is_empty()
199            {
200                return None;
201            }
202
203            if !state.scheduled_from_background.is_empty() && state.rng.gen() {
204                let background_len = state.scheduled_from_background.len();
205                let ix = state.rng.gen_range(0..background_len);
206                let runnable = state.scheduled_from_background.remove(ix);
207                drop(state);
208                runnable.run();
209            } else if !state.scheduled_from_foreground.is_empty() {
210                let available_cx_ids = state
211                    .scheduled_from_foreground
212                    .keys()
213                    .copied()
214                    .collect::<Vec<_>>();
215                let cx_id_to_run = *available_cx_ids.iter().choose(&mut state.rng).unwrap();
216                let scheduled_from_cx = state
217                    .scheduled_from_foreground
218                    .get_mut(&cx_id_to_run)
219                    .unwrap();
220                let foreground_runnable = scheduled_from_cx.remove(0);
221                if scheduled_from_cx.is_empty() {
222                    state.scheduled_from_foreground.remove(&cx_id_to_run);
223                }
224
225                drop(state);
226
227                foreground_runnable.runnable.run();
228                if let Some(main_task) = main_task.as_mut() {
229                    if foreground_runnable.main {
230                        if let Poll::Ready(result) = main_task.poll(&mut cx) {
231                            return Some(result);
232                        }
233                    }
234                }
235            }
236        }
237    }
238
239    fn block<F, T>(&self, future: &mut F, max_ticks: usize) -> Option<T>
240    where
241        F: Unpin + Future<Output = T>,
242    {
243        let unparker = self.parker.lock().unparker();
244        let waker = waker_fn(move || {
245            unparker.unpark();
246        });
247
248        let mut cx = Context::from_waker(&waker);
249        for _ in 0..max_ticks {
250            let mut state = self.state.lock();
251            let runnable_count = state.scheduled_from_background.len();
252            let ix = state.rng.gen_range(0..=runnable_count);
253            if ix < state.scheduled_from_background.len() {
254                let runnable = state.scheduled_from_background.remove(ix);
255                drop(state);
256                runnable.run();
257            } else {
258                drop(state);
259                if let Poll::Ready(result) = future.poll(&mut cx) {
260                    return Some(result);
261                }
262                let mut state = self.state.lock();
263                if state.scheduled_from_background.is_empty() {
264                    state.will_park();
265                    drop(state);
266                    self.parker.lock().park();
267                }
268
269                continue;
270            }
271        }
272
273        None
274    }
275}
276
277impl DeterministicState {
278    fn will_park(&mut self) {
279        if self.forbid_parking {
280            let mut backtrace_message = String::new();
281            if let Some(backtrace) = self.waiting_backtrace.as_mut() {
282                backtrace.resolve();
283                backtrace_message = format!(
284                    "\nbacktrace of waiting future:\n{:?}",
285                    CwdBacktrace::new(backtrace)
286                );
287            }
288
289            panic!(
290                "deterministic executor parked after a call to forbid_parking{}",
291                backtrace_message
292            );
293        }
294    }
295}
296
297struct CwdBacktrace<'a> {
298    backtrace: &'a Backtrace,
299}
300
301impl<'a> CwdBacktrace<'a> {
302    fn new(backtrace: &'a Backtrace) -> Self {
303        Self { backtrace }
304    }
305}
306
307impl<'a> Debug for CwdBacktrace<'a> {
308    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
309        let cwd = std::env::current_dir().unwrap();
310        let mut print_path = |fmt: &mut fmt::Formatter<'_>, path: BytesOrWideString<'_>| {
311            fmt::Display::fmt(&path, fmt)
312        };
313        let mut fmt = BacktraceFmt::new(f, backtrace::PrintFmt::Full, &mut print_path);
314        for frame in self.backtrace.frames() {
315            let mut formatted_frame = fmt.frame();
316            if frame
317                .symbols()
318                .iter()
319                .any(|s| s.filename().map_or(false, |f| f.starts_with(&cwd)))
320            {
321                formatted_frame.backtrace_frame(frame)?;
322            }
323        }
324        fmt.finish()
325    }
326}
327
328impl Foreground {
329    pub fn platform(dispatcher: Arc<dyn platform::Dispatcher>) -> Result<Self> {
330        if dispatcher.is_main_thread() {
331            Ok(Self::Platform {
332                dispatcher,
333                _not_send_or_sync: PhantomData,
334            })
335        } else {
336            Err(anyhow!("must be constructed on main thread"))
337        }
338    }
339
340    pub fn spawn<T: 'static>(&self, future: impl Future<Output = T> + 'static) -> Task<T> {
341        let future = any_local_future(future);
342        let any_task = match self {
343            Self::Deterministic { cx_id, executor } => {
344                executor.spawn_from_foreground(*cx_id, future, false)
345            }
346            Self::Platform { dispatcher, .. } => {
347                fn spawn_inner(
348                    future: AnyLocalFuture,
349                    dispatcher: &Arc<dyn Dispatcher>,
350                ) -> AnyLocalTask {
351                    let dispatcher = dispatcher.clone();
352                    let schedule =
353                        move |runnable: Runnable| dispatcher.run_on_main_thread(runnable);
354                    let (runnable, task) = async_task::spawn_local(future, schedule);
355                    runnable.schedule();
356                    task
357                }
358                spawn_inner(future, dispatcher)
359            }
360        };
361        Task::local(any_task)
362    }
363
364    pub fn run<T: 'static>(&self, future: impl 'static + Future<Output = T>) -> T {
365        let future = any_local_future(future);
366        let any_value = match self {
367            Self::Deterministic { cx_id, executor } => executor.run(*cx_id, future),
368            Self::Platform { .. } => panic!("you can't call run on a platform foreground executor"),
369        };
370        *any_value.downcast().unwrap()
371    }
372
373    pub fn parking_forbidden(&self) -> bool {
374        match self {
375            Self::Deterministic { executor, .. } => executor.state.lock().forbid_parking,
376            _ => panic!("this method can only be called on a deterministic executor"),
377        }
378    }
379
380    pub fn start_waiting(&self) {
381        match self {
382            Self::Deterministic { executor, .. } => {
383                executor.state.lock().waiting_backtrace = Some(Backtrace::new_unresolved());
384            }
385            _ => panic!("this method can only be called on a deterministic executor"),
386        }
387    }
388
389    pub fn finish_waiting(&self) {
390        match self {
391            Self::Deterministic { executor, .. } => {
392                executor.state.lock().waiting_backtrace.take();
393            }
394            _ => panic!("this method can only be called on a deterministic executor"),
395        }
396    }
397
398    pub fn forbid_parking(&self) {
399        match self {
400            Self::Deterministic { executor, .. } => {
401                let mut state = executor.state.lock();
402                state.forbid_parking = true;
403                state.rng = StdRng::seed_from_u64(state.seed);
404            }
405            _ => panic!("this method can only be called on a deterministic executor"),
406        }
407    }
408
409    pub async fn timer(&self, duration: Duration) {
410        match self {
411            Self::Deterministic { executor, .. } => {
412                let (tx, mut rx) = barrier::channel();
413                {
414                    let mut state = executor.state.lock();
415                    let wakeup_at = state.now + duration;
416                    state.pending_timers.push((wakeup_at, tx));
417                }
418                rx.recv().await;
419            }
420            _ => {
421                Timer::after(duration).await;
422            }
423        }
424    }
425
426    pub fn advance_clock(&self, duration: Duration) {
427        match self {
428            Self::Deterministic { executor, .. } => {
429                executor.run_until_parked();
430
431                let mut state = executor.state.lock();
432                state.now += duration;
433                let now = state.now;
434                let mut pending_timers = mem::take(&mut state.pending_timers);
435                drop(state);
436
437                pending_timers.retain(|(wakeup, _)| *wakeup > now);
438                executor.state.lock().pending_timers.extend(pending_timers);
439            }
440            _ => panic!("this method can only be called on a deterministic executor"),
441        }
442    }
443
444    pub fn set_block_on_ticks(&self, range: RangeInclusive<usize>) {
445        match self {
446            Self::Deterministic { executor, .. } => executor.state.lock().block_on_ticks = range,
447            _ => panic!("this method can only be called on a deterministic executor"),
448        }
449    }
450}
451
452impl Background {
453    pub fn new() -> Self {
454        let executor = Arc::new(Executor::new());
455        let stop = channel::unbounded::<()>();
456
457        for i in 0..2 * num_cpus::get() {
458            let executor = executor.clone();
459            let stop = stop.1.clone();
460            thread::Builder::new()
461                .name(format!("background-executor-{}", i))
462                .spawn(move || smol::block_on(executor.run(stop.recv())))
463                .unwrap();
464        }
465
466        Self::Production {
467            executor,
468            _stop: stop.0,
469        }
470    }
471
472    pub fn num_cpus(&self) -> usize {
473        num_cpus::get()
474    }
475
476    pub fn spawn<T, F>(&self, future: F) -> Task<T>
477    where
478        T: 'static + Send,
479        F: Send + Future<Output = T> + 'static,
480    {
481        let future = any_future(future);
482        let any_task = match self {
483            Self::Production { executor, .. } => executor.spawn(future),
484            Self::Deterministic { executor } => executor.spawn(future),
485        };
486        Task::send(any_task)
487    }
488
489    pub fn block<F, T>(&self, future: F) -> T
490    where
491        F: Future<Output = T>,
492    {
493        smol::pin!(future);
494        match self {
495            Self::Production { .. } => smol::block_on(&mut future),
496            Self::Deterministic { executor, .. } => {
497                executor.block(&mut future, usize::MAX).unwrap()
498            }
499        }
500    }
501
502    pub fn block_with_timeout<F, T>(
503        &self,
504        timeout: Duration,
505        future: F,
506    ) -> Result<T, impl Future<Output = T>>
507    where
508        T: 'static,
509        F: 'static + Unpin + Future<Output = T>,
510    {
511        let mut future = any_local_future(future);
512        if !timeout.is_zero() {
513            let output = match self {
514                Self::Production { .. } => smol::block_on(util::timeout(timeout, &mut future)).ok(),
515                Self::Deterministic { executor, .. } => {
516                    let max_ticks = {
517                        let mut state = executor.state.lock();
518                        let range = state.block_on_ticks.clone();
519                        state.rng.gen_range(range)
520                    };
521                    executor.block(&mut future, max_ticks)
522                }
523            };
524            if let Some(output) = output {
525                return Ok(*output.downcast().unwrap());
526            }
527        }
528        Err(async { *future.await.downcast().unwrap() })
529    }
530
531    pub async fn scoped<'scope, F>(&self, scheduler: F)
532    where
533        F: FnOnce(&mut Scope<'scope>),
534    {
535        let mut scope = Scope {
536            futures: Default::default(),
537            _phantom: PhantomData,
538        };
539        (scheduler)(&mut scope);
540        let spawned = scope
541            .futures
542            .into_iter()
543            .map(|f| self.spawn(f))
544            .collect::<Vec<_>>();
545        for task in spawned {
546            task.await;
547        }
548    }
549
550    pub async fn simulate_random_delay(&self) {
551        match self {
552            Self::Deterministic { executor, .. } => {
553                if executor.state.lock().rng.gen_bool(0.2) {
554                    let yields = executor.state.lock().rng.gen_range(1..=10);
555                    for _ in 0..yields {
556                        yield_now().await;
557                    }
558                }
559            }
560            _ => panic!("this method can only be called on a deterministic executor"),
561        }
562    }
563}
564
565pub struct Scope<'a> {
566    futures: Vec<Pin<Box<dyn Future<Output = ()> + Send + 'static>>>,
567    _phantom: PhantomData<&'a ()>,
568}
569
570impl<'a> Scope<'a> {
571    pub fn spawn<F>(&mut self, f: F)
572    where
573        F: Future<Output = ()> + Send + 'a,
574    {
575        let f = unsafe {
576            mem::transmute::<
577                Pin<Box<dyn Future<Output = ()> + Send + 'a>>,
578                Pin<Box<dyn Future<Output = ()> + Send + 'static>>,
579            >(Box::pin(f))
580        };
581        self.futures.push(f);
582    }
583}
584
585impl<T> Task<T> {
586    pub fn ready(value: T) -> Self {
587        Self::Ready(Some(value))
588    }
589
590    fn local(any_task: AnyLocalTask) -> Self {
591        Self::Local {
592            any_task,
593            result_type: PhantomData,
594        }
595    }
596
597    pub fn detach(self) {
598        match self {
599            Task::Ready(_) => {}
600            Task::Local { any_task, .. } => any_task.detach(),
601            Task::Send { any_task, .. } => any_task.detach(),
602        }
603    }
604}
605
606impl<T: 'static, E: 'static + Display> Task<Result<T, E>> {
607    pub fn detach_and_log_err(self, cx: &mut MutableAppContext) {
608        cx.spawn(|_| async move {
609            if let Err(err) = self.await {
610                log::error!("{}", err);
611            }
612        })
613        .detach();
614    }
615}
616
617impl<T: Send> Task<T> {
618    fn send(any_task: AnyTask) -> Self {
619        Self::Send {
620            any_task,
621            result_type: PhantomData,
622        }
623    }
624}
625
626impl<T: fmt::Debug> fmt::Debug for Task<T> {
627    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
628        match self {
629            Task::Ready(value) => value.fmt(f),
630            Task::Local { any_task, .. } => any_task.fmt(f),
631            Task::Send { any_task, .. } => any_task.fmt(f),
632        }
633    }
634}
635
636impl<T: 'static> Future for Task<T> {
637    type Output = T;
638
639    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
640        match unsafe { self.get_unchecked_mut() } {
641            Task::Ready(value) => Poll::Ready(value.take().unwrap()),
642            Task::Local { any_task, .. } => {
643                any_task.poll(cx).map(|value| *value.downcast().unwrap())
644            }
645            Task::Send { any_task, .. } => {
646                any_task.poll(cx).map(|value| *value.downcast().unwrap())
647            }
648        }
649    }
650}
651
652fn any_future<T, F>(future: F) -> AnyFuture
653where
654    T: 'static + Send,
655    F: Future<Output = T> + Send + 'static,
656{
657    async { Box::new(future.await) as Box<dyn Any + Send> }.boxed()
658}
659
660fn any_local_future<T, F>(future: F) -> AnyLocalFuture
661where
662    T: 'static,
663    F: Future<Output = T> + 'static,
664{
665    async { Box::new(future.await) as Box<dyn Any> }.boxed_local()
666}