1use anyhow::{anyhow, Result};
2use async_task::Runnable;
3use backtrace::Backtrace;
4use collections::HashMap;
5use parking_lot::Mutex;
6use postage::{barrier, prelude::Stream as _};
7use rand::prelude::*;
8use smol::{channel, future::yield_now, prelude::*, Executor, Timer};
9use std::{
10 any::Any,
11 fmt::{self, Display},
12 marker::PhantomData,
13 mem,
14 ops::RangeInclusive,
15 pin::Pin,
16 rc::Rc,
17 sync::{
18 atomic::{AtomicBool, Ordering::SeqCst},
19 Arc,
20 },
21 task::{Context, Poll},
22 thread,
23 time::{Duration, Instant},
24};
25use waker_fn::waker_fn;
26
27use crate::{
28 platform::{self, Dispatcher},
29 util, MutableAppContext,
30};
31
32pub enum Foreground {
33 Platform {
34 dispatcher: Arc<dyn platform::Dispatcher>,
35 _not_send_or_sync: PhantomData<Rc<()>>,
36 },
37 Deterministic {
38 cx_id: usize,
39 executor: Arc<Deterministic>,
40 },
41}
42
43pub enum Background {
44 Deterministic {
45 executor: Arc<Deterministic>,
46 },
47 Production {
48 executor: Arc<smol::Executor<'static>>,
49 _stop: channel::Sender<()>,
50 },
51}
52
53type AnyLocalFuture = Pin<Box<dyn 'static + Future<Output = Box<dyn Any + 'static>>>>;
54type AnyFuture = Pin<Box<dyn 'static + Send + Future<Output = Box<dyn Any + Send + 'static>>>>;
55type AnyTask = async_task::Task<Box<dyn Any + Send + 'static>>;
56type AnyLocalTask = async_task::Task<Box<dyn Any + 'static>>;
57
58#[must_use]
59pub enum Task<T> {
60 Ready(Option<T>),
61 Local {
62 any_task: AnyLocalTask,
63 result_type: PhantomData<T>,
64 },
65 Send {
66 any_task: AnyTask,
67 result_type: PhantomData<T>,
68 },
69}
70
71unsafe impl<T: Send> Send for Task<T> {}
72
73struct DeterministicState {
74 rng: StdRng,
75 seed: u64,
76 scheduled_from_foreground: HashMap<usize, Vec<ForegroundRunnable>>,
77 scheduled_from_background: Vec<Runnable>,
78 forbid_parking: bool,
79 block_on_ticks: RangeInclusive<usize>,
80 now: Instant,
81 pending_timers: Vec<(Instant, barrier::Sender)>,
82 waiting_backtrace: Option<Backtrace>,
83}
84
85struct ForegroundRunnable {
86 runnable: Runnable,
87 main: bool,
88}
89
90pub struct Deterministic {
91 state: Arc<Mutex<DeterministicState>>,
92 parker: Mutex<parking::Parker>,
93}
94
95impl Deterministic {
96 pub fn new(seed: u64) -> Arc<Self> {
97 Arc::new(Self {
98 state: Arc::new(Mutex::new(DeterministicState {
99 rng: StdRng::seed_from_u64(seed),
100 seed,
101 scheduled_from_foreground: Default::default(),
102 scheduled_from_background: Default::default(),
103 forbid_parking: false,
104 block_on_ticks: 0..=1000,
105 now: Instant::now(),
106 pending_timers: Default::default(),
107 waiting_backtrace: None,
108 })),
109 parker: Default::default(),
110 })
111 }
112
113 pub fn build_background(self: &Arc<Self>) -> Arc<Background> {
114 Arc::new(Background::Deterministic {
115 executor: self.clone(),
116 })
117 }
118
119 pub fn build_foreground(self: &Arc<Self>, id: usize) -> Rc<Foreground> {
120 Rc::new(Foreground::Deterministic {
121 cx_id: id,
122 executor: self.clone(),
123 })
124 }
125
126 fn spawn_from_foreground(
127 &self,
128 cx_id: usize,
129 future: AnyLocalFuture,
130 main: bool,
131 ) -> AnyLocalTask {
132 let state = self.state.clone();
133 let unparker = self.parker.lock().unparker();
134 let (runnable, task) = async_task::spawn_local(future, move |runnable| {
135 let mut state = state.lock();
136 state
137 .scheduled_from_foreground
138 .entry(cx_id)
139 .or_default()
140 .push(ForegroundRunnable { runnable, main });
141 unparker.unpark();
142 });
143 runnable.schedule();
144 task
145 }
146
147 fn spawn(&self, future: AnyFuture) -> AnyTask {
148 let state = self.state.clone();
149 let unparker = self.parker.lock().unparker();
150 let (runnable, task) = async_task::spawn(future, move |runnable| {
151 let mut state = state.lock();
152 state.scheduled_from_background.push(runnable);
153 unparker.unpark();
154 });
155 runnable.schedule();
156 task
157 }
158
159 fn run(&self, cx_id: usize, main_future: AnyLocalFuture) -> Box<dyn Any> {
160 let woken = Arc::new(AtomicBool::new(false));
161 let mut main_task = self.spawn_from_foreground(cx_id, main_future, true);
162
163 loop {
164 if let Some(result) = self.run_internal(woken.clone(), Some(&mut main_task)) {
165 return result;
166 }
167
168 if !woken.load(SeqCst) {
169 self.state.lock().will_park();
170 }
171
172 woken.store(false, SeqCst);
173 self.parker.lock().park();
174 }
175 }
176
177 fn run_until_parked(&self) {
178 let woken = Arc::new(AtomicBool::new(false));
179 self.run_internal(woken, None);
180 }
181
182 fn run_internal(
183 &self,
184 woken: Arc<AtomicBool>,
185 mut main_task: Option<&mut AnyLocalTask>,
186 ) -> Option<Box<dyn Any>> {
187 let unparker = self.parker.lock().unparker();
188 let waker = waker_fn(move || {
189 woken.store(true, SeqCst);
190 unparker.unpark();
191 });
192
193 let mut cx = Context::from_waker(&waker);
194 loop {
195 let mut state = self.state.lock();
196
197 if state.scheduled_from_foreground.is_empty()
198 && state.scheduled_from_background.is_empty()
199 {
200 return None;
201 }
202
203 if !state.scheduled_from_background.is_empty() && state.rng.gen() {
204 let background_len = state.scheduled_from_background.len();
205 let ix = state.rng.gen_range(0..background_len);
206 let runnable = state.scheduled_from_background.remove(ix);
207 drop(state);
208 runnable.run();
209 } else if !state.scheduled_from_foreground.is_empty() {
210 let available_cx_ids = state
211 .scheduled_from_foreground
212 .keys()
213 .copied()
214 .collect::<Vec<_>>();
215 let cx_id_to_run = *available_cx_ids.iter().choose(&mut state.rng).unwrap();
216 let scheduled_from_cx = state
217 .scheduled_from_foreground
218 .get_mut(&cx_id_to_run)
219 .unwrap();
220 let foreground_runnable = scheduled_from_cx.remove(0);
221 if scheduled_from_cx.is_empty() {
222 state.scheduled_from_foreground.remove(&cx_id_to_run);
223 }
224
225 drop(state);
226
227 foreground_runnable.runnable.run();
228 if let Some(main_task) = main_task.as_mut() {
229 if foreground_runnable.main {
230 if let Poll::Ready(result) = main_task.poll(&mut cx) {
231 return Some(result);
232 }
233 }
234 }
235 }
236 }
237 }
238
239 fn block<F, T>(&self, future: &mut F, max_ticks: usize) -> Option<T>
240 where
241 F: Unpin + Future<Output = T>,
242 {
243 let unparker = self.parker.lock().unparker();
244 let waker = waker_fn(move || {
245 unparker.unpark();
246 });
247
248 let mut cx = Context::from_waker(&waker);
249 for _ in 0..max_ticks {
250 let mut state = self.state.lock();
251 let runnable_count = state.scheduled_from_background.len();
252 let ix = state.rng.gen_range(0..=runnable_count);
253 if ix < state.scheduled_from_background.len() {
254 let runnable = state.scheduled_from_background.remove(ix);
255 drop(state);
256 runnable.run();
257 } else {
258 drop(state);
259 if let Poll::Ready(result) = future.poll(&mut cx) {
260 return Some(result);
261 }
262 let mut state = self.state.lock();
263 if state.scheduled_from_background.is_empty() {
264 state.will_park();
265 drop(state);
266 self.parker.lock().park();
267 }
268
269 continue;
270 }
271 }
272
273 None
274 }
275}
276
277impl DeterministicState {
278 fn will_park(&mut self) {
279 if self.forbid_parking {
280 let mut backtrace_message = String::new();
281 if let Some(backtrace) = self.waiting_backtrace.as_mut() {
282 backtrace.resolve();
283 backtrace_message = format!(
284 "\nbacktrace of waiting future:\n{:?}",
285 util::CwdBacktrace(backtrace)
286 );
287 }
288
289 panic!(
290 "deterministic executor parked after a call to forbid_parking{}",
291 backtrace_message
292 );
293 }
294 }
295}
296
297impl Foreground {
298 pub fn platform(dispatcher: Arc<dyn platform::Dispatcher>) -> Result<Self> {
299 if dispatcher.is_main_thread() {
300 Ok(Self::Platform {
301 dispatcher,
302 _not_send_or_sync: PhantomData,
303 })
304 } else {
305 Err(anyhow!("must be constructed on main thread"))
306 }
307 }
308
309 pub fn spawn<T: 'static>(&self, future: impl Future<Output = T> + 'static) -> Task<T> {
310 let future = any_local_future(future);
311 let any_task = match self {
312 Self::Deterministic { cx_id, executor } => {
313 executor.spawn_from_foreground(*cx_id, future, false)
314 }
315 Self::Platform { dispatcher, .. } => {
316 fn spawn_inner(
317 future: AnyLocalFuture,
318 dispatcher: &Arc<dyn Dispatcher>,
319 ) -> AnyLocalTask {
320 let dispatcher = dispatcher.clone();
321 let schedule =
322 move |runnable: Runnable| dispatcher.run_on_main_thread(runnable);
323 let (runnable, task) = async_task::spawn_local(future, schedule);
324 runnable.schedule();
325 task
326 }
327 spawn_inner(future, dispatcher)
328 }
329 };
330 Task::local(any_task)
331 }
332
333 pub fn run<T: 'static>(&self, future: impl 'static + Future<Output = T>) -> T {
334 let future = any_local_future(future);
335 let any_value = match self {
336 Self::Deterministic { cx_id, executor } => executor.run(*cx_id, future),
337 Self::Platform { .. } => panic!("you can't call run on a platform foreground executor"),
338 };
339 *any_value.downcast().unwrap()
340 }
341
342 pub fn run_until_parked(&self) {
343 match self {
344 Self::Deterministic { executor, .. } => executor.run_until_parked(),
345 _ => panic!("this method can only be called on a deterministic executor"),
346 }
347 }
348
349 pub fn parking_forbidden(&self) -> bool {
350 match self {
351 Self::Deterministic { executor, .. } => executor.state.lock().forbid_parking,
352 _ => panic!("this method can only be called on a deterministic executor"),
353 }
354 }
355
356 pub fn start_waiting(&self) {
357 match self {
358 Self::Deterministic { executor, .. } => {
359 executor.state.lock().waiting_backtrace = Some(Backtrace::new_unresolved());
360 }
361 _ => panic!("this method can only be called on a deterministic executor"),
362 }
363 }
364
365 pub fn finish_waiting(&self) {
366 match self {
367 Self::Deterministic { executor, .. } => {
368 executor.state.lock().waiting_backtrace.take();
369 }
370 _ => panic!("this method can only be called on a deterministic executor"),
371 }
372 }
373
374 pub fn forbid_parking(&self) {
375 match self {
376 Self::Deterministic { executor, .. } => {
377 let mut state = executor.state.lock();
378 state.forbid_parking = true;
379 state.rng = StdRng::seed_from_u64(state.seed);
380 }
381 _ => panic!("this method can only be called on a deterministic executor"),
382 }
383 }
384
385 pub async fn timer(&self, duration: Duration) {
386 match self {
387 Self::Deterministic { executor, .. } => {
388 let (tx, mut rx) = barrier::channel();
389 {
390 let mut state = executor.state.lock();
391 let wakeup_at = state.now + duration;
392 state.pending_timers.push((wakeup_at, tx));
393 }
394 rx.recv().await;
395 }
396 _ => {
397 Timer::after(duration).await;
398 }
399 }
400 }
401
402 pub fn advance_clock(&self, duration: Duration) {
403 match self {
404 Self::Deterministic { executor, .. } => {
405 executor.run_until_parked();
406
407 let mut state = executor.state.lock();
408 state.now += duration;
409 let now = state.now;
410 let mut pending_timers = mem::take(&mut state.pending_timers);
411 drop(state);
412
413 pending_timers.retain(|(wakeup, _)| *wakeup > now);
414 executor.state.lock().pending_timers.extend(pending_timers);
415 }
416 _ => panic!("this method can only be called on a deterministic executor"),
417 }
418 }
419
420 pub fn set_block_on_ticks(&self, range: RangeInclusive<usize>) {
421 match self {
422 Self::Deterministic { executor, .. } => executor.state.lock().block_on_ticks = range,
423 _ => panic!("this method can only be called on a deterministic executor"),
424 }
425 }
426}
427
428impl Background {
429 pub fn new() -> Self {
430 let executor = Arc::new(Executor::new());
431 let stop = channel::unbounded::<()>();
432
433 for i in 0..2 * num_cpus::get() {
434 let executor = executor.clone();
435 let stop = stop.1.clone();
436 thread::Builder::new()
437 .name(format!("background-executor-{}", i))
438 .spawn(move || smol::block_on(executor.run(stop.recv())))
439 .unwrap();
440 }
441
442 Self::Production {
443 executor,
444 _stop: stop.0,
445 }
446 }
447
448 pub fn num_cpus(&self) -> usize {
449 num_cpus::get()
450 }
451
452 pub fn spawn<T, F>(&self, future: F) -> Task<T>
453 where
454 T: 'static + Send,
455 F: Send + Future<Output = T> + 'static,
456 {
457 let future = any_future(future);
458 let any_task = match self {
459 Self::Production { executor, .. } => executor.spawn(future),
460 Self::Deterministic { executor } => executor.spawn(future),
461 };
462 Task::send(any_task)
463 }
464
465 pub fn block<F, T>(&self, future: F) -> T
466 where
467 F: Future<Output = T>,
468 {
469 smol::pin!(future);
470 match self {
471 Self::Production { .. } => smol::block_on(&mut future),
472 Self::Deterministic { executor, .. } => {
473 executor.block(&mut future, usize::MAX).unwrap()
474 }
475 }
476 }
477
478 pub fn block_with_timeout<F, T>(
479 &self,
480 timeout: Duration,
481 future: F,
482 ) -> Result<T, impl Future<Output = T>>
483 where
484 T: 'static,
485 F: 'static + Unpin + Future<Output = T>,
486 {
487 let mut future = any_local_future(future);
488 if !timeout.is_zero() {
489 let output = match self {
490 Self::Production { .. } => smol::block_on(util::timeout(timeout, &mut future)).ok(),
491 Self::Deterministic { executor, .. } => {
492 let max_ticks = {
493 let mut state = executor.state.lock();
494 let range = state.block_on_ticks.clone();
495 state.rng.gen_range(range)
496 };
497 executor.block(&mut future, max_ticks)
498 }
499 };
500 if let Some(output) = output {
501 return Ok(*output.downcast().unwrap());
502 }
503 }
504 Err(async { *future.await.downcast().unwrap() })
505 }
506
507 pub async fn scoped<'scope, F>(&self, scheduler: F)
508 where
509 F: FnOnce(&mut Scope<'scope>),
510 {
511 let mut scope = Scope {
512 futures: Default::default(),
513 _phantom: PhantomData,
514 };
515 (scheduler)(&mut scope);
516 let spawned = scope
517 .futures
518 .into_iter()
519 .map(|f| self.spawn(f))
520 .collect::<Vec<_>>();
521 for task in spawned {
522 task.await;
523 }
524 }
525
526 pub async fn simulate_random_delay(&self) {
527 match self {
528 Self::Deterministic { executor, .. } => {
529 if executor.state.lock().rng.gen_bool(0.2) {
530 let yields = executor.state.lock().rng.gen_range(1..=10);
531 for _ in 0..yields {
532 yield_now().await;
533 }
534 }
535 }
536 _ => panic!("this method can only be called on a deterministic executor"),
537 }
538 }
539}
540
541pub struct Scope<'a> {
542 futures: Vec<Pin<Box<dyn Future<Output = ()> + Send + 'static>>>,
543 _phantom: PhantomData<&'a ()>,
544}
545
546impl<'a> Scope<'a> {
547 pub fn spawn<F>(&mut self, f: F)
548 where
549 F: Future<Output = ()> + Send + 'a,
550 {
551 let f = unsafe {
552 mem::transmute::<
553 Pin<Box<dyn Future<Output = ()> + Send + 'a>>,
554 Pin<Box<dyn Future<Output = ()> + Send + 'static>>,
555 >(Box::pin(f))
556 };
557 self.futures.push(f);
558 }
559}
560
561impl<T> Task<T> {
562 pub fn ready(value: T) -> Self {
563 Self::Ready(Some(value))
564 }
565
566 fn local(any_task: AnyLocalTask) -> Self {
567 Self::Local {
568 any_task,
569 result_type: PhantomData,
570 }
571 }
572
573 pub fn detach(self) {
574 match self {
575 Task::Ready(_) => {}
576 Task::Local { any_task, .. } => any_task.detach(),
577 Task::Send { any_task, .. } => any_task.detach(),
578 }
579 }
580}
581
582impl<T: 'static, E: 'static + Display> Task<Result<T, E>> {
583 pub fn detach_and_log_err(self, cx: &mut MutableAppContext) {
584 cx.spawn(|_| async move {
585 if let Err(err) = self.await {
586 log::error!("{}", err);
587 }
588 })
589 .detach();
590 }
591}
592
593impl<T: Send> Task<T> {
594 fn send(any_task: AnyTask) -> Self {
595 Self::Send {
596 any_task,
597 result_type: PhantomData,
598 }
599 }
600}
601
602impl<T: fmt::Debug> fmt::Debug for Task<T> {
603 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
604 match self {
605 Task::Ready(value) => value.fmt(f),
606 Task::Local { any_task, .. } => any_task.fmt(f),
607 Task::Send { any_task, .. } => any_task.fmt(f),
608 }
609 }
610}
611
612impl<T: 'static> Future for Task<T> {
613 type Output = T;
614
615 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
616 match unsafe { self.get_unchecked_mut() } {
617 Task::Ready(value) => Poll::Ready(value.take().unwrap()),
618 Task::Local { any_task, .. } => {
619 any_task.poll(cx).map(|value| *value.downcast().unwrap())
620 }
621 Task::Send { any_task, .. } => {
622 any_task.poll(cx).map(|value| *value.downcast().unwrap())
623 }
624 }
625 }
626}
627
628fn any_future<T, F>(future: F) -> AnyFuture
629where
630 T: 'static + Send,
631 F: Future<Output = T> + Send + 'static,
632{
633 async { Box::new(future.await) as Box<dyn Any + Send> }.boxed()
634}
635
636fn any_local_future<T, F>(future: F) -> AnyLocalFuture
637where
638 T: 'static,
639 F: Future<Output = T> + 'static,
640{
641 async { Box::new(future.await) as Box<dyn Any> }.boxed_local()
642}