executor.rs

  1use crate::{AppContext, PlatformDispatcher};
  2use futures::{channel::mpsc, pin_mut, FutureExt};
  3use smol::prelude::*;
  4use std::{
  5    fmt::Debug,
  6    marker::PhantomData,
  7    mem,
  8    pin::Pin,
  9    rc::Rc,
 10    sync::{
 11        atomic::{AtomicBool, Ordering::SeqCst},
 12        Arc,
 13    },
 14    task::{Context, Poll},
 15    time::Duration,
 16};
 17use util::TryFutureExt;
 18use waker_fn::waker_fn;
 19
 20#[cfg(any(test, feature = "test-support"))]
 21use rand::rngs::StdRng;
 22
 23#[derive(Clone)]
 24pub struct BackgroundExecutor {
 25    dispatcher: Arc<dyn PlatformDispatcher>,
 26}
 27
 28#[derive(Clone)]
 29pub struct ForegroundExecutor {
 30    dispatcher: Arc<dyn PlatformDispatcher>,
 31    not_send: PhantomData<Rc<()>>,
 32}
 33
 34#[must_use]
 35#[derive(Debug)]
 36pub enum Task<T> {
 37    Ready(Option<T>),
 38    Spawned(async_task::Task<T>),
 39}
 40
 41impl<T> Task<T> {
 42    pub fn ready(val: T) -> Self {
 43        Task::Ready(Some(val))
 44    }
 45
 46    pub fn detach(self) {
 47        match self {
 48            Task::Ready(_) => {}
 49            Task::Spawned(task) => task.detach(),
 50        }
 51    }
 52}
 53
 54impl<E, T> Task<Result<T, E>>
 55where
 56    T: 'static,
 57    E: 'static + Debug,
 58{
 59    pub fn detach_and_log_err(self, cx: &mut AppContext) {
 60        cx.foreground_executor().spawn(self.log_err()).detach();
 61    }
 62}
 63
 64impl<T> Future for Task<T> {
 65    type Output = T;
 66
 67    fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
 68        match unsafe { self.get_unchecked_mut() } {
 69            Task::Ready(val) => Poll::Ready(val.take().unwrap()),
 70            Task::Spawned(task) => task.poll(cx),
 71        }
 72    }
 73}
 74
 75impl BackgroundExecutor {
 76    pub fn new(dispatcher: Arc<dyn PlatformDispatcher>) -> Self {
 77        Self { dispatcher }
 78    }
 79
 80    /// Enqueues the given closure to be run on any thread. The closure returns
 81    /// a future which will be run to completion on any available thread.
 82    pub fn spawn<R>(&self, future: impl Future<Output = R> + Send + 'static) -> Task<R>
 83    where
 84        R: Send + 'static,
 85    {
 86        let dispatcher = self.dispatcher.clone();
 87        let (runnable, task) =
 88            async_task::spawn(future, move |runnable| dispatcher.dispatch(runnable));
 89        runnable.schedule();
 90        Task::Spawned(task)
 91    }
 92
 93    #[cfg(any(test, feature = "test-support"))]
 94    pub fn block_test<R>(&self, future: impl Future<Output = R>) -> R {
 95        self.block_internal(false, future)
 96    }
 97
 98    pub fn block<R>(&self, future: impl Future<Output = R>) -> R {
 99        self.block_internal(true, future)
100    }
101
102    pub(crate) fn block_internal<R>(
103        &self,
104        background_only: bool,
105        future: impl Future<Output = R>,
106    ) -> R {
107        pin_mut!(future);
108        let unparker = self.dispatcher.unparker();
109        let awoken = Arc::new(AtomicBool::new(false));
110
111        let waker = waker_fn({
112            let awoken = awoken.clone();
113            move || {
114                awoken.store(true, SeqCst);
115                unparker.unpark();
116            }
117        });
118        let mut cx = std::task::Context::from_waker(&waker);
119
120        loop {
121            match future.as_mut().poll(&mut cx) {
122                Poll::Ready(result) => return result,
123                Poll::Pending => {
124                    if !self.dispatcher.poll(background_only) {
125                        if awoken.swap(false, SeqCst) {
126                            continue;
127                        }
128
129                        #[cfg(any(test, feature = "test-support"))]
130                        if let Some(test) = self.dispatcher.as_test() {
131                            if !test.parking_allowed() {
132                                let mut backtrace_message = String::new();
133                                if let Some(backtrace) = test.waiting_backtrace() {
134                                    backtrace_message =
135                                        format!("\nbacktrace of waiting future:\n{:?}", backtrace);
136                                }
137                                panic!("parked with nothing left to run\n{:?}", backtrace_message)
138                            }
139                        }
140
141                        self.dispatcher.park();
142                    }
143                }
144            }
145        }
146    }
147
148    pub fn block_with_timeout<R>(
149        &self,
150        duration: Duration,
151        future: impl Future<Output = R>,
152    ) -> Result<R, impl Future<Output = R>> {
153        let mut future = Box::pin(future.fuse());
154        if duration.is_zero() {
155            return Err(future);
156        }
157
158        let mut timer = self.timer(duration).fuse();
159        let timeout = async {
160            futures::select_biased! {
161                value = future => Ok(value),
162                _ = timer => Err(()),
163            }
164        };
165        match self.block(timeout) {
166            Ok(value) => Ok(value),
167            Err(_) => Err(future),
168        }
169    }
170
171    pub async fn scoped<'scope, F>(&self, scheduler: F)
172    where
173        F: FnOnce(&mut Scope<'scope>),
174    {
175        let mut scope = Scope::new(self.clone());
176        (scheduler)(&mut scope);
177        let spawned = mem::take(&mut scope.futures)
178            .into_iter()
179            .map(|f| self.spawn(f))
180            .collect::<Vec<_>>();
181        for task in spawned {
182            task.await;
183        }
184    }
185
186    pub fn timer(&self, duration: Duration) -> Task<()> {
187        let (runnable, task) = async_task::spawn(async move {}, {
188            let dispatcher = self.dispatcher.clone();
189            move |runnable| dispatcher.dispatch_after(duration, runnable)
190        });
191        runnable.schedule();
192        Task::Spawned(task)
193    }
194
195    #[cfg(any(test, feature = "test-support"))]
196    pub fn start_waiting(&self) {
197        self.dispatcher.as_test().unwrap().start_waiting();
198    }
199
200    #[cfg(any(test, feature = "test-support"))]
201    pub fn finish_waiting(&self) {
202        self.dispatcher.as_test().unwrap().finish_waiting();
203    }
204
205    #[cfg(any(test, feature = "test-support"))]
206    pub fn simulate_random_delay(&self) -> impl Future<Output = ()> {
207        self.dispatcher.as_test().unwrap().simulate_random_delay()
208    }
209
210    #[cfg(any(test, feature = "test-support"))]
211    pub fn advance_clock(&self, duration: Duration) {
212        self.dispatcher.as_test().unwrap().advance_clock(duration)
213    }
214
215    #[cfg(any(test, feature = "test-support"))]
216    pub fn run_until_parked(&self) {
217        self.dispatcher.as_test().unwrap().run_until_parked()
218    }
219
220    #[cfg(any(test, feature = "test-support"))]
221    pub fn allow_parking(&self) {
222        self.dispatcher.as_test().unwrap().allow_parking();
223    }
224
225    #[cfg(any(test, feature = "test-support"))]
226    pub fn record_backtrace(&self) {
227        self.dispatcher.as_test().unwrap().record_backtrace();
228    }
229
230    #[cfg(any(test, feature = "test-support"))]
231    pub fn rng(&self) -> StdRng {
232        self.dispatcher.as_test().unwrap().rng()
233    }
234
235    pub fn num_cpus(&self) -> usize {
236        num_cpus::get()
237    }
238
239    pub fn is_main_thread(&self) -> bool {
240        self.dispatcher.is_main_thread()
241    }
242}
243
244impl ForegroundExecutor {
245    pub fn new(dispatcher: Arc<dyn PlatformDispatcher>) -> Self {
246        Self {
247            dispatcher,
248            not_send: PhantomData,
249        }
250    }
251
252    /// Enqueues the given closure to be run on any thread. The closure returns
253    /// a future which will be run to completion on any available thread.
254    pub fn spawn<R>(&self, future: impl Future<Output = R> + 'static) -> Task<R>
255    where
256        R: 'static,
257    {
258        let dispatcher = self.dispatcher.clone();
259        let (runnable, task) = async_task::spawn_local(future, move |runnable| {
260            dispatcher.dispatch_on_main_thread(runnable)
261        });
262        runnable.schedule();
263        Task::Spawned(task)
264    }
265}
266
267pub struct Scope<'a> {
268    executor: BackgroundExecutor,
269    futures: Vec<Pin<Box<dyn Future<Output = ()> + Send + 'static>>>,
270    tx: Option<mpsc::Sender<()>>,
271    rx: mpsc::Receiver<()>,
272    lifetime: PhantomData<&'a ()>,
273}
274
275impl<'a> Scope<'a> {
276    fn new(executor: BackgroundExecutor) -> Self {
277        let (tx, rx) = mpsc::channel(1);
278        Self {
279            executor,
280            tx: Some(tx),
281            rx,
282            futures: Default::default(),
283            lifetime: PhantomData,
284        }
285    }
286
287    pub fn spawn<F>(&mut self, f: F)
288    where
289        F: Future<Output = ()> + Send + 'a,
290    {
291        let tx = self.tx.clone().unwrap();
292
293        // Safety: The 'a lifetime is guaranteed to outlive any of these futures because
294        // dropping this `Scope` blocks until all of the futures have resolved.
295        let f = unsafe {
296            mem::transmute::<
297                Pin<Box<dyn Future<Output = ()> + Send + 'a>>,
298                Pin<Box<dyn Future<Output = ()> + Send + 'static>>,
299            >(Box::pin(async move {
300                f.await;
301                drop(tx);
302            }))
303        };
304        self.futures.push(f);
305    }
306}
307
308impl<'a> Drop for Scope<'a> {
309    fn drop(&mut self) {
310        self.tx.take().unwrap();
311
312        // Wait until the channel is closed, which means that all of the spawned
313        // futures have resolved.
314        self.executor.block(self.rx.next());
315    }
316}