executor.rs

  1use crate::{AppContext, PlatformDispatcher};
  2use futures::{channel::mpsc, pin_mut, FutureExt};
  3use smol::prelude::*;
  4use std::{
  5    fmt::Debug,
  6    marker::PhantomData,
  7    mem,
  8    num::NonZeroUsize,
  9    pin::Pin,
 10    rc::Rc,
 11    sync::{
 12        atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
 13        Arc,
 14    },
 15    task::{Context, Poll},
 16    time::Duration,
 17};
 18use util::TryFutureExt;
 19use waker_fn::waker_fn;
 20
 21#[cfg(any(test, feature = "test-support"))]
 22use rand::rngs::StdRng;
 23
 24#[derive(Clone)]
 25pub struct BackgroundExecutor {
 26    dispatcher: Arc<dyn PlatformDispatcher>,
 27}
 28
 29#[derive(Clone)]
 30pub struct ForegroundExecutor {
 31    dispatcher: Arc<dyn PlatformDispatcher>,
 32    not_send: PhantomData<Rc<()>>,
 33}
 34
 35#[must_use]
 36#[derive(Debug)]
 37pub enum Task<T> {
 38    Ready(Option<T>),
 39    Spawned(async_task::Task<T>),
 40}
 41
 42impl<T> Task<T> {
 43    pub fn ready(val: T) -> Self {
 44        Task::Ready(Some(val))
 45    }
 46
 47    pub fn detach(self) {
 48        match self {
 49            Task::Ready(_) => {}
 50            Task::Spawned(task) => task.detach(),
 51        }
 52    }
 53}
 54
 55impl<E, T> Task<Result<T, E>>
 56where
 57    T: 'static,
 58    E: 'static + Debug,
 59{
 60    #[track_caller]
 61    pub fn detach_and_log_err(self, cx: &mut AppContext) {
 62        let location = core::panic::Location::caller();
 63        cx.foreground_executor()
 64            .spawn(self.log_tracked_err(*location))
 65            .detach();
 66    }
 67}
 68
 69impl<T> Future for Task<T> {
 70    type Output = T;
 71
 72    fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
 73        match unsafe { self.get_unchecked_mut() } {
 74            Task::Ready(val) => Poll::Ready(val.take().unwrap()),
 75            Task::Spawned(task) => task.poll(cx),
 76        }
 77    }
 78}
 79
 80#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
 81pub struct TaskLabel(NonZeroUsize);
 82
 83impl TaskLabel {
 84    pub fn new() -> Self {
 85        static NEXT_TASK_LABEL: AtomicUsize = AtomicUsize::new(1);
 86        Self(NEXT_TASK_LABEL.fetch_add(1, SeqCst).try_into().unwrap())
 87    }
 88}
 89
 90type AnyLocalFuture<R> = Pin<Box<dyn 'static + Future<Output = R>>>;
 91
 92type AnyFuture<R> = Pin<Box<dyn 'static + Send + Future<Output = R>>>;
 93
 94impl BackgroundExecutor {
 95    pub fn new(dispatcher: Arc<dyn PlatformDispatcher>) -> Self {
 96        Self { dispatcher }
 97    }
 98
 99    /// Enqueues the given future to be run to completion on a background thread.
100    pub fn spawn<R>(&self, future: impl Future<Output = R> + Send + 'static) -> Task<R>
101    where
102        R: Send + 'static,
103    {
104        self.spawn_internal::<R>(Box::pin(future), None)
105    }
106
107    /// Enqueues the given future to be run to completion on a background thread.
108    /// The given label can be used to control the priority of the task in tests.
109    pub fn spawn_labeled<R>(
110        &self,
111        label: TaskLabel,
112        future: impl Future<Output = R> + Send + 'static,
113    ) -> Task<R>
114    where
115        R: Send + 'static,
116    {
117        self.spawn_internal::<R>(Box::pin(future), Some(label))
118    }
119
120    fn spawn_internal<R: Send + 'static>(
121        &self,
122        future: AnyFuture<R>,
123        label: Option<TaskLabel>,
124    ) -> Task<R> {
125        let dispatcher = self.dispatcher.clone();
126        let (runnable, task) =
127            async_task::spawn(future, move |runnable| dispatcher.dispatch(runnable, label));
128        runnable.schedule();
129        Task::Spawned(task)
130    }
131
132    #[cfg(any(test, feature = "test-support"))]
133    #[track_caller]
134    pub fn block_test<R>(&self, future: impl Future<Output = R>) -> R {
135        if let Ok(value) = self.block_internal(false, future, usize::MAX) {
136            value
137        } else {
138            unreachable!()
139        }
140    }
141
142    pub fn block<R>(&self, future: impl Future<Output = R>) -> R {
143        if let Ok(value) = self.block_internal(true, future, usize::MAX) {
144            value
145        } else {
146            unreachable!()
147        }
148    }
149
150    #[track_caller]
151    pub(crate) fn block_internal<R>(
152        &self,
153        background_only: bool,
154        future: impl Future<Output = R>,
155        mut max_ticks: usize,
156    ) -> Result<R, ()> {
157        pin_mut!(future);
158        let unparker = self.dispatcher.unparker();
159        let awoken = Arc::new(AtomicBool::new(false));
160
161        let waker = waker_fn({
162            let awoken = awoken.clone();
163            move || {
164                awoken.store(true, SeqCst);
165                unparker.unpark();
166            }
167        });
168        let mut cx = std::task::Context::from_waker(&waker);
169
170        loop {
171            match future.as_mut().poll(&mut cx) {
172                Poll::Ready(result) => return Ok(result),
173                Poll::Pending => {
174                    if max_ticks == 0 {
175                        return Err(());
176                    }
177                    max_ticks -= 1;
178
179                    if !self.dispatcher.tick(background_only) {
180                        if awoken.swap(false, SeqCst) {
181                            continue;
182                        }
183
184                        #[cfg(any(test, feature = "test-support"))]
185                        if let Some(test) = self.dispatcher.as_test() {
186                            if !test.parking_allowed() {
187                                let mut backtrace_message = String::new();
188                                if let Some(backtrace) = test.waiting_backtrace() {
189                                    backtrace_message =
190                                        format!("\nbacktrace of waiting future:\n{:?}", backtrace);
191                                }
192                                panic!("parked with nothing left to run\n{:?}", backtrace_message)
193                            }
194                        }
195
196                        self.dispatcher.park();
197                    }
198                }
199            }
200        }
201    }
202
203    pub fn block_with_timeout<R>(
204        &self,
205        duration: Duration,
206        future: impl Future<Output = R>,
207    ) -> Result<R, impl Future<Output = R>> {
208        let mut future = Box::pin(future.fuse());
209        if duration.is_zero() {
210            return Err(future);
211        }
212
213        #[cfg(any(test, feature = "test-support"))]
214        let max_ticks = self
215            .dispatcher
216            .as_test()
217            .map_or(usize::MAX, |dispatcher| dispatcher.gen_block_on_ticks());
218        #[cfg(not(any(test, feature = "test-support")))]
219        let max_ticks = usize::MAX;
220
221        let mut timer = self.timer(duration).fuse();
222
223        let timeout = async {
224            futures::select_biased! {
225                value = future => Ok(value),
226                _ = timer => Err(()),
227            }
228        };
229        match self.block_internal(true, timeout, max_ticks) {
230            Ok(Ok(value)) => Ok(value),
231            _ => Err(future),
232        }
233    }
234
235    pub async fn scoped<'scope, F>(&self, scheduler: F)
236    where
237        F: FnOnce(&mut Scope<'scope>),
238    {
239        let mut scope = Scope::new(self.clone());
240        (scheduler)(&mut scope);
241        let spawned = mem::take(&mut scope.futures)
242            .into_iter()
243            .map(|f| self.spawn(f))
244            .collect::<Vec<_>>();
245        for task in spawned {
246            task.await;
247        }
248    }
249
250    pub fn timer(&self, duration: Duration) -> Task<()> {
251        let (runnable, task) = async_task::spawn(async move {}, {
252            let dispatcher = self.dispatcher.clone();
253            move |runnable| dispatcher.dispatch_after(duration, runnable)
254        });
255        runnable.schedule();
256        Task::Spawned(task)
257    }
258
259    #[cfg(any(test, feature = "test-support"))]
260    pub fn start_waiting(&self) {
261        self.dispatcher.as_test().unwrap().start_waiting();
262    }
263
264    #[cfg(any(test, feature = "test-support"))]
265    pub fn finish_waiting(&self) {
266        self.dispatcher.as_test().unwrap().finish_waiting();
267    }
268
269    #[cfg(any(test, feature = "test-support"))]
270    pub fn simulate_random_delay(&self) -> impl Future<Output = ()> {
271        self.dispatcher.as_test().unwrap().simulate_random_delay()
272    }
273
274    #[cfg(any(test, feature = "test-support"))]
275    pub fn deprioritize(&self, task_label: TaskLabel) {
276        self.dispatcher.as_test().unwrap().deprioritize(task_label)
277    }
278
279    #[cfg(any(test, feature = "test-support"))]
280    pub fn advance_clock(&self, duration: Duration) {
281        self.dispatcher.as_test().unwrap().advance_clock(duration)
282    }
283
284    #[cfg(any(test, feature = "test-support"))]
285    pub fn tick(&self) -> bool {
286        self.dispatcher.as_test().unwrap().tick(false)
287    }
288
289    #[cfg(any(test, feature = "test-support"))]
290    pub fn run_until_parked(&self) {
291        self.dispatcher.as_test().unwrap().run_until_parked()
292    }
293
294    #[cfg(any(test, feature = "test-support"))]
295    pub fn allow_parking(&self) {
296        self.dispatcher.as_test().unwrap().allow_parking();
297    }
298
299    #[cfg(any(test, feature = "test-support"))]
300    pub fn rng(&self) -> StdRng {
301        self.dispatcher.as_test().unwrap().rng()
302    }
303
304    pub fn num_cpus(&self) -> usize {
305        num_cpus::get()
306    }
307
308    pub fn is_main_thread(&self) -> bool {
309        self.dispatcher.is_main_thread()
310    }
311
312    #[cfg(any(test, feature = "test-support"))]
313    pub fn set_block_on_ticks(&self, range: std::ops::RangeInclusive<usize>) {
314        self.dispatcher.as_test().unwrap().set_block_on_ticks(range);
315    }
316}
317
318impl ForegroundExecutor {
319    pub fn new(dispatcher: Arc<dyn PlatformDispatcher>) -> Self {
320        Self {
321            dispatcher,
322            not_send: PhantomData,
323        }
324    }
325
326    /// Enqueues the given closure to be run on any thread. The closure returns
327    /// a future which will be run to completion on any available thread.
328    pub fn spawn<R>(&self, future: impl Future<Output = R> + 'static) -> Task<R>
329    where
330        R: 'static,
331    {
332        let dispatcher = self.dispatcher.clone();
333        fn inner<R: 'static>(
334            dispatcher: Arc<dyn PlatformDispatcher>,
335            future: AnyLocalFuture<R>,
336        ) -> Task<R> {
337            let (runnable, task) = async_task::spawn_local(future, move |runnable| {
338                dispatcher.dispatch_on_main_thread(runnable)
339            });
340            runnable.schedule();
341            Task::Spawned(task)
342        }
343        inner::<R>(dispatcher, Box::pin(future))
344    }
345}
346
347pub struct Scope<'a> {
348    executor: BackgroundExecutor,
349    futures: Vec<Pin<Box<dyn Future<Output = ()> + Send + 'static>>>,
350    tx: Option<mpsc::Sender<()>>,
351    rx: mpsc::Receiver<()>,
352    lifetime: PhantomData<&'a ()>,
353}
354
355impl<'a> Scope<'a> {
356    fn new(executor: BackgroundExecutor) -> Self {
357        let (tx, rx) = mpsc::channel(1);
358        Self {
359            executor,
360            tx: Some(tx),
361            rx,
362            futures: Default::default(),
363            lifetime: PhantomData,
364        }
365    }
366
367    pub fn spawn<F>(&mut self, f: F)
368    where
369        F: Future<Output = ()> + Send + 'a,
370    {
371        let tx = self.tx.clone().unwrap();
372
373        // Safety: The 'a lifetime is guaranteed to outlive any of these futures because
374        // dropping this `Scope` blocks until all of the futures have resolved.
375        let f = unsafe {
376            mem::transmute::<
377                Pin<Box<dyn Future<Output = ()> + Send + 'a>>,
378                Pin<Box<dyn Future<Output = ()> + Send + 'static>>,
379            >(Box::pin(async move {
380                f.await;
381                drop(tx);
382            }))
383        };
384        self.futures.push(f);
385    }
386}
387
388impl<'a> Drop for Scope<'a> {
389    fn drop(&mut self) {
390        self.tx.take().unwrap();
391
392        // Wait until the channel is closed, which means that all of the spawned
393        // futures have resolved.
394        self.executor.block(self.rx.next());
395    }
396}