1use crate::{AppContext, PlatformDispatcher};
2use futures::channel::mpsc;
3use smol::prelude::*;
4use std::{
5 fmt::Debug,
6 marker::PhantomData,
7 mem,
8 pin::Pin,
9 sync::Arc,
10 task::{Context, Poll},
11};
12use util::TryFutureExt;
13
14#[derive(Clone)]
15pub struct Executor {
16 dispatcher: Arc<dyn PlatformDispatcher>,
17}
18
19#[must_use]
20pub enum Task<T> {
21 Ready(Option<T>),
22 Spawned(async_task::Task<T>),
23}
24
25impl<T> Task<T> {
26 pub fn ready(val: T) -> Self {
27 Task::Ready(Some(val))
28 }
29
30 pub fn detach(self) {
31 match self {
32 Task::Ready(_) => {}
33 Task::Spawned(task) => task.detach(),
34 }
35 }
36}
37
38impl<E, T> Task<Result<T, E>>
39where
40 T: 'static + Send,
41 E: 'static + Send + Debug,
42{
43 pub fn detach_and_log_err(self, cx: &mut AppContext) {
44 cx.executor().spawn(self.log_err()).detach();
45 }
46}
47
48impl<T> Future for Task<T> {
49 type Output = T;
50
51 fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
52 match unsafe { self.get_unchecked_mut() } {
53 Task::Ready(val) => Poll::Ready(val.take().unwrap()),
54 Task::Spawned(task) => task.poll(cx),
55 }
56 }
57}
58
59impl Executor {
60 pub fn new(dispatcher: Arc<dyn PlatformDispatcher>) -> Self {
61 Self { dispatcher }
62 }
63
64 /// Enqueues the given closure to be run on any thread. The closure returns
65 /// a future which will be run to completion on any available thread.
66 pub fn spawn<R>(&self, future: impl Future<Output = R> + Send + 'static) -> Task<R>
67 where
68 R: Send + 'static,
69 {
70 let dispatcher = self.dispatcher.clone();
71 let (runnable, task) =
72 async_task::spawn(future, move |runnable| dispatcher.dispatch(runnable));
73 runnable.schedule();
74 Task::Spawned(task)
75 }
76
77 /// Enqueues the given closure to run on the application's event loop.
78 /// Returns the result asynchronously.
79 pub fn run_on_main<F, R>(&self, func: F) -> Task<R>
80 where
81 F: FnOnce() -> R + Send + 'static,
82 R: Send + 'static,
83 {
84 if self.dispatcher.is_main_thread() {
85 Task::ready(func())
86 } else {
87 self.spawn_on_main(move || async move { func() })
88 }
89 }
90
91 /// Enqueues the given closure to be run on the application's event loop. The
92 /// closure returns a future which will be run to completion on the main thread.
93 pub fn spawn_on_main<F, R>(&self, func: impl FnOnce() -> F + Send + 'static) -> Task<R>
94 where
95 F: Future<Output = R> + 'static,
96 R: Send + 'static,
97 {
98 let (runnable, task) = async_task::spawn(
99 {
100 let this = self.clone();
101 async move {
102 let task = this.spawn_on_main_local(func());
103 task.await
104 }
105 },
106 {
107 let dispatcher = self.dispatcher.clone();
108 move |runnable| dispatcher.dispatch_on_main_thread(runnable)
109 },
110 );
111 runnable.schedule();
112 Task::Spawned(task)
113 }
114
115 /// Enqueues the given closure to be run on the application's event loop. Must
116 /// be called on the main thread.
117 pub fn spawn_on_main_local<R>(&self, future: impl Future<Output = R> + 'static) -> Task<R>
118 where
119 R: 'static,
120 {
121 assert!(
122 self.dispatcher.is_main_thread(),
123 "must be called on main thread"
124 );
125
126 let dispatcher = self.dispatcher.clone();
127 let (runnable, task) = async_task::spawn_local(future, move |runnable| {
128 dispatcher.dispatch_on_main_thread(runnable)
129 });
130 runnable.schedule();
131 Task::Spawned(task)
132 }
133
134 pub fn block<R>(&self, future: impl Future<Output = R>) -> R {
135 // todo!("integrate with deterministic dispatcher")
136 futures::executor::block_on(future)
137 }
138
139 pub async fn scoped<'scope, F>(&self, scheduler: F)
140 where
141 F: FnOnce(&mut Scope<'scope>),
142 {
143 let mut scope = Scope::new(self.clone());
144 (scheduler)(&mut scope);
145 let spawned = mem::take(&mut scope.futures)
146 .into_iter()
147 .map(|f| self.spawn(f))
148 .collect::<Vec<_>>();
149 for task in spawned {
150 task.await;
151 }
152 }
153
154 pub fn is_main_thread(&self) -> bool {
155 self.dispatcher.is_main_thread()
156 }
157}
158
159pub struct Scope<'a> {
160 executor: Executor,
161 futures: Vec<Pin<Box<dyn Future<Output = ()> + Send + 'static>>>,
162 tx: Option<mpsc::Sender<()>>,
163 rx: mpsc::Receiver<()>,
164 lifetime: PhantomData<&'a ()>,
165}
166
167impl<'a> Scope<'a> {
168 fn new(executor: Executor) -> Self {
169 let (tx, rx) = mpsc::channel(1);
170 Self {
171 executor,
172 tx: Some(tx),
173 rx,
174 futures: Default::default(),
175 lifetime: PhantomData,
176 }
177 }
178
179 pub fn spawn<F>(&mut self, f: F)
180 where
181 F: Future<Output = ()> + Send + 'a,
182 {
183 let tx = self.tx.clone().unwrap();
184
185 // Safety: The 'a lifetime is guaranteed to outlive any of these futures because
186 // dropping this `Scope` blocks until all of the futures have resolved.
187 let f = unsafe {
188 mem::transmute::<
189 Pin<Box<dyn Future<Output = ()> + Send + 'a>>,
190 Pin<Box<dyn Future<Output = ()> + Send + 'static>>,
191 >(Box::pin(async move {
192 f.await;
193 drop(tx);
194 }))
195 };
196 self.futures.push(f);
197 }
198}
199
200impl<'a> Drop for Scope<'a> {
201 fn drop(&mut self) {
202 self.tx.take().unwrap();
203
204 // Wait until the channel is closed, which means that all of the spawned
205 // futures have resolved.
206 self.executor.block(self.rx.next());
207 }
208}