executor.rs

  1use crate::{AppContext, PlatformDispatcher};
  2use futures::{channel::mpsc, pin_mut, FutureExt};
  3use smol::prelude::*;
  4use std::{
  5    fmt::Debug,
  6    marker::PhantomData,
  7    mem,
  8    num::NonZeroUsize,
  9    pin::Pin,
 10    rc::Rc,
 11    sync::{
 12        atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
 13        Arc,
 14    },
 15    task::{Context, Poll},
 16    time::Duration,
 17};
 18use util::TryFutureExt;
 19use waker_fn::waker_fn;
 20
 21#[cfg(any(test, feature = "test-support"))]
 22use rand::rngs::StdRng;
 23
 24#[derive(Clone)]
 25pub struct BackgroundExecutor {
 26    dispatcher: Arc<dyn PlatformDispatcher>,
 27}
 28
 29#[derive(Clone)]
 30pub struct ForegroundExecutor {
 31    dispatcher: Arc<dyn PlatformDispatcher>,
 32    not_send: PhantomData<Rc<()>>,
 33}
 34
 35/// Task is a primitive that allows work to happen in the background.
 36///
 37/// It implements [`Future`] so you can `.await` on it.
 38///
 39/// If you drop a task it will be cancelled immediately. Calling [`Task::detach`] allows
 40/// the task to continue running in the background, but with no way to return a value.
 41#[must_use]
 42#[derive(Debug)]
 43pub enum Task<T> {
 44    Ready(Option<T>),
 45    Spawned(async_task::Task<T>),
 46}
 47
 48impl<T> Task<T> {
 49    /// Create a new task that will resolve with the value
 50    pub fn ready(val: T) -> Self {
 51        Task::Ready(Some(val))
 52    }
 53
 54    /// Detaching a task runs it to completion in the background
 55    pub fn detach(self) {
 56        match self {
 57            Task::Ready(_) => {}
 58            Task::Spawned(task) => task.detach(),
 59        }
 60    }
 61}
 62
 63impl<E, T> Task<Result<T, E>>
 64where
 65    T: 'static,
 66    E: 'static + Debug,
 67{
 68    /// Run the task to completion in the background and log any
 69    /// errors that occur.
 70    #[track_caller]
 71    pub fn detach_and_log_err(self, cx: &mut AppContext) {
 72        let location = core::panic::Location::caller();
 73        cx.foreground_executor()
 74            .spawn(self.log_tracked_err(*location))
 75            .detach();
 76    }
 77}
 78
 79impl<T> Future for Task<T> {
 80    type Output = T;
 81
 82    fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
 83        match unsafe { self.get_unchecked_mut() } {
 84            Task::Ready(val) => Poll::Ready(val.take().unwrap()),
 85            Task::Spawned(task) => task.poll(cx),
 86        }
 87    }
 88}
 89
 90#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
 91pub struct TaskLabel(NonZeroUsize);
 92
 93impl Default for TaskLabel {
 94    fn default() -> Self {
 95        Self::new()
 96    }
 97}
 98
 99impl TaskLabel {
100    pub fn new() -> Self {
101        static NEXT_TASK_LABEL: AtomicUsize = AtomicUsize::new(1);
102        Self(NEXT_TASK_LABEL.fetch_add(1, SeqCst).try_into().unwrap())
103    }
104}
105
106type AnyLocalFuture<R> = Pin<Box<dyn 'static + Future<Output = R>>>;
107
108type AnyFuture<R> = Pin<Box<dyn 'static + Send + Future<Output = R>>>;
109
110/// BackgroundExecutor lets you run things on background threads.
111/// In production this is a thread pool with no ordering guarantees.
112/// In tests this is simulated by running tasks one by one in a deterministic
113/// (but arbitrary) order controlled by the `SEED` environment variable.
114impl BackgroundExecutor {
115    #[doc(hidden)]
116    pub fn new(dispatcher: Arc<dyn PlatformDispatcher>) -> Self {
117        Self { dispatcher }
118    }
119
120    /// Enqueues the given future to be run to completion on a background thread.
121    pub fn spawn<R>(&self, future: impl Future<Output = R> + Send + 'static) -> Task<R>
122    where
123        R: Send + 'static,
124    {
125        self.spawn_internal::<R>(Box::pin(future), None)
126    }
127
128    /// Enqueues the given future to be run to completion on a background thread.
129    /// The given label can be used to control the priority of the task in tests.
130    pub fn spawn_labeled<R>(
131        &self,
132        label: TaskLabel,
133        future: impl Future<Output = R> + Send + 'static,
134    ) -> Task<R>
135    where
136        R: Send + 'static,
137    {
138        self.spawn_internal::<R>(Box::pin(future), Some(label))
139    }
140
141    fn spawn_internal<R: Send + 'static>(
142        &self,
143        future: AnyFuture<R>,
144        label: Option<TaskLabel>,
145    ) -> Task<R> {
146        let dispatcher = self.dispatcher.clone();
147        let (runnable, task) =
148            async_task::spawn(future, move |runnable| dispatcher.dispatch(runnable, label));
149        runnable.schedule();
150        Task::Spawned(task)
151    }
152
153    /// Used by the test harness to run an async test in a synchronous fashion.
154    #[cfg(any(test, feature = "test-support"))]
155    #[track_caller]
156    pub fn block_test<R>(&self, future: impl Future<Output = R>) -> R {
157        if let Ok(value) = self.block_internal(false, future, usize::MAX) {
158            value
159        } else {
160            unreachable!()
161        }
162    }
163
164    /// Block the current thread until the given future resolves.
165    /// Consider using `block_with_timeout` instead.
166    pub fn block<R>(&self, future: impl Future<Output = R>) -> R {
167        if let Ok(value) = self.block_internal(true, future, usize::MAX) {
168            value
169        } else {
170            unreachable!()
171        }
172    }
173
174    #[track_caller]
175    pub(crate) fn block_internal<R>(
176        &self,
177        background_only: bool,
178        future: impl Future<Output = R>,
179        mut max_ticks: usize,
180    ) -> Result<R, ()> {
181        pin_mut!(future);
182        let unparker = self.dispatcher.unparker();
183        let awoken = Arc::new(AtomicBool::new(false));
184
185        let waker = waker_fn({
186            let awoken = awoken.clone();
187            move || {
188                awoken.store(true, SeqCst);
189                unparker.unpark();
190            }
191        });
192        let mut cx = std::task::Context::from_waker(&waker);
193
194        loop {
195            match future.as_mut().poll(&mut cx) {
196                Poll::Ready(result) => return Ok(result),
197                Poll::Pending => {
198                    if max_ticks == 0 {
199                        return Err(());
200                    }
201                    max_ticks -= 1;
202
203                    if !self.dispatcher.tick(background_only) {
204                        if awoken.swap(false, SeqCst) {
205                            continue;
206                        }
207
208                        #[cfg(any(test, feature = "test-support"))]
209                        if let Some(test) = self.dispatcher.as_test() {
210                            if !test.parking_allowed() {
211                                let mut backtrace_message = String::new();
212                                if let Some(backtrace) = test.waiting_backtrace() {
213                                    backtrace_message =
214                                        format!("\nbacktrace of waiting future:\n{:?}", backtrace);
215                                }
216                                panic!("parked with nothing left to run\n{:?}", backtrace_message)
217                            }
218                        }
219
220                        self.dispatcher.park();
221                    }
222                }
223            }
224        }
225    }
226
227    /// Block the current thread until the given future resolves
228    /// or `duration` has elapsed.
229    pub fn block_with_timeout<R>(
230        &self,
231        duration: Duration,
232        future: impl Future<Output = R>,
233    ) -> Result<R, impl Future<Output = R>> {
234        let mut future = Box::pin(future.fuse());
235        if duration.is_zero() {
236            return Err(future);
237        }
238
239        #[cfg(any(test, feature = "test-support"))]
240        let max_ticks = self
241            .dispatcher
242            .as_test()
243            .map_or(usize::MAX, |dispatcher| dispatcher.gen_block_on_ticks());
244        #[cfg(not(any(test, feature = "test-support")))]
245        let max_ticks = usize::MAX;
246
247        let mut timer = self.timer(duration).fuse();
248
249        let timeout = async {
250            futures::select_biased! {
251                value = future => Ok(value),
252                _ = timer => Err(()),
253            }
254        };
255        match self.block_internal(true, timeout, max_ticks) {
256            Ok(Ok(value)) => Ok(value),
257            _ => Err(future),
258        }
259    }
260
261    /// Scoped lets you start a number of tasks and waits
262    /// for all of them to complete before returning.
263    pub async fn scoped<'scope, F>(&self, scheduler: F)
264    where
265        F: FnOnce(&mut Scope<'scope>),
266    {
267        let mut scope = Scope::new(self.clone());
268        (scheduler)(&mut scope);
269        let spawned = mem::take(&mut scope.futures)
270            .into_iter()
271            .map(|f| self.spawn(f))
272            .collect::<Vec<_>>();
273        for task in spawned {
274            task.await;
275        }
276    }
277
278    /// Returns a task that will complete after the given duration.
279    /// Depending on other concurrent tasks the elapsed duration may be longer
280    /// than requested.
281    pub fn timer(&self, duration: Duration) -> Task<()> {
282        let (runnable, task) = async_task::spawn(async move {}, {
283            let dispatcher = self.dispatcher.clone();
284            move |runnable| dispatcher.dispatch_after(duration, runnable)
285        });
286        runnable.schedule();
287        Task::Spawned(task)
288    }
289
290    /// in tests, start_waiting lets you indicate which task is waiting (for debugging only)
291    #[cfg(any(test, feature = "test-support"))]
292    pub fn start_waiting(&self) {
293        self.dispatcher.as_test().unwrap().start_waiting();
294    }
295
296    /// in tests, removes the debugging data added by start_waiting
297    #[cfg(any(test, feature = "test-support"))]
298    pub fn finish_waiting(&self) {
299        self.dispatcher.as_test().unwrap().finish_waiting();
300    }
301
302    /// in tests, run an arbitrary number of tasks (determined by the SEED environment variable)
303    #[cfg(any(test, feature = "test-support"))]
304    pub fn simulate_random_delay(&self) -> impl Future<Output = ()> {
305        self.dispatcher.as_test().unwrap().simulate_random_delay()
306    }
307
308    /// in tests, indicate that a given task from `spawn_labeled` should run after everything else
309    #[cfg(any(test, feature = "test-support"))]
310    pub fn deprioritize(&self, task_label: TaskLabel) {
311        self.dispatcher.as_test().unwrap().deprioritize(task_label)
312    }
313
314    /// in tests, move time forward. This does not run any tasks, but does make `timer`s ready.
315    #[cfg(any(test, feature = "test-support"))]
316    pub fn advance_clock(&self, duration: Duration) {
317        self.dispatcher.as_test().unwrap().advance_clock(duration)
318    }
319
320    /// in tests, run one task.
321    #[cfg(any(test, feature = "test-support"))]
322    pub fn tick(&self) -> bool {
323        self.dispatcher.as_test().unwrap().tick(false)
324    }
325
326    /// in tests, run all tasks that are ready to run. If after doing so
327    /// the test still has outstanding tasks, this will panic. (See also `allow_parking`)
328    #[cfg(any(test, feature = "test-support"))]
329    pub fn run_until_parked(&self) {
330        self.dispatcher.as_test().unwrap().run_until_parked()
331    }
332
333    /// in tests, prevents `run_until_parked` from panicking if there are outstanding tasks.
334    /// This is useful when you are integrating other (non-GPUI) futures, like disk access, that
335    /// do take real async time to run.
336    #[cfg(any(test, feature = "test-support"))]
337    pub fn allow_parking(&self) {
338        self.dispatcher.as_test().unwrap().allow_parking();
339    }
340
341    /// in tests, returns the rng used by the dispatcher and seeded by the `SEED` environment variable
342    #[cfg(any(test, feature = "test-support"))]
343    pub fn rng(&self) -> StdRng {
344        self.dispatcher.as_test().unwrap().rng()
345    }
346
347    /// How many CPUs are available to the dispatcher
348    pub fn num_cpus(&self) -> usize {
349        num_cpus::get()
350    }
351
352    /// Whether we're on the main thread.
353    pub fn is_main_thread(&self) -> bool {
354        self.dispatcher.is_main_thread()
355    }
356
357    #[cfg(any(test, feature = "test-support"))]
358    /// in tests, control the number of ticks that `block_with_timeout` will run before timing out.
359    pub fn set_block_on_ticks(&self, range: std::ops::RangeInclusive<usize>) {
360        self.dispatcher.as_test().unwrap().set_block_on_ticks(range);
361    }
362}
363
364/// ForegroundExecutor runs things on the main thread.
365impl ForegroundExecutor {
366    pub fn new(dispatcher: Arc<dyn PlatformDispatcher>) -> Self {
367        Self {
368            dispatcher,
369            not_send: PhantomData,
370        }
371    }
372
373    /// Enqueues the given Task to run on the main thread at some point in the future.
374    pub fn spawn<R>(&self, future: impl Future<Output = R> + 'static) -> Task<R>
375    where
376        R: 'static,
377    {
378        let dispatcher = self.dispatcher.clone();
379        fn inner<R: 'static>(
380            dispatcher: Arc<dyn PlatformDispatcher>,
381            future: AnyLocalFuture<R>,
382        ) -> Task<R> {
383            let (runnable, task) = async_task::spawn_local(future, move |runnable| {
384                dispatcher.dispatch_on_main_thread(runnable)
385            });
386            runnable.schedule();
387            Task::Spawned(task)
388        }
389        inner::<R>(dispatcher, Box::pin(future))
390    }
391}
392
393/// Scope manages a set of tasks that are enqueued and waited on together. See [`BackgroundExecutor::scoped`].
394pub struct Scope<'a> {
395    executor: BackgroundExecutor,
396    futures: Vec<Pin<Box<dyn Future<Output = ()> + Send + 'static>>>,
397    tx: Option<mpsc::Sender<()>>,
398    rx: mpsc::Receiver<()>,
399    lifetime: PhantomData<&'a ()>,
400}
401
402impl<'a> Scope<'a> {
403    fn new(executor: BackgroundExecutor) -> Self {
404        let (tx, rx) = mpsc::channel(1);
405        Self {
406            executor,
407            tx: Some(tx),
408            rx,
409            futures: Default::default(),
410            lifetime: PhantomData,
411        }
412    }
413
414    pub fn spawn<F>(&mut self, f: F)
415    where
416        F: Future<Output = ()> + Send + 'a,
417    {
418        let tx = self.tx.clone().unwrap();
419
420        // Safety: The 'a lifetime is guaranteed to outlive any of these futures because
421        // dropping this `Scope` blocks until all of the futures have resolved.
422        let f = unsafe {
423            mem::transmute::<
424                Pin<Box<dyn Future<Output = ()> + Send + 'a>>,
425                Pin<Box<dyn Future<Output = ()> + Send + 'static>>,
426            >(Box::pin(async move {
427                f.await;
428                drop(tx);
429            }))
430        };
431        self.futures.push(f);
432    }
433}
434
435impl<'a> Drop for Scope<'a> {
436    fn drop(&mut self) {
437        self.tx.take().unwrap();
438
439        // Wait until the channel is closed, which means that all of the spawned
440        // futures have resolved.
441        self.executor.block(self.rx.next());
442    }
443}