1use crate::{AppContext, PlatformDispatcher};
2use futures::{channel::mpsc, pin_mut, FutureExt};
3use smol::prelude::*;
4use std::{
5 fmt::Debug,
6 marker::PhantomData,
7 mem,
8 num::NonZeroUsize,
9 pin::Pin,
10 rc::Rc,
11 sync::{
12 atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
13 Arc,
14 },
15 task::{Context, Poll},
16 time::Duration,
17};
18use util::TryFutureExt;
19use waker_fn::waker_fn;
20
21#[cfg(any(test, feature = "test-support"))]
22use rand::rngs::StdRng;
23
24/// A pointer to the executor that is currently running,
25/// for spawning background tasks.
26#[derive(Clone)]
27pub struct BackgroundExecutor {
28 dispatcher: Arc<dyn PlatformDispatcher>,
29}
30
31/// A pointer to the executor that is currently running,
32/// for spawning tasks on the main thread.
33#[derive(Clone)]
34pub struct ForegroundExecutor {
35 dispatcher: Arc<dyn PlatformDispatcher>,
36 not_send: PhantomData<Rc<()>>,
37}
38
39/// Task is a primitive that allows work to happen in the background.
40///
41/// It implements [`Future`] so you can `.await` on it.
42///
43/// If you drop a task it will be cancelled immediately. Calling [`Task::detach`] allows
44/// the task to continue running, but with no way to return a value.
45#[must_use]
46#[derive(Debug)]
47pub enum Task<T> {
48 /// A task that is ready to return a value
49 Ready(Option<T>),
50
51 /// A task that is currently running.
52 Spawned(async_task::Task<T>),
53}
54
55impl<T> Task<T> {
56 /// Creates a new task that will resolve with the value
57 pub fn ready(val: T) -> Self {
58 Task::Ready(Some(val))
59 }
60
61 /// Detaching a task runs it to completion in the background
62 pub fn detach(self) {
63 match self {
64 Task::Ready(_) => {}
65 Task::Spawned(task) => task.detach(),
66 }
67 }
68}
69
70impl<E, T> Task<Result<T, E>>
71where
72 T: 'static,
73 E: 'static + Debug,
74{
75 /// Run the task to completion in the background and log any
76 /// errors that occur.
77 #[track_caller]
78 pub fn detach_and_log_err(self, cx: &AppContext) {
79 let location = core::panic::Location::caller();
80 cx.foreground_executor()
81 .spawn(self.log_tracked_err(*location))
82 .detach();
83 }
84}
85
86impl<T> Future for Task<T> {
87 type Output = T;
88
89 fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
90 match unsafe { self.get_unchecked_mut() } {
91 Task::Ready(val) => Poll::Ready(val.take().unwrap()),
92 Task::Spawned(task) => task.poll(cx),
93 }
94 }
95}
96
97/// A task label is an opaque identifier that you can use to
98/// refer to a task in tests.
99#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
100pub struct TaskLabel(NonZeroUsize);
101
102impl Default for TaskLabel {
103 fn default() -> Self {
104 Self::new()
105 }
106}
107
108impl TaskLabel {
109 /// Construct a new task label.
110 pub fn new() -> Self {
111 static NEXT_TASK_LABEL: AtomicUsize = AtomicUsize::new(1);
112 Self(NEXT_TASK_LABEL.fetch_add(1, SeqCst).try_into().unwrap())
113 }
114}
115
116type AnyLocalFuture<R> = Pin<Box<dyn 'static + Future<Output = R>>>;
117
118type AnyFuture<R> = Pin<Box<dyn 'static + Send + Future<Output = R>>>;
119
120/// BackgroundExecutor lets you run things on background threads.
121/// In production this is a thread pool with no ordering guarantees.
122/// In tests this is simulated by running tasks one by one in a deterministic
123/// (but arbitrary) order controlled by the `SEED` environment variable.
124impl BackgroundExecutor {
125 #[doc(hidden)]
126 pub fn new(dispatcher: Arc<dyn PlatformDispatcher>) -> Self {
127 Self { dispatcher }
128 }
129
130 /// Enqueues the given future to be run to completion on a background thread.
131 pub fn spawn<R>(&self, future: impl Future<Output = R> + Send + 'static) -> Task<R>
132 where
133 R: Send + 'static,
134 {
135 self.spawn_internal::<R>(Box::pin(future), None)
136 }
137
138 /// Enqueues the given future to be run to completion on a background thread.
139 /// The given label can be used to control the priority of the task in tests.
140 pub fn spawn_labeled<R>(
141 &self,
142 label: TaskLabel,
143 future: impl Future<Output = R> + Send + 'static,
144 ) -> Task<R>
145 where
146 R: Send + 'static,
147 {
148 self.spawn_internal::<R>(Box::pin(future), Some(label))
149 }
150
151 fn spawn_internal<R: Send + 'static>(
152 &self,
153 future: AnyFuture<R>,
154 label: Option<TaskLabel>,
155 ) -> Task<R> {
156 let dispatcher = self.dispatcher.clone();
157 let (runnable, task) =
158 async_task::spawn(future, move |runnable| dispatcher.dispatch(runnable, label));
159 runnable.schedule();
160 Task::Spawned(task)
161 }
162
163 /// Used by the test harness to run an async test in a synchronous fashion.
164 #[cfg(any(test, feature = "test-support"))]
165 #[track_caller]
166 pub fn block_test<R>(&self, future: impl Future<Output = R>) -> R {
167 if let Ok(value) = self.block_internal(false, future, usize::MAX) {
168 value
169 } else {
170 unreachable!()
171 }
172 }
173
174 /// Block the current thread until the given future resolves.
175 /// Consider using `block_with_timeout` instead.
176 pub fn block<R>(&self, future: impl Future<Output = R>) -> R {
177 if let Ok(value) = self.block_internal(true, future, usize::MAX) {
178 value
179 } else {
180 unreachable!()
181 }
182 }
183
184 #[track_caller]
185 pub(crate) fn block_internal<R>(
186 &self,
187 background_only: bool,
188 future: impl Future<Output = R>,
189 mut max_ticks: usize,
190 ) -> Result<R, ()> {
191 pin_mut!(future);
192 let unparker = self.dispatcher.unparker();
193 let awoken = Arc::new(AtomicBool::new(false));
194
195 let waker = waker_fn({
196 let awoken = awoken.clone();
197 move || {
198 awoken.store(true, SeqCst);
199 unparker.unpark();
200 }
201 });
202 let mut cx = std::task::Context::from_waker(&waker);
203
204 loop {
205 match future.as_mut().poll(&mut cx) {
206 Poll::Ready(result) => return Ok(result),
207 Poll::Pending => {
208 if max_ticks == 0 {
209 return Err(());
210 }
211 max_ticks -= 1;
212
213 if !self.dispatcher.tick(background_only) {
214 if awoken.swap(false, SeqCst) {
215 continue;
216 }
217
218 #[cfg(any(test, feature = "test-support"))]
219 if let Some(test) = self.dispatcher.as_test() {
220 if !test.parking_allowed() {
221 let mut backtrace_message = String::new();
222 if let Some(backtrace) = test.waiting_backtrace() {
223 backtrace_message =
224 format!("\nbacktrace of waiting future:\n{:?}", backtrace);
225 }
226 panic!("parked with nothing left to run\n{:?}", backtrace_message)
227 }
228 }
229
230 self.dispatcher.park();
231 }
232 }
233 }
234 }
235 }
236
237 /// Block the current thread until the given future resolves
238 /// or `duration` has elapsed.
239 pub fn block_with_timeout<R>(
240 &self,
241 duration: Duration,
242 future: impl Future<Output = R>,
243 ) -> Result<R, impl Future<Output = R>> {
244 let mut future = Box::pin(future.fuse());
245 if duration.is_zero() {
246 return Err(future);
247 }
248
249 #[cfg(any(test, feature = "test-support"))]
250 let max_ticks = self
251 .dispatcher
252 .as_test()
253 .map_or(usize::MAX, |dispatcher| dispatcher.gen_block_on_ticks());
254 #[cfg(not(any(test, feature = "test-support")))]
255 let max_ticks = usize::MAX;
256
257 let mut timer = self.timer(duration).fuse();
258
259 let timeout = async {
260 futures::select_biased! {
261 value = future => Ok(value),
262 _ = timer => Err(()),
263 }
264 };
265 match self.block_internal(true, timeout, max_ticks) {
266 Ok(Ok(value)) => Ok(value),
267 _ => Err(future),
268 }
269 }
270
271 /// Scoped lets you start a number of tasks and waits
272 /// for all of them to complete before returning.
273 pub async fn scoped<'scope, F>(&self, scheduler: F)
274 where
275 F: FnOnce(&mut Scope<'scope>),
276 {
277 let mut scope = Scope::new(self.clone());
278 (scheduler)(&mut scope);
279 let spawned = mem::take(&mut scope.futures)
280 .into_iter()
281 .map(|f| self.spawn(f))
282 .collect::<Vec<_>>();
283 for task in spawned {
284 task.await;
285 }
286 }
287
288 /// Returns a task that will complete after the given duration.
289 /// Depending on other concurrent tasks the elapsed duration may be longer
290 /// than requested.
291 pub fn timer(&self, duration: Duration) -> Task<()> {
292 let (runnable, task) = async_task::spawn(async move {}, {
293 let dispatcher = self.dispatcher.clone();
294 move |runnable| dispatcher.dispatch_after(duration, runnable)
295 });
296 runnable.schedule();
297 Task::Spawned(task)
298 }
299
300 /// in tests, start_waiting lets you indicate which task is waiting (for debugging only)
301 #[cfg(any(test, feature = "test-support"))]
302 pub fn start_waiting(&self) {
303 self.dispatcher.as_test().unwrap().start_waiting();
304 }
305
306 /// in tests, removes the debugging data added by start_waiting
307 #[cfg(any(test, feature = "test-support"))]
308 pub fn finish_waiting(&self) {
309 self.dispatcher.as_test().unwrap().finish_waiting();
310 }
311
312 /// in tests, run an arbitrary number of tasks (determined by the SEED environment variable)
313 #[cfg(any(test, feature = "test-support"))]
314 pub fn simulate_random_delay(&self) -> impl Future<Output = ()> {
315 self.dispatcher.as_test().unwrap().simulate_random_delay()
316 }
317
318 /// in tests, indicate that a given task from `spawn_labeled` should run after everything else
319 #[cfg(any(test, feature = "test-support"))]
320 pub fn deprioritize(&self, task_label: TaskLabel) {
321 self.dispatcher.as_test().unwrap().deprioritize(task_label)
322 }
323
324 /// in tests, move time forward. This does not run any tasks, but does make `timer`s ready.
325 #[cfg(any(test, feature = "test-support"))]
326 pub fn advance_clock(&self, duration: Duration) {
327 self.dispatcher.as_test().unwrap().advance_clock(duration)
328 }
329
330 /// in tests, run one task.
331 #[cfg(any(test, feature = "test-support"))]
332 pub fn tick(&self) -> bool {
333 self.dispatcher.as_test().unwrap().tick(false)
334 }
335
336 /// in tests, run all tasks that are ready to run. If after doing so
337 /// the test still has outstanding tasks, this will panic. (See also `allow_parking`)
338 #[cfg(any(test, feature = "test-support"))]
339 pub fn run_until_parked(&self) {
340 self.dispatcher.as_test().unwrap().run_until_parked()
341 }
342
343 /// in tests, prevents `run_until_parked` from panicking if there are outstanding tasks.
344 /// This is useful when you are integrating other (non-GPUI) futures, like disk access, that
345 /// do take real async time to run.
346 #[cfg(any(test, feature = "test-support"))]
347 pub fn allow_parking(&self) {
348 self.dispatcher.as_test().unwrap().allow_parking();
349 }
350
351 /// undoes the effect of [`allow_parking`].
352 #[cfg(any(test, feature = "test-support"))]
353 pub fn forbid_parking(&self) {
354 self.dispatcher.as_test().unwrap().forbid_parking();
355 }
356
357 /// in tests, returns the rng used by the dispatcher and seeded by the `SEED` environment variable
358 #[cfg(any(test, feature = "test-support"))]
359 pub fn rng(&self) -> StdRng {
360 self.dispatcher.as_test().unwrap().rng()
361 }
362
363 /// How many CPUs are available to the dispatcher
364 pub fn num_cpus(&self) -> usize {
365 num_cpus::get()
366 }
367
368 /// Whether we're on the main thread.
369 pub fn is_main_thread(&self) -> bool {
370 self.dispatcher.is_main_thread()
371 }
372
373 #[cfg(any(test, feature = "test-support"))]
374 /// in tests, control the number of ticks that `block_with_timeout` will run before timing out.
375 pub fn set_block_on_ticks(&self, range: std::ops::RangeInclusive<usize>) {
376 self.dispatcher.as_test().unwrap().set_block_on_ticks(range);
377 }
378}
379
380/// ForegroundExecutor runs things on the main thread.
381impl ForegroundExecutor {
382 /// Creates a new ForegroundExecutor from the given PlatformDispatcher.
383 pub fn new(dispatcher: Arc<dyn PlatformDispatcher>) -> Self {
384 Self {
385 dispatcher,
386 not_send: PhantomData,
387 }
388 }
389
390 /// Enqueues the given Task to run on the main thread at some point in the future.
391 pub fn spawn<R>(&self, future: impl Future<Output = R> + 'static) -> Task<R>
392 where
393 R: 'static,
394 {
395 let dispatcher = self.dispatcher.clone();
396 fn inner<R: 'static>(
397 dispatcher: Arc<dyn PlatformDispatcher>,
398 future: AnyLocalFuture<R>,
399 ) -> Task<R> {
400 let (runnable, task) = async_task::spawn_local(future, move |runnable| {
401 dispatcher.dispatch_on_main_thread(runnable)
402 });
403 runnable.schedule();
404 Task::Spawned(task)
405 }
406 inner::<R>(dispatcher, Box::pin(future))
407 }
408}
409
410/// Scope manages a set of tasks that are enqueued and waited on together. See [`BackgroundExecutor::scoped`].
411pub struct Scope<'a> {
412 executor: BackgroundExecutor,
413 futures: Vec<Pin<Box<dyn Future<Output = ()> + Send + 'static>>>,
414 tx: Option<mpsc::Sender<()>>,
415 rx: mpsc::Receiver<()>,
416 lifetime: PhantomData<&'a ()>,
417}
418
419impl<'a> Scope<'a> {
420 fn new(executor: BackgroundExecutor) -> Self {
421 let (tx, rx) = mpsc::channel(1);
422 Self {
423 executor,
424 tx: Some(tx),
425 rx,
426 futures: Default::default(),
427 lifetime: PhantomData,
428 }
429 }
430
431 /// Spawn a future into this scope.
432 pub fn spawn<F>(&mut self, f: F)
433 where
434 F: Future<Output = ()> + Send + 'a,
435 {
436 let tx = self.tx.clone().unwrap();
437
438 // SAFETY: The 'a lifetime is guaranteed to outlive any of these futures because
439 // dropping this `Scope` blocks until all of the futures have resolved.
440 let f = unsafe {
441 mem::transmute::<
442 Pin<Box<dyn Future<Output = ()> + Send + 'a>>,
443 Pin<Box<dyn Future<Output = ()> + Send + 'static>>,
444 >(Box::pin(async move {
445 f.await;
446 drop(tx);
447 }))
448 };
449 self.futures.push(f);
450 }
451}
452
453impl<'a> Drop for Scope<'a> {
454 fn drop(&mut self) {
455 self.tx.take().unwrap();
456
457 // Wait until the channel is closed, which means that all of the spawned
458 // futures have resolved.
459 self.executor.block(self.rx.next());
460 }
461}