executor.rs

  1use crate::{AppContext, PlatformDispatcher};
  2use futures::{channel::mpsc, pin_mut};
  3use smol::prelude::*;
  4use std::{
  5    fmt::Debug,
  6    marker::PhantomData,
  7    mem,
  8    pin::Pin,
  9    sync::Arc,
 10    task::{Context, Poll},
 11    time::Duration,
 12};
 13use util::TryFutureExt;
 14use waker_fn::waker_fn;
 15
 16#[derive(Clone)]
 17pub struct Executor {
 18    dispatcher: Arc<dyn PlatformDispatcher>,
 19}
 20
 21#[must_use]
 22pub enum Task<T> {
 23    Ready(Option<T>),
 24    Spawned(async_task::Task<T>),
 25}
 26
 27impl<T> Task<T> {
 28    pub fn ready(val: T) -> Self {
 29        Task::Ready(Some(val))
 30    }
 31
 32    pub fn detach(self) {
 33        match self {
 34            Task::Ready(_) => {}
 35            Task::Spawned(task) => task.detach(),
 36        }
 37    }
 38}
 39
 40impl<E, T> Task<Result<T, E>>
 41where
 42    T: 'static + Send,
 43    E: 'static + Send + Debug,
 44{
 45    pub fn detach_and_log_err(self, cx: &mut AppContext) {
 46        cx.executor().spawn(self.log_err()).detach();
 47    }
 48}
 49
 50impl<T> Future for Task<T> {
 51    type Output = T;
 52
 53    fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
 54        match unsafe { self.get_unchecked_mut() } {
 55            Task::Ready(val) => Poll::Ready(val.take().unwrap()),
 56            Task::Spawned(task) => task.poll(cx),
 57        }
 58    }
 59}
 60
 61impl Executor {
 62    pub fn new(dispatcher: Arc<dyn PlatformDispatcher>) -> Self {
 63        Self { dispatcher }
 64    }
 65
 66    /// Enqueues the given closure to be run on any thread. The closure returns
 67    /// a future which will be run to completion on any available thread.
 68    pub fn spawn<R>(&self, future: impl Future<Output = R> + Send + 'static) -> Task<R>
 69    where
 70        R: Send + 'static,
 71    {
 72        let dispatcher = self.dispatcher.clone();
 73        let (runnable, task) =
 74            async_task::spawn(future, move |runnable| dispatcher.dispatch(runnable));
 75        runnable.schedule();
 76        Task::Spawned(task)
 77    }
 78
 79    /// Enqueues the given closure to run on the application's event loop.
 80    /// Returns the result asynchronously.
 81    pub fn run_on_main<F, R>(&self, func: F) -> Task<R>
 82    where
 83        F: FnOnce() -> R + Send + 'static,
 84        R: Send + 'static,
 85    {
 86        if self.dispatcher.is_main_thread() {
 87            Task::ready(func())
 88        } else {
 89            self.spawn_on_main(move || async move { func() })
 90        }
 91    }
 92
 93    /// Enqueues the given closure to be run on the application's event loop. The
 94    /// closure returns a future which will be run to completion on the main thread.
 95    pub fn spawn_on_main<F, R>(&self, func: impl FnOnce() -> F + Send + 'static) -> Task<R>
 96    where
 97        F: Future<Output = R> + 'static,
 98        R: Send + 'static,
 99    {
100        let (runnable, task) = async_task::spawn(
101            {
102                let this = self.clone();
103                async move {
104                    let task = this.spawn_on_main_local(func());
105                    task.await
106                }
107            },
108            {
109                let dispatcher = self.dispatcher.clone();
110                move |runnable| dispatcher.dispatch_on_main_thread(runnable)
111            },
112        );
113        runnable.schedule();
114        Task::Spawned(task)
115    }
116
117    /// Enqueues the given closure to be run on the application's event loop. Must
118    /// be called on the main thread.
119    pub fn spawn_on_main_local<R>(&self, future: impl Future<Output = R> + 'static) -> Task<R>
120    where
121        R: 'static,
122    {
123        assert!(
124            self.dispatcher.is_main_thread(),
125            "must be called on main thread"
126        );
127
128        let dispatcher = self.dispatcher.clone();
129        let (runnable, task) = async_task::spawn_local(future, move |runnable| {
130            dispatcher.dispatch_on_main_thread(runnable)
131        });
132        runnable.schedule();
133        Task::Spawned(task)
134    }
135
136    pub fn block<R>(&self, future: impl Future<Output = R>) -> R {
137        pin_mut!(future);
138        let (parker, unparker) = parking::pair();
139        let waker = waker_fn(move || {
140            unparker.unpark();
141        });
142        let mut cx = std::task::Context::from_waker(&waker);
143
144        loop {
145            match future.as_mut().poll(&mut cx) {
146                Poll::Ready(result) => return result,
147                Poll::Pending => {
148                    if !self.dispatcher.poll() {
149                        #[cfg(any(test, feature = "test-support"))]
150                        if let Some(_) = self.dispatcher.as_test() {
151                            panic!("blocked with nothing left to run")
152                        }
153                        parker.park();
154                    }
155                }
156            }
157        }
158    }
159
160    pub fn block_with_timeout<R>(
161        &self,
162        duration: Duration,
163        future: impl Future<Output = R>,
164    ) -> Result<R, impl Future<Output = R>> {
165        let mut future = Box::pin(future);
166        if duration.is_zero() {
167            return Err(future);
168        }
169
170        let timeout = {
171            let future = &mut future;
172            async {
173                let timer = async {
174                    self.timer(duration).await;
175                    Err(())
176                };
177                let future = async move { Ok(future.await) };
178                timer.race(future).await
179            }
180        };
181        match self.block(timeout) {
182            Ok(value) => Ok(value),
183            Err(_) => Err(future),
184        }
185    }
186
187    pub async fn scoped<'scope, F>(&self, scheduler: F)
188    where
189        F: FnOnce(&mut Scope<'scope>),
190    {
191        let mut scope = Scope::new(self.clone());
192        (scheduler)(&mut scope);
193        let spawned = mem::take(&mut scope.futures)
194            .into_iter()
195            .map(|f| self.spawn(f))
196            .collect::<Vec<_>>();
197        for task in spawned {
198            task.await;
199        }
200    }
201
202    pub fn timer(&self, duration: Duration) -> Task<()> {
203        let (runnable, task) = async_task::spawn(async move {}, {
204            let dispatcher = self.dispatcher.clone();
205            move |runnable| dispatcher.dispatch_after(duration, runnable)
206        });
207        runnable.schedule();
208        Task::Spawned(task)
209    }
210
211    #[cfg(any(test, feature = "test-support"))]
212    pub fn start_waiting(&self) {
213        todo!("start_waiting")
214    }
215
216    #[cfg(any(test, feature = "test-support"))]
217    pub fn finish_waiting(&self) {
218        todo!("finish_waiting")
219    }
220
221    #[cfg(any(test, feature = "test-support"))]
222    pub fn simulate_random_delay(&self) -> impl Future<Output = ()> {
223        self.dispatcher.as_test().unwrap().simulate_random_delay()
224    }
225
226    #[cfg(any(test, feature = "test-support"))]
227    pub fn advance_clock(&self, duration: Duration) {
228        self.dispatcher.as_test().unwrap().advance_clock(duration)
229    }
230
231    #[cfg(any(test, feature = "test-support"))]
232    pub fn run_until_parked(&self) {
233        self.dispatcher.as_test().unwrap().run_until_parked()
234    }
235
236    pub fn num_cpus(&self) -> usize {
237        num_cpus::get()
238    }
239
240    pub fn is_main_thread(&self) -> bool {
241        self.dispatcher.is_main_thread()
242    }
243}
244
245pub struct Scope<'a> {
246    executor: Executor,
247    futures: Vec<Pin<Box<dyn Future<Output = ()> + Send + 'static>>>,
248    tx: Option<mpsc::Sender<()>>,
249    rx: mpsc::Receiver<()>,
250    lifetime: PhantomData<&'a ()>,
251}
252
253impl<'a> Scope<'a> {
254    fn new(executor: Executor) -> Self {
255        let (tx, rx) = mpsc::channel(1);
256        Self {
257            executor,
258            tx: Some(tx),
259            rx,
260            futures: Default::default(),
261            lifetime: PhantomData,
262        }
263    }
264
265    pub fn spawn<F>(&mut self, f: F)
266    where
267        F: Future<Output = ()> + Send + 'a,
268    {
269        let tx = self.tx.clone().unwrap();
270
271        // Safety: The 'a lifetime is guaranteed to outlive any of these futures because
272        // dropping this `Scope` blocks until all of the futures have resolved.
273        let f = unsafe {
274            mem::transmute::<
275                Pin<Box<dyn Future<Output = ()> + Send + 'a>>,
276                Pin<Box<dyn Future<Output = ()> + Send + 'static>>,
277            >(Box::pin(async move {
278                f.await;
279                drop(tx);
280            }))
281        };
282        self.futures.push(f);
283    }
284}
285
286impl<'a> Drop for Scope<'a> {
287    fn drop(&mut self) {
288        self.tx.take().unwrap();
289
290        // Wait until the channel is closed, which means that all of the spawned
291        // futures have resolved.
292        self.executor.block(self.rx.next());
293    }
294}