1use crate::{AppContext, PlatformDispatcher};
2use futures::{channel::mpsc, pin_mut, FutureExt};
3use smol::prelude::*;
4use std::{
5 fmt::Debug,
6 marker::PhantomData,
7 mem,
8 num::NonZeroUsize,
9 pin::Pin,
10 rc::Rc,
11 sync::{
12 atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
13 Arc,
14 },
15 task::{Context, Poll},
16 time::Duration,
17};
18use util::TryFutureExt;
19use waker_fn::waker_fn;
20
21#[cfg(any(test, feature = "test-support"))]
22use rand::rngs::StdRng;
23
24#[derive(Clone)]
25pub struct BackgroundExecutor {
26 dispatcher: Arc<dyn PlatformDispatcher>,
27}
28
29#[derive(Clone)]
30pub struct ForegroundExecutor {
31 dispatcher: Arc<dyn PlatformDispatcher>,
32 not_send: PhantomData<Rc<()>>,
33}
34
35/// Task is a primitive that allows work to happen in the background.
36/// It implements Future so you can `.await` on it.
37/// If you drop a task it will be cancelled immediately. Calling `.detach()` allows
38/// the task to continue running in the background, but with no way to return a value.
39#[must_use]
40#[derive(Debug)]
41pub enum Task<T> {
42 Ready(Option<T>),
43 Spawned(async_task::Task<T>),
44}
45
46impl<T> Task<T> {
47 /// Create a new task that will resolve with the value
48 pub fn ready(val: T) -> Self {
49 Task::Ready(Some(val))
50 }
51
52 /// Detaching a task runs it to completion in the background
53 pub fn detach(self) {
54 match self {
55 Task::Ready(_) => {}
56 Task::Spawned(task) => task.detach(),
57 }
58 }
59}
60
61impl<E, T> Task<Result<T, E>>
62where
63 T: 'static,
64 E: 'static + Debug,
65{
66 /// Run the task to completion in the background and log any
67 /// errors that occur.
68 #[track_caller]
69 pub fn detach_and_log_err(self, cx: &mut AppContext) {
70 let location = core::panic::Location::caller();
71 cx.foreground_executor()
72 .spawn(self.log_tracked_err(*location))
73 .detach();
74 }
75}
76
77impl<T> Future for Task<T> {
78 type Output = T;
79
80 fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
81 match unsafe { self.get_unchecked_mut() } {
82 Task::Ready(val) => Poll::Ready(val.take().unwrap()),
83 Task::Spawned(task) => task.poll(cx),
84 }
85 }
86}
87
88#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
89pub struct TaskLabel(NonZeroUsize);
90
91impl Default for TaskLabel {
92 fn default() -> Self {
93 Self::new()
94 }
95}
96
97impl TaskLabel {
98 pub fn new() -> Self {
99 static NEXT_TASK_LABEL: AtomicUsize = AtomicUsize::new(1);
100 Self(NEXT_TASK_LABEL.fetch_add(1, SeqCst).try_into().unwrap())
101 }
102}
103
104type AnyLocalFuture<R> = Pin<Box<dyn 'static + Future<Output = R>>>;
105
106type AnyFuture<R> = Pin<Box<dyn 'static + Send + Future<Output = R>>>;
107
108/// BackgroundExecutor lets you run things on background threads.
109/// In production this is a thread pool with no ordering guarantees.
110/// In tests this is simalated by running tasks one by one in a deterministic
111/// (but arbitrary) order controlled by the `SEED` environment variable.
112impl BackgroundExecutor {
113 pub fn new(dispatcher: Arc<dyn PlatformDispatcher>) -> Self {
114 Self { dispatcher }
115 }
116
117 /// Enqueues the given future to be run to completion on a background thread.
118 pub fn spawn<R>(&self, future: impl Future<Output = R> + Send + 'static) -> Task<R>
119 where
120 R: Send + 'static,
121 {
122 self.spawn_internal::<R>(Box::pin(future), None)
123 }
124
125 /// Enqueues the given future to be run to completion on a background thread.
126 /// The given label can be used to control the priority of the task in tests.
127 pub fn spawn_labeled<R>(
128 &self,
129 label: TaskLabel,
130 future: impl Future<Output = R> + Send + 'static,
131 ) -> Task<R>
132 where
133 R: Send + 'static,
134 {
135 self.spawn_internal::<R>(Box::pin(future), Some(label))
136 }
137
138 fn spawn_internal<R: Send + 'static>(
139 &self,
140 future: AnyFuture<R>,
141 label: Option<TaskLabel>,
142 ) -> Task<R> {
143 let dispatcher = self.dispatcher.clone();
144 let (runnable, task) =
145 async_task::spawn(future, move |runnable| dispatcher.dispatch(runnable, label));
146 runnable.schedule();
147 Task::Spawned(task)
148 }
149
150 /// Used by the test harness to run an async test in a syncronous fashion.
151 #[cfg(any(test, feature = "test-support"))]
152 #[track_caller]
153 pub fn block_test<R>(&self, future: impl Future<Output = R>) -> R {
154 if let Ok(value) = self.block_internal(false, future, usize::MAX) {
155 value
156 } else {
157 unreachable!()
158 }
159 }
160
161 /// Block the current thread until the given future resolves.
162 /// Consider using `block_with_timeout` instead.
163 pub fn block<R>(&self, future: impl Future<Output = R>) -> R {
164 if let Ok(value) = self.block_internal(true, future, usize::MAX) {
165 value
166 } else {
167 unreachable!()
168 }
169 }
170
171 #[track_caller]
172 pub(crate) fn block_internal<R>(
173 &self,
174 background_only: bool,
175 future: impl Future<Output = R>,
176 mut max_ticks: usize,
177 ) -> Result<R, ()> {
178 pin_mut!(future);
179 let unparker = self.dispatcher.unparker();
180 let awoken = Arc::new(AtomicBool::new(false));
181
182 let waker = waker_fn({
183 let awoken = awoken.clone();
184 move || {
185 awoken.store(true, SeqCst);
186 unparker.unpark();
187 }
188 });
189 let mut cx = std::task::Context::from_waker(&waker);
190
191 loop {
192 match future.as_mut().poll(&mut cx) {
193 Poll::Ready(result) => return Ok(result),
194 Poll::Pending => {
195 if max_ticks == 0 {
196 return Err(());
197 }
198 max_ticks -= 1;
199
200 if !self.dispatcher.tick(background_only) {
201 if awoken.swap(false, SeqCst) {
202 continue;
203 }
204
205 #[cfg(any(test, feature = "test-support"))]
206 if let Some(test) = self.dispatcher.as_test() {
207 if !test.parking_allowed() {
208 let mut backtrace_message = String::new();
209 if let Some(backtrace) = test.waiting_backtrace() {
210 backtrace_message =
211 format!("\nbacktrace of waiting future:\n{:?}", backtrace);
212 }
213 panic!("parked with nothing left to run\n{:?}", backtrace_message)
214 }
215 }
216
217 self.dispatcher.park();
218 }
219 }
220 }
221 }
222 }
223
224 /// Block the current thread until the given future resolves
225 /// or `duration` has elapsed.
226 pub fn block_with_timeout<R>(
227 &self,
228 duration: Duration,
229 future: impl Future<Output = R>,
230 ) -> Result<R, impl Future<Output = R>> {
231 let mut future = Box::pin(future.fuse());
232 if duration.is_zero() {
233 return Err(future);
234 }
235
236 #[cfg(any(test, feature = "test-support"))]
237 let max_ticks = self
238 .dispatcher
239 .as_test()
240 .map_or(usize::MAX, |dispatcher| dispatcher.gen_block_on_ticks());
241 #[cfg(not(any(test, feature = "test-support")))]
242 let max_ticks = usize::MAX;
243
244 let mut timer = self.timer(duration).fuse();
245
246 let timeout = async {
247 futures::select_biased! {
248 value = future => Ok(value),
249 _ = timer => Err(()),
250 }
251 };
252 match self.block_internal(true, timeout, max_ticks) {
253 Ok(Ok(value)) => Ok(value),
254 _ => Err(future),
255 }
256 }
257
258 /// Scoped lets you start a number of tasks and waits
259 /// for all of them to complete before returning.
260 pub async fn scoped<'scope, F>(&self, scheduler: F)
261 where
262 F: FnOnce(&mut Scope<'scope>),
263 {
264 let mut scope = Scope::new(self.clone());
265 (scheduler)(&mut scope);
266 let spawned = mem::take(&mut scope.futures)
267 .into_iter()
268 .map(|f| self.spawn(f))
269 .collect::<Vec<_>>();
270 for task in spawned {
271 task.await;
272 }
273 }
274
275 /// Returns a task that will complete after the given duration.
276 /// Depending on other concurrent tasks the elapsed duration may be longer
277 /// than reqested.
278 pub fn timer(&self, duration: Duration) -> Task<()> {
279 let (runnable, task) = async_task::spawn(async move {}, {
280 let dispatcher = self.dispatcher.clone();
281 move |runnable| dispatcher.dispatch_after(duration, runnable)
282 });
283 runnable.schedule();
284 Task::Spawned(task)
285 }
286
287 /// in tests, start_waiting lets you indicate which task is waiting (for debugging only)
288 #[cfg(any(test, feature = "test-support"))]
289 pub fn start_waiting(&self) {
290 self.dispatcher.as_test().unwrap().start_waiting();
291 }
292
293 /// in tests, removes the debugging data added by start_waiting
294 #[cfg(any(test, feature = "test-support"))]
295 pub fn finish_waiting(&self) {
296 self.dispatcher.as_test().unwrap().finish_waiting();
297 }
298
299 /// in tests, run an arbitrary number of tasks (determined by the SEED environment variable)
300 #[cfg(any(test, feature = "test-support"))]
301 pub fn simulate_random_delay(&self) -> impl Future<Output = ()> {
302 self.dispatcher.as_test().unwrap().simulate_random_delay()
303 }
304
305 /// in tests, indicate that a given task from `spawn_labeled` should run after everything else
306 #[cfg(any(test, feature = "test-support"))]
307 pub fn deprioritize(&self, task_label: TaskLabel) {
308 self.dispatcher.as_test().unwrap().deprioritize(task_label)
309 }
310
311 /// in tests, move time forward. This does not run any tasks, but does make `timer`s ready.
312 #[cfg(any(test, feature = "test-support"))]
313 pub fn advance_clock(&self, duration: Duration) {
314 self.dispatcher.as_test().unwrap().advance_clock(duration)
315 }
316
317 /// in tests, run one task.
318 #[cfg(any(test, feature = "test-support"))]
319 pub fn tick(&self) -> bool {
320 self.dispatcher.as_test().unwrap().tick(false)
321 }
322
323 /// in tests, run all tasks that are ready to run. If after doing so
324 /// the test still has outstanding tasks, this will panic. (See also `allow_parking`)
325 #[cfg(any(test, feature = "test-support"))]
326 pub fn run_until_parked(&self) {
327 self.dispatcher.as_test().unwrap().run_until_parked()
328 }
329
330 /// in tests, prevents `run_until_parked` from panicking if there are outstanding tasks.
331 /// This is useful when you are integrating other (non-GPUI) futures, like disk access, that
332 /// do take real async time to run.
333 #[cfg(any(test, feature = "test-support"))]
334 pub fn allow_parking(&self) {
335 self.dispatcher.as_test().unwrap().allow_parking();
336 }
337
338 /// in tests, returns the rng used by the dispatcher and seeded by the `SEED` environment variable
339 #[cfg(any(test, feature = "test-support"))]
340 pub fn rng(&self) -> StdRng {
341 self.dispatcher.as_test().unwrap().rng()
342 }
343
344 /// How many CPUs are available to the dispatcher
345 pub fn num_cpus(&self) -> usize {
346 num_cpus::get()
347 }
348
349 /// Whether we're on the main thread.
350 pub fn is_main_thread(&self) -> bool {
351 self.dispatcher.is_main_thread()
352 }
353
354 #[cfg(any(test, feature = "test-support"))]
355 /// in tests, control the number of ticks that `block_with_timeout` will run before timing out.
356 pub fn set_block_on_ticks(&self, range: std::ops::RangeInclusive<usize>) {
357 self.dispatcher.as_test().unwrap().set_block_on_ticks(range);
358 }
359}
360
361/// ForegroundExecutor runs things on the main thread.
362impl ForegroundExecutor {
363 pub fn new(dispatcher: Arc<dyn PlatformDispatcher>) -> Self {
364 Self {
365 dispatcher,
366 not_send: PhantomData,
367 }
368 }
369
370 /// Enqueues the given Task to run on the main thread at some point in the future.
371 pub fn spawn<R>(&self, future: impl Future<Output = R> + 'static) -> Task<R>
372 where
373 R: 'static,
374 {
375 let dispatcher = self.dispatcher.clone();
376 fn inner<R: 'static>(
377 dispatcher: Arc<dyn PlatformDispatcher>,
378 future: AnyLocalFuture<R>,
379 ) -> Task<R> {
380 let (runnable, task) = async_task::spawn_local(future, move |runnable| {
381 dispatcher.dispatch_on_main_thread(runnable)
382 });
383 runnable.schedule();
384 Task::Spawned(task)
385 }
386 inner::<R>(dispatcher, Box::pin(future))
387 }
388}
389
390/// Scope manages a set of tasks that are enqueued and waited on together. See `BackgroundExecutor#scoped`
391pub struct Scope<'a> {
392 executor: BackgroundExecutor,
393 futures: Vec<Pin<Box<dyn Future<Output = ()> + Send + 'static>>>,
394 tx: Option<mpsc::Sender<()>>,
395 rx: mpsc::Receiver<()>,
396 lifetime: PhantomData<&'a ()>,
397}
398
399impl<'a> Scope<'a> {
400 fn new(executor: BackgroundExecutor) -> Self {
401 let (tx, rx) = mpsc::channel(1);
402 Self {
403 executor,
404 tx: Some(tx),
405 rx,
406 futures: Default::default(),
407 lifetime: PhantomData,
408 }
409 }
410
411 pub fn spawn<F>(&mut self, f: F)
412 where
413 F: Future<Output = ()> + Send + 'a,
414 {
415 let tx = self.tx.clone().unwrap();
416
417 // Safety: The 'a lifetime is guaranteed to outlive any of these futures because
418 // dropping this `Scope` blocks until all of the futures have resolved.
419 let f = unsafe {
420 mem::transmute::<
421 Pin<Box<dyn Future<Output = ()> + Send + 'a>>,
422 Pin<Box<dyn Future<Output = ()> + Send + 'static>>,
423 >(Box::pin(async move {
424 f.await;
425 drop(tx);
426 }))
427 };
428 self.futures.push(f);
429 }
430}
431
432impl<'a> Drop for Scope<'a> {
433 fn drop(&mut self) {
434 self.tx.take().unwrap();
435
436 // Wait until the channel is closed, which means that all of the spawned
437 // futures have resolved.
438 self.executor.block(self.rx.next());
439 }
440}