1use crate::{AppContext, PlatformDispatcher};
2use futures::{channel::mpsc, pin_mut, FutureExt};
3use smol::prelude::*;
4use std::{
5 fmt::Debug,
6 marker::PhantomData,
7 mem,
8 num::NonZeroUsize,
9 pin::Pin,
10 rc::Rc,
11 sync::{
12 atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
13 Arc,
14 },
15 task::{Context, Poll},
16 time::Duration,
17};
18use util::TryFutureExt;
19use waker_fn::waker_fn;
20
21#[cfg(any(test, feature = "test-support"))]
22use rand::rngs::StdRng;
23
24/// A pointer to the executor that is currently running,
25/// for spawning background tasks.
26#[derive(Clone)]
27pub struct BackgroundExecutor {
28 dispatcher: Arc<dyn PlatformDispatcher>,
29}
30
31/// A pointer to the executor that is currently running,
32/// for spawning tasks on the main thread.
33#[derive(Clone)]
34pub struct ForegroundExecutor {
35 dispatcher: Arc<dyn PlatformDispatcher>,
36 not_send: PhantomData<Rc<()>>,
37}
38
39/// Task is a primitive that allows work to happen in the background.
40///
41/// It implements [`Future`] so you can `.await` on it.
42///
43/// If you drop a task it will be cancelled immediately. Calling [`Task::detach`] allows
44/// the task to continue running, but with no way to return a value.
45#[must_use]
46#[derive(Debug)]
47pub enum Task<T> {
48 /// A task that is ready to return a value
49 Ready(Option<T>),
50
51 /// A task that is currently running.
52 Spawned(async_task::Task<T>),
53}
54
55impl<T> Task<T> {
56 /// Creates a new task that will resolve with the value
57 pub fn ready(val: T) -> Self {
58 Task::Ready(Some(val))
59 }
60
61 /// Detaching a task runs it to completion in the background
62 pub fn detach(self) {
63 match self {
64 Task::Ready(_) => {}
65 Task::Spawned(task) => task.detach(),
66 }
67 }
68}
69
70impl<E, T> Task<Result<T, E>>
71where
72 T: 'static,
73 E: 'static + Debug,
74{
75 /// Run the task to completion in the background and log any
76 /// errors that occur.
77 #[track_caller]
78 pub fn detach_and_log_err(self, cx: &AppContext) {
79 let location = core::panic::Location::caller();
80 cx.foreground_executor()
81 .spawn(self.log_tracked_err(*location))
82 .detach();
83 }
84}
85
86impl<T> Future for Task<T> {
87 type Output = T;
88
89 fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
90 match unsafe { self.get_unchecked_mut() } {
91 Task::Ready(val) => Poll::Ready(val.take().unwrap()),
92 Task::Spawned(task) => task.poll(cx),
93 }
94 }
95}
96
97/// A task label is an opaque identifier that you can use to
98/// refer to a task in tests.
99#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
100pub struct TaskLabel(NonZeroUsize);
101
102impl Default for TaskLabel {
103 fn default() -> Self {
104 Self::new()
105 }
106}
107
108impl TaskLabel {
109 /// Construct a new task label.
110 pub fn new() -> Self {
111 static NEXT_TASK_LABEL: AtomicUsize = AtomicUsize::new(1);
112 Self(NEXT_TASK_LABEL.fetch_add(1, SeqCst).try_into().unwrap())
113 }
114}
115
116type AnyLocalFuture<R> = Pin<Box<dyn 'static + Future<Output = R>>>;
117
118type AnyFuture<R> = Pin<Box<dyn 'static + Send + Future<Output = R>>>;
119
120/// BackgroundExecutor lets you run things on background threads.
121/// In production this is a thread pool with no ordering guarantees.
122/// In tests this is simulated by running tasks one by one in a deterministic
123/// (but arbitrary) order controlled by the `SEED` environment variable.
124impl BackgroundExecutor {
125 #[doc(hidden)]
126 pub fn new(dispatcher: Arc<dyn PlatformDispatcher>) -> Self {
127 Self { dispatcher }
128 }
129
130 /// Enqueues the given future to be run to completion on a background thread.
131 pub fn spawn<R>(&self, future: impl Future<Output = R> + Send + 'static) -> Task<R>
132 where
133 R: Send + 'static,
134 {
135 self.spawn_internal::<R>(Box::pin(future), None)
136 }
137
138 /// Enqueues the given future to be run to completion on a background thread.
139 /// The given label can be used to control the priority of the task in tests.
140 pub fn spawn_labeled<R>(
141 &self,
142 label: TaskLabel,
143 future: impl Future<Output = R> + Send + 'static,
144 ) -> Task<R>
145 where
146 R: Send + 'static,
147 {
148 self.spawn_internal::<R>(Box::pin(future), Some(label))
149 }
150
151 fn spawn_internal<R: Send + 'static>(
152 &self,
153 future: AnyFuture<R>,
154 label: Option<TaskLabel>,
155 ) -> Task<R> {
156 let dispatcher = self.dispatcher.clone();
157 let (runnable, task) =
158 async_task::spawn(future, move |runnable| dispatcher.dispatch(runnable, label));
159 runnable.schedule();
160 Task::Spawned(task)
161 }
162
163 /// Used by the test harness to run an async test in a synchronous fashion.
164 #[cfg(any(test, feature = "test-support"))]
165 #[track_caller]
166 pub fn block_test<R>(&self, future: impl Future<Output = R>) -> R {
167 if let Ok(value) = self.block_internal(false, future, usize::MAX) {
168 value
169 } else {
170 unreachable!()
171 }
172 }
173
174 /// Block the current thread until the given future resolves.
175 /// Consider using `block_with_timeout` instead.
176 pub fn block<R>(&self, future: impl Future<Output = R>) -> R {
177 if let Ok(value) = self.block_internal(true, future, usize::MAX) {
178 value
179 } else {
180 unreachable!()
181 }
182 }
183
184 #[track_caller]
185 pub(crate) fn block_internal<R>(
186 &self,
187 background_only: bool,
188 future: impl Future<Output = R>,
189 mut max_ticks: usize,
190 ) -> Result<R, ()> {
191 pin_mut!(future);
192 let unparker = self.dispatcher.unparker();
193 let awoken = Arc::new(AtomicBool::new(false));
194
195 let waker = waker_fn({
196 let awoken = awoken.clone();
197 move || {
198 awoken.store(true, SeqCst);
199 unparker.unpark();
200 }
201 });
202 let mut cx = std::task::Context::from_waker(&waker);
203
204 loop {
205 match future.as_mut().poll(&mut cx) {
206 Poll::Ready(result) => return Ok(result),
207 Poll::Pending => {
208 if max_ticks == 0 {
209 return Err(());
210 }
211 max_ticks -= 1;
212
213 if !self.dispatcher.tick(background_only) {
214 if awoken.swap(false, SeqCst) {
215 continue;
216 }
217
218 #[cfg(any(test, feature = "test-support"))]
219 if let Some(test) = self.dispatcher.as_test() {
220 if !test.parking_allowed() {
221 let mut backtrace_message = String::new();
222 let mut waiting_message = String::new();
223 if let Some(backtrace) = test.waiting_backtrace() {
224 backtrace_message =
225 format!("\nbacktrace of waiting future:\n{:?}", backtrace);
226 }
227 if let Some(waiting_hint) = test.waiting_hint() {
228 waiting_message = format!("\n waiting on: {}\n", waiting_hint);
229 }
230 panic!(
231 "parked with nothing left to run{waiting_message}{backtrace_message}",
232 )
233 }
234 }
235
236 self.dispatcher.park();
237 }
238 }
239 }
240 }
241 }
242
243 /// Block the current thread until the given future resolves
244 /// or `duration` has elapsed.
245 pub fn block_with_timeout<R>(
246 &self,
247 duration: Duration,
248 future: impl Future<Output = R>,
249 ) -> Result<R, impl Future<Output = R>> {
250 let mut future = Box::pin(future.fuse());
251 if duration.is_zero() {
252 return Err(future);
253 }
254
255 #[cfg(any(test, feature = "test-support"))]
256 let max_ticks = self
257 .dispatcher
258 .as_test()
259 .map_or(usize::MAX, |dispatcher| dispatcher.gen_block_on_ticks());
260 #[cfg(not(any(test, feature = "test-support")))]
261 let max_ticks = usize::MAX;
262
263 let mut timer = self.timer(duration).fuse();
264
265 let timeout = async {
266 futures::select_biased! {
267 value = future => Ok(value),
268 _ = timer => Err(()),
269 }
270 };
271 match self.block_internal(true, timeout, max_ticks) {
272 Ok(Ok(value)) => Ok(value),
273 _ => Err(future),
274 }
275 }
276
277 /// Scoped lets you start a number of tasks and waits
278 /// for all of them to complete before returning.
279 pub async fn scoped<'scope, F>(&self, scheduler: F)
280 where
281 F: FnOnce(&mut Scope<'scope>),
282 {
283 let mut scope = Scope::new(self.clone());
284 (scheduler)(&mut scope);
285 let spawned = mem::take(&mut scope.futures)
286 .into_iter()
287 .map(|f| self.spawn(f))
288 .collect::<Vec<_>>();
289 for task in spawned {
290 task.await;
291 }
292 }
293
294 /// Returns a task that will complete after the given duration.
295 /// Depending on other concurrent tasks the elapsed duration may be longer
296 /// than requested.
297 pub fn timer(&self, duration: Duration) -> Task<()> {
298 let (runnable, task) = async_task::spawn(async move {}, {
299 let dispatcher = self.dispatcher.clone();
300 move |runnable| dispatcher.dispatch_after(duration, runnable)
301 });
302 runnable.schedule();
303 Task::Spawned(task)
304 }
305
306 /// in tests, start_waiting lets you indicate which task is waiting (for debugging only)
307 #[cfg(any(test, feature = "test-support"))]
308 pub fn start_waiting(&self) {
309 self.dispatcher.as_test().unwrap().start_waiting();
310 }
311
312 /// in tests, removes the debugging data added by start_waiting
313 #[cfg(any(test, feature = "test-support"))]
314 pub fn finish_waiting(&self) {
315 self.dispatcher.as_test().unwrap().finish_waiting();
316 }
317
318 /// in tests, run an arbitrary number of tasks (determined by the SEED environment variable)
319 #[cfg(any(test, feature = "test-support"))]
320 pub fn simulate_random_delay(&self) -> impl Future<Output = ()> {
321 self.dispatcher.as_test().unwrap().simulate_random_delay()
322 }
323
324 /// in tests, indicate that a given task from `spawn_labeled` should run after everything else
325 #[cfg(any(test, feature = "test-support"))]
326 pub fn deprioritize(&self, task_label: TaskLabel) {
327 self.dispatcher.as_test().unwrap().deprioritize(task_label)
328 }
329
330 /// in tests, move time forward. This does not run any tasks, but does make `timer`s ready.
331 #[cfg(any(test, feature = "test-support"))]
332 pub fn advance_clock(&self, duration: Duration) {
333 self.dispatcher.as_test().unwrap().advance_clock(duration)
334 }
335
336 /// in tests, run one task.
337 #[cfg(any(test, feature = "test-support"))]
338 pub fn tick(&self) -> bool {
339 self.dispatcher.as_test().unwrap().tick(false)
340 }
341
342 /// in tests, run all tasks that are ready to run. If after doing so
343 /// the test still has outstanding tasks, this will panic. (See also `allow_parking`)
344 #[cfg(any(test, feature = "test-support"))]
345 pub fn run_until_parked(&self) {
346 self.dispatcher.as_test().unwrap().run_until_parked()
347 }
348
349 /// in tests, prevents `run_until_parked` from panicking if there are outstanding tasks.
350 /// This is useful when you are integrating other (non-GPUI) futures, like disk access, that
351 /// do take real async time to run.
352 #[cfg(any(test, feature = "test-support"))]
353 pub fn allow_parking(&self) {
354 self.dispatcher.as_test().unwrap().allow_parking();
355 }
356
357 /// undoes the effect of [`allow_parking`].
358 #[cfg(any(test, feature = "test-support"))]
359 pub fn forbid_parking(&self) {
360 self.dispatcher.as_test().unwrap().forbid_parking();
361 }
362
363 /// adds detail to the "parked with nothing let to run" message.
364 #[cfg(any(test, feature = "test-support"))]
365 pub fn set_waiting_hint(&self, msg: Option<String>) {
366 self.dispatcher.as_test().unwrap().set_waiting_hint(msg);
367 }
368
369 /// in tests, returns the rng used by the dispatcher and seeded by the `SEED` environment variable
370 #[cfg(any(test, feature = "test-support"))]
371 pub fn rng(&self) -> StdRng {
372 self.dispatcher.as_test().unwrap().rng()
373 }
374
375 /// How many CPUs are available to the dispatcher
376 pub fn num_cpus(&self) -> usize {
377 num_cpus::get()
378 }
379
380 /// Whether we're on the main thread.
381 pub fn is_main_thread(&self) -> bool {
382 self.dispatcher.is_main_thread()
383 }
384
385 #[cfg(any(test, feature = "test-support"))]
386 /// in tests, control the number of ticks that `block_with_timeout` will run before timing out.
387 pub fn set_block_on_ticks(&self, range: std::ops::RangeInclusive<usize>) {
388 self.dispatcher.as_test().unwrap().set_block_on_ticks(range);
389 }
390}
391
392/// ForegroundExecutor runs things on the main thread.
393impl ForegroundExecutor {
394 /// Creates a new ForegroundExecutor from the given PlatformDispatcher.
395 pub fn new(dispatcher: Arc<dyn PlatformDispatcher>) -> Self {
396 Self {
397 dispatcher,
398 not_send: PhantomData,
399 }
400 }
401
402 /// Enqueues the given Task to run on the main thread at some point in the future.
403 pub fn spawn<R>(&self, future: impl Future<Output = R> + 'static) -> Task<R>
404 where
405 R: 'static,
406 {
407 let dispatcher = self.dispatcher.clone();
408 fn inner<R: 'static>(
409 dispatcher: Arc<dyn PlatformDispatcher>,
410 future: AnyLocalFuture<R>,
411 ) -> Task<R> {
412 let (runnable, task) = async_task::spawn_local(future, move |runnable| {
413 dispatcher.dispatch_on_main_thread(runnable)
414 });
415 runnable.schedule();
416 Task::Spawned(task)
417 }
418 inner::<R>(dispatcher, Box::pin(future))
419 }
420}
421
422/// Scope manages a set of tasks that are enqueued and waited on together. See [`BackgroundExecutor::scoped`].
423pub struct Scope<'a> {
424 executor: BackgroundExecutor,
425 futures: Vec<Pin<Box<dyn Future<Output = ()> + Send + 'static>>>,
426 tx: Option<mpsc::Sender<()>>,
427 rx: mpsc::Receiver<()>,
428 lifetime: PhantomData<&'a ()>,
429}
430
431impl<'a> Scope<'a> {
432 fn new(executor: BackgroundExecutor) -> Self {
433 let (tx, rx) = mpsc::channel(1);
434 Self {
435 executor,
436 tx: Some(tx),
437 rx,
438 futures: Default::default(),
439 lifetime: PhantomData,
440 }
441 }
442
443 /// Spawn a future into this scope.
444 pub fn spawn<F>(&mut self, f: F)
445 where
446 F: Future<Output = ()> + Send + 'a,
447 {
448 let tx = self.tx.clone().unwrap();
449
450 // SAFETY: The 'a lifetime is guaranteed to outlive any of these futures because
451 // dropping this `Scope` blocks until all of the futures have resolved.
452 let f = unsafe {
453 mem::transmute::<
454 Pin<Box<dyn Future<Output = ()> + Send + 'a>>,
455 Pin<Box<dyn Future<Output = ()> + Send + 'static>>,
456 >(Box::pin(async move {
457 f.await;
458 drop(tx);
459 }))
460 };
461 self.futures.push(f);
462 }
463}
464
465impl<'a> Drop for Scope<'a> {
466 fn drop(&mut self) {
467 self.tx.take().unwrap();
468
469 // Wait until the channel is closed, which means that all of the spawned
470 // futures have resolved.
471 self.executor.block(self.rx.next());
472 }
473}