executor.rs

  1use crate::{App, PlatformDispatcher};
  2use async_task::Runnable;
  3use futures::channel::mpsc;
  4use smol::prelude::*;
  5use std::mem::ManuallyDrop;
  6use std::panic::Location;
  7use std::thread::{self, ThreadId};
  8use std::{
  9    fmt::Debug,
 10    marker::PhantomData,
 11    mem,
 12    num::NonZeroUsize,
 13    pin::Pin,
 14    rc::Rc,
 15    sync::{
 16        Arc,
 17        atomic::{AtomicUsize, Ordering::SeqCst},
 18    },
 19    task::{Context, Poll},
 20    time::{Duration, Instant},
 21};
 22use util::TryFutureExt;
 23use waker_fn::waker_fn;
 24
 25#[cfg(any(test, feature = "test-support"))]
 26use rand::rngs::StdRng;
 27
 28/// A pointer to the executor that is currently running,
 29/// for spawning background tasks.
 30#[derive(Clone)]
 31pub struct BackgroundExecutor {
 32    #[doc(hidden)]
 33    pub dispatcher: Arc<dyn PlatformDispatcher>,
 34}
 35
 36/// A pointer to the executor that is currently running,
 37/// for spawning tasks on the main thread.
 38///
 39/// This is intentionally `!Send` via the `not_send` marker field. This is because
 40/// `ForegroundExecutor::spawn` does not require `Send` but checks at runtime that the future is
 41/// only polled from the same thread it was spawned from. These checks would fail when spawning
 42/// foreground tasks from from background threads.
 43#[derive(Clone)]
 44pub struct ForegroundExecutor {
 45    #[doc(hidden)]
 46    pub dispatcher: Arc<dyn PlatformDispatcher>,
 47    not_send: PhantomData<Rc<()>>,
 48}
 49
 50/// Task is a primitive that allows work to happen in the background.
 51///
 52/// It implements [`Future`] so you can `.await` on it.
 53///
 54/// If you drop a task it will be cancelled immediately. Calling [`Task::detach`] allows
 55/// the task to continue running, but with no way to return a value.
 56#[must_use]
 57#[derive(Debug)]
 58pub struct Task<T>(TaskState<T>);
 59
 60#[derive(Debug)]
 61enum TaskState<T> {
 62    /// A task that is ready to return a value
 63    Ready(Option<T>),
 64
 65    /// A task that is currently running.
 66    Spawned(async_task::Task<T>),
 67}
 68
 69impl<T> Task<T> {
 70    /// Creates a new task that will resolve with the value
 71    pub const fn ready(val: T) -> Self {
 72        Task(TaskState::Ready(Some(val)))
 73    }
 74
 75    /// Detaching a task runs it to completion in the background
 76    pub fn detach(self) {
 77        match self {
 78            Task(TaskState::Ready(_)) => {}
 79            Task(TaskState::Spawned(task)) => task.detach(),
 80        }
 81    }
 82}
 83
 84impl<E, T> Task<Result<T, E>>
 85where
 86    T: 'static,
 87    E: 'static + Debug,
 88{
 89    /// Run the task to completion in the background and log any
 90    /// errors that occur.
 91    #[track_caller]
 92    pub fn detach_and_log_err(self, cx: &App) {
 93        let location = core::panic::Location::caller();
 94        cx.foreground_executor()
 95            .spawn(self.log_tracked_err(*location))
 96            .detach();
 97    }
 98}
 99
