1use crate::{AppContext, PlatformDispatcher};
2use futures::{channel::mpsc, pin_mut, FutureExt};
3use smol::prelude::*;
4use std::{
5 fmt::Debug,
6 marker::PhantomData,
7 mem,
8 pin::Pin,
9 sync::{
10 atomic::{AtomicBool, Ordering::SeqCst},
11 Arc,
12 },
13 task::{Context, Poll},
14 time::Duration,
15};
16use util::TryFutureExt;
17use waker_fn::waker_fn;
18
19#[derive(Clone)]
20pub struct Executor {
21 dispatcher: Arc<dyn PlatformDispatcher>,
22}
23
24#[must_use]
25pub enum Task<T> {
26 Ready(Option<T>),
27 Spawned(async_task::Task<T>),
28}
29
30impl<T> Task<T> {
31 pub fn ready(val: T) -> Self {
32 Task::Ready(Some(val))
33 }
34
35 pub fn detach(self) {
36 match self {
37 Task::Ready(_) => {}
38 Task::Spawned(task) => task.detach(),
39 }
40 }
41}
42
43impl<E, T> Task<Result<T, E>>
44where
45 T: 'static + Send,
46 E: 'static + Send + Debug,
47{
48 pub fn detach_and_log_err(self, cx: &mut AppContext) {
49 cx.executor().spawn(self.log_err()).detach();
50 }
51}
52
53impl<T> Future for Task<T> {
54 type Output = T;
55
56 fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
57 match unsafe { self.get_unchecked_mut() } {
58 Task::Ready(val) => Poll::Ready(val.take().unwrap()),
59 Task::Spawned(task) => task.poll(cx),
60 }
61 }
62}
63
64impl Executor {
65 pub fn new(dispatcher: Arc<dyn PlatformDispatcher>) -> Self {
66 Self { dispatcher }
67 }
68
69 /// Enqueues the given closure to be run on any thread. The closure returns
70 /// a future which will be run to completion on any available thread.
71 pub fn spawn<R>(&self, future: impl Future<Output = R> + Send + 'static) -> Task<R>
72 where
73 R: Send + 'static,
74 {
75 let dispatcher = self.dispatcher.clone();
76 let (runnable, task) =
77 async_task::spawn(future, move |runnable| dispatcher.dispatch(runnable));
78 runnable.schedule();
79 Task::Spawned(task)
80 }
81
82 /// Enqueues the given closure to run on the application's event loop.
83 /// Returns the result asynchronously.
84 pub fn run_on_main<F, R>(&self, func: F) -> Task<R>
85 where
86 F: FnOnce() -> R + Send + 'static,
87 R: Send + 'static,
88 {
89 if self.dispatcher.is_main_thread() {
90 Task::ready(func())
91 } else {
92 self.spawn_on_main(move || async move { func() })
93 }
94 }
95
96 /// Enqueues the given closure to be run on the application's event loop. The
97 /// closure returns a future which will be run to completion on the main thread.
98 pub fn spawn_on_main<F, R>(&self, func: impl FnOnce() -> F + Send + 'static) -> Task<R>
99 where
100 F: Future<Output = R> + 'static,
101 R: Send + 'static,
102 {
103 let (runnable, task) = async_task::spawn(
104 {
105 let this = self.clone();
106 async move {
107 let task = this.spawn_on_main_local(func());
108 task.await
109 }
110 },
111 {
112 let dispatcher = self.dispatcher.clone();
113 move |runnable| dispatcher.dispatch_on_main_thread(runnable)
114 },
115 );
116 runnable.schedule();
117 Task::Spawned(task)
118 }
119
120 /// Enqueues the given closure to be run on the application's event loop. Must
121 /// be called on the main thread.
122 pub fn spawn_on_main_local<R>(&self, future: impl Future<Output = R> + 'static) -> Task<R>
123 where
124 R: 'static,
125 {
126 assert!(
127 self.dispatcher.is_main_thread(),
128 "must be called on main thread"
129 );
130
131 let dispatcher = self.dispatcher.clone();
132 let (runnable, task) = async_task::spawn_local(future, move |runnable| {
133 dispatcher.dispatch_on_main_thread(runnable)
134 });
135 runnable.schedule();
136 Task::Spawned(task)
137 }
138
139 pub fn block<R>(&self, future: impl Future<Output = R>) -> R {
140 pin_mut!(future);
141 let (parker, unparker) = parking::pair();
142 let awoken = Arc::new(AtomicBool::new(false));
143 let awoken2 = awoken.clone();
144
145 let waker = waker_fn(move || {
146 awoken2.store(true, SeqCst);
147 unparker.unpark();
148 });
149 let mut cx = std::task::Context::from_waker(&waker);
150
151 loop {
152 match future.as_mut().poll(&mut cx) {
153 Poll::Ready(result) => return result,
154 Poll::Pending => {
155 if !self.dispatcher.poll() {
156 if awoken.swap(false, SeqCst) {
157 continue;
158 }
159
160 #[cfg(any(test, feature = "test-support"))]
161 if let Some(test) = self.dispatcher.as_test() {
162 if !test.parking_allowed() {
163 let mut backtrace_message = String::new();
164 if let Some(backtrace) = test.waiting_backtrace() {
165 backtrace_message =
166 format!("\nbacktrace of waiting future:\n{:?}", backtrace);
167 }
168 panic!("parked with nothing left to run\n{:?}", backtrace_message)
169 }
170 }
171 parker.park();
172 }
173 }
174 }
175 }
176 }
177
178 pub fn block_with_timeout<R>(
179 &self,
180 duration: Duration,
181 future: impl Future<Output = R>,
182 ) -> Result<R, impl Future<Output = R>> {
183 let mut future = Box::pin(future.fuse());
184 if duration.is_zero() {
185 return Err(future);
186 }
187
188 let mut timer = self.timer(duration).fuse();
189 let timeout = async {
190 futures::select_biased! {
191 value = future => Ok(value),
192 _ = timer => Err(()),
193 }
194 };
195 match self.block(timeout) {
196 Ok(value) => Ok(value),
197 Err(_) => Err(future),
198 }
199 }
200
201 pub async fn scoped<'scope, F>(&self, scheduler: F)
202 where
203 F: FnOnce(&mut Scope<'scope>),
204 {
205 let mut scope = Scope::new(self.clone());
206 (scheduler)(&mut scope);
207 let spawned = mem::take(&mut scope.futures)
208 .into_iter()
209 .map(|f| self.spawn(f))
210 .collect::<Vec<_>>();
211 for task in spawned {
212 task.await;
213 }
214 }
215
216 pub fn timer(&self, duration: Duration) -> Task<()> {
217 let (runnable, task) = async_task::spawn(async move {}, {
218 let dispatcher = self.dispatcher.clone();
219 move |runnable| dispatcher.dispatch_after(duration, runnable)
220 });
221 runnable.schedule();
222 Task::Spawned(task)
223 }
224
225 #[cfg(any(test, feature = "test-support"))]
226 pub fn start_waiting(&self) {
227 self.dispatcher.as_test().unwrap().start_waiting();
228 }
229
230 #[cfg(any(test, feature = "test-support"))]
231 pub fn finish_waiting(&self) {
232 self.dispatcher.as_test().unwrap().finish_waiting();
233 }
234
235 #[cfg(any(test, feature = "test-support"))]
236 pub fn simulate_random_delay(&self) -> impl Future<Output = ()> {
237 self.spawn(self.dispatcher.as_test().unwrap().simulate_random_delay())
238 }
239
240 #[cfg(any(test, feature = "test-support"))]
241 pub fn advance_clock(&self, duration: Duration) {
242 self.dispatcher.as_test().unwrap().advance_clock(duration)
243 }
244
245 #[cfg(any(test, feature = "test-support"))]
246 pub fn run_until_parked(&self) {
247 self.dispatcher.as_test().unwrap().run_until_parked()
248 }
249
250 #[cfg(any(test, feature = "test-support"))]
251 pub fn allow_parking(&self) {
252 self.dispatcher.as_test().unwrap().allow_parking();
253 }
254
255 pub fn num_cpus(&self) -> usize {
256 num_cpus::get()
257 }
258
259 pub fn is_main_thread(&self) -> bool {
260 self.dispatcher.is_main_thread()
261 }
262}
263
264pub struct Scope<'a> {
265 executor: Executor,
266 futures: Vec<Pin<Box<dyn Future<Output = ()> + Send + 'static>>>,
267 tx: Option<mpsc::Sender<()>>,
268 rx: mpsc::Receiver<()>,
269 lifetime: PhantomData<&'a ()>,
270}
271
272impl<'a> Scope<'a> {
273 fn new(executor: Executor) -> Self {
274 let (tx, rx) = mpsc::channel(1);
275 Self {
276 executor,
277 tx: Some(tx),
278 rx,
279 futures: Default::default(),
280 lifetime: PhantomData,
281 }
282 }
283
284 pub fn spawn<F>(&mut self, f: F)
285 where
286 F: Future<Output = ()> + Send + 'a,
287 {
288 let tx = self.tx.clone().unwrap();
289
290 // Safety: The 'a lifetime is guaranteed to outlive any of these futures because
291 // dropping this `Scope` blocks until all of the futures have resolved.
292 let f = unsafe {
293 mem::transmute::<
294 Pin<Box<dyn Future<Output = ()> + Send + 'a>>,
295 Pin<Box<dyn Future<Output = ()> + Send + 'static>>,
296 >(Box::pin(async move {
297 f.await;
298 drop(tx);
299 }))
300 };
301 self.futures.push(f);
302 }
303}
304
305impl<'a> Drop for Scope<'a> {
306 fn drop(&mut self) {
307 self.tx.take().unwrap();
308
309 // Wait until the channel is closed, which means that all of the spawned
310 // futures have resolved.
311 self.executor.block(self.rx.next());
312 }
313}