1use crate::{AppContext, PlatformDispatcher};
2use futures::{channel::mpsc, pin_mut, FutureExt};
3use smol::prelude::*;
4use std::{
5 fmt::Debug,
6 marker::PhantomData,
7 mem,
8 num::NonZeroUsize,
9 pin::Pin,
10 rc::Rc,
11 sync::{
12 atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
13 Arc,
14 },
15 task::{Context, Poll},
16 time::Duration,
17};
18use util::TryFutureExt;
19use waker_fn::waker_fn;
20
21#[cfg(any(test, feature = "test-support"))]
22use rand::rngs::StdRng;
23
24#[derive(Clone)]
25pub struct BackgroundExecutor {
26 dispatcher: Arc<dyn PlatformDispatcher>,
27}
28
29#[derive(Clone)]
30pub struct ForegroundExecutor {
31 dispatcher: Arc<dyn PlatformDispatcher>,
32 not_send: PhantomData<Rc<()>>,
33}
34
35/// Task is a primitive that allows work to happen in the background.
36///
37/// It implements [`Future`] so you can `.await` on it.
38///
39/// If you drop a task it will be cancelled immediately. Calling [`Task::detach`] allows
40/// the task to continue running in the background, but with no way to return a value.
41#[must_use]
42#[derive(Debug)]
43pub enum Task<T> {
44 Ready(Option<T>),
45 Spawned(async_task::Task<T>),
46}
47
48impl<T> Task<T> {
49 /// Create a new task that will resolve with the value
50 pub fn ready(val: T) -> Self {
51 Task::Ready(Some(val))
52 }
53
54 /// Detaching a task runs it to completion in the background
55 pub fn detach(self) {
56 match self {
57 Task::Ready(_) => {}
58 Task::Spawned(task) => task.detach(),
59 }
60 }
61}
62
63impl<E, T> Task<Result<T, E>>
64where
65 T: 'static,
66 E: 'static + Debug,
67{
68 /// Run the task to completion in the background and log any
69 /// errors that occur.
70 #[track_caller]
71 pub fn detach_and_log_err(self, cx: &mut AppContext) {
72 let location = core::panic::Location::caller();
73 cx.foreground_executor()
74 .spawn(self.log_tracked_err(*location))
75 .detach();
76 }
77}
78
79impl<T> Future for Task<T> {
80 type Output = T;
81
82 fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
83 match unsafe { self.get_unchecked_mut() } {
84 Task::Ready(val) => Poll::Ready(val.take().unwrap()),
85 Task::Spawned(task) => task.poll(cx),
86 }
87 }
88}
89
90#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
91pub struct TaskLabel(NonZeroUsize);
92
93impl Default for TaskLabel {
94 fn default() -> Self {
95 Self::new()
96 }
97}
98
99impl TaskLabel {
100 pub fn new() -> Self {
101 static NEXT_TASK_LABEL: AtomicUsize = AtomicUsize::new(1);
102 Self(NEXT_TASK_LABEL.fetch_add(1, SeqCst).try_into().unwrap())
103 }
104}
105
106type AnyLocalFuture<R> = Pin<Box<dyn 'static + Future<Output = R>>>;
107
108type AnyFuture<R> = Pin<Box<dyn 'static + Send + Future<Output = R>>>;
109
110/// BackgroundExecutor lets you run things on background threads.
111/// In production this is a thread pool with no ordering guarantees.
112/// In tests this is simalated by running tasks one by one in a deterministic
113/// (but arbitrary) order controlled by the `SEED` environment variable.
114impl BackgroundExecutor {
115 pub fn new(dispatcher: Arc<dyn PlatformDispatcher>) -> Self {
116 Self { dispatcher }
117 }
118
119 /// Enqueues the given future to be run to completion on a background thread.
120 pub fn spawn<R>(&self, future: impl Future<Output = R> + Send + 'static) -> Task<R>
121 where
122 R: Send + 'static,
123 {
124 self.spawn_internal::<R>(Box::pin(future), None)
125 }
126
127 /// Enqueues the given future to be run to completion on a background thread.
128 /// The given label can be used to control the priority of the task in tests.
129 pub fn spawn_labeled<R>(
130 &self,
131 label: TaskLabel,
132 future: impl Future<Output = R> + Send + 'static,
133 ) -> Task<R>
134 where
135 R: Send + 'static,
136 {
137 self.spawn_internal::<R>(Box::pin(future), Some(label))
138 }
139
140 fn spawn_internal<R: Send + 'static>(
141 &self,
142 future: AnyFuture<R>,
143 label: Option<TaskLabel>,
144 ) -> Task<R> {
145 let dispatcher = self.dispatcher.clone();
146 let (runnable, task) =
147 async_task::spawn(future, move |runnable| dispatcher.dispatch(runnable, label));
148 runnable.schedule();
149 Task::Spawned(task)
150 }
151
152 /// Used by the test harness to run an async test in a syncronous fashion.
153 #[cfg(any(test, feature = "test-support"))]
154 #[track_caller]
155 pub fn block_test<R>(&self, future: impl Future<Output = R>) -> R {
156 if let Ok(value) = self.block_internal(false, future, usize::MAX) {
157 value
158 } else {
159 unreachable!()
160 }
161 }
162
163 /// Block the current thread until the given future resolves.
164 /// Consider using `block_with_timeout` instead.
165 pub fn block<R>(&self, future: impl Future<Output = R>) -> R {
166 if let Ok(value) = self.block_internal(true, future, usize::MAX) {
167 value
168 } else {
169 unreachable!()
170 }
171 }
172
173 #[track_caller]
174 pub(crate) fn block_internal<R>(
175 &self,
176 background_only: bool,
177 future: impl Future<Output = R>,
178 mut max_ticks: usize,
179 ) -> Result<R, ()> {
180 pin_mut!(future);
181 let unparker = self.dispatcher.unparker();
182 let awoken = Arc::new(AtomicBool::new(false));
183
184 let waker = waker_fn({
185 let awoken = awoken.clone();
186 move || {
187 awoken.store(true, SeqCst);
188 unparker.unpark();
189 }
190 });
191 let mut cx = std::task::Context::from_waker(&waker);
192
193 loop {
194 match future.as_mut().poll(&mut cx) {
195 Poll::Ready(result) => return Ok(result),
196 Poll::Pending => {
197 if max_ticks == 0 {
198 return Err(());
199 }
200 max_ticks -= 1;
201
202 if !self.dispatcher.tick(background_only) {
203 if awoken.swap(false, SeqCst) {
204 continue;
205 }
206
207 #[cfg(any(test, feature = "test-support"))]
208 if let Some(test) = self.dispatcher.as_test() {
209 if !test.parking_allowed() {
210 let mut backtrace_message = String::new();
211 if let Some(backtrace) = test.waiting_backtrace() {
212 backtrace_message =
213 format!("\nbacktrace of waiting future:\n{:?}", backtrace);
214 }
215 panic!("parked with nothing left to run\n{:?}", backtrace_message)
216 }
217 }
218
219 self.dispatcher.park();
220 }
221 }
222 }
223 }
224 }
225
226 /// Block the current thread until the given future resolves
227 /// or `duration` has elapsed.
228 pub fn block_with_timeout<R>(
229 &self,
230 duration: Duration,
231 future: impl Future<Output = R>,
232 ) -> Result<R, impl Future<Output = R>> {
233 let mut future = Box::pin(future.fuse());
234 if duration.is_zero() {
235 return Err(future);
236 }
237
238 #[cfg(any(test, feature = "test-support"))]
239 let max_ticks = self
240 .dispatcher
241 .as_test()
242 .map_or(usize::MAX, |dispatcher| dispatcher.gen_block_on_ticks());
243 #[cfg(not(any(test, feature = "test-support")))]
244 let max_ticks = usize::MAX;
245
246 let mut timer = self.timer(duration).fuse();
247
248 let timeout = async {
249 futures::select_biased! {
250 value = future => Ok(value),
251 _ = timer => Err(()),
252 }
253 };
254 match self.block_internal(true, timeout, max_ticks) {
255 Ok(Ok(value)) => Ok(value),
256 _ => Err(future),
257 }
258 }
259
260 /// Scoped lets you start a number of tasks and waits
261 /// for all of them to complete before returning.
262 pub async fn scoped<'scope, F>(&self, scheduler: F)
263 where
264 F: FnOnce(&mut Scope<'scope>),
265 {
266 let mut scope = Scope::new(self.clone());
267 (scheduler)(&mut scope);
268 let spawned = mem::take(&mut scope.futures)
269 .into_iter()
270 .map(|f| self.spawn(f))
271 .collect::<Vec<_>>();
272 for task in spawned {
273 task.await;
274 }
275 }
276
277 /// Returns a task that will complete after the given duration.
278 /// Depending on other concurrent tasks the elapsed duration may be longer
279 /// than reqested.
280 pub fn timer(&self, duration: Duration) -> Task<()> {
281 let (runnable, task) = async_task::spawn(async move {}, {
282 let dispatcher = self.dispatcher.clone();
283 move |runnable| dispatcher.dispatch_after(duration, runnable)
284 });
285 runnable.schedule();
286 Task::Spawned(task)
287 }
288
289 /// in tests, start_waiting lets you indicate which task is waiting (for debugging only)
290 #[cfg(any(test, feature = "test-support"))]
291 pub fn start_waiting(&self) {
292 self.dispatcher.as_test().unwrap().start_waiting();
293 }
294
295 /// in tests, removes the debugging data added by start_waiting
296 #[cfg(any(test, feature = "test-support"))]
297 pub fn finish_waiting(&self) {
298 self.dispatcher.as_test().unwrap().finish_waiting();
299 }
300
301 /// in tests, run an arbitrary number of tasks (determined by the SEED environment variable)
302 #[cfg(any(test, feature = "test-support"))]
303 pub fn simulate_random_delay(&self) -> impl Future<Output = ()> {
304 self.dispatcher.as_test().unwrap().simulate_random_delay()
305 }
306
307 /// in tests, indicate that a given task from `spawn_labeled` should run after everything else
308 #[cfg(any(test, feature = "test-support"))]
309 pub fn deprioritize(&self, task_label: TaskLabel) {
310 self.dispatcher.as_test().unwrap().deprioritize(task_label)
311 }
312
313 /// in tests, move time forward. This does not run any tasks, but does make `timer`s ready.
314 #[cfg(any(test, feature = "test-support"))]
315 pub fn advance_clock(&self, duration: Duration) {
316 self.dispatcher.as_test().unwrap().advance_clock(duration)
317 }
318
319 /// in tests, run one task.
320 #[cfg(any(test, feature = "test-support"))]
321 pub fn tick(&self) -> bool {
322 self.dispatcher.as_test().unwrap().tick(false)
323 }
324
325 /// in tests, run all tasks that are ready to run. If after doing so
326 /// the test still has outstanding tasks, this will panic. (See also `allow_parking`)
327 #[cfg(any(test, feature = "test-support"))]
328 pub fn run_until_parked(&self) {
329 self.dispatcher.as_test().unwrap().run_until_parked()
330 }
331
332 /// in tests, prevents `run_until_parked` from panicking if there are outstanding tasks.
333 /// This is useful when you are integrating other (non-GPUI) futures, like disk access, that
334 /// do take real async time to run.
335 #[cfg(any(test, feature = "test-support"))]
336 pub fn allow_parking(&self) {
337 self.dispatcher.as_test().unwrap().allow_parking();
338 }
339
340 /// in tests, returns the rng used by the dispatcher and seeded by the `SEED` environment variable
341 #[cfg(any(test, feature = "test-support"))]
342 pub fn rng(&self) -> StdRng {
343 self.dispatcher.as_test().unwrap().rng()
344 }
345
346 /// How many CPUs are available to the dispatcher
347 pub fn num_cpus(&self) -> usize {
348 num_cpus::get()
349 }
350
351 /// Whether we're on the main thread.
352 pub fn is_main_thread(&self) -> bool {
353 self.dispatcher.is_main_thread()
354 }
355
356 #[cfg(any(test, feature = "test-support"))]
357 /// in tests, control the number of ticks that `block_with_timeout` will run before timing out.
358 pub fn set_block_on_ticks(&self, range: std::ops::RangeInclusive<usize>) {
359 self.dispatcher.as_test().unwrap().set_block_on_ticks(range);
360 }
361}
362
363/// ForegroundExecutor runs things on the main thread.
364impl ForegroundExecutor {
365 pub fn new(dispatcher: Arc<dyn PlatformDispatcher>) -> Self {
366 Self {
367 dispatcher,
368 not_send: PhantomData,
369 }
370 }
371
372 /// Enqueues the given Task to run on the main thread at some point in the future.
373 pub fn spawn<R>(&self, future: impl Future<Output = R> + 'static) -> Task<R>
374 where
375 R: 'static,
376 {
377 let dispatcher = self.dispatcher.clone();
378 fn inner<R: 'static>(
379 dispatcher: Arc<dyn PlatformDispatcher>,
380 future: AnyLocalFuture<R>,
381 ) -> Task<R> {
382 let (runnable, task) = async_task::spawn_local(future, move |runnable| {
383 dispatcher.dispatch_on_main_thread(runnable)
384 });
385 runnable.schedule();
386 Task::Spawned(task)
387 }
388 inner::<R>(dispatcher, Box::pin(future))
389 }
390}
391
392/// Scope manages a set of tasks that are enqueued and waited on together. See [`BackgroundExecutor::scoped`].
393pub struct Scope<'a> {
394 executor: BackgroundExecutor,
395 futures: Vec<Pin<Box<dyn Future<Output = ()> + Send + 'static>>>,
396 tx: Option<mpsc::Sender<()>>,
397 rx: mpsc::Receiver<()>,
398 lifetime: PhantomData<&'a ()>,
399}
400
401impl<'a> Scope<'a> {
402 fn new(executor: BackgroundExecutor) -> Self {
403 let (tx, rx) = mpsc::channel(1);
404 Self {
405 executor,
406 tx: Some(tx),
407 rx,
408 futures: Default::default(),
409 lifetime: PhantomData,
410 }
411 }
412
413 pub fn spawn<F>(&mut self, f: F)
414 where
415 F: Future<Output = ()> + Send + 'a,
416 {
417 let tx = self.tx.clone().unwrap();
418
419 // Safety: The 'a lifetime is guaranteed to outlive any of these futures because
420 // dropping this `Scope` blocks until all of the futures have resolved.
421 let f = unsafe {
422 mem::transmute::<
423 Pin<Box<dyn Future<Output = ()> + Send + 'a>>,
424 Pin<Box<dyn Future<Output = ()> + Send + 'static>>,
425 >(Box::pin(async move {
426 f.await;
427 drop(tx);
428 }))
429 };
430 self.futures.push(f);
431 }
432}
433
434impl<'a> Drop for Scope<'a> {
435 fn drop(&mut self) {
436 self.tx.take().unwrap();
437
438 // Wait until the channel is closed, which means that all of the spawned
439 // futures have resolved.
440 self.executor.block(self.rx.next());
441 }
442}