executor.rs

  1use crate::{App, PlatformDispatcher};
  2use async_task::Runnable;
  3use futures::channel::mpsc;
  4use smol::prelude::*;
  5use std::{
  6    fmt::Debug,
  7    marker::PhantomData,
  8    mem::{self, ManuallyDrop},
  9    num::NonZeroUsize,
 10    panic::Location,
 11    pin::Pin,
 12    rc::Rc,
 13    sync::{
 14        Arc,
 15        atomic::{AtomicUsize, Ordering},
 16    },
 17    task::{Context, Poll},
 18    thread::{self, ThreadId},
 19    time::{Duration, Instant},
 20};
 21use util::TryFutureExt;
 22use waker_fn::waker_fn;
 23
 24#[cfg(any(test, feature = "test-support"))]
 25use rand::rngs::StdRng;
 26
 27/// A pointer to the executor that is currently running,
 28/// for spawning background tasks.
 29#[derive(Clone)]
 30pub struct BackgroundExecutor {
 31    #[doc(hidden)]
 32    pub dispatcher: Arc<dyn PlatformDispatcher>,
 33}
 34
 35/// A pointer to the executor that is currently running,
 36/// for spawning tasks on the main thread.
 37///
 38/// This is intentionally `!Send` via the `not_send` marker field. This is because
 39/// `ForegroundExecutor::spawn` does not require `Send` but checks at runtime that the future is
 40/// only polled from the same thread it was spawned from. These checks would fail when spawning
 41/// foreground tasks from background threads.
 42#[derive(Clone)]
 43pub struct ForegroundExecutor {
 44    #[doc(hidden)]
 45    pub dispatcher: Arc<dyn PlatformDispatcher>,
 46    not_send: PhantomData<Rc<()>>,
 47}
 48
 49/// Task is a primitive that allows work to happen in the background.
 50///
 51/// It implements [`Future`] so you can `.await` on it.
 52///
 53/// If you drop a task it will be cancelled immediately. Calling [`Task::detach`] allows
 54/// the task to continue running, but with no way to return a value.
 55#[must_use]
 56#[derive(Debug)]
 57pub struct Task<T>(TaskState<T>);
 58
 59#[derive(Debug)]
 60enum TaskState<T> {
 61    /// A task that is ready to return a value
 62    Ready(Option<T>),
 63
 64    /// A task that is currently running.
 65    Spawned(async_task::Task<T>),
 66}
 67
 68impl<T> Task<T> {
 69    /// Creates a new task that will resolve with the value
 70    pub fn ready(val: T) -> Self {
 71        Task(TaskState::Ready(Some(val)))
 72    }
 73
 74    /// Detaching a task runs it to completion in the background
 75    pub fn detach(self) {
 76        match self {
 77            Task(TaskState::Ready(_)) => {}
 78            Task(TaskState::Spawned(task)) => task.detach(),
 79        }
 80    }
 81}
 82
 83impl<E, T> Task<Result<T, E>>
 84where
 85    T: 'static,
 86    E: 'static + Debug,
 87{
 88    /// Run the task to completion in the background and log any
 89    /// errors that occur.
 90    #[track_caller]
 91    pub fn detach_and_log_err(self, cx: &App) {
 92        let location = core::panic::Location::caller();
 93        cx.foreground_executor()
 94            .spawn(self.log_tracked_err(*location))
 95            .detach();
 96    }
 97}
 98
 99impl<T> Future for Task<T> {
100    type Output = T;
101
102    fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
103        match unsafe { self.get_unchecked_mut() } {
104            Task(TaskState::Ready(val)) => Poll::Ready(val.take().unwrap()),
105            Task(TaskState::Spawned(task)) => task.poll(cx),
106        }
107    }
108}
109
110/// A task label is an opaque identifier that you can use to
111/// refer to a task in tests.
112#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
113pub struct TaskLabel(NonZeroUsize);
114
115impl Default for TaskLabel {
116    fn default() -> Self {
117        Self::new()
118    }
119}
120
121impl TaskLabel {
122    /// Construct a new task label.
123    pub fn new() -> Self {
124        static NEXT_TASK_LABEL: AtomicUsize = AtomicUsize::new(1);
125        Self(
126            NEXT_TASK_LABEL
127                .fetch_add(1, Ordering::SeqCst)
128                .try_into()
129                .unwrap(),
130        )
131    }
132}
133
134type AnyLocalFuture<R> = Pin<Box<dyn 'static + Future<Output = R>>>;
135
136type AnyFuture<R> = Pin<Box<dyn 'static + Send + Future<Output = R>>>;
137
138/// BackgroundExecutor lets you run things on background threads.
139/// In production this is a thread pool with no ordering guarantees.
140/// In tests this is simulated by running tasks one by one in a deterministic
141/// (but arbitrary) order controlled by the `SEED` environment variable.
142impl BackgroundExecutor {
143    #[doc(hidden)]
144    pub fn new(dispatcher: Arc<dyn PlatformDispatcher>) -> Self {
145        Self { dispatcher }
146    }
147
148    /// Enqueues the given future to be run to completion on a background thread.
149    pub fn spawn<R>(&self, future: impl Future<Output = R> + Send + 'static) -> Task<R>
150    where
151        R: Send + 'static,
152    {
153        self.spawn_internal::<R>(Box::pin(future), None)
154    }
155
156    /// Enqueues the given future to be run to completion on a background thread.
157    /// The given label can be used to control the priority of the task in tests.
158    pub fn spawn_labeled<R>(
159        &self,
160        label: TaskLabel,
161        future: impl Future<Output = R> + Send + 'static,
162    ) -> Task<R>
163    where
164        R: Send + 'static,
165    {
166        self.spawn_internal::<R>(Box::pin(future), Some(label))
167    }
168
169    fn spawn_internal<R: Send + 'static>(
170        &self,
171        future: AnyFuture<R>,
172        label: Option<TaskLabel>,
173    ) -> Task<R> {
174        let dispatcher = self.dispatcher.clone();
175        let (runnable, task) =
176            async_task::spawn(future, move |runnable| dispatcher.dispatch(runnable, label));
177        runnable.schedule();
178        Task(TaskState::Spawned(task))
179    }
180
181    /// Used by the test harness to run an async test in a synchronous fashion.
182    #[cfg(any(test, feature = "test-support"))]
183    #[track_caller]
184    pub fn block_test<R>(&self, future: impl Future<Output = R>) -> R {
185        if let Ok(value) = self.block_internal(false, future, None) {
186            value
187        } else {
188            unreachable!()
189        }
190    }
191
192    /// Block the current thread until the given future resolves.
193    /// Consider using `block_with_timeout` instead.
194    pub fn block<R>(&self, future: impl Future<Output = R>) -> R {
195        if let Ok(value) = self.block_internal(true, future, None) {
196            value
197        } else {
198            unreachable!()
199        }
200    }
201
202    #[cfg(not(any(test, feature = "test-support")))]
203    pub(crate) fn block_internal<Fut: Future>(
204        &self,
205        _background_only: bool,
206        future: Fut,
207        timeout: Option<Duration>,
208    ) -> Result<Fut::Output, impl Future<Output = Fut::Output> + use<Fut>> {
209        use std::time::Instant;
210
211        let mut future = Box::pin(future);
212        if timeout == Some(Duration::ZERO) {
213            return Err(future);
214        }
215        let deadline = timeout.map(|timeout| Instant::now() + timeout);
216
217        let parker = parking::Parker::new();
218        let unparker = parker.unparker();
219        let waker = waker_fn(move || {
220            unparker.unpark();
221        });
222        let mut cx = std::task::Context::from_waker(&waker);
223
224        loop {
225            match future.as_mut().poll(&mut cx) {
226                Poll::Ready(result) => return Ok(result),
227                Poll::Pending => {
228                    let timeout =
229                        deadline.map(|deadline| deadline.saturating_duration_since(Instant::now()));
230                    if let Some(timeout) = timeout {
231                        if !parker.park_timeout(timeout)
232                            && deadline.is_some_and(|deadline| deadline < Instant::now())
233                        {
234                            return Err(future);
235                        }
236                    } else {
237                        parker.park();
238                    }
239                }
240            }
241        }
242    }
243
244    #[cfg(any(test, feature = "test-support"))]
245    #[track_caller]
246    pub(crate) fn block_internal<Fut: Future>(
247        &self,
248        background_only: bool,
249        future: Fut,
250        timeout: Option<Duration>,
251    ) -> Result<Fut::Output, impl Future<Output = Fut::Output> + use<Fut>> {
252        use std::sync::atomic::AtomicBool;
253
254        use parking::Parker;
255
256        let mut future = Box::pin(future);
257        if timeout == Some(Duration::ZERO) {
258            return Err(future);
259        }
260        let Some(dispatcher) = self.dispatcher.as_test() else {
261            return Err(future);
262        };
263
264        let mut max_ticks = if timeout.is_some() {
265            dispatcher.gen_block_on_ticks()
266        } else {
267            usize::MAX
268        };
269
270        let parker = Parker::new();
271        let unparker = parker.unparker();
272
273        let awoken = Arc::new(AtomicBool::new(false));
274        let waker = waker_fn({
275            let awoken = awoken.clone();
276            let unparker = unparker.clone();
277            move || {
278                awoken.store(true, Ordering::SeqCst);
279                unparker.unpark();
280            }
281        });
282        let mut cx = std::task::Context::from_waker(&waker);
283
284        let duration = Duration::from_secs(180);
285        let mut test_should_end_by = Instant::now() + duration;
286
287        loop {
288            match future.as_mut().poll(&mut cx) {
289                Poll::Ready(result) => return Ok(result),
290                Poll::Pending => {
291                    if max_ticks == 0 {
292                        return Err(future);
293                    }
294                    max_ticks -= 1;
295
296                    if !dispatcher.tick(background_only) {
297                        if awoken.swap(false, Ordering::SeqCst) {
298                            continue;
299                        }
300
301                        if !dispatcher.parking_allowed() {
302                            if dispatcher.advance_clock_to_next_delayed() {
303                                continue;
304                            }
305                            let mut backtrace_message = String::new();
306                            let mut waiting_message = String::new();
307                            if let Some(backtrace) = dispatcher.waiting_backtrace() {
308                                backtrace_message =
309                                    format!("\nbacktrace of waiting future:\n{:?}", backtrace);
310                            }
311                            if let Some(waiting_hint) = dispatcher.waiting_hint() {
312                                waiting_message = format!("\n  waiting on: {}\n", waiting_hint);
313                            }
314                            panic!(
315                                "parked with nothing left to run{waiting_message}{backtrace_message}",
316                            )
317                        }
318                        dispatcher.push_unparker(unparker.clone());
319                        parker.park_timeout(Duration::from_millis(1));
320                        if Instant::now() > test_should_end_by {
321                            panic!("test timed out after {duration:?} with allow_parking")
322                        }
323                    }
324                }
325            }
326        }
327    }
328
329    /// Block the current thread until the given future resolves
330    /// or `duration` has elapsed.
331    pub fn block_with_timeout<Fut: Future>(
332        &self,
333        duration: Duration,
334        future: Fut,
335    ) -> Result<Fut::Output, impl Future<Output = Fut::Output> + use<Fut>> {
336        self.block_internal(true, future, Some(duration))
337    }
338
339    /// Scoped lets you start a number of tasks and waits
340    /// for all of them to complete before returning.
341    pub async fn scoped<'scope, F>(&self, scheduler: F)
342    where
343        F: FnOnce(&mut Scope<'scope>),
344    {
345        let mut scope = Scope::new(self.clone());
346        (scheduler)(&mut scope);
347        let spawned = mem::take(&mut scope.futures)
348            .into_iter()
349            .map(|f| self.spawn(f))
350            .collect::<Vec<_>>();
351        for task in spawned {
352            task.await;
353        }
354    }
355
356    /// Get the current time.
357    ///
358    /// Calling this instead of `std::time::Instant::now` allows the use
359    /// of fake timers in tests.
360    pub fn now(&self) -> Instant {
361        self.dispatcher.now()
362    }
363
364    /// Returns a task that will complete after the given duration.
365    /// Depending on other concurrent tasks the elapsed duration may be longer
366    /// than requested.
367    pub fn timer(&self, duration: Duration) -> Task<()> {
368        if duration.is_zero() {
369            return Task::ready(());
370        }
371        let (runnable, task) = async_task::spawn(async move {}, {
372            let dispatcher = self.dispatcher.clone();
373            move |runnable| dispatcher.dispatch_after(duration, runnable)
374        });
375        runnable.schedule();
376        Task(TaskState::Spawned(task))
377    }
378
379    /// in tests, start_waiting lets you indicate which task is waiting (for debugging only)
380    #[cfg(any(test, feature = "test-support"))]
381    pub fn start_waiting(&self) {
382        self.dispatcher.as_test().unwrap().start_waiting();
383    }
384
385    /// in tests, removes the debugging data added by start_waiting
386    #[cfg(any(test, feature = "test-support"))]
387    pub fn finish_waiting(&self) {
388        self.dispatcher.as_test().unwrap().finish_waiting();
389    }
390
391    /// in tests, run an arbitrary number of tasks (determined by the SEED environment variable)
392    #[cfg(any(test, feature = "test-support"))]
393    pub fn simulate_random_delay(&self) -> impl Future<Output = ()> + use<> {
394        self.dispatcher.as_test().unwrap().simulate_random_delay()
395    }
396
397    /// in tests, indicate that a given task from `spawn_labeled` should run after everything else
398    #[cfg(any(test, feature = "test-support"))]
399    pub fn deprioritize(&self, task_label: TaskLabel) {
400        self.dispatcher.as_test().unwrap().deprioritize(task_label)
401    }
402
403    /// in tests, move time forward. This does not run any tasks, but does make `timer`s ready.
404    #[cfg(any(test, feature = "test-support"))]
405    pub fn advance_clock(&self, duration: Duration) {
406        self.dispatcher.as_test().unwrap().advance_clock(duration)
407    }
408
409    /// in tests, run one task.
410    #[cfg(any(test, feature = "test-support"))]
411    pub fn tick(&self) -> bool {
412        self.dispatcher.as_test().unwrap().tick(false)
413    }
414
415    /// in tests, run all tasks that are ready to run. If after doing so
416    /// the test still has outstanding tasks, this will panic. (See also [`Self::allow_parking`])
417    #[cfg(any(test, feature = "test-support"))]
418    pub fn run_until_parked(&self) {
419        self.dispatcher.as_test().unwrap().run_until_parked()
420    }
421
422    /// in tests, prevents `run_until_parked` from panicking if there are outstanding tasks.
423    /// This is useful when you are integrating other (non-GPUI) futures, like disk access, that
424    /// do take real async time to run.
425    #[cfg(any(test, feature = "test-support"))]
426    pub fn allow_parking(&self) {
427        self.dispatcher.as_test().unwrap().allow_parking();
428    }
429
430    /// undoes the effect of [`Self::allow_parking`].
431    #[cfg(any(test, feature = "test-support"))]
432    pub fn forbid_parking(&self) {
433        self.dispatcher.as_test().unwrap().forbid_parking();
434    }
435
436    /// adds detail to the "parked with nothing let to run" message.
437    #[cfg(any(test, feature = "test-support"))]
438    pub fn set_waiting_hint(&self, msg: Option<String>) {
439        self.dispatcher.as_test().unwrap().set_waiting_hint(msg);
440    }
441
442    /// in tests, returns the rng used by the dispatcher and seeded by the `SEED` environment variable
443    #[cfg(any(test, feature = "test-support"))]
444    pub fn rng(&self) -> StdRng {
445        self.dispatcher.as_test().unwrap().rng()
446    }
447
448    /// How many CPUs are available to the dispatcher.
449    pub fn num_cpus(&self) -> usize {
450        #[cfg(any(test, feature = "test-support"))]
451        return 4;
452
453        #[cfg(not(any(test, feature = "test-support")))]
454        return num_cpus::get();
455    }
456
457    /// Whether we're on the main thread.
458    pub fn is_main_thread(&self) -> bool {
459        self.dispatcher.is_main_thread()
460    }
461
462    #[cfg(any(test, feature = "test-support"))]
463    /// in tests, control the number of ticks that `block_with_timeout` will run before timing out.
464    pub fn set_block_on_ticks(&self, range: std::ops::RangeInclusive<usize>) {
465        self.dispatcher.as_test().unwrap().set_block_on_ticks(range);
466    }
467}
468
469/// ForegroundExecutor runs things on the main thread.
470impl ForegroundExecutor {
471    /// Creates a new ForegroundExecutor from the given PlatformDispatcher.
472    pub fn new(dispatcher: Arc<dyn PlatformDispatcher>) -> Self {
473        Self {
474            dispatcher,
475            not_send: PhantomData,
476        }
477    }
478
479    /// Enqueues the given Task to run on the main thread at some point in the future.
480    pub fn spawn<R>(&self, future: impl Future<Output = R> + 'static) -> Task<R>
481    where
482        R: 'static,
483    {
484        let dispatcher = self.dispatcher.clone();
485
486        #[track_caller]
487        fn inner<R: 'static>(
488            dispatcher: Arc<dyn PlatformDispatcher>,
489            future: AnyLocalFuture<R>,
490        ) -> Task<R> {
491            let (runnable, task) = spawn_local_with_source_location(future, move |runnable| {
492                dispatcher.dispatch_on_main_thread(runnable)
493            });
494            runnable.schedule();
495            Task(TaskState::Spawned(task))
496        }
497        inner::<R>(dispatcher, Box::pin(future))
498    }
499}
500
501/// Variant of `async_task::spawn_local` that includes the source location of the spawn in panics.
502///
503/// Copy-modified from:
504/// <https://github.com/smol-rs/async-task/blob/ca9dbe1db9c422fd765847fa91306e30a6bb58a9/src/runnable.rs#L405>
505#[track_caller]
506fn spawn_local_with_source_location<Fut, S>(
507    future: Fut,
508    schedule: S,
509) -> (Runnable<()>, async_task::Task<Fut::Output, ()>)
510where
511    Fut: Future + 'static,
512    Fut::Output: 'static,
513    S: async_task::Schedule<()> + Send + Sync + 'static,
514{
515    #[inline]
516    fn thread_id() -> ThreadId {
517        std::thread_local! {
518            static ID: ThreadId = thread::current().id();
519        }
520        ID.try_with(|id| *id)
521            .unwrap_or_else(|_| thread::current().id())
522    }
523
524    struct Checked<F> {
525        id: ThreadId,
526        inner: ManuallyDrop<F>,
527        location: &'static Location<'static>,
528    }
529
530    impl<F> Drop for Checked<F> {
531        fn drop(&mut self) {
532            assert!(
533                self.id == thread_id(),
534                "local task dropped by a thread that didn't spawn it. Task spawned at {}",
535                self.location
536            );
537            unsafe { ManuallyDrop::drop(&mut self.inner) };
538        }
539    }
540
541    impl<F: Future> Future for Checked<F> {
542        type Output = F::Output;
543
544        fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
545            assert!(
546                self.id == thread_id(),
547                "local task polled by a thread that didn't spawn it. Task spawned at {}",
548                self.location
549            );
550            unsafe { self.map_unchecked_mut(|c| &mut *c.inner).poll(cx) }
551        }
552    }
553
554    // Wrap the future into one that checks which thread it's on.
555    let future = Checked {
556        id: thread_id(),
557        inner: ManuallyDrop::new(future),
558        location: Location::caller(),
559    };
560
561    unsafe { async_task::spawn_unchecked(future, schedule) }
562}
563
564/// Scope manages a set of tasks that are enqueued and waited on together. See [`BackgroundExecutor::scoped`].
565pub struct Scope<'a> {
566    executor: BackgroundExecutor,
567    futures: Vec<Pin<Box<dyn Future<Output = ()> + Send + 'static>>>,
568    tx: Option<mpsc::Sender<()>>,
569    rx: mpsc::Receiver<()>,
570    lifetime: PhantomData<&'a ()>,
571}
572
573impl<'a> Scope<'a> {
574    fn new(executor: BackgroundExecutor) -> Self {
575        let (tx, rx) = mpsc::channel(1);
576        Self {
577            executor,
578            tx: Some(tx),
579            rx,
580            futures: Default::default(),
581            lifetime: PhantomData,
582        }
583    }
584
585    /// How many CPUs are available to the dispatcher.
586    pub fn num_cpus(&self) -> usize {
587        self.executor.num_cpus()
588    }
589
590    /// Spawn a future into this scope.
591    pub fn spawn<F>(&mut self, f: F)
592    where
593        F: Future<Output = ()> + Send + 'a,
594    {
595        let tx = self.tx.clone().unwrap();
596
597        // SAFETY: The 'a lifetime is guaranteed to outlive any of these futures because
598        // dropping this `Scope` blocks until all of the futures have resolved.
599        let f = unsafe {
600            mem::transmute::<
601                Pin<Box<dyn Future<Output = ()> + Send + 'a>>,
602                Pin<Box<dyn Future<Output = ()> + Send + 'static>>,
603            >(Box::pin(async move {
604                f.await;
605                drop(tx);
606            }))
607        };
608        self.futures.push(f);
609    }
610}
611
612impl Drop for Scope<'_> {
613    fn drop(&mut self) {
614        self.tx.take().unwrap();
615
616        // Wait until the channel is closed, which means that all of the spawned
617        // futures have resolved.
618        self.executor.block(self.rx.next());
619    }
620}