scheduler.rs

  1use anyhow::Result;
  2use async_task::Runnable;
  3use parking_lot::Mutex;
  4use rand_chacha::rand_core::SeedableRng;
  5use rand_chacha::ChaCha8Rng;
  6use std::any::Any;
  7use std::collections::VecDeque;
  8use std::future::Future;
  9use std::marker::PhantomData;
 10use std::sync::Arc;
 11
 12use futures::channel::oneshot;
 13use futures::executor;
 14use std::thread::{self, ThreadId};
 15
 16#[derive(Copy, Clone, PartialEq, Eq, Hash)]
 17pub struct TaskLabel(usize);
 18
 19pub trait Scheduler: Send + Sync + Any {
 20    fn schedule(&self, runnable: Runnable, label: Option<TaskLabel>);
 21    fn schedule_foreground(&self, runnable: Runnable, label: Option<TaskLabel>);
 22    fn is_main_thread(&self) -> bool;
 23}
 24
 25#[derive(Clone, Copy, PartialEq, Eq, Hash)]
 26pub struct TaskId(usize);
 27
 28pub struct Task<R>(async_task::Task<R>);
 29
 30impl<R> Task<R> {
 31    pub fn id(&self) -> TaskId {
 32        TaskId(0) // Placeholder
 33    }
 34}
 35
 36impl Default for TaskLabel {
 37    fn default() -> Self {
 38        TaskLabel(0)
 39    }
 40}
 41
 42pub struct SchedulerConfig {
 43    pub randomize_order: bool,
 44    pub seed: u64,
 45}
 46
 47impl Default for SchedulerConfig {
 48    fn default() -> Self {
 49        Self {
 50            randomize_order: true,
 51            seed: 0,
 52        }
 53    }
 54}
 55
 56pub struct TestScheduler {
 57    inner: Mutex<TestSchedulerInner>,
 58}
 59
 60struct TestSchedulerInner {
 61    rng: ChaCha8Rng,
 62    foreground_queue: VecDeque<Runnable>,
 63    creation_thread_id: ThreadId,
 64}
 65
 66impl TestScheduler {
 67    pub fn new(config: SchedulerConfig) -> Self {
 68        Self {
 69            inner: Mutex::new(TestSchedulerInner {
 70                rng: ChaCha8Rng::seed_from_u64(config.seed),
 71                foreground_queue: VecDeque::new(),
 72                creation_thread_id: thread::current().id(),
 73            }),
 74        }
 75    }
 76
 77    pub fn tick(&self, background_only: bool) -> bool {
 78        let mut inner = self.inner.lock();
 79        if !background_only {
 80            if let Some(runnable) = inner.foreground_queue.pop_front() {
 81                drop(inner); // Unlock while running
 82                runnable.run();
 83                return true;
 84            }
 85        }
 86        false
 87    }
 88
 89    pub fn run(&self) {
 90        while self.tick(false) {}
 91    }
 92}
 93
 94impl Scheduler for TestScheduler {
 95    fn schedule(&self, runnable: Runnable, _label: Option<TaskLabel>) {
 96        runnable.run();
 97    }
 98
 99    fn schedule_foreground(&self, runnable: Runnable, _label: Option<TaskLabel>) {
100        self.inner.lock().foreground_queue.push_back(runnable);
101    }
102
103    fn is_main_thread(&self) -> bool {
104        thread::current().id() == self.inner.lock().creation_thread_id
105    }
106}
107
108pub struct ForegroundExecutor {
109    scheduler: Arc<dyn Scheduler>,
110    _phantom: PhantomData<()>,
111}
112
113impl ForegroundExecutor {
114    pub fn new(scheduler: Arc<dyn Scheduler>) -> Result<Self> {
115        Ok(Self {
116            scheduler,
117            _phantom: PhantomData,
118        })
119    }
120
121    pub fn spawn<R: 'static>(&self, future: impl Future<Output = R> + 'static) -> Task<R> {
122        let scheduler = self.scheduler.clone();
123        let (runnable, task) = async_task::spawn_local(future, move |runnable| {
124            scheduler.schedule_foreground(runnable, None);
125        });
126        runnable.schedule();
127        Task(task)
128    }
129
130    pub fn spawn_labeled<R: 'static>(
131        &self,
132        future: impl Future<Output = R> + 'static,
133        label: TaskLabel,
134    ) -> Task<R> {
135        let scheduler = self.scheduler.clone();
136        let (runnable, task) = async_task::spawn_local(future, move |runnable| {
137            scheduler.schedule_foreground(runnable, Some(label));
138        });
139        runnable.schedule();
140        Task(task)
141    }
142}
143
144pub struct BackgroundExecutor {
145    scheduler: Arc<dyn Scheduler>,
146}
147
148impl BackgroundExecutor {
149    pub fn new(scheduler: Arc<dyn Scheduler>) -> Result<Self> {
150        Ok(Self { scheduler })
151    }
152
153    pub fn spawn<R: 'static + Send>(
154        &self,
155        future: impl Future<Output = R> + Send + 'static,
156    ) -> Task<R> {
157        let scheduler = self.scheduler.clone();
158        let (runnable, task) = async_task::spawn(future, move |runnable| {
159            scheduler.schedule_foreground(runnable, None);
160        });
161        runnable.schedule();
162        Task(task)
163    }
164}
165
166#[cfg(test)]
167mod tests {
168    use super::*;
169    use std::sync::atomic::{AtomicBool, Ordering};
170    use std::sync::Arc;
171
172    #[test]
173    fn test_basic_spawn_and_run() {
174        let scheduler = Arc::new(TestScheduler::new(SchedulerConfig::default()));
175        let executor = ForegroundExecutor::new(scheduler.clone()).unwrap();
176
177        let flag = Arc::new(AtomicBool::new(false));
178        assert!(!flag.load(Ordering::SeqCst));
179        let _task = executor.spawn({
180            let flag = flag.clone();
181            async move {
182                flag.store(true, Ordering::SeqCst);
183            }
184        });
185
186        assert!(!flag.load(Ordering::SeqCst));
187
188        scheduler.run();
189
190        assert!(flag.load(Ordering::SeqCst));
191    }
192
193    #[test]
194    fn test_background_task_with_foreground_wait() {
195        let scheduler = Arc::new(TestScheduler::new(SchedulerConfig::default()));
196
197        // Create a oneshot channel to send data from background to foreground
198        let (tx, rx) = oneshot::channel();
199
200        // Spawn background task that sends 42
201        let bg_executor = BackgroundExecutor::new(scheduler.clone()).unwrap();
202        let _background_task = bg_executor.spawn(async move {
203            tx.send(42).unwrap();
204        });
205
206        // Run all tasks
207        scheduler.run();
208
209        // Block on receiving the value from the background task
210        let received = executor::block_on(rx).unwrap();
211
212        // Assert on the result
213        assert_eq!(received, 42);
214    }
215}