executor.rs

  1use crate::{App, PlatformDispatcher};
  2use async_task::Runnable;
  3use futures::channel::mpsc;
  4use smol::prelude::*;
  5use std::mem::ManuallyDrop;
  6use std::panic::Location;
  7use std::thread::{self, ThreadId};
  8use std::{
  9    fmt::Debug,
 10    marker::PhantomData,
 11    mem,
 12    num::NonZeroUsize,
 13    pin::Pin,
 14    rc::Rc,
 15    sync::{
 16        Arc,
 17        atomic::{AtomicUsize, Ordering::SeqCst},
 18    },
 19    task::{Context, Poll},
 20    time::{Duration, Instant},
 21};
 22use util::TryFutureExt;
 23use waker_fn::waker_fn;
 24
 25#[cfg(any(test, feature = "test-support"))]
 26use rand::rngs::StdRng;
 27
 28/// A pointer to the executor that is currently running,
 29/// for spawning background tasks.
 30#[derive(Clone)]
 31pub struct BackgroundExecutor {
 32    #[doc(hidden)]
 33    pub dispatcher: Arc<dyn PlatformDispatcher>,
 34}
 35
 36/// A pointer to the executor that is currently running,
 37/// for spawning tasks on the main thread.
 38///
 39/// This is intentionally `!Send` via the `not_send` marker field. This is because
 40/// `ForegroundExecutor::spawn` does not require `Send` but checks at runtime that the future is
 41/// only polled from the same thread it was spawned from. These checks would fail when spawning
 42/// foreground tasks from from background threads.
 43#[derive(Clone)]
 44pub struct ForegroundExecutor {
 45    #[doc(hidden)]
 46    pub dispatcher: Arc<dyn PlatformDispatcher>,
 47    not_send: PhantomData<Rc<()>>,
 48}
 49
 50/// Task is a primitive that allows work to happen in the background.
 51///
 52/// It implements [`Future`] so you can `.await` on it.
 53///
 54/// If you drop a task it will be cancelled immediately. Calling [`Task::detach`] allows
 55/// the task to continue running, but with no way to return a value.
 56#[must_use]
 57#[derive(Debug)]
 58pub struct Task<T>(TaskState<T>);
 59
 60#[derive(Debug)]
 61enum TaskState<T> {
 62    /// A task that is ready to return a value
 63    Ready(Option<T>),
 64
 65    /// A task that is currently running.
 66    Spawned(async_task::Task<T>),
 67}
 68
 69impl<T> Task<T> {
 70    /// Creates a new task that will resolve with the value
 71    pub fn ready(val: T) -> Self {
 72        Task(TaskState::Ready(Some(val)))
 73    }
 74
 75    /// Detaching a task runs it to completion in the background
 76    pub fn detach(self) {
 77        match self {
 78            Task(TaskState::Ready(_)) => {}
 79            Task(TaskState::Spawned(task)) => task.detach(),
 80        }
 81    }
 82}
 83
 84impl<E, T> Task<Result<T, E>>
 85where
 86    T: 'static,
 87    E: 'static + Debug,
 88{
 89    /// Run the task to completion in the background and log any
 90    /// errors that occur.
 91    #[track_caller]
 92    pub fn detach_and_log_err(self, cx: &App) {
 93        let location = core::panic::Location::caller();
 94        cx.foreground_executor()
 95            .spawn(self.log_tracked_err(*location))
 96            .detach();
 97    }
 98}
 99
