executor.rs

  1use crate::{AppContext, PlatformDispatcher};
  2use futures::{channel::mpsc, pin_mut, FutureExt};
  3use smol::prelude::*;
  4use std::{
  5    fmt::Debug,
  6    marker::PhantomData,
  7    mem,
  8    pin::Pin,
  9    sync::{
 10        atomic::{AtomicBool, Ordering::SeqCst},
 11        Arc,
 12    },
 13    task::{Context, Poll},
 14    time::Duration,
 15};
 16use util::TryFutureExt;
 17use waker_fn::waker_fn;
 18
 19#[derive(Clone)]
 20pub struct Executor {
 21    dispatcher: Arc<dyn PlatformDispatcher>,
 22}
 23
 24#[must_use]
 25pub enum Task<T> {
 26    Ready(Option<T>),
 27    Spawned(async_task::Task<T>),
 28}
 29
 30impl<T> Task<T> {
 31    pub fn ready(val: T) -> Self {
 32        Task::Ready(Some(val))
 33    }
 34
 35    pub fn detach(self) {
 36        match self {
 37            Task::Ready(_) => {}
 38            Task::Spawned(task) => task.detach(),
 39        }
 40    }
 41}
 42
 43impl<E, T> Task<Result<T, E>>
 44where
 45    T: 'static + Send,
 46    E: 'static + Send + Debug,
 47{
 48    pub fn detach_and_log_err(self, cx: &mut AppContext) {
 49        cx.executor().spawn(self.log_err()).detach();
 50    }
 51}
 52
 53impl<T> Future for Task<T> {
 54    type Output = T;
 55
 56    fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
 57        match unsafe { self.get_unchecked_mut() } {
 58            Task::Ready(val) => Poll::Ready(val.take().unwrap()),
 59            Task::Spawned(task) => task.poll(cx),
 60        }
 61    }
 62}
 63
 64impl Executor {
 65    pub fn new(dispatcher: Arc<dyn PlatformDispatcher>) -> Self {
 66        Self { dispatcher }
 67    }
 68
 69    /// Enqueues the given closure to be run on any thread. The closure returns
 70    /// a future which will be run to completion on any available thread.
 71    pub fn spawn<R>(&self, future: impl Future<Output = R> + Send + 'static) -> Task<R>
 72    where
 73        R: Send + 'static,
 74    {
 75        let dispatcher = self.dispatcher.clone();
 76        let (runnable, task) =
 77            async_task::spawn(future, move |runnable| dispatcher.dispatch(runnable));
 78        runnable.schedule();
 79        Task::Spawned(task)
 80    }
 81
 82    /// Enqueues the given closure to run on the application's event loop.
 83    /// Returns the result asynchronously.
 84    pub fn run_on_main<F, R>(&self, func: F) -> Task<R>
 85    where
 86        F: FnOnce() -> R + Send + 'static,
 87        R: Send + 'static,
 88    {
 89        if self.dispatcher.is_main_thread() {
 90            Task::ready(func())
 91        } else {
 92            self.spawn_on_main(move || async move { func() })
 93        }
 94    }
 95
 96    /// Enqueues the given closure to be run on the application's event loop. The
 97    /// closure returns a future which will be run to completion on the main thread.
 98    pub fn spawn_on_main<F, R>(&self, func: impl FnOnce() -> F + Send + 'static) -> Task<R>
 99    where
100        F: Future<Output = R> + 'static,
101        R: Send + 'static,
102    {
103        let (runnable, task) = async_task::spawn(
104            {
105                let this = self.clone();
106                async move {
107                    let task = this.spawn_on_main_local(func());
108                    task.await
109                }
110            },
111            {
112                let dispatcher = self.dispatcher.clone();
113                move |runnable| dispatcher.dispatch_on_main_thread(runnable)
114            },
115        );
116        runnable.schedule();
117        Task::Spawned(task)
118    }
119
120    /// Enqueues the given closure to be run on the application's event loop. Must
121    /// be called on the main thread.
122    pub fn spawn_on_main_local<R>(&self, future: impl Future<Output = R> + 'static) -> Task<R>
123    where
124        R: 'static,
125    {
126        assert!(
127            self.dispatcher.is_main_thread(),
128            "must be called on main thread"
129        );
130
131        let dispatcher = self.dispatcher.clone();
132        let (runnable, task) = async_task::spawn_local(future, move |runnable| {
133            dispatcher.dispatch_on_main_thread(runnable)
134        });
135        runnable.schedule();
136        Task::Spawned(task)
137    }
138
139    pub fn block<R>(&self, future: impl Future<Output = R>) -> R {
140        pin_mut!(future);
141        let (parker, unparker) = parking::pair();
142        let awoken = Arc::new(AtomicBool::new(false));
143        let awoken2 = awoken.clone();
144
145        let waker = waker_fn(move || {
146            awoken2.store(true, SeqCst);
147            unparker.unpark();
148        });
149        let mut cx = std::task::Context::from_waker(&waker);
150
151        loop {
152            match future.as_mut().poll(&mut cx) {
153                Poll::Ready(result) => return result,
154                Poll::Pending => {
155                    if !self.dispatcher.poll() {
156                        if awoken.swap(false, SeqCst) {
157                            continue;
158                        }
159
160                        #[cfg(any(test, feature = "test-support"))]
161                        if let Some(test) = self.dispatcher.as_test() {
162                            if !test.parking_allowed() {
163                                let mut backtrace_message = String::new();
164                                if let Some(backtrace) = test.waiting_backtrace() {
165                                    backtrace_message =
166                                        format!("\nbacktrace of waiting future:\n{:?}", backtrace);
167                                }
168                                panic!("parked with nothing left to run\n{:?}", backtrace_message)
169                            }
170                        }
171                        parker.park();
172                    }
173                }
174            }
175        }
176    }
177
178    pub fn block_with_timeout<R>(
179        &self,
180        duration: Duration,
181        future: impl Future<Output = R>,
182    ) -> Result<R, impl Future<Output = R>> {
183        let mut future = Box::pin(future.fuse());
184        if duration.is_zero() {
185            return Err(future);
186        }
187
188        let mut timer = self.timer(duration).fuse();
189        let timeout = async {
190            futures::select_biased! {
191                value = future => Ok(value),
192                _ = timer => Err(()),
193            }
194        };
195        match self.block(timeout) {
196            Ok(value) => Ok(value),
197            Err(_) => Err(future),
198        }
199    }
200
201    pub async fn scoped<'scope, F>(&self, scheduler: F)
202    where
203        F: FnOnce(&mut Scope<'scope>),
204    {
205        let mut scope = Scope::new(self.clone());
206        (scheduler)(&mut scope);
207        let spawned = mem::take(&mut scope.futures)
208            .into_iter()
209            .map(|f| self.spawn(f))
210            .collect::<Vec<_>>();
211        for task in spawned {
212            task.await;
213        }
214    }
215
216    pub fn timer(&self, duration: Duration) -> Task<()> {
217        let (runnable, task) = async_task::spawn(async move {}, {
218            let dispatcher = self.dispatcher.clone();
219            move |runnable| dispatcher.dispatch_after(duration, runnable)
220        });
221        runnable.schedule();
222        Task::Spawned(task)
223    }
224
225    #[cfg(any(test, feature = "test-support"))]
226    pub fn start_waiting(&self) {
227        self.dispatcher.as_test().unwrap().start_waiting();
228    }
229
230    #[cfg(any(test, feature = "test-support"))]
231    pub fn finish_waiting(&self) {
232        self.dispatcher.as_test().unwrap().finish_waiting();
233    }
234
235    #[cfg(any(test, feature = "test-support"))]
236    pub fn simulate_random_delay(&self) -> impl Future<Output = ()> {
237        self.spawn(self.dispatcher.as_test().unwrap().simulate_random_delay())
238    }
239
240    #[cfg(any(test, feature = "test-support"))]
241    pub fn advance_clock(&self, duration: Duration) {
242        self.dispatcher.as_test().unwrap().advance_clock(duration)
243    }
244
245    #[cfg(any(test, feature = "test-support"))]
246    pub fn run_until_parked(&self) {
247        self.dispatcher.as_test().unwrap().run_until_parked()
248    }
249
250    #[cfg(any(test, feature = "test-support"))]
251    pub fn allow_parking(&self) {
252        self.dispatcher.as_test().unwrap().allow_parking();
253    }
254
255    pub fn num_cpus(&self) -> usize {
256        num_cpus::get()
257    }
258
259    pub fn is_main_thread(&self) -> bool {
260        self.dispatcher.is_main_thread()
261    }
262}
263
264pub struct Scope<'a> {
265    executor: Executor,
266    futures: Vec<Pin<Box<dyn Future<Output = ()> + Send + 'static>>>,
267    tx: Option<mpsc::Sender<()>>,
268    rx: mpsc::Receiver<()>,
269    lifetime: PhantomData<&'a ()>,
270}
271
272impl<'a> Scope<'a> {
273    fn new(executor: Executor) -> Self {
274        let (tx, rx) = mpsc::channel(1);
275        Self {
276            executor,
277            tx: Some(tx),
278            rx,
279            futures: Default::default(),
280            lifetime: PhantomData,
281        }
282    }
283
284    pub fn spawn<F>(&mut self, f: F)
285    where
286        F: Future<Output = ()> + Send + 'a,
287    {
288        let tx = self.tx.clone().unwrap();
289
290        // Safety: The 'a lifetime is guaranteed to outlive any of these futures because
291        // dropping this `Scope` blocks until all of the futures have resolved.
292        let f = unsafe {
293            mem::transmute::<
294                Pin<Box<dyn Future<Output = ()> + Send + 'a>>,
295                Pin<Box<dyn Future<Output = ()> + Send + 'static>>,
296            >(Box::pin(async move {
297                f.await;
298                drop(tx);
299            }))
300        };
301        self.futures.push(f);
302    }
303}
304
305impl<'a> Drop for Scope<'a> {
306    fn drop(&mut self) {
307        self.tx.take().unwrap();
308
309        // Wait until the channel is closed, which means that all of the spawned
310        // futures have resolved.
311        self.executor.block(self.rx.next());
312    }
313}