executor.rs

  1use anyhow::{anyhow, Result};
  2use async_task::Runnable;
  3use smol::{channel, prelude::*, Executor, Timer};
  4use std::{
  5    any::Any,
  6    fmt::{self, Display},
  7    marker::PhantomData,
  8    mem,
  9    pin::Pin,
 10    rc::Rc,
 11    sync::Arc,
 12    task::{Context, Poll},
 13    thread,
 14    time::Duration,
 15};
 16
 17use crate::{
 18    platform::{self, Dispatcher},
 19    util::{self, post_inc},
 20    MutableAppContext,
 21};
 22
 23pub enum Foreground {
 24    Platform {
 25        dispatcher: Arc<dyn platform::Dispatcher>,
 26        _not_send_or_sync: PhantomData<Rc<()>>,
 27    },
 28    #[cfg(any(test, feature = "test-support"))]
 29    Deterministic {
 30        cx_id: usize,
 31        executor: Arc<Deterministic>,
 32    },
 33}
 34
 35pub enum Background {
 36    #[cfg(any(test, feature = "test-support"))]
 37    Deterministic { executor: Arc<Deterministic> },
 38    Production {
 39        executor: Arc<smol::Executor<'static>>,
 40        _stop: channel::Sender<()>,
 41    },
 42}
 43
 44type AnyLocalFuture = Pin<Box<dyn 'static + Future<Output = Box<dyn Any + 'static>>>>;
 45type AnyFuture = Pin<Box<dyn 'static + Send + Future<Output = Box<dyn Any + Send + 'static>>>>;
 46type AnyTask = async_task::Task<Box<dyn Any + Send + 'static>>;
 47type AnyLocalTask = async_task::Task<Box<dyn Any + 'static>>;
 48
 49#[must_use]
 50pub enum Task<T> {
 51    Ready(Option<T>),
 52    Local {
 53        any_task: AnyLocalTask,
 54        result_type: PhantomData<T>,
 55    },
 56    Send {
 57        any_task: AnyTask,
 58        result_type: PhantomData<T>,
 59    },
 60}
 61
 62unsafe impl<T: Send> Send for Task<T> {}
 63
 64#[cfg(any(test, feature = "test-support"))]
 65struct DeterministicState {
 66    rng: rand::prelude::StdRng,
 67    seed: u64,
 68    scheduled_from_foreground: collections::HashMap<usize, Vec<ForegroundRunnable>>,
 69    scheduled_from_background: Vec<Runnable>,
 70    forbid_parking: bool,
 71    block_on_ticks: std::ops::RangeInclusive<usize>,
 72    now: std::time::Instant,
 73    next_timer_id: usize,
 74    pending_timers: Vec<(usize, std::time::Instant, postage::barrier::Sender)>,
 75    waiting_backtrace: Option<backtrace::Backtrace>,
 76}
 77
 78#[cfg(any(test, feature = "test-support"))]
 79struct ForegroundRunnable {
 80    runnable: Runnable,
 81    main: bool,
 82}
 83
 84#[cfg(any(test, feature = "test-support"))]
 85pub struct Deterministic {
 86    state: Arc<parking_lot::Mutex<DeterministicState>>,
 87    parker: parking_lot::Mutex<parking::Parker>,
 88}
 89
 90#[cfg(any(test, feature = "test-support"))]
 91impl Deterministic {
 92    pub fn new(seed: u64) -> Arc<Self> {
 93        use rand::prelude::*;
 94
 95        Arc::new(Self {
 96            state: Arc::new(parking_lot::Mutex::new(DeterministicState {
 97                rng: StdRng::seed_from_u64(seed),
 98                seed,
 99                scheduled_from_foreground: Default::default(),
100                scheduled_from_background: Default::default(),
101                forbid_parking: false,
102                block_on_ticks: 0..=1000,
103                now: std::time::Instant::now(),
104                next_timer_id: Default::default(),
105                pending_timers: Default::default(),
106                waiting_backtrace: None,
107            })),
108            parker: Default::default(),
109        })
110    }
111
112    pub fn build_background(self: &Arc<Self>) -> Arc<Background> {
113        Arc::new(Background::Deterministic {
114            executor: self.clone(),
115        })
116    }
117
118    pub fn build_foreground(self: &Arc<Self>, id: usize) -> Rc<Foreground> {
119        Rc::new(Foreground::Deterministic {
120            cx_id: id,
121            executor: self.clone(),
122        })
123    }
124
125    fn spawn_from_foreground(
126        &self,
127        cx_id: usize,
128        future: AnyLocalFuture,
129        main: bool,
130    ) -> AnyLocalTask {
131        let state = self.state.clone();
132        let unparker = self.parker.lock().unparker();
133        let (runnable, task) = async_task::spawn_local(future, move |runnable| {
134            let mut state = state.lock();
135            state
136                .scheduled_from_foreground
137                .entry(cx_id)
138                .or_default()
139                .push(ForegroundRunnable { runnable, main });
140            unparker.unpark();
141        });
142        runnable.schedule();
143        task
144    }
145
146    fn spawn(&self, future: AnyFuture) -> AnyTask {
147        let state = self.state.clone();
148        let unparker = self.parker.lock().unparker();
149        let (runnable, task) = async_task::spawn(future, move |runnable| {
150            let mut state = state.lock();
151            state.scheduled_from_background.push(runnable);
152            unparker.unpark();
153        });
154        runnable.schedule();
155        task
156    }
157
158    fn run<'a>(
159        &self,
160        cx_id: usize,
161        main_future: Pin<Box<dyn 'a + Future<Output = Box<dyn Any>>>>,
162    ) -> Box<dyn Any> {
163        use std::sync::atomic::{AtomicBool, Ordering::SeqCst};
164
165        let woken = Arc::new(AtomicBool::new(false));
166
167        let state = self.state.clone();
168        let unparker = self.parker.lock().unparker();
169        let (runnable, mut main_task) = unsafe {
170            async_task::spawn_unchecked(main_future, move |runnable| {
171                let mut state = state.lock();
172                state
173                    .scheduled_from_foreground
174                    .entry(cx_id)
175                    .or_default()
176                    .push(ForegroundRunnable {
177                        runnable,
178                        main: true,
179                    });
180                unparker.unpark();
181            })
182        };
183        runnable.schedule();
184
185        loop {
186            if let Some(result) = self.run_internal(woken.clone(), Some(&mut main_task)) {
187                return result;
188            }
189
190            if !woken.load(SeqCst) {
191                self.state.lock().will_park();
192            }
193
194            woken.store(false, SeqCst);
195            self.parker.lock().park();
196        }
197    }
198
199    pub fn run_until_parked(&self) {
200        use std::sync::atomic::AtomicBool;
201        let woken = Arc::new(AtomicBool::new(false));
202        self.run_internal(woken, None);
203    }
204
205    fn run_internal(
206        &self,
207        woken: Arc<std::sync::atomic::AtomicBool>,
208        mut main_task: Option<&mut AnyLocalTask>,
209    ) -> Option<Box<dyn Any>> {
210        use rand::prelude::*;
211        use std::sync::atomic::Ordering::SeqCst;
212
213        let unparker = self.parker.lock().unparker();
214        let waker = waker_fn::waker_fn(move || {
215            woken.store(true, SeqCst);
216            unparker.unpark();
217        });
218
219        let mut cx = Context::from_waker(&waker);
220        loop {
221            let mut state = self.state.lock();
222
223            if state.scheduled_from_foreground.is_empty()
224                && state.scheduled_from_background.is_empty()
225            {
226                if let Some(main_task) = main_task {
227                    if let Poll::Ready(result) = main_task.poll(&mut cx) {
228                        return Some(result);
229                    }
230                }
231
232                return None;
233            }
234
235            if !state.scheduled_from_background.is_empty() && state.rng.gen() {
236                let background_len = state.scheduled_from_background.len();
237                let ix = state.rng.gen_range(0..background_len);
238                let runnable = state.scheduled_from_background.remove(ix);
239                drop(state);
240                runnable.run();
241            } else if !state.scheduled_from_foreground.is_empty() {
242                let available_cx_ids = state
243                    .scheduled_from_foreground
244                    .keys()
245                    .copied()
246                    .collect::<Vec<_>>();
247                let cx_id_to_run = *available_cx_ids.iter().choose(&mut state.rng).unwrap();
248                let scheduled_from_cx = state
249                    .scheduled_from_foreground
250                    .get_mut(&cx_id_to_run)
251                    .unwrap();
252                let foreground_runnable = scheduled_from_cx.remove(0);
253                if scheduled_from_cx.is_empty() {
254                    state.scheduled_from_foreground.remove(&cx_id_to_run);
255                }
256
257                drop(state);
258
259                foreground_runnable.runnable.run();
260                if let Some(main_task) = main_task.as_mut() {
261                    if foreground_runnable.main {
262                        if let Poll::Ready(result) = main_task.poll(&mut cx) {
263                            return Some(result);
264                        }
265                    }
266                }
267            }
268        }
269    }
270
271    fn block<F, T>(&self, future: &mut F, max_ticks: usize) -> Option<T>
272    where
273        F: Unpin + Future<Output = T>,
274    {
275        use rand::prelude::*;
276
277        let unparker = self.parker.lock().unparker();
278        let waker = waker_fn::waker_fn(move || {
279            unparker.unpark();
280        });
281
282        let mut cx = Context::from_waker(&waker);
283        for _ in 0..max_ticks {
284            let mut state = self.state.lock();
285            let runnable_count = state.scheduled_from_background.len();
286            let ix = state.rng.gen_range(0..=runnable_count);
287            if ix < state.scheduled_from_background.len() {
288                let runnable = state.scheduled_from_background.remove(ix);
289                drop(state);
290                runnable.run();
291            } else {
292                drop(state);
293                if let Poll::Ready(result) = future.poll(&mut cx) {
294                    return Some(result);
295                }
296                let mut state = self.state.lock();
297                if state.scheduled_from_background.is_empty() {
298                    state.will_park();
299                    drop(state);
300                    self.parker.lock().park();
301                }
302
303                continue;
304            }
305        }
306
307        None
308    }
309
310    pub fn advance_clock(&self, duration: Duration) {
311        let mut state = self.state.lock();
312        state.now += duration;
313        let now = state.now;
314        let mut pending_timers = mem::take(&mut state.pending_timers);
315        drop(state);
316
317        pending_timers.retain(|(_, wakeup, _)| *wakeup > now);
318        self.state.lock().pending_timers.extend(pending_timers);
319    }
320}
321
322#[cfg(any(test, feature = "test-support"))]
323impl DeterministicState {
324    fn will_park(&mut self) {
325        if self.forbid_parking {
326            let mut backtrace_message = String::new();
327            #[cfg(any(test, feature = "test-support"))]
328            if let Some(backtrace) = self.waiting_backtrace.as_mut() {
329                backtrace.resolve();
330                backtrace_message = format!(
331                    "\nbacktrace of waiting future:\n{:?}",
332                    util::CwdBacktrace(backtrace)
333                );
334            }
335
336            panic!(
337                "deterministic executor parked after a call to forbid_parking{}",
338                backtrace_message
339            );
340        }
341    }
342}
343
344impl Foreground {
345    pub fn platform(dispatcher: Arc<dyn platform::Dispatcher>) -> Result<Self> {
346        if dispatcher.is_main_thread() {
347            Ok(Self::Platform {
348                dispatcher,
349                _not_send_or_sync: PhantomData,
350            })
351        } else {
352            Err(anyhow!("must be constructed on main thread"))
353        }
354    }
355
356    pub fn spawn<T: 'static>(&self, future: impl Future<Output = T> + 'static) -> Task<T> {
357        let future = any_local_future(future);
358        let any_task = match self {
359            #[cfg(any(test, feature = "test-support"))]
360            Self::Deterministic { cx_id, executor } => {
361                executor.spawn_from_foreground(*cx_id, future, false)
362            }
363            Self::Platform { dispatcher, .. } => {
364                fn spawn_inner(
365                    future: AnyLocalFuture,
366                    dispatcher: &Arc<dyn Dispatcher>,
367                ) -> AnyLocalTask {
368                    let dispatcher = dispatcher.clone();
369                    let schedule =
370                        move |runnable: Runnable| dispatcher.run_on_main_thread(runnable);
371                    let (runnable, task) = async_task::spawn_local(future, schedule);
372                    runnable.schedule();
373                    task
374                }
375                spawn_inner(future, dispatcher)
376            }
377        };
378        Task::local(any_task)
379    }
380
381    #[cfg(any(test, feature = "test-support"))]
382    pub fn run<T: 'static>(&self, future: impl Future<Output = T>) -> T {
383        let future = async move { Box::new(future.await) as Box<dyn Any> }.boxed_local();
384        let result = match self {
385            Self::Deterministic { cx_id, executor } => executor.run(*cx_id, future),
386            Self::Platform { .. } => panic!("you can't call run on a platform foreground executor"),
387        };
388        *result.downcast().unwrap()
389    }
390
391    #[cfg(any(test, feature = "test-support"))]
392    pub fn run_until_parked(&self) {
393        match self {
394            Self::Deterministic { executor, .. } => executor.run_until_parked(),
395            _ => panic!("this method can only be called on a deterministic executor"),
396        }
397    }
398
399    #[cfg(any(test, feature = "test-support"))]
400    pub fn parking_forbidden(&self) -> bool {
401        match self {
402            Self::Deterministic { executor, .. } => executor.state.lock().forbid_parking,
403            _ => panic!("this method can only be called on a deterministic executor"),
404        }
405    }
406
407    #[cfg(any(test, feature = "test-support"))]
408    pub fn start_waiting(&self) {
409        match self {
410            Self::Deterministic { executor, .. } => {
411                executor.state.lock().waiting_backtrace =
412                    Some(backtrace::Backtrace::new_unresolved());
413            }
414            _ => panic!("this method can only be called on a deterministic executor"),
415        }
416    }
417
418    #[cfg(any(test, feature = "test-support"))]
419    pub fn finish_waiting(&self) {
420        match self {
421            Self::Deterministic { executor, .. } => {
422                executor.state.lock().waiting_backtrace.take();
423            }
424            _ => panic!("this method can only be called on a deterministic executor"),
425        }
426    }
427
428    #[cfg(any(test, feature = "test-support"))]
429    pub fn forbid_parking(&self) {
430        use rand::prelude::*;
431
432        match self {
433            Self::Deterministic { executor, .. } => {
434                let mut state = executor.state.lock();
435                state.forbid_parking = true;
436                state.rng = StdRng::seed_from_u64(state.seed);
437            }
438            _ => panic!("this method can only be called on a deterministic executor"),
439        }
440    }
441
442    pub async fn timer(&self, duration: Duration) {
443        match self {
444            #[cfg(any(test, feature = "test-support"))]
445            Self::Deterministic { executor, .. } => {
446                use postage::prelude::Stream as _;
447
448                let (tx, mut rx) = postage::barrier::channel();
449                let timer_id;
450                {
451                    let mut state = executor.state.lock();
452                    let wakeup_at = state.now + duration;
453                    timer_id = post_inc(&mut state.next_timer_id);
454                    state.pending_timers.push((timer_id, wakeup_at, tx));
455                }
456
457                struct DropTimer<'a>(usize, &'a Foreground);
458                impl<'a> Drop for DropTimer<'a> {
459                    fn drop(&mut self) {
460                        match self.1 {
461                            Foreground::Deterministic { executor, .. } => {
462                                executor
463                                    .state
464                                    .lock()
465                                    .pending_timers
466                                    .retain(|(timer_id, _, _)| *timer_id != self.0);
467                            }
468                            _ => unreachable!(),
469                        }
470                    }
471                }
472
473                let _guard = DropTimer(timer_id, self);
474                rx.recv().await;
475            }
476            _ => {
477                Timer::after(duration).await;
478            }
479        }
480    }
481
482    #[cfg(any(test, feature = "test-support"))]
483    pub fn advance_clock(&self, duration: Duration) {
484        match self {
485            Self::Deterministic { executor, .. } => {
486                executor.run_until_parked();
487                executor.advance_clock(duration);
488            }
489            _ => panic!("this method can only be called on a deterministic executor"),
490        }
491    }
492
493    #[cfg(any(test, feature = "test-support"))]
494    pub fn set_block_on_ticks(&self, range: std::ops::RangeInclusive<usize>) {
495        match self {
496            Self::Deterministic { executor, .. } => executor.state.lock().block_on_ticks = range,
497            _ => panic!("this method can only be called on a deterministic executor"),
498        }
499    }
500}
501
502impl Background {
503    pub fn new() -> Self {
504        let executor = Arc::new(Executor::new());
505        let stop = channel::unbounded::<()>();
506
507        for i in 0..2 * num_cpus::get() {
508            let executor = executor.clone();
509            let stop = stop.1.clone();
510            thread::Builder::new()
511                .name(format!("background-executor-{}", i))
512                .spawn(move || smol::block_on(executor.run(stop.recv())))
513                .unwrap();
514        }
515
516        Self::Production {
517            executor,
518            _stop: stop.0,
519        }
520    }
521
522    pub fn num_cpus(&self) -> usize {
523        num_cpus::get()
524    }
525
526    pub fn spawn<T, F>(&self, future: F) -> Task<T>
527    where
528        T: 'static + Send,
529        F: Send + Future<Output = T> + 'static,
530    {
531        let future = any_future(future);
532        let any_task = match self {
533            Self::Production { executor, .. } => executor.spawn(future),
534            #[cfg(any(test, feature = "test-support"))]
535            Self::Deterministic { executor } => executor.spawn(future),
536        };
537        Task::send(any_task)
538    }
539
540    pub fn block<F, T>(&self, future: F) -> T
541    where
542        F: Future<Output = T>,
543    {
544        smol::pin!(future);
545        match self {
546            Self::Production { .. } => smol::block_on(&mut future),
547            #[cfg(any(test, feature = "test-support"))]
548            Self::Deterministic { executor, .. } => {
549                executor.block(&mut future, usize::MAX).unwrap()
550            }
551        }
552    }
553
554    pub fn block_with_timeout<F, T>(
555        &self,
556        timeout: Duration,
557        future: F,
558    ) -> Result<T, impl Future<Output = T>>
559    where
560        T: 'static,
561        F: 'static + Unpin + Future<Output = T>,
562    {
563        let mut future = any_local_future(future);
564        if !timeout.is_zero() {
565            let output = match self {
566                Self::Production { .. } => smol::block_on(util::timeout(timeout, &mut future)).ok(),
567                #[cfg(any(test, feature = "test-support"))]
568                Self::Deterministic { executor, .. } => {
569                    use rand::prelude::*;
570                    let max_ticks = {
571                        let mut state = executor.state.lock();
572                        let range = state.block_on_ticks.clone();
573                        state.rng.gen_range(range)
574                    };
575                    executor.block(&mut future, max_ticks)
576                }
577            };
578            if let Some(output) = output {
579                return Ok(*output.downcast().unwrap());
580            }
581        }
582        Err(async { *future.await.downcast().unwrap() })
583    }
584
585    pub async fn scoped<'scope, F>(&self, scheduler: F)
586    where
587        F: FnOnce(&mut Scope<'scope>),
588    {
589        let mut scope = Scope {
590            futures: Default::default(),
591            _phantom: PhantomData,
592        };
593        (scheduler)(&mut scope);
594        let spawned = scope
595            .futures
596            .into_iter()
597            .map(|f| self.spawn(f))
598            .collect::<Vec<_>>();
599        for task in spawned {
600            task.await;
601        }
602    }
603
604    #[cfg(any(test, feature = "test-support"))]
605    pub async fn simulate_random_delay(&self) {
606        use rand::prelude::*;
607        use smol::future::yield_now;
608
609        match self {
610            Self::Deterministic { executor, .. } => {
611                if executor.state.lock().rng.gen_bool(0.2) {
612                    let yields = executor.state.lock().rng.gen_range(1..=10);
613                    for _ in 0..yields {
614                        yield_now().await;
615                    }
616
617                    let delay = Duration::from_millis(executor.state.lock().rng.gen_range(0..100));
618                    executor.advance_clock(delay);
619                }
620            }
621            _ => panic!("this method can only be called on a deterministic executor"),
622        }
623    }
624}
625
626pub struct Scope<'a> {
627    futures: Vec<Pin<Box<dyn Future<Output = ()> + Send + 'static>>>,
628    _phantom: PhantomData<&'a ()>,
629}
630
631impl<'a> Scope<'a> {
632    pub fn spawn<F>(&mut self, f: F)
633    where
634        F: Future<Output = ()> + Send + 'a,
635    {
636        let f = unsafe {
637            mem::transmute::<
638                Pin<Box<dyn Future<Output = ()> + Send + 'a>>,
639                Pin<Box<dyn Future<Output = ()> + Send + 'static>>,
640            >(Box::pin(f))
641        };
642        self.futures.push(f);
643    }
644}
645
646impl<T> Task<T> {
647    pub fn ready(value: T) -> Self {
648        Self::Ready(Some(value))
649    }
650
651    fn local(any_task: AnyLocalTask) -> Self {
652        Self::Local {
653            any_task,
654            result_type: PhantomData,
655        }
656    }
657
658    pub fn detach(self) {
659        match self {
660            Task::Ready(_) => {}
661            Task::Local { any_task, .. } => any_task.detach(),
662            Task::Send { any_task, .. } => any_task.detach(),
663        }
664    }
665}
666
667impl<T: 'static, E: 'static + Display> Task<Result<T, E>> {
668    pub fn detach_and_log_err(self, cx: &mut MutableAppContext) {
669        cx.spawn(|_| async move {
670            if let Err(err) = self.await {
671                log::error!("{}", err);
672            }
673        })
674        .detach();
675    }
676}
677
678impl<T: Send> Task<T> {
679    fn send(any_task: AnyTask) -> Self {
680        Self::Send {
681            any_task,
682            result_type: PhantomData,
683        }
684    }
685}
686
687impl<T: fmt::Debug> fmt::Debug for Task<T> {
688    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
689        match self {
690            Task::Ready(value) => value.fmt(f),
691            Task::Local { any_task, .. } => any_task.fmt(f),
692            Task::Send { any_task, .. } => any_task.fmt(f),
693        }
694    }
695}
696
697impl<T: 'static> Future for Task<T> {
698    type Output = T;
699
700    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
701        match unsafe { self.get_unchecked_mut() } {
702            Task::Ready(value) => Poll::Ready(value.take().unwrap()),
703            Task::Local { any_task, .. } => {
704                any_task.poll(cx).map(|value| *value.downcast().unwrap())
705            }
706            Task::Send { any_task, .. } => {
707                any_task.poll(cx).map(|value| *value.downcast().unwrap())
708            }
709        }
710    }
711}
712
713fn any_future<T, F>(future: F) -> AnyFuture
714where
715    T: 'static + Send,
716    F: Future<Output = T> + Send + 'static,
717{
718    async { Box::new(future.await) as Box<dyn Any + Send> }.boxed()
719}
720
721fn any_local_future<T, F>(future: F) -> AnyLocalFuture
722where
723    T: 'static,
724    F: Future<Output = T> + 'static,
725{
726    async { Box::new(future.await) as Box<dyn Any> }.boxed_local()
727}