executor.rs

  1use crate::{App, PlatformDispatcher, RunnableMeta, RunnableVariant, TaskTiming, profiler};
  2use async_task::Runnable;
  3use futures::channel::mpsc;
  4use parking_lot::{Condvar, Mutex};
  5use smol::prelude::*;
  6use std::{
  7    fmt::Debug,
  8    marker::PhantomData,
  9    mem::{self, ManuallyDrop},
 10    num::NonZeroUsize,
 11    panic::Location,
 12    pin::Pin,
 13    rc::Rc,
 14    sync::{
 15        Arc,
 16        atomic::{AtomicUsize, Ordering},
 17    },
 18    task::{Context, Poll},
 19    thread::{self, ThreadId},
 20    time::{Duration, Instant},
 21};
 22use util::TryFutureExt;
 23use waker_fn::waker_fn;
 24
 25#[cfg(any(test, feature = "test-support"))]
 26use rand::rngs::StdRng;
 27
 28/// A pointer to the executor that is currently running,
 29/// for spawning background tasks.
 30#[derive(Clone)]
 31pub struct BackgroundExecutor {
 32    #[doc(hidden)]
 33    pub dispatcher: Arc<dyn PlatformDispatcher>,
 34}
 35
 36/// A pointer to the executor that is currently running,
 37/// for spawning tasks on the main thread.
 38///
 39/// This is intentionally `!Send` via the `not_send` marker field. This is because
 40/// `ForegroundExecutor::spawn` does not require `Send` but checks at runtime that the future is
 41/// only polled from the same thread it was spawned from. These checks would fail when spawning
 42/// foreground tasks from background threads.
 43#[derive(Clone)]
 44pub struct ForegroundExecutor {
 45    #[doc(hidden)]
 46    pub dispatcher: Arc<dyn PlatformDispatcher>,
 47    not_send: PhantomData<Rc<()>>,
 48}
 49
 50/// Realtime task priority
 51#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
 52#[repr(u8)]
 53pub enum RealtimePriority {
 54    /// Audio task
 55    Audio,
 56    /// Other realtime task
 57    #[default]
 58    Other,
 59}
 60
 61/// Task priority
 62#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
 63#[repr(u8)]
 64pub enum Priority {
 65    /// Realtime priority
 66    ///
 67    /// Spawning a task with this priority will spin it off on a separate thread dedicated just to that task.
 68    Realtime(RealtimePriority),
 69    /// High priority
 70    ///
 71    /// Only use for tasks that are critical to the user experience / responsiveness of the editor.
 72    High,
 73    /// Medium priority, probably suits most of your use cases.
 74    #[default]
 75    Medium,
 76    /// Low priority
 77    ///
 78    /// Prioritize this for background work that can come in large quantities
 79    /// to not starve the executor of resources for high priority tasks
 80    Low,
 81}
 82
 83impl Priority {
 84    #[allow(dead_code)]
 85    pub(crate) const fn probability(&self) -> u32 {
 86        match self {
 87            // realtime priorities are not considered for probability scheduling
 88            Priority::Realtime(_) => 0,
 89            Priority::High => 60,
 90            Priority::Medium => 30,
 91            Priority::Low => 10,
 92        }
 93    }
 94}
 95
 96/// Task is a primitive that allows work to happen in the background.
 97///
 98/// It implements [`Future`] so you can `.await` on it.
 99///