100impl<T> Future for Task<T> {
101    type Output = T;
102
103    fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
104        match unsafe { self.get_unchecked_mut() } {
105            Task(TaskState::Ready(val)) => Poll::Ready(val.take().unwrap()),
106            Task(TaskState::Spawned(task)) => task.poll(cx),
107        }
108    }
109}
110
111/// A task label is an opaque identifier that you can use to
112/// refer to a task in tests.
113#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
114pub struct TaskLabel(NonZeroUsize);
115
116impl Default for TaskLabel {
117    fn default() -> Self {
118        Self::new()
119    }
120}
121
122impl TaskLabel {
123    /// Construct a new task label.
124    pub fn new() -> Self {
125        static NEXT_TASK_LABEL: AtomicUsize = AtomicUsize::new(1);
126        Self(NEXT_TASK_LABEL.fetch_add(1, SeqCst).try_into().unwrap())
127    }
128}
129
130type AnyLocalFuture<R> = Pin<Box<dyn 'static + Future<Output = R>>>;
131
132type AnyFuture<R> = Pin<Box<dyn 'static + Send + Future<Output = R>>>;
133
134/// BackgroundExecutor lets you run things on background threads.
135/// In production this is a thread pool with no ordering guarantees.
136/// In tests this is simulated by running tasks one by one in a deterministic
137/// (but arbitrary) order controlled by the `SEED` environment variable.
138impl BackgroundExecutor {
139    #[doc(hidden)]
140    pub fn new(dispatcher: Arc<dyn PlatformDispatcher>) -> Self {
141        Self { dispatcher }
142    }
143
144    /// Enqueues the given future to be run to completion on a background thread.
145    pub fn spawn<R>(&self, future: impl Future<Output = R> + Send + 'static) -> Task<R>
146    where
147        R: Send + 'static,
148    {
149        self.spawn_internal::<R>(Box::pin(future), None)
150    }
151
152    /// Enqueues the given future to be run to completion on a background thread.
153    /// The given label can be used to control the priority of the task in tests.
154    pub fn spawn_labeled<R>(
155        &self,
156        label: TaskLabel,
157        future: impl Future<Output = R> + Send + 'static,
158    ) -> Task<R>
159    where
160        R: Send + 'static,
161    {
162        self.spawn_internal::<R>(Box::pin(future), Some(label))
163    }
164
165    fn spawn_internal<R: Send + 'static>(
166        &self,
167        future: AnyFuture<R>,
168        label: Option<TaskLabel>,
169    ) -> Task<R> {
170        let dispatcher = self.dispatcher.clone();
171        let (runnable, task) =
172            async_task::spawn(future, move |runnable| dispatcher.dispatch(runnable, label));
173        runnable.schedule();
174        Task(TaskState::Spawned(task))
175    }
176
177    /// Used by the test harness to run an async test in a synchronous fashion.
178    #[cfg(any(test, feature = "test-support"))]
179    #[track_caller]
180    pub fn block_test<R>(&self, future: impl Future<Output = R>) -> R {
181        if let Ok(value) = self.block_internal(false, future, None) {
182            value
183        } else {
184            unreachable!()
185        }
186    }
187
188    /// Block the current thread until the given future resolves.
189    /// Consider using `block_with_timeout` instead.
190    pub fn block<R>(&self, future: impl Future<Output = R>) -> R {
191        if let Ok(value) = self.block_internal(true, future, None) {
192            value
193        } else {
194            unreachable!()
195        }
196    }
197
198    #[cfg(not(any(test, feature = "test-support")))]
199    pub(crate) fn block_internal<Fut: Future>(
200        &self,
201        _background_only: bool,
202        future: Fut,
203        timeout: Option<Duration>,
204    ) -> Result<Fut::Output, impl Future<Output = Fut::Output> + use<Fut>> {
205        use std::time::Instant;
206
207        let mut future = Box::pin(future);
208        if timeout == Some(Duration::ZERO) {
209            return Err(future);
210        }
211        let deadline = timeout.map(|timeout| Instant::now() + timeout);
212
213        let unparker = self.dispatcher.unparker();
214        let waker = waker_fn(move || {
215            unparker.unpark();
216        });
217        let mut cx = std::task::Context::from_waker(&waker);
218
219        loop {
220            match future.as_mut().poll(&mut cx) {
221                Poll::Ready(result) => return Ok(result),
222                Poll::Pending => {
223                    let timeout =
224                        deadline.map(|deadline| deadline.saturating_duration_since(Instant::now()));
225                    if !self.dispatcher.park(timeout)
226                        && deadline.is_some_and(|deadline| deadline < Instant::now())
227                    {
228                        return Err(future);
229                    }
230                }
231            }
232        }
233    }
234
235    #[cfg(any(test, feature = "test-support"))]
236    #[track_caller]
237    pub(crate) fn block_internal<Fut: Future>(
238        &self,
239        background_only: bool,
240        future: Fut,
241        timeout: Option<Duration>,
242    ) -> Result<Fut::Output, impl Future<Output = Fut::Output> + use<Fut>> {
243        use std::sync::atomic::AtomicBool;
244
245        let mut future = Box::pin(future);
246        if timeout == Some(Duration::ZERO) {
247            return Err(future);
248        }
249        let Some(dispatcher) = self.dispatcher.as_test() else {
250            return Err(future);
251        };
252
253        let mut max_ticks = if timeout.is_some() {
254            dispatcher.gen_block_on_ticks()
255        } else {
256            usize::MAX
257        };
258        let unparker = self.dispatcher.unparker();
259        let awoken = Arc::new(AtomicBool::new(false));
260        let waker = waker_fn({
261            let awoken = awoken.clone();
262            move || {
263                awoken.store(true, SeqCst);
264                unparker.unpark();
265            }
266        });
267        let mut cx = std::task::Context::from_waker(&waker);
268
269        loop {
270            match future.as_mut().poll(&mut cx) {
271                Poll::Ready(result) => return Ok(result),
272                Poll::Pending => {
273                    if max_ticks == 0 {
274                        return Err(future);
275                    }
276                    max_ticks -= 1;
277
278                    if !dispatcher.tick(background_only) {
279                        if awoken.swap(false, SeqCst) {
280                            continue;
281                        }
282
283                        if !dispatcher.parking_allowed() {
284                            if dispatcher.advance_clock_to_next_delayed() {
285                                continue;
286                            }
287                            let mut backtrace_message = String::new();
288                            let mut waiting_message = String::new();
289                            if let Some(backtrace) = dispatcher.waiting_backtrace() {
290                                backtrace_message =
291                                    format!("\nbacktrace of waiting future:\n{:?}", backtrace);
292                            }
293                            if let Some(waiting_hint) = dispatcher.waiting_hint() {
294                                waiting_message = format!("\n  waiting on: {}\n", waiting_hint);
295                            }
296                            panic!(
297                                "parked with nothing left to run{waiting_message}{backtrace_message}",
298                            )
299                        }
300                        self.dispatcher.park(None);
301                    }
302                }
303            }
304        }
305    }
306
307    /// Block the current thread until the given future resolves
308    /// or `duration` has elapsed.
309    pub fn block_with_timeout<Fut: Future>(
310        &self,
311        duration: Duration,
312        future: Fut,
313    ) -> Result<Fut::Output, impl Future<Output = Fut::Output> + use<Fut>> {
314        self.block_internal(true, future, Some(duration))
315    }
316
317    /// Scoped lets you start a number of tasks and waits
318    /// for all of them to complete before returning.
319    pub async fn scoped<'scope, F>(&self, scheduler: F)
320    where
321        F: FnOnce(&mut Scope<'scope>),
322    {
323        let mut scope = Scope::new(self.clone());
324        (scheduler)(&mut scope);
325        let spawned = mem::take(&mut scope.futures)
326            .into_iter()
327            .map(|f| self.spawn(f))
328            .collect::<Vec<_>>();
329        for task in spawned {
330            task.await;
331        }
332    }
333
334    /// Get the current time.
335    ///
336    /// Calling this instead of `std::time::Instant::now` allows the use
337    /// of fake timers in tests.
338    pub fn now(&self) -> Instant {
339        self.dispatcher.now()
340    }
341
342    /// Returns a task that will complete after the given duration.
343    /// Depending on other concurrent tasks the elapsed duration may be longer
344    /// than requested.
345    pub fn timer(&self, duration: Duration) -> Task<()> {
346        if duration.is_zero() {
347            return Task::ready(());
348        }
349        let (runnable, task) = async_task::spawn(async move {}, {
350            let dispatcher = self.dispatcher.clone();
351            move |runnable| dispatcher.dispatch_after(duration, runnable)
352        });
353        runnable.schedule();
354        Task(TaskState::Spawned(task))
355    }
356
357    /// in tests, start_waiting lets you indicate which task is waiting (for debugging only)
358    #[cfg(any(test, feature = "test-support"))]
359    pub fn start_waiting(&self) {
360        self.dispatcher.as_test().unwrap().start_waiting();
361    }
362
363    /// in tests, removes the debugging data added by start_waiting
364    #[cfg(any(test, feature = "test-support"))]
365    pub fn finish_waiting(&self) {
366        self.dispatcher.as_test().unwrap().finish_waiting();
367    }
368
369    /// in tests, run an arbitrary number of tasks (determined by the SEED environment variable)
370    #[cfg(any(test, feature = "test-support"))]
371    pub fn simulate_random_delay(&self) -> impl Future<Output = ()> + use<> {
372        self.dispatcher.as_test().unwrap().simulate_random_delay()
373    }
374
375    /// in tests, indicate that a given task from `spawn_labeled` should run after everything else
376    #[cfg(any(test, feature = "test-support"))]
377    pub fn deprioritize(&self, task_label: TaskLabel) {
378        self.dispatcher.as_test().unwrap().deprioritize(task_label)
379    }
380
381    /// in tests, move time forward. This does not run any tasks, but does make `timer`s ready.
382    #[cfg(any(test, feature = "test-support"))]
383    pub fn advance_clock(&self, duration: Duration) {
384        self.dispatcher.as_test().unwrap().advance_clock(duration)
385    }
386
387    /// in tests, run one task.
388    #[cfg(any(test, feature = "test-support"))]
389    pub fn tick(&self) -> bool {
390        self.dispatcher.as_test().unwrap().tick(false)
391    }
392
393    /// in tests, run all tasks that are ready to run. If after doing so
394    /// the test still has outstanding tasks, this will panic. (See also [`Self::allow_parking`])
395    #[cfg(any(test, feature = "test-support"))]
396    pub fn run_until_parked(&self) {
397        self.dispatcher.as_test().unwrap().run_until_parked()
398    }
399
400    /// in tests, prevents `run_until_parked` from panicking if there are outstanding tasks.
401    /// This is useful when you are integrating other (non-GPUI) futures, like disk access, that
402    /// do take real async time to run.
403    #[cfg(any(test, feature = "test-support"))]
404    pub fn allow_parking(&self) {
405        self.dispatcher.as_test().unwrap().allow_parking();
406    }
407
408    /// undoes the effect of [`Self::allow_parking`].
409    #[cfg(any(test, feature = "test-support"))]
410    pub fn forbid_parking(&self) {
411        self.dispatcher.as_test().unwrap().forbid_parking();
412    }
413
414    /// adds detail to the "parked with nothing let to run" message.
415    #[cfg(any(test, feature = "test-support"))]
416    pub fn set_waiting_hint(&self, msg: Option<String>) {
417        self.dispatcher.as_test().unwrap().set_waiting_hint(msg);
418    }
419
420    /// in tests, returns the rng used by the dispatcher and seeded by the `SEED` environment variable
421    #[cfg(any(test, feature = "test-support"))]
422    pub fn rng(&self) -> StdRng {
423        self.dispatcher.as_test().unwrap().rng()
424    }
425
426    /// How many CPUs are available to the dispatcher.
427    pub fn num_cpus(&self) -> usize {
428        #[cfg(any(test, feature = "test-support"))]
429        return 4;
430
431        #[cfg(not(any(test, feature = "test-support")))]
432        return num_cpus::get();
433    }
434
435    /// Whether we're on the main thread.
436    pub fn is_main_thread(&self) -> bool {
437        self.dispatcher.is_main_thread()
438    }
439
440    #[cfg(any(test, feature = "test-support"))]
441    /// in tests, control the number of ticks that `block_with_timeout` will run before timing out.
442    pub fn set_block_on_ticks(&self, range: std::ops::RangeInclusive<usize>) {
443        self.dispatcher.as_test().unwrap().set_block_on_ticks(range);
444    }
445}
446
447/// ForegroundExecutor runs things on the main thread.
448impl ForegroundExecutor {
449    /// Creates a new ForegroundExecutor from the given PlatformDispatcher.
450    pub fn new(dispatcher: Arc<dyn PlatformDispatcher>) -> Self {
451        Self {
452            dispatcher,
453            not_send: PhantomData,
454        }
455    }
456
457    /// Enqueues the given Task to run on the main thread at some point in the future.
458    #[track_caller]
459    pub fn spawn<R>(&self, future: impl Future<Output = R> + 'static) -> Task<R>
460    where
461        R: 'static,
462    {
463        let dispatcher = self.dispatcher.clone();
464
465        #[track_caller]
466        fn inner<R: 'static>(
467            dispatcher: Arc<dyn PlatformDispatcher>,
468            future: AnyLocalFuture<R>,
469        ) -> Task<R> {
470            let (runnable, task) = spawn_local_with_source_location(future, move |runnable| {
471                dispatcher.dispatch_on_main_thread(runnable)
472            });
473            runnable.schedule();
474            Task(TaskState::Spawned(task))
475        }
476        inner::<R>(dispatcher, Box::pin(future))
477    }
478}
479
480/// Variant of `async_task::spawn_local` that includes the source location of the spawn in panics.
481///
482/// Copy-modified from:
483/// <https://github.com/smol-rs/async-task/blob/ca9dbe1db9c422fd765847fa91306e30a6bb58a9/src/runnable.rs#L405>
484#[track_caller]
485fn spawn_local_with_source_location<Fut, S>(
486    future: Fut,
487    schedule: S,
488) -> (Runnable<()>, async_task::Task<Fut::Output, ()>)
489where
490    Fut: Future + 'static,
491    Fut::Output: 'static,
492    S: async_task::Schedule<()> + Send + Sync + 'static,
493{
494    #[inline]
495    fn thread_id() -> ThreadId {
496        std::thread_local! {
497            static ID: ThreadId = thread::current().id();
498        }
499        ID.try_with(|id| *id)
500            .unwrap_or_else(|_| thread::current().id())
501    }
502
503    struct Checked<F> {
504        id: ThreadId,
505        inner: ManuallyDrop<F>,
506        location: &'static Location<'static>,
507    }
508
509    impl<F> Drop for Checked<F> {
510        fn drop(&mut self) {
511            assert!(
512                self.id == thread_id(),
513                "local task dropped by a thread that didn't spawn it. Task spawned at {}",
514                self.location
515            );
516            unsafe {
517                ManuallyDrop::drop(&mut self.inner);
518            }
519        }
520    }
521
522    impl<F: Future> Future for Checked<F> {
523        type Output = F::Output;
524
525        fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
526            assert!(
527                self.id == thread_id(),
528                "local task polled by a thread that didn't spawn it. Task spawned at {}",
529                self.location
530            );
531            unsafe { self.map_unchecked_mut(|c| &mut *c.inner).poll(cx) }
532        }
533    }
534
535    // Wrap the future into one that checks which thread it's on.
536    let future = Checked {
537        id: thread_id(),
538        inner: ManuallyDrop::new(future),
539        location: Location::caller(),
540    };
541
542    unsafe { async_task::spawn_unchecked(future, schedule) }
543}
544
545/// Scope manages a set of tasks that are enqueued and waited on together. See [`BackgroundExecutor::scoped`].
546pub struct Scope<'a> {
547    executor: BackgroundExecutor,
548    futures: Vec<Pin<Box<dyn Future<Output = ()> + Send + 'static>>>,
549    tx: Option<mpsc::Sender<()>>,
550    rx: mpsc::Receiver<()>,
551    lifetime: PhantomData<&'a ()>,
552}
553
554impl<'a> Scope<'a> {
555    fn new(executor: BackgroundExecutor) -> Self {
556        let (tx, rx) = mpsc::channel(1);
557        Self {
558            executor,
559            tx: Some(tx),
560            rx,
561            futures: Default::default(),
562            lifetime: PhantomData,
563        }
564    }
565
566    /// How many CPUs are available to the dispatcher.
567    pub fn num_cpus(&self) -> usize {
568        self.executor.num_cpus()
569    }
570
571    /// Spawn a future into this scope.
572    pub fn spawn<F>(&mut self, f: F)
573    where
574        F: Future<Output = ()> + Send + 'a,
575    {
576        let tx = self.tx.clone().unwrap();
577
578        // SAFETY: The 'a lifetime is guaranteed to outlive any of these futures because
579        // dropping this `Scope` blocks until all of the futures have resolved.
580        let f = unsafe {
581            mem::transmute::<
582                Pin<Box<dyn Future<Output = ()> + Send + 'a>>,
583                Pin<Box<dyn Future<Output = ()> + Send + 'static>>,
584            >(Box::pin(async move {
585                f.await;
586                drop(tx);
587            }))
588        };
589        self.futures.push(f);
590    }
591}
592
593impl Drop for Scope<'_> {
594    fn drop(&mut self) {
595        self.tx.take().unwrap();
596
597        // Wait until the channel is closed, which means that all of the spawned
598        // futures have resolved.
599        self.executor.block(self.rx.next());
600    }
601}