executor.rs

  1use crate::{App, PlatformDispatcher, RunnableMeta, RunnableVariant};
  2use async_task::Runnable;
  3use futures::channel::mpsc;
  4use parking_lot::{Condvar, Mutex};
  5use smol::prelude::*;
  6use std::{
  7    fmt::Debug,
  8    marker::PhantomData,
  9    mem::{self, ManuallyDrop},
 10    num::NonZeroUsize,
 11    panic::Location,
 12    pin::Pin,
 13    rc::Rc,
 14    sync::{
 15        Arc,
 16        atomic::{AtomicUsize, Ordering},
 17    },
 18    task::{Context, Poll},
 19    thread::{self, ThreadId},
 20    time::{Duration, Instant},
 21};
 22use util::TryFutureExt;
 23use waker_fn::waker_fn;
 24
 25#[cfg(any(test, feature = "test-support"))]
 26use rand::rngs::StdRng;
 27
 28/// A pointer to the executor that is currently running,
 29/// for spawning background tasks.
 30#[derive(Clone)]
 31pub struct BackgroundExecutor {
 32    #[doc(hidden)]
 33    pub dispatcher: Arc<dyn PlatformDispatcher>,
 34}
 35
 36/// A pointer to the executor that is currently running,
 37/// for spawning tasks on the main thread.
 38///
 39/// This is intentionally `!Send` via the `not_send` marker field. This is because
 40/// `ForegroundExecutor::spawn` does not require `Send` but checks at runtime that the future is
 41/// only polled from the same thread it was spawned from. These checks would fail when spawning
 42/// foreground tasks from background threads.
 43#[derive(Clone)]
 44pub struct ForegroundExecutor {
 45    #[doc(hidden)]
 46    pub dispatcher: Arc<dyn PlatformDispatcher>,
 47    not_send: PhantomData<Rc<()>>,
 48}
 49
 50/// Task is a primitive that allows work to happen in the background.
 51///
 52/// It implements [`Future`] so you can `.await` on it.
 53///
 54/// If you drop a task it will be cancelled immediately. Calling [`Task::detach`] allows
 55/// the task to continue running, but with no way to return a value.
 56#[must_use]
 57#[derive(Debug)]
 58pub struct Task<T>(TaskState<T>);
 59
 60#[derive(Debug)]
 61enum TaskState<T> {
 62    /// A task that is ready to return a value
 63    Ready(Option<T>),
 64
 65    /// A task that is currently running.
 66    Spawned(async_task::Task<T, RunnableMeta>),
 67}
 68
 69impl<T> Task<T> {
 70    /// Creates a new task that will resolve with the value
 71    pub fn ready(val: T) -> Self {
 72        Task(TaskState::Ready(Some(val)))
 73    }
 74
 75    /// Detaching a task runs it to completion in the background
 76    pub fn detach(self) {
 77        match self {
 78            Task(TaskState::Ready(_)) => {}
 79            Task(TaskState::Spawned(task)) => task.detach(),
 80        }
 81    }
 82}
 83
 84impl<E, T> Task<Result<T, E>>
 85where
 86    T: 'static,
 87    E: 'static + Debug,
 88{
 89    /// Run the task to completion in the background and log any
 90    /// errors that occur.
 91    #[track_caller]
 92    pub fn detach_and_log_err(self, cx: &App) {
 93        let location = core::panic::Location::caller();
 94        cx.foreground_executor()
 95            .spawn(self.log_tracked_err(*location))
 96            .detach();
 97    }
 98}
 99