100/// If you drop a task it will be cancelled immediately. Calling [`Task::detach`] allows
101/// the task to continue running, but with no way to return a value.
102#[must_use]
103#[derive(Debug)]
104pub struct Task<T>(TaskState<T>);
105
106#[derive(Debug)]
107enum TaskState<T> {
108    /// A task that is ready to return a value
109    Ready(Option<T>),
110
111    /// A task that is currently running.
112    Spawned(async_task::Task<T, RunnableMeta>),
113}
114
115impl<T> Task<T> {
116    /// Creates a new task that will resolve with the value
117    pub fn ready(val: T) -> Self {
118        Task(TaskState::Ready(Some(val)))
119    }
120
121    /// Detaching a task runs it to completion in the background
122    pub fn detach(self) {
123        match self {
124            Task(TaskState::Ready(_)) => {}
125            Task(TaskState::Spawned(task)) => task.detach(),
126        }
127    }
128}
129
130impl<E, T> Task<Result<T, E>>
131where
132    T: 'static,
133    E: 'static + Debug,
134{
135    /// Run the task to completion in the background and log any
136    /// errors that occur.
137    #[track_caller]
138    pub fn detach_and_log_err(self, cx: &App) {
139        let location = core::panic::Location::caller();
140        cx.foreground_executor()
141            .spawn(self.log_tracked_err(*location))
142            .detach();
143    }
144}
145
146impl<T> Future for Task<T> {
147    type Output = T;
148
149    fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
150        match unsafe { self.get_unchecked_mut() } {
151            Task(TaskState::Ready(val)) => Poll::Ready(val.take().unwrap()),
152            Task(TaskState::Spawned(task)) => task.poll(cx),
153        }
154    }
155}
156
157/// A task label is an opaque identifier that you can use to
158/// refer to a task in tests.
159#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
160pub struct TaskLabel(NonZeroUsize);
161
162impl Default for TaskLabel {
163    fn default() -> Self {
164        Self::new()
165    }
166}
167
168impl TaskLabel {
169    /// Construct a new task label.
170    pub fn new() -> Self {
171        static NEXT_TASK_LABEL: AtomicUsize = AtomicUsize::new(1);
172        Self(
173            NEXT_TASK_LABEL
174                .fetch_add(1, Ordering::SeqCst)
175                .try_into()
176                .unwrap(),
177        )
178    }
179}
180
181type AnyLocalFuture<R> = Pin<Box<dyn 'static + Future<Output = R>>>;
182
183type AnyFuture<R> = Pin<Box<dyn 'static + Send + Future<Output = R>>>;
184
185/// BackgroundExecutor lets you run things on background threads.
186/// In production this is a thread pool with no ordering guarantees.
187/// In tests this is simulated by running tasks one by one in a deterministic
188/// (but arbitrary) order controlled by the `SEED` environment variable.
189impl BackgroundExecutor {
190    #[doc(hidden)]
191    pub fn new(dispatcher: Arc<dyn PlatformDispatcher>) -> Self {
192        Self { dispatcher }
193    }
194
195    /// Enqueues the given future to be run to completion on a background thread.
196    #[track_caller]
197    pub fn spawn<R>(&self, future: impl Future<Output = R> + Send + 'static) -> Task<R>
198    where
199        R: Send + 'static,
200    {
201        self.spawn_with_priority(Priority::default(), future)
202    }
203
204    /// Enqueues the given future to be run to completion on a background thread.
205    #[track_caller]
206    pub fn spawn_with_priority<R>(
207        &self,
208        priority: Priority,
209        future: impl Future<Output = R> + Send + 'static,
210    ) -> Task<R>
211    where
212        R: Send + 'static,
213    {
214        self.spawn_internal::<R>(Box::pin(future), None, priority)
215    }
216
217    /// Enqueues the given future to be run to completion on a background thread and blocking the current task on it.
218    ///
219    /// This allows to spawn background work that borrows from its scope. Note that the supplied future will run to
220    /// completion before the current task is resumed, even if the current task is slated for cancellation.
221    pub async fn await_on_background<R>(&self, future: impl Future<Output = R> + Send) -> R
222    where
223        R: Send,
224    {
225        // We need to ensure that cancellation of the parent task does not drop the environment
226        // before the our own task has completed or got cancelled.
227        struct NotifyOnDrop<'a>(&'a (Condvar, Mutex<bool>));
228
229        impl Drop for NotifyOnDrop<'_> {
230            fn drop(&mut self) {
231                *self.0.1.lock() = true;
232                self.0.0.notify_all();
233            }
234        }
235
236        struct WaitOnDrop<'a>(&'a (Condvar, Mutex<bool>));
237
238        impl Drop for WaitOnDrop<'_> {
239            fn drop(&mut self) {
240                let mut done = self.0.1.lock();
241                if !*done {
242                    self.0.0.wait(&mut done);
243                }
244            }
245        }
246
247        let dispatcher = self.dispatcher.clone();
248        let location = core::panic::Location::caller();
249
250        let pair = &(Condvar::new(), Mutex::new(false));
251        let _wait_guard = WaitOnDrop(pair);
252
253        let (runnable, task) = unsafe {
254            async_task::Builder::new()
255                .metadata(RunnableMeta { location })
256                .spawn_unchecked(
257                    move |_| async {
258                        let _notify_guard = NotifyOnDrop(pair);
259                        future.await
260                    },
261                    move |runnable| {
262                        dispatcher.dispatch(
263                            RunnableVariant::Meta(runnable),
264                            None,
265                            Priority::default(),
266                        )
267                    },
268                )
269        };
270        runnable.schedule();
271        task.await
272    }
273
274    /// Enqueues the given future to be run to completion on a background thread.
275    /// The given label can be used to control the priority of the task in tests.
276    #[track_caller]
277    pub fn spawn_labeled<R>(
278        &self,
279        label: TaskLabel,
280        future: impl Future<Output = R> + Send + 'static,
281    ) -> Task<R>
282    where
283        R: Send + 'static,
284    {
285        self.spawn_internal::<R>(Box::pin(future), Some(label), Priority::default())
286    }
287
288    #[track_caller]
289    fn spawn_internal<R: Send + 'static>(
290        &self,
291        future: AnyFuture<R>,
292        label: Option<TaskLabel>,
293        #[cfg_attr(
294            target_os = "windows",
295            expect(
296                unused_variables,
297                reason = "Multi priority scheduler is broken on windows"
298            )
299        )]
300        priority: Priority,
301    ) -> Task<R> {
302        let dispatcher = self.dispatcher.clone();
303        #[cfg(target_os = "windows")]
304        let priority = Priority::Medium; // multi-prio scheduler is broken on windows
305
306        let (runnable, task) = if let Priority::Realtime(realtime) = priority {
307            let location = core::panic::Location::caller();
308            let (mut tx, rx) = flume::bounded::<Runnable<RunnableMeta>>(1);
309
310            dispatcher.spawn_realtime(
311                realtime,
312                Box::new(move || {
313                    while let Ok(runnable) = rx.recv() {
314                        let start = Instant::now();
315                        let location = runnable.metadata().location;
316                        let mut timing = TaskTiming {
317                            location,
318                            start,
319                            end: None,
320                        };
321                        profiler::add_task_timing(timing);
322
323                        runnable.run();
324
325                        let end = Instant::now();
326                        timing.end = Some(end);
327                        profiler::add_task_timing(timing);
328                    }
329                }),
330            );
331
332            async_task::Builder::new()
333                .metadata(RunnableMeta { location })
334                .spawn(
335                    move |_| future,
336                    move |runnable| {
337                        let _ = tx.send(runnable);
338                    },
339                )
340        } else {
341            let location = core::panic::Location::caller();
342            async_task::Builder::new()
343                .metadata(RunnableMeta { location })
344                .spawn(
345                    move |_| future,
346                    move |runnable| {
347                        dispatcher.dispatch(RunnableVariant::Meta(runnable), label, priority)
348                    },
349                )
350        };
351
352        runnable.schedule();
353        Task(TaskState::Spawned(task))
354    }
355
356    /// Used by the test harness to run an async test in a synchronous fashion.
357    #[cfg(any(test, feature = "test-support"))]
358    #[track_caller]
359    pub fn block_test<R>(&self, future: impl Future<Output = R>) -> R {
360        if let Ok(value) = self.block_internal(false, future, None) {
361            value
362        } else {
363            unreachable!()
364        }
365    }
366
367    /// Block the current thread until the given future resolves.
368    /// Consider using `block_with_timeout` instead.
369    pub fn block<R>(&self, future: impl Future<Output = R>) -> R {
370        if let Ok(value) = self.block_internal(true, future, None) {
371            value
372        } else {
373            unreachable!()
374        }
375    }
376
377    #[cfg(not(any(test, feature = "test-support")))]
378    pub(crate) fn block_internal<Fut: Future>(
379        &self,
380        _background_only: bool,
381        future: Fut,
382        timeout: Option<Duration>,
383    ) -> Result<Fut::Output, impl Future<Output = Fut::Output> + use<Fut>> {
384        use std::time::Instant;
385
386        let mut future = Box::pin(future);
387        if timeout == Some(Duration::ZERO) {
388            return Err(future);
389        }
390        let deadline = timeout.map(|timeout| Instant::now() + timeout);
391
392        let parker = parking::Parker::new();
393        let unparker = parker.unparker();
394        let waker = waker_fn(move || {
395            unparker.unpark();
396        });
397        let mut cx = std::task::Context::from_waker(&waker);
398
399        loop {
400            match future.as_mut().poll(&mut cx) {
401                Poll::Ready(result) => return Ok(result),
402                Poll::Pending => {
403                    let timeout =
404                        deadline.map(|deadline| deadline.saturating_duration_since(Instant::now()));
405                    if let Some(timeout) = timeout {
406                        if !parker.park_timeout(timeout)
407                            && deadline.is_some_and(|deadline| deadline < Instant::now())
408                        {
409                            return Err(future);
410                        }
411                    } else {
412                        parker.park();
413                    }
414                }
415            }
416        }
417    }
418
419    #[cfg(any(test, feature = "test-support"))]
420    #[track_caller]
421    pub(crate) fn block_internal<Fut: Future>(
422        &self,
423        background_only: bool,
424        future: Fut,
425        timeout: Option<Duration>,
426    ) -> Result<Fut::Output, impl Future<Output = Fut::Output> + use<Fut>> {
427        use std::sync::atomic::AtomicBool;
428
429        use parking::Parker;
430
431        let mut future = Box::pin(future);
432        if timeout == Some(Duration::ZERO) {
433            return Err(future);
434        }
435        let Some(dispatcher) = self.dispatcher.as_test() else {
436            return Err(future);
437        };
438
439        let mut max_ticks = if timeout.is_some() {
440            dispatcher.gen_block_on_ticks()
441        } else {
442            usize::MAX
443        };
444
445        let parker = Parker::new();
446        let unparker = parker.unparker();
447
448        let awoken = Arc::new(AtomicBool::new(false));
449        let waker = waker_fn({
450            let awoken = awoken.clone();
451            let unparker = unparker.clone();
452            move || {
453                awoken.store(true, Ordering::SeqCst);
454                unparker.unpark();
455            }
456        });
457        let mut cx = std::task::Context::from_waker(&waker);
458
459        let duration = Duration::from_secs(
460            option_env!("GPUI_TEST_TIMEOUT")
461                .and_then(|s| s.parse::<u64>().ok())
462                .unwrap_or(180),
463        );
464        let mut test_should_end_by = Instant::now() + duration;
465
466        loop {
467            match future.as_mut().poll(&mut cx) {
468                Poll::Ready(result) => return Ok(result),
469                Poll::Pending => {
470                    if max_ticks == 0 {
471                        return Err(future);
472                    }
473                    max_ticks -= 1;
474
475                    if !dispatcher.tick(background_only) {
476                        if awoken.swap(false, Ordering::SeqCst) {
477                            continue;
478                        }
479
480                        if !dispatcher.parking_allowed() {
481                            if dispatcher.advance_clock_to_next_delayed() {
482                                continue;
483                            }
484                            let mut backtrace_message = String::new();
485                            let mut waiting_message = String::new();
486                            if let Some(backtrace) = dispatcher.waiting_backtrace() {
487                                backtrace_message =
488                                    format!("\nbacktrace of waiting future:\n{:?}", backtrace);
489                            }
490                            if let Some(waiting_hint) = dispatcher.waiting_hint() {
491                                waiting_message = format!("\n  waiting on: {}\n", waiting_hint);
492                            }
493                            panic!(
494                                "parked with nothing left to run{waiting_message}{backtrace_message}",
495                            )
496                        }
497                        dispatcher.push_unparker(unparker.clone());
498                        parker.park_timeout(Duration::from_millis(1));
499                        if Instant::now() > test_should_end_by {
500                            panic!("test timed out after {duration:?} with allow_parking")
501                        }
502                    }
503                }
504            }
505        }
506    }
507
508    /// Block the current thread until the given future resolves
509    /// or `duration` has elapsed.
510    pub fn block_with_timeout<Fut: Future>(
511        &self,
512        duration: Duration,
513        future: Fut,
514    ) -> Result<Fut::Output, impl Future<Output = Fut::Output> + use<Fut>> {
515        self.block_internal(true, future, Some(duration))
516    }
517
518    /// Scoped lets you start a number of tasks and waits
519    /// for all of them to complete before returning.
520    pub async fn scoped<'scope, F>(&self, scheduler: F)
521    where
522        F: FnOnce(&mut Scope<'scope>),
523    {
524        let mut scope = Scope::new(self.clone(), Priority::default());
525        (scheduler)(&mut scope);
526        let spawned = mem::take(&mut scope.futures)
527            .into_iter()
528            .map(|f| self.spawn_with_priority(scope.priority, f))
529            .collect::<Vec<_>>();
530        for task in spawned {
531            task.await;
532        }
533    }
534
535    /// Scoped lets you start a number of tasks and waits
536    /// for all of them to complete before returning.
537    pub async fn scoped_priority<'scope, F>(&self, priority: Priority, scheduler: F)
538    where
539        F: FnOnce(&mut Scope<'scope>),
540    {
541        let mut scope = Scope::new(self.clone(), priority);
542        (scheduler)(&mut scope);
543        let spawned = mem::take(&mut scope.futures)
544            .into_iter()
545            .map(|f| self.spawn_with_priority(scope.priority, f))
546            .collect::<Vec<_>>();
547        for task in spawned {
548            task.await;
549        }
550    }
551
552    /// Get the current time.
553    ///
554    /// Calling this instead of `std::time::Instant::now` allows the use
555    /// of fake timers in tests.
556    pub fn now(&self) -> Instant {
557        self.dispatcher.now()
558    }
559
560    /// Returns a task that will complete after the given duration.
561    /// Depending on other concurrent tasks the elapsed duration may be longer
562    /// than requested.
563    pub fn timer(&self, duration: Duration) -> Task<()> {
564        if duration.is_zero() {
565            return Task::ready(());
566        }
567        let location = core::panic::Location::caller();
568        let (runnable, task) = async_task::Builder::new()
569            .metadata(RunnableMeta { location })
570            .spawn(move |_| async move {}, {
571                let dispatcher = self.dispatcher.clone();
572                move |runnable| dispatcher.dispatch_after(duration, RunnableVariant::Meta(runnable))
573            });
574        runnable.schedule();
575        Task(TaskState::Spawned(task))
576    }
577
578    /// in tests, start_waiting lets you indicate which task is waiting (for debugging only)
579    #[cfg(any(test, feature = "test-support"))]
580    pub fn start_waiting(&self) {
581        self.dispatcher.as_test().unwrap().start_waiting();
582    }
583
584    /// in tests, removes the debugging data added by start_waiting
585    #[cfg(any(test, feature = "test-support"))]
586    pub fn finish_waiting(&self) {
587        self.dispatcher.as_test().unwrap().finish_waiting();
588    }
589
590    /// in tests, run an arbitrary number of tasks (determined by the SEED environment variable)
591    #[cfg(any(test, feature = "test-support"))]
592    pub fn simulate_random_delay(&self) -> impl Future<Output = ()> + use<> {
593        self.dispatcher.as_test().unwrap().simulate_random_delay()
594    }
595
596    /// in tests, indicate that a given task from `spawn_labeled` should run after everything else
597    #[cfg(any(test, feature = "test-support"))]
598    pub fn deprioritize(&self, task_label: TaskLabel) {
599        self.dispatcher.as_test().unwrap().deprioritize(task_label)
600    }
601
602    /// in tests, move time forward. This does not run any tasks, but does make `timer`s ready.
603    #[cfg(any(test, feature = "test-support"))]
604    pub fn advance_clock(&self, duration: Duration) {
605        self.dispatcher.as_test().unwrap().advance_clock(duration)
606    }
607
608    /// in tests, run one task.
609    #[cfg(any(test, feature = "test-support"))]
610    pub fn tick(&self) -> bool {
611        self.dispatcher.as_test().unwrap().tick(false)
612    }
613
614    /// in tests, run all tasks that are ready to run. If after doing so
615    /// the test still has outstanding tasks, this will panic. (See also [`Self::allow_parking`])
616    #[cfg(any(test, feature = "test-support"))]
617    pub fn run_until_parked(&self) {
618        self.dispatcher.as_test().unwrap().run_until_parked()
619    }
620
621    /// in tests, prevents `run_until_parked` from panicking if there are outstanding tasks.
622    /// This is useful when you are integrating other (non-GPUI) futures, like disk access, that
623    /// do take real async time to run.
624    #[cfg(any(test, feature = "test-support"))]
625    pub fn allow_parking(&self) {
626        self.dispatcher.as_test().unwrap().allow_parking();
627    }
628
629    /// undoes the effect of [`Self::allow_parking`].
630    #[cfg(any(test, feature = "test-support"))]
631    pub fn forbid_parking(&self) {
632        self.dispatcher.as_test().unwrap().forbid_parking();
633    }
634
635    /// adds detail to the "parked with nothing let to run" message.
636    #[cfg(any(test, feature = "test-support"))]
637    pub fn set_waiting_hint(&self, msg: Option<String>) {
638        self.dispatcher.as_test().unwrap().set_waiting_hint(msg);
639    }
640
641    /// in tests, returns the rng used by the dispatcher and seeded by the `SEED` environment variable
642    #[cfg(any(test, feature = "test-support"))]
643    pub fn rng(&self) -> StdRng {
644        self.dispatcher.as_test().unwrap().rng()
645    }
646
647    /// How many CPUs are available to the dispatcher.
648    pub fn num_cpus(&self) -> usize {
649        #[cfg(any(test, feature = "test-support"))]
650        return 4;
651
652        #[cfg(not(any(test, feature = "test-support")))]
653        return num_cpus::get();
654    }
655
656    /// Whether we're on the main thread.
657    pub fn is_main_thread(&self) -> bool {
658        self.dispatcher.is_main_thread()
659    }
660
661    #[cfg(any(test, feature = "test-support"))]
662    /// in tests, control the number of ticks that `block_with_timeout` will run before timing out.
663    pub fn set_block_on_ticks(&self, range: std::ops::RangeInclusive<usize>) {
664        self.dispatcher.as_test().unwrap().set_block_on_ticks(range);
665    }
666}
667
668/// ForegroundExecutor runs things on the main thread.
669impl ForegroundExecutor {
670    /// Creates a new ForegroundExecutor from the given PlatformDispatcher.
671    pub fn new(dispatcher: Arc<dyn PlatformDispatcher>) -> Self {
672        Self {
673            dispatcher,
674            not_send: PhantomData,
675        }
676    }
677
678    /// Enqueues the given Task to run on the main thread at some point in the future.
679    #[track_caller]
680    pub fn spawn<R>(&self, future: impl Future<Output = R> + 'static) -> Task<R>
681    where
682        R: 'static,
683    {
684        self.spawn_with_priority(Priority::default(), future)
685    }
686
687    /// Enqueues the given Task to run on the main thread at some point in the future.
688    #[track_caller]
689    pub fn spawn_with_priority<R>(
690        &self,
691        priority: Priority,
692        future: impl Future<Output = R> + 'static,
693    ) -> Task<R>
694    where
695        R: 'static,
696    {
697        let dispatcher = self.dispatcher.clone();
698        let location = core::panic::Location::caller();
699
700        #[track_caller]
701        fn inner<R: 'static>(
702            dispatcher: Arc<dyn PlatformDispatcher>,
703            future: AnyLocalFuture<R>,
704            location: &'static core::panic::Location<'static>,
705            priority: Priority,
706        ) -> Task<R> {
707            let (runnable, task) = spawn_local_with_source_location(
708                future,
709                move |runnable| {
710                    dispatcher.dispatch_on_main_thread(RunnableVariant::Meta(runnable), priority)
711                },
712                RunnableMeta { location },
713            );
714            runnable.schedule();
715            Task(TaskState::Spawned(task))
716        }
717        inner::<R>(dispatcher, Box::pin(future), location, priority)
718    }
719}
720
721/// Variant of `async_task::spawn_local` that includes the source location of the spawn in panics.
722///
723/// Copy-modified from:
724/// <https://github.com/smol-rs/async-task/blob/ca9dbe1db9c422fd765847fa91306e30a6bb58a9/src/runnable.rs#L405>
725#[track_caller]
726fn spawn_local_with_source_location<Fut, S, M>(
727    future: Fut,
728    schedule: S,
729    metadata: M,
730) -> (Runnable<M>, async_task::Task<Fut::Output, M>)
731where
732    Fut: Future + 'static,
733    Fut::Output: 'static,
734    S: async_task::Schedule<M> + Send + Sync + 'static,
735    M: 'static,
736{
737    #[inline]
738    fn thread_id() -> ThreadId {
739        std::thread_local! {
740            static ID: ThreadId = thread::current().id();
741        }
742        ID.try_with(|id| *id)
743            .unwrap_or_else(|_| thread::current().id())
744    }
745
746    struct Checked<F> {
747        id: ThreadId,
748        inner: ManuallyDrop<F>,
749        location: &'static Location<'static>,
750    }
751
752    impl<F> Drop for Checked<F> {
753        fn drop(&mut self) {
754            assert!(
755                self.id == thread_id(),
756                "local task dropped by a thread that didn't spawn it. Task spawned at {}",
757                self.location
758            );
759            unsafe { ManuallyDrop::drop(&mut self.inner) };
760        }
761    }
762
763    impl<F: Future> Future for Checked<F> {
764        type Output = F::Output;
765
766        fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
767            assert!(
768                self.id == thread_id(),
769                "local task polled by a thread that didn't spawn it. Task spawned at {}",
770                self.location
771            );
772            unsafe { self.map_unchecked_mut(|c| &mut *c.inner).poll(cx) }
773        }
774    }
775
776    // Wrap the future into one that checks which thread it's on.
777    let future = Checked {
778        id: thread_id(),
779        inner: ManuallyDrop::new(future),
780        location: Location::caller(),
781    };
782
783    unsafe {
784        async_task::Builder::new()
785            .metadata(metadata)
786            .spawn_unchecked(move |_| future, schedule)
787    }
788}
789
790/// Scope manages a set of tasks that are enqueued and waited on together. See [`BackgroundExecutor::scoped`].
791pub struct Scope<'a> {
792    executor: BackgroundExecutor,
793    priority: Priority,
794    futures: Vec<Pin<Box<dyn Future<Output = ()> + Send + 'static>>>,
795    tx: Option<mpsc::Sender<()>>,
796    rx: mpsc::Receiver<()>,
797    lifetime: PhantomData<&'a ()>,
798}
799
800impl<'a> Scope<'a> {
801    fn new(executor: BackgroundExecutor, priority: Priority) -> Self {
802        let (tx, rx) = mpsc::channel(1);
803        Self {
804            executor,
805            priority,
806            tx: Some(tx),
807            rx,
808            futures: Default::default(),
809            lifetime: PhantomData,
810        }
811    }
812
813    /// How many CPUs are available to the dispatcher.
814    pub fn num_cpus(&self) -> usize {
815        self.executor.num_cpus()
816    }
817
818    /// Spawn a future into this scope.
819    #[track_caller]
820    pub fn spawn<F>(&mut self, f: F)
821    where
822        F: Future<Output = ()> + Send + 'a,
823    {
824        let tx = self.tx.clone().unwrap();
825
826        // SAFETY: The 'a lifetime is guaranteed to outlive any of these futures because
827        // dropping this `Scope` blocks until all of the futures have resolved.
828        let f = unsafe {
829            mem::transmute::<
830                Pin<Box<dyn Future<Output = ()> + Send + 'a>>,
831                Pin<Box<dyn Future<Output = ()> + Send + 'static>>,
832            >(Box::pin(async move {
833                f.await;
834                drop(tx);
835            }))
836        };
837        self.futures.push(f);
838    }
839}
840
841impl Drop for Scope<'_> {
842    fn drop(&mut self) {
843        self.tx.take().unwrap();
844
845        // Wait until the channel is closed, which means that all of the spawned
846        // futures have resolved.
847        self.executor.block(self.rx.next());
848    }
849}