executor.rs

  1use crate::{AppContext, PlatformDispatcher};
  2use futures::channel::mpsc;
  3use smol::prelude::*;
  4use std::{
  5    fmt::Debug,
  6    marker::PhantomData,
  7    mem,
  8    num::NonZeroUsize,
  9    pin::Pin,
 10    rc::Rc,
 11    sync::{
 12        atomic::{AtomicUsize, Ordering::SeqCst},
 13        Arc,
 14    },
 15    task::{Context, Poll},
 16    time::{Duration, Instant},
 17};
 18use util::TryFutureExt;
 19use waker_fn::waker_fn;
 20
 21#[cfg(any(test, feature = "test-support"))]
 22use rand::rngs::StdRng;
 23
 24/// A pointer to the executor that is currently running,
 25/// for spawning background tasks.
 26#[derive(Clone)]
 27pub struct BackgroundExecutor {
 28    #[doc(hidden)]
 29    pub dispatcher: Arc<dyn PlatformDispatcher>,
 30}
 31
 32/// A pointer to the executor that is currently running,
 33/// for spawning tasks on the main thread.
 34#[derive(Clone)]
 35pub struct ForegroundExecutor {
 36    #[doc(hidden)]
 37    pub dispatcher: Arc<dyn PlatformDispatcher>,
 38    not_send: PhantomData<Rc<()>>,
 39}
 40
 41/// Task is a primitive that allows work to happen in the background.
 42///
 43/// It implements [`Future`] so you can `.await` on it.
 44///
 45/// If you drop a task it will be cancelled immediately. Calling [`Task::detach`] allows
 46/// the task to continue running, but with no way to return a value.
 47#[must_use]
 48#[derive(Debug)]
 49pub enum Task<T> {
 50    /// A task that is ready to return a value
 51    Ready(Option<T>),
 52
 53    /// A task that is currently running.
 54    Spawned(async_task::Task<T>),
 55}
 56
 57impl<T> Task<T> {
 58    /// Creates a new task that will resolve with the value
 59    pub fn ready(val: T) -> Self {
 60        Task::Ready(Some(val))
 61    }
 62
 63    /// Detaching a task runs it to completion in the background
 64    pub fn detach(self) {
 65        match self {
 66            Task::Ready(_) => {}
 67            Task::Spawned(task) => task.detach(),
 68        }
 69    }
 70}
 71
 72impl<E, T> Task<Result<T, E>>
 73where
 74    T: 'static,
 75    E: 'static + Debug,
 76{
 77    /// Run the task to completion in the background and log any
 78    /// errors that occur.
 79    #[track_caller]
 80    pub fn detach_and_log_err(self, cx: &AppContext) {
 81        let location = core::panic::Location::caller();
 82        cx.foreground_executor()
 83            .spawn(self.log_tracked_err(*location))
 84            .detach();
 85    }
 86}
 87
 88impl<T> Future for Task<T> {
 89    type Output = T;
 90
 91    fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
 92        match unsafe { self.get_unchecked_mut() } {
 93            Task::Ready(val) => Poll::Ready(val.take().unwrap()),
 94            Task::Spawned(task) => task.poll(cx),
 95        }
 96    }
 97}
 98
 99/// A task label is an opaque identifier that you can use to
