1use crate::{AppContext, PlatformDispatcher};
2use futures::{channel::mpsc, pin_mut};
3use smol::prelude::*;
4use std::{
5 fmt::Debug,
6 marker::PhantomData,
7 mem,
8 pin::Pin,
9 sync::Arc,
10 task::{Context, Poll},
11 time::Duration,
12};
13use util::TryFutureExt;
14use waker_fn::waker_fn;
15
16#[derive(Clone)]
17pub struct Executor {
18 dispatcher: Arc<dyn PlatformDispatcher>,
19}
20
21#[must_use]
22pub enum Task<T> {
23 Ready(Option<T>),
24 Spawned(async_task::Task<T>),
25}
26
27impl<T> Task<T> {
28 pub fn ready(val: T) -> Self {
29 Task::Ready(Some(val))
30 }
31
32 pub fn detach(self) {
33 match self {
34 Task::Ready(_) => {}
35 Task::Spawned(task) => task.detach(),
36 }
37 }
38}
39
40impl<E, T> Task<Result<T, E>>
41where
42 T: 'static + Send,
43 E: 'static + Send + Debug,
44{
45 pub fn detach_and_log_err(self, cx: &mut AppContext) {
46 cx.executor().spawn(self.log_err()).detach();
47 }
48}
49
50impl<T> Future for Task<T> {
51 type Output = T;
52
53 fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
54 match unsafe { self.get_unchecked_mut() } {
55 Task::Ready(val) => Poll::Ready(val.take().unwrap()),
56 Task::Spawned(task) => task.poll(cx),
57 }
58 }
59}
60
61impl Executor {
62 pub fn new(dispatcher: Arc<dyn PlatformDispatcher>) -> Self {
63 Self { dispatcher }
64 }
65
66 /// Enqueues the given closure to be run on any thread. The closure returns
67 /// a future which will be run to completion on any available thread.
68 pub fn spawn<R>(&self, future: impl Future<Output = R> + Send + 'static) -> Task<R>
69 where
70 R: Send + 'static,
71 {
72 let dispatcher = self.dispatcher.clone();
73 let (runnable, task) =
74 async_task::spawn(future, move |runnable| dispatcher.dispatch(runnable));
75 runnable.schedule();
76 Task::Spawned(task)
77 }
78
79 /// Enqueues the given closure to run on the application's event loop.
80 /// Returns the result asynchronously.
81 pub fn run_on_main<F, R>(&self, func: F) -> Task<R>
82 where
83 F: FnOnce() -> R + Send + 'static,
84 R: Send + 'static,
85 {
86 if self.dispatcher.is_main_thread() {
87 Task::ready(func())
88 } else {
89 self.spawn_on_main(move || async move { func() })
90 }
91 }
92
93 /// Enqueues the given closure to be run on the application's event loop. The
94 /// closure returns a future which will be run to completion on the main thread.
95 pub fn spawn_on_main<F, R>(&self, func: impl FnOnce() -> F + Send + 'static) -> Task<R>
96 where
97 F: Future<Output = R> + 'static,
98 R: Send + 'static,
99 {
100 let (runnable, task) = async_task::spawn(
101 {
102 let this = self.clone();
103 async move {
104 let task = this.spawn_on_main_local(func());
105 task.await
106 }
107 },
108 {
109 let dispatcher = self.dispatcher.clone();
110 move |runnable| dispatcher.dispatch_on_main_thread(runnable)
111 },
112 );
113 runnable.schedule();
114 Task::Spawned(task)
115 }
116
117 /// Enqueues the given closure to be run on the application's event loop. Must
118 /// be called on the main thread.
119 pub fn spawn_on_main_local<R>(&self, future: impl Future<Output = R> + 'static) -> Task<R>
120 where
121 R: 'static,
122 {
123 assert!(
124 self.dispatcher.is_main_thread(),
125 "must be called on main thread"
126 );
127
128 let dispatcher = self.dispatcher.clone();
129 let (runnable, task) = async_task::spawn_local(future, move |runnable| {
130 dispatcher.dispatch_on_main_thread(runnable)
131 });
132 runnable.schedule();
133 Task::Spawned(task)
134 }
135
136 pub fn block<R>(&self, future: impl Future<Output = R>) -> R {
137 pin_mut!(future);
138 let (parker, unparker) = parking::pair();
139 let waker = waker_fn(move || {
140 unparker.unpark();
141 });
142 let mut cx = std::task::Context::from_waker(&waker);
143
144 loop {
145 match future.as_mut().poll(&mut cx) {
146 Poll::Ready(result) => return result,
147 Poll::Pending => {
148 if !self.dispatcher.poll() {
149 #[cfg(any(test, feature = "test-support"))]
150 if let Some(_) = self.dispatcher.as_test() {
151 panic!("blocked with nothing left to run")
152 }
153 parker.park();
154 }
155 }
156 }
157 }
158 }
159
160 pub fn block_with_timeout<R>(
161 &self,
162 duration: Duration,
163 future: impl Future<Output = R>,
164 ) -> Result<R, impl Future<Output = R>> {
165 let mut future = Box::pin(future);
166 if duration.is_zero() {
167 return Err(future);
168 }
169
170 let timeout = {
171 let future = &mut future;
172 async {
173 let timer = async {
174 self.timer(duration).await;
175 Err(())
176 };
177 let future = async move { Ok(future.await) };
178 timer.race(future).await
179 }
180 };
181 match self.block(timeout) {
182 Ok(value) => Ok(value),
183 Err(_) => Err(future),
184 }
185 }
186
187 pub async fn scoped<'scope, F>(&self, scheduler: F)
188 where
189 F: FnOnce(&mut Scope<'scope>),
190 {
191 let mut scope = Scope::new(self.clone());
192 (scheduler)(&mut scope);
193 let spawned = mem::take(&mut scope.futures)
194 .into_iter()
195 .map(|f| self.spawn(f))
196 .collect::<Vec<_>>();
197 for task in spawned {
198 task.await;
199 }
200 }
201
202 pub fn timer(&self, duration: Duration) -> Task<()> {
203 let (runnable, task) = async_task::spawn(async move {}, {
204 let dispatcher = self.dispatcher.clone();
205 move |runnable| dispatcher.dispatch_after(duration, runnable)
206 });
207 runnable.schedule();
208 Task::Spawned(task)
209 }
210
211 #[cfg(any(test, feature = "test-support"))]
212 pub fn start_waiting(&self) {
213 todo!("start_waiting")
214 }
215
216 #[cfg(any(test, feature = "test-support"))]
217 pub fn finish_waiting(&self) {
218 todo!("finish_waiting")
219 }
220
221 #[cfg(any(test, feature = "test-support"))]
222 pub fn simulate_random_delay(&self) -> impl Future<Output = ()> {
223 self.dispatcher.as_test().unwrap().simulate_random_delay()
224 }
225
226 #[cfg(any(test, feature = "test-support"))]
227 pub fn advance_clock(&self, duration: Duration) {
228 self.dispatcher.as_test().unwrap().advance_clock(duration)
229 }
230
231 #[cfg(any(test, feature = "test-support"))]
232 pub fn run_until_parked(&self) {
233 self.dispatcher.as_test().unwrap().run_until_parked()
234 }
235
236 pub fn num_cpus(&self) -> usize {
237 num_cpus::get()
238 }
239
240 pub fn is_main_thread(&self) -> bool {
241 self.dispatcher.is_main_thread()
242 }
243}
244
245pub struct Scope<'a> {
246 executor: Executor,
247 futures: Vec<Pin<Box<dyn Future<Output = ()> + Send + 'static>>>,
248 tx: Option<mpsc::Sender<()>>,
249 rx: mpsc::Receiver<()>,
250 lifetime: PhantomData<&'a ()>,
251}
252
253impl<'a> Scope<'a> {
254 fn new(executor: Executor) -> Self {
255 let (tx, rx) = mpsc::channel(1);
256 Self {
257 executor,
258 tx: Some(tx),
259 rx,
260 futures: Default::default(),
261 lifetime: PhantomData,
262 }
263 }
264
265 pub fn spawn<F>(&mut self, f: F)
266 where
267 F: Future<Output = ()> + Send + 'a,
268 {
269 let tx = self.tx.clone().unwrap();
270
271 // Safety: The 'a lifetime is guaranteed to outlive any of these futures because
272 // dropping this `Scope` blocks until all of the futures have resolved.
273 let f = unsafe {
274 mem::transmute::<
275 Pin<Box<dyn Future<Output = ()> + Send + 'a>>,
276 Pin<Box<dyn Future<Output = ()> + Send + 'static>>,
277 >(Box::pin(async move {
278 f.await;
279 drop(tx);
280 }))
281 };
282 self.futures.push(f);
283 }
284}
285
286impl<'a> Drop for Scope<'a> {
287 fn drop(&mut self) {
288 self.tx.take().unwrap();
289
290 // Wait until the channel is closed, which means that all of the spawned
291 // futures have resolved.
292 self.executor.block(self.rx.next());
293 }
294}