executor.rs

  1use crate::{App, PlatformDispatcher};
  2use async_task::Runnable;
  3use futures::channel::mpsc;
  4use smol::prelude::*;
  5use std::mem::ManuallyDrop;
  6use std::panic::Location;
  7use std::thread::{self, ThreadId};
  8use std::{
  9    fmt::Debug,
 10    marker::PhantomData,
 11    mem,
 12    num::NonZeroUsize,
 13    pin::Pin,
 14    rc::Rc,
 15    sync::{
 16        Arc,
 17        atomic::{AtomicUsize, Ordering::SeqCst},
 18    },
 19    task::{Context, Poll},
 20    time::{Duration, Instant},
 21};
 22use util::TryFutureExt;
 23use waker_fn::waker_fn;
 24
 25#[cfg(any(test, feature = "test-support"))]
 26use rand::rngs::StdRng;
 27
 28/// A pointer to the executor that is currently running,
 29/// for spawning background tasks.
 30#[derive(Clone)]
 31pub struct BackgroundExecutor {
 32    #[doc(hidden)]
 33    pub dispatcher: Arc<dyn PlatformDispatcher>,
 34}
 35
 36/// A pointer to the executor that is currently running,
 37/// for spawning tasks on the main thread.
 38///
 39/// This is intentionally `!Send` via the `not_send` marker field. This is because
 40/// `ForegroundExecutor::spawn` does not require `Send` but checks at runtime that the future is
 41/// only polled from the same thread it was spawned from. These checks would fail when spawning
 42/// foreground tasks from from background threads.
 43#[derive(Clone)]
 44pub struct ForegroundExecutor {
 45    #[doc(hidden)]
 46    pub dispatcher: Arc<dyn PlatformDispatcher>,
 47    not_send: PhantomData<Rc<()>>,
 48}
 49
 50/// Task is a primitive that allows work to happen in the background.
 51///
 52/// It implements [`Future`] so you can `.await` on it.
 53///
 54/// If you drop a task it will be cancelled immediately. Calling [`Task::detach`] allows
 55/// the task to continue running, but with no way to return a value.
 56#[must_use]
 57#[derive(Debug)]
 58pub struct Task<T>(TaskState<T>);
 59
 60#[derive(Debug)]
 61enum TaskState<T> {
 62    /// A task that is ready to return a value
 63    Ready(Option<T>),
 64
 65    /// A task that is currently running.
 66    Spawned(async_task::Task<T>),
 67}
 68
 69impl<T> Task<T> {
 70    /// Creates a new task that will resolve with the value
 71    pub fn ready(val: T) -> Self {
 72        Task(TaskState::Ready(Some(val)))
 73    }
 74
 75    /// Detaching a task runs it to completion in the background
 76    pub fn detach(self) {
 77        match self {
 78            Task(TaskState::Ready(_)) => {}
 79            Task(TaskState::Spawned(task)) => task.detach(),
 80        }
 81    }
 82}
 83
 84impl<E, T> Task<Result<T, E>>
 85where
 86    T: 'static,
 87    E: 'static + Debug,
 88{
 89    /// Run the task to completion in the background and log any
 90    /// errors that occur.
 91    #[track_caller]
 92    pub fn detach_and_log_err(self, cx: &App) {
 93        let location = core::panic::Location::caller();
 94        cx.foreground_executor()
 95            .spawn(self.log_tracked_err(*location))
 96            .detach();
 97    }
 98}
 99
