executor.rs

  1use anyhow::{anyhow, Result};
  2use async_task::Runnable;
  3use smol::{channel, prelude::*, Executor, Timer};
  4use std::{
  5    any::Any,
  6    fmt::{self, Display},
  7    marker::PhantomData,
  8    mem,
  9    pin::Pin,
 10    rc::Rc,
 11    sync::Arc,
 12    task::{Context, Poll},
 13    thread,
 14    time::Duration,
 15};
 16
 17use crate::{
 18    platform::{self, Dispatcher},
 19    util, MutableAppContext,
 20};
 21
 22pub enum Foreground {
 23    Platform {
 24        dispatcher: Arc<dyn platform::Dispatcher>,
 25        _not_send_or_sync: PhantomData<Rc<()>>,
 26    },
 27    #[cfg(any(test, feature = "test-support"))]
 28    Deterministic {
 29        cx_id: usize,
 30        executor: Arc<Deterministic>,
 31    },
 32}
 33
 34pub enum Background {
 35    #[cfg(any(test, feature = "test-support"))]
 36    Deterministic { executor: Arc<Deterministic> },
 37    Production {
 38        executor: Arc<smol::Executor<'static>>,
 39        _stop: channel::Sender<()>,
 40    },
 41}
 42
 43type AnyLocalFuture = Pin<Box<dyn 'static + Future<Output = Box<dyn Any + 'static>>>>;
 44type AnyFuture = Pin<Box<dyn 'static + Send + Future<Output = Box<dyn Any + Send + 'static>>>>;
 45type AnyTask = async_task::Task<Box<dyn Any + Send + 'static>>;
 46type AnyLocalTask = async_task::Task<Box<dyn Any + 'static>>;
 47
 48#[must_use]
 49pub enum Task<T> {
 50    Ready(Option<T>),
 51    Local {
 52        any_task: AnyLocalTask,
 53        result_type: PhantomData<T>,
 54    },
 55    Send {
 56        any_task: AnyTask,
 57        result_type: PhantomData<T>,
 58    },
 59}
 60
 61unsafe impl<T: Send> Send for Task<T> {}
 62
 63#[cfg(any(test, feature = "test-support"))]
 64struct DeterministicState {
 65    rng: rand::prelude::StdRng,
 66    seed: u64,
 67    scheduled_from_foreground: collections::HashMap<usize, Vec<ForegroundRunnable>>,
 68    scheduled_from_background: Vec<Runnable>,
 69    forbid_parking: bool,
 70    block_on_ticks: std::ops::RangeInclusive<usize>,
 71    now: std::time::Instant,
 72    next_timer_id: usize,
 73    pending_timers: Vec<(usize, std::time::Instant, postage::barrier::Sender)>,
 74    waiting_backtrace: Option<backtrace::Backtrace>,
 75}
 76
 77#[cfg(any(test, feature = "test-support"))]
 78struct ForegroundRunnable {
 79    runnable: Runnable,
 80    main: bool,
 81}
 82
 83#[cfg(any(test, feature = "test-support"))]
 84pub struct Deterministic {
 85    state: Arc<parking_lot::Mutex<DeterministicState>>,
 86    parker: parking_lot::Mutex<parking::Parker>,
 87}
 88
 89#[cfg(any(test, feature = "test-support"))]
 90impl Deterministic {
 91    pub fn new(seed: u64) -> Arc<Self> {
 92        use rand::prelude::*;
 93
 94        Arc::new(Self {
 95            state: Arc::new(parking_lot::Mutex::new(DeterministicState {
 96                rng: StdRng::seed_from_u64(seed),
 97                seed,
 98                scheduled_from_foreground: Default::default(),
 99                scheduled_from_background: Default::default(),
100                forbid_parking: false,
101                block_on_ticks: 0..=1000,
102                now: std::time::Instant::now(),
103                next_timer_id: Default::default(),
104                pending_timers: Default::default(),
105                waiting_backtrace: None,
106            })),
107            parker: Default::default(),
108        })
109    }
110
111    pub fn build_background(self: &Arc<Self>) -> Arc<Background> {
112        Arc::new(Background::Deterministic {
113            executor: self.clone(),
114        })
115    }
116
117    pub fn build_foreground(self: &Arc<Self>, id: usize) -> Rc<Foreground> {
118        Rc::new(Foreground::Deterministic {
119            cx_id: id,
120            executor: self.clone(),
121        })
122    }
123
124    fn spawn_from_foreground(
125        &self,
126        cx_id: usize,
127        future: AnyLocalFuture,
128        main: bool,
129    ) -> AnyLocalTask {
130        let state = self.state.clone();
131        let unparker = self.parker.lock().unparker();
132        let (runnable, task) = async_task::spawn_local(future, move |runnable| {
133            let mut state = state.lock();
134            state
135                .scheduled_from_foreground
136                .entry(cx_id)
137                .or_default()
138                .push(ForegroundRunnable { runnable, main });
139            unparker.unpark();
140        });
141        runnable.schedule();
142        task
143    }
144
145    fn spawn(&self, future: AnyFuture) -> AnyTask {
146        let state = self.state.clone();
147        let unparker = self.parker.lock().unparker();
148        let (runnable, task) = async_task::spawn(future, move |runnable| {
149            let mut state = state.lock();
150            state.scheduled_from_background.push(runnable);
151            unparker.unpark();
152        });
153        runnable.schedule();
154        task
155    }
156
157    fn run<'a>(
158        &self,
159        cx_id: usize,
160        main_future: Pin<Box<dyn 'a + Future<Output = Box<dyn Any>>>>,
161    ) -> Box<dyn Any> {
162        use std::sync::atomic::{AtomicBool, Ordering::SeqCst};
163
164        let woken = Arc::new(AtomicBool::new(false));
165
166        let state = self.state.clone();
167        let unparker = self.parker.lock().unparker();
168        let (runnable, mut main_task) = unsafe {
169            async_task::spawn_unchecked(main_future, move |runnable| {
170                let mut state = state.lock();
171                state
172                    .scheduled_from_foreground
173                    .entry(cx_id)
174                    .or_default()
175                    .push(ForegroundRunnable {
176                        runnable,
177                        main: true,
178                    });
179                unparker.unpark();
180            })
181        };
182        runnable.schedule();
183
184        loop {
185            if let Some(result) = self.run_internal(woken.clone(), Some(&mut main_task)) {
186                return result;
187            }
188
189            if !woken.load(SeqCst) {
190                self.state.lock().will_park();
191            }
192
193            woken.store(false, SeqCst);
194            self.parker.lock().park();
195        }
196    }
197
198    pub fn run_until_parked(&self) {
199        use std::sync::atomic::AtomicBool;
200        let woken = Arc::new(AtomicBool::new(false));
201        self.run_internal(woken, None);
202    }
203
204    fn run_internal(
205        &self,
206        woken: Arc<std::sync::atomic::AtomicBool>,
207        mut main_task: Option<&mut AnyLocalTask>,
208    ) -> Option<Box<dyn Any>> {
209        use rand::prelude::*;
210        use std::sync::atomic::Ordering::SeqCst;
211
212        let unparker = self.parker.lock().unparker();
213        let waker = waker_fn::waker_fn(move || {
214            woken.store(true, SeqCst);
215            unparker.unpark();
216        });
217
218        let mut cx = Context::from_waker(&waker);
219        loop {
220            let mut state = self.state.lock();
221
222            if state.scheduled_from_foreground.is_empty()
223                && state.scheduled_from_background.is_empty()
224            {
225                if let Some(main_task) = main_task {
226                    if let Poll::Ready(result) = main_task.poll(&mut cx) {
227                        return Some(result);
228                    }
229                }
230
231                return None;
232            }
233
234            if !state.scheduled_from_background.is_empty() && state.rng.gen() {
235                let background_len = state.scheduled_from_background.len();
236                let ix = state.rng.gen_range(0..background_len);
237                let runnable = state.scheduled_from_background.remove(ix);
238                drop(state);
239                runnable.run();
240            } else if !state.scheduled_from_foreground.is_empty() {
241                let available_cx_ids = state
242                    .scheduled_from_foreground
243                    .keys()
244                    .copied()
245                    .collect::<Vec<_>>();
246                let cx_id_to_run = *available_cx_ids.iter().choose(&mut state.rng).unwrap();
247                let scheduled_from_cx = state
248                    .scheduled_from_foreground
249                    .get_mut(&cx_id_to_run)
250                    .unwrap();
251                let foreground_runnable = scheduled_from_cx.remove(0);
252                if scheduled_from_cx.is_empty() {
253                    state.scheduled_from_foreground.remove(&cx_id_to_run);
254                }
255
256                drop(state);
257
258                foreground_runnable.runnable.run();
259                if let Some(main_task) = main_task.as_mut() {
260                    if foreground_runnable.main {
261                        if let Poll::Ready(result) = main_task.poll(&mut cx) {
262                            return Some(result);
263                        }
264                    }
265                }
266            }
267        }
268    }
269
270    fn block<F, T>(&self, future: &mut F, max_ticks: usize) -> Option<T>
271    where
272        F: Unpin + Future<Output = T>,
273    {
274        use rand::prelude::*;
275
276        let unparker = self.parker.lock().unparker();
277        let waker = waker_fn::waker_fn(move || {
278            unparker.unpark();
279        });
280
281        let mut cx = Context::from_waker(&waker);
282        for _ in 0..max_ticks {
283            let mut state = self.state.lock();
284            let runnable_count = state.scheduled_from_background.len();
285            let ix = state.rng.gen_range(0..=runnable_count);
286            if ix < state.scheduled_from_background.len() {
287                let runnable = state.scheduled_from_background.remove(ix);
288                drop(state);
289                runnable.run();
290            } else {
291                drop(state);
292                if let Poll::Ready(result) = future.poll(&mut cx) {
293                    return Some(result);
294                }
295                let mut state = self.state.lock();
296                if state.scheduled_from_background.is_empty() {
297                    state.will_park();
298                    drop(state);
299                    self.parker.lock().park();
300                }
301
302                continue;
303            }
304        }
305
306        None
307    }
308
309    pub fn advance_clock(&self, duration: Duration) {
310        let mut state = self.state.lock();
311        state.now += duration;
312        let now = state.now;
313        let mut pending_timers = mem::take(&mut state.pending_timers);
314        drop(state);
315
316        pending_timers.retain(|(_, wakeup, _)| *wakeup > now);
317        self.state.lock().pending_timers.extend(pending_timers);
318    }
319}
320
321#[cfg(any(test, feature = "test-support"))]
322impl DeterministicState {
323    fn will_park(&mut self) {
324        if self.forbid_parking {
325            let mut backtrace_message = String::new();
326            #[cfg(any(test, feature = "test-support"))]
327            if let Some(backtrace) = self.waiting_backtrace.as_mut() {
328                backtrace.resolve();
329                backtrace_message = format!(
330                    "\nbacktrace of waiting future:\n{:?}",
331                    util::CwdBacktrace(backtrace)
332                );
333            }
334
335            panic!(
336                "deterministic executor parked after a call to forbid_parking{}",
337                backtrace_message
338            );
339        }
340    }
341}
342
343impl Foreground {
344    pub fn platform(dispatcher: Arc<dyn platform::Dispatcher>) -> Result<Self> {
345        if dispatcher.is_main_thread() {
346            Ok(Self::Platform {
347                dispatcher,
348                _not_send_or_sync: PhantomData,
349            })
350        } else {
351            Err(anyhow!("must be constructed on main thread"))
352        }
353    }
354
355    pub fn spawn<T: 'static>(&self, future: impl Future<Output = T> + 'static) -> Task<T> {
356        let future = any_local_future(future);
357        let any_task = match self {
358            #[cfg(any(test, feature = "test-support"))]
359            Self::Deterministic { cx_id, executor } => {
360                executor.spawn_from_foreground(*cx_id, future, false)
361            }
362            Self::Platform { dispatcher, .. } => {
363                fn spawn_inner(
364                    future: AnyLocalFuture,
365                    dispatcher: &Arc<dyn Dispatcher>,
366                ) -> AnyLocalTask {
367                    let dispatcher = dispatcher.clone();
368                    let schedule =
369                        move |runnable: Runnable| dispatcher.run_on_main_thread(runnable);
370                    let (runnable, task) = async_task::spawn_local(future, schedule);
371                    runnable.schedule();
372                    task
373                }
374                spawn_inner(future, dispatcher)
375            }
376        };
377        Task::local(any_task)
378    }
379
380    #[cfg(any(test, feature = "test-support"))]
381    pub fn run<T: 'static>(&self, future: impl Future<Output = T>) -> T {
382        let future = async move { Box::new(future.await) as Box<dyn Any> }.boxed_local();
383        let result = match self {
384            Self::Deterministic { cx_id, executor } => executor.run(*cx_id, future),
385            Self::Platform { .. } => panic!("you can't call run on a platform foreground executor"),
386        };
387        *result.downcast().unwrap()
388    }
389
390    #[cfg(any(test, feature = "test-support"))]
391    pub fn run_until_parked(&self) {
392        match self {
393            Self::Deterministic { executor, .. } => executor.run_until_parked(),
394            _ => panic!("this method can only be called on a deterministic executor"),
395        }
396    }
397
398    #[cfg(any(test, feature = "test-support"))]
399    pub fn parking_forbidden(&self) -> bool {
400        match self {
401            Self::Deterministic { executor, .. } => executor.state.lock().forbid_parking,
402            _ => panic!("this method can only be called on a deterministic executor"),
403        }
404    }
405
406    #[cfg(any(test, feature = "test-support"))]
407    pub fn start_waiting(&self) {
408        match self {
409            Self::Deterministic { executor, .. } => {
410                executor.state.lock().waiting_backtrace =
411                    Some(backtrace::Backtrace::new_unresolved());
412            }
413            _ => panic!("this method can only be called on a deterministic executor"),
414        }
415    }
416
417    #[cfg(any(test, feature = "test-support"))]
418    pub fn finish_waiting(&self) {
419        match self {
420            Self::Deterministic { executor, .. } => {
421                executor.state.lock().waiting_backtrace.take();
422            }
423            _ => panic!("this method can only be called on a deterministic executor"),
424        }
425    }
426
427    #[cfg(any(test, feature = "test-support"))]
428    pub fn forbid_parking(&self) {
429        use rand::prelude::*;
430
431        match self {
432            Self::Deterministic { executor, .. } => {
433                let mut state = executor.state.lock();
434                state.forbid_parking = true;
435                state.rng = StdRng::seed_from_u64(state.seed);
436            }
437            _ => panic!("this method can only be called on a deterministic executor"),
438        }
439    }
440
441    pub async fn timer(&self, duration: Duration) {
442        match self {
443            #[cfg(any(test, feature = "test-support"))]
444            Self::Deterministic { executor, .. } => {
445                use postage::prelude::Stream as _;
446
447                let (tx, mut rx) = postage::barrier::channel();
448                let timer_id;
449                {
450                    let mut state = executor.state.lock();
451                    let wakeup_at = state.now + duration;
452                    timer_id = util::post_inc(&mut state.next_timer_id);
453                    state.pending_timers.push((timer_id, wakeup_at, tx));
454                }
455
456                struct DropTimer<'a>(usize, &'a Foreground);
457                impl<'a> Drop for DropTimer<'a> {
458                    fn drop(&mut self) {
459                        match self.1 {
460                            Foreground::Deterministic { executor, .. } => {
461                                executor
462                                    .state
463                                    .lock()
464                                    .pending_timers
465                                    .retain(|(timer_id, _, _)| *timer_id != self.0);
466                            }
467                            _ => unreachable!(),
468                        }
469                    }
470                }
471
472                let _guard = DropTimer(timer_id, self);
473                rx.recv().await;
474            }
475            _ => {
476                Timer::after(duration).await;
477            }
478        }
479    }
480
481    #[cfg(any(test, feature = "test-support"))]
482    pub fn advance_clock(&self, duration: Duration) {
483        match self {
484            Self::Deterministic { executor, .. } => {
485                executor.run_until_parked();
486                executor.advance_clock(duration);
487            }
488            _ => panic!("this method can only be called on a deterministic executor"),
489        }
490    }
491
492    #[cfg(any(test, feature = "test-support"))]
493    pub fn set_block_on_ticks(&self, range: std::ops::RangeInclusive<usize>) {
494        match self {
495            Self::Deterministic { executor, .. } => executor.state.lock().block_on_ticks = range,
496            _ => panic!("this method can only be called on a deterministic executor"),
497        }
498    }
499}
500
501impl Background {
502    pub fn new() -> Self {
503        let executor = Arc::new(Executor::new());
504        let stop = channel::unbounded::<()>();
505
506        for i in 0..2 * num_cpus::get() {
507            let executor = executor.clone();
508            let stop = stop.1.clone();
509            thread::Builder::new()
510                .name(format!("background-executor-{}", i))
511                .spawn(move || smol::block_on(executor.run(stop.recv())))
512                .unwrap();
513        }
514
515        Self::Production {
516            executor,
517            _stop: stop.0,
518        }
519    }
520
521    pub fn num_cpus(&self) -> usize {
522        num_cpus::get()
523    }
524
525    pub fn spawn<T, F>(&self, future: F) -> Task<T>
526    where
527        T: 'static + Send,
528        F: Send + Future<Output = T> + 'static,
529    {
530        let future = any_future(future);
531        let any_task = match self {
532            Self::Production { executor, .. } => executor.spawn(future),
533            #[cfg(any(test, feature = "test-support"))]
534            Self::Deterministic { executor } => executor.spawn(future),
535        };
536        Task::send(any_task)
537    }
538
539    pub fn block<F, T>(&self, future: F) -> T
540    where
541        F: Future<Output = T>,
542    {
543        smol::pin!(future);
544        match self {
545            Self::Production { .. } => smol::block_on(&mut future),
546            #[cfg(any(test, feature = "test-support"))]
547            Self::Deterministic { executor, .. } => {
548                executor.block(&mut future, usize::MAX).unwrap()
549            }
550        }
551    }
552
553    pub fn block_with_timeout<F, T>(
554        &self,
555        timeout: Duration,
556        future: F,
557    ) -> Result<T, impl Future<Output = T>>
558    where
559        T: 'static,
560        F: 'static + Unpin + Future<Output = T>,
561    {
562        let mut future = any_local_future(future);
563        if !timeout.is_zero() {
564            let output = match self {
565                Self::Production { .. } => smol::block_on(util::timeout(timeout, &mut future)).ok(),
566                #[cfg(any(test, feature = "test-support"))]
567                Self::Deterministic { executor, .. } => {
568                    use rand::prelude::*;
569                    let max_ticks = {
570                        let mut state = executor.state.lock();
571                        let range = state.block_on_ticks.clone();
572                        state.rng.gen_range(range)
573                    };
574                    executor.block(&mut future, max_ticks)
575                }
576            };
577            if let Some(output) = output {
578                return Ok(*output.downcast().unwrap());
579            }
580        }
581        Err(async { *future.await.downcast().unwrap() })
582    }
583
584    pub async fn scoped<'scope, F>(&self, scheduler: F)
585    where
586        F: FnOnce(&mut Scope<'scope>),
587    {
588        let mut scope = Scope {
589            futures: Default::default(),
590            _phantom: PhantomData,
591        };
592        (scheduler)(&mut scope);
593        let spawned = scope
594            .futures
595            .into_iter()
596            .map(|f| self.spawn(f))
597            .collect::<Vec<_>>();
598        for task in spawned {
599            task.await;
600        }
601    }
602
603    #[cfg(any(test, feature = "test-support"))]
604    pub async fn simulate_random_delay(&self) {
605        use rand::prelude::*;
606        use smol::future::yield_now;
607
608        match self {
609            Self::Deterministic { executor, .. } => {
610                if executor.state.lock().rng.gen_bool(0.2) {
611                    let yields = executor.state.lock().rng.gen_range(1..=10);
612                    for _ in 0..yields {
613                        yield_now().await;
614                    }
615
616                    let delay = Duration::from_millis(executor.state.lock().rng.gen_range(0..100));
617                    executor.advance_clock(delay);
618                }
619            }
620            _ => panic!("this method can only be called on a deterministic executor"),
621        }
622    }
623}
624
625pub struct Scope<'a> {
626    futures: Vec<Pin<Box<dyn Future<Output = ()> + Send + 'static>>>,
627    _phantom: PhantomData<&'a ()>,
628}
629
630impl<'a> Scope<'a> {
631    pub fn spawn<F>(&mut self, f: F)
632    where
633        F: Future<Output = ()> + Send + 'a,
634    {
635        let f = unsafe {
636            mem::transmute::<
637                Pin<Box<dyn Future<Output = ()> + Send + 'a>>,
638                Pin<Box<dyn Future<Output = ()> + Send + 'static>>,
639            >(Box::pin(f))
640        };
641        self.futures.push(f);
642    }
643}
644
645impl<T> Task<T> {
646    pub fn ready(value: T) -> Self {
647        Self::Ready(Some(value))
648    }
649
650    fn local(any_task: AnyLocalTask) -> Self {
651        Self::Local {
652            any_task,
653            result_type: PhantomData,
654        }
655    }
656
657    pub fn detach(self) {
658        match self {
659            Task::Ready(_) => {}
660            Task::Local { any_task, .. } => any_task.detach(),
661            Task::Send { any_task, .. } => any_task.detach(),
662        }
663    }
664}
665
666impl<T: 'static, E: 'static + Display> Task<Result<T, E>> {
667    pub fn detach_and_log_err(self, cx: &mut MutableAppContext) {
668        cx.spawn(|_| async move {
669            if let Err(err) = self.await {
670                log::error!("{}", err);
671            }
672        })
673        .detach();
674    }
675}
676
677impl<T: Send> Task<T> {
678    fn send(any_task: AnyTask) -> Self {
679        Self::Send {
680            any_task,
681            result_type: PhantomData,
682        }
683    }
684}
685
686impl<T: fmt::Debug> fmt::Debug for Task<T> {
687    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
688        match self {
689            Task::Ready(value) => value.fmt(f),
690            Task::Local { any_task, .. } => any_task.fmt(f),
691            Task::Send { any_task, .. } => any_task.fmt(f),
692        }
693    }
694}
695
696impl<T: 'static> Future for Task<T> {
697    type Output = T;
698
699    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
700        match unsafe { self.get_unchecked_mut() } {
701            Task::Ready(value) => Poll::Ready(value.take().unwrap()),
702            Task::Local { any_task, .. } => {
703                any_task.poll(cx).map(|value| *value.downcast().unwrap())
704            }
705            Task::Send { any_task, .. } => {
706                any_task.poll(cx).map(|value| *value.downcast().unwrap())
707            }
708        }
709    }
710}
711
712fn any_future<T, F>(future: F) -> AnyFuture
713where
714    T: 'static + Send,
715    F: Future<Output = T> + Send + 'static,
716{
717    async { Box::new(future.await) as Box<dyn Any + Send> }.boxed()
718}
719
720fn any_local_future<T, F>(future: F) -> AnyLocalFuture
721where
722    T: 'static,
723    F: Future<Output = T> + 'static,
724{
725    async { Box::new(future.await) as Box<dyn Any> }.boxed_local()
726}