100impl<T> Future for Task<T> {
101    type Output = T;
102
103    fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
104        match unsafe { self.get_unchecked_mut() } {
105            Task(TaskState::Ready(val)) => Poll::Ready(val.take().unwrap()),
106            Task(TaskState::Spawned(task)) => task.poll(cx),
107        }
108    }
109}
110
111/// A task label is an opaque identifier that you can use to
112/// refer to a task in tests.
113#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
114pub struct TaskLabel(NonZeroUsize);
115
116impl Default for TaskLabel {
117    fn default() -> Self {
118        Self::new()
119    }
120}
121
122impl TaskLabel {
123    /// Construct a new task label.
124    pub fn new() -> Self {
125        static NEXT_TASK_LABEL: AtomicUsize = AtomicUsize::new(1);
126        Self(
127            NEXT_TASK_LABEL
128                .fetch_add(1, Ordering::SeqCst)
129                .try_into()
130                .unwrap(),
131        )
132    }
133}
134
135type AnyLocalFuture<R> = Pin<Box<dyn 'static + Future<Output = R>>>;
136
137type AnyFuture<R> = Pin<Box<dyn 'static + Send + Future<Output = R>>>;
138
139/// BackgroundExecutor lets you run things on background threads.
140/// In production this is a thread pool with no ordering guarantees.
141/// In tests this is simulated by running tasks one by one in a deterministic
142/// (but arbitrary) order controlled by the `SEED` environment variable.
143impl BackgroundExecutor {
144    #[doc(hidden)]
145    pub fn new(dispatcher: Arc<dyn PlatformDispatcher>) -> Self {
146        Self { dispatcher }
147    }
148
149    /// Enqueues the given future to be run to completion on a background thread.
150    #[track_caller]
151    pub fn spawn<R>(&self, future: impl Future<Output = R> + Send + 'static) -> Task<R>
152    where
153        R: Send + 'static,
154    {
155        self.spawn_internal::<R>(Box::pin(future), None)
156    }
157
158    /// Enqueues the given future to be run to completion on a background thread and blocking the current task on it.
159    ///
160    /// This allows to spawn background work that borrows from its scope. Note that the supplied future will run to
161    /// completion before the current task is resumed, even if the current task is slated for cancellation.
162    pub async fn await_on_background<R>(&self, future: impl Future<Output = R> + Send) -> R
163    where
164        R: Send,
165    {
166        // We need to ensure that cancellation of the parent task does not drop the environment
167        // before the our own task has completed or got cancelled.
168        struct NotifyOnDrop<'a>(&'a (Condvar, Mutex<bool>));
169
170        impl Drop for NotifyOnDrop<'_> {
171            fn drop(&mut self) {
172                *self.0.1.lock() = true;
173                self.0.0.notify_all();
174            }
175        }
176
177        struct WaitOnDrop<'a>(&'a (Condvar, Mutex<bool>));
178
179        impl Drop for WaitOnDrop<'_> {
180            fn drop(&mut self) {
181                let mut done = self.0.1.lock();
182                if !*done {
183                    self.0.0.wait(&mut done);
184                }
185            }
186        }
187
188        let dispatcher = self.dispatcher.clone();
189        let location = core::panic::Location::caller();
190
191        let pair = &(Condvar::new(), Mutex::new(false));
192        let _wait_guard = WaitOnDrop(pair);
193
194        let (runnable, task) = unsafe {
195            async_task::Builder::new()
196                .metadata(RunnableMeta { location })
197                .spawn_unchecked(
198                    move |_| async {
199                        let _notify_guard = NotifyOnDrop(pair);
200                        future.await
201                    },
202                    move |runnable| dispatcher.dispatch(RunnableVariant::Meta(runnable), None),
203                )
204        };
205        runnable.schedule();
206        task.await
207    }
208
209    /// Enqueues the given future to be run to completion on a background thread.
210    /// The given label can be used to control the priority of the task in tests.
211    #[track_caller]
212    pub fn spawn_labeled<R>(
213        &self,
214        label: TaskLabel,
215        future: impl Future<Output = R> + Send + 'static,
216    ) -> Task<R>
217    where
218        R: Send + 'static,
219    {
220        self.spawn_internal::<R>(Box::pin(future), Some(label))
221    }
222
223    #[track_caller]
224    fn spawn_internal<R: Send + 'static>(
225        &self,
226        future: AnyFuture<R>,
227        label: Option<TaskLabel>,
228    ) -> Task<R> {
229        let dispatcher = self.dispatcher.clone();
230        let location = core::panic::Location::caller();
231        let (runnable, task) = async_task::Builder::new()
232            .metadata(RunnableMeta { location })
233            .spawn(
234                move |_| future,
235                move |runnable| dispatcher.dispatch(RunnableVariant::Meta(runnable), label),
236            );
237        runnable.schedule();
238        Task(TaskState::Spawned(task))
239    }
240
241    /// Used by the test harness to run an async test in a synchronous fashion.
242    #[cfg(any(test, feature = "test-support"))]
243    #[track_caller]
244    pub fn block_test<R>(&self, future: impl Future<Output = R>) -> R {
245        if let Ok(value) = self.block_internal(false, future, None) {
246            value
247        } else {
248            unreachable!()
249        }
250    }
251
252    /// Block the current thread until the given future resolves.
253    /// Consider using `block_with_timeout` instead.
254    pub fn block<R>(&self, future: impl Future<Output = R>) -> R {
255        if let Ok(value) = self.block_internal(true, future, None) {
256            value
257        } else {
258            unreachable!()
259        }
260    }
261
262    #[cfg(not(any(test, feature = "test-support")))]
263    pub(crate) fn block_internal<Fut: Future>(
264        &self,
265        _background_only: bool,
266        future: Fut,
267        timeout: Option<Duration>,
268    ) -> Result<Fut::Output, impl Future<Output = Fut::Output> + use<Fut>> {
269        use std::time::Instant;
270
271        let mut future = Box::pin(future);
272        if timeout == Some(Duration::ZERO) {
273            return Err(future);
274        }
275        let deadline = timeout.map(|timeout| Instant::now() + timeout);
276
277        let parker = parking::Parker::new();
278        let unparker = parker.unparker();
279        let waker = waker_fn(move || {
280            unparker.unpark();
281        });
282        let mut cx = std::task::Context::from_waker(&waker);
283
284        loop {
285            match future.as_mut().poll(&mut cx) {
286                Poll::Ready(result) => return Ok(result),
287                Poll::Pending => {
288                    let timeout =
289                        deadline.map(|deadline| deadline.saturating_duration_since(Instant::now()));
290                    if let Some(timeout) = timeout {
291                        if !parker.park_timeout(timeout)
292                            && deadline.is_some_and(|deadline| deadline < Instant::now())
293                        {
294                            return Err(future);
295                        }
296                    } else {
297                        parker.park();
298                    }
299                }
300            }
301        }
302    }
303
304    #[cfg(any(test, feature = "test-support"))]
305    #[track_caller]
306    pub(crate) fn block_internal<Fut: Future>(
307        &self,
308        background_only: bool,
309        future: Fut,
310        timeout: Option<Duration>,
311    ) -> Result<Fut::Output, impl Future<Output = Fut::Output> + use<Fut>> {
312        use std::sync::atomic::AtomicBool;
313
314        use parking::Parker;
315
316        let mut future = Box::pin(future);
317        if timeout == Some(Duration::ZERO) {
318            return Err(future);
319        }
320        let Some(dispatcher) = self.dispatcher.as_test() else {
321            return Err(future);
322        };
323
324        let mut max_ticks = if timeout.is_some() {
325            dispatcher.gen_block_on_ticks()
326        } else {
327            usize::MAX
328        };
329
330        let parker = Parker::new();
331        let unparker = parker.unparker();
332
333        let awoken = Arc::new(AtomicBool::new(false));
334        let waker = waker_fn({
335            let awoken = awoken.clone();
336            let unparker = unparker.clone();
337            move || {
338                awoken.store(true, Ordering::SeqCst);
339                unparker.unpark();
340            }
341        });
342        let mut cx = std::task::Context::from_waker(&waker);
343
344        let duration = Duration::from_secs(
345            option_env!("GPUI_TEST_TIMEOUT")
346                .and_then(|s| s.parse::<u64>().ok())
347                .unwrap_or(180),
348        );
349        let mut test_should_end_by = Instant::now() + duration;
350
351        loop {
352            match future.as_mut().poll(&mut cx) {
353                Poll::Ready(result) => return Ok(result),
354                Poll::Pending => {
355                    if max_ticks == 0 {
356                        return Err(future);
357                    }
358                    max_ticks -= 1;
359
360                    if !dispatcher.tick(background_only) {
361                        if awoken.swap(false, Ordering::SeqCst) {
362                            continue;
363                        }
364
365                        if !dispatcher.parking_allowed() {
366                            if dispatcher.advance_clock_to_next_delayed() {
367                                continue;
368                            }
369                            let mut backtrace_message = String::new();
370                            let mut waiting_message = String::new();
371                            if let Some(backtrace) = dispatcher.waiting_backtrace() {
372                                backtrace_message =
373                                    format!("\nbacktrace of waiting future:\n{:?}", backtrace);
374                            }
375                            if let Some(waiting_hint) = dispatcher.waiting_hint() {
376                                waiting_message = format!("\n  waiting on: {}\n", waiting_hint);
377                            }
378                            panic!(
379                                "parked with nothing left to run{waiting_message}{backtrace_message}",
380                            )
381                        }
382                        dispatcher.push_unparker(unparker.clone());
383                        parker.park_timeout(Duration::from_millis(1));
384                        if Instant::now() > test_should_end_by {
385                            panic!("test timed out after {duration:?} with allow_parking")
386                        }
387                    }
388                }
389            }
390        }
391    }
392
393    /// Block the current thread until the given future resolves
394    /// or `duration` has elapsed.
395    pub fn block_with_timeout<Fut: Future>(
396        &self,
397        duration: Duration,
398        future: Fut,
399    ) -> Result<Fut::Output, impl Future<Output = Fut::Output> + use<Fut>> {
400        self.block_internal(true, future, Some(duration))
401    }
402
403    /// Scoped lets you start a number of tasks and waits
404    /// for all of them to complete before returning.
405    pub async fn scoped<'scope, F>(&self, scheduler: F)
406    where
407        F: FnOnce(&mut Scope<'scope>),
408    {
409        let mut scope = Scope::new(self.clone());
410        (scheduler)(&mut scope);
411        let spawned = mem::take(&mut scope.futures)
412            .into_iter()
413            .map(|f| self.spawn(f))
414            .collect::<Vec<_>>();
415        for task in spawned {
416            task.await;
417        }
418    }
419
420    /// Get the current time.
421    ///
422    /// Calling this instead of `std::time::Instant::now` allows the use
423    /// of fake timers in tests.
424    pub fn now(&self) -> Instant {
425        self.dispatcher.now()
426    }
427
428    /// Returns a task that will complete after the given duration.
429    /// Depending on other concurrent tasks the elapsed duration may be longer
430    /// than requested.
431    pub fn timer(&self, duration: Duration) -> Task<()> {
432        if duration.is_zero() {
433            return Task::ready(());
434        }
435        let location = core::panic::Location::caller();
436        let (runnable, task) = async_task::Builder::new()
437            .metadata(RunnableMeta { location })
438            .spawn(move |_| async move {}, {
439                let dispatcher = self.dispatcher.clone();
440                move |runnable| dispatcher.dispatch_after(duration, RunnableVariant::Meta(runnable))
441            });
442        runnable.schedule();
443        Task(TaskState::Spawned(task))
444    }
445
446    /// in tests, start_waiting lets you indicate which task is waiting (for debugging only)
447    #[cfg(any(test, feature = "test-support"))]
448    pub fn start_waiting(&self) {
449        self.dispatcher.as_test().unwrap().start_waiting();
450    }
451
452    /// in tests, removes the debugging data added by start_waiting
453    #[cfg(any(test, feature = "test-support"))]
454    pub fn finish_waiting(&self) {
455        self.dispatcher.as_test().unwrap().finish_waiting();
456    }
457
458    /// in tests, run an arbitrary number of tasks (determined by the SEED environment variable)
459    #[cfg(any(test, feature = "test-support"))]
460    pub fn simulate_random_delay(&self) -> impl Future<Output = ()> + use<> {
461        self.dispatcher.as_test().unwrap().simulate_random_delay()
462    }
463
464    /// in tests, indicate that a given task from `spawn_labeled` should run after everything else
465    #[cfg(any(test, feature = "test-support"))]
466    pub fn deprioritize(&self, task_label: TaskLabel) {
467        self.dispatcher.as_test().unwrap().deprioritize(task_label)
468    }
469
470    /// in tests, move time forward. This does not run any tasks, but does make `timer`s ready.
471    #[cfg(any(test, feature = "test-support"))]
472    pub fn advance_clock(&self, duration: Duration) {
473        self.dispatcher.as_test().unwrap().advance_clock(duration)
474    }
475
476    /// in tests, run one task.
477    #[cfg(any(test, feature = "test-support"))]
478    pub fn tick(&self) -> bool {
479        self.dispatcher.as_test().unwrap().tick(false)
480    }
481
482    /// in tests, run all tasks that are ready to run. If after doing so
483    /// the test still has outstanding tasks, this will panic. (See also [`Self::allow_parking`])
484    #[cfg(any(test, feature = "test-support"))]
485    pub fn run_until_parked(&self) {
486        self.dispatcher.as_test().unwrap().run_until_parked()
487    }
488
489    /// in tests, prevents `run_until_parked` from panicking if there are outstanding tasks.
490    /// This is useful when you are integrating other (non-GPUI) futures, like disk access, that
491    /// do take real async time to run.
492    #[cfg(any(test, feature = "test-support"))]
493    pub fn allow_parking(&self) {
494        self.dispatcher.as_test().unwrap().allow_parking();
495    }
496
497    /// undoes the effect of [`Self::allow_parking`].
498    #[cfg(any(test, feature = "test-support"))]
499    pub fn forbid_parking(&self) {
500        self.dispatcher.as_test().unwrap().forbid_parking();
501    }
502
503    /// adds detail to the "parked with nothing let to run" message.
504    #[cfg(any(test, feature = "test-support"))]
505    pub fn set_waiting_hint(&self, msg: Option<String>) {
506        self.dispatcher.as_test().unwrap().set_waiting_hint(msg);
507    }
508
509    /// in tests, returns the rng used by the dispatcher and seeded by the `SEED` environment variable
510    #[cfg(any(test, feature = "test-support"))]
511    pub fn rng(&self) -> StdRng {
512        self.dispatcher.as_test().unwrap().rng()
513    }
514
515    /// How many CPUs are available to the dispatcher.
516    pub fn num_cpus(&self) -> usize {
517        #[cfg(any(test, feature = "test-support"))]
518        return 4;
519
520        #[cfg(not(any(test, feature = "test-support")))]
521        return num_cpus::get();
522    }
523
524    /// Whether we're on the main thread.
525    pub fn is_main_thread(&self) -> bool {
526        self.dispatcher.is_main_thread()
527    }
528
529    #[cfg(any(test, feature = "test-support"))]
530    /// in tests, control the number of ticks that `block_with_timeout` will run before timing out.
531    pub fn set_block_on_ticks(&self, range: std::ops::RangeInclusive<usize>) {
532        self.dispatcher.as_test().unwrap().set_block_on_ticks(range);
533    }
534}
535
536/// ForegroundExecutor runs things on the main thread.
537impl ForegroundExecutor {
538    /// Creates a new ForegroundExecutor from the given PlatformDispatcher.
539    pub fn new(dispatcher: Arc<dyn PlatformDispatcher>) -> Self {
540        Self {
541            dispatcher,
542            not_send: PhantomData,
543        }
544    }
545
546    /// Enqueues the given Task to run on the main thread at some point in the future.
547    #[track_caller]
548    pub fn spawn<R>(&self, future: impl Future<Output = R> + 'static) -> Task<R>
549    where
550        R: 'static,
551    {
552        let dispatcher = self.dispatcher.clone();
553        let location = core::panic::Location::caller();
554
555        #[track_caller]
556        fn inner<R: 'static>(
557            dispatcher: Arc<dyn PlatformDispatcher>,
558            future: AnyLocalFuture<R>,
559            location: &'static core::panic::Location<'static>,
560        ) -> Task<R> {
561            let (runnable, task) = spawn_local_with_source_location(
562                future,
563                move |runnable| dispatcher.dispatch_on_main_thread(RunnableVariant::Meta(runnable)),
564                RunnableMeta { location },
565            );
566            runnable.schedule();
567            Task(TaskState::Spawned(task))
568        }
569        inner::<R>(dispatcher, Box::pin(future), location)
570    }
571}
572
573/// Variant of `async_task::spawn_local` that includes the source location of the spawn in panics.
574///
575/// Copy-modified from:
576/// <https://github.com/smol-rs/async-task/blob/ca9dbe1db9c422fd765847fa91306e30a6bb58a9/src/runnable.rs#L405>
577#[track_caller]
578fn spawn_local_with_source_location<Fut, S, M>(
579    future: Fut,
580    schedule: S,
581    metadata: M,
582) -> (Runnable<M>, async_task::Task<Fut::Output, M>)
583where
584    Fut: Future + 'static,
585    Fut::Output: 'static,
586    S: async_task::Schedule<M> + Send + Sync + 'static,
587    M: 'static,
588{
589    #[inline]
590    fn thread_id() -> ThreadId {
591        std::thread_local! {
592            static ID: ThreadId = thread::current().id();
593        }
594        ID.try_with(|id| *id)
595            .unwrap_or_else(|_| thread::current().id())
596    }
597
598    struct Checked<F> {
599        id: ThreadId,
600        inner: ManuallyDrop<F>,
601        location: &'static Location<'static>,
602    }
603
604    impl<F> Drop for Checked<F> {
605        fn drop(&mut self) {
606            assert!(
607                self.id == thread_id(),
608                "local task dropped by a thread that didn't spawn it. Task spawned at {}",
609                self.location
610            );
611            unsafe { ManuallyDrop::drop(&mut self.inner) };
612        }
613    }
614
615    impl<F: Future> Future for Checked<F> {
616        type Output = F::Output;
617
618        fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
619            assert!(
620                self.id == thread_id(),
621                "local task polled by a thread that didn't spawn it. Task spawned at {}",
622                self.location
623            );
624            unsafe { self.map_unchecked_mut(|c| &mut *c.inner).poll(cx) }
625        }
626    }
627
628    // Wrap the future into one that checks which thread it's on.
629    let future = Checked {
630        id: thread_id(),
631        inner: ManuallyDrop::new(future),
632        location: Location::caller(),
633    };
634
635    unsafe {
636        async_task::Builder::new()
637            .metadata(metadata)
638            .spawn_unchecked(move |_| future, schedule)
639    }
640}
641
642/// Scope manages a set of tasks that are enqueued and waited on together. See [`BackgroundExecutor::scoped`].
643pub struct Scope<'a> {
644    executor: BackgroundExecutor,
645    futures: Vec<Pin<Box<dyn Future<Output = ()> + Send + 'static>>>,
646    tx: Option<mpsc::Sender<()>>,
647    rx: mpsc::Receiver<()>,
648    lifetime: PhantomData<&'a ()>,
649}
650
651impl<'a> Scope<'a> {
652    fn new(executor: BackgroundExecutor) -> Self {
653        let (tx, rx) = mpsc::channel(1);
654        Self {
655            executor,
656            tx: Some(tx),
657            rx,
658            futures: Default::default(),
659            lifetime: PhantomData,
660        }
661    }
662
663    /// How many CPUs are available to the dispatcher.
664    pub fn num_cpus(&self) -> usize {
665        self.executor.num_cpus()
666    }
667
668    /// Spawn a future into this scope.
669    #[track_caller]
670    pub fn spawn<F>(&mut self, f: F)
671    where
672        F: Future<Output = ()> + Send + 'a,
673    {
674        let tx = self.tx.clone().unwrap();
675
676        // SAFETY: The 'a lifetime is guaranteed to outlive any of these futures because
677        // dropping this `Scope` blocks until all of the futures have resolved.
678        let f = unsafe {
679            mem::transmute::<
680                Pin<Box<dyn Future<Output = ()> + Send + 'a>>,
681                Pin<Box<dyn Future<Output = ()> + Send + 'static>>,
682            >(Box::pin(async move {
683                f.await;
684                drop(tx);
685            }))
686        };
687        self.futures.push(f);
688    }
689}
690
691impl Drop for Scope<'_> {
692    fn drop(&mut self) {
693        self.tx.take().unwrap();
694
695        // Wait until the channel is closed, which means that all of the spawned
696        // futures have resolved.
697        self.executor.block(self.rx.next());
698    }
699}