executor.rs

  1use crate::{App, PlatformDispatcher, RunnableMeta, RunnableVariant, TaskTiming, profiler};
  2use async_task::Runnable;
  3use futures::channel::mpsc;
  4use parking_lot::{Condvar, Mutex};
  5use smol::prelude::*;
  6use std::{
  7    fmt::Debug,
  8    marker::PhantomData,
  9    mem::{self, ManuallyDrop},
 10    num::NonZeroUsize,
 11    panic::Location,
 12    pin::Pin,
 13    rc::Rc,
 14    sync::{
 15        Arc,
 16        atomic::{AtomicUsize, Ordering},
 17    },
 18    task::{Context, Poll},
 19    thread::{self, ThreadId},
 20    time::{Duration, Instant},
 21};
 22use util::TryFutureExt;
 23use waker_fn::waker_fn;
 24
 25#[cfg(any(test, feature = "test-support"))]
 26use rand::rngs::StdRng;
 27
 28/// A pointer to the executor that is currently running,
 29/// for spawning background tasks.
 30#[derive(Clone)]
 31pub struct BackgroundExecutor {
 32    #[doc(hidden)]
 33    pub dispatcher: Arc<dyn PlatformDispatcher>,
 34}
 35
 36/// A pointer to the executor that is currently running,
 37/// for spawning tasks on the main thread.
 38///
 39/// This is intentionally `!Send` via the `not_send` marker field. This is because
 40/// `ForegroundExecutor::spawn` does not require `Send` but checks at runtime that the future is
 41/// only polled from the same thread it was spawned from. These checks would fail when spawning
 42/// foreground tasks from background threads.
 43#[derive(Clone)]
 44pub struct ForegroundExecutor {
 45    #[doc(hidden)]
 46    pub dispatcher: Arc<dyn PlatformDispatcher>,
 47    not_send: PhantomData<Rc<()>>,
 48}
 49
 50/// Realtime task priority
 51#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
 52#[repr(u8)]
 53pub enum RealtimePriority {
 54    /// Audio task
 55    Audio,
 56    /// Other realtime task
 57    #[default]
 58    Other,
 59}
 60
 61/// Task priority
 62#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
 63#[repr(u8)]
 64pub enum Priority {
 65    /// Realtime priority
 66    ///
 67    /// Spawning a task with this priority will spin it off on a separate thread dedicated just to that task.
 68    Realtime(RealtimePriority),
 69    /// High priority
 70    ///
 71    /// Only use for tasks that are critical to the user experience / responsiveness of the editor.
 72    High,
 73    /// Medium priority, probably suits most of your use cases.
 74    #[default]
 75    Medium,
 76    /// Low priority
 77    ///
 78    /// Prioritize this for background work that can come in large quantities
 79    /// to not starve the executor of resources for high priority tasks
 80    Low,
 81}
 82
 83impl Priority {
 84    #[allow(dead_code)]
 85    pub(crate) const fn probability(&self) -> u32 {
 86        match self {
 87            // realtime priorities are not considered for probability scheduling
 88            Priority::Realtime(_) => 0,
 89            Priority::High => 60,
 90            Priority::Medium => 30,
 91            Priority::Low => 10,
 92        }
 93    }
 94}
 95
 96/// Task is a primitive that allows work to happen in the background.
 97///
 98/// It implements [`Future`] so you can `.await` on it.
 99///
