executor.rs

  1use crate::{App, PlatformDispatcher};
  2use async_task::Runnable;
  3use futures::channel::mpsc;
  4use smol::prelude::*;
  5use std::mem::ManuallyDrop;
  6use std::panic::Location;
  7use std::thread::{self, ThreadId};
  8use std::{
  9    fmt::Debug,
 10    marker::PhantomData,
 11    mem,
 12    num::NonZeroUsize,
 13    pin::Pin,
 14    rc::Rc,
 15    sync::{
 16        Arc,
 17        atomic::{AtomicUsize, Ordering::SeqCst},
 18    },
 19    task::{Context, Poll},
 20    time::{Duration, Instant},
 21};
 22use util::TryFutureExt;
 23use waker_fn::waker_fn;
 24
 25#[cfg(any(test, feature = "test-support"))]
 26use rand::rngs::StdRng;
 27
 28/// A pointer to the executor that is currently running,
 29/// for spawning background tasks.
 30#[derive(Clone)]
 31pub struct BackgroundExecutor {
 32    #[doc(hidden)]
 33    pub dispatcher: Arc<dyn PlatformDispatcher>,
 34}
 35
 36/// A pointer to the executor that is currently running,
 37/// for spawning tasks on the main thread.
 38///
 39/// This is intentionally `!Send` via the `not_send` marker field. This is because
 40/// `ForegroundExecutor::spawn` does not require `Send` but checks at runtime that the future is
 41/// only polled from the same thread it was spawned from. These checks would fail when spawning
 42/// foreground tasks from from background threads.
 43#[derive(Clone)]
 44pub struct ForegroundExecutor {
 45    #[doc(hidden)]
 46    pub dispatcher: Arc<dyn PlatformDispatcher>,
 47    not_send: PhantomData<Rc<()>>,
 48}
 49
 50/// Task is a primitive that allows work to happen in the background.
 51///
 52/// It implements [`Future`] so you can `.await` on it.
 53///
 54/// If you drop a task it will be cancelled immediately. Calling [`Task::detach`] allows
 55/// the task to continue running, but with no way to return a value.
 56#[must_use]
 57#[derive(Debug)]
 58pub struct Task<T>(TaskState<T>);
 59
 60#[derive(Debug)]
 61enum TaskState<T> {
 62    /// A task that is ready to return a value
 63    Ready(Option<T>),
 64
 65    /// A task that is currently running.
 66    Spawned(async_task::Task<T>),
 67}
 68
 69impl<T> Task<T> {
 70    /// Creates a new task that will resolve with the value
 71    pub fn ready(val: T) -> Self {
 72        Task(TaskState::Ready(Some(val)))
 73    }
 74
 75    /// Detaching a task runs it to completion in the background
 76    pub fn detach(self) {
 77        match self {
 78            Task(TaskState::Ready(_)) => {}
 79            Task(TaskState::Spawned(task)) => task.detach(),
 80        }
 81    }
 82}
 83
 84impl<E, T> Task<Result<T, E>>
 85where
 86    T: 'static,
 87    E: 'static + Debug,
 88{
 89    /// Run the task to completion in the background and log any
 90    /// errors that occur.
 91    #[track_caller]
 92    pub fn detach_and_log_err(self, cx: &App) {
 93        let location = core::panic::Location::caller();
 94        cx.foreground_executor()
 95            .spawn(self.log_tracked_err(*location))
 96            .detach();
 97    }
 98}
 99
