scheduler.rs

  1use anyhow::Result;
  2use async_task::Runnable;
  3use parking_lot::Mutex;
  4use rand_chacha::rand_core::SeedableRng;
  5use rand_chacha::ChaCha8Rng;
  6use std::any::Any;
  7use std::collections::VecDeque;
  8use std::future::Future;
  9use std::marker::PhantomData;
 10use std::sync::Arc;
 11
 12use std::thread::{self, ThreadId};
 13
 14#[derive(Copy, Clone, PartialEq, Eq, Hash)]
 15pub struct TaskLabel(usize);
 16
 17pub trait Scheduler: Send + Sync + Any {
 18    fn schedule(&self, runnable: Runnable, label: Option<TaskLabel>);
 19    fn schedule_foreground(&self, runnable: Runnable, label: Option<TaskLabel>);
 20    fn is_main_thread(&self) -> bool;
 21}
 22
 23#[derive(Clone, Copy, PartialEq, Eq, Hash)]
 24pub struct TaskId(usize);
 25
 26pub struct Task<R>(async_task::Task<R>);
 27
 28impl<R> Task<R> {
 29    pub fn id(&self) -> TaskId {
 30        TaskId(0) // Placeholder
 31    }
 32}
 33
 34impl Default for TaskLabel {
 35    fn default() -> Self {
 36        TaskLabel(0)
 37    }
 38}
 39
 40pub struct SchedulerConfig {
 41    pub randomize_order: bool,
 42    pub seed: u64,
 43}
 44
 45impl Default for SchedulerConfig {
 46    fn default() -> Self {
 47        Self {
 48            randomize_order: true,
 49            seed: 0,
 50        }
 51    }
 52}
 53
 54pub struct TestScheduler {
 55    inner: Mutex<TestSchedulerInner>,
 56}
 57
 58struct TestSchedulerInner {
 59    rng: ChaCha8Rng,
 60    foreground_queue: VecDeque<Runnable>,
 61    creation_thread_id: ThreadId,
 62}
 63
 64impl TestScheduler {
 65    pub fn new(config: SchedulerConfig) -> Self {
 66        Self {
 67            inner: Mutex::new(TestSchedulerInner {
 68                rng: ChaCha8Rng::seed_from_u64(config.seed),
 69                foreground_queue: VecDeque::new(),
 70                creation_thread_id: thread::current().id(),
 71            }),
 72        }
 73    }
 74
 75    pub fn tick(&self, background_only: bool) -> bool {
 76        let mut inner = self.inner.lock();
 77        if !background_only {
 78            if let Some(runnable) = inner.foreground_queue.pop_front() {
 79                drop(inner); // Unlock while running
 80                runnable.run();
 81                return true;
 82            }
 83        }
 84        false
 85    }
 86
 87    pub fn run(&self) {
 88        while self.tick(false) {}
 89    }
 90}
 91
 92impl Scheduler for TestScheduler {
 93    fn schedule(&self, runnable: Runnable, _label: Option<TaskLabel>) {
 94        runnable.run();
 95    }
 96
 97    fn schedule_foreground(&self, runnable: Runnable, _label: Option<TaskLabel>) {
 98        self.inner.lock().foreground_queue.push_back(runnable);
 99    }
100
101    fn is_main_thread(&self) -> bool {
102        thread::current().id() == self.inner.lock().creation_thread_id
103    }
104}
105
106pub struct ForegroundExecutor {
107    scheduler: Arc<dyn Scheduler>,
108    _phantom: PhantomData<()>,
109}
110
111impl ForegroundExecutor {
112    pub fn new(scheduler: Arc<dyn Scheduler>) -> Result<Self> {
113        Ok(Self {
114            scheduler,
115            _phantom: PhantomData,
116        })
117    }
118
119    pub fn spawn<R: 'static>(&self, future: impl Future<Output = R> + 'static) -> Task<R> {
120        let scheduler = self.scheduler.clone();
121        let (runnable, task) = async_task::spawn_local(future, move |runnable| {
122            scheduler.schedule_foreground(runnable, None);
123        });
124        runnable.schedule();
125        Task(task)
126    }
127
128    pub fn spawn_labeled<R: 'static>(
129        &self,
130        future: impl Future<Output = R> + 'static,
131        label: TaskLabel,
132    ) -> Task<R> {
133        let scheduler = self.scheduler.clone();
134        let (runnable, task) = async_task::spawn_local(future, move |runnable| {
135            scheduler.schedule_foreground(runnable, Some(label));
136        });
137        runnable.schedule();
138        Task(task)
139    }
140}
141
142pub struct BackgroundExecutor {
143    scheduler: Arc<dyn Scheduler>,
144}
145
146impl BackgroundExecutor {
147    pub fn new(scheduler: Arc<dyn Scheduler>) -> Result<Self> {
148        Ok(Self { scheduler })
149    }
150
151    pub fn spawn<R: 'static + Send>(
152        &self,
153        future: impl Future<Output = R> + Send + 'static,
154    ) -> Task<R> {
155        let scheduler = self.scheduler.clone();
156        let (runnable, task) = async_task::spawn(future, move |runnable| {
157            scheduler.schedule_foreground(runnable, None);
158        });
159        runnable.schedule();
160        Task(task)
161    }
162}
163
164#[cfg(test)]
165mod tests {
166    use super::*;
167    use std::sync::atomic::{AtomicBool, Ordering};
168    use std::sync::Arc;
169
170    #[test]
171    fn test_basic_spawn_and_run() {
172        let scheduler = Arc::new(TestScheduler::new(SchedulerConfig::default()));
173        let executor = ForegroundExecutor::new(scheduler.clone()).unwrap();
174
175        let flag = Arc::new(AtomicBool::new(false));
176        assert!(!flag.load(Ordering::SeqCst));
177        let _task = executor.spawn({
178            let flag = flag.clone();
179            async move {
180                flag.store(true, Ordering::SeqCst);
181            }
182        });
183
184        assert!(!flag.load(Ordering::SeqCst));
185
186        scheduler.run();
187
188        assert!(flag.load(Ordering::SeqCst));
189    }
190
191    #[test]
192    fn test_background_task_with_foreground_wait() {
193        let scheduler = Arc::new(TestScheduler::new(SchedulerConfig::default()));
194
195        let flag = Arc::new(AtomicBool::new(false));
196        assert!(!flag.load(Ordering::SeqCst));
197
198        // Spawn background task
199        let bg_executor = BackgroundExecutor::new(scheduler.clone()).unwrap();
200        let _background_task = bg_executor.spawn({
201            let flag = flag.clone();
202            async move {
203                flag.store(true, Ordering::SeqCst);
204            }
205        });
206
207        // Spawn foreground task (nothing special, just demonstrates both types)
208        let fg_executor = ForegroundExecutor::new(scheduler.clone()).unwrap();
209        let _fg_task = fg_executor.spawn(async move {
210            // Foreground-specific work if needed
211        });
212
213        // Run all tasks
214        scheduler.run();
215
216        // Background task should have run and set the flag
217        assert!(flag.load(Ordering::SeqCst));
218    }
219}