scheduler.rs

  1use anyhow::Result;
  2use async_task::Runnable;
  3use parking_lot::Mutex;
  4use rand_chacha::rand_core::SeedableRng;
  5use rand_chacha::ChaCha8Rng;
  6use std::any::Any;
  7use std::collections::VecDeque;
  8use std::future::Future;
  9use std::marker::PhantomData;
 10use std::sync::Arc;
 11
 12use std::thread::{self, ThreadId};
 13
 14pub trait Scheduler: Send + Sync + Any {
 15    fn schedule_foreground(&self, runnable: Runnable);
 16    fn is_main_thread(&self) -> bool;
 17}
 18
 19#[derive(Clone, Copy, PartialEq, Eq, Hash)]
 20pub struct TaskId(usize);
 21
 22pub struct Task<R>(async_task::Task<R>);
 23
 24impl<R> Task<R> {
 25    pub fn id(&self) -> TaskId {
 26        TaskId(0) // Placeholder
 27    }
 28}
 29
 30pub struct SchedulerConfig {
 31    pub randomize_order: bool,
 32    pub seed: u64,
 33}
 34
 35impl Default for SchedulerConfig {
 36    fn default() -> Self {
 37        Self {
 38            randomize_order: true,
 39            seed: 0,
 40        }
 41    }
 42}
 43
 44pub struct TestScheduler {
 45    inner: Mutex<TestSchedulerInner>,
 46}
 47
 48struct TestSchedulerInner {
 49    rng: ChaCha8Rng,
 50    foreground_queue: VecDeque<Runnable>,
 51    creation_thread_id: ThreadId,
 52}
 53
 54impl TestScheduler {
 55    pub fn new(config: SchedulerConfig) -> Self {
 56        Self {
 57            inner: Mutex::new(TestSchedulerInner {
 58                rng: ChaCha8Rng::seed_from_u64(config.seed),
 59                foreground_queue: VecDeque::new(),
 60                creation_thread_id: thread::current().id(),
 61            }),
 62        }
 63    }
 64
 65    pub fn tick(&self, background_only: bool) -> bool {
 66        let mut inner = self.inner.lock();
 67        if !background_only {
 68            if let Some(runnable) = inner.foreground_queue.pop_front() {
 69                drop(inner); // Unlock while running
 70                runnable.run();
 71                return true;
 72            }
 73        }
 74        false
 75    }
 76
 77    pub fn run(&self) {
 78        while self.tick(false) {}
 79    }
 80}
 81
 82impl Scheduler for TestScheduler {
 83    fn schedule_foreground(&self, runnable: Runnable) {
 84        self.inner.lock().foreground_queue.push_back(runnable);
 85    }
 86
 87    fn is_main_thread(&self) -> bool {
 88        thread::current().id() == self.inner.lock().creation_thread_id
 89    }
 90}
 91
 92pub struct ForegroundExecutor {
 93    scheduler: Arc<dyn Scheduler>,
 94    _phantom: PhantomData<()>,
 95}
 96
 97impl ForegroundExecutor {
 98    pub fn new(scheduler: Arc<dyn Scheduler>) -> Result<Self> {
 99        Ok(Self {
100            scheduler,
101            _phantom: PhantomData,
102        })
103    }
104
105    pub fn spawn<R: 'static + Send>(
106        &self,
107        future: impl Future<Output = R> + Send + 'static,
108    ) -> Task<R> {
109        let scheduler = self.scheduler.clone();
110        let (runnable, task) = async_task::spawn_local(future, move |runnable| {
111            scheduler.schedule_foreground(runnable);
112        });
113        runnable.schedule();
114        Task(task)
115    }
116}
117
118#[cfg(test)]
119mod tests {
120    use super::*;
121    use std::sync::atomic::{AtomicBool, Ordering};
122    use std::sync::Arc;
123
124    #[test]
125    fn test_basic_spawn_and_run() {
126        let scheduler = Arc::new(TestScheduler::new(SchedulerConfig::default()));
127        let executor = ForegroundExecutor::new(scheduler.clone()).unwrap();
128
129        let flag = Arc::new(AtomicBool::new(false));
130        assert!(!flag.load(Ordering::SeqCst));
131        let _task = executor.spawn({
132            let flag = flag.clone();
133            async move {
134                flag.store(true, Ordering::SeqCst);
135            }
136        });
137
138        assert!(!flag.load(Ordering::SeqCst));
139
140        scheduler.run();
141
142        assert!(flag.load(Ordering::SeqCst));
143    }
144}