1use anyhow::Result;
2use async_task::Runnable;
3use parking_lot::Mutex;
4use rand_chacha::rand_core::SeedableRng;
5use rand_chacha::ChaCha8Rng;
6use std::any::Any;
7use std::collections::VecDeque;
8use std::future::Future;
9use std::marker::PhantomData;
10use std::sync::Arc;
11
12use std::thread::{self, ThreadId};
13
14pub trait Scheduler: Send + Sync + Any {
15 fn schedule_foreground(&self, runnable: Runnable);
16 fn is_main_thread(&self) -> bool;
17}
18
19#[derive(Clone, Copy, PartialEq, Eq, Hash)]
20pub struct TaskId(usize);
21
22pub struct Task<R>(async_task::Task<R>);
23
24impl<R> Task<R> {
25 pub fn id(&self) -> TaskId {
26 TaskId(0) // Placeholder
27 }
28}
29
30pub struct SchedulerConfig {
31 pub randomize_order: bool,
32 pub seed: u64,
33}
34
35impl Default for SchedulerConfig {
36 fn default() -> Self {
37 Self {
38 randomize_order: true,
39 seed: 0,
40 }
41 }
42}
43
44pub struct TestScheduler {
45 inner: Mutex<TestSchedulerInner>,
46}
47
48struct TestSchedulerInner {
49 rng: ChaCha8Rng,
50 foreground_queue: VecDeque<Runnable>,
51 creation_thread_id: ThreadId,
52}
53
54impl TestScheduler {
55 pub fn new(config: SchedulerConfig) -> Self {
56 Self {
57 inner: Mutex::new(TestSchedulerInner {
58 rng: ChaCha8Rng::seed_from_u64(config.seed),
59 foreground_queue: VecDeque::new(),
60 creation_thread_id: thread::current().id(),
61 }),
62 }
63 }
64
65 pub fn tick(&self, background_only: bool) -> bool {
66 let mut inner = self.inner.lock();
67 if !background_only {
68 if let Some(runnable) = inner.foreground_queue.pop_front() {
69 drop(inner); // Unlock while running
70 runnable.run();
71 return true;
72 }
73 }
74 false
75 }
76
77 pub fn run(&self) {
78 while self.tick(false) {}
79 }
80}
81
82impl Scheduler for TestScheduler {
83 fn schedule_foreground(&self, runnable: Runnable) {
84 self.inner.lock().foreground_queue.push_back(runnable);
85 }
86
87 fn is_main_thread(&self) -> bool {
88 thread::current().id() == self.inner.lock().creation_thread_id
89 }
90}
91
92pub struct ForegroundExecutor {
93 scheduler: Arc<dyn Scheduler>,
94 _phantom: PhantomData<()>,
95}
96
97impl ForegroundExecutor {
98 pub fn new(scheduler: Arc<dyn Scheduler>) -> Result<Self> {
99 Ok(Self {
100 scheduler,
101 _phantom: PhantomData,
102 })
103 }
104
105 pub fn spawn<R: 'static + Send>(
106 &self,
107 future: impl Future<Output = R> + Send + 'static,
108 ) -> Task<R> {
109 let scheduler = self.scheduler.clone();
110 let (runnable, task) = async_task::spawn_local(future, move |runnable| {
111 scheduler.schedule_foreground(runnable);
112 });
113 runnable.schedule();
114 Task(task)
115 }
116}
117
118#[cfg(test)]
119mod tests {
120 use super::*;
121 use std::sync::atomic::{AtomicBool, Ordering};
122 use std::sync::Arc;
123
124 #[test]
125 fn test_basic_spawn_and_run() {
126 let scheduler = Arc::new(TestScheduler::new(SchedulerConfig::default()));
127 let executor = ForegroundExecutor::new(scheduler.clone()).unwrap();
128
129 let flag = Arc::new(AtomicBool::new(false));
130 assert!(!flag.load(Ordering::SeqCst));
131 let _task = executor.spawn({
132 let flag = flag.clone();
133 async move {
134 flag.store(true, Ordering::SeqCst);
135 }
136 });
137
138 assert!(!flag.load(Ordering::SeqCst));
139
140 scheduler.run();
141
142 assert!(flag.load(Ordering::SeqCst));
143 }
144}