1use anyhow::{anyhow, Result};
2use async_task::Runnable;
3use backtrace::{Backtrace, BacktraceFmt, BytesOrWideString};
4use parking_lot::Mutex;
5use postage::{barrier, prelude::Stream as _};
6use rand::prelude::*;
7use smol::{channel, prelude::*, Executor, Timer};
8use std::{
9 any::Any,
10 fmt::{self, Debug},
11 marker::PhantomData,
12 mem,
13 ops::RangeInclusive,
14 pin::Pin,
15 rc::Rc,
16 sync::{
17 atomic::{AtomicBool, Ordering::SeqCst},
18 Arc,
19 },
20 task::{Context, Poll},
21 thread,
22 time::{Duration, Instant},
23};
24use waker_fn::waker_fn;
25
26use crate::{
27 platform::{self, Dispatcher},
28 util,
29};
30
31pub enum Foreground {
32 Platform {
33 dispatcher: Arc<dyn platform::Dispatcher>,
34 _not_send_or_sync: PhantomData<Rc<()>>,
35 },
36 Test(smol::LocalExecutor<'static>),
37 Deterministic(Arc<Deterministic>),
38}
39
40pub enum Background {
41 Deterministic {
42 executor: Arc<Deterministic>,
43 critical_tasks: Mutex<Vec<Task<()>>>,
44 },
45 Production {
46 executor: Arc<smol::Executor<'static>>,
47 critical_tasks: Mutex<Vec<Task<()>>>,
48 _stop: channel::Sender<()>,
49 },
50}
51
52type AnyLocalFuture = Pin<Box<dyn 'static + Future<Output = Box<dyn Any + 'static>>>>;
53type AnyFuture = Pin<Box<dyn 'static + Send + Future<Output = Box<dyn Any + Send + 'static>>>>;
54type AnyTask = async_task::Task<Box<dyn Any + Send + 'static>>;
55type AnyLocalTask = async_task::Task<Box<dyn Any + 'static>>;
56
57#[must_use]
58pub enum Task<T> {
59 Local {
60 any_task: AnyLocalTask,
61 result_type: PhantomData<T>,
62 },
63 Send {
64 any_task: AnyTask,
65 result_type: PhantomData<T>,
66 },
67}
68
69unsafe impl<T: Send> Send for Task<T> {}
70
71struct DeterministicState {
72 rng: StdRng,
73 seed: u64,
74 scheduled_from_foreground: Vec<(Runnable, Backtrace)>,
75 scheduled_from_background: Vec<(Runnable, Backtrace)>,
76 spawned_from_foreground: Vec<(Runnable, Backtrace)>,
77 forbid_parking: bool,
78 block_on_ticks: RangeInclusive<usize>,
79 now: Instant,
80 pending_timers: Vec<(Instant, barrier::Sender)>,
81}
82
83pub struct Deterministic {
84 state: Arc<Mutex<DeterministicState>>,
85 parker: Mutex<parking::Parker>,
86}
87
88impl Deterministic {
89 fn new(seed: u64) -> Self {
90 Self {
91 state: Arc::new(Mutex::new(DeterministicState {
92 rng: StdRng::seed_from_u64(seed),
93 seed,
94 scheduled_from_foreground: Default::default(),
95 scheduled_from_background: Default::default(),
96 spawned_from_foreground: Default::default(),
97 forbid_parking: false,
98 block_on_ticks: 0..=1000,
99 now: Instant::now(),
100 pending_timers: Default::default(),
101 })),
102 parker: Default::default(),
103 }
104 }
105
106 fn spawn_from_foreground(&self, future: AnyLocalFuture) -> AnyLocalTask {
107 let backtrace = Backtrace::new_unresolved();
108 let scheduled_once = AtomicBool::new(false);
109 let state = self.state.clone();
110 let unparker = self.parker.lock().unparker();
111 let (runnable, task) = async_task::spawn_local(future, move |runnable| {
112 let mut state = state.lock();
113 let backtrace = backtrace.clone();
114 if scheduled_once.fetch_or(true, SeqCst) {
115 state.scheduled_from_foreground.push((runnable, backtrace));
116 } else {
117 state.spawned_from_foreground.push((runnable, backtrace));
118 }
119 unparker.unpark();
120 });
121 runnable.schedule();
122 task
123 }
124
125 fn spawn(&self, future: AnyFuture) -> AnyTask {
126 let backtrace = Backtrace::new_unresolved();
127 let state = self.state.clone();
128 let unparker = self.parker.lock().unparker();
129 let (runnable, task) = async_task::spawn(future, move |runnable| {
130 let mut state = state.lock();
131 state
132 .scheduled_from_background
133 .push((runnable, backtrace.clone()));
134 unparker.unpark();
135 });
136 runnable.schedule();
137 task
138 }
139
140 fn run(&self, mut future: AnyLocalFuture) -> Box<dyn Any> {
141 let woken = Arc::new(AtomicBool::new(false));
142 loop {
143 if let Some(result) = self.run_internal(woken.clone(), &mut future) {
144 return result;
145 }
146
147 if !woken.load(SeqCst) && self.state.lock().forbid_parking {
148 panic!("deterministic executor parked after a call to forbid_parking");
149 }
150
151 woken.store(false, SeqCst);
152 self.parker.lock().park();
153 }
154 }
155
156 fn run_until_parked(&self) {
157 let woken = Arc::new(AtomicBool::new(false));
158 let mut future = any_local_future(std::future::pending::<()>());
159 self.run_internal(woken, &mut future);
160 }
161
162 fn run_internal(
163 &self,
164 woken: Arc<AtomicBool>,
165 future: &mut AnyLocalFuture,
166 ) -> Option<Box<dyn Any>> {
167 let unparker = self.parker.lock().unparker();
168 let waker = waker_fn(move || {
169 woken.store(true, SeqCst);
170 unparker.unpark();
171 });
172
173 let mut cx = Context::from_waker(&waker);
174 let mut trace = Trace::default();
175 loop {
176 let mut state = self.state.lock();
177 let runnable_count = state.scheduled_from_foreground.len()
178 + state.scheduled_from_background.len()
179 + state.spawned_from_foreground.len();
180
181 let ix = state.rng.gen_range(0..=runnable_count);
182 if ix < state.scheduled_from_foreground.len() {
183 let (_, backtrace) = &state.scheduled_from_foreground[ix];
184 trace.record(&state, backtrace.clone());
185 let runnable = state.scheduled_from_foreground.remove(ix).0;
186 drop(state);
187 runnable.run();
188 } else if ix - state.scheduled_from_foreground.len()
189 < state.scheduled_from_background.len()
190 {
191 let ix = ix - state.scheduled_from_foreground.len();
192 let (_, backtrace) = &state.scheduled_from_background[ix];
193 trace.record(&state, backtrace.clone());
194 let runnable = state.scheduled_from_background.remove(ix).0;
195 drop(state);
196 runnable.run();
197 } else if ix < runnable_count {
198 let (_, backtrace) = &state.spawned_from_foreground[0];
199 trace.record(&state, backtrace.clone());
200 let runnable = state.spawned_from_foreground.remove(0).0;
201 drop(state);
202 runnable.run();
203 } else {
204 drop(state);
205 if let Poll::Ready(result) = future.poll(&mut cx) {
206 return Some(result);
207 }
208
209 let state = self.state.lock();
210 if state.scheduled_from_foreground.is_empty()
211 && state.scheduled_from_background.is_empty()
212 && state.spawned_from_foreground.is_empty()
213 {
214 return None;
215 }
216 }
217 }
218 }
219
220 fn block_on(&self, future: &mut AnyLocalFuture) -> Option<Box<dyn Any>> {
221 let unparker = self.parker.lock().unparker();
222 let waker = waker_fn(move || {
223 unparker.unpark();
224 });
225 let max_ticks = {
226 let mut state = self.state.lock();
227 let range = state.block_on_ticks.clone();
228 state.rng.gen_range(range)
229 };
230
231 let mut cx = Context::from_waker(&waker);
232 let mut trace = Trace::default();
233 for _ in 0..max_ticks {
234 let mut state = self.state.lock();
235 let runnable_count = state.scheduled_from_background.len();
236 let ix = state.rng.gen_range(0..=runnable_count);
237 if ix < state.scheduled_from_background.len() {
238 let (_, backtrace) = &state.scheduled_from_background[ix];
239 trace.record(&state, backtrace.clone());
240 let runnable = state.scheduled_from_background.remove(ix).0;
241 drop(state);
242 runnable.run();
243 } else {
244 drop(state);
245 if let Poll::Ready(result) = future.as_mut().poll(&mut cx) {
246 return Some(result);
247 }
248 let state = self.state.lock();
249 if state.scheduled_from_background.is_empty() {
250 if state.forbid_parking {
251 panic!("deterministic executor parked after a call to forbid_parking");
252 }
253 drop(state);
254 self.parker.lock().park();
255 }
256
257 continue;
258 }
259 }
260
261 None
262 }
263}
264
265#[derive(Default)]
266struct Trace {
267 executed: Vec<Backtrace>,
268 scheduled: Vec<Vec<Backtrace>>,
269 spawned_from_foreground: Vec<Vec<Backtrace>>,
270}
271
272impl Trace {
273 fn record(&mut self, state: &DeterministicState, executed: Backtrace) {
274 self.scheduled.push(
275 state
276 .scheduled_from_foreground
277 .iter()
278 .map(|(_, backtrace)| backtrace.clone())
279 .collect(),
280 );
281 self.spawned_from_foreground.push(
282 state
283 .spawned_from_foreground
284 .iter()
285 .map(|(_, backtrace)| backtrace.clone())
286 .collect(),
287 );
288 self.executed.push(executed);
289 }
290
291 fn resolve(&mut self) {
292 for backtrace in &mut self.executed {
293 backtrace.resolve();
294 }
295
296 for backtraces in &mut self.scheduled {
297 for backtrace in backtraces {
298 backtrace.resolve();
299 }
300 }
301
302 for backtraces in &mut self.spawned_from_foreground {
303 for backtrace in backtraces {
304 backtrace.resolve();
305 }
306 }
307 }
308}
309
310impl Debug for Trace {
311 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
312 struct FirstCwdFrameInBacktrace<'a>(&'a Backtrace);
313
314 impl<'a> Debug for FirstCwdFrameInBacktrace<'a> {
315 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
316 let cwd = std::env::current_dir().unwrap();
317 let mut print_path = |fmt: &mut fmt::Formatter<'_>, path: BytesOrWideString<'_>| {
318 fmt::Display::fmt(&path, fmt)
319 };
320 let mut fmt = BacktraceFmt::new(f, backtrace::PrintFmt::Full, &mut print_path);
321 for frame in self.0.frames() {
322 let mut formatted_frame = fmt.frame();
323 if frame
324 .symbols()
325 .iter()
326 .any(|s| s.filename().map_or(false, |f| f.starts_with(&cwd)))
327 {
328 formatted_frame.backtrace_frame(frame)?;
329 break;
330 }
331 }
332 fmt.finish()
333 }
334 }
335
336 for ((backtrace, scheduled), spawned_from_foreground) in self
337 .executed
338 .iter()
339 .zip(&self.scheduled)
340 .zip(&self.spawned_from_foreground)
341 {
342 writeln!(f, "Scheduled")?;
343 for backtrace in scheduled {
344 writeln!(f, "- {:?}", FirstCwdFrameInBacktrace(backtrace))?;
345 }
346 if scheduled.is_empty() {
347 writeln!(f, "None")?;
348 }
349 writeln!(f, "==========")?;
350
351 writeln!(f, "Spawned from foreground")?;
352 for backtrace in spawned_from_foreground {
353 writeln!(f, "- {:?}", FirstCwdFrameInBacktrace(backtrace))?;
354 }
355 if spawned_from_foreground.is_empty() {
356 writeln!(f, "None")?;
357 }
358 writeln!(f, "==========")?;
359
360 writeln!(f, "Run: {:?}", FirstCwdFrameInBacktrace(backtrace))?;
361 writeln!(f, "+++++++++++++++++++")?;
362 }
363
364 Ok(())
365 }
366}
367
368impl Drop for Trace {
369 fn drop(&mut self) {
370 let trace_on_panic = if let Ok(trace_on_panic) = std::env::var("EXECUTOR_TRACE_ON_PANIC") {
371 trace_on_panic == "1" || trace_on_panic == "true"
372 } else {
373 false
374 };
375 let trace_always = if let Ok(trace_always) = std::env::var("EXECUTOR_TRACE_ALWAYS") {
376 trace_always == "1" || trace_always == "true"
377 } else {
378 false
379 };
380
381 if trace_always || (trace_on_panic && thread::panicking()) {
382 self.resolve();
383 dbg!(self);
384 }
385 }
386}
387
388impl Foreground {
389 pub fn platform(dispatcher: Arc<dyn platform::Dispatcher>) -> Result<Self> {
390 if dispatcher.is_main_thread() {
391 Ok(Self::Platform {
392 dispatcher,
393 _not_send_or_sync: PhantomData,
394 })
395 } else {
396 Err(anyhow!("must be constructed on main thread"))
397 }
398 }
399
400 pub fn test() -> Self {
401 Self::Test(smol::LocalExecutor::new())
402 }
403
404 pub fn spawn<T: 'static>(&self, future: impl Future<Output = T> + 'static) -> Task<T> {
405 let future = any_local_future(future);
406 let any_task = match self {
407 Self::Deterministic(executor) => executor.spawn_from_foreground(future),
408 Self::Platform { dispatcher, .. } => {
409 fn spawn_inner(
410 future: AnyLocalFuture,
411 dispatcher: &Arc<dyn Dispatcher>,
412 ) -> AnyLocalTask {
413 let dispatcher = dispatcher.clone();
414 let schedule =
415 move |runnable: Runnable| dispatcher.run_on_main_thread(runnable);
416 let (runnable, task) = async_task::spawn_local(future, schedule);
417 runnable.schedule();
418 task
419 }
420 spawn_inner(future, dispatcher)
421 }
422 Self::Test(executor) => executor.spawn(future),
423 };
424 Task::local(any_task)
425 }
426
427 pub fn run<T: 'static>(&self, future: impl 'static + Future<Output = T>) -> T {
428 let future = any_local_future(future);
429 let any_value = match self {
430 Self::Deterministic(executor) => executor.run(future),
431 Self::Platform { .. } => panic!("you can't call run on a platform foreground executor"),
432 Self::Test(executor) => smol::block_on(executor.run(future)),
433 };
434 *any_value.downcast().unwrap()
435 }
436
437 pub fn forbid_parking(&self) {
438 match self {
439 Self::Deterministic(executor) => {
440 let mut state = executor.state.lock();
441 state.forbid_parking = true;
442 state.rng = StdRng::seed_from_u64(state.seed);
443 }
444 _ => panic!("this method can only be called on a deterministic executor"),
445 }
446 }
447
448 pub async fn timer(&self, duration: Duration) {
449 match self {
450 Self::Deterministic(executor) => {
451 let (tx, mut rx) = barrier::channel();
452 {
453 let mut state = executor.state.lock();
454 let wakeup_at = state.now + duration;
455 state.pending_timers.push((wakeup_at, tx));
456 }
457 rx.recv().await;
458 }
459 _ => {
460 Timer::after(duration).await;
461 }
462 }
463 }
464
465 pub fn advance_clock(&self, duration: Duration) {
466 match self {
467 Self::Deterministic(executor) => {
468 executor.run_until_parked();
469
470 let mut state = executor.state.lock();
471 state.now += duration;
472 let now = state.now;
473 let mut pending_timers = mem::take(&mut state.pending_timers);
474 drop(state);
475
476 pending_timers.retain(|(wakeup, _)| *wakeup > now);
477 executor.state.lock().pending_timers.extend(pending_timers);
478 }
479 _ => panic!("this method can only be called on a deterministic executor"),
480 }
481 }
482
483 pub fn set_block_on_ticks(&self, range: RangeInclusive<usize>) {
484 match self {
485 Self::Deterministic(executor) => executor.state.lock().block_on_ticks = range,
486 _ => panic!("this method can only be called on a deterministic executor"),
487 }
488 }
489}
490
491impl Background {
492 pub fn new() -> Self {
493 let executor = Arc::new(Executor::new());
494 let stop = channel::unbounded::<()>();
495
496 for i in 0..2 * num_cpus::get() {
497 let executor = executor.clone();
498 let stop = stop.1.clone();
499 thread::Builder::new()
500 .name(format!("background-executor-{}", i))
501 .spawn(move || smol::block_on(executor.run(stop.recv())))
502 .unwrap();
503 }
504
505 Self::Production {
506 executor,
507 critical_tasks: Default::default(),
508 _stop: stop.0,
509 }
510 }
511
512 pub fn num_cpus(&self) -> usize {
513 num_cpus::get()
514 }
515
516 pub fn spawn<T, F>(&self, future: F) -> Task<T>
517 where
518 T: 'static + Send,
519 F: Send + Future<Output = T> + 'static,
520 {
521 let future = any_future(future);
522 let any_task = match self {
523 Self::Production { executor, .. } => executor.spawn(future),
524 Self::Deterministic { executor, .. } => executor.spawn(future),
525 };
526 Task::send(any_task)
527 }
528
529 pub fn spawn_critical<T, F>(&self, future: F)
530 where
531 T: 'static + Send,
532 F: Send + Future<Output = T> + 'static,
533 {
534 let task = self.spawn(async move {
535 future.await;
536 });
537 match self {
538 Self::Production { critical_tasks, .. }
539 | Self::Deterministic { critical_tasks, .. } => critical_tasks.lock().push(task),
540 }
541 }
542
543 pub fn block_on_critical_tasks(&self, timeout: Duration) -> bool {
544 match self {
545 Background::Production { critical_tasks, .. }
546 | Self::Deterministic { critical_tasks, .. } => {
547 let tasks = mem::take(&mut *critical_tasks.lock());
548 self.block_with_timeout(timeout, futures::future::join_all(tasks))
549 .is_ok()
550 }
551 }
552 }
553
554 pub fn block_with_timeout<F, T>(
555 &self,
556 timeout: Duration,
557 future: F,
558 ) -> Result<T, impl Future<Output = T>>
559 where
560 T: 'static,
561 F: 'static + Unpin + Future<Output = T>,
562 {
563 let mut future = any_local_future(future);
564 if !timeout.is_zero() {
565 let output = match self {
566 Self::Production { .. } => smol::block_on(util::timeout(timeout, &mut future)).ok(),
567 Self::Deterministic { executor, .. } => executor.block_on(&mut future),
568 };
569 if let Some(output) = output {
570 return Ok(*output.downcast().unwrap());
571 }
572 }
573 Err(async { *future.await.downcast().unwrap() })
574 }
575
576 pub async fn scoped<'scope, F>(&self, scheduler: F)
577 where
578 F: FnOnce(&mut Scope<'scope>),
579 {
580 let mut scope = Scope {
581 futures: Default::default(),
582 _phantom: PhantomData,
583 };
584 (scheduler)(&mut scope);
585 let spawned = scope
586 .futures
587 .into_iter()
588 .map(|f| self.spawn(f))
589 .collect::<Vec<_>>();
590 for task in spawned {
591 task.await;
592 }
593 }
594}
595
596pub struct Scope<'a> {
597 futures: Vec<Pin<Box<dyn Future<Output = ()> + Send + 'static>>>,
598 _phantom: PhantomData<&'a ()>,
599}
600
601impl<'a> Scope<'a> {
602 pub fn spawn<F>(&mut self, f: F)
603 where
604 F: Future<Output = ()> + Send + 'a,
605 {
606 let f = unsafe {
607 mem::transmute::<
608 Pin<Box<dyn Future<Output = ()> + Send + 'a>>,
609 Pin<Box<dyn Future<Output = ()> + Send + 'static>>,
610 >(Box::pin(f))
611 };
612 self.futures.push(f);
613 }
614}
615
616pub fn deterministic(seed: u64) -> (Rc<Foreground>, Arc<Background>) {
617 let executor = Arc::new(Deterministic::new(seed));
618 (
619 Rc::new(Foreground::Deterministic(executor.clone())),
620 Arc::new(Background::Deterministic {
621 executor,
622 critical_tasks: Default::default(),
623 }),
624 )
625}
626
627impl<T> Task<T> {
628 fn local(any_task: AnyLocalTask) -> Self {
629 Self::Local {
630 any_task,
631 result_type: PhantomData,
632 }
633 }
634
635 pub fn detach(self) {
636 match self {
637 Task::Local { any_task, .. } => any_task.detach(),
638 Task::Send { any_task, .. } => any_task.detach(),
639 }
640 }
641}
642
643impl<T: Send> Task<T> {
644 fn send(any_task: AnyTask) -> Self {
645 Self::Send {
646 any_task,
647 result_type: PhantomData,
648 }
649 }
650}
651
652impl<T: fmt::Debug> fmt::Debug for Task<T> {
653 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
654 match self {
655 Task::Local { any_task, .. } => any_task.fmt(f),
656 Task::Send { any_task, .. } => any_task.fmt(f),
657 }
658 }
659}
660
661impl<T: 'static> Future for Task<T> {
662 type Output = T;
663
664 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
665 match unsafe { self.get_unchecked_mut() } {
666 Task::Local { any_task, .. } => {
667 any_task.poll(cx).map(|value| *value.downcast().unwrap())
668 }
669 Task::Send { any_task, .. } => {
670 any_task.poll(cx).map(|value| *value.downcast().unwrap())
671 }
672 }
673 }
674}
675
676fn any_future<T, F>(future: F) -> AnyFuture
677where
678 T: 'static + Send,
679 F: Future<Output = T> + Send + 'static,
680{
681 async { Box::new(future.await) as Box<dyn Any + Send> }.boxed()
682}
683
684fn any_local_future<T, F>(future: F) -> AnyLocalFuture
685where
686 T: 'static,
687 F: Future<Output = T> + 'static,
688{
689 async { Box::new(future.await) as Box<dyn Any> }.boxed_local()
690}