100/// refer to a task in tests.
101#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
102pub struct TaskLabel(NonZeroUsize);
103
104impl Default for TaskLabel {
105    fn default() -> Self {
106        Self::new()
107    }
108}
109
110impl TaskLabel {
111    /// Construct a new task label.
112    pub fn new() -> Self {
113        static NEXT_TASK_LABEL: AtomicUsize = AtomicUsize::new(1);
114        Self(NEXT_TASK_LABEL.fetch_add(1, SeqCst).try_into().unwrap())
115    }
116}
117
118type AnyLocalFuture<R> = Pin<Box<dyn 'static + Future<Output = R>>>;
119
120type AnyFuture<R> = Pin<Box<dyn 'static + Send + Future<Output = R>>>;
121
122/// BackgroundExecutor lets you run things on background threads.
123/// In production this is a thread pool with no ordering guarantees.
124/// In tests this is simulated by running tasks one by one in a deterministic
125/// (but arbitrary) order controlled by the `SEED` environment variable.
126impl BackgroundExecutor {
127    #[doc(hidden)]
128    pub fn new(dispatcher: Arc<dyn PlatformDispatcher>) -> Self {
129        Self { dispatcher }
130    }
131
132    /// Enqueues the given future to be run to completion on a background thread.
133    pub fn spawn<R>(&self, future: impl Future<Output = R> + Send + 'static) -> Task<R>
134    where
135        R: Send + 'static,
136    {
137        self.spawn_internal::<R>(Box::pin(future), None)
138    }
139
140    /// Enqueues the given future to be run to completion on a background thread.
141    /// The given label can be used to control the priority of the task in tests.
142    pub fn spawn_labeled<R>(
143        &self,
144        label: TaskLabel,
145        future: impl Future<Output = R> + Send + 'static,
146    ) -> Task<R>
147    where
148        R: Send + 'static,
149    {
150        self.spawn_internal::<R>(Box::pin(future), Some(label))
151    }
152
153    fn spawn_internal<R: Send + 'static>(
154        &self,
155        future: AnyFuture<R>,
156        label: Option<TaskLabel>,
157    ) -> Task<R> {
158        let dispatcher = self.dispatcher.clone();
159        let (runnable, task) =
160            async_task::spawn(future, move |runnable| dispatcher.dispatch(runnable, label));
161        runnable.schedule();
162        Task::Spawned(task)
163    }
164
165    /// Used by the test harness to run an async test in a synchronous fashion.
166    #[cfg(any(test, feature = "test-support"))]
167    #[track_caller]
168    pub fn block_test<R>(&self, future: impl Future<Output = R>) -> R {
169        if let Ok(value) = self.block_internal(false, future, None) {
170            value
171        } else {
172            unreachable!()
173        }
174    }
175
176    /// Block the current thread until the given future resolves.
177    /// Consider using `block_with_timeout` instead.
178    pub fn block<R>(&self, future: impl Future<Output = R>) -> R {
179        if let Ok(value) = self.block_internal(true, future, None) {
180            value
181        } else {
182            unreachable!()
183        }
184    }
185
186    #[cfg(not(any(test, feature = "test-support")))]
187    pub(crate) fn block_internal<R>(
188        &self,
189        _background_only: bool,
190        future: impl Future<Output = R>,
191        timeout: Option<Duration>,
192    ) -> Result<R, impl Future<Output = R>> {
193        use std::time::Instant;
194
195        let mut future = Box::pin(future);
196        if timeout == Some(Duration::ZERO) {
197            return Err(future);
198        }
199        let deadline = timeout.map(|timeout| Instant::now() + timeout);
200
201        let unparker = self.dispatcher.unparker();
202        let waker = waker_fn(move || {
203            unparker.unpark();
204        });
205        let mut cx = std::task::Context::from_waker(&waker);
206
207        loop {
208            match future.as_mut().poll(&mut cx) {
209                Poll::Ready(result) => return Ok(result),
210                Poll::Pending => {
211                    let timeout =
212                        deadline.map(|deadline| deadline.saturating_duration_since(Instant::now()));
213                    if !self.dispatcher.park(timeout)
214                        && deadline.is_some_and(|deadline| deadline < Instant::now())
215                    {
216                        return Err(future);
217                    }
218                }
219            }
220        }
221    }
222
223    #[cfg(any(test, feature = "test-support"))]
224    #[track_caller]
225    pub(crate) fn block_internal<R>(
226        &self,
227        background_only: bool,
228        future: impl Future<Output = R>,
229        timeout: Option<Duration>,
230    ) -> Result<R, impl Future<Output = R>> {
231        use std::sync::atomic::AtomicBool;
232
233        let mut future = Box::pin(future);
234        if timeout == Some(Duration::ZERO) {
235            return Err(future);
236        }
237        let Some(dispatcher) = self.dispatcher.as_test() else {
238            return Err(future);
239        };
240
241        let mut max_ticks = if timeout.is_some() {
242            dispatcher.gen_block_on_ticks()
243        } else {
244            usize::MAX
245        };
246        let unparker = self.dispatcher.unparker();
247        let awoken = Arc::new(AtomicBool::new(false));
248        let waker = waker_fn({
249            let awoken = awoken.clone();
250            move || {
251                awoken.store(true, SeqCst);
252                unparker.unpark();
253            }
254        });
255        let mut cx = std::task::Context::from_waker(&waker);
256
257        loop {
258            match future.as_mut().poll(&mut cx) {
259                Poll::Ready(result) => return Ok(result),
260                Poll::Pending => {
261                    if max_ticks == 0 {
262                        return Err(future);
263                    }
264                    max_ticks -= 1;
265
266                    if !dispatcher.tick(background_only) {
267                        if awoken.swap(false, SeqCst) {
268                            continue;
269                        }
270
271                        if !dispatcher.parking_allowed() {
272                            let mut backtrace_message = String::new();
273                            let mut waiting_message = String::new();
274                            if let Some(backtrace) = dispatcher.waiting_backtrace() {
275                                backtrace_message =
276                                    format!("\nbacktrace of waiting future:\n{:?}", backtrace);
277                            }
278                            if let Some(waiting_hint) = dispatcher.waiting_hint() {
279                                waiting_message = format!("\n  waiting on: {}\n", waiting_hint);
280                            }
281                            panic!(
282                                    "parked with nothing left to run{waiting_message}{backtrace_message}",
283                                )
284                        }
285                        self.dispatcher.park(None);
286                    }
287                }
288            }
289        }
290    }
291
292    /// Block the current thread until the given future resolves
293    /// or `duration` has elapsed.
294    pub fn block_with_timeout<R>(
295        &self,
296        duration: Duration,
297        future: impl Future<Output = R>,
298    ) -> Result<R, impl Future<Output = R>> {
299        self.block_internal(true, future, Some(duration))
300    }
301
302    /// Scoped lets you start a number of tasks and waits
303    /// for all of them to complete before returning.
304    pub async fn scoped<'scope, F>(&self, scheduler: F)
305    where
306        F: FnOnce(&mut Scope<'scope>),
307    {
308        let mut scope = Scope::new(self.clone());
309        (scheduler)(&mut scope);
310        let spawned = mem::take(&mut scope.futures)
311            .into_iter()
312            .map(|f| self.spawn(f))
313            .collect::<Vec<_>>();
314        for task in spawned {
315            task.await;
316        }
317    }
318
319    /// Get the current time.
320    ///
321    /// Calling this instead of `std::time::Instant::now` allows the use
322    /// of fake timers in tests.
323    pub fn now(&self) -> Instant {
324        self.dispatcher.now()
325    }
326
327    /// Returns a task that will complete after the given duration.
328    /// Depending on other concurrent tasks the elapsed duration may be longer
329    /// than requested.
330    pub fn timer(&self, duration: Duration) -> Task<()> {
331        if duration.is_zero() {
332            return Task::ready(());
333        }
334        let (runnable, task) = async_task::spawn(async move {}, {
335            let dispatcher = self.dispatcher.clone();
336            move |runnable| dispatcher.dispatch_after(duration, runnable)
337        });
338        runnable.schedule();
339        Task::Spawned(task)
340    }
341
342    /// in tests, start_waiting lets you indicate which task is waiting (for debugging only)
343    #[cfg(any(test, feature = "test-support"))]
344    pub fn start_waiting(&self) {
345        self.dispatcher.as_test().unwrap().start_waiting();
346    }
347
348    /// in tests, removes the debugging data added by start_waiting
349    #[cfg(any(test, feature = "test-support"))]
350    pub fn finish_waiting(&self) {
351        self.dispatcher.as_test().unwrap().finish_waiting();
352    }
353
354    /// in tests, run an arbitrary number of tasks (determined by the SEED environment variable)
355    #[cfg(any(test, feature = "test-support"))]
356    pub fn simulate_random_delay(&self) -> impl Future<Output = ()> {
357        self.dispatcher.as_test().unwrap().simulate_random_delay()
358    }
359
360    /// in tests, indicate that a given task from `spawn_labeled` should run after everything else
361    #[cfg(any(test, feature = "test-support"))]
362    pub fn deprioritize(&self, task_label: TaskLabel) {
363        self.dispatcher.as_test().unwrap().deprioritize(task_label)
364    }
365
366    /// in tests, move time forward. This does not run any tasks, but does make `timer`s ready.
367    #[cfg(any(test, feature = "test-support"))]
368    pub fn advance_clock(&self, duration: Duration) {
369        self.dispatcher.as_test().unwrap().advance_clock(duration)
370    }
371
372    /// in tests, run one task.
373    #[cfg(any(test, feature = "test-support"))]
374    pub fn tick(&self) -> bool {
375        self.dispatcher.as_test().unwrap().tick(false)
376    }
377
378    /// in tests, run all tasks that are ready to run. If after doing so
379    /// the test still has outstanding tasks, this will panic. (See also `allow_parking`)
380    #[cfg(any(test, feature = "test-support"))]
381    pub fn run_until_parked(&self) {
382        self.dispatcher.as_test().unwrap().run_until_parked()
383    }
384
385    /// in tests, prevents `run_until_parked` from panicking if there are outstanding tasks.
386    /// This is useful when you are integrating other (non-GPUI) futures, like disk access, that
387    /// do take real async time to run.
388    #[cfg(any(test, feature = "test-support"))]
389    pub fn allow_parking(&self) {
390        self.dispatcher.as_test().unwrap().allow_parking();
391    }
392
393    /// undoes the effect of [`allow_parking`].
394    #[cfg(any(test, feature = "test-support"))]
395    pub fn forbid_parking(&self) {
396        self.dispatcher.as_test().unwrap().forbid_parking();
397    }
398
399    /// adds detail to the "parked with nothing let to run" message.
400    #[cfg(any(test, feature = "test-support"))]
401    pub fn set_waiting_hint(&self, msg: Option<String>) {
402        self.dispatcher.as_test().unwrap().set_waiting_hint(msg);
403    }
404
405    /// in tests, returns the rng used by the dispatcher and seeded by the `SEED` environment variable
406    #[cfg(any(test, feature = "test-support"))]
407    pub fn rng(&self) -> StdRng {
408        self.dispatcher.as_test().unwrap().rng()
409    }
410
411    /// How many CPUs are available to the dispatcher.
412    pub fn num_cpus(&self) -> usize {
413        #[cfg(any(test, feature = "test-support"))]
414        return 4;
415
416        #[cfg(not(any(test, feature = "test-support")))]
417        return num_cpus::get();
418    }
419
420    /// Whether we're on the main thread.
421    pub fn is_main_thread(&self) -> bool {
422        self.dispatcher.is_main_thread()
423    }
424
425    #[cfg(any(test, feature = "test-support"))]
426    /// in tests, control the number of ticks that `block_with_timeout` will run before timing out.
427    pub fn set_block_on_ticks(&self, range: std::ops::RangeInclusive<usize>) {
428        self.dispatcher.as_test().unwrap().set_block_on_ticks(range);
429    }
430}
431
432/// ForegroundExecutor runs things on the main thread.
433impl ForegroundExecutor {
434    /// Creates a new ForegroundExecutor from the given PlatformDispatcher.
435    pub fn new(dispatcher: Arc<dyn PlatformDispatcher>) -> Self {
436        Self {
437            dispatcher,
438            not_send: PhantomData,
439        }
440    }
441
442    /// Enqueues the given Task to run on the main thread at some point in the future.
443    pub fn spawn<R>(&self, future: impl Future<Output = R> + 'static) -> Task<R>
444    where
445        R: 'static,
446    {
447        let dispatcher = self.dispatcher.clone();
448        fn inner<R: 'static>(
449            dispatcher: Arc<dyn PlatformDispatcher>,
450            future: AnyLocalFuture<R>,
451        ) -> Task<R> {
452            let (runnable, task) = async_task::spawn_local(future, move |runnable| {
453                dispatcher.dispatch_on_main_thread(runnable)
454            });
455            runnable.schedule();
456            Task::Spawned(task)
457        }
458        inner::<R>(dispatcher, Box::pin(future))
459    }
460}
461
462/// Scope manages a set of tasks that are enqueued and waited on together. See [`BackgroundExecutor::scoped`].
463pub struct Scope<'a> {
464    executor: BackgroundExecutor,
465    futures: Vec<Pin<Box<dyn Future<Output = ()> + Send + 'static>>>,
466    tx: Option<mpsc::Sender<()>>,
467    rx: mpsc::Receiver<()>,
468    lifetime: PhantomData<&'a ()>,
469}
470
471impl<'a> Scope<'a> {
472    fn new(executor: BackgroundExecutor) -> Self {
473        let (tx, rx) = mpsc::channel(1);
474        Self {
475            executor,
476            tx: Some(tx),
477            rx,
478            futures: Default::default(),
479            lifetime: PhantomData,
480        }
481    }
482
483    /// How many CPUs are available to the dispatcher.
484    pub fn num_cpus(&self) -> usize {
485        self.executor.num_cpus()
486    }
487
488    /// Spawn a future into this scope.
489    pub fn spawn<F>(&mut self, f: F)
490    where
491        F: Future<Output = ()> + Send + 'a,
492    {
493        let tx = self.tx.clone().unwrap();
494
495        // SAFETY: The 'a lifetime is guaranteed to outlive any of these futures because
496        // dropping this `Scope` blocks until all of the futures have resolved.
497        let f = unsafe {
498            mem::transmute::<
499                Pin<Box<dyn Future<Output = ()> + Send + 'a>>,
500                Pin<Box<dyn Future<Output = ()> + Send + 'static>>,
501            >(Box::pin(async move {
502                f.await;
503                drop(tx);
504            }))
505        };
506        self.futures.push(f);
507    }
508}
509
510impl<'a> Drop for Scope<'a> {
511    fn drop(&mut self) {
512        self.tx.take().unwrap();
513
514        // Wait until the channel is closed, which means that all of the spawned
515        // futures have resolved.
516        self.executor.block(self.rx.next());
517    }
518}