100/// If you drop a task it will be cancelled immediately. Calling [`Task::detach`] allows
101/// the task to continue running, but with no way to return a value.
102#[must_use]
103#[derive(Debug)]
104pub struct Task<T>(TaskState<T>);
105
106#[derive(Debug)]
107enum TaskState<T> {
108    /// A task that is ready to return a value
109    Ready(Option<T>),
110
111    /// A task that is currently running.
112    Spawned(async_task::Task<T, RunnableMeta>),
113}
114
115impl<T> Task<T> {
116    /// Creates a new task that will resolve with the value
117    pub fn ready(val: T) -> Self {
118        Task(TaskState::Ready(Some(val)))
119    }
120
121    /// Detaching a task runs it to completion in the background
122    pub fn detach(self) {
123        match self {
124            Task(TaskState::Ready(_)) => {}
125            Task(TaskState::Spawned(task)) => task.detach(),
126        }
127    }
128}
129
130impl<E, T> Task<Result<T, E>>
131where
132    T: 'static,
133    E: 'static + Debug,
134{
135    /// Run the task to completion in the background and log any
136    /// errors that occur.
137    #[track_caller]
138    pub fn detach_and_log_err(self, cx: &App) {
139        let location = core::panic::Location::caller();
140        cx.foreground_executor()
141            .spawn(self.log_tracked_err(*location))
142            .detach();
143    }
144}
145
146impl<T> Future for Task<T> {
147    type Output = T;
148
149    fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
150        match unsafe { self.get_unchecked_mut() } {
151            Task(TaskState::Ready(val)) => Poll::Ready(val.take().unwrap()),
152            Task(TaskState::Spawned(task)) => task.poll(cx),
153        }
154    }
155}
156
157/// A task label is an opaque identifier that you can use to
158/// refer to a task in tests.
159#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
160pub struct TaskLabel(NonZeroUsize);
161
162impl Default for TaskLabel {
163    fn default() -> Self {
164        Self::new()
165    }
166}
167
168impl TaskLabel {
169    /// Construct a new task label.
170    pub fn new() -> Self {
171        static NEXT_TASK_LABEL: AtomicUsize = AtomicUsize::new(1);
172        Self(
173            NEXT_TASK_LABEL
174                .fetch_add(1, Ordering::SeqCst)
175                .try_into()
176                .unwrap(),
177        )
178    }
179}
180
181type AnyLocalFuture<R> = Pin<Box<dyn 'static + Future<Output = R>>>;
182
183type AnyFuture<R> = Pin<Box<dyn 'static + Send + Future<Output = R>>>;
184
185/// BackgroundExecutor lets you run things on background threads.
186/// In production this is a thread pool with no ordering guarantees.
187/// In tests this is simulated by running tasks one by one in a deterministic
188/// (but arbitrary) order controlled by the `SEED` environment variable.
189impl BackgroundExecutor {
190    #[doc(hidden)]
191    pub fn new(dispatcher: Arc<dyn PlatformDispatcher>) -> Self {
192        Self { dispatcher }
193    }
194
195    /// Enqueues the given future to be run to completion on a background thread.
196    #[track_caller]
197    pub fn spawn<R>(&self, future: impl Future<Output = R> + Send + 'static) -> Task<R>
198    where
199        R: Send + 'static,
200    {
201        self.spawn_with_priority(Priority::default(), future)
202    }
203
204    /// Enqueues the given future to be run to completion on a background thread.
205    #[track_caller]
206    pub fn spawn_with_priority<R>(
207        &self,
208        priority: Priority,
209        future: impl Future<Output = R> + Send + 'static,
210    ) -> Task<R>
211    where
212        R: Send + 'static,
213    {
214        self.spawn_internal::<R>(Box::pin(future), None, priority)
215    }
216
217    /// Enqueues the given future to be run to completion on a background thread and blocking the current task on it.
218    ///
219    /// This allows to spawn background work that borrows from its scope. Note that the supplied future will run to
220    /// completion before the current task is resumed, even if the current task is slated for cancellation.
221    pub async fn await_on_background<R>(&self, future: impl Future<Output = R> + Send) -> R
222    where
223        R: Send,
224    {
225        // We need to ensure that cancellation of the parent task does not drop the environment
226        // before the our own task has completed or got cancelled.
227        struct NotifyOnDrop<'a>(&'a (Condvar, Mutex<bool>));
228
229        impl Drop for NotifyOnDrop<'_> {
230            fn drop(&mut self) {
231                *self.0.1.lock() = true;
232                self.0.0.notify_all();
233            }
234        }
235
236        struct WaitOnDrop<'a>(&'a (Condvar, Mutex<bool>));
237
238        impl Drop for WaitOnDrop<'_> {
239            fn drop(&mut self) {
240                let mut done = self.0.1.lock();
241                if !*done {
242                    self.0.0.wait(&mut done);
243                }
244            }
245        }
246
247        let dispatcher = self.dispatcher.clone();
248        let location = core::panic::Location::caller();
249
250        let pair = &(Condvar::new(), Mutex::new(false));
251        let _wait_guard = WaitOnDrop(pair);
252
253        let (runnable, task) = unsafe {
254            async_task::Builder::new()
255                .metadata(RunnableMeta { location })
256                .spawn_unchecked(
257                    move |_| async {
258                        let _notify_guard = NotifyOnDrop(pair);
259                        future.await
260                    },
261                    move |runnable| {
262                        dispatcher.dispatch(
263                            RunnableVariant::Meta(runnable),
264                            None,
265                            Priority::default(),
266                        )
267                    },
268                )
269        };
270        runnable.schedule();
271        task.await
272    }
273
274    /// Enqueues the given future to be run to completion on a background thread.
275    /// The given label can be used to control the priority of the task in tests.
276    #[track_caller]
277    pub fn spawn_labeled<R>(
278        &self,
279        label: TaskLabel,
280        future: impl Future<Output = R> + Send + 'static,
281    ) -> Task<R>
282    where
283        R: Send + 'static,
284    {
285        self.spawn_internal::<R>(Box::pin(future), Some(label), Priority::default())
286    }
287
288    #[track_caller]
289    fn spawn_internal<R: Send + 'static>(
290        &self,
291        future: AnyFuture<R>,
292        label: Option<TaskLabel>,
293        priority: Priority,
294    ) -> Task<R> {
295        let dispatcher = self.dispatcher.clone();
296        let (runnable, task) = if let Priority::Realtime(realtime) = priority {
297            let location = core::panic::Location::caller();
298            let (mut tx, rx) = flume::bounded::<Runnable<RunnableMeta>>(1);
299
300            dispatcher.spawn_realtime(
301                realtime,
302                Box::new(move || {
303                    while let Ok(runnable) = rx.recv() {
304                        let start = Instant::now();
305                        let location = runnable.metadata().location;
306                        let mut timing = TaskTiming {
307                            location,
308                            start,
309                            end: None,
310                        };
311                        profiler::add_task_timing(timing);
312
313                        runnable.run();
314
315                        let end = Instant::now();
316                        timing.end = Some(end);
317                        profiler::add_task_timing(timing);
318                    }
319                }),
320            );
321
322            async_task::Builder::new()
323                .metadata(RunnableMeta { location })
324                .spawn(
325                    move |_| future,
326                    move |runnable| {
327                        let _ = tx.send(runnable);
328                    },
329                )
330        } else {
331            let location = core::panic::Location::caller();
332            async_task::Builder::new()
333                .metadata(RunnableMeta { location })
334                .spawn(
335                    move |_| future,
336                    move |runnable| {
337                        dispatcher.dispatch(RunnableVariant::Meta(runnable), label, priority)
338                    },
339                )
340        };
341
342        runnable.schedule();
343        Task(TaskState::Spawned(task))
344    }
345
346    /// Used by the test harness to run an async test in a synchronous fashion.
347    #[cfg(any(test, feature = "test-support"))]
348    #[track_caller]
349    pub fn block_test<R>(&self, future: impl Future<Output = R>) -> R {
350        if let Ok(value) = self.block_internal(false, future, None) {
351            value
352        } else {
353            unreachable!()
354        }
355    }
356
357    /// Block the current thread until the given future resolves.
358    /// Consider using `block_with_timeout` instead.
359    pub fn block<R>(&self, future: impl Future<Output = R>) -> R {
360        if let Ok(value) = self.block_internal(true, future, None) {
361            value
362        } else {
363            unreachable!()
364        }
365    }
366
367    #[cfg(not(any(test, feature = "test-support")))]
368    pub(crate) fn block_internal<Fut: Future>(
369        &self,
370        _background_only: bool,
371        future: Fut,
372        timeout: Option<Duration>,
373    ) -> Result<Fut::Output, impl Future<Output = Fut::Output> + use<Fut>> {
374        use std::time::Instant;
375
376        let mut future = Box::pin(future);
377        if timeout == Some(Duration::ZERO) {
378            return Err(future);
379        }
380        let deadline = timeout.map(|timeout| Instant::now() + timeout);
381
382        let parker = parking::Parker::new();
383        let unparker = parker.unparker();
384        let waker = waker_fn(move || {
385            unparker.unpark();
386        });
387        let mut cx = std::task::Context::from_waker(&waker);
388
389        loop {
390            match future.as_mut().poll(&mut cx) {
391                Poll::Ready(result) => return Ok(result),
392                Poll::Pending => {
393                    let timeout =
394                        deadline.map(|deadline| deadline.saturating_duration_since(Instant::now()));
395                    if let Some(timeout) = timeout {
396                        if !parker.park_timeout(timeout)
397                            && deadline.is_some_and(|deadline| deadline < Instant::now())
398                        {
399                            return Err(future);
400                        }
401                    } else {
402                        parker.park();
403                    }
404                }
405            }
406        }
407    }
408
409    #[cfg(any(test, feature = "test-support"))]
410    #[track_caller]
411    pub(crate) fn block_internal<Fut: Future>(
412        &self,
413        background_only: bool,
414        future: Fut,
415        timeout: Option<Duration>,
416    ) -> Result<Fut::Output, impl Future<Output = Fut::Output> + use<Fut>> {
417        use std::sync::atomic::AtomicBool;
418
419        use parking::Parker;
420
421        let mut future = Box::pin(future);
422        if timeout == Some(Duration::ZERO) {
423            return Err(future);
424        }
425        let Some(dispatcher) = self.dispatcher.as_test() else {
426            return Err(future);
427        };
428
429        let mut max_ticks = if timeout.is_some() {
430            dispatcher.gen_block_on_ticks()
431        } else {
432            usize::MAX
433        };
434
435        let parker = Parker::new();
436        let unparker = parker.unparker();
437
438        let awoken = Arc::new(AtomicBool::new(false));
439        let waker = waker_fn({
440            let awoken = awoken.clone();
441            let unparker = unparker.clone();
442            move || {
443                awoken.store(true, Ordering::SeqCst);
444                unparker.unpark();
445            }
446        });
447        let mut cx = std::task::Context::from_waker(&waker);
448
449        let duration = Duration::from_secs(
450            option_env!("GPUI_TEST_TIMEOUT")
451                .and_then(|s| s.parse::<u64>().ok())
452                .unwrap_or(180),
453        );
454        let mut test_should_end_by = Instant::now() + duration;
455
456        loop {
457            match future.as_mut().poll(&mut cx) {
458                Poll::Ready(result) => return Ok(result),
459                Poll::Pending => {
460                    if max_ticks == 0 {
461                        return Err(future);
462                    }
463                    max_ticks -= 1;
464
465                    if !dispatcher.tick(background_only) {
466                        if awoken.swap(false, Ordering::SeqCst) {
467                            continue;
468                        }
469
470                        if !dispatcher.parking_allowed() {
471                            if dispatcher.advance_clock_to_next_delayed() {
472                                continue;
473                            }
474                            let mut backtrace_message = String::new();
475                            let mut waiting_message = String::new();
476                            if let Some(backtrace) = dispatcher.waiting_backtrace() {
477                                backtrace_message =
478                                    format!("\nbacktrace of waiting future:\n{:?}", backtrace);
479                            }
480                            if let Some(waiting_hint) = dispatcher.waiting_hint() {
481                                waiting_message = format!("\n  waiting on: {}\n", waiting_hint);
482                            }
483                            panic!(
484                                "parked with nothing left to run{waiting_message}{backtrace_message}",
485                            )
486                        }
487                        dispatcher.push_unparker(unparker.clone());
488                        parker.park_timeout(Duration::from_millis(1));
489                        if Instant::now() > test_should_end_by {
490                            panic!("test timed out after {duration:?} with allow_parking")
491                        }
492                    }
493                }
494            }
495        }
496    }
497
498    /// Block the current thread until the given future resolves
499    /// or `duration` has elapsed.
500    pub fn block_with_timeout<Fut: Future>(
501        &self,
502        duration: Duration,
503        future: Fut,
504    ) -> Result<Fut::Output, impl Future<Output = Fut::Output> + use<Fut>> {
505        self.block_internal(true, future, Some(duration))
506    }
507
508    /// Scoped lets you start a number of tasks and waits
509    /// for all of them to complete before returning.
510    pub async fn scoped<'scope, F>(&self, scheduler: F)
511    where
512        F: FnOnce(&mut Scope<'scope>),
513    {
514        let mut scope = Scope::new(self.clone(), Priority::default());
515        (scheduler)(&mut scope);
516        let spawned = mem::take(&mut scope.futures)
517            .into_iter()
518            .map(|f| self.spawn_with_priority(scope.priority, f))
519            .collect::<Vec<_>>();
520        for task in spawned {
521            task.await;
522        }
523    }
524
525    /// Scoped lets you start a number of tasks and waits
526    /// for all of them to complete before returning.
527    pub async fn scoped_priority<'scope, F>(&self, priority: Priority, scheduler: F)
528    where
529        F: FnOnce(&mut Scope<'scope>),
530    {
531        let mut scope = Scope::new(self.clone(), priority);
532        (scheduler)(&mut scope);
533        let spawned = mem::take(&mut scope.futures)
534            .into_iter()
535            .map(|f| self.spawn_with_priority(scope.priority, f))
536            .collect::<Vec<_>>();
537        for task in spawned {
538            task.await;
539        }
540    }
541
542    /// Get the current time.
543    ///
544    /// Calling this instead of `std::time::Instant::now` allows the use
545    /// of fake timers in tests.
546    pub fn now(&self) -> Instant {
547        self.dispatcher.now()
548    }
549
550    /// Returns a task that will complete after the given duration.
551    /// Depending on other concurrent tasks the elapsed duration may be longer
552    /// than requested.
553    pub fn timer(&self, duration: Duration) -> Task<()> {
554        if duration.is_zero() {
555            return Task::ready(());
556        }
557        let location = core::panic::Location::caller();
558        let (runnable, task) = async_task::Builder::new()
559            .metadata(RunnableMeta { location })
560            .spawn(move |_| async move {}, {
561                let dispatcher = self.dispatcher.clone();
562                move |runnable| dispatcher.dispatch_after(duration, RunnableVariant::Meta(runnable))
563            });
564        runnable.schedule();
565        Task(TaskState::Spawned(task))
566    }
567
568    /// in tests, start_waiting lets you indicate which task is waiting (for debugging only)
569    #[cfg(any(test, feature = "test-support"))]
570    pub fn start_waiting(&self) {
571        self.dispatcher.as_test().unwrap().start_waiting();
572    }
573
574    /// in tests, removes the debugging data added by start_waiting
575    #[cfg(any(test, feature = "test-support"))]
576    pub fn finish_waiting(&self) {
577        self.dispatcher.as_test().unwrap().finish_waiting();
578    }
579
580    /// in tests, run an arbitrary number of tasks (determined by the SEED environment variable)
581    #[cfg(any(test, feature = "test-support"))]
582    pub fn simulate_random_delay(&self) -> impl Future<Output = ()> + use<> {
583        self.dispatcher.as_test().unwrap().simulate_random_delay()
584    }
585
586    /// in tests, indicate that a given task from `spawn_labeled` should run after everything else
587    #[cfg(any(test, feature = "test-support"))]
588    pub fn deprioritize(&self, task_label: TaskLabel) {
589        self.dispatcher.as_test().unwrap().deprioritize(task_label)
590    }
591
592    /// in tests, move time forward. This does not run any tasks, but does make `timer`s ready.
593    #[cfg(any(test, feature = "test-support"))]
594    pub fn advance_clock(&self, duration: Duration) {
595        self.dispatcher.as_test().unwrap().advance_clock(duration)
596    }
597
598    /// in tests, run one task.
599    #[cfg(any(test, feature = "test-support"))]
600    pub fn tick(&self) -> bool {
601        self.dispatcher.as_test().unwrap().tick(false)
602    }
603
604    /// in tests, run all tasks that are ready to run. If after doing so
605    /// the test still has outstanding tasks, this will panic. (See also [`Self::allow_parking`])
606    #[cfg(any(test, feature = "test-support"))]
607    pub fn run_until_parked(&self) {
608        self.dispatcher.as_test().unwrap().run_until_parked()
609    }
610
611    /// in tests, prevents `run_until_parked` from panicking if there are outstanding tasks.
612    /// This is useful when you are integrating other (non-GPUI) futures, like disk access, that
613    /// do take real async time to run.
614    #[cfg(any(test, feature = "test-support"))]
615    pub fn allow_parking(&self) {
616        self.dispatcher.as_test().unwrap().allow_parking();
617    }
618
619    /// undoes the effect of [`Self::allow_parking`].
620    #[cfg(any(test, feature = "test-support"))]
621    pub fn forbid_parking(&self) {
622        self.dispatcher.as_test().unwrap().forbid_parking();
623    }
624
625    /// adds detail to the "parked with nothing let to run" message.
626    #[cfg(any(test, feature = "test-support"))]
627    pub fn set_waiting_hint(&self, msg: Option<String>) {
628        self.dispatcher.as_test().unwrap().set_waiting_hint(msg);
629    }
630
631    /// in tests, returns the rng used by the dispatcher and seeded by the `SEED` environment variable
632    #[cfg(any(test, feature = "test-support"))]
633    pub fn rng(&self) -> StdRng {
634        self.dispatcher.as_test().unwrap().rng()
635    }
636
637    /// How many CPUs are available to the dispatcher.
638    pub fn num_cpus(&self) -> usize {
639        #[cfg(any(test, feature = "test-support"))]
640        return 4;
641
642        #[cfg(not(any(test, feature = "test-support")))]
643        return num_cpus::get();
644    }
645
646    /// Whether we're on the main thread.
647    pub fn is_main_thread(&self) -> bool {
648        self.dispatcher.is_main_thread()
649    }
650
651    #[cfg(any(test, feature = "test-support"))]
652    /// in tests, control the number of ticks that `block_with_timeout` will run before timing out.
653    pub fn set_block_on_ticks(&self, range: std::ops::RangeInclusive<usize>) {
654        self.dispatcher.as_test().unwrap().set_block_on_ticks(range);
655    }
656}
657
658/// ForegroundExecutor runs things on the main thread.
659impl ForegroundExecutor {
660    /// Creates a new ForegroundExecutor from the given PlatformDispatcher.
661    pub fn new(dispatcher: Arc<dyn PlatformDispatcher>) -> Self {
662        Self {
663            dispatcher,
664            not_send: PhantomData,
665        }
666    }
667
668    /// Enqueues the given Task to run on the main thread at some point in the future.
669    #[track_caller]
670    pub fn spawn<R>(&self, future: impl Future<Output = R> + 'static) -> Task<R>
671    where
672        R: 'static,
673    {
674        self.spawn_with_priority(Priority::default(), future)
675    }
676
677    /// Enqueues the given Task to run on the main thread at some point in the future.
678    #[track_caller]
679    pub fn spawn_with_priority<R>(
680        &self,
681        priority: Priority,
682        future: impl Future<Output = R> + 'static,
683    ) -> Task<R>
684    where
685        R: 'static,
686    {
687        let dispatcher = self.dispatcher.clone();
688        let location = core::panic::Location::caller();
689
690        #[track_caller]
691        fn inner<R: 'static>(
692            dispatcher: Arc<dyn PlatformDispatcher>,
693            future: AnyLocalFuture<R>,
694            location: &'static core::panic::Location<'static>,
695            priority: Priority,
696        ) -> Task<R> {
697            let (runnable, task) = spawn_local_with_source_location(
698                future,
699                move |runnable| {
700                    dispatcher.dispatch_on_main_thread(RunnableVariant::Meta(runnable), priority)
701                },
702                RunnableMeta { location },
703            );
704            runnable.schedule();
705            Task(TaskState::Spawned(task))
706        }
707        inner::<R>(dispatcher, Box::pin(future), location, priority)
708    }
709}
710
711/// Variant of `async_task::spawn_local` that includes the source location of the spawn in panics.
712///
713/// Copy-modified from:
714/// <https://github.com/smol-rs/async-task/blob/ca9dbe1db9c422fd765847fa91306e30a6bb58a9/src/runnable.rs#L405>
715#[track_caller]
716fn spawn_local_with_source_location<Fut, S, M>(
717    future: Fut,
718    schedule: S,
719    metadata: M,
720) -> (Runnable<M>, async_task::Task<Fut::Output, M>)
721where
722    Fut: Future + 'static,
723    Fut::Output: 'static,
724    S: async_task::Schedule<M> + Send + Sync + 'static,
725    M: 'static,
726{
727    #[inline]
728    fn thread_id() -> ThreadId {
729        std::thread_local! {
730            static ID: ThreadId = thread::current().id();
731        }
732        ID.try_with(|id| *id)
733            .unwrap_or_else(|_| thread::current().id())
734    }
735
736    struct Checked<F> {
737        id: ThreadId,
738        inner: ManuallyDrop<F>,
739        location: &'static Location<'static>,
740    }
741
742    impl<F> Drop for Checked<F> {
743        fn drop(&mut self) {
744            assert!(
745                self.id == thread_id(),
746                "local task dropped by a thread that didn't spawn it. Task spawned at {}",
747                self.location
748            );
749            unsafe { ManuallyDrop::drop(&mut self.inner) };
750        }
751    }
752
753    impl<F: Future> Future for Checked<F> {
754        type Output = F::Output;
755
756        fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
757            assert!(
758                self.id == thread_id(),
759                "local task polled by a thread that didn't spawn it. Task spawned at {}",
760                self.location
761            );
762            unsafe { self.map_unchecked_mut(|c| &mut *c.inner).poll(cx) }
763        }
764    }
765
766    // Wrap the future into one that checks which thread it's on.
767    let future = Checked {
768        id: thread_id(),
769        inner: ManuallyDrop::new(future),
770        location: Location::caller(),
771    };
772
773    unsafe {
774        async_task::Builder::new()
775            .metadata(metadata)
776            .spawn_unchecked(move |_| future, schedule)
777    }
778}
779
780/// Scope manages a set of tasks that are enqueued and waited on together. See [`BackgroundExecutor::scoped`].
781pub struct Scope<'a> {
782    executor: BackgroundExecutor,
783    priority: Priority,
784    futures: Vec<Pin<Box<dyn Future<Output = ()> + Send + 'static>>>,
785    tx: Option<mpsc::Sender<()>>,
786    rx: mpsc::Receiver<()>,
787    lifetime: PhantomData<&'a ()>,
788}
789
790impl<'a> Scope<'a> {
791    fn new(executor: BackgroundExecutor, priority: Priority) -> Self {
792        let (tx, rx) = mpsc::channel(1);
793        Self {
794            executor,
795            priority,
796            tx: Some(tx),
797            rx,
798            futures: Default::default(),
799            lifetime: PhantomData,
800        }
801    }
802
803    /// How many CPUs are available to the dispatcher.
804    pub fn num_cpus(&self) -> usize {
805        self.executor.num_cpus()
806    }
807
808    /// Spawn a future into this scope.
809    #[track_caller]
810    pub fn spawn<F>(&mut self, f: F)
811    where
812        F: Future<Output = ()> + Send + 'a,
813    {
814        let tx = self.tx.clone().unwrap();
815
816        // SAFETY: The 'a lifetime is guaranteed to outlive any of these futures because
817        // dropping this `Scope` blocks until all of the futures have resolved.
818        let f = unsafe {
819            mem::transmute::<
820                Pin<Box<dyn Future<Output = ()> + Send + 'a>>,
821                Pin<Box<dyn Future<Output = ()> + Send + 'static>>,
822            >(Box::pin(async move {
823                f.await;
824                drop(tx);
825            }))
826        };
827        self.futures.push(f);
828    }
829}
830
831impl Drop for Scope<'_> {
832    fn drop(&mut self) {
833        self.tx.take().unwrap();
834
835        // Wait until the channel is closed, which means that all of the spawned
836        // futures have resolved.
837        self.executor.block(self.rx.next());
838    }
839}