executor.rs

  1use crate::{App, PlatformDispatcher};
  2use async_task::Runnable;
  3use futures::channel::mpsc;
  4use smol::prelude::*;
  5use std::{
  6    fmt::Debug,
  7    marker::PhantomData,
  8    mem::{self, ManuallyDrop},
  9    num::NonZeroUsize,
 10    panic::Location,
 11    pin::Pin,
 12    rc::Rc,
 13    sync::{
 14        Arc,
 15        atomic::{AtomicUsize, Ordering},
 16    },
 17    task::{Context, Poll},
 18    thread::{self, ThreadId},
 19    time::{Duration, Instant},
 20};
 21use util::TryFutureExt;
 22use waker_fn::waker_fn;
 23
 24#[cfg(any(test, feature = "test-support"))]
 25use rand::rngs::StdRng;
 26
 27/// A pointer to the executor that is currently running,
 28/// for spawning background tasks.
 29#[derive(Clone)]
 30pub struct BackgroundExecutor {
 31    #[doc(hidden)]
 32    pub dispatcher: Arc<dyn PlatformDispatcher>,
 33}
 34
 35/// A pointer to the executor that is currently running,
 36/// for spawning tasks on the main thread.
 37///
 38/// This is intentionally `!Send` via the `not_send` marker field. This is because
 39/// `ForegroundExecutor::spawn` does not require `Send` but checks at runtime that the future is
 40/// only polled from the same thread it was spawned from. These checks would fail when spawning
 41/// foreground tasks from from background threads.
 42#[derive(Clone)]
 43pub struct ForegroundExecutor {
 44    #[doc(hidden)]
 45    pub dispatcher: Arc<dyn PlatformDispatcher>,
 46    not_send: PhantomData<Rc<()>>,
 47}
 48
 49/// Task is a primitive that allows work to happen in the background.
 50///
 51/// It implements [`Future`] so you can `.await` on it.
 52///
 53/// If you drop a task it will be cancelled immediately. Calling [`Task::detach`] allows
 54/// the task to continue running, but with no way to return a value.
 55#[must_use]
 56#[derive(Debug)]
 57pub struct Task<T>(TaskState<T>);
 58
 59#[derive(Debug)]
 60enum TaskState<T> {
 61    /// A task that is ready to return a value
 62    Ready(Option<T>),
 63
 64    /// A task that is currently running.
 65    Spawned(async_task::Task<T>),
 66}
 67
 68impl<T> Task<T> {
 69    /// Creates a new task that will resolve with the value
 70    pub fn ready(val: T) -> Self {
 71        Task(TaskState::Ready(Some(val)))
 72    }
 73
 74    /// Detaching a task runs it to completion in the background
 75    pub fn detach(self) {
 76        match self {
 77            Task(TaskState::Ready(_)) => {}
 78            Task(TaskState::Spawned(task)) => task.detach(),
 79        }
 80    }
 81}
 82
 83impl<E, T> Task<Result<T, E>>
 84where
 85    T: 'static,
 86    E: 'static + Debug,
 87{
 88    /// Run the task to completion in the background and log any
 89    /// errors that occur.
 90    #[track_caller]
 91    pub fn detach_and_log_err(self, cx: &App) {
 92        let location = core::panic::Location::caller();
 93        cx.foreground_executor()
 94            .spawn(self.log_tracked_err(*location))
 95            .detach();
 96    }
 97}
 98
 99impl<T> Future for Task<T> {
100    type Output = T;
101
102    fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
103        match unsafe { self.get_unchecked_mut() } {
104            Task(TaskState::Ready(val)) => Poll::Ready(val.take().unwrap()),
105            Task(TaskState::Spawned(task)) => task.poll(cx),
106        }
107    }
108}
109
110/// A task label is an opaque identifier that you can use to
111/// refer to a task in tests.
112#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
113pub struct TaskLabel(NonZeroUsize);
114
115impl Default for TaskLabel {
116    fn default() -> Self {
117        Self::new()
118    }
119}
120
121impl TaskLabel {
122    /// Construct a new task label.
123    pub fn new() -> Self {
124        static NEXT_TASK_LABEL: AtomicUsize = AtomicUsize::new(1);
125        Self(
126            NEXT_TASK_LABEL
127                .fetch_add(1, Ordering::SeqCst)
128                .try_into()
129                .unwrap(),
130        )
131    }
132}
133
134type AnyLocalFuture<R> = Pin<Box<dyn 'static + Future<Output = R>>>;
135
136type AnyFuture<R> = Pin<Box<dyn 'static + Send + Future<Output = R>>>;
137
138/// BackgroundExecutor lets you run things on background threads.
139/// In production this is a thread pool with no ordering guarantees.
140/// In tests this is simulated by running tasks one by one in a deterministic
141/// (but arbitrary) order controlled by the `SEED` environment variable.
142impl BackgroundExecutor {
143    #[doc(hidden)]
144    pub fn new(dispatcher: Arc<dyn PlatformDispatcher>) -> Self {
145        Self { dispatcher }
146    }
147
148    /// Enqueues the given future to be run to completion on a background thread.
149    pub fn spawn<R>(&self, future: impl Future<Output = R> + Send + 'static) -> Task<R>
150    where
151        R: Send + 'static,
152    {
153        self.spawn_internal::<R>(Box::pin(future), None)
154    }
155
156    /// Enqueues the given future to be run to completion on a background thread.
157    /// The given label can be used to control the priority of the task in tests.
158    pub fn spawn_labeled<R>(
159        &self,
160        label: TaskLabel,
161        future: impl Future<Output = R> + Send + 'static,
162    ) -> Task<R>
163    where
164        R: Send + 'static,
165    {
166        self.spawn_internal::<R>(Box::pin(future), Some(label))
167    }
168
169    fn spawn_internal<R: Send + 'static>(
170        &self,
171        future: AnyFuture<R>,
172        label: Option<TaskLabel>,
173    ) -> Task<R> {
174        let dispatcher = self.dispatcher.clone();
175        let (runnable, task) =
176            async_task::spawn(future, move |runnable| dispatcher.dispatch(runnable, label));
177        runnable.schedule();
178        Task(TaskState::Spawned(task))
179    }
180
181    /// Used by the test harness to run an async test in a synchronous fashion.
182    #[cfg(any(test, feature = "test-support"))]
183    #[track_caller]
184    pub fn block_test<R>(&self, future: impl Future<Output = R>) -> R {
185        if let Ok(value) = self.block_internal(false, future, None) {
186            value
187        } else {
188            unreachable!()
189        }
190    }
191
192    /// Block the current thread until the given future resolves.
193    /// Consider using `block_with_timeout` instead.
194    pub fn block<R>(&self, future: impl Future<Output = R>) -> R {
195        if let Ok(value) = self.block_internal(true, future, None) {
196            value
197        } else {
198            unreachable!()
199        }
200    }
201
202    #[cfg(not(any(test, feature = "test-support")))]
203    pub(crate) fn block_internal<Fut: Future>(
204        &self,
205        _background_only: bool,
206        future: Fut,
207        timeout: Option<Duration>,
208    ) -> Result<Fut::Output, impl Future<Output = Fut::Output> + use<Fut>> {
209        use std::time::Instant;
210
211        let mut future = Box::pin(future);
212        if timeout == Some(Duration::ZERO) {
213            return Err(future);
214        }
215        let deadline = timeout.map(|timeout| Instant::now() + timeout);
216
217        let parker = parking::Parker::new();
218        let unparker = parker.unparker();
219        let waker = waker_fn(move || {
220            unparker.unpark();
221        });
222        let mut cx = std::task::Context::from_waker(&waker);
223
224        loop {
225            match future.as_mut().poll(&mut cx) {
226                Poll::Ready(result) => return Ok(result),
227                Poll::Pending => {
228                    let timeout =
229                        deadline.map(|deadline| deadline.saturating_duration_since(Instant::now()));
230                    if let Some(timeout) = timeout {
231                        if !parker.park_timeout(timeout)
232                            && deadline.is_some_and(|deadline| deadline < Instant::now())
233                        {
234                            return Err(future);
235                        }
236                    } else {
237                        parker.park();
238                    }
239                }
240            }
241        }
242    }
243
244    #[cfg(any(test, feature = "test-support"))]
245    #[track_caller]
246    pub(crate) fn block_internal<Fut: Future>(
247        &self,
248        background_only: bool,
249        future: Fut,
250        timeout: Option<Duration>,
251    ) -> Result<Fut::Output, impl Future<Output = Fut::Output> + use<Fut>> {
252        use std::sync::atomic::AtomicBool;
253
254        use parking::Parker;
255
256        let mut future = Box::pin(future);
257        if timeout == Some(Duration::ZERO) {
258            return Err(future);
259        }
260        let Some(dispatcher) = self.dispatcher.as_test() else {
261            return Err(future);
262        };
263
264        let mut max_ticks = if timeout.is_some() {
265            dispatcher.gen_block_on_ticks()
266        } else {
267            usize::MAX
268        };
269
270        let parker = Parker::new();
271        let unparker = parker.unparker();
272
273        let awoken = Arc::new(AtomicBool::new(false));
274        let waker = waker_fn({
275            let awoken = awoken.clone();
276            let unparker = unparker.clone();
277            move || {
278                awoken.store(true, Ordering::SeqCst);
279                unparker.unpark();
280            }
281        });
282        let mut cx = std::task::Context::from_waker(&waker);
283
284        let mut test_should_end_by = Instant::now() + Duration::from_secs(500);
285
286        loop {
287            match future.as_mut().poll(&mut cx) {
288                Poll::Ready(result) => return Ok(result),
289                Poll::Pending => {
290                    if max_ticks == 0 {
291                        return Err(future);
292                    }
293                    max_ticks -= 1;
294
295                    if !dispatcher.tick(background_only) {
296                        if awoken.swap(false, Ordering::SeqCst) {
297                            continue;
298                        }
299
300                        if !dispatcher.parking_allowed() {
301                            if dispatcher.advance_clock_to_next_delayed() {
302                                continue;
303                            }
304                            let mut backtrace_message = String::new();
305                            let mut waiting_message = String::new();
306                            if let Some(backtrace) = dispatcher.waiting_backtrace() {
307                                backtrace_message =
308                                    format!("\nbacktrace of waiting future:\n{:?}", backtrace);
309                            }
310                            if let Some(waiting_hint) = dispatcher.waiting_hint() {
311                                waiting_message = format!("\n  waiting on: {}\n", waiting_hint);
312                            }
313                            panic!(
314                                "parked with nothing left to run{waiting_message}{backtrace_message}",
315                            )
316                        }
317                        dispatcher.set_unparker(unparker.clone());
318                        parker.park_timeout(
319                            test_should_end_by.saturating_duration_since(Instant::now()),
320                        );
321                        if Instant::now() > test_should_end_by {
322                            panic!("test timed out with allow_parking")
323                        }
324                    }
325                }
326            }
327        }
328    }
329
330    /// Block the current thread until the given future resolves
331    /// or `duration` has elapsed.
332    pub fn block_with_timeout<Fut: Future>(
333        &self,
334        duration: Duration,
335        future: Fut,
336    ) -> Result<Fut::Output, impl Future<Output = Fut::Output> + use<Fut>> {
337        self.block_internal(true, future, Some(duration))
338    }
339
340    /// Scoped lets you start a number of tasks and waits
341    /// for all of them to complete before returning.
342    pub async fn scoped<'scope, F>(&self, scheduler: F)
343    where
344        F: for<'a> FnOnce(&'a mut Scope<'scope>),
345    {
346        let mut scope = Scope::new(self.clone());
347        (scheduler)(&mut scope);
348        let spawned = mem::take(&mut scope.futures)
349            .into_iter()
350            .map(|f| self.spawn(f))
351            .collect::<Vec<_>>();
352        for task in spawned {
353            task.await;
354        }
355    }
356
357    /// Get the current time.
358    ///
359    /// Calling this instead of `std::time::Instant::now` allows the use
360    /// of fake timers in tests.
361    pub fn now(&self) -> Instant {
362        self.dispatcher.now()
363    }
364
365    /// Returns a task that will complete after the given duration.
366    /// Depending on other concurrent tasks the elapsed duration may be longer
367    /// than requested.
368    pub fn timer(&self, duration: Duration) -> Task<()> {
369        if duration.is_zero() {
370            return Task::ready(());
371        }
372        let (runnable, task) = async_task::spawn(async move {}, {
373            let dispatcher = self.dispatcher.clone();
374            move |runnable| dispatcher.dispatch_after(duration, runnable)
375        });
376        runnable.schedule();
377        Task(TaskState::Spawned(task))
378    }
379
380    /// in tests, start_waiting lets you indicate which task is waiting (for debugging only)
381    #[cfg(any(test, feature = "test-support"))]
382    pub fn start_waiting(&self) {
383        self.dispatcher.as_test().unwrap().start_waiting();
384    }
385
386    /// in tests, removes the debugging data added by start_waiting
387    #[cfg(any(test, feature = "test-support"))]
388    pub fn finish_waiting(&self) {
389        self.dispatcher.as_test().unwrap().finish_waiting();
390    }
391
392    /// in tests, run an arbitrary number of tasks (determined by the SEED environment variable)
393    #[cfg(any(test, feature = "test-support"))]
394    pub fn simulate_random_delay(&self) -> impl Future<Output = ()> + use<> {
395        self.dispatcher.as_test().unwrap().simulate_random_delay()
396    }
397
398    /// in tests, indicate that a given task from `spawn_labeled` should run after everything else
399    #[cfg(any(test, feature = "test-support"))]
400    pub fn deprioritize(&self, task_label: TaskLabel) {
401        self.dispatcher.as_test().unwrap().deprioritize(task_label)
402    }
403
404    /// in tests, move time forward. This does not run any tasks, but does make `timer`s ready.
405    #[cfg(any(test, feature = "test-support"))]
406    pub fn advance_clock(&self, duration: Duration) {
407        self.dispatcher.as_test().unwrap().advance_clock(duration)
408    }
409
410    /// in tests, run one task.
411    #[cfg(any(test, feature = "test-support"))]
412    pub fn tick(&self) -> bool {
413        self.dispatcher.as_test().unwrap().tick(false)
414    }
415
416    /// in tests, run all tasks that are ready to run. If after doing so
417    /// the test still has outstanding tasks, this will panic. (See also [`Self::allow_parking`])
418    #[cfg(any(test, feature = "test-support"))]
419    pub fn run_until_parked(&self) {
420        self.dispatcher.as_test().unwrap().run_until_parked()
421    }
422
423    /// in tests, prevents `run_until_parked` from panicking if there are outstanding tasks.
424    /// This is useful when you are integrating other (non-GPUI) futures, like disk access, that
425    /// do take real async time to run.
426    #[cfg(any(test, feature = "test-support"))]
427    pub fn allow_parking(&self) {
428        self.dispatcher.as_test().unwrap().allow_parking();
429    }
430
431    /// undoes the effect of [`Self::allow_parking`].
432    #[cfg(any(test, feature = "test-support"))]
433    pub fn forbid_parking(&self) {
434        self.dispatcher.as_test().unwrap().forbid_parking();
435    }
436
437    /// adds detail to the "parked with nothing let to run" message.
438    #[cfg(any(test, feature = "test-support"))]
439    pub fn set_waiting_hint(&self, msg: Option<String>) {
440        self.dispatcher.as_test().unwrap().set_waiting_hint(msg);
441    }
442
443    /// in tests, returns the rng used by the dispatcher and seeded by the `SEED` environment variable
444    #[cfg(any(test, feature = "test-support"))]
445    pub fn rng(&self) -> StdRng {
446        self.dispatcher.as_test().unwrap().rng()
447    }
448
449    /// How many CPUs are available to the dispatcher.
450    pub fn num_cpus(&self) -> usize {
451        #[cfg(any(test, feature = "test-support"))]
452        return 4;
453
454        #[cfg(not(any(test, feature = "test-support")))]
455        return num_cpus::get();
456    }
457
458    /// Whether we're on the main thread.
459    pub fn is_main_thread(&self) -> bool {
460        self.dispatcher.is_main_thread()
461    }
462
463    #[cfg(any(test, feature = "test-support"))]
464    /// in tests, control the number of ticks that `block_with_timeout` will run before timing out.
465    pub fn set_block_on_ticks(&self, range: std::ops::RangeInclusive<usize>) {
466        self.dispatcher.as_test().unwrap().set_block_on_ticks(range);
467    }
468}
469
470/// ForegroundExecutor runs things on the main thread.
471impl ForegroundExecutor {
472    /// Creates a new ForegroundExecutor from the given PlatformDispatcher.
473    pub fn new(dispatcher: Arc<dyn PlatformDispatcher>) -> Self {
474        Self {
475            dispatcher,
476            not_send: PhantomData,
477        }
478    }
479
480    /// Enqueues the given Task to run on the main thread at some point in the future.
481    #[track_caller]
482    pub fn spawn<R>(&self, future: impl Future<Output = R> + 'static) -> Task<R>
483    where
484        R: 'static,
485    {
486        let dispatcher = self.dispatcher.clone();
487
488        #[track_caller]
489        fn inner<R: 'static>(
490            dispatcher: Arc<dyn PlatformDispatcher>,
491            future: AnyLocalFuture<R>,
492        ) -> Task<R> {
493            let (runnable, task) = spawn_local_with_source_location(future, move |runnable| {
494                dispatcher.dispatch_on_main_thread(runnable)
495            });
496            runnable.schedule();
497            Task(TaskState::Spawned(task))
498        }
499        inner::<R>(dispatcher, Box::pin(future))
500    }
501}
502
503/// Variant of `async_task::spawn_local` that includes the source location of the spawn in panics.
504///
505/// Copy-modified from:
506/// <https://github.com/smol-rs/async-task/blob/ca9dbe1db9c422fd765847fa91306e30a6bb58a9/src/runnable.rs#L405>
507#[track_caller]
508fn spawn_local_with_source_location<Fut, S>(
509    future: Fut,
510    schedule: S,
511) -> (Runnable<()>, async_task::Task<Fut::Output, ()>)
512where
513    Fut: Future + 'static,
514    Fut::Output: 'static,
515    S: async_task::Schedule<()> + Send + Sync + 'static,
516{
517    #[inline]
518    fn thread_id() -> ThreadId {
519        std::thread_local! {
520            static ID: ThreadId = thread::current().id();
521        }
522        ID.try_with(|id| *id)
523            .unwrap_or_else(|_| thread::current().id())
524    }
525
526    struct Checked<F> {
527        id: ThreadId,
528        inner: ManuallyDrop<F>,
529        location: &'static Location<'static>,
530    }
531
532    impl<F> Drop for Checked<F> {
533        fn drop(&mut self) {
534            assert!(
535                self.id == thread_id(),
536                "local task dropped by a thread that didn't spawn it. Task spawned at {}",
537                self.location
538            );
539            unsafe { ManuallyDrop::drop(&mut self.inner) };
540        }
541    }
542
543    impl<F: Future> Future for Checked<F> {
544        type Output = F::Output;
545
546        fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
547            assert!(
548                self.id == thread_id(),
549                "local task polled by a thread that didn't spawn it. Task spawned at {}",
550                self.location
551            );
552            unsafe { self.map_unchecked_mut(|c| &mut *c.inner).poll(cx) }
553        }
554    }
555
556    // Wrap the future into one that checks which thread it's on.
557    let future = Checked {
558        id: thread_id(),
559        inner: ManuallyDrop::new(future),
560        location: Location::caller(),
561    };
562
563    unsafe { async_task::spawn_unchecked(future, schedule) }
564}
565
566/// Scope manages a set of tasks that are enqueued and waited on together. See [`BackgroundExecutor::scoped`].
567pub struct Scope<'a> {
568    executor: BackgroundExecutor,
569    futures: Vec<Pin<Box<dyn Future<Output = ()> + Send + 'static>>>,
570    tx: Option<mpsc::Sender<()>>,
571    rx: mpsc::Receiver<()>,
572    lifetime: PhantomData<&'a ()>,
573}
574
575impl<'a> Scope<'a> {
576    fn new(executor: BackgroundExecutor) -> Self {
577        let (tx, rx) = mpsc::channel(1);
578        Self {
579            executor,
580            tx: Some(tx),
581            rx,
582            futures: Default::default(),
583            lifetime: PhantomData,
584        }
585    }
586
587    /// How many CPUs are available to the dispatcher.
588    pub fn num_cpus(&self) -> usize {
589        self.executor.num_cpus()
590    }
591
592    /// Spawn a future into this scope.
593    pub fn spawn<F>(&mut self, f: F)
594    where
595        F: Future<Output = ()> + Send + 'a,
596    {
597        let tx = self.tx.clone().unwrap();
598
599        // SAFETY: The 'a lifetime is guaranteed to outlive any of these futures because
600        // dropping this `Scope` blocks until all of the futures have resolved.
601        let f = unsafe {
602            mem::transmute::<
603                Pin<Box<dyn Future<Output = ()> + Send + 'a>>,
604                Pin<Box<dyn Future<Output = ()> + Send + 'static>>,
605            >(Box::pin(async move {
606                f.await;
607                drop(tx);
608            }))
609        };
610        self.futures.push(f);
611    }
612}
613
614impl Drop for Scope<'_> {
615    fn drop(&mut self) {
616        self.tx.take().unwrap();
617
618        // Wait until the channel is closed, which means that all of the spawned
619        // futures have resolved.
620        self.executor.block(self.rx.next());
621    }
622}