1use anyhow::Result;
2use async_task::Runnable;
3use parking_lot::Mutex;
4use rand_chacha::rand_core::SeedableRng;
5use rand_chacha::ChaCha8Rng;
6use std::any::Any;
7use std::collections::VecDeque;
8use std::future::Future;
9use std::marker::PhantomData;
10use std::sync::Arc;
11
12use futures::channel::oneshot;
13use futures::executor;
14use std::thread::{self, ThreadId};
15
16#[derive(Copy, Clone, PartialEq, Eq, Hash)]
17pub struct TaskLabel(usize);
18
19pub trait Scheduler: Send + Sync + Any {
20 fn schedule(&self, runnable: Runnable, label: Option<TaskLabel>);
21 fn schedule_foreground(&self, runnable: Runnable, label: Option<TaskLabel>);
22 fn is_main_thread(&self) -> bool;
23}
24
25#[derive(Clone, Copy, PartialEq, Eq, Hash)]
26pub struct TaskId(usize);
27
28pub struct Task<R>(async_task::Task<R>);
29
30impl<R> Task<R> {
31 pub fn id(&self) -> TaskId {
32 TaskId(0) // Placeholder
33 }
34}
35
36impl Default for TaskLabel {
37 fn default() -> Self {
38 TaskLabel(0)
39 }
40}
41
42pub struct SchedulerConfig {
43 pub randomize_order: bool,
44 pub seed: u64,
45}
46
47impl Default for SchedulerConfig {
48 fn default() -> Self {
49 Self {
50 randomize_order: true,
51 seed: 0,
52 }
53 }
54}
55
56pub struct TestScheduler {
57 inner: Mutex<TestSchedulerInner>,
58}
59
60struct TestSchedulerInner {
61 rng: ChaCha8Rng,
62 foreground_queue: VecDeque<Runnable>,
63 creation_thread_id: ThreadId,
64}
65
66impl TestScheduler {
67 pub fn new(config: SchedulerConfig) -> Self {
68 Self {
69 inner: Mutex::new(TestSchedulerInner {
70 rng: ChaCha8Rng::seed_from_u64(config.seed),
71 foreground_queue: VecDeque::new(),
72 creation_thread_id: thread::current().id(),
73 }),
74 }
75 }
76
77 pub fn tick(&self, background_only: bool) -> bool {
78 let mut inner = self.inner.lock();
79 if !background_only {
80 if let Some(runnable) = inner.foreground_queue.pop_front() {
81 drop(inner); // Unlock while running
82 runnable.run();
83 return true;
84 }
85 }
86 false
87 }
88
89 pub fn run(&self) {
90 while self.tick(false) {}
91 }
92}
93
94impl Scheduler for TestScheduler {
95 fn schedule(&self, runnable: Runnable, _label: Option<TaskLabel>) {
96 runnable.run();
97 }
98
99 fn schedule_foreground(&self, runnable: Runnable, _label: Option<TaskLabel>) {
100 self.inner.lock().foreground_queue.push_back(runnable);
101 }
102
103 fn is_main_thread(&self) -> bool {
104 thread::current().id() == self.inner.lock().creation_thread_id
105 }
106}
107
108pub struct ForegroundExecutor {
109 scheduler: Arc<dyn Scheduler>,
110 _phantom: PhantomData<()>,
111}
112
113impl ForegroundExecutor {
114 pub fn new(scheduler: Arc<dyn Scheduler>) -> Result<Self> {
115 Ok(Self {
116 scheduler,
117 _phantom: PhantomData,
118 })
119 }
120
121 pub fn spawn<R: 'static>(&self, future: impl Future<Output = R> + 'static) -> Task<R> {
122 let scheduler = self.scheduler.clone();
123 let (runnable, task) = async_task::spawn_local(future, move |runnable| {
124 scheduler.schedule_foreground(runnable, None);
125 });
126 runnable.schedule();
127 Task(task)
128 }
129
130 pub fn spawn_labeled<R: 'static>(
131 &self,
132 future: impl Future<Output = R> + 'static,
133 label: TaskLabel,
134 ) -> Task<R> {
135 let scheduler = self.scheduler.clone();
136 let (runnable, task) = async_task::spawn_local(future, move |runnable| {
137 scheduler.schedule_foreground(runnable, Some(label));
138 });
139 runnable.schedule();
140 Task(task)
141 }
142}
143
144pub struct BackgroundExecutor {
145 scheduler: Arc<dyn Scheduler>,
146}
147
148impl BackgroundExecutor {
149 pub fn new(scheduler: Arc<dyn Scheduler>) -> Result<Self> {
150 Ok(Self { scheduler })
151 }
152
153 pub fn spawn<R: 'static + Send>(
154 &self,
155 future: impl Future<Output = R> + Send + 'static,
156 ) -> Task<R> {
157 let scheduler = self.scheduler.clone();
158 let (runnable, task) = async_task::spawn(future, move |runnable| {
159 scheduler.schedule_foreground(runnable, None);
160 });
161 runnable.schedule();
162 Task(task)
163 }
164}
165
166#[cfg(test)]
167mod tests {
168 use super::*;
169 use std::sync::atomic::{AtomicBool, Ordering};
170 use std::sync::Arc;
171
172 #[test]
173 fn test_basic_spawn_and_run() {
174 let scheduler = Arc::new(TestScheduler::new(SchedulerConfig::default()));
175 let executor = ForegroundExecutor::new(scheduler.clone()).unwrap();
176
177 let flag = Arc::new(AtomicBool::new(false));
178 assert!(!flag.load(Ordering::SeqCst));
179 let _task = executor.spawn({
180 let flag = flag.clone();
181 async move {
182 flag.store(true, Ordering::SeqCst);
183 }
184 });
185
186 assert!(!flag.load(Ordering::SeqCst));
187
188 scheduler.run();
189
190 assert!(flag.load(Ordering::SeqCst));
191 }
192
193 #[test]
194 fn test_background_task_with_foreground_wait() {
195 let scheduler = Arc::new(TestScheduler::new(SchedulerConfig::default()));
196
197 // Create a oneshot channel to send data from background to foreground
198 let (tx, rx) = oneshot::channel();
199
200 // Spawn background task that sends 42
201 let bg_executor = BackgroundExecutor::new(scheduler.clone()).unwrap();
202 let _background_task = bg_executor.spawn(async move {
203 tx.send(42).unwrap();
204 });
205
206 // Run all tasks
207 scheduler.run();
208
209 // Block on receiving the value from the background task
210 let received = executor::block_on(rx).unwrap();
211
212 // Assert on the result
213 assert_eq!(received, 42);
214 }
215}