1use crate::{App, PlatformDispatcher};
  2use async_task::Runnable;
  3use futures::channel::mpsc;
  4use smol::prelude::*;
  5use std::{
  6    fmt::Debug,
  7    marker::PhantomData,
  8    mem::{self, ManuallyDrop},
  9    num::NonZeroUsize,
 10    panic::Location,
 11    pin::Pin,
 12    rc::Rc,
 13    sync::{
 14        Arc,
 15        atomic::{AtomicUsize, Ordering},
 16    },
 17    task::{Context, Poll},
 18    thread::{self, ThreadId},
 19    time::{Duration, Instant},
 20};
 21use util::TryFutureExt;
 22use waker_fn::waker_fn;
 23
 24#[cfg(any(test, feature = "test-support"))]
 25use rand::rngs::StdRng;
 26
 27/// A pointer to the executor that is currently running,
 28/// for spawning background tasks.
 29#[derive(Clone)]
 30pub struct BackgroundExecutor {
 31    #[doc(hidden)]
 32    pub dispatcher: Arc<dyn PlatformDispatcher>,
 33}
 34
 35/// A pointer to the executor that is currently running,
 36/// for spawning tasks on the main thread.
 37///
 38/// This is intentionally `!Send` via the `not_send` marker field. This is because
 39/// `ForegroundExecutor::spawn` does not require `Send` but checks at runtime that the future is
 40/// only polled from the same thread it was spawned from. These checks would fail when spawning
 41/// foreground tasks from background threads.
 42#[derive(Clone)]
 43pub struct ForegroundExecutor {
 44    #[doc(hidden)]
 45    pub dispatcher: Arc<dyn PlatformDispatcher>,
 46    not_send: PhantomData<Rc<()>>,
 47}
 48
 49/// Task is a primitive that allows work to happen in the background.
 50///
 51/// It implements [`Future`] so you can `.await` on it.
 52///
 53/// If you drop a task it will be cancelled immediately. Calling [`Task::detach`] allows
 54/// the task to continue running, but with no way to return a value.
 55#[must_use]
 56#[derive(Debug)]
 57pub struct Task<T>(TaskState<T>);
 58
 59#[derive(Debug)]
 60enum TaskState<T> {
 61    /// A task that is ready to return a value
 62    Ready(Option<T>),
 63
 64    /// A task that is currently running.
 65    Spawned(async_task::Task<T>),
 66}
 67
 68impl<T> Task<T> {
 69    /// Creates a new task that will resolve with the value
 70    pub fn ready(val: T) -> Self {
 71        Task(TaskState::Ready(Some(val)))
 72    }
 73
 74    /// Detaching a task runs it to completion in the background
 75    pub fn detach(self) {
 76        match self {
 77            Task(TaskState::Ready(_)) => {}
 78            Task(TaskState::Spawned(task)) => task.detach(),
 79        }
 80    }
 81}
 82
 83impl<E, T> Task<Result<T, E>>
 84where
 85    T: 'static,
 86    E: 'static + Debug,
 87{
 88    /// Run the task to completion in the background and log any
 89    /// errors that occur.
 90    #[track_caller]
 91    pub fn detach_and_log_err(self, cx: &App) {
 92        let location = core::panic::Location::caller();
 93        cx.foreground_executor()
 94            .spawn(self.log_tracked_err(*location))
 95            .detach();
 96    }
 97}
 98
 99impl<T> Future for Task<T> {
100    type Output = T;
101
102    fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
103        match unsafe { self.get_unchecked_mut() } {
104            Task(TaskState::Ready(val)) => Poll::Ready(val.take().unwrap()),
105            Task(TaskState::Spawned(task)) => task.poll(cx),
106        }
107    }
108}
109
110/// A task label is an opaque identifier that you can use to
111/// refer to a task in tests.
112#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
113pub struct TaskLabel(NonZeroUsize);
114
115impl Default for TaskLabel {
116    fn default() -> Self {
117        Self::new()
118    }
119}
120
121impl TaskLabel {
122    /// Construct a new task label.
123    pub fn new() -> Self {
124        static NEXT_TASK_LABEL: AtomicUsize = AtomicUsize::new(1);
125        Self(
126            NEXT_TASK_LABEL
127                .fetch_add(1, Ordering::SeqCst)
128                .try_into()
129                .unwrap(),
130        )
131    }
132}
133
134type AnyLocalFuture<R> = Pin<Box<dyn 'static + Future<Output = R>>>;
135
136type AnyFuture<R> = Pin<Box<dyn 'static + Send + Future<Output = R>>>;
137
138/// BackgroundExecutor lets you run things on background threads.
139/// In production this is a thread pool with no ordering guarantees.
140/// In tests this is simulated by running tasks one by one in a deterministic
141/// (but arbitrary) order controlled by the `SEED` environment variable.
142impl BackgroundExecutor {
143    #[doc(hidden)]
144    pub fn new(dispatcher: Arc<dyn PlatformDispatcher>) -> Self {
145        Self { dispatcher }
146    }
147
148    /// Enqueues the given future to be run to completion on a background thread.
149    pub fn spawn<R>(&self, future: impl Future<Output = R> + Send + 'static) -> Task<R>
150    where
151        R: Send + 'static,
152    {
153        self.spawn_internal::<R>(Box::pin(future), None)
154    }
155
156    /// Enqueues the given future to be run to completion on a background thread.
157    /// The given label can be used to control the priority of the task in tests.
158    pub fn spawn_labeled<R>(
159        &self,
160        label: TaskLabel,
161        future: impl Future<Output = R> + Send + 'static,
162    ) -> Task<R>
163    where
164        R: Send + 'static,
165    {
166        self.spawn_internal::<R>(Box::pin(future), Some(label))
167    }
168
169    fn spawn_internal<R: Send + 'static>(
170        &self,
171        future: AnyFuture<R>,
172        label: Option<TaskLabel>,
173    ) -> Task<R> {
174        let dispatcher = self.dispatcher.clone();
175        let (runnable, task) =
176            async_task::spawn(future, move |runnable| dispatcher.dispatch(runnable, label));
177        runnable.schedule();
178        Task(TaskState::Spawned(task))
179    }
180
181    /// Used by the test harness to run an async test in a synchronous fashion.
182    #[cfg(any(test, feature = "test-support"))]
183    #[track_caller]
184    pub fn block_test<R>(&self, future: impl Future<Output = R>) -> R {
185        if let Ok(value) = self.block_internal(false, future, None) {
186            value
187        } else {
188            unreachable!()
189        }
190    }
191
192    /// Block the current thread until the given future resolves.
193    /// Consider using `block_with_timeout` instead.
194    pub fn block<R>(&self, future: impl Future<Output = R>) -> R {
195        if let Ok(value) = self.block_internal(true, future, None) {
196            value
197        } else {
198            unreachable!()
199        }
200    }
201
202    #[cfg(not(any(test, feature = "test-support")))]
203    pub(crate) fn block_internal<Fut: Future>(
204        &self,
205        _background_only: bool,
206        future: Fut,
207        timeout: Option<Duration>,
208    ) -> Result<Fut::Output, impl Future<Output = Fut::Output> + use<Fut>> {
209        use std::time::Instant;
210
211        let mut future = Box::pin(future);
212        if timeout == Some(Duration::ZERO) {
213            return Err(future);
214        }
215        let deadline = timeout.map(|timeout| Instant::now() + timeout);
216
217        let parker = parking::Parker::new();
218        let unparker = parker.unparker();
219        let waker = waker_fn(move || {
220            unparker.unpark();
221        });
222        let mut cx = std::task::Context::from_waker(&waker);
223
224        loop {
225            match future.as_mut().poll(&mut cx) {
226                Poll::Ready(result) => return Ok(result),
227                Poll::Pending => {
228                    let timeout =
229                        deadline.map(|deadline| deadline.saturating_duration_since(Instant::now()));
230                    if let Some(timeout) = timeout {
231                        if !parker.park_timeout(timeout)
232                            && deadline.is_some_and(|deadline| deadline < Instant::now())
233                        {
234                            return Err(future);
235                        }
236                    } else {
237                        parker.park();
238                    }
239                }
240            }
241        }
242    }
243
244    #[cfg(any(test, feature = "test-support"))]
245    #[track_caller]
246    pub(crate) fn block_internal<Fut: Future>(
247        &self,
248        background_only: bool,
249        future: Fut,
250        timeout: Option<Duration>,
251    ) -> Result<Fut::Output, impl Future<Output = Fut::Output> + use<Fut>> {
252        use std::sync::atomic::AtomicBool;
253
254        use parking::Parker;
255
256        let mut future = Box::pin(future);
257        if timeout == Some(Duration::ZERO) {
258            return Err(future);
259        }
260        let Some(dispatcher) = self.dispatcher.as_test() else {
261            return Err(future);
262        };
263
264        let mut max_ticks = if timeout.is_some() {
265            dispatcher.gen_block_on_ticks()
266        } else {
267            usize::MAX
268        };
269
270        let parker = Parker::new();
271        let unparker = parker.unparker();
272
273        let awoken = Arc::new(AtomicBool::new(false));
274        let waker = waker_fn({
275            let awoken = awoken.clone();
276            let unparker = unparker.clone();
277            move || {
278                awoken.store(true, Ordering::SeqCst);
279                unparker.unpark();
280            }
281        });
282        let mut cx = std::task::Context::from_waker(&waker);
283
284        let duration = Duration::from_secs(500);
285        let mut test_should_end_by = Instant::now() + duration;
286
287        loop {
288            match future.as_mut().poll(&mut cx) {
289                Poll::Ready(result) => return Ok(result),
290                Poll::Pending => {
291                    if max_ticks == 0 {
292                        return Err(future);
293                    }
294                    max_ticks -= 1;
295
296                    if !dispatcher.tick(background_only) {
297                        if awoken.swap(false, Ordering::SeqCst) {
298                            continue;
299                        }
300
301                        if !dispatcher.parking_allowed() {
302                            if dispatcher.advance_clock_to_next_delayed() {
303                                continue;
304                            }
305                            let mut backtrace_message = String::new();
306                            let mut waiting_message = String::new();
307                            if let Some(backtrace) = dispatcher.waiting_backtrace() {
308                                backtrace_message =
309                                    format!("\nbacktrace of waiting future:\n{:?}", backtrace);
310                            }
311                            if let Some(waiting_hint) = dispatcher.waiting_hint() {
312                                waiting_message = format!("\n  waiting on: {}\n", waiting_hint);
313                            }
314                            panic!(
315                                "parked with nothing left to run{waiting_message}{backtrace_message}",
316                            )
317                        }
318                        dispatcher.set_unparker(unparker.clone());
319                        parker.park_timeout(
320                            test_should_end_by.saturating_duration_since(Instant::now()),
321                        );
322                        if Instant::now() > test_should_end_by {
323                            panic!("test timed out after {duration:?} with allow_parking")
324                        }
325                    }
326                }
327            }
328        }
329    }
330
331    /// Block the current thread until the given future resolves
332    /// or `duration` has elapsed.
333    pub fn block_with_timeout<Fut: Future>(
334        &self,
335        duration: Duration,
336        future: Fut,
337    ) -> Result<Fut::Output, impl Future<Output = Fut::Output> + use<Fut>> {
338        self.block_internal(true, future, Some(duration))
339    }
340
341    /// Scoped lets you start a number of tasks and waits
342    /// for all of them to complete before returning.
343    pub async fn scoped<'scope, F>(&self, scheduler: F)
344    where
345        F: FnOnce(&mut Scope<'scope>),
346    {
347        let mut scope = Scope::new(self.clone());
348        (scheduler)(&mut scope);
349        let spawned = mem::take(&mut scope.futures)
350            .into_iter()
351            .map(|f| self.spawn(f))
352            .collect::<Vec<_>>();
353        for task in spawned {
354            task.await;
355        }
356    }
357
358    /// Get the current time.
359    ///
360    /// Calling this instead of `std::time::Instant::now` allows the use
361    /// of fake timers in tests.
362    pub fn now(&self) -> Instant {
363        self.dispatcher.now()
364    }
365
366    /// Returns a task that will complete after the given duration.
367    /// Depending on other concurrent tasks the elapsed duration may be longer
368    /// than requested.
369    pub fn timer(&self, duration: Duration) -> Task<()> {
370        if duration.is_zero() {
371            return Task::ready(());
372        }
373        let (runnable, task) = async_task::spawn(async move {}, {
374            let dispatcher = self.dispatcher.clone();
375            move |runnable| dispatcher.dispatch_after(duration, runnable)
376        });
377        runnable.schedule();
378        Task(TaskState::Spawned(task))
379    }
380
381    /// in tests, start_waiting lets you indicate which task is waiting (for debugging only)
382    #[cfg(any(test, feature = "test-support"))]
383    pub fn start_waiting(&self) {
384        self.dispatcher.as_test().unwrap().start_waiting();
385    }
386
387    /// in tests, removes the debugging data added by start_waiting
388    #[cfg(any(test, feature = "test-support"))]
389    pub fn finish_waiting(&self) {
390        self.dispatcher.as_test().unwrap().finish_waiting();
391    }
392
393    /// in tests, run an arbitrary number of tasks (determined by the SEED environment variable)
394    #[cfg(any(test, feature = "test-support"))]
395    pub fn simulate_random_delay(&self) -> impl Future<Output = ()> + use<> {
396        self.dispatcher.as_test().unwrap().simulate_random_delay()
397    }
398
399    /// in tests, indicate that a given task from `spawn_labeled` should run after everything else
400    #[cfg(any(test, feature = "test-support"))]
401    pub fn deprioritize(&self, task_label: TaskLabel) {
402        self.dispatcher.as_test().unwrap().deprioritize(task_label)
403    }
404
405    /// in tests, move time forward. This does not run any tasks, but does make `timer`s ready.
406    #[cfg(any(test, feature = "test-support"))]
407    pub fn advance_clock(&self, duration: Duration) {
408        self.dispatcher.as_test().unwrap().advance_clock(duration)
409    }
410
411    /// in tests, run one task.
412    #[cfg(any(test, feature = "test-support"))]
413    pub fn tick(&self) -> bool {
414        self.dispatcher.as_test().unwrap().tick(false)
415    }
416
417    /// in tests, run all tasks that are ready to run. If after doing so
418    /// the test still has outstanding tasks, this will panic. (See also [`Self::allow_parking`])
419    #[cfg(any(test, feature = "test-support"))]
420    pub fn run_until_parked(&self) {
421        self.dispatcher.as_test().unwrap().run_until_parked()
422    }
423
424    /// in tests, prevents `run_until_parked` from panicking if there are outstanding tasks.
425    /// This is useful when you are integrating other (non-GPUI) futures, like disk access, that
426    /// do take real async time to run.
427    #[cfg(any(test, feature = "test-support"))]
428    pub fn allow_parking(&self) {
429        self.dispatcher.as_test().unwrap().allow_parking();
430    }
431
432    /// undoes the effect of [`Self::allow_parking`].
433    #[cfg(any(test, feature = "test-support"))]
434    pub fn forbid_parking(&self) {
435        self.dispatcher.as_test().unwrap().forbid_parking();
436    }
437
438    /// adds detail to the "parked with nothing let to run" message.
439    #[cfg(any(test, feature = "test-support"))]
440    pub fn set_waiting_hint(&self, msg: Option<String>) {
441        self.dispatcher.as_test().unwrap().set_waiting_hint(msg);
442    }
443
444    /// in tests, returns the rng used by the dispatcher and seeded by the `SEED` environment variable
445    #[cfg(any(test, feature = "test-support"))]
446    pub fn rng(&self) -> StdRng {
447        self.dispatcher.as_test().unwrap().rng()
448    }
449
450    /// How many CPUs are available to the dispatcher.
451    pub fn num_cpus(&self) -> usize {
452        #[cfg(any(test, feature = "test-support"))]
453        return 4;
454
455        #[cfg(not(any(test, feature = "test-support")))]
456        return num_cpus::get();
457    }
458
459    /// Whether we're on the main thread.
460    pub fn is_main_thread(&self) -> bool {
461        self.dispatcher.is_main_thread()
462    }
463
464    #[cfg(any(test, feature = "test-support"))]
465    /// in tests, control the number of ticks that `block_with_timeout` will run before timing out.
466    pub fn set_block_on_ticks(&self, range: std::ops::RangeInclusive<usize>) {
467        self.dispatcher.as_test().unwrap().set_block_on_ticks(range);
468    }
469}
470
471/// ForegroundExecutor runs things on the main thread.
472impl ForegroundExecutor {
473    /// Creates a new ForegroundExecutor from the given PlatformDispatcher.
474    pub fn new(dispatcher: Arc<dyn PlatformDispatcher>) -> Self {
475        Self {
476            dispatcher,
477            not_send: PhantomData,
478        }
479    }
480
481    /// Enqueues the given Task to run on the main thread at some point in the future.
482    #[track_caller]
483    pub fn spawn<R>(&self, future: impl Future<Output = R> + 'static) -> Task<R>
484    where
485        R: 'static,
486    {
487        let dispatcher = self.dispatcher.clone();
488
489        #[track_caller]
490        fn inner<R: 'static>(
491            dispatcher: Arc<dyn PlatformDispatcher>,
492            future: AnyLocalFuture<R>,
493        ) -> Task<R> {
494            let (runnable, task) = spawn_local_with_source_location(future, move |runnable| {
495                dispatcher.dispatch_on_main_thread(runnable)
496            });
497            runnable.schedule();
498            Task(TaskState::Spawned(task))
499        }
500        inner::<R>(dispatcher, Box::pin(future))
501    }
502}
503
504/// Variant of `async_task::spawn_local` that includes the source location of the spawn in panics.
505///
506/// Copy-modified from:
507/// <https://github.com/smol-rs/async-task/blob/ca9dbe1db9c422fd765847fa91306e30a6bb58a9/src/runnable.rs#L405>
508#[track_caller]
509fn spawn_local_with_source_location<Fut, S>(
510    future: Fut,
511    schedule: S,
512) -> (Runnable<()>, async_task::Task<Fut::Output, ()>)
513where
514    Fut: Future + 'static,
515    Fut::Output: 'static,
516    S: async_task::Schedule<()> + Send + Sync + 'static,
517{
518    #[inline]
519    fn thread_id() -> ThreadId {
520        std::thread_local! {
521            static ID: ThreadId = thread::current().id();
522        }
523        ID.try_with(|id| *id)
524            .unwrap_or_else(|_| thread::current().id())
525    }
526
527    struct Checked<F> {
528        id: ThreadId,
529        inner: ManuallyDrop<F>,
530        location: &'static Location<'static>,
531    }
532
533    impl<F> Drop for Checked<F> {
534        fn drop(&mut self) {
535            assert!(
536                self.id == thread_id(),
537                "local task dropped by a thread that didn't spawn it. Task spawned at {}",
538                self.location
539            );
540            unsafe { ManuallyDrop::drop(&mut self.inner) };
541        }
542    }
543
544    impl<F: Future> Future for Checked<F> {
545        type Output = F::Output;
546
547        fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
548            assert!(
549                self.id == thread_id(),
550                "local task polled by a thread that didn't spawn it. Task spawned at {}",
551                self.location
552            );
553            unsafe { self.map_unchecked_mut(|c| &mut *c.inner).poll(cx) }
554        }
555    }
556
557    // Wrap the future into one that checks which thread it's on.
558    let future = Checked {
559        id: thread_id(),
560        inner: ManuallyDrop::new(future),
561        location: Location::caller(),
562    };
563
564    unsafe { async_task::spawn_unchecked(future, schedule) }
565}
566
567/// Scope manages a set of tasks that are enqueued and waited on together. See [`BackgroundExecutor::scoped`].
568pub struct Scope<'a> {
569    executor: BackgroundExecutor,
570    futures: Vec<Pin<Box<dyn Future<Output = ()> + Send + 'static>>>,
571    tx: Option<mpsc::Sender<()>>,
572    rx: mpsc::Receiver<()>,
573    lifetime: PhantomData<&'a ()>,
574}
575
576impl<'a> Scope<'a> {
577    fn new(executor: BackgroundExecutor) -> Self {
578        let (tx, rx) = mpsc::channel(1);
579        Self {
580            executor,
581            tx: Some(tx),
582            rx,
583            futures: Default::default(),
584            lifetime: PhantomData,
585        }
586    }
587
588    /// How many CPUs are available to the dispatcher.
589    pub fn num_cpus(&self) -> usize {
590        self.executor.num_cpus()
591    }
592
593    /// Spawn a future into this scope.
594    pub fn spawn<F>(&mut self, f: F)
595    where
596        F: Future<Output = ()> + Send + 'a,
597    {
598        let tx = self.tx.clone().unwrap();
599
600        // SAFETY: The 'a lifetime is guaranteed to outlive any of these futures because
601        // dropping this `Scope` blocks until all of the futures have resolved.
602        let f = unsafe {
603            mem::transmute::<
604                Pin<Box<dyn Future<Output = ()> + Send + 'a>>,
605                Pin<Box<dyn Future<Output = ()> + Send + 'static>>,
606            >(Box::pin(async move {
607                f.await;
608                drop(tx);
609            }))
610        };
611        self.futures.push(f);
612    }
613}
614
615impl Drop for Scope<'_> {
616    fn drop(&mut self) {
617        self.tx.take().unwrap();
618
619        // Wait until the channel is closed, which means that all of the spawned
620        // futures have resolved.
621        self.executor.block(self.rx.next());
622    }
623}