1use crate::{AppContext, PlatformDispatcher};
2use futures::{channel::mpsc, pin_mut};
3use smol::prelude::*;
4use std::{
5 borrow::BorrowMut,
6 fmt::Debug,
7 marker::PhantomData,
8 mem,
9 pin::Pin,
10 sync::{
11 atomic::{AtomicBool, Ordering::SeqCst},
12 Arc,
13 },
14 task::{Context, Poll},
15 time::Duration,
16};
17use util::TryFutureExt;
18use waker_fn::waker_fn;
19
20#[derive(Clone)]
21pub struct Executor {
22 dispatcher: Arc<dyn PlatformDispatcher>,
23}
24
25#[must_use]
26pub enum Task<T> {
27 Ready(Option<T>),
28 Spawned(async_task::Task<T>),
29}
30
31impl<T> Task<T> {
32 pub fn ready(val: T) -> Self {
33 Task::Ready(Some(val))
34 }
35
36 pub fn detach(self) {
37 match self {
38 Task::Ready(_) => {}
39 Task::Spawned(task) => task.detach(),
40 }
41 }
42}
43
44impl<E, T> Task<Result<T, E>>
45where
46 T: 'static + Send,
47 E: 'static + Send + Debug,
48{
49 pub fn detach_and_log_err(self, cx: &mut AppContext) {
50 cx.executor().spawn(self.log_err()).detach();
51 }
52}
53
54impl<T> Future for Task<T> {
55 type Output = T;
56
57 fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
58 match unsafe { self.get_unchecked_mut() } {
59 Task::Ready(val) => Poll::Ready(val.take().unwrap()),
60 Task::Spawned(task) => task.poll(cx),
61 }
62 }
63}
64
65impl Executor {
66 pub fn new(dispatcher: Arc<dyn PlatformDispatcher>) -> Self {
67 Self { dispatcher }
68 }
69
70 /// Enqueues the given closure to be run on any thread. The closure returns
71 /// a future which will be run to completion on any available thread.
72 pub fn spawn<R>(&self, future: impl Future<Output = R> + Send + 'static) -> Task<R>
73 where
74 R: Send + 'static,
75 {
76 let dispatcher = self.dispatcher.clone();
77 let (runnable, task) =
78 async_task::spawn(future, move |runnable| dispatcher.dispatch(runnable));
79 runnable.schedule();
80 Task::Spawned(task)
81 }
82
83 /// Enqueues the given closure to run on the application's event loop.
84 /// Returns the result asynchronously.
85 pub fn run_on_main<F, R>(&self, func: F) -> Task<R>
86 where
87 F: FnOnce() -> R + Send + 'static,
88 R: Send + 'static,
89 {
90 if self.dispatcher.is_main_thread() {
91 Task::ready(func())
92 } else {
93 self.spawn_on_main(move || async move { func() })
94 }
95 }
96
97 /// Enqueues the given closure to be run on the application's event loop. The
98 /// closure returns a future which will be run to completion on the main thread.
99 pub fn spawn_on_main<F, R>(&self, func: impl FnOnce() -> F + Send + 'static) -> Task<R>
100 where
101 F: Future<Output = R> + 'static,
102 R: Send + 'static,
103 {
104 let (runnable, task) = async_task::spawn(
105 {
106 let this = self.clone();
107 async move {
108 let task = this.spawn_on_main_local(func());
109 task.await
110 }
111 },
112 {
113 let dispatcher = self.dispatcher.clone();
114 move |runnable| dispatcher.dispatch_on_main_thread(runnable)
115 },
116 );
117 runnable.schedule();
118 Task::Spawned(task)
119 }
120
121 /// Enqueues the given closure to be run on the application's event loop. Must
122 /// be called on the main thread.
123 pub fn spawn_on_main_local<R>(&self, future: impl Future<Output = R> + 'static) -> Task<R>
124 where
125 R: 'static,
126 {
127 assert!(
128 self.dispatcher.is_main_thread(),
129 "must be called on main thread"
130 );
131
132 let dispatcher = self.dispatcher.clone();
133 let (runnable, task) = async_task::spawn_local(future, move |runnable| {
134 dispatcher.dispatch_on_main_thread(runnable)
135 });
136 runnable.schedule();
137 Task::Spawned(task)
138 }
139
140 pub fn block<R>(&self, future: impl Future<Output = R>) -> R {
141 pin_mut!(future);
142 let (parker, unparker) = parking::pair();
143 let awoken = Arc::new(AtomicBool::new(false));
144 let awoken2 = awoken.clone();
145
146 let waker = waker_fn(move || {
147 awoken2.store(true, SeqCst);
148 unparker.unpark();
149 });
150 let mut cx = std::task::Context::from_waker(&waker);
151
152 loop {
153 match future.as_mut().poll(&mut cx) {
154 Poll::Ready(result) => return result,
155 Poll::Pending => {
156 if !self.dispatcher.poll() {
157 if awoken.swap(false, SeqCst) {
158 continue;
159 }
160
161 #[cfg(any(test, feature = "test-support"))]
162 if let Some(test) = self.dispatcher.as_test() {
163 if !test.parking_allowed() {
164 panic!("blocked with nothing left to run")
165 }
166 }
167 parker.park();
168 }
169 }
170 }
171 }
172 }
173
174 pub fn block_with_timeout<R>(
175 &self,
176 duration: Duration,
177 future: impl Future<Output = R>,
178 ) -> Result<R, impl Future<Output = R>> {
179 let mut future = Box::pin(future);
180 let timeout = {
181 let future = &mut future;
182 async {
183 let timer = async {
184 self.timer(duration).await;
185 Err(())
186 };
187 let future = async move { Ok(future.await) };
188 timer.race(future).await
189 }
190 };
191 match self.block(timeout) {
192 Ok(value) => Ok(value),
193 Err(_) => Err(future),
194 }
195 }
196
197 pub async fn scoped<'scope, F>(&self, scheduler: F)
198 where
199 F: FnOnce(&mut Scope<'scope>),
200 {
201 let mut scope = Scope::new(self.clone());
202 (scheduler)(&mut scope);
203 let spawned = mem::take(&mut scope.futures)
204 .into_iter()
205 .map(|f| self.spawn(f))
206 .collect::<Vec<_>>();
207 for task in spawned {
208 task.await;
209 }
210 }
211
212 pub fn timer(&self, duration: Duration) -> Task<()> {
213 let (runnable, task) = async_task::spawn(async move {}, {
214 let dispatcher = self.dispatcher.clone();
215 move |runnable| dispatcher.dispatch_after(duration, runnable)
216 });
217 runnable.schedule();
218 Task::Spawned(task)
219 }
220
221 #[cfg(any(test, feature = "test-support"))]
222 pub fn start_waiting(&self) {
223 todo!("start_waiting")
224 }
225
226 #[cfg(any(test, feature = "test-support"))]
227 pub fn finish_waiting(&self) {
228 todo!("finish_waiting")
229 }
230
231 #[cfg(any(test, feature = "test-support"))]
232 pub fn simulate_random_delay(&self) -> impl Future<Output = ()> {
233 self.spawn(self.dispatcher.as_test().unwrap().simulate_random_delay())
234 }
235
236 #[cfg(any(test, feature = "test-support"))]
237 pub fn advance_clock(&self, duration: Duration) {
238 self.dispatcher.as_test().unwrap().advance_clock(duration)
239 }
240
241 #[cfg(any(test, feature = "test-support"))]
242 pub fn run_until_parked(&self) {
243 self.dispatcher.as_test().unwrap().run_until_parked()
244 }
245
246 #[cfg(any(test, feature = "test-support"))]
247 pub fn allow_parking(&self) {
248 self.dispatcher.as_test().unwrap().allow_parking();
249 }
250
251 pub fn num_cpus(&self) -> usize {
252 num_cpus::get()
253 }
254
255 pub fn is_main_thread(&self) -> bool {
256 self.dispatcher.is_main_thread()
257 }
258}
259
260pub struct Scope<'a> {
261 executor: Executor,
262 futures: Vec<Pin<Box<dyn Future<Output = ()> + Send + 'static>>>,
263 tx: Option<mpsc::Sender<()>>,
264 rx: mpsc::Receiver<()>,
265 lifetime: PhantomData<&'a ()>,
266}
267
268impl<'a> Scope<'a> {
269 fn new(executor: Executor) -> Self {
270 let (tx, rx) = mpsc::channel(1);
271 Self {
272 executor,
273 tx: Some(tx),
274 rx,
275 futures: Default::default(),
276 lifetime: PhantomData,
277 }
278 }
279
280 pub fn spawn<F>(&mut self, f: F)
281 where
282 F: Future<Output = ()> + Send + 'a,
283 {
284 let tx = self.tx.clone().unwrap();
285
286 // Safety: The 'a lifetime is guaranteed to outlive any of these futures because
287 // dropping this `Scope` blocks until all of the futures have resolved.
288 let f = unsafe {
289 mem::transmute::<
290 Pin<Box<dyn Future<Output = ()> + Send + 'a>>,
291 Pin<Box<dyn Future<Output = ()> + Send + 'static>>,
292 >(Box::pin(async move {
293 f.await;
294 drop(tx);
295 }))
296 };
297 self.futures.push(f);
298 }
299}
300
301impl<'a> Drop for Scope<'a> {
302 fn drop(&mut self) {
303 self.tx.take().unwrap();
304
305 // Wait until the channel is closed, which means that all of the spawned
306 // futures have resolved.
307 self.executor.block(self.rx.next());
308 }
309}