executor.rs

  1use crate::{AppContext, PlatformDispatcher};
  2use futures::channel::mpsc;
  3use smol::prelude::*;
  4use std::{
  5    fmt::Debug,
  6    marker::PhantomData,
  7    mem,
  8    pin::Pin,
  9    sync::Arc,
 10    task::{Context, Poll},
 11};
 12use util::TryFutureExt;
 13
 14#[derive(Clone)]
 15pub struct Executor {
 16    dispatcher: Arc<dyn PlatformDispatcher>,
 17}
 18
 19#[must_use]
 20pub enum Task<T> {
 21    Ready(Option<T>),
 22    Spawned(async_task::Task<T>),
 23}
 24
 25impl<T> Task<T> {
 26    pub fn ready(val: T) -> Self {
 27        Task::Ready(Some(val))
 28    }
 29
 30    pub fn detach(self) {
 31        match self {
 32            Task::Ready(_) => {}
 33            Task::Spawned(task) => task.detach(),
 34        }
 35    }
 36}
 37
 38impl<E, T> Task<Result<T, E>>
 39where
 40    T: 'static + Send,
 41    E: 'static + Send + Debug,
 42{
 43    pub fn detach_and_log_err(self, cx: &mut AppContext) {
 44        cx.executor().spawn(self.log_err()).detach();
 45    }
 46}
 47
 48impl<T> Future for Task<T> {
 49    type Output = T;
 50
 51    fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
 52        match unsafe { self.get_unchecked_mut() } {
 53            Task::Ready(val) => Poll::Ready(val.take().unwrap()),
 54            Task::Spawned(task) => task.poll(cx),
 55        }
 56    }
 57}
 58
 59impl Executor {
 60    pub fn new(dispatcher: Arc<dyn PlatformDispatcher>) -> Self {
 61        Self { dispatcher }
 62    }
 63
 64    /// Enqueues the given closure to be run on any thread. The closure returns
 65    /// a future which will be run to completion on any available thread.
 66    pub fn spawn<R>(&self, future: impl Future<Output = R> + Send + 'static) -> Task<R>
 67    where
 68        R: Send + 'static,
 69    {
 70        let dispatcher = self.dispatcher.clone();
 71        let (runnable, task) =
 72            async_task::spawn(future, move |runnable| dispatcher.dispatch(runnable));
 73        runnable.schedule();
 74        Task::Spawned(task)
 75    }
 76
 77    /// Enqueues the given closure to run on the application's event loop.
 78    /// Returns the result asynchronously.
 79    pub fn run_on_main<F, R>(&self, func: F) -> Task<R>
 80    where
 81        F: FnOnce() -> R + Send + 'static,
 82        R: Send + 'static,
 83    {
 84        if self.dispatcher.is_main_thread() {
 85            Task::ready(func())
 86        } else {
 87            self.spawn_on_main(move || async move { func() })
 88        }
 89    }
 90
 91    /// Enqueues the given closure to be run on the application's event loop. The
 92    /// closure returns a future which will be run to completion on the main thread.
 93    pub fn spawn_on_main<F, R>(&self, func: impl FnOnce() -> F + Send + 'static) -> Task<R>
 94    where
 95        F: Future<Output = R> + 'static,
 96        R: Send + 'static,
 97    {
 98        let (runnable, task) = async_task::spawn(
 99            {
100                let this = self.clone();
101                async move {
102                    let task = this.spawn_on_main_local(func());
103                    task.await
104                }
105            },
106            {
107                let dispatcher = self.dispatcher.clone();
108                move |runnable| dispatcher.dispatch_on_main_thread(runnable)
109            },
110        );
111        runnable.schedule();
112        Task::Spawned(task)
113    }
114
115    /// Enqueues the given closure to be run on the application's event loop. Must
116    /// be called on the main thread.
117    pub fn spawn_on_main_local<R>(&self, future: impl Future<Output = R> + 'static) -> Task<R>
118    where
119        R: 'static,
120    {
121        assert!(
122            self.dispatcher.is_main_thread(),
123            "must be called on main thread"
124        );
125
126        let dispatcher = self.dispatcher.clone();
127        let (runnable, task) = async_task::spawn_local(future, move |runnable| {
128            dispatcher.dispatch_on_main_thread(runnable)
129        });
130        runnable.schedule();
131        Task::Spawned(task)
132    }
133
134    pub fn block<R>(&self, future: impl Future<Output = R>) -> R {
135        // todo!("integrate with deterministic dispatcher")
136        futures::executor::block_on(future)
137    }
138
139    pub async fn scoped<'scope, F>(&self, scheduler: F)
140    where
141        F: FnOnce(&mut Scope<'scope>),
142    {
143        let mut scope = Scope::new(self.clone());
144        (scheduler)(&mut scope);
145        let spawned = mem::take(&mut scope.futures)
146            .into_iter()
147            .map(|f| self.spawn(f))
148            .collect::<Vec<_>>();
149        for task in spawned {
150            task.await;
151        }
152    }
153
154    pub fn is_main_thread(&self) -> bool {
155        self.dispatcher.is_main_thread()
156    }
157}
158
159pub struct Scope<'a> {
160    executor: Executor,
161    futures: Vec<Pin<Box<dyn Future<Output = ()> + Send + 'static>>>,
162    tx: Option<mpsc::Sender<()>>,
163    rx: mpsc::Receiver<()>,
164    lifetime: PhantomData<&'a ()>,
165}
166
167impl<'a> Scope<'a> {
168    fn new(executor: Executor) -> Self {
169        let (tx, rx) = mpsc::channel(1);
170        Self {
171            executor,
172            tx: Some(tx),
173            rx,
174            futures: Default::default(),
175            lifetime: PhantomData,
176        }
177    }
178
179    pub fn spawn<F>(&mut self, f: F)
180    where
181        F: Future<Output = ()> + Send + 'a,
182    {
183        let tx = self.tx.clone().unwrap();
184
185        // Safety: The 'a lifetime is guaranteed to outlive any of these futures because
186        // dropping this `Scope` blocks until all of the futures have resolved.
187        let f = unsafe {
188            mem::transmute::<
189                Pin<Box<dyn Future<Output = ()> + Send + 'a>>,
190                Pin<Box<dyn Future<Output = ()> + Send + 'static>>,
191            >(Box::pin(async move {
192                f.await;
193                drop(tx);
194            }))
195        };
196        self.futures.push(f);
197    }
198}
199
200impl<'a> Drop for Scope<'a> {
201    fn drop(&mut self) {
202        self.tx.take().unwrap();
203
204        // Wait until the channel is closed, which means that all of the spawned
205        // futures have resolved.
206        self.executor.block(self.rx.next());
207    }
208}