1use anyhow::Result;
2use async_task::Runnable;
3use parking_lot::Mutex;
4use rand_chacha::rand_core::SeedableRng;
5use rand_chacha::ChaCha8Rng;
6use std::any::Any;
7use std::collections::VecDeque;
8use std::future::Future;
9use std::marker::PhantomData;
10use std::sync::Arc;
11
12use std::thread::{self, ThreadId};
13
14#[derive(Copy, Clone, PartialEq, Eq, Hash)]
15pub struct TaskLabel(usize);
16
17pub trait Scheduler: Send + Sync + Any {
18 fn schedule(&self, runnable: Runnable, label: Option<TaskLabel>);
19 fn schedule_foreground(&self, runnable: Runnable, label: Option<TaskLabel>);
20 fn is_main_thread(&self) -> bool;
21}
22
23#[derive(Clone, Copy, PartialEq, Eq, Hash)]
24pub struct TaskId(usize);
25
26pub struct Task<R>(async_task::Task<R>);
27
28impl<R> Task<R> {
29 pub fn id(&self) -> TaskId {
30 TaskId(0) // Placeholder
31 }
32}
33
34impl Default for TaskLabel {
35 fn default() -> Self {
36 TaskLabel(0)
37 }
38}
39
40pub struct SchedulerConfig {
41 pub randomize_order: bool,
42 pub seed: u64,
43}
44
45impl Default for SchedulerConfig {
46 fn default() -> Self {
47 Self {
48 randomize_order: true,
49 seed: 0,
50 }
51 }
52}
53
54pub struct TestScheduler {
55 inner: Mutex<TestSchedulerInner>,
56}
57
58struct TestSchedulerInner {
59 rng: ChaCha8Rng,
60 foreground_queue: VecDeque<Runnable>,
61 creation_thread_id: ThreadId,
62}
63
64impl TestScheduler {
65 pub fn new(config: SchedulerConfig) -> Self {
66 Self {
67 inner: Mutex::new(TestSchedulerInner {
68 rng: ChaCha8Rng::seed_from_u64(config.seed),
69 foreground_queue: VecDeque::new(),
70 creation_thread_id: thread::current().id(),
71 }),
72 }
73 }
74
75 pub fn tick(&self, background_only: bool) -> bool {
76 let mut inner = self.inner.lock();
77 if !background_only {
78 if let Some(runnable) = inner.foreground_queue.pop_front() {
79 drop(inner); // Unlock while running
80 runnable.run();
81 return true;
82 }
83 }
84 false
85 }
86
87 pub fn run(&self) {
88 while self.tick(false) {}
89 }
90}
91
92impl Scheduler for TestScheduler {
93 fn schedule(&self, runnable: Runnable, _label: Option<TaskLabel>) {
94 runnable.run();
95 }
96
97 fn schedule_foreground(&self, runnable: Runnable, _label: Option<TaskLabel>) {
98 self.inner.lock().foreground_queue.push_back(runnable);
99 }
100
101 fn is_main_thread(&self) -> bool {
102 thread::current().id() == self.inner.lock().creation_thread_id
103 }
104}
105
106pub struct ForegroundExecutor {
107 scheduler: Arc<dyn Scheduler>,
108 _phantom: PhantomData<()>,
109}
110
111impl ForegroundExecutor {
112 pub fn new(scheduler: Arc<dyn Scheduler>) -> Result<Self> {
113 Ok(Self {
114 scheduler,
115 _phantom: PhantomData,
116 })
117 }
118
119 pub fn spawn<R: 'static>(&self, future: impl Future<Output = R> + 'static) -> Task<R> {
120 let scheduler = self.scheduler.clone();
121 let (runnable, task) = async_task::spawn_local(future, move |runnable| {
122 scheduler.schedule_foreground(runnable, None);
123 });
124 runnable.schedule();
125 Task(task)
126 }
127
128 pub fn spawn_labeled<R: 'static>(
129 &self,
130 future: impl Future<Output = R> + 'static,
131 label: TaskLabel,
132 ) -> Task<R> {
133 let scheduler = self.scheduler.clone();
134 let (runnable, task) = async_task::spawn_local(future, move |runnable| {
135 scheduler.schedule_foreground(runnable, Some(label));
136 });
137 runnable.schedule();
138 Task(task)
139 }
140}
141
142pub struct BackgroundExecutor {
143 scheduler: Arc<dyn Scheduler>,
144}
145
146impl BackgroundExecutor {
147 pub fn new(scheduler: Arc<dyn Scheduler>) -> Result<Self> {
148 Ok(Self { scheduler })
149 }
150
151 pub fn spawn<R: 'static + Send>(
152 &self,
153 future: impl Future<Output = R> + Send + 'static,
154 ) -> Task<R> {
155 let scheduler = self.scheduler.clone();
156 let (runnable, task) = async_task::spawn(future, move |runnable| {
157 scheduler.schedule_foreground(runnable, None);
158 });
159 runnable.schedule();
160 Task(task)
161 }
162}
163
164#[cfg(test)]
165mod tests {
166 use super::*;
167 use std::sync::atomic::{AtomicBool, Ordering};
168 use std::sync::Arc;
169
170 #[test]
171 fn test_basic_spawn_and_run() {
172 let scheduler = Arc::new(TestScheduler::new(SchedulerConfig::default()));
173 let executor = ForegroundExecutor::new(scheduler.clone()).unwrap();
174
175 let flag = Arc::new(AtomicBool::new(false));
176 assert!(!flag.load(Ordering::SeqCst));
177 let _task = executor.spawn({
178 let flag = flag.clone();
179 async move {
180 flag.store(true, Ordering::SeqCst);
181 }
182 });
183
184 assert!(!flag.load(Ordering::SeqCst));
185
186 scheduler.run();
187
188 assert!(flag.load(Ordering::SeqCst));
189 }
190
191 #[test]
192 fn test_background_task_with_foreground_wait() {
193 let scheduler = Arc::new(TestScheduler::new(SchedulerConfig::default()));
194
195 let flag = Arc::new(AtomicBool::new(false));
196 assert!(!flag.load(Ordering::SeqCst));
197
198 // Spawn background task
199 let bg_executor = BackgroundExecutor::new(scheduler.clone()).unwrap();
200 let _background_task = bg_executor.spawn({
201 let flag = flag.clone();
202 async move {
203 flag.store(true, Ordering::SeqCst);
204 }
205 });
206
207 // Spawn foreground task (nothing special, just demonstrates both types)
208 let fg_executor = ForegroundExecutor::new(scheduler.clone()).unwrap();
209 let _fg_task = fg_executor.spawn(async move {
210 // Foreground-specific work if needed
211 });
212
213 // Run all tasks
214 scheduler.run();
215
216 // Background task should have run and set the flag
217 assert!(flag.load(Ordering::SeqCst));
218 }
219}