1use crate::{AppContext, PlatformDispatcher};
2use futures::{channel::mpsc, pin_mut, FutureExt};
3use smol::prelude::*;
4use std::{
5 fmt::Debug,
6 marker::PhantomData,
7 mem,
8 num::NonZeroUsize,
9 pin::Pin,
10 rc::Rc,
11 sync::{
12 atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
13 Arc,
14 },
15 task::{Context, Poll},
16 time::Duration,
17};
18use util::TryFutureExt;
19use waker_fn::waker_fn;
20
21#[cfg(any(test, feature = "test-support"))]
22use rand::rngs::StdRng;
23
24#[derive(Clone)]
25pub struct BackgroundExecutor {
26 dispatcher: Arc<dyn PlatformDispatcher>,
27}
28
29#[derive(Clone)]
30pub struct ForegroundExecutor {
31 dispatcher: Arc<dyn PlatformDispatcher>,
32 not_send: PhantomData<Rc<()>>,
33}
34
35/// Task is a primitive that allows work to happen in the background.
36///
37/// It implements [`Future`] so you can `.await` on it.
38///
39/// If you drop a task it will be cancelled immediately. Calling [`Task::detach`] allows
40/// the task to continue running in the background, but with no way to return a value.
41#[must_use]
42#[derive(Debug)]
43pub enum Task<T> {
44 Ready(Option<T>),
45 Spawned(async_task::Task<T>),
46}
47
48impl<T> Task<T> {
49 /// Create a new task that will resolve with the value
50 pub fn ready(val: T) -> Self {
51 Task::Ready(Some(val))
52 }
53
54 /// Detaching a task runs it to completion in the background
55 pub fn detach(self) {
56 match self {
57 Task::Ready(_) => {}
58 Task::Spawned(task) => task.detach(),
59 }
60 }
61}
62
63impl<E, T> Task<Result<T, E>>
64where
65 T: 'static,
66 E: 'static + Debug,
67{
68 /// Run the task to completion in the background and log any
69 /// errors that occur.
70 #[track_caller]
71 pub fn detach_and_log_err(self, cx: &AppContext) {
72 let location = core::panic::Location::caller();
73 cx.foreground_executor()
74 .spawn(self.log_tracked_err(*location))
75 .detach();
76 }
77}
78
79impl<T> Future for Task<T> {
80 type Output = T;
81
82 fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
83 match unsafe { self.get_unchecked_mut() } {
84 Task::Ready(val) => Poll::Ready(val.take().unwrap()),
85 Task::Spawned(task) => task.poll(cx),
86 }
87 }
88}
89
90#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
91pub struct TaskLabel(NonZeroUsize);
92
93impl Default for TaskLabel {
94 fn default() -> Self {
95 Self::new()
96 }
97}
98
99impl TaskLabel {
100 pub fn new() -> Self {
101 static NEXT_TASK_LABEL: AtomicUsize = AtomicUsize::new(1);
102 Self(NEXT_TASK_LABEL.fetch_add(1, SeqCst).try_into().unwrap())
103 }
104}
105
106type AnyLocalFuture<R> = Pin<Box<dyn 'static + Future<Output = R>>>;
107
108type AnyFuture<R> = Pin<Box<dyn 'static + Send + Future<Output = R>>>;
109
110/// BackgroundExecutor lets you run things on background threads.
111/// In production this is a thread pool with no ordering guarantees.
112/// In tests this is simulated by running tasks one by one in a deterministic
113/// (but arbitrary) order controlled by the `SEED` environment variable.
114impl BackgroundExecutor {
115 #[doc(hidden)]
116 pub fn new(dispatcher: Arc<dyn PlatformDispatcher>) -> Self {
117 Self { dispatcher }
118 }
119
120 /// Enqueues the given future to be run to completion on a background thread.
121 pub fn spawn<R>(&self, future: impl Future<Output = R> + Send + 'static) -> Task<R>
122 where
123 R: Send + 'static,
124 {
125 self.spawn_internal::<R>(Box::pin(future), None)
126 }
127
128 /// Enqueues the given future to be run to completion on a background thread.
129 /// The given label can be used to control the priority of the task in tests.
130 pub fn spawn_labeled<R>(
131 &self,
132 label: TaskLabel,
133 future: impl Future<Output = R> + Send + 'static,
134 ) -> Task<R>
135 where
136 R: Send + 'static,
137 {
138 self.spawn_internal::<R>(Box::pin(future), Some(label))
139 }
140
141 fn spawn_internal<R: Send + 'static>(
142 &self,
143 future: AnyFuture<R>,
144 label: Option<TaskLabel>,
145 ) -> Task<R> {
146 let dispatcher = self.dispatcher.clone();
147 let (runnable, task) =
148 async_task::spawn(future, move |runnable| dispatcher.dispatch(runnable, label));
149 runnable.schedule();
150 Task::Spawned(task)
151 }
152
153 /// Used by the test harness to run an async test in a synchronous fashion.
154 #[cfg(any(test, feature = "test-support"))]
155 #[track_caller]
156 pub fn block_test<R>(&self, future: impl Future<Output = R>) -> R {
157 if let Ok(value) = self.block_internal(false, future, usize::MAX) {
158 value
159 } else {
160 unreachable!()
161 }
162 }
163
164 /// Block the current thread until the given future resolves.
165 /// Consider using `block_with_timeout` instead.
166 pub fn block<R>(&self, future: impl Future<Output = R>) -> R {
167 if let Ok(value) = self.block_internal(true, future, usize::MAX) {
168 value
169 } else {
170 unreachable!()
171 }
172 }
173
174 #[track_caller]
175 pub(crate) fn block_internal<R>(
176 &self,
177 background_only: bool,
178 future: impl Future<Output = R>,
179 mut max_ticks: usize,
180 ) -> Result<R, ()> {
181 pin_mut!(future);
182 let unparker = self.dispatcher.unparker();
183 let awoken = Arc::new(AtomicBool::new(false));
184
185 let waker = waker_fn({
186 let awoken = awoken.clone();
187 move || {
188 awoken.store(true, SeqCst);
189 unparker.unpark();
190 }
191 });
192 let mut cx = std::task::Context::from_waker(&waker);
193
194 loop {
195 match future.as_mut().poll(&mut cx) {
196 Poll::Ready(result) => return Ok(result),
197 Poll::Pending => {
198 if max_ticks == 0 {
199 return Err(());
200 }
201 max_ticks -= 1;
202
203 if !self.dispatcher.tick(background_only) {
204 if awoken.swap(false, SeqCst) {
205 continue;
206 }
207
208 #[cfg(any(test, feature = "test-support"))]
209 if let Some(test) = self.dispatcher.as_test() {
210 if !test.parking_allowed() {
211 let mut backtrace_message = String::new();
212 if let Some(backtrace) = test.waiting_backtrace() {
213 backtrace_message =
214 format!("\nbacktrace of waiting future:\n{:?}", backtrace);
215 }
216 panic!("parked with nothing left to run\n{:?}", backtrace_message)
217 }
218 }
219
220 self.dispatcher.park();
221 }
222 }
223 }
224 }
225 }
226
227 /// Block the current thread until the given future resolves
228 /// or `duration` has elapsed.
229 pub fn block_with_timeout<R>(
230 &self,
231 duration: Duration,
232 future: impl Future<Output = R>,
233 ) -> Result<R, impl Future<Output = R>> {
234 let mut future = Box::pin(future.fuse());
235 if duration.is_zero() {
236 return Err(future);
237 }
238
239 #[cfg(any(test, feature = "test-support"))]
240 let max_ticks = self
241 .dispatcher
242 .as_test()
243 .map_or(usize::MAX, |dispatcher| dispatcher.gen_block_on_ticks());
244 #[cfg(not(any(test, feature = "test-support")))]
245 let max_ticks = usize::MAX;
246
247 let mut timer = self.timer(duration).fuse();
248
249 let timeout = async {
250 futures::select_biased! {
251 value = future => Ok(value),
252 _ = timer => Err(()),
253 }
254 };
255 match self.block_internal(true, timeout, max_ticks) {
256 Ok(Ok(value)) => Ok(value),
257 _ => Err(future),
258 }
259 }
260
261 /// Scoped lets you start a number of tasks and waits
262 /// for all of them to complete before returning.
263 pub async fn scoped<'scope, F>(&self, scheduler: F)
264 where
265 F: FnOnce(&mut Scope<'scope>),
266 {
267 let mut scope = Scope::new(self.clone());
268 (scheduler)(&mut scope);
269 let spawned = mem::take(&mut scope.futures)
270 .into_iter()
271 .map(|f| self.spawn(f))
272 .collect::<Vec<_>>();
273 for task in spawned {
274 task.await;
275 }
276 }
277
278 /// Returns a task that will complete after the given duration.
279 /// Depending on other concurrent tasks the elapsed duration may be longer
280 /// than requested.
281 pub fn timer(&self, duration: Duration) -> Task<()> {
282 let (runnable, task) = async_task::spawn(async move {}, {
283 let dispatcher = self.dispatcher.clone();
284 move |runnable| dispatcher.dispatch_after(duration, runnable)
285 });
286 runnable.schedule();
287 Task::Spawned(task)
288 }
289
290 /// in tests, start_waiting lets you indicate which task is waiting (for debugging only)
291 #[cfg(any(test, feature = "test-support"))]
292 pub fn start_waiting(&self) {
293 self.dispatcher.as_test().unwrap().start_waiting();
294 }
295
296 /// in tests, removes the debugging data added by start_waiting
297 #[cfg(any(test, feature = "test-support"))]
298 pub fn finish_waiting(&self) {
299 self.dispatcher.as_test().unwrap().finish_waiting();
300 }
301
302 /// in tests, run an arbitrary number of tasks (determined by the SEED environment variable)
303 #[cfg(any(test, feature = "test-support"))]
304 pub fn simulate_random_delay(&self) -> impl Future<Output = ()> {
305 self.dispatcher.as_test().unwrap().simulate_random_delay()
306 }
307
308 /// in tests, indicate that a given task from `spawn_labeled` should run after everything else
309 #[cfg(any(test, feature = "test-support"))]
310 pub fn deprioritize(&self, task_label: TaskLabel) {
311 self.dispatcher.as_test().unwrap().deprioritize(task_label)
312 }
313
314 /// in tests, move time forward. This does not run any tasks, but does make `timer`s ready.
315 #[cfg(any(test, feature = "test-support"))]
316 pub fn advance_clock(&self, duration: Duration) {
317 self.dispatcher.as_test().unwrap().advance_clock(duration)
318 }
319
320 /// in tests, run one task.
321 #[cfg(any(test, feature = "test-support"))]
322 pub fn tick(&self) -> bool {
323 self.dispatcher.as_test().unwrap().tick(false)
324 }
325
326 /// in tests, run all tasks that are ready to run. If after doing so
327 /// the test still has outstanding tasks, this will panic. (See also `allow_parking`)
328 #[cfg(any(test, feature = "test-support"))]
329 pub fn run_until_parked(&self) {
330 self.dispatcher.as_test().unwrap().run_until_parked()
331 }
332
333 /// in tests, prevents `run_until_parked` from panicking if there are outstanding tasks.
334 /// This is useful when you are integrating other (non-GPUI) futures, like disk access, that
335 /// do take real async time to run.
336 #[cfg(any(test, feature = "test-support"))]
337 pub fn allow_parking(&self) {
338 self.dispatcher.as_test().unwrap().allow_parking();
339 }
340
341 /// in tests, returns the rng used by the dispatcher and seeded by the `SEED` environment variable
342 #[cfg(any(test, feature = "test-support"))]
343 pub fn rng(&self) -> StdRng {
344 self.dispatcher.as_test().unwrap().rng()
345 }
346
347 /// How many CPUs are available to the dispatcher
348 pub fn num_cpus(&self) -> usize {
349 num_cpus::get()
350 }
351
352 /// Whether we're on the main thread.
353 pub fn is_main_thread(&self) -> bool {
354 self.dispatcher.is_main_thread()
355 }
356
357 #[cfg(any(test, feature = "test-support"))]
358 /// in tests, control the number of ticks that `block_with_timeout` will run before timing out.
359 pub fn set_block_on_ticks(&self, range: std::ops::RangeInclusive<usize>) {
360 self.dispatcher.as_test().unwrap().set_block_on_ticks(range);
361 }
362}
363
364/// ForegroundExecutor runs things on the main thread.
365impl ForegroundExecutor {
366 pub fn new(dispatcher: Arc<dyn PlatformDispatcher>) -> Self {
367 Self {
368 dispatcher,
369 not_send: PhantomData,
370 }
371 }
372
373 /// Enqueues the given Task to run on the main thread at some point in the future.
374 pub fn spawn<R>(&self, future: impl Future<Output = R> + 'static) -> Task<R>
375 where
376 R: 'static,
377 {
378 let dispatcher = self.dispatcher.clone();
379 fn inner<R: 'static>(
380 dispatcher: Arc<dyn PlatformDispatcher>,
381 future: AnyLocalFuture<R>,
382 ) -> Task<R> {
383 let (runnable, task) = async_task::spawn_local(future, move |runnable| {
384 dispatcher.dispatch_on_main_thread(runnable)
385 });
386 runnable.schedule();
387 Task::Spawned(task)
388 }
389 inner::<R>(dispatcher, Box::pin(future))
390 }
391}
392
393/// Scope manages a set of tasks that are enqueued and waited on together. See [`BackgroundExecutor::scoped`].
394pub struct Scope<'a> {
395 executor: BackgroundExecutor,
396 futures: Vec<Pin<Box<dyn Future<Output = ()> + Send + 'static>>>,
397 tx: Option<mpsc::Sender<()>>,
398 rx: mpsc::Receiver<()>,
399 lifetime: PhantomData<&'a ()>,
400}
401
402impl<'a> Scope<'a> {
403 fn new(executor: BackgroundExecutor) -> Self {
404 let (tx, rx) = mpsc::channel(1);
405 Self {
406 executor,
407 tx: Some(tx),
408 rx,
409 futures: Default::default(),
410 lifetime: PhantomData,
411 }
412 }
413
414 pub fn spawn<F>(&mut self, f: F)
415 where
416 F: Future<Output = ()> + Send + 'a,
417 {
418 let tx = self.tx.clone().unwrap();
419
420 // Safety: The 'a lifetime is guaranteed to outlive any of these futures because
421 // dropping this `Scope` blocks until all of the futures have resolved.
422 let f = unsafe {
423 mem::transmute::<
424 Pin<Box<dyn Future<Output = ()> + Send + 'a>>,
425 Pin<Box<dyn Future<Output = ()> + Send + 'static>>,
426 >(Box::pin(async move {
427 f.await;
428 drop(tx);
429 }))
430 };
431 self.futures.push(f);
432 }
433}
434
435impl<'a> Drop for Scope<'a> {
436 fn drop(&mut self) {
437 self.tx.take().unwrap();
438
439 // Wait until the channel is closed, which means that all of the spawned
440 // futures have resolved.
441 self.executor.block(self.rx.next());
442 }
443}