executor.rs

  1use crate::{AppContext, PlatformDispatcher};
  2use futures::{channel::mpsc, pin_mut, FutureExt};
  3use smol::prelude::*;
  4use std::{
  5    fmt::Debug,
  6    marker::PhantomData,
  7    mem,
  8    pin::Pin,
  9    rc::Rc,
 10    sync::{
 11        atomic::{AtomicBool, Ordering::SeqCst},
 12        Arc,
 13    },
 14    task::{Context, Poll},
 15    time::Duration,
 16};
 17use util::TryFutureExt;
 18use waker_fn::waker_fn;
 19
 20#[derive(Clone)]
 21pub struct BackgroundExecutor {
 22    dispatcher: Arc<dyn PlatformDispatcher>,
 23}
 24
 25#[derive(Clone)]
 26pub struct ForegroundExecutor {
 27    dispatcher: Arc<dyn PlatformDispatcher>,
 28    not_send: PhantomData<Rc<()>>,
 29}
 30
 31#[must_use]
 32#[derive(Debug)]
 33pub enum Task<T> {
 34    Ready(Option<T>),
 35    Spawned(async_task::Task<T>),
 36}
 37
 38impl<T> Task<T> {
 39    pub fn ready(val: T) -> Self {
 40        Task::Ready(Some(val))
 41    }
 42
 43    pub fn detach(self) {
 44        match self {
 45            Task::Ready(_) => {}
 46            Task::Spawned(task) => task.detach(),
 47        }
 48    }
 49}
 50
 51impl<E, T> Task<Result<T, E>>
 52where
 53    T: 'static,
 54    E: 'static + Debug,
 55{
 56    pub fn detach_and_log_err(self, cx: &mut AppContext) {
 57        cx.foreground_executor().spawn(self.log_err()).detach();
 58    }
 59}
 60
 61impl<T> Future for Task<T> {
 62    type Output = T;
 63
 64    fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
 65        match unsafe { self.get_unchecked_mut() } {
 66            Task::Ready(val) => Poll::Ready(val.take().unwrap()),
 67            Task::Spawned(task) => task.poll(cx),
 68        }
 69    }
 70}
 71type AnyLocalFuture<R> = Pin<Box<dyn 'static + Future<Output = R>>>;
 72type AnyFuture<R> = Pin<Box<dyn 'static + Send + Future<Output = R>>>;
 73impl BackgroundExecutor {
 74    pub fn new(dispatcher: Arc<dyn PlatformDispatcher>) -> Self {
 75        Self { dispatcher }
 76    }
 77
 78    /// Enqueues the given closure to be run on any thread. The closure returns
 79    /// a future which will be run to completion on any available thread.
 80    pub fn spawn<R>(&self, future: impl Future<Output = R> + Send + 'static) -> Task<R>
 81    where
 82        R: Send + 'static,
 83    {
 84        let dispatcher = self.dispatcher.clone();
 85        fn inner<R: Send + 'static>(
 86            dispatcher: Arc<dyn PlatformDispatcher>,
 87            future: AnyFuture<R>,
 88        ) -> Task<R> {
 89            let (runnable, task) =
 90                async_task::spawn(future, move |runnable| dispatcher.dispatch(runnable));
 91            runnable.schedule();
 92            Task::Spawned(task)
 93        }
 94        inner::<R>(dispatcher, Box::pin(future))
 95    }
 96
 97    #[cfg(any(test, feature = "test-support"))]
 98    pub fn block_test<R>(&self, future: impl Future<Output = R>) -> R {
 99        self.block_internal(false, future)
100    }
101
102    pub fn block<R>(&self, future: impl Future<Output = R>) -> R {
103        self.block_internal(true, future)
104    }
105
106    pub(crate) fn block_internal<R>(
107        &self,
108        background_only: bool,
109        future: impl Future<Output = R>,
110    ) -> R {
111        pin_mut!(future);
112        let unparker = self.dispatcher.unparker();
113        let awoken = Arc::new(AtomicBool::new(false));
114
115        let waker = waker_fn({
116            let awoken = awoken.clone();
117            move || {
118                awoken.store(true, SeqCst);
119                unparker.unpark();
120            }
121        });
122        let mut cx = std::task::Context::from_waker(&waker);
123
124        loop {
125            match future.as_mut().poll(&mut cx) {
126                Poll::Ready(result) => return result,
127                Poll::Pending => {
128                    if !self.dispatcher.poll(background_only) {
129                        if awoken.swap(false, SeqCst) {
130                            continue;
131                        }
132
133                        #[cfg(any(test, feature = "test-support"))]
134                        if let Some(test) = self.dispatcher.as_test() {
135                            if !test.parking_allowed() {
136                                let mut backtrace_message = String::new();
137                                if let Some(backtrace) = test.waiting_backtrace() {
138                                    backtrace_message =
139                                        format!("\nbacktrace of waiting future:\n{:?}", backtrace);
140                                }
141                                panic!("parked with nothing left to run\n{:?}", backtrace_message)
142                            }
143                        }
144
145                        self.dispatcher.park();
146                    }
147                }
148            }
149        }
150    }
151
152    pub fn block_with_timeout<R>(
153        &self,
154        duration: Duration,
155        future: impl Future<Output = R>,
156    ) -> Result<R, impl Future<Output = R>> {
157        let mut future = Box::pin(future.fuse());
158        if duration.is_zero() {
159            return Err(future);
160        }
161
162        let mut timer = self.timer(duration).fuse();
163        let timeout = async {
164            futures::select_biased! {
165                value = future => Ok(value),
166                _ = timer => Err(()),
167            }
168        };
169        match self.block(timeout) {
170            Ok(value) => Ok(value),
171            Err(_) => Err(future),
172        }
173    }
174
175    pub async fn scoped<'scope, F>(&self, scheduler: F)
176    where
177        F: FnOnce(&mut Scope<'scope>),
178    {
179        let mut scope = Scope::new(self.clone());
180        (scheduler)(&mut scope);
181        let spawned = mem::take(&mut scope.futures)
182            .into_iter()
183            .map(|f| self.spawn(f))
184            .collect::<Vec<_>>();
185        for task in spawned {
186            task.await;
187        }
188    }
189
190    pub fn timer(&self, duration: Duration) -> Task<()> {
191        let (runnable, task) = async_task::spawn(async move {}, {
192            let dispatcher = self.dispatcher.clone();
193            move |runnable| dispatcher.dispatch_after(duration, runnable)
194        });
195        runnable.schedule();
196        Task::Spawned(task)
197    }
198
199    #[cfg(any(test, feature = "test-support"))]
200    pub fn start_waiting(&self) {
201        self.dispatcher.as_test().unwrap().start_waiting();
202    }
203
204    #[cfg(any(test, feature = "test-support"))]
205    pub fn finish_waiting(&self) {
206        self.dispatcher.as_test().unwrap().finish_waiting();
207    }
208
209    #[cfg(any(test, feature = "test-support"))]
210    pub fn simulate_random_delay(&self) -> impl Future<Output = ()> {
211        self.dispatcher.as_test().unwrap().simulate_random_delay()
212    }
213
214    #[cfg(any(test, feature = "test-support"))]
215    pub fn advance_clock(&self, duration: Duration) {
216        self.dispatcher.as_test().unwrap().advance_clock(duration)
217    }
218
219    #[cfg(any(test, feature = "test-support"))]
220    pub fn run_until_parked(&self) {
221        self.dispatcher.as_test().unwrap().run_until_parked()
222    }
223
224    #[cfg(any(test, feature = "test-support"))]
225    pub fn allow_parking(&self) {
226        self.dispatcher.as_test().unwrap().allow_parking();
227    }
228
229    pub fn num_cpus(&self) -> usize {
230        num_cpus::get()
231    }
232
233    pub fn is_main_thread(&self) -> bool {
234        self.dispatcher.is_main_thread()
235    }
236}
237
238impl ForegroundExecutor {
239    pub fn new(dispatcher: Arc<dyn PlatformDispatcher>) -> Self {
240        Self {
241            dispatcher,
242            not_send: PhantomData,
243        }
244    }
245
246    /// Enqueues the given closure to be run on any thread. The closure returns
247    /// a future which will be run to completion on any available thread.
248    pub fn spawn<R>(&self, future: impl Future<Output = R> + 'static) -> Task<R>
249    where
250        R: 'static,
251    {
252        let dispatcher = self.dispatcher.clone();
253        fn inner<R: 'static>(
254            dispatcher: Arc<dyn PlatformDispatcher>,
255            future: AnyLocalFuture<R>,
256        ) -> Task<R> {
257            let (runnable, task) = async_task::spawn_local(future, move |runnable| {
258                dispatcher.dispatch_on_main_thread(runnable)
259            });
260            runnable.schedule();
261            Task::Spawned(task)
262        }
263        inner::<R>(dispatcher, Box::pin(future))
264    }
265}
266
267pub struct Scope<'a> {
268    executor: BackgroundExecutor,
269    futures: Vec<Pin<Box<dyn Future<Output = ()> + Send + 'static>>>,
270    tx: Option<mpsc::Sender<()>>,
271    rx: mpsc::Receiver<()>,
272    lifetime: PhantomData<&'a ()>,
273}
274
275impl<'a> Scope<'a> {
276    fn new(executor: BackgroundExecutor) -> Self {
277        let (tx, rx) = mpsc::channel(1);
278        Self {
279            executor,
280            tx: Some(tx),
281            rx,
282            futures: Default::default(),
283            lifetime: PhantomData,
284        }
285    }
286
287    pub fn spawn<F>(&mut self, f: F)
288    where
289        F: Future<Output = ()> + Send + 'a,
290    {
291        let tx = self.tx.clone().unwrap();
292
293        // Safety: The 'a lifetime is guaranteed to outlive any of these futures because
294        // dropping this `Scope` blocks until all of the futures have resolved.
295        let f = unsafe {
296            mem::transmute::<
297                Pin<Box<dyn Future<Output = ()> + Send + 'a>>,
298                Pin<Box<dyn Future<Output = ()> + Send + 'static>>,
299            >(Box::pin(async move {
300                f.await;
301                drop(tx);
302            }))
303        };
304        self.futures.push(f);
305    }
306}
307
308impl<'a> Drop for Scope<'a> {
309    fn drop(&mut self) {
310        self.tx.take().unwrap();
311
312        // Wait until the channel is closed, which means that all of the spawned
313        // futures have resolved.
314        self.executor.block(self.rx.next());
315    }
316}