executor.rs

  1use crate::{AppContext, PlatformDispatcher};
  2use futures::{channel::mpsc, pin_mut, FutureExt};
  3use smol::prelude::*;
  4use std::{
  5    fmt::Debug,
  6    marker::PhantomData,
  7    mem,
  8    pin::Pin,
  9    rc::Rc,
 10    sync::{
 11        atomic::{AtomicBool, Ordering::SeqCst},
 12        Arc,
 13    },
 14    task::{Context, Poll},
 15    time::Duration,
 16};
 17use util::TryFutureExt;
 18use waker_fn::waker_fn;
 19
 20#[cfg(any(test, feature = "test-support"))]
 21use rand::rngs::StdRng;
 22
 23#[derive(Clone)]
 24pub struct BackgroundExecutor {
 25    dispatcher: Arc<dyn PlatformDispatcher>,
 26}
 27
 28#[derive(Clone)]
 29pub struct ForegroundExecutor {
 30    dispatcher: Arc<dyn PlatformDispatcher>,
 31    not_send: PhantomData<Rc<()>>,
 32}
 33
 34#[must_use]
 35#[derive(Debug)]
 36pub enum Task<T> {
 37    Ready(Option<T>),
 38    Spawned(async_task::Task<T>),
 39}
 40
 41impl<T> Task<T> {
 42    pub fn ready(val: T) -> Self {
 43        Task::Ready(Some(val))
 44    }
 45
 46    pub fn detach(self) {
 47        match self {
 48            Task::Ready(_) => {}
 49            Task::Spawned(task) => task.detach(),
 50        }
 51    }
 52}
 53
 54impl<E, T> Task<Result<T, E>>
 55where
 56    T: 'static,
 57    E: 'static + Debug,
 58{
 59    pub fn detach_and_log_err(self, cx: &mut AppContext) {
 60        cx.foreground_executor().spawn(self.log_err()).detach();
 61    }
 62}
 63
 64impl<T> Future for Task<T> {
 65    type Output = T;
 66
 67    fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
 68        match unsafe { self.get_unchecked_mut() } {
 69            Task::Ready(val) => Poll::Ready(val.take().unwrap()),
 70            Task::Spawned(task) => task.poll(cx),
 71        }
 72    }
 73}
 74type AnyLocalFuture<R> = Pin<Box<dyn 'static + Future<Output = R>>>;
 75type AnyFuture<R> = Pin<Box<dyn 'static + Send + Future<Output = R>>>;
 76impl BackgroundExecutor {
 77    pub fn new(dispatcher: Arc<dyn PlatformDispatcher>) -> Self {
 78        Self { dispatcher }
 79    }
 80
 81    /// Enqueues the given closure to be run on any thread. The closure returns
 82    /// a future which will be run to completion on any available thread.
 83    pub fn spawn<R>(&self, future: impl Future<Output = R> + Send + 'static) -> Task<R>
 84    where
 85        R: Send + 'static,
 86    {
 87        let dispatcher = self.dispatcher.clone();
 88        fn inner<R: Send + 'static>(
 89            dispatcher: Arc<dyn PlatformDispatcher>,
 90            future: AnyFuture<R>,
 91        ) -> Task<R> {
 92            let (runnable, task) =
 93                async_task::spawn(future, move |runnable| dispatcher.dispatch(runnable));
 94            runnable.schedule();
 95            Task::Spawned(task)
 96        }
 97        inner::<R>(dispatcher, Box::pin(future))
 98    }
 99
100    #[cfg(any(test, feature = "test-support"))]
101    pub fn block_test<R>(&self, future: impl Future<Output = R>) -> R {
102        self.block_internal(false, future)
103    }
104
105    pub fn block<R>(&self, future: impl Future<Output = R>) -> R {
106        self.block_internal(true, future)
107    }
108
109    pub(crate) fn block_internal<R>(
110        &self,
111        background_only: bool,
112        future: impl Future<Output = R>,
113    ) -> R {
114        pin_mut!(future);
115        let unparker = self.dispatcher.unparker();
116        let awoken = Arc::new(AtomicBool::new(false));
117
118        let waker = waker_fn({
119            let awoken = awoken.clone();
120            move || {
121                awoken.store(true, SeqCst);
122                unparker.unpark();
123            }
124        });
125        let mut cx = std::task::Context::from_waker(&waker);
126
127        loop {
128            match future.as_mut().poll(&mut cx) {
129                Poll::Ready(result) => return result,
130                Poll::Pending => {
131                    if !self.dispatcher.poll(background_only) {
132                        if awoken.swap(false, SeqCst) {
133                            continue;
134                        }
135
136                        #[cfg(any(test, feature = "test-support"))]
137                        if let Some(test) = self.dispatcher.as_test() {
138                            if !test.parking_allowed() {
139                                let mut backtrace_message = String::new();
140                                if let Some(backtrace) = test.waiting_backtrace() {
141                                    backtrace_message =
142                                        format!("\nbacktrace of waiting future:\n{:?}", backtrace);
143                                }
144                                panic!("parked with nothing left to run\n{:?}", backtrace_message)
145                            }
146                        }
147
148                        self.dispatcher.park();
149                    }
150                }
151            }
152        }
153    }
154
155    pub fn block_with_timeout<R>(
156        &self,
157        duration: Duration,
158        future: impl Future<Output = R>,
159    ) -> Result<R, impl Future<Output = R>> {
160        let mut future = Box::pin(future.fuse());
161        if duration.is_zero() {
162            return Err(future);
163        }
164
165        let mut timer = self.timer(duration).fuse();
166        let timeout = async {
167            futures::select_biased! {
168                value = future => Ok(value),
169                _ = timer => Err(()),
170            }
171        };
172        match self.block(timeout) {
173            Ok(value) => Ok(value),
174            Err(_) => Err(future),
175        }
176    }
177
178    pub async fn scoped<'scope, F>(&self, scheduler: F)
179    where
180        F: FnOnce(&mut Scope<'scope>),
181    {
182        let mut scope = Scope::new(self.clone());
183        (scheduler)(&mut scope);
184        let spawned = mem::take(&mut scope.futures)
185            .into_iter()
186            .map(|f| self.spawn(f))
187            .collect::<Vec<_>>();
188        for task in spawned {
189            task.await;
190        }
191    }
192
193    pub fn timer(&self, duration: Duration) -> Task<()> {
194        let (runnable, task) = async_task::spawn(async move {}, {
195            let dispatcher = self.dispatcher.clone();
196            move |runnable| dispatcher.dispatch_after(duration, runnable)
197        });
198        runnable.schedule();
199        Task::Spawned(task)
200    }
201
202    #[cfg(any(test, feature = "test-support"))]
203    pub fn start_waiting(&self) {
204        self.dispatcher.as_test().unwrap().start_waiting();
205    }
206
207    #[cfg(any(test, feature = "test-support"))]
208    pub fn finish_waiting(&self) {
209        self.dispatcher.as_test().unwrap().finish_waiting();
210    }
211
212    #[cfg(any(test, feature = "test-support"))]
213    pub fn simulate_random_delay(&self) -> impl Future<Output = ()> {
214        self.dispatcher.as_test().unwrap().simulate_random_delay()
215    }
216
217    #[cfg(any(test, feature = "test-support"))]
218    pub fn advance_clock(&self, duration: Duration) {
219        self.dispatcher.as_test().unwrap().advance_clock(duration)
220    }
221
222    #[cfg(any(test, feature = "test-support"))]
223    pub fn run_until_parked(&self) {
224        self.dispatcher.as_test().unwrap().run_until_parked()
225    }
226
227    #[cfg(any(test, feature = "test-support"))]
228    pub fn allow_parking(&self) {
229        self.dispatcher.as_test().unwrap().allow_parking();
230    }
231
232    #[cfg(any(test, feature = "test-support"))]
233    pub fn rng(&self) -> StdRng {
234        self.dispatcher.as_test().unwrap().rng()
235    }
236
237    pub fn num_cpus(&self) -> usize {
238        num_cpus::get()
239    }
240
241    pub fn is_main_thread(&self) -> bool {
242        self.dispatcher.is_main_thread()
243    }
244}
245
246impl ForegroundExecutor {
247    pub fn new(dispatcher: Arc<dyn PlatformDispatcher>) -> Self {
248        Self {
249            dispatcher,
250            not_send: PhantomData,
251        }
252    }
253
254    /// Enqueues the given closure to be run on any thread. The closure returns
255    /// a future which will be run to completion on any available thread.
256    pub fn spawn<R>(&self, future: impl Future<Output = R> + 'static) -> Task<R>
257    where
258        R: 'static,
259    {
260        let dispatcher = self.dispatcher.clone();
261        fn inner<R: 'static>(
262            dispatcher: Arc<dyn PlatformDispatcher>,
263            future: AnyLocalFuture<R>,
264        ) -> Task<R> {
265            let (runnable, task) = async_task::spawn_local(future, move |runnable| {
266                dispatcher.dispatch_on_main_thread(runnable)
267            });
268            runnable.schedule();
269            Task::Spawned(task)
270        }
271        inner::<R>(dispatcher, Box::pin(future))
272    }
273}
274
275pub struct Scope<'a> {
276    executor: BackgroundExecutor,
277    futures: Vec<Pin<Box<dyn Future<Output = ()> + Send + 'static>>>,
278    tx: Option<mpsc::Sender<()>>,
279    rx: mpsc::Receiver<()>,
280    lifetime: PhantomData<&'a ()>,
281}
282
283impl<'a> Scope<'a> {
284    fn new(executor: BackgroundExecutor) -> Self {
285        let (tx, rx) = mpsc::channel(1);
286        Self {
287            executor,
288            tx: Some(tx),
289            rx,
290            futures: Default::default(),
291            lifetime: PhantomData,
292        }
293    }
294
295    pub fn spawn<F>(&mut self, f: F)
296    where
297        F: Future<Output = ()> + Send + 'a,
298    {
299        let tx = self.tx.clone().unwrap();
300
301        // Safety: The 'a lifetime is guaranteed to outlive any of these futures because
302        // dropping this `Scope` blocks until all of the futures have resolved.
303        let f = unsafe {
304            mem::transmute::<
305                Pin<Box<dyn Future<Output = ()> + Send + 'a>>,
306                Pin<Box<dyn Future<Output = ()> + Send + 'static>>,
307            >(Box::pin(async move {
308                f.await;
309                drop(tx);
310            }))
311        };
312        self.futures.push(f);
313    }
314}
315
316impl<'a> Drop for Scope<'a> {
317    fn drop(&mut self) {
318        self.tx.take().unwrap();
319
320        // Wait until the channel is closed, which means that all of the spawned
321        // futures have resolved.
322        self.executor.block(self.rx.next());
323    }
324}