executor.rs

  1use crate::{AppContext, PlatformDispatcher};
  2use futures::channel::mpsc;
  3use smol::prelude::*;
  4use std::{
  5    fmt::Debug,
  6    marker::PhantomData,
  7    mem,
  8    pin::Pin,
  9    sync::Arc,
 10    task::{Context, Poll},
 11    time::Duration,
 12};
 13use util::TryFutureExt;
 14
 15#[derive(Clone)]
 16pub struct Executor {
 17    dispatcher: Arc<dyn PlatformDispatcher>,
 18}
 19
 20#[must_use]
 21pub enum Task<T> {
 22    Ready(Option<T>),
 23    Spawned(async_task::Task<T>),
 24}
 25
 26impl<T> Task<T> {
 27    pub fn ready(val: T) -> Self {
 28        Task::Ready(Some(val))
 29    }
 30
 31    pub fn detach(self) {
 32        match self {
 33            Task::Ready(_) => {}
 34            Task::Spawned(task) => task.detach(),
 35        }
 36    }
 37}
 38
 39impl<E, T> Task<Result<T, E>>
 40where
 41    T: 'static + Send,
 42    E: 'static + Send + Debug,
 43{
 44    pub fn detach_and_log_err(self, cx: &mut AppContext) {
 45        cx.executor().spawn(self.log_err()).detach();
 46    }
 47}
 48
 49impl<T> Future for Task<T> {
 50    type Output = T;
 51
 52    fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
 53        match unsafe { self.get_unchecked_mut() } {
 54            Task::Ready(val) => Poll::Ready(val.take().unwrap()),
 55            Task::Spawned(task) => task.poll(cx),
 56        }
 57    }
 58}
 59
 60impl Executor {
 61    pub fn new(dispatcher: Arc<dyn PlatformDispatcher>) -> Self {
 62        Self { dispatcher }
 63    }
 64
 65    /// Enqueues the given closure to be run on any thread. The closure returns
 66    /// a future which will be run to completion on any available thread.
 67    pub fn spawn<R>(&self, future: impl Future<Output = R> + Send + 'static) -> Task<R>
 68    where
 69        R: Send + 'static,
 70    {
 71        let dispatcher = self.dispatcher.clone();
 72        let (runnable, task) =
 73            async_task::spawn(future, move |runnable| dispatcher.dispatch(runnable));
 74        runnable.schedule();
 75        Task::Spawned(task)
 76    }
 77
 78    /// Enqueues the given closure to run on the application's event loop.
 79    /// Returns the result asynchronously.
 80    pub fn run_on_main<F, R>(&self, func: F) -> Task<R>
 81    where
 82        F: FnOnce() -> R + Send + 'static,
 83        R: Send + 'static,
 84    {
 85        if self.dispatcher.is_main_thread() {
 86            Task::ready(func())
 87        } else {
 88            self.spawn_on_main(move || async move { func() })
 89        }
 90    }
 91
 92    /// Enqueues the given closure to be run on the application's event loop. The
 93    /// closure returns a future which will be run to completion on the main thread.
 94    pub fn spawn_on_main<F, R>(&self, func: impl FnOnce() -> F + Send + 'static) -> Task<R>
 95    where
 96        F: Future<Output = R> + 'static,
 97        R: Send + 'static,
 98    {
 99        let (runnable, task) = async_task::spawn(
100            {
101                let this = self.clone();
102                async move {
103                    let task = this.spawn_on_main_local(func());
104                    task.await
105                }
106            },
107            {
108                let dispatcher = self.dispatcher.clone();
109                move |runnable| dispatcher.dispatch_on_main_thread(runnable)
110            },
111        );
112        runnable.schedule();
113        Task::Spawned(task)
114    }
115
116    /// Enqueues the given closure to be run on the application's event loop. Must
117    /// be called on the main thread.
118    pub fn spawn_on_main_local<R>(&self, future: impl Future<Output = R> + 'static) -> Task<R>
119    where
120        R: 'static,
121    {
122        assert!(
123            self.dispatcher.is_main_thread(),
124            "must be called on main thread"
125        );
126
127        let dispatcher = self.dispatcher.clone();
128        let (runnable, task) = async_task::spawn_local(future, move |runnable| {
129            dispatcher.dispatch_on_main_thread(runnable)
130        });
131        runnable.schedule();
132        Task::Spawned(task)
133    }
134
135    pub fn block<R>(&self, future: impl Future<Output = R>) -> R {
136        // todo!("integrate with deterministic dispatcher")
137        futures::executor::block_on(future)
138    }
139
140    pub async fn scoped<'scope, F>(&self, scheduler: F)
141    where
142        F: FnOnce(&mut Scope<'scope>),
143    {
144        let mut scope = Scope::new(self.clone());
145        (scheduler)(&mut scope);
146        let spawned = mem::take(&mut scope.futures)
147            .into_iter()
148            .map(|f| self.spawn(f))
149            .collect::<Vec<_>>();
150        for task in spawned {
151            task.await;
152        }
153    }
154
155    pub fn timer(&self, duration: Duration) -> smol::Timer {
156        // todo!("integrate with deterministic dispatcher")
157        smol::Timer::after(duration)
158    }
159
160    pub fn is_main_thread(&self) -> bool {
161        self.dispatcher.is_main_thread()
162    }
163}
164
165pub struct Scope<'a> {
166    executor: Executor,
167    futures: Vec<Pin<Box<dyn Future<Output = ()> + Send + 'static>>>,
168    tx: Option<mpsc::Sender<()>>,
169    rx: mpsc::Receiver<()>,
170    lifetime: PhantomData<&'a ()>,
171}
172
173impl<'a> Scope<'a> {
174    fn new(executor: Executor) -> Self {
175        let (tx, rx) = mpsc::channel(1);
176        Self {
177            executor,
178            tx: Some(tx),
179            rx,
180            futures: Default::default(),
181            lifetime: PhantomData,
182        }
183    }
184
185    pub fn spawn<F>(&mut self, f: F)
186    where
187        F: Future<Output = ()> + Send + 'a,
188    {
189        let tx = self.tx.clone().unwrap();
190
191        // Safety: The 'a lifetime is guaranteed to outlive any of these futures because
192        // dropping this `Scope` blocks until all of the futures have resolved.
193        let f = unsafe {
194            mem::transmute::<
195                Pin<Box<dyn Future<Output = ()> + Send + 'a>>,
196                Pin<Box<dyn Future<Output = ()> + Send + 'static>>,
197            >(Box::pin(async move {
198                f.await;
199                drop(tx);
200            }))
201        };
202        self.futures.push(f);
203    }
204}
205
206impl<'a> Drop for Scope<'a> {
207    fn drop(&mut self) {
208        self.tx.take().unwrap();
209
210        // Wait until the channel is closed, which means that all of the spawned
211        // futures have resolved.
212        self.executor.block(self.rx.next());
213    }
214}