1use crate::{AppContext, PlatformDispatcher};
2use futures::{channel::mpsc, pin_mut, FutureExt};
3use smol::prelude::*;
4use std::{
5 fmt::Debug,
6 marker::PhantomData,
7 mem,
8 num::NonZeroUsize,
9 pin::Pin,
10 rc::Rc,
11 sync::{
12 atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
13 Arc,
14 },
15 task::{Context, Poll},
16 time::Duration,
17};
18use util::TryFutureExt;
19use waker_fn::waker_fn;
20
21#[cfg(any(test, feature = "test-support"))]
22use rand::rngs::StdRng;
23
24#[derive(Clone)]
25pub struct BackgroundExecutor {
26 dispatcher: Arc<dyn PlatformDispatcher>,
27}
28
29#[derive(Clone)]
30pub struct ForegroundExecutor {
31 dispatcher: Arc<dyn PlatformDispatcher>,
32 not_send: PhantomData<Rc<()>>,
33}
34
35#[must_use]
36#[derive(Debug)]
37pub enum Task<T> {
38 Ready(Option<T>),
39 Spawned(async_task::Task<T>),
40}
41
42impl<T> Task<T> {
43 pub fn ready(val: T) -> Self {
44 Task::Ready(Some(val))
45 }
46
47 pub fn detach(self) {
48 match self {
49 Task::Ready(_) => {}
50 Task::Spawned(task) => task.detach(),
51 }
52 }
53}
54
55impl<E, T> Task<Result<T, E>>
56where
57 T: 'static,
58 E: 'static + Debug,
59{
60 pub fn detach_and_log_err(self, cx: &mut AppContext) {
61 cx.foreground_executor().spawn(self.log_err()).detach();
62 }
63}
64
65impl<T> Future for Task<T> {
66 type Output = T;
67
68 fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
69 match unsafe { self.get_unchecked_mut() } {
70 Task::Ready(val) => Poll::Ready(val.take().unwrap()),
71 Task::Spawned(task) => task.poll(cx),
72 }
73 }
74}
75
76#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
77pub struct TaskLabel(NonZeroUsize);
78
79impl TaskLabel {
80 pub fn new() -> Self {
81 static NEXT_TASK_LABEL: AtomicUsize = AtomicUsize::new(1);
82 Self(NEXT_TASK_LABEL.fetch_add(1, SeqCst).try_into().unwrap())
83 }
84}
85
86type AnyLocalFuture<R> = Pin<Box<dyn 'static + Future<Output = R>>>;
87
88type AnyFuture<R> = Pin<Box<dyn 'static + Send + Future<Output = R>>>;
89
90impl BackgroundExecutor {
91 pub fn new(dispatcher: Arc<dyn PlatformDispatcher>) -> Self {
92 Self { dispatcher }
93 }
94
95 /// Enqueues the given future to be run to completion on a background thread.
96 pub fn spawn<R>(&self, future: impl Future<Output = R> + Send + 'static) -> Task<R>
97 where
98 R: Send + 'static,
99 {
100 self.spawn_internal::<R>(Box::pin(future), None)
101 }
102
103 /// Enqueues the given future to be run to completion on a background thread.
104 /// The given label can be used to control the priority of the task in tests.
105 pub fn spawn_labeled<R>(
106 &self,
107 label: TaskLabel,
108 future: impl Future<Output = R> + Send + 'static,
109 ) -> Task<R>
110 where
111 R: Send + 'static,
112 {
113 self.spawn_internal::<R>(Box::pin(future), Some(label))
114 }
115
116 fn spawn_internal<R: Send + 'static>(
117 &self,
118 future: AnyFuture<R>,
119 label: Option<TaskLabel>,
120 ) -> Task<R> {
121 let dispatcher = self.dispatcher.clone();
122 let (runnable, task) =
123 async_task::spawn(future, move |runnable| dispatcher.dispatch(runnable, label));
124 runnable.schedule();
125 Task::Spawned(task)
126 }
127
128 #[cfg(any(test, feature = "test-support"))]
129 #[track_caller]
130 pub fn block_test<R>(&self, future: impl Future<Output = R>) -> R {
131 if let Ok(value) = self.block_internal(false, future, usize::MAX) {
132 value
133 } else {
134 unreachable!()
135 }
136 }
137
138 pub fn block<R>(&self, future: impl Future<Output = R>) -> R {
139 if let Ok(value) = self.block_internal(true, future, usize::MAX) {
140 value
141 } else {
142 unreachable!()
143 }
144 }
145
146 #[track_caller]
147 pub(crate) fn block_internal<R>(
148 &self,
149 background_only: bool,
150 future: impl Future<Output = R>,
151 mut max_ticks: usize,
152 ) -> Result<R, ()> {
153 pin_mut!(future);
154 let unparker = self.dispatcher.unparker();
155 let awoken = Arc::new(AtomicBool::new(false));
156
157 let waker = waker_fn({
158 let awoken = awoken.clone();
159 move || {
160 awoken.store(true, SeqCst);
161 unparker.unpark();
162 }
163 });
164 let mut cx = std::task::Context::from_waker(&waker);
165
166 loop {
167 match future.as_mut().poll(&mut cx) {
168 Poll::Ready(result) => return Ok(result),
169 Poll::Pending => {
170 if max_ticks == 0 {
171 return Err(());
172 }
173 max_ticks -= 1;
174
175 if !self.dispatcher.tick(background_only) {
176 if awoken.swap(false, SeqCst) {
177 continue;
178 }
179
180 #[cfg(any(test, feature = "test-support"))]
181 if let Some(test) = self.dispatcher.as_test() {
182 if !test.parking_allowed() {
183 let mut backtrace_message = String::new();
184 if let Some(backtrace) = test.waiting_backtrace() {
185 backtrace_message =
186 format!("\nbacktrace of waiting future:\n{:?}", backtrace);
187 }
188 panic!("parked with nothing left to run\n{:?}", backtrace_message)
189 }
190 }
191
192 self.dispatcher.park();
193 }
194 }
195 }
196 }
197 }
198
199 pub fn block_with_timeout<R>(
200 &self,
201 duration: Duration,
202 future: impl Future<Output = R>,
203 ) -> Result<R, impl Future<Output = R>> {
204 let mut future = Box::pin(future.fuse());
205 if duration.is_zero() {
206 return Err(future);
207 }
208
209 let max_ticks = if cfg!(any(test, feature = "test-support")) {
210 self.dispatcher
211 .as_test()
212 .map_or(usize::MAX, |dispatcher| dispatcher.gen_block_on_ticks())
213 } else {
214 usize::MAX
215 };
216 let mut timer = self.timer(duration).fuse();
217
218 let timeout = async {
219 futures::select_biased! {
220 value = future => Ok(value),
221 _ = timer => Err(()),
222 }
223 };
224 match self.block_internal(true, timeout, max_ticks) {
225 Ok(Ok(value)) => Ok(value),
226 _ => Err(future),
227 }
228 }
229
230 pub async fn scoped<'scope, F>(&self, scheduler: F)
231 where
232 F: FnOnce(&mut Scope<'scope>),
233 {
234 let mut scope = Scope::new(self.clone());
235 (scheduler)(&mut scope);
236 let spawned = mem::take(&mut scope.futures)
237 .into_iter()
238 .map(|f| self.spawn(f))
239 .collect::<Vec<_>>();
240 for task in spawned {
241 task.await;
242 }
243 }
244
245 pub fn timer(&self, duration: Duration) -> Task<()> {
246 let (runnable, task) = async_task::spawn(async move {}, {
247 let dispatcher = self.dispatcher.clone();
248 move |runnable| dispatcher.dispatch_after(duration, runnable)
249 });
250 runnable.schedule();
251 Task::Spawned(task)
252 }
253
254 #[cfg(any(test, feature = "test-support"))]
255 pub fn start_waiting(&self) {
256 self.dispatcher.as_test().unwrap().start_waiting();
257 }
258
259 #[cfg(any(test, feature = "test-support"))]
260 pub fn finish_waiting(&self) {
261 self.dispatcher.as_test().unwrap().finish_waiting();
262 }
263
264 #[cfg(any(test, feature = "test-support"))]
265 pub fn simulate_random_delay(&self) -> impl Future<Output = ()> {
266 self.dispatcher.as_test().unwrap().simulate_random_delay()
267 }
268
269 #[cfg(any(test, feature = "test-support"))]
270 pub fn deprioritize(&self, task_label: TaskLabel) {
271 self.dispatcher.as_test().unwrap().deprioritize(task_label)
272 }
273
274 #[cfg(any(test, feature = "test-support"))]
275 pub fn advance_clock(&self, duration: Duration) {
276 self.dispatcher.as_test().unwrap().advance_clock(duration)
277 }
278
279 #[cfg(any(test, feature = "test-support"))]
280 pub fn tick(&self) -> bool {
281 self.dispatcher.as_test().unwrap().tick(false)
282 }
283
284 #[cfg(any(test, feature = "test-support"))]
285 pub fn run_until_parked(&self) {
286 self.dispatcher.as_test().unwrap().run_until_parked()
287 }
288
289 #[cfg(any(test, feature = "test-support"))]
290 pub fn allow_parking(&self) {
291 self.dispatcher.as_test().unwrap().allow_parking();
292 }
293
294 #[cfg(any(test, feature = "test-support"))]
295 pub fn rng(&self) -> StdRng {
296 self.dispatcher.as_test().unwrap().rng()
297 }
298
299 pub fn num_cpus(&self) -> usize {
300 num_cpus::get()
301 }
302
303 pub fn is_main_thread(&self) -> bool {
304 self.dispatcher.is_main_thread()
305 }
306
307 #[cfg(any(test, feature = "test-support"))]
308 pub fn set_block_on_ticks(&self, range: std::ops::RangeInclusive<usize>) {
309 self.dispatcher.as_test().unwrap().set_block_on_ticks(range);
310 }
311}
312
313impl ForegroundExecutor {
314 pub fn new(dispatcher: Arc<dyn PlatformDispatcher>) -> Self {
315 Self {
316 dispatcher,
317 not_send: PhantomData,
318 }
319 }
320
321 /// Enqueues the given closure to be run on any thread. The closure returns
322 /// a future which will be run to completion on any available thread.
323 pub fn spawn<R>(&self, future: impl Future<Output = R> + 'static) -> Task<R>
324 where
325 R: 'static,
326 {
327 let dispatcher = self.dispatcher.clone();
328 fn inner<R: 'static>(
329 dispatcher: Arc<dyn PlatformDispatcher>,
330 future: AnyLocalFuture<R>,
331 ) -> Task<R> {
332 let (runnable, task) = async_task::spawn_local(future, move |runnable| {
333 dispatcher.dispatch_on_main_thread(runnable)
334 });
335 runnable.schedule();
336 Task::Spawned(task)
337 }
338 inner::<R>(dispatcher, Box::pin(future))
339 }
340}
341
342pub struct Scope<'a> {
343 executor: BackgroundExecutor,
344 futures: Vec<Pin<Box<dyn Future<Output = ()> + Send + 'static>>>,
345 tx: Option<mpsc::Sender<()>>,
346 rx: mpsc::Receiver<()>,
347 lifetime: PhantomData<&'a ()>,
348}
349
350impl<'a> Scope<'a> {
351 fn new(executor: BackgroundExecutor) -> Self {
352 let (tx, rx) = mpsc::channel(1);
353 Self {
354 executor,
355 tx: Some(tx),
356 rx,
357 futures: Default::default(),
358 lifetime: PhantomData,
359 }
360 }
361
362 pub fn spawn<F>(&mut self, f: F)
363 where
364 F: Future<Output = ()> + Send + 'a,
365 {
366 let tx = self.tx.clone().unwrap();
367
368 // Safety: The 'a lifetime is guaranteed to outlive any of these futures because
369 // dropping this `Scope` blocks until all of the futures have resolved.
370 let f = unsafe {
371 mem::transmute::<
372 Pin<Box<dyn Future<Output = ()> + Send + 'a>>,
373 Pin<Box<dyn Future<Output = ()> + Send + 'static>>,
374 >(Box::pin(async move {
375 f.await;
376 drop(tx);
377 }))
378 };
379 self.futures.push(f);
380 }
381}
382
383impl<'a> Drop for Scope<'a> {
384 fn drop(&mut self) {
385 self.tx.take().unwrap();
386
387 // Wait until the channel is closed, which means that all of the spawned
388 // futures have resolved.
389 self.executor.block(self.rx.next());
390 }
391}