1use crate::{AppContext, PlatformDispatcher};
2use futures::{channel::mpsc, pin_mut, FutureExt};
3use smol::prelude::*;
4use std::{
5 fmt::Debug,
6 marker::PhantomData,
7 mem,
8 pin::Pin,
9 rc::Rc,
10 sync::{
11 atomic::{AtomicBool, Ordering::SeqCst},
12 Arc,
13 },
14 task::{Context, Poll},
15 time::Duration,
16};
17use util::TryFutureExt;
18use waker_fn::waker_fn;
19
20#[derive(Clone)]
21pub struct BackgroundExecutor {
22 dispatcher: Arc<dyn PlatformDispatcher>,
23}
24
25#[derive(Clone)]
26pub struct ForegroundExecutor {
27 dispatcher: Arc<dyn PlatformDispatcher>,
28 not_send: PhantomData<Rc<()>>,
29}
30
31#[must_use]
32pub enum Task<T> {
33 Ready(Option<T>),
34 Spawned(async_task::Task<T>),
35}
36
37impl<T> Task<T> {
38 pub fn ready(val: T) -> Self {
39 Task::Ready(Some(val))
40 }
41
42 pub fn detach(self) {
43 match self {
44 Task::Ready(_) => {}
45 Task::Spawned(task) => task.detach(),
46 }
47 }
48}
49
50impl<E, T> Task<Result<T, E>>
51where
52 T: 'static + Send,
53 E: 'static + Send + Debug,
54{
55 pub fn detach_and_log_err(self, cx: &mut AppContext) {
56 cx.background_executor().spawn(self.log_err()).detach();
57 }
58}
59
60impl<T> Future for Task<T> {
61 type Output = T;
62
63 fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
64 match unsafe { self.get_unchecked_mut() } {
65 Task::Ready(val) => Poll::Ready(val.take().unwrap()),
66 Task::Spawned(task) => task.poll(cx),
67 }
68 }
69}
70
71impl BackgroundExecutor {
72 pub fn new(dispatcher: Arc<dyn PlatformDispatcher>) -> Self {
73 Self { dispatcher }
74 }
75
76 /// Enqueues the given closure to be run on any thread. The closure returns
77 /// a future which will be run to completion on any available thread.
78 pub fn spawn<R>(&self, future: impl Future<Output = R> + Send + 'static) -> Task<R>
79 where
80 R: Send + 'static,
81 {
82 let dispatcher = self.dispatcher.clone();
83 let (runnable, task) =
84 async_task::spawn(future, move |runnable| dispatcher.dispatch(runnable));
85 runnable.schedule();
86 Task::Spawned(task)
87 }
88
89 #[cfg(any(test, feature = "test-support"))]
90 pub fn block_test<R>(&self, future: impl Future<Output = R>) -> R {
91 let (runnable, task) = unsafe {
92 async_task::spawn_unchecked(future, {
93 let dispatcher = self.dispatcher.clone();
94 move |runnable| dispatcher.dispatch_on_main_thread(runnable)
95 })
96 };
97
98 runnable.schedule();
99
100 self.block_internal(false, task)
101 }
102
103 pub fn block<R>(&self, future: impl Future<Output = R>) -> R {
104 self.block_internal(true, future)
105 }
106
107 pub(crate) fn block_internal<R>(
108 &self,
109 background_only: bool,
110 future: impl Future<Output = R>,
111 ) -> R {
112 dbg!("block_internal");
113 pin_mut!(future);
114 let (parker, unparker) = parking::pair();
115 let awoken = Arc::new(AtomicBool::new(false));
116 let awoken2 = awoken.clone();
117
118 let waker = waker_fn(move || {
119 dbg!("WAKING UP.");
120 awoken2.store(true, SeqCst);
121 unparker.unpark();
122 });
123 let mut cx = std::task::Context::from_waker(&waker);
124
125 dbg!("BOOOP");
126 loop {
127 match future.as_mut().poll(&mut cx) {
128 Poll::Ready(result) => return result,
129 Poll::Pending => {
130 if !self.dispatcher.poll(background_only) {
131 if awoken.swap(false, SeqCst) {
132 continue;
133 }
134
135 #[cfg(any(test, feature = "test-support"))]
136 if let Some(test) = self.dispatcher.as_test() {
137 if !test.parking_allowed() {
138 let mut backtrace_message = String::new();
139 if let Some(backtrace) = test.waiting_backtrace() {
140 backtrace_message =
141 format!("\nbacktrace of waiting future:\n{:?}", backtrace);
142 }
143 panic!("parked with nothing left to run\n{:?}", backtrace_message)
144 }
145 }
146 dbg!("PARKING!");
147 parker.park();
148 dbg!("CONTINUING!");
149 }
150 }
151 }
152 }
153 }
154
155 pub fn block_with_timeout<R>(
156 &self,
157 duration: Duration,
158 future: impl Future<Output = R>,
159 ) -> Result<R, impl Future<Output = R>> {
160 let mut future = Box::pin(future.fuse());
161 if duration.is_zero() {
162 return Err(future);
163 }
164
165 let mut timer = self.timer(duration).fuse();
166 let timeout = async {
167 futures::select_biased! {
168 value = future => Ok(value),
169 _ = timer => Err(()),
170 }
171 };
172 match self.block(timeout) {
173 Ok(value) => Ok(value),
174 Err(_) => Err(future),
175 }
176 }
177
178 pub async fn scoped<'scope, F>(&self, scheduler: F)
179 where
180 F: FnOnce(&mut Scope<'scope>),
181 {
182 let mut scope = Scope::new(self.clone());
183 (scheduler)(&mut scope);
184 let spawned = mem::take(&mut scope.futures)
185 .into_iter()
186 .map(|f| self.spawn(f))
187 .collect::<Vec<_>>();
188 for task in spawned {
189 task.await;
190 }
191 }
192
193 pub fn timer(&self, duration: Duration) -> Task<()> {
194 let (runnable, task) = async_task::spawn(async move {}, {
195 let dispatcher = self.dispatcher.clone();
196 move |runnable| dispatcher.dispatch_after(duration, runnable)
197 });
198 runnable.schedule();
199 Task::Spawned(task)
200 }
201
202 #[cfg(any(test, feature = "test-support"))]
203 pub fn start_waiting(&self) {
204 self.dispatcher.as_test().unwrap().start_waiting();
205 }
206
207 #[cfg(any(test, feature = "test-support"))]
208 pub fn finish_waiting(&self) {
209 self.dispatcher.as_test().unwrap().finish_waiting();
210 }
211
212 #[cfg(any(test, feature = "test-support"))]
213 pub fn simulate_random_delay(&self) -> impl Future<Output = ()> {
214 self.dispatcher.as_test().unwrap().simulate_random_delay()
215 }
216
217 #[cfg(any(test, feature = "test-support"))]
218 pub fn advance_clock(&self, duration: Duration) {
219 self.dispatcher.as_test().unwrap().advance_clock(duration)
220 }
221
222 #[cfg(any(test, feature = "test-support"))]
223 pub fn run_until_parked(&self) {
224 self.dispatcher.as_test().unwrap().run_until_parked()
225 }
226
227 #[cfg(any(test, feature = "test-support"))]
228 pub fn allow_parking(&self) {
229 self.dispatcher.as_test().unwrap().allow_parking();
230 }
231
232 pub fn num_cpus(&self) -> usize {
233 num_cpus::get()
234 }
235
236 pub fn is_main_thread(&self) -> bool {
237 self.dispatcher.is_main_thread()
238 }
239}
240
241impl ForegroundExecutor {
242 pub fn new(dispatcher: Arc<dyn PlatformDispatcher>) -> Self {
243 Self {
244 dispatcher,
245 not_send: PhantomData,
246 }
247 }
248
249 /// Enqueues the given closure to be run on any thread. The closure returns
250 /// a future which will be run to completion on any available thread.
251 pub fn spawn<R>(&self, future: impl Future<Output = R> + 'static) -> Task<R>
252 where
253 R: 'static,
254 {
255 let dispatcher = self.dispatcher.clone();
256 let (runnable, task) = async_task::spawn_local(future, move |runnable| {
257 dispatcher.dispatch_on_main_thread(runnable)
258 });
259 runnable.schedule();
260 Task::Spawned(task)
261 }
262}
263
264pub struct Scope<'a> {
265 executor: BackgroundExecutor,
266 futures: Vec<Pin<Box<dyn Future<Output = ()> + Send + 'static>>>,
267 tx: Option<mpsc::Sender<()>>,
268 rx: mpsc::Receiver<()>,
269 lifetime: PhantomData<&'a ()>,
270}
271
272impl<'a> Scope<'a> {
273 fn new(executor: BackgroundExecutor) -> Self {
274 let (tx, rx) = mpsc::channel(1);
275 Self {
276 executor,
277 tx: Some(tx),
278 rx,
279 futures: Default::default(),
280 lifetime: PhantomData,
281 }
282 }
283
284 pub fn spawn<F>(&mut self, f: F)
285 where
286 F: Future<Output = ()> + Send + 'a,
287 {
288 let tx = self.tx.clone().unwrap();
289
290 // Safety: The 'a lifetime is guaranteed to outlive any of these futures because
291 // dropping this `Scope` blocks until all of the futures have resolved.
292 let f = unsafe {
293 mem::transmute::<
294 Pin<Box<dyn Future<Output = ()> + Send + 'a>>,
295 Pin<Box<dyn Future<Output = ()> + Send + 'static>>,
296 >(Box::pin(async move {
297 f.await;
298 drop(tx);
299 }))
300 };
301 self.futures.push(f);
302 }
303}
304
305impl<'a> Drop for Scope<'a> {
306 fn drop(&mut self) {
307 self.tx.take().unwrap();
308
309 // Wait until the channel is closed, which means that all of the spawned
310 // futures have resolved.
311 self.executor.block(self.rx.next());
312 }
313}