100impl<T> Future for Task<T> {
101    type Output = T;
102
103    fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
104        match unsafe { self.get_unchecked_mut() } {
105            Task(TaskState::Ready(val)) => Poll::Ready(val.take().unwrap()),
106            Task(TaskState::Spawned(task)) => task.poll(cx),
107        }
108    }
109}
110
111/// A task label is an opaque identifier that you can use to
112/// refer to a task in tests.
113#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
114pub struct TaskLabel(NonZeroUsize);
115
116impl Default for TaskLabel {
117    fn default() -> Self {
118        Self::new()
119    }
120}
121
122impl TaskLabel {
123    /// Construct a new task label.
124    pub fn new() -> Self {
125        static NEXT_TASK_LABEL: AtomicUsize = AtomicUsize::new(1);
126        Self(NEXT_TASK_LABEL.fetch_add(1, SeqCst).try_into().unwrap())
127    }
128}
129
130type AnyLocalFuture<R> = Pin<Box<dyn 'static + Future<Output = R>>>;
131
132type AnyFuture<R> = Pin<Box<dyn 'static + Send + Future<Output = R>>>;
133
134/// BackgroundExecutor lets you run things on background threads.
135/// In production this is a thread pool with no ordering guarantees.
136/// In tests this is simulated by running tasks one by one in a deterministic
137/// (but arbitrary) order controlled by the `SEED` environment variable.
138impl BackgroundExecutor {
139    #[doc(hidden)]
140    pub fn new(dispatcher: Arc<dyn PlatformDispatcher>) -> Self {
141        Self { dispatcher }
142    }
143
144    /// Enqueues the given future to be run to completion on a background thread.
145    pub fn spawn<R>(&self, future: impl Future<Output = R> + Send + 'static) -> Task<R>
146    where
147        R: Send + 'static,
148    {
149        self.spawn_internal::<R>(Box::pin(future), None)
150    }
151
152    /// Enqueues the given future to be run to completion on a background thread.
153    /// The given label can be used to control the priority of the task in tests.
154    pub fn spawn_labeled<R>(
155        &self,
156        label: TaskLabel,
157        future: impl Future<Output = R> + Send + 'static,
158    ) -> Task<R>
159    where
160        R: Send + 'static,
161    {
162        self.spawn_internal::<R>(Box::pin(future), Some(label))
163    }
164
165    fn spawn_internal<R: Send + 'static>(
166        &self,
167        future: AnyFuture<R>,
168        label: Option<TaskLabel>,
169    ) -> Task<R> {
170        let dispatcher = self.dispatcher.clone();
171        let (runnable, task) =
172            async_task::spawn(future, move |runnable| dispatcher.dispatch(runnable, label));
173        runnable.schedule();
174        Task(TaskState::Spawned(task))
175    }
176
177    /// Used by the test harness to run an async test in a synchronous fashion.
178    #[cfg(any(test, feature = "test-support"))]
179    #[track_caller]
180    pub fn block_test<R>(&self, future: impl Future<Output = R>) -> R {
181        if let Ok(value) = self.block_internal(false, future, None) {
182            value
183        } else {
184            unreachable!()
185        }
186    }
187
188    /// Block the current thread until the given future resolves.
189    /// Consider using `block_with_timeout` instead.
190    pub fn block<R>(&self, future: impl Future<Output = R>) -> R {
191        if let Ok(value) = self.block_internal(true, future, None) {
192            value
193        } else {
194            unreachable!()
195        }
196    }
197
198    #[cfg(not(any(test, feature = "test-support")))]
199    pub(crate) fn block_internal<Fut: Future>(
200        &self,
201        _background_only: bool,
202        future: Fut,
203        timeout: Option<Duration>,
204    ) -> Result<Fut::Output, impl Future<Output = Fut::Output> + use<Fut>> {
205        use std::time::Instant;
206
207        let mut future = Box::pin(future);
208        if timeout == Some(Duration::ZERO) {
209            return Err(future);
210        }
211        let deadline = timeout.map(|timeout| Instant::now() + timeout);
212
213        let parker = parking::Parker::new();
214        let unparker = parker.unparker();
215        let waker = waker_fn(move || {
216            unparker.unpark();
217        });
218        let mut cx = std::task::Context::from_waker(&waker);
219
220        loop {
221            match future.as_mut().poll(&mut cx) {
222                Poll::Ready(result) => return Ok(result),
223                Poll::Pending => {
224                    let timeout =
225                        deadline.map(|deadline| deadline.saturating_duration_since(Instant::now()));
226                    if let Some(timeout) = timeout {
227                        if !parker.park_timeout(timeout)
228                            && deadline.is_some_and(|deadline| deadline < Instant::now())
229                        {
230                            return Err(future);
231                        }
232                    } else {
233                        parker.park();
234                    }
235                }
236            }
237        }
238    }
239
240    #[cfg(any(test, feature = "test-support"))]
241    #[track_caller]
242    pub(crate) fn block_internal<Fut: Future>(
243        &self,
244        background_only: bool,
245        future: Fut,
246        timeout: Option<Duration>,
247    ) -> Result<Fut::Output, impl Future<Output = Fut::Output> + use<Fut>> {
248        use std::sync::atomic::AtomicBool;
249
250        use parking::Parker;
251
252        let mut future = Box::pin(future);
253        if timeout == Some(Duration::ZERO) {
254            return Err(future);
255        }
256        let Some(dispatcher) = self.dispatcher.as_test() else {
257            return Err(future);
258        };
259
260        let mut max_ticks = if timeout.is_some() {
261            dispatcher.gen_block_on_ticks()
262        } else {
263            usize::MAX
264        };
265
266        let parker = Parker::new();
267        let unparker = parker.unparker();
268
269        let awoken = Arc::new(AtomicBool::new(false));
270        let waker = waker_fn({
271            let awoken = awoken.clone();
272            let unparker = unparker.clone();
273            move || {
274                awoken.store(true, SeqCst);
275                unparker.unpark();
276            }
277        });
278        let mut cx = std::task::Context::from_waker(&waker);
279
280        loop {
281            match future.as_mut().poll(&mut cx) {
282                Poll::Ready(result) => return Ok(result),
283                Poll::Pending => {
284                    if max_ticks == 0 {
285                        return Err(future);
286                    }
287                    max_ticks -= 1;
288
289                    if !dispatcher.tick(background_only) {
290                        if awoken.swap(false, SeqCst) {
291                            continue;
292                        }
293
294                        if !dispatcher.parking_allowed() {
295                            if dispatcher.advance_clock_to_next_delayed() {
296                                continue;
297                            }
298                            let mut backtrace_message = String::new();
299                            let mut waiting_message = String::new();
300                            if let Some(backtrace) = dispatcher.waiting_backtrace() {
301                                backtrace_message =
302                                    format!("\nbacktrace of waiting future:\n{:?}", backtrace);
303                            }
304                            if let Some(waiting_hint) = dispatcher.waiting_hint() {
305                                waiting_message = format!("\n  waiting on: {}\n", waiting_hint);
306                            }
307                            panic!(
308                                "parked with nothing left to run{waiting_message}{backtrace_message}",
309                            )
310                        }
311                        dispatcher.set_unparker(unparker.clone());
312                        parker.park();
313                    }
314                }
315            }
316        }
317    }
318
319    /// Block the current thread until the given future resolves
320    /// or `duration` has elapsed.
321    pub fn block_with_timeout<Fut: Future>(
322        &self,
323        duration: Duration,
324        future: Fut,
325    ) -> Result<Fut::Output, impl Future<Output = Fut::Output> + use<Fut>> {
326        self.block_internal(true, future, Some(duration))
327    }
328
329    /// Scoped lets you start a number of tasks and waits
330    /// for all of them to complete before returning.
331    pub async fn scoped<'scope, F>(&self, scheduler: F)
332    where
333        F: FnOnce(&mut Scope<'scope>),
334    {
335        let mut scope = Scope::new(self.clone());
336        (scheduler)(&mut scope);
337        let spawned = mem::take(&mut scope.futures)
338            .into_iter()
339            .map(|f| self.spawn(f))
340            .collect::<Vec<_>>();
341        for task in spawned {
342            task.await;
343        }
344    }
345
346    /// Get the current time.
347    ///
348    /// Calling this instead of `std::time::Instant::now` allows the use
349    /// of fake timers in tests.
350    pub fn now(&self) -> Instant {
351        self.dispatcher.now()
352    }
353
354    /// Returns a task that will complete after the given duration.
355    /// Depending on other concurrent tasks the elapsed duration may be longer
356    /// than requested.
357    pub fn timer(&self, duration: Duration) -> Task<()> {
358        if duration.is_zero() {
359            return Task::ready(());
360        }
361        let (runnable, task) = async_task::spawn(async move {}, {
362            let dispatcher = self.dispatcher.clone();
363            move |runnable| dispatcher.dispatch_after(duration, runnable)
364        });
365        runnable.schedule();
366        Task(TaskState::Spawned(task))
367    }
368
369    /// in tests, start_waiting lets you indicate which task is waiting (for debugging only)
370    #[cfg(any(test, feature = "test-support"))]
371    pub fn start_waiting(&self) {
372        self.dispatcher.as_test().unwrap().start_waiting();
373    }
374
375    /// in tests, removes the debugging data added by start_waiting
376    #[cfg(any(test, feature = "test-support"))]
377    pub fn finish_waiting(&self) {
378        self.dispatcher.as_test().unwrap().finish_waiting();
379    }
380
381    /// in tests, run an arbitrary number of tasks (determined by the SEED environment variable)
382    #[cfg(any(test, feature = "test-support"))]
383    pub fn simulate_random_delay(&self) -> impl Future<Output = ()> + use<> {
384        self.dispatcher.as_test().unwrap().simulate_random_delay()
385    }
386
387    /// in tests, indicate that a given task from `spawn_labeled` should run after everything else
388    #[cfg(any(test, feature = "test-support"))]
389    pub fn deprioritize(&self, task_label: TaskLabel) {
390        self.dispatcher.as_test().unwrap().deprioritize(task_label)
391    }
392
393    /// in tests, move time forward. This does not run any tasks, but does make `timer`s ready.
394    #[cfg(any(test, feature = "test-support"))]
395    pub fn advance_clock(&self, duration: Duration) {
396        self.dispatcher.as_test().unwrap().advance_clock(duration)
397    }
398
399    /// in tests, run one task.
400    #[cfg(any(test, feature = "test-support"))]
401    pub fn tick(&self) -> bool {
402        self.dispatcher.as_test().unwrap().tick(false)
403    }
404
405    /// in tests, run all tasks that are ready to run. If after doing so
406    /// the test still has outstanding tasks, this will panic. (See also [`Self::allow_parking`])
407    #[cfg(any(test, feature = "test-support"))]
408    pub fn run_until_parked(&self) {
409        self.dispatcher.as_test().unwrap().run_until_parked()
410    }
411
412    /// in tests, prevents `run_until_parked` from panicking if there are outstanding tasks.
413    /// This is useful when you are integrating other (non-GPUI) futures, like disk access, that
414    /// do take real async time to run.
415    #[cfg(any(test, feature = "test-support"))]
416    pub fn allow_parking(&self) {
417        self.dispatcher.as_test().unwrap().allow_parking();
418    }
419
420    /// undoes the effect of [`Self::allow_parking`].
421    #[cfg(any(test, feature = "test-support"))]
422    pub fn forbid_parking(&self) {
423        self.dispatcher.as_test().unwrap().forbid_parking();
424    }
425
426    /// adds detail to the "parked with nothing let to run" message.
427    #[cfg(any(test, feature = "test-support"))]
428    pub fn set_waiting_hint(&self, msg: Option<String>) {
429        self.dispatcher.as_test().unwrap().set_waiting_hint(msg);
430    }
431
432    /// in tests, returns the rng used by the dispatcher and seeded by the `SEED` environment variable
433    #[cfg(any(test, feature = "test-support"))]
434    pub fn rng(&self) -> StdRng {
435        self.dispatcher.as_test().unwrap().rng()
436    }
437
438    /// How many CPUs are available to the dispatcher.
439    pub fn num_cpus(&self) -> usize {
440        #[cfg(any(test, feature = "test-support"))]
441        return 4;
442
443        #[cfg(not(any(test, feature = "test-support")))]
444        return num_cpus::get();
445    }
446
447    /// Whether we're on the main thread.
448    pub fn is_main_thread(&self) -> bool {
449        self.dispatcher.is_main_thread()
450    }
451
452    #[cfg(any(test, feature = "test-support"))]
453    /// in tests, control the number of ticks that `block_with_timeout` will run before timing out.
454    pub fn set_block_on_ticks(&self, range: std::ops::RangeInclusive<usize>) {
455        self.dispatcher.as_test().unwrap().set_block_on_ticks(range);
456    }
457}
458
459/// ForegroundExecutor runs things on the main thread.
460impl ForegroundExecutor {
461    /// Creates a new ForegroundExecutor from the given PlatformDispatcher.
462    pub fn new(dispatcher: Arc<dyn PlatformDispatcher>) -> Self {
463        Self {
464            dispatcher,
465            not_send: PhantomData,
466        }
467    }
468
469    /// Enqueues the given Task to run on the main thread at some point in the future.
470    #[track_caller]
471    pub fn spawn<R>(&self, future: impl Future<Output = R> + 'static) -> Task<R>
472    where
473        R: 'static,
474    {
475        let dispatcher = self.dispatcher.clone();
476
477        #[track_caller]
478        fn inner<R: 'static>(
479            dispatcher: Arc<dyn PlatformDispatcher>,
480            future: AnyLocalFuture<R>,
481        ) -> Task<R> {
482            let (runnable, task) = spawn_local_with_source_location(future, move |runnable| {
483                dispatcher.dispatch_on_main_thread(runnable)
484            });
485            runnable.schedule();
486            Task(TaskState::Spawned(task))
487        }
488        inner::<R>(dispatcher, Box::pin(future))
489    }
490}
491
492/// Variant of `async_task::spawn_local` that includes the source location of the spawn in panics.
493///
494/// Copy-modified from:
495/// <https://github.com/smol-rs/async-task/blob/ca9dbe1db9c422fd765847fa91306e30a6bb58a9/src/runnable.rs#L405>
496#[track_caller]
497fn spawn_local_with_source_location<Fut, S>(
498    future: Fut,
499    schedule: S,
500) -> (Runnable<()>, async_task::Task<Fut::Output, ()>)
501where
502    Fut: Future + 'static,
503    Fut::Output: 'static,
504    S: async_task::Schedule<()> + Send + Sync + 'static,
505{
506    #[inline]
507    fn thread_id() -> ThreadId {
508        std::thread_local! {
509            static ID: ThreadId = thread::current().id();
510        }
511        ID.try_with(|id| *id)
512            .unwrap_or_else(|_| thread::current().id())
513    }
514
515    struct Checked<F> {
516        id: ThreadId,
517        inner: ManuallyDrop<F>,
518        location: &'static Location<'static>,
519    }
520
521    impl<F> Drop for Checked<F> {
522        fn drop(&mut self) {
523            assert!(
524                self.id == thread_id(),
525                "local task dropped by a thread that didn't spawn it. Task spawned at {}",
526                self.location
527            );
528            unsafe { ManuallyDrop::drop(&mut self.inner) };
529        }
530    }
531
532    impl<F: Future> Future for Checked<F> {
533        type Output = F::Output;
534
535        fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
536            assert!(
537                self.id == thread_id(),
538                "local task polled by a thread that didn't spawn it. Task spawned at {}",
539                self.location
540            );
541            unsafe { self.map_unchecked_mut(|c| &mut *c.inner).poll(cx) }
542        }
543    }
544
545    // Wrap the future into one that checks which thread it's on.
546    let future = Checked {
547        id: thread_id(),
548        inner: ManuallyDrop::new(future),
549        location: Location::caller(),
550    };
551
552    unsafe { async_task::spawn_unchecked(future, schedule) }
553}
554
555/// Scope manages a set of tasks that are enqueued and waited on together. See [`BackgroundExecutor::scoped`].
556pub struct Scope<'a> {
557    executor: BackgroundExecutor,
558    futures: Vec<Pin<Box<dyn Future<Output = ()> + Send + 'static>>>,
559    tx: Option<mpsc::Sender<()>>,
560    rx: mpsc::Receiver<()>,
561    lifetime: PhantomData<&'a ()>,
562}
563
564impl<'a> Scope<'a> {
565    fn new(executor: BackgroundExecutor) -> Self {
566        let (tx, rx) = mpsc::channel(1);
567        Self {
568            executor,
569            tx: Some(tx),
570            rx,
571            futures: Default::default(),
572            lifetime: PhantomData,
573        }
574    }
575
576    /// How many CPUs are available to the dispatcher.
577    pub fn num_cpus(&self) -> usize {
578        self.executor.num_cpus()
579    }
580
581    /// Spawn a future into this scope.
582    pub fn spawn<F>(&mut self, f: F)
583    where
584        F: Future<Output = ()> + Send + 'a,
585    {
586        let tx = self.tx.clone().unwrap();
587
588        // SAFETY: The 'a lifetime is guaranteed to outlive any of these futures because
589        // dropping this `Scope` blocks until all of the futures have resolved.
590        let f = unsafe {
591            mem::transmute::<
592                Pin<Box<dyn Future<Output = ()> + Send + 'a>>,
593                Pin<Box<dyn Future<Output = ()> + Send + 'static>>,
594            >(Box::pin(async move {
595                f.await;
596                drop(tx);
597            }))
598        };
599        self.futures.push(f);
600    }
601}
602
603impl Drop for Scope<'_> {
604    fn drop(&mut self) {
605        self.tx.take().unwrap();
606
607        // Wait until the channel is closed, which means that all of the spawned
608        // futures have resolved.
609        self.executor.block(self.rx.next());
610    }
611}