executor.rs

  1use anyhow::{anyhow, Result};
  2use async_task::Runnable;
  3use backtrace::{Backtrace, BacktraceFmt, BytesOrWideString};
  4use parking_lot::Mutex;
  5use postage::{barrier, prelude::Stream as _};
  6use rand::prelude::*;
  7use smol::{channel, prelude::*, Executor, Timer};
  8use std::{
  9    any::Any,
 10    fmt::{self, Debug},
 11    marker::PhantomData,
 12    mem,
 13    ops::RangeInclusive,
 14    pin::Pin,
 15    rc::Rc,
 16    sync::{
 17        atomic::{AtomicBool, Ordering::SeqCst},
 18        Arc,
 19    },
 20    task::{Context, Poll},
 21    thread,
 22    time::{Duration, Instant},
 23};
 24use waker_fn::waker_fn;
 25
 26use crate::{platform, util};
 27
 28pub enum Foreground {
 29    Platform {
 30        dispatcher: Arc<dyn platform::Dispatcher>,
 31        _not_send_or_sync: PhantomData<Rc<()>>,
 32    },
 33    Test(smol::LocalExecutor<'static>),
 34    Deterministic(Arc<Deterministic>),
 35}
 36
 37pub enum Background {
 38    Deterministic(Arc<Deterministic>),
 39    Production {
 40        executor: Arc<smol::Executor<'static>>,
 41        _stop: channel::Sender<()>,
 42    },
 43}
 44
 45type AnyLocalFuture = Pin<Box<dyn 'static + Future<Output = Box<dyn Any + 'static>>>>;
 46type AnyFuture = Pin<Box<dyn 'static + Send + Future<Output = Box<dyn Any + Send + 'static>>>>;
 47type AnyTask = async_task::Task<Box<dyn Any + Send + 'static>>;
 48type AnyLocalTask = async_task::Task<Box<dyn Any + 'static>>;
 49
 50pub enum Task<T> {
 51    Local {
 52        any_task: AnyLocalTask,
 53        result_type: PhantomData<T>,
 54    },
 55    Send {
 56        any_task: AnyTask,
 57        result_type: PhantomData<T>,
 58    },
 59}
 60
 61unsafe impl<T: Send> Send for Task<T> {}
 62
 63struct DeterministicState {
 64    rng: StdRng,
 65    seed: u64,
 66    scheduled_from_foreground: Vec<(Runnable, Backtrace)>,
 67    scheduled_from_background: Vec<(Runnable, Backtrace)>,
 68    spawned_from_foreground: Vec<(Runnable, Backtrace)>,
 69    forbid_parking: bool,
 70    block_on_ticks: RangeInclusive<usize>,
 71    now: Instant,
 72    pending_timers: Vec<(Instant, barrier::Sender)>,
 73}
 74
 75pub struct Deterministic {
 76    state: Arc<Mutex<DeterministicState>>,
 77    parker: Mutex<parking::Parker>,
 78}
 79
 80impl Deterministic {
 81    fn new(seed: u64) -> Self {
 82        Self {
 83            state: Arc::new(Mutex::new(DeterministicState {
 84                rng: StdRng::seed_from_u64(seed),
 85                seed,
 86                scheduled_from_foreground: Default::default(),
 87                scheduled_from_background: Default::default(),
 88                spawned_from_foreground: Default::default(),
 89                forbid_parking: false,
 90                block_on_ticks: 0..=1000,
 91                now: Instant::now(),
 92                pending_timers: Default::default(),
 93            })),
 94            parker: Default::default(),
 95        }
 96    }
 97
 98    fn spawn_from_foreground(&self, future: AnyLocalFuture) -> AnyLocalTask {
 99        let backtrace = Backtrace::new_unresolved();
100        let scheduled_once = AtomicBool::new(false);
101        let state = self.state.clone();
102        let unparker = self.parker.lock().unparker();
103        let (runnable, task) = async_task::spawn_local(future, move |runnable| {
104            let mut state = state.lock();
105            let backtrace = backtrace.clone();
106            if scheduled_once.fetch_or(true, SeqCst) {
107                state.scheduled_from_foreground.push((runnable, backtrace));
108            } else {
109                state.spawned_from_foreground.push((runnable, backtrace));
110            }
111            unparker.unpark();
112        });
113        runnable.schedule();
114        task
115    }
116
117    fn spawn(&self, future: AnyFuture) -> AnyTask {
118        let backtrace = Backtrace::new_unresolved();
119        let state = self.state.clone();
120        let unparker = self.parker.lock().unparker();
121        let (runnable, task) = async_task::spawn(future, move |runnable| {
122            let mut state = state.lock();
123            state
124                .scheduled_from_background
125                .push((runnable, backtrace.clone()));
126            unparker.unpark();
127        });
128        runnable.schedule();
129        task
130    }
131
132    fn run(&self, mut future: AnyLocalFuture) -> Box<dyn Any> {
133        let woken = Arc::new(AtomicBool::new(false));
134        loop {
135            if let Some(result) = self.run_internal(woken.clone(), &mut future) {
136                return result;
137            }
138
139            if !woken.load(SeqCst) && self.state.lock().forbid_parking {
140                panic!("deterministic executor parked after a call to forbid_parking");
141            }
142
143            woken.store(false, SeqCst);
144            self.parker.lock().park();
145        }
146    }
147
148    fn run_until_parked(&self) {
149        let woken = Arc::new(AtomicBool::new(false));
150        let mut future = any_local_future(std::future::pending::<()>());
151        self.run_internal(woken, &mut future);
152    }
153
154    fn run_internal(
155        &self,
156        woken: Arc<AtomicBool>,
157        future: &mut AnyLocalFuture,
158    ) -> Option<Box<dyn Any>> {
159        let unparker = self.parker.lock().unparker();
160        let waker = waker_fn(move || {
161            woken.store(true, SeqCst);
162            unparker.unpark();
163        });
164
165        let mut cx = Context::from_waker(&waker);
166        let mut trace = Trace::default();
167        loop {
168            let mut state = self.state.lock();
169            let runnable_count = state.scheduled_from_foreground.len()
170                + state.scheduled_from_background.len()
171                + state.spawned_from_foreground.len();
172
173            let ix = state.rng.gen_range(0..=runnable_count);
174            if ix < state.scheduled_from_foreground.len() {
175                let (_, backtrace) = &state.scheduled_from_foreground[ix];
176                trace.record(&state, backtrace.clone());
177                let runnable = state.scheduled_from_foreground.remove(ix).0;
178                drop(state);
179                runnable.run();
180            } else if ix - state.scheduled_from_foreground.len()
181                < state.scheduled_from_background.len()
182            {
183                let ix = ix - state.scheduled_from_foreground.len();
184                let (_, backtrace) = &state.scheduled_from_background[ix];
185                trace.record(&state, backtrace.clone());
186                let runnable = state.scheduled_from_background.remove(ix).0;
187                drop(state);
188                runnable.run();
189            } else if ix < runnable_count {
190                let (_, backtrace) = &state.spawned_from_foreground[0];
191                trace.record(&state, backtrace.clone());
192                let runnable = state.spawned_from_foreground.remove(0).0;
193                drop(state);
194                runnable.run();
195            } else {
196                drop(state);
197                if let Poll::Ready(result) = future.poll(&mut cx) {
198                    return Some(result);
199                }
200
201                let state = self.state.lock();
202                if state.scheduled_from_foreground.is_empty()
203                    && state.scheduled_from_background.is_empty()
204                    && state.spawned_from_foreground.is_empty()
205                {
206                    return None;
207                }
208            }
209        }
210    }
211
212    fn block_on(&self, future: &mut AnyLocalFuture) -> Option<Box<dyn Any>> {
213        let unparker = self.parker.lock().unparker();
214        let waker = waker_fn(move || {
215            unparker.unpark();
216        });
217        let max_ticks = {
218            let mut state = self.state.lock();
219            let range = state.block_on_ticks.clone();
220            state.rng.gen_range(range)
221        };
222
223        let mut cx = Context::from_waker(&waker);
224        let mut trace = Trace::default();
225        for _ in 0..max_ticks {
226            let mut state = self.state.lock();
227            let runnable_count = state.scheduled_from_background.len();
228            let ix = state.rng.gen_range(0..=runnable_count);
229            if ix < state.scheduled_from_background.len() {
230                let (_, backtrace) = &state.scheduled_from_background[ix];
231                trace.record(&state, backtrace.clone());
232                let runnable = state.scheduled_from_background.remove(ix).0;
233                drop(state);
234                runnable.run();
235            } else {
236                drop(state);
237                if let Poll::Ready(result) = future.as_mut().poll(&mut cx) {
238                    return Some(result);
239                }
240                let state = self.state.lock();
241                if state.scheduled_from_background.is_empty() {
242                    if state.forbid_parking {
243                        panic!("deterministic executor parked after a call to forbid_parking");
244                    }
245                    drop(state);
246                    self.parker.lock().park();
247                }
248
249                continue;
250            }
251        }
252
253        None
254    }
255}
256
257#[derive(Default)]
258struct Trace {
259    executed: Vec<Backtrace>,
260    scheduled: Vec<Vec<Backtrace>>,
261    spawned_from_foreground: Vec<Vec<Backtrace>>,
262}
263
264impl Trace {
265    fn record(&mut self, state: &DeterministicState, executed: Backtrace) {
266        self.scheduled.push(
267            state
268                .scheduled_from_foreground
269                .iter()
270                .map(|(_, backtrace)| backtrace.clone())
271                .collect(),
272        );
273        self.spawned_from_foreground.push(
274            state
275                .spawned_from_foreground
276                .iter()
277                .map(|(_, backtrace)| backtrace.clone())
278                .collect(),
279        );
280        self.executed.push(executed);
281    }
282
283    fn resolve(&mut self) {
284        for backtrace in &mut self.executed {
285            backtrace.resolve();
286        }
287
288        for backtraces in &mut self.scheduled {
289            for backtrace in backtraces {
290                backtrace.resolve();
291            }
292        }
293
294        for backtraces in &mut self.spawned_from_foreground {
295            for backtrace in backtraces {
296                backtrace.resolve();
297            }
298        }
299    }
300}
301
302impl Debug for Trace {
303    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
304        struct FirstCwdFrameInBacktrace<'a>(&'a Backtrace);
305
306        impl<'a> Debug for FirstCwdFrameInBacktrace<'a> {
307            fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
308                let cwd = std::env::current_dir().unwrap();
309                let mut print_path = |fmt: &mut fmt::Formatter<'_>, path: BytesOrWideString<'_>| {
310                    fmt::Display::fmt(&path, fmt)
311                };
312                let mut fmt = BacktraceFmt::new(f, backtrace::PrintFmt::Full, &mut print_path);
313                for frame in self.0.frames() {
314                    let mut formatted_frame = fmt.frame();
315                    if frame
316                        .symbols()
317                        .iter()
318                        .any(|s| s.filename().map_or(false, |f| f.starts_with(&cwd)))
319                    {
320                        formatted_frame.backtrace_frame(frame)?;
321                        break;
322                    }
323                }
324                fmt.finish()
325            }
326        }
327
328        for ((backtrace, scheduled), spawned_from_foreground) in self
329            .executed
330            .iter()
331            .zip(&self.scheduled)
332            .zip(&self.spawned_from_foreground)
333        {
334            writeln!(f, "Scheduled")?;
335            for backtrace in scheduled {
336                writeln!(f, "- {:?}", FirstCwdFrameInBacktrace(backtrace))?;
337            }
338            if scheduled.is_empty() {
339                writeln!(f, "None")?;
340            }
341            writeln!(f, "==========")?;
342
343            writeln!(f, "Spawned from foreground")?;
344            for backtrace in spawned_from_foreground {
345                writeln!(f, "- {:?}", FirstCwdFrameInBacktrace(backtrace))?;
346            }
347            if spawned_from_foreground.is_empty() {
348                writeln!(f, "None")?;
349            }
350            writeln!(f, "==========")?;
351
352            writeln!(f, "Run: {:?}", FirstCwdFrameInBacktrace(backtrace))?;
353            writeln!(f, "+++++++++++++++++++")?;
354        }
355
356        Ok(())
357    }
358}
359
360impl Drop for Trace {
361    fn drop(&mut self) {
362        let trace_on_panic = if let Ok(trace_on_panic) = std::env::var("EXECUTOR_TRACE_ON_PANIC") {
363            trace_on_panic == "1" || trace_on_panic == "true"
364        } else {
365            false
366        };
367        let trace_always = if let Ok(trace_always) = std::env::var("EXECUTOR_TRACE_ALWAYS") {
368            trace_always == "1" || trace_always == "true"
369        } else {
370            false
371        };
372
373        if trace_always || (trace_on_panic && thread::panicking()) {
374            self.resolve();
375            dbg!(self);
376        }
377    }
378}
379
380impl Foreground {
381    pub fn platform(dispatcher: Arc<dyn platform::Dispatcher>) -> Result<Self> {
382        if dispatcher.is_main_thread() {
383            Ok(Self::Platform {
384                dispatcher,
385                _not_send_or_sync: PhantomData,
386            })
387        } else {
388            Err(anyhow!("must be constructed on main thread"))
389        }
390    }
391
392    pub fn test() -> Self {
393        Self::Test(smol::LocalExecutor::new())
394    }
395
396    pub fn spawn<T: 'static>(&self, future: impl Future<Output = T> + 'static) -> Task<T> {
397        let future = any_local_future(future);
398        let any_task = match self {
399            Self::Deterministic(executor) => executor.spawn_from_foreground(future),
400            Self::Platform { dispatcher, .. } => {
401                let dispatcher = dispatcher.clone();
402                let schedule = move |runnable: Runnable| dispatcher.run_on_main_thread(runnable);
403                let (runnable, task) = async_task::spawn_local(future, schedule);
404                runnable.schedule();
405                task
406            }
407            Self::Test(executor) => executor.spawn(future),
408        };
409        Task::local(any_task)
410    }
411
412    pub fn run<T: 'static>(&self, future: impl 'static + Future<Output = T>) -> T {
413        let future = any_local_future(future);
414        let any_value = match self {
415            Self::Deterministic(executor) => executor.run(future),
416            Self::Platform { .. } => panic!("you can't call run on a platform foreground executor"),
417            Self::Test(executor) => smol::block_on(executor.run(future)),
418        };
419        *any_value.downcast().unwrap()
420    }
421
422    pub fn forbid_parking(&self) {
423        match self {
424            Self::Deterministic(executor) => {
425                let mut state = executor.state.lock();
426                state.forbid_parking = true;
427                state.rng = StdRng::seed_from_u64(state.seed);
428            }
429            _ => panic!("this method can only be called on a deterministic executor"),
430        }
431    }
432
433    pub async fn timer(&self, duration: Duration) {
434        match self {
435            Self::Deterministic(executor) => {
436                let (tx, mut rx) = barrier::channel();
437                {
438                    let mut state = executor.state.lock();
439                    let wakeup_at = state.now + duration;
440                    state.pending_timers.push((wakeup_at, tx));
441                }
442                rx.recv().await;
443            }
444            _ => {
445                Timer::after(duration).await;
446            }
447        }
448    }
449
450    pub fn advance_clock(&self, duration: Duration) {
451        match self {
452            Self::Deterministic(executor) => {
453                executor.run_until_parked();
454
455                let mut state = executor.state.lock();
456                state.now += duration;
457                let now = state.now;
458                let mut pending_timers = mem::take(&mut state.pending_timers);
459                drop(state);
460
461                pending_timers.retain(|(wakeup, _)| *wakeup > now);
462                executor.state.lock().pending_timers.extend(pending_timers);
463            }
464            _ => panic!("this method can only be called on a deterministic executor"),
465        }
466    }
467
468    pub fn set_block_on_ticks(&self, range: RangeInclusive<usize>) {
469        match self {
470            Self::Deterministic(executor) => executor.state.lock().block_on_ticks = range,
471            _ => panic!("this method can only be called on a deterministic executor"),
472        }
473    }
474}
475
476impl Background {
477    pub fn new() -> Self {
478        let executor = Arc::new(Executor::new());
479        let stop = channel::unbounded::<()>();
480
481        for i in 0..2 * num_cpus::get() {
482            let executor = executor.clone();
483            let stop = stop.1.clone();
484            thread::Builder::new()
485                .name(format!("background-executor-{}", i))
486                .spawn(move || smol::block_on(executor.run(stop.recv())))
487                .unwrap();
488        }
489
490        Self::Production {
491            executor,
492            _stop: stop.0,
493        }
494    }
495
496    pub fn num_cpus(&self) -> usize {
497        num_cpus::get()
498    }
499
500    pub fn spawn<T, F>(&self, future: F) -> Task<T>
501    where
502        T: 'static + Send,
503        F: Send + Future<Output = T> + 'static,
504    {
505        let future = any_future(future);
506        let any_task = match self {
507            Self::Production { executor, .. } => executor.spawn(future),
508            Self::Deterministic(executor) => executor.spawn(future),
509        };
510        Task::send(any_task)
511    }
512
513    pub fn block_with_timeout<F, T>(
514        &self,
515        timeout: Duration,
516        future: F,
517    ) -> Result<T, impl Future<Output = T>>
518    where
519        T: 'static,
520        F: 'static + Unpin + Future<Output = T>,
521    {
522        let mut future = any_local_future(future);
523        if !timeout.is_zero() {
524            let output = match self {
525                Self::Production { .. } => smol::block_on(util::timeout(timeout, &mut future)).ok(),
526                Self::Deterministic(executor) => executor.block_on(&mut future),
527            };
528            if let Some(output) = output {
529                return Ok(*output.downcast().unwrap());
530            }
531        }
532        Err(async { *future.await.downcast().unwrap() })
533    }
534
535    pub async fn scoped<'scope, F>(&self, scheduler: F)
536    where
537        F: FnOnce(&mut Scope<'scope>),
538    {
539        let mut scope = Scope {
540            futures: Default::default(),
541            _phantom: PhantomData,
542        };
543        (scheduler)(&mut scope);
544        let spawned = scope
545            .futures
546            .into_iter()
547            .map(|f| self.spawn(f))
548            .collect::<Vec<_>>();
549        for task in spawned {
550            task.await;
551        }
552    }
553}
554
555pub struct Scope<'a> {
556    futures: Vec<Pin<Box<dyn Future<Output = ()> + Send + 'static>>>,
557    _phantom: PhantomData<&'a ()>,
558}
559
560impl<'a> Scope<'a> {
561    pub fn spawn<F>(&mut self, f: F)
562    where
563        F: Future<Output = ()> + Send + 'a,
564    {
565        let f = unsafe {
566            mem::transmute::<
567                Pin<Box<dyn Future<Output = ()> + Send + 'a>>,
568                Pin<Box<dyn Future<Output = ()> + Send + 'static>>,
569            >(Box::pin(f))
570        };
571        self.futures.push(f);
572    }
573}
574
575pub fn deterministic(seed: u64) -> (Rc<Foreground>, Arc<Background>) {
576    let executor = Arc::new(Deterministic::new(seed));
577    (
578        Rc::new(Foreground::Deterministic(executor.clone())),
579        Arc::new(Background::Deterministic(executor)),
580    )
581}
582
583impl<T> Task<T> {
584    fn local(any_task: AnyLocalTask) -> Self {
585        Self::Local {
586            any_task,
587            result_type: PhantomData,
588        }
589    }
590
591    pub fn detach(self) {
592        match self {
593            Task::Local { any_task, .. } => any_task.detach(),
594            Task::Send { any_task, .. } => any_task.detach(),
595        }
596    }
597}
598
599impl<T: Send> Task<T> {
600    fn send(any_task: AnyTask) -> Self {
601        Self::Send {
602            any_task,
603            result_type: PhantomData,
604        }
605    }
606}
607
608impl<T: fmt::Debug> fmt::Debug for Task<T> {
609    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
610        match self {
611            Task::Local { any_task, .. } => any_task.fmt(f),
612            Task::Send { any_task, .. } => any_task.fmt(f),
613        }
614    }
615}
616
617impl<T: 'static> Future for Task<T> {
618    type Output = T;
619
620    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
621        match unsafe { self.get_unchecked_mut() } {
622            Task::Local { any_task, .. } => {
623                any_task.poll(cx).map(|value| *value.downcast().unwrap())
624            }
625            Task::Send { any_task, .. } => {
626                any_task.poll(cx).map(|value| *value.downcast().unwrap())
627            }
628        }
629    }
630}
631
632fn any_future<T, F>(future: F) -> AnyFuture
633where
634    T: 'static + Send,
635    F: Future<Output = T> + Send + 'static,
636{
637    async { Box::new(future.await) as Box<dyn Any + Send> }.boxed()
638}
639
640fn any_local_future<T, F>(future: F) -> AnyLocalFuture
641where
642    T: 'static,
643    F: Future<Output = T> + 'static,
644{
645    async { Box::new(future.await) as Box<dyn Any> }.boxed_local()
646}