100impl<T> Future for Task<T> {
101    type Output = T;
102
103    fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
104        match unsafe { self.get_unchecked_mut() } {
105            Task(TaskState::Ready(val)) => Poll::Ready(val.take().unwrap()),
106            Task(TaskState::Spawned(task)) => {
107                let name = std::ffi::CString::new("Fiber").unwrap();
108                unsafe {
109                    tracy_client_sys::___tracy_fiber_enter(name.as_ptr());
110                }
111                let res = task.poll(cx);
112                unsafe {
113                    tracy_client_sys::___tracy_fiber_leave();
114                }
115                res
116            }
117        }
118    }
119}
120
121/// A task label is an opaque identifier that you can use to
122/// refer to a task in tests.
123#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
124pub struct TaskLabel(NonZeroUsize);
125
126impl Default for TaskLabel {
127    fn default() -> Self {
128        Self::new()
129    }
130}
131
132impl TaskLabel {
133    /// Construct a new task label.
134    pub fn new() -> Self {
135        static NEXT_TASK_LABEL: AtomicUsize = AtomicUsize::new(1);
136        Self(NEXT_TASK_LABEL.fetch_add(1, SeqCst).try_into().unwrap())
137    }
138}
139
140type AnyLocalFuture<R> = Pin<Box<dyn 'static + Future<Output = R>>>;
141
142type AnyFuture<R> = Pin<Box<dyn 'static + Send + Future<Output = R>>>;
143
144/// BackgroundExecutor lets you run things on background threads.
145/// In production this is a thread pool with no ordering guarantees.
146/// In tests this is simulated by running tasks one by one in a deterministic
147/// (but arbitrary) order controlled by the `SEED` environment variable.
148impl BackgroundExecutor {
149    #[doc(hidden)]
150    pub fn new(dispatcher: Arc<dyn PlatformDispatcher>) -> Self {
151        Self { dispatcher }
152    }
153
154    /// Enqueues the given future to be run to completion on a background thread.
155    pub fn spawn<R>(&self, future: impl Future<Output = R> + Send + 'static) -> Task<R>
156    where
157        R: Send + 'static,
158    {
159        self.spawn_internal::<R>(Box::pin(future), None)
160    }
161
162    /// Enqueues the given future to be run to completion on a background thread.
163    /// The given label can be used to control the priority of the task in tests.
164    pub fn spawn_labeled<R>(
165        &self,
166        label: TaskLabel,
167        future: impl Future<Output = R> + Send + 'static,
168    ) -> Task<R>
169    where
170        R: Send + 'static,
171    {
172        self.spawn_internal::<R>(Box::pin(future), Some(label))
173    }
174
175    fn spawn_internal<R: Send + 'static>(
176        &self,
177        future: AnyFuture<R>,
178        label: Option<TaskLabel>,
179    ) -> Task<R> {
180        let dispatcher = self.dispatcher.clone();
181        let (runnable, task) =
182            async_task::spawn(future, move |runnable| dispatcher.dispatch(runnable, label));
183        runnable.schedule();
184        Task(TaskState::Spawned(task))
185    }
186
187    /// Used by the test harness to run an async test in a synchronous fashion.
188    #[cfg(any(test, feature = "test-support"))]
189    #[track_caller]
190    pub fn block_test<R>(&self, future: impl Future<Output = R>) -> R {
191        if let Ok(value) = self.block_internal(false, future, None) {
192            value
193        } else {
194            unreachable!()
195        }
196    }
197
198    /// Block the current thread until the given future resolves.
199    /// Consider using `block_with_timeout` instead.
200    pub fn block<R>(&self, future: impl Future<Output = R>) -> R {
201        if let Ok(value) = self.block_internal(true, future, None) {
202            value
203        } else {
204            unreachable!()
205        }
206    }
207
208    #[cfg(not(any(test, feature = "test-support")))]
209    pub(crate) fn block_internal<Fut: Future>(
210        &self,
211        _background_only: bool,
212        future: Fut,
213        timeout: Option<Duration>,
214    ) -> Result<Fut::Output, impl Future<Output = Fut::Output> + use<Fut>> {
215        use std::time::Instant;
216
217        let mut future = Box::pin(future);
218        if timeout == Some(Duration::ZERO) {
219            return Err(future);
220        }
221        let deadline = timeout.map(|timeout| Instant::now() + timeout);
222
223        let unparker = self.dispatcher.unparker();
224        let waker = waker_fn(move || {
225            unparker.unpark();
226        });
227        let mut cx = std::task::Context::from_waker(&waker);
228
229        loop {
230            match future.as_mut().poll(&mut cx) {
231                Poll::Ready(result) => return Ok(result),
232                Poll::Pending => {
233                    let timeout =
234                        deadline.map(|deadline| deadline.saturating_duration_since(Instant::now()));
235                    if !self.dispatcher.park(timeout)
236                        && deadline.is_some_and(|deadline| deadline < Instant::now())
237                    {
238                        return Err(future);
239                    }
240                }
241            }
242        }
243    }
244
245    #[cfg(any(test, feature = "test-support"))]
246    #[track_caller]
247    pub(crate) fn block_internal<Fut: Future>(
248        &self,
249        background_only: bool,
250        future: Fut,
251        timeout: Option<Duration>,
252    ) -> Result<Fut::Output, impl Future<Output = Fut::Output> + use<Fut>> {
253        use std::sync::atomic::AtomicBool;
254
255        let mut future = Box::pin(future);
256        if timeout == Some(Duration::ZERO) {
257            return Err(future);
258        }
259        let Some(dispatcher) = self.dispatcher.as_test() else {
260            return Err(future);
261        };
262
263        let mut max_ticks = if timeout.is_some() {
264            dispatcher.gen_block_on_ticks()
265        } else {
266            usize::MAX
267        };
268        let unparker = self.dispatcher.unparker();
269        let awoken = Arc::new(AtomicBool::new(false));
270        let waker = waker_fn({
271            let awoken = awoken.clone();
272            move || {
273                awoken.store(true, SeqCst);
274                unparker.unpark();
275            }
276        });
277        let mut cx = std::task::Context::from_waker(&waker);
278
279        loop {
280            match future.as_mut().poll(&mut cx) {
281                Poll::Ready(result) => return Ok(result),
282                Poll::Pending => {
283                    if max_ticks == 0 {
284                        return Err(future);
285                    }
286                    max_ticks -= 1;
287
288                    if !dispatcher.tick(background_only) {
289                        if awoken.swap(false, SeqCst) {
290                            continue;
291                        }
292
293                        if !dispatcher.parking_allowed() {
294                            if dispatcher.advance_clock_to_next_delayed() {
295                                continue;
296                            }
297                            let mut backtrace_message = String::new();
298                            let mut waiting_message = String::new();
299                            if let Some(backtrace) = dispatcher.waiting_backtrace() {
300                                backtrace_message =
301                                    format!("\nbacktrace of waiting future:\n{:?}", backtrace);
302                            }
303                            if let Some(waiting_hint) = dispatcher.waiting_hint() {
304                                waiting_message = format!("\n  waiting on: {}\n", waiting_hint);
305                            }
306                            panic!(
307                                "parked with nothing left to run{waiting_message}{backtrace_message}",
308                            )
309                        }
310                        self.dispatcher.park(None);
311                    }
312                }
313            }
314        }
315    }
316
317    /// Block the current thread until the given future resolves
318    /// or `duration` has elapsed.
319    pub fn block_with_timeout<Fut: Future>(
320        &self,
321        duration: Duration,
322        future: Fut,
323    ) -> Result<Fut::Output, impl Future<Output = Fut::Output> + use<Fut>> {
324        self.block_internal(true, future, Some(duration))
325    }
326
327    /// Scoped lets you start a number of tasks and waits
328    /// for all of them to complete before returning.
329    pub async fn scoped<'scope, F>(&self, scheduler: F)
330    where
331        F: FnOnce(&mut Scope<'scope>),
332    {
333        let mut scope = Scope::new(self.clone());
334        (scheduler)(&mut scope);
335        let spawned = mem::take(&mut scope.futures)
336            .into_iter()
337            .map(|f| self.spawn(f))
338            .collect::<Vec<_>>();
339        for task in spawned {
340            task.await;
341        }
342    }
343
344    /// Get the current time.
345    ///
346    /// Calling this instead of `std::time::Instant::now` allows the use
347    /// of fake timers in tests.
348    pub fn now(&self) -> Instant {
349        self.dispatcher.now()
350    }
351
352    /// Returns a task that will complete after the given duration.
353    /// Depending on other concurrent tasks the elapsed duration may be longer
354    /// than requested.
355    pub fn timer(&self, duration: Duration) -> Task<()> {
356        if duration.is_zero() {
357            return Task::ready(());
358        }
359        let (runnable, task) = async_task::spawn(async move {}, {
360            let dispatcher = self.dispatcher.clone();
361            move |runnable| dispatcher.dispatch_after(duration, runnable)
362        });
363        runnable.schedule();
364        Task(TaskState::Spawned(task))
365    }
366
367    /// in tests, start_waiting lets you indicate which task is waiting (for debugging only)
368    #[cfg(any(test, feature = "test-support"))]
369    pub fn start_waiting(&self) {
370        self.dispatcher.as_test().unwrap().start_waiting();
371    }
372
373    /// in tests, removes the debugging data added by start_waiting
374    #[cfg(any(test, feature = "test-support"))]
375    pub fn finish_waiting(&self) {
376        self.dispatcher.as_test().unwrap().finish_waiting();
377    }
378
379    /// in tests, run an arbitrary number of tasks (determined by the SEED environment variable)
380    #[cfg(any(test, feature = "test-support"))]
381    pub fn simulate_random_delay(&self) -> impl Future<Output = ()> + use<> {
382        self.dispatcher.as_test().unwrap().simulate_random_delay()
383    }
384
385    /// in tests, indicate that a given task from `spawn_labeled` should run after everything else
386    #[cfg(any(test, feature = "test-support"))]
387    pub fn deprioritize(&self, task_label: TaskLabel) {
388        self.dispatcher.as_test().unwrap().deprioritize(task_label)
389    }
390
391    /// in tests, move time forward. This does not run any tasks, but does make `timer`s ready.
392    #[cfg(any(test, feature = "test-support"))]
393    pub fn advance_clock(&self, duration: Duration) {
394        self.dispatcher.as_test().unwrap().advance_clock(duration)
395    }
396
397    /// in tests, run one task.
398    #[cfg(any(test, feature = "test-support"))]
399    pub fn tick(&self) -> bool {
400        self.dispatcher.as_test().unwrap().tick(false)
401    }
402
403    /// in tests, run all tasks that are ready to run. If after doing so
404    /// the test still has outstanding tasks, this will panic. (See also [`Self::allow_parking`])
405    #[cfg(any(test, feature = "test-support"))]
406    pub fn run_until_parked(&self) {
407        self.dispatcher.as_test().unwrap().run_until_parked()
408    }
409
410    /// in tests, prevents `run_until_parked` from panicking if there are outstanding tasks.
411    /// This is useful when you are integrating other (non-GPUI) futures, like disk access, that
412    /// do take real async time to run.
413    #[cfg(any(test, feature = "test-support"))]
414    pub fn allow_parking(&self) {
415        self.dispatcher.as_test().unwrap().allow_parking();
416    }
417
418    /// undoes the effect of [`Self::allow_parking`].
419    #[cfg(any(test, feature = "test-support"))]
420    pub fn forbid_parking(&self) {
421        self.dispatcher.as_test().unwrap().forbid_parking();
422    }
423
424    /// adds detail to the "parked with nothing let to run" message.
425    #[cfg(any(test, feature = "test-support"))]
426    pub fn set_waiting_hint(&self, msg: Option<String>) {
427        self.dispatcher.as_test().unwrap().set_waiting_hint(msg);
428    }
429
430    /// in tests, returns the rng used by the dispatcher and seeded by the `SEED` environment variable
431    #[cfg(any(test, feature = "test-support"))]
432    pub fn rng(&self) -> StdRng {
433        self.dispatcher.as_test().unwrap().rng()
434    }
435
436    /// How many CPUs are available to the dispatcher.
437    pub fn num_cpus(&self) -> usize {
438        #[cfg(any(test, feature = "test-support"))]
439        return 4;
440
441        #[cfg(not(any(test, feature = "test-support")))]
442        return num_cpus::get();
443    }
444
445    /// Whether we're on the main thread.
446    pub fn is_main_thread(&self) -> bool {
447        self.dispatcher.is_main_thread()
448    }
449
450    #[cfg(any(test, feature = "test-support"))]
451    /// in tests, control the number of ticks that `block_with_timeout` will run before timing out.
452    pub fn set_block_on_ticks(&self, range: std::ops::RangeInclusive<usize>) {
453        self.dispatcher.as_test().unwrap().set_block_on_ticks(range);
454    }
455}
456
457/// ForegroundExecutor runs things on the main thread.
458impl ForegroundExecutor {
459    /// Creates a new ForegroundExecutor from the given PlatformDispatcher.
460    pub fn new(dispatcher: Arc<dyn PlatformDispatcher>) -> Self {
461        Self {
462            dispatcher,
463            not_send: PhantomData,
464        }
465    }
466
467    /// Enqueues the given Task to run on the main thread at some point in the future.
468    #[track_caller]
469    pub fn spawn<R>(&self, future: impl Future<Output = R> + 'static) -> Task<R>
470    where
471        R: 'static,
472    {
473        let dispatcher = self.dispatcher.clone();
474
475        #[track_caller]
476        fn inner<R: 'static>(
477            dispatcher: Arc<dyn PlatformDispatcher>,
478            future: AnyLocalFuture<R>,
479        ) -> Task<R> {
480            let (runnable, task) = spawn_local_with_source_location(future, move |runnable| {
481                dispatcher.dispatch_on_main_thread(runnable)
482            });
483            runnable.schedule();
484            Task(TaskState::Spawned(task))
485        }
486        inner::<R>(dispatcher, Box::pin(future))
487    }
488}
489
490/// Variant of `async_task::spawn_local` that includes the source location of the spawn in panics.
491///
492/// Copy-modified from:
493/// <https://github.com/smol-rs/async-task/blob/ca9dbe1db9c422fd765847fa91306e30a6bb58a9/src/runnable.rs#L405>
494#[track_caller]
495fn spawn_local_with_source_location<Fut, S>(
496    future: Fut,
497    schedule: S,
498) -> (Runnable<()>, async_task::Task<Fut::Output, ()>)
499where
500    Fut: Future + 'static,
501    Fut::Output: 'static,
502    S: async_task::Schedule<()> + Send + Sync + 'static,
503{
504    #[inline]
505    fn thread_id() -> ThreadId {
506        std::thread_local! {
507            static ID: ThreadId = thread::current().id();
508        }
509        ID.try_with(|id| *id)
510            .unwrap_or_else(|_| thread::current().id())
511    }
512
513    struct Checked<F> {
514        id: ThreadId,
515        inner: ManuallyDrop<F>,
516        location: &'static Location<'static>,
517    }
518
519    impl<F> Drop for Checked<F> {
520        fn drop(&mut self) {
521            assert!(
522                self.id == thread_id(),
523                "local task dropped by a thread that didn't spawn it. Task spawned at {}",
524                self.location
525            );
526            unsafe {
527                ManuallyDrop::drop(&mut self.inner);
528            }
529        }
530    }
531
532    impl<F: Future> Future for Checked<F> {
533        type Output = F::Output;
534
535        fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
536            assert!(
537                self.id == thread_id(),
538                "local task polled by a thread that didn't spawn it. Task spawned at {}",
539                self.location
540            );
541            unsafe { self.map_unchecked_mut(|c| &mut *c.inner).poll(cx) }
542        }
543    }
544
545    // Wrap the future into one that checks which thread it's on.
546    let future = Checked {
547        id: thread_id(),
548        inner: ManuallyDrop::new(future),
549        location: Location::caller(),
550    };
551
552    unsafe { async_task::spawn_unchecked(future, schedule) }
553}
554
555/// Scope manages a set of tasks that are enqueued and waited on together. See [`BackgroundExecutor::scoped`].
556pub struct Scope<'a> {
557    executor: BackgroundExecutor,
558    futures: Vec<Pin<Box<dyn Future<Output = ()> + Send + 'static>>>,
559    tx: Option<mpsc::Sender<()>>,
560    rx: mpsc::Receiver<()>,
561    lifetime: PhantomData<&'a ()>,
562}
563
564impl<'a> Scope<'a> {
565    fn new(executor: BackgroundExecutor) -> Self {
566        let (tx, rx) = mpsc::channel(1);
567        Self {
568            executor,
569            tx: Some(tx),
570            rx,
571            futures: Default::default(),
572            lifetime: PhantomData,
573        }
574    }
575
576    /// How many CPUs are available to the dispatcher.
577    pub fn num_cpus(&self) -> usize {
578        self.executor.num_cpus()
579    }
580
581    /// Spawn a future into this scope.
582    pub fn spawn<F>(&mut self, f: F)
583    where
584        F: Future<Output = ()> + Send + 'a,
585    {
586        let tx = self.tx.clone().unwrap();
587
588        // SAFETY: The 'a lifetime is guaranteed to outlive any of these futures because
589        // dropping this `Scope` blocks until all of the futures have resolved.
590        let f = unsafe {
591            mem::transmute::<
592                Pin<Box<dyn Future<Output = ()> + Send + 'a>>,
593                Pin<Box<dyn Future<Output = ()> + Send + 'static>>,
594            >(Box::pin(async move {
595                f.await;
596                drop(tx);
597            }))
598        };
599        self.futures.push(f);
600    }
601}
602
603impl Drop for Scope<'_> {
604    fn drop(&mut self) {
605        self.tx.take().unwrap();
606
607        // Wait until the channel is closed, which means that all of the spawned
608        // futures have resolved.
609        self.executor.block(self.rx.next());
610    }
611}