executor.rs

  1use crate::{AppContext, PlatformDispatcher};
  2use futures::{channel::mpsc, pin_mut, FutureExt};
  3use smol::prelude::*;
  4use std::{
  5    fmt::Debug,
  6    marker::PhantomData,
  7    mem,
  8    pin::Pin,
  9    rc::Rc,
 10    sync::{
 11        atomic::{AtomicBool, Ordering::SeqCst},
 12        Arc,
 13    },
 14    task::{Context, Poll},
 15    time::Duration,
 16};
 17use util::TryFutureExt;
 18use waker_fn::waker_fn;
 19
 20#[derive(Clone)]
 21pub struct BackgroundExecutor {
 22    dispatcher: Arc<dyn PlatformDispatcher>,
 23}
 24
 25#[derive(Clone)]
 26pub struct ForegroundExecutor {
 27    dispatcher: Arc<dyn PlatformDispatcher>,
 28    not_send: PhantomData<Rc<()>>,
 29}
 30
 31#[must_use]
 32pub enum Task<T> {
 33    Ready(Option<T>),
 34    Spawned(async_task::Task<T>),
 35}
 36
 37impl<T> Task<T> {
 38    pub fn ready(val: T) -> Self {
 39        Task::Ready(Some(val))
 40    }
 41
 42    pub fn detach(self) {
 43        match self {
 44            Task::Ready(_) => {}
 45            Task::Spawned(task) => task.detach(),
 46        }
 47    }
 48}
 49
 50impl<E, T> Task<Result<T, E>>
 51where
 52    T: 'static + Send,
 53    E: 'static + Send + Debug,
 54{
 55    pub fn detach_and_log_err(self, cx: &mut AppContext) {
 56        cx.background_executor().spawn(self.log_err()).detach();
 57    }
 58}
 59
 60impl<T> Future for Task<T> {
 61    type Output = T;
 62
 63    fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
 64        match unsafe { self.get_unchecked_mut() } {
 65            Task::Ready(val) => Poll::Ready(val.take().unwrap()),
 66            Task::Spawned(task) => task.poll(cx),
 67        }
 68    }
 69}
 70
 71impl BackgroundExecutor {
 72    pub fn new(dispatcher: Arc<dyn PlatformDispatcher>) -> Self {
 73        Self { dispatcher }
 74    }
 75
 76    /// Enqueues the given closure to be run on any thread. The closure returns
 77    /// a future which will be run to completion on any available thread.
 78    pub fn spawn<R>(&self, future: impl Future<Output = R> + Send + 'static) -> Task<R>
 79    where
 80        R: Send + 'static,
 81    {
 82        let dispatcher = self.dispatcher.clone();
 83        let (runnable, task) =
 84            async_task::spawn(future, move |runnable| dispatcher.dispatch(runnable));
 85        runnable.schedule();
 86        Task::Spawned(task)
 87    }
 88
 89    #[cfg(any(test, feature = "test-support"))]
 90    pub fn block_test<R>(&self, future: impl Future<Output = R>) -> R {
 91        self.block_internal(false, future)
 92    }
 93
 94    pub fn block<R>(&self, future: impl Future<Output = R>) -> R {
 95        self.block_internal(true, future)
 96    }
 97
 98    pub(crate) fn block_internal<R>(
 99        &self,
100        background_only: bool,
101        future: impl Future<Output = R>,
102    ) -> R {
103        pin_mut!(future);
104        let unparker = self.dispatcher.unparker();
105        let awoken = Arc::new(AtomicBool::new(false));
106
107        let waker = waker_fn({
108            let awoken = awoken.clone();
109            move || {
110                awoken.store(true, SeqCst);
111                unparker.unpark();
112            }
113        });
114        let mut cx = std::task::Context::from_waker(&waker);
115
116        loop {
117            match future.as_mut().poll(&mut cx) {
118                Poll::Ready(result) => return result,
119                Poll::Pending => {
120                    if !self.dispatcher.poll(background_only) {
121                        if awoken.swap(false, SeqCst) {
122                            continue;
123                        }
124
125                        #[cfg(any(test, feature = "test-support"))]
126                        if let Some(test) = self.dispatcher.as_test() {
127                            if !test.parking_allowed() {
128                                let mut backtrace_message = String::new();
129                                if let Some(backtrace) = test.waiting_backtrace() {
130                                    backtrace_message =
131                                        format!("\nbacktrace of waiting future:\n{:?}", backtrace);
132                                }
133                                panic!("parked with nothing left to run\n{:?}", backtrace_message)
134                            }
135                        }
136
137                        self.dispatcher.park();
138                    }
139                }
140            }
141        }
142    }
143
144    pub fn block_with_timeout<R>(
145        &self,
146        duration: Duration,
147        future: impl Future<Output = R>,
148    ) -> Result<R, impl Future<Output = R>> {
149        let mut future = Box::pin(future.fuse());
150        if duration.is_zero() {
151            return Err(future);
152        }
153
154        let mut timer = self.timer(duration).fuse();
155        let timeout = async {
156            futures::select_biased! {
157                value = future => Ok(value),
158                _ = timer => Err(()),
159            }
160        };
161        match self.block(timeout) {
162            Ok(value) => Ok(value),
163            Err(_) => Err(future),
164        }
165    }
166
167    pub async fn scoped<'scope, F>(&self, scheduler: F)
168    where
169        F: FnOnce(&mut Scope<'scope>),
170    {
171        let mut scope = Scope::new(self.clone());
172        (scheduler)(&mut scope);
173        let spawned = mem::take(&mut scope.futures)
174            .into_iter()
175            .map(|f| self.spawn(f))
176            .collect::<Vec<_>>();
177        for task in spawned {
178            task.await;
179        }
180    }
181
182    pub fn timer(&self, duration: Duration) -> Task<()> {
183        let (runnable, task) = async_task::spawn(async move {}, {
184            let dispatcher = self.dispatcher.clone();
185            move |runnable| dispatcher.dispatch_after(duration, runnable)
186        });
187        runnable.schedule();
188        Task::Spawned(task)
189    }
190
191    #[cfg(any(test, feature = "test-support"))]
192    pub fn start_waiting(&self) {
193        self.dispatcher.as_test().unwrap().start_waiting();
194    }
195
196    #[cfg(any(test, feature = "test-support"))]
197    pub fn finish_waiting(&self) {
198        self.dispatcher.as_test().unwrap().finish_waiting();
199    }
200
201    #[cfg(any(test, feature = "test-support"))]
202    pub fn simulate_random_delay(&self) -> impl Future<Output = ()> {
203        self.dispatcher.as_test().unwrap().simulate_random_delay()
204    }
205
206    #[cfg(any(test, feature = "test-support"))]
207    pub fn advance_clock(&self, duration: Duration) {
208        self.dispatcher.as_test().unwrap().advance_clock(duration)
209    }
210
211    #[cfg(any(test, feature = "test-support"))]
212    pub fn run_until_parked(&self) {
213        self.dispatcher.as_test().unwrap().run_until_parked()
214    }
215
216    #[cfg(any(test, feature = "test-support"))]
217    pub fn allow_parking(&self) {
218        self.dispatcher.as_test().unwrap().allow_parking();
219    }
220
221    pub fn num_cpus(&self) -> usize {
222        num_cpus::get()
223    }
224
225    pub fn is_main_thread(&self) -> bool {
226        self.dispatcher.is_main_thread()
227    }
228}
229
230impl ForegroundExecutor {
231    pub fn new(dispatcher: Arc<dyn PlatformDispatcher>) -> Self {
232        Self {
233            dispatcher,
234            not_send: PhantomData,
235        }
236    }
237
238    /// Enqueues the given closure to be run on any thread. The closure returns
239    /// a future which will be run to completion on any available thread.
240    pub fn spawn<R>(&self, future: impl Future<Output = R> + 'static) -> Task<R>
241    where
242        R: 'static,
243    {
244        let dispatcher = self.dispatcher.clone();
245        let (runnable, task) = async_task::spawn_local(future, move |runnable| {
246            dispatcher.dispatch_on_main_thread(runnable)
247        });
248        runnable.schedule();
249        Task::Spawned(task)
250    }
251}
252
253pub struct Scope<'a> {
254    executor: BackgroundExecutor,
255    futures: Vec<Pin<Box<dyn Future<Output = ()> + Send + 'static>>>,
256    tx: Option<mpsc::Sender<()>>,
257    rx: mpsc::Receiver<()>,
258    lifetime: PhantomData<&'a ()>,
259}
260
261impl<'a> Scope<'a> {
262    fn new(executor: BackgroundExecutor) -> Self {
263        let (tx, rx) = mpsc::channel(1);
264        Self {
265            executor,
266            tx: Some(tx),
267            rx,
268            futures: Default::default(),
269            lifetime: PhantomData,
270        }
271    }
272
273    pub fn spawn<F>(&mut self, f: F)
274    where
275        F: Future<Output = ()> + Send + 'a,
276    {
277        let tx = self.tx.clone().unwrap();
278
279        // Safety: The 'a lifetime is guaranteed to outlive any of these futures because
280        // dropping this `Scope` blocks until all of the futures have resolved.
281        let f = unsafe {
282            mem::transmute::<
283                Pin<Box<dyn Future<Output = ()> + Send + 'a>>,
284                Pin<Box<dyn Future<Output = ()> + Send + 'static>>,
285            >(Box::pin(async move {
286                f.await;
287                drop(tx);
288            }))
289        };
290        self.futures.push(f);
291    }
292}
293
294impl<'a> Drop for Scope<'a> {
295    fn drop(&mut self) {
296        self.tx.take().unwrap();
297
298        // Wait until the channel is closed, which means that all of the spawned
299        // futures have resolved.
300        self.executor.block(self.rx.next());
301    }
302}