100impl<T> Future for Task<T> {
101    type Output = T;
102
103    fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
104        match unsafe { self.get_unchecked_mut() } {
105            Task(TaskState::Ready(val)) => Poll::Ready(val.take().unwrap()),
106            Task(TaskState::Spawned(task)) => task.poll(cx),
107        }
108    }
109}
110
111/// A task label is an opaque identifier that you can use to
112/// refer to a task in tests.
113#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
114pub struct TaskLabel(NonZeroUsize);
115
116impl Default for TaskLabel {
117    fn default() -> Self {
118        Self::new()
119    }
120}
121
122impl TaskLabel {
123    /// Construct a new task label.
124    pub fn new() -> Self {
125        static NEXT_TASK_LABEL: AtomicUsize = AtomicUsize::new(1);
126        Self(NEXT_TASK_LABEL.fetch_add(1, SeqCst).try_into().unwrap())
127    }
128}
129
130type AnyLocalFuture<R> = Pin<Box<dyn 'static + Future<Output = R>>>;
131
132type AnyFuture<R> = Pin<Box<dyn 'static + Send + Future<Output = R>>>;
133
134/// BackgroundExecutor lets you run things on background threads.
135/// In production this is a thread pool with no ordering guarantees.
136/// In tests this is simulated by running tasks one by one in a deterministic
137/// (but arbitrary) order controlled by the `SEED` environment variable.
138impl BackgroundExecutor {
139    #[doc(hidden)]
140    pub fn new(dispatcher: Arc<dyn PlatformDispatcher>) -> Self {
141        Self { dispatcher }
142    }
143
144    /// Enqueues the given future to be run to completion on a background thread.
145    pub fn spawn<R>(&self, future: impl Future<Output = R> + Send + 'static) -> Task<R>
146    where
147        R: Send + 'static,
148    {
149        self.spawn_internal::<R>(Box::pin(future), None)
150    }
151
152    /// Enqueues the given future to be run to completion on a background thread.
153    /// The given label can be used to control the priority of the task in tests.
154    pub fn spawn_labeled<R>(
155        &self,
156        label: TaskLabel,
157        future: impl Future<Output = R> + Send + 'static,
158    ) -> Task<R>
159    where
160        R: Send + 'static,
161    {
162        self.spawn_internal::<R>(Box::pin(future), Some(label))
163    }
164
165    fn spawn_internal<R: Send + 'static>(
166        &self,
167        future: AnyFuture<R>,
168        label: Option<TaskLabel>,
169    ) -> Task<R> {
170        let dispatcher = self.dispatcher.clone();
171        let (runnable, task) =
172            async_task::spawn(future, move |runnable| dispatcher.dispatch(runnable, label));
173        runnable.schedule();
174        Task(TaskState::Spawned(task))
175    }
176
177    /// Used by the test harness to run an async test in a synchronous fashion.
178    #[cfg(any(test, feature = "test-support"))]
179    #[track_caller]
180    pub fn block_test<R>(&self, future: impl Future<Output = R>) -> R {
181        if let Ok(value) = self.block_internal(false, future, None) {
182            value
183        } else {
184            unreachable!()
185        }
186    }
187
188    /// Block the current thread until the given future resolves.
189    /// Consider using `block_with_timeout` instead.
190    pub fn block<R>(&self, future: impl Future<Output = R>) -> R {
191        if let Ok(value) = self.block_internal(true, future, None) {
192            value
193        } else {
194            unreachable!()
195        }
196    }
197
198    #[cfg(not(any(test, feature = "test-support")))]
199    pub(crate) fn block_internal<Fut: Future>(
200        &self,
201        _background_only: bool,
202        future: Fut,
203        timeout: Option<Duration>,
204    ) -> Result<Fut::Output, impl Future<Output = Fut::Output> + use<Fut>> {
205        use std::time::Instant;
206
207        let mut future = Box::pin(future);
208        if timeout == Some(Duration::ZERO) {
209            return Err(future);
210        }
211        let deadline = timeout.map(|timeout| Instant::now() + timeout);
212
213        let unparker = self.dispatcher.unparker();
214        let waker = waker_fn(move || {
215            unparker.unpark();
216        });
217        let mut cx = std::task::Context::from_waker(&waker);
218
219        loop {
220            match future.as_mut().poll(&mut cx) {
221                Poll::Ready(result) => return Ok(result),
222                Poll::Pending => {
223                    let timeout =
224                        deadline.map(|deadline| deadline.saturating_duration_since(Instant::now()));
225                    if !self.dispatcher.park(timeout)
226                        && deadline.is_some_and(|deadline| deadline < Instant::now())
227                    {
228                        return Err(future);
229                    }
230                }
231            }
232        }
233    }
234
235    #[cfg(any(test, feature = "test-support"))]
236    #[track_caller]
237    pub(crate) fn block_internal<Fut: Future>(
238        &self,
239        background_only: bool,
240        future: Fut,
241        timeout: Option<Duration>,
242    ) -> Result<Fut::Output, impl Future<Output = Fut::Output> + use<Fut>> {
243        use std::sync::atomic::AtomicBool;
244
245        let mut future = Box::pin(future);
246        if timeout == Some(Duration::ZERO) {
247            return Err(future);
248        }
249        let Some(dispatcher) = self.dispatcher.as_test() else {
250            return Err(future);
251        };
252
253        let mut max_ticks = if timeout.is_some() {
254            dispatcher.gen_block_on_ticks()
255        } else {
256            usize::MAX
257        };
258        let unparker = self.dispatcher.unparker();
259        let awoken = Arc::new(AtomicBool::new(false));
260        let waker = waker_fn({
261            let awoken = awoken.clone();
262            move || {
263                awoken.store(true, SeqCst);
264                unparker.unpark();
265            }
266        });
267        let mut cx = std::task::Context::from_waker(&waker);
268
269        loop {
270            match future.as_mut().poll(&mut cx) {
271                Poll::Ready(result) => return Ok(result),
272                Poll::Pending => {
273                    if max_ticks == 0 {
274                        return Err(future);
275                    }
276                    max_ticks -= 1;
277
278                    if !dispatcher.tick(background_only) {
279                        if awoken.swap(false, SeqCst) {
280                            continue;
281                        }
282
283                        if !dispatcher.parking_allowed() {
284                            let mut backtrace_message = String::new();
285                            let mut waiting_message = String::new();
286                            if let Some(backtrace) = dispatcher.waiting_backtrace() {
287                                backtrace_message =
288                                    format!("\nbacktrace of waiting future:\n{:?}", backtrace);
289                            }
290                            if let Some(waiting_hint) = dispatcher.waiting_hint() {
291                                waiting_message = format!("\n  waiting on: {}\n", waiting_hint);
292                            }
293                            panic!(
294                                "parked with nothing left to run{waiting_message}{backtrace_message}",
295                            )
296                        }
297                        self.dispatcher.park(None);
298                    }
299                }
300            }
301        }
302    }
303
304    /// Block the current thread until the given future resolves
305    /// or `duration` has elapsed.
306    pub fn block_with_timeout<Fut: Future>(
307        &self,
308        duration: Duration,
309        future: Fut,
310    ) -> Result<Fut::Output, impl Future<Output = Fut::Output> + use<Fut>> {
311        self.block_internal(true, future, Some(duration))
312    }
313
314    /// Scoped lets you start a number of tasks and waits
315    /// for all of them to complete before returning.
316    pub async fn scoped<'scope, F>(&self, scheduler: F)
317    where
318        F: FnOnce(&mut Scope<'scope>),
319    {
320        let mut scope = Scope::new(self.clone());
321        (scheduler)(&mut scope);
322        let spawned = mem::take(&mut scope.futures)
323            .into_iter()
324            .map(|f| self.spawn(f))
325            .collect::<Vec<_>>();
326        for task in spawned {
327            task.await;
328        }
329    }
330
331    /// Get the current time.
332    ///
333    /// Calling this instead of `std::time::Instant::now` allows the use
334    /// of fake timers in tests.
335    pub fn now(&self) -> Instant {
336        self.dispatcher.now()
337    }
338
339    /// Returns a task that will complete after the given duration.
340    /// Depending on other concurrent tasks the elapsed duration may be longer
341    /// than requested.
342    pub fn timer(&self, duration: Duration) -> Task<()> {
343        if duration.is_zero() {
344            return Task::ready(());
345        }
346        let (runnable, task) = async_task::spawn(async move {}, {
347            let dispatcher = self.dispatcher.clone();
348            move |runnable| dispatcher.dispatch_after(duration, runnable)
349        });
350        runnable.schedule();
351        Task(TaskState::Spawned(task))
352    }
353
354    /// in tests, start_waiting lets you indicate which task is waiting (for debugging only)
355    #[cfg(any(test, feature = "test-support"))]
356    pub fn start_waiting(&self) {
357        self.dispatcher.as_test().unwrap().start_waiting();
358    }
359
360    /// in tests, removes the debugging data added by start_waiting
361    #[cfg(any(test, feature = "test-support"))]
362    pub fn finish_waiting(&self) {
363        self.dispatcher.as_test().unwrap().finish_waiting();
364    }
365
366    /// in tests, run an arbitrary number of tasks (determined by the SEED environment variable)
367    #[cfg(any(test, feature = "test-support"))]
368    pub fn simulate_random_delay(&self) -> impl Future<Output = ()> + use<> {
369        self.dispatcher.as_test().unwrap().simulate_random_delay()
370    }
371
372    /// in tests, indicate that a given task from `spawn_labeled` should run after everything else
373    #[cfg(any(test, feature = "test-support"))]
374    pub fn deprioritize(&self, task_label: TaskLabel) {
375        self.dispatcher.as_test().unwrap().deprioritize(task_label)
376    }
377
378    /// in tests, move time forward. This does not run any tasks, but does make `timer`s ready.
379    #[cfg(any(test, feature = "test-support"))]
380    pub fn advance_clock(&self, duration: Duration) {
381        self.dispatcher.as_test().unwrap().advance_clock(duration)
382    }
383
384    /// in tests, run one task.
385    #[cfg(any(test, feature = "test-support"))]
386    pub fn tick(&self) -> bool {
387        self.dispatcher.as_test().unwrap().tick(false)
388    }
389
390    /// in tests, run all tasks that are ready to run. If after doing so
391    /// the test still has outstanding tasks, this will panic. (See also `allow_parking`)
392    #[cfg(any(test, feature = "test-support"))]
393    pub fn run_until_parked(&self) {
394        self.dispatcher.as_test().unwrap().run_until_parked()
395    }
396
397    /// in tests, prevents `run_until_parked` from panicking if there are outstanding tasks.
398    /// This is useful when you are integrating other (non-GPUI) futures, like disk access, that
399    /// do take real async time to run.
400    #[cfg(any(test, feature = "test-support"))]
401    pub fn allow_parking(&self) {
402        self.dispatcher.as_test().unwrap().allow_parking();
403    }
404
405    /// undoes the effect of [`allow_parking`].
406    #[cfg(any(test, feature = "test-support"))]
407    pub fn forbid_parking(&self) {
408        self.dispatcher.as_test().unwrap().forbid_parking();
409    }
410
411    /// adds detail to the "parked with nothing let to run" message.
412    #[cfg(any(test, feature = "test-support"))]
413    pub fn set_waiting_hint(&self, msg: Option<String>) {
414        self.dispatcher.as_test().unwrap().set_waiting_hint(msg);
415    }
416
417    /// in tests, returns the rng used by the dispatcher and seeded by the `SEED` environment variable
418    #[cfg(any(test, feature = "test-support"))]
419    pub fn rng(&self) -> StdRng {
420        self.dispatcher.as_test().unwrap().rng()
421    }
422
423    /// How many CPUs are available to the dispatcher.
424    pub fn num_cpus(&self) -> usize {
425        #[cfg(any(test, feature = "test-support"))]
426        return 4;
427
428        #[cfg(not(any(test, feature = "test-support")))]
429        return num_cpus::get();
430    }
431
432    /// Whether we're on the main thread.
433    pub fn is_main_thread(&self) -> bool {
434        self.dispatcher.is_main_thread()
435    }
436
437    #[cfg(any(test, feature = "test-support"))]
438    /// in tests, control the number of ticks that `block_with_timeout` will run before timing out.
439    pub fn set_block_on_ticks(&self, range: std::ops::RangeInclusive<usize>) {
440        self.dispatcher.as_test().unwrap().set_block_on_ticks(range);
441    }
442}
443
444/// ForegroundExecutor runs things on the main thread.
445impl ForegroundExecutor {
446    /// Creates a new ForegroundExecutor from the given PlatformDispatcher.
447    pub fn new(dispatcher: Arc<dyn PlatformDispatcher>) -> Self {
448        Self {
449            dispatcher,
450            not_send: PhantomData,
451        }
452    }
453
454    /// Enqueues the given Task to run on the main thread at some point in the future.
455    #[track_caller]
456    pub fn spawn<R>(&self, future: impl Future<Output = R> + 'static) -> Task<R>
457    where
458        R: 'static,
459    {
460        let dispatcher = self.dispatcher.clone();
461
462        #[track_caller]
463        fn inner<R: 'static>(
464            dispatcher: Arc<dyn PlatformDispatcher>,
465            future: AnyLocalFuture<R>,
466        ) -> Task<R> {
467            let (runnable, task) = spawn_local_with_source_location(future, move |runnable| {
468                dispatcher.dispatch_on_main_thread(runnable)
469            });
470            runnable.schedule();
471            Task(TaskState::Spawned(task))
472        }
473        inner::<R>(dispatcher, Box::pin(future))
474    }
475}
476
477/// Variant of `async_task::spawn_local` that includes the source location of the spawn in panics.
478///
479/// Copy-modified from:
480/// https://github.com/smol-rs/async-task/blob/ca9dbe1db9c422fd765847fa91306e30a6bb58a9/src/runnable.rs#L405
481#[track_caller]
482fn spawn_local_with_source_location<Fut, S>(
483    future: Fut,
484    schedule: S,
485) -> (Runnable<()>, async_task::Task<Fut::Output, ()>)
486where
487    Fut: Future + 'static,
488    Fut::Output: 'static,
489    S: async_task::Schedule<()> + Send + Sync + 'static,
490{
491    #[inline]
492    fn thread_id() -> ThreadId {
493        std::thread_local! {
494            static ID: ThreadId = thread::current().id();
495        }
496        ID.try_with(|id| *id)
497            .unwrap_or_else(|_| thread::current().id())
498    }
499
500    struct Checked<F> {
501        id: ThreadId,
502        inner: ManuallyDrop<F>,
503        location: &'static Location<'static>,
504    }
505
506    impl<F> Drop for Checked<F> {
507        fn drop(&mut self) {
508            assert!(
509                self.id == thread_id(),
510                "local task dropped by a thread that didn't spawn it. Task spawned at {}",
511                self.location
512            );
513            unsafe {
514                ManuallyDrop::drop(&mut self.inner);
515            }
516        }
517    }
518
519    impl<F: Future> Future for Checked<F> {
520        type Output = F::Output;
521
522        fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
523            assert!(
524                self.id == thread_id(),
525                "local task polled by a thread that didn't spawn it. Task spawned at {}",
526                self.location
527            );
528            unsafe { self.map_unchecked_mut(|c| &mut *c.inner).poll(cx) }
529        }
530    }
531
532    // Wrap the future into one that checks which thread it's on.
533    let future = Checked {
534        id: thread_id(),
535        inner: ManuallyDrop::new(future),
536        location: Location::caller(),
537    };
538
539    unsafe { async_task::spawn_unchecked(future, schedule) }
540}
541
542/// Scope manages a set of tasks that are enqueued and waited on together. See [`BackgroundExecutor::scoped`].
543pub struct Scope<'a> {
544    executor: BackgroundExecutor,
545    futures: Vec<Pin<Box<dyn Future<Output = ()> + Send + 'static>>>,
546    tx: Option<mpsc::Sender<()>>,
547    rx: mpsc::Receiver<()>,
548    lifetime: PhantomData<&'a ()>,
549}
550
551impl<'a> Scope<'a> {
552    fn new(executor: BackgroundExecutor) -> Self {
553        let (tx, rx) = mpsc::channel(1);
554        Self {
555            executor,
556            tx: Some(tx),
557            rx,
558            futures: Default::default(),
559            lifetime: PhantomData,
560        }
561    }
562
563    /// How many CPUs are available to the dispatcher.
564    pub fn num_cpus(&self) -> usize {
565        self.executor.num_cpus()
566    }
567
568    /// Spawn a future into this scope.
569    pub fn spawn<F>(&mut self, f: F)
570    where
571        F: Future<Output = ()> + Send + 'a,
572    {
573        let tx = self.tx.clone().unwrap();
574
575        // SAFETY: The 'a lifetime is guaranteed to outlive any of these futures because
576        // dropping this `Scope` blocks until all of the futures have resolved.
577        let f = unsafe {
578            mem::transmute::<
579                Pin<Box<dyn Future<Output = ()> + Send + 'a>>,
580                Pin<Box<dyn Future<Output = ()> + Send + 'static>>,
581            >(Box::pin(async move {
582                f.await;
583                drop(tx);
584            }))
585        };
586        self.futures.push(f);
587    }
588}
589
590impl Drop for Scope<'_> {
591    fn drop(&mut self) {
592        self.tx.take().unwrap();
593
594        // Wait until the channel is closed, which means that all of the spawned
595        // futures have resolved.
596        self.executor.block(self.rx.next());
597    }
598}