@@ -11,8 +11,12 @@ use std::sync::Arc;
use std::thread::{self, ThreadId};
+#[derive(Copy, Clone, PartialEq, Eq, Hash)]
+pub struct TaskLabel(usize);
+
pub trait Scheduler: Send + Sync + Any {
- fn schedule_foreground(&self, runnable: Runnable);
+ fn schedule(&self, runnable: Runnable, label: Option<TaskLabel>);
+ fn schedule_foreground(&self, runnable: Runnable, label: Option<TaskLabel>);
fn is_main_thread(&self) -> bool;
}
@@ -27,6 +31,12 @@ impl<R> Task<R> {
}
}
+impl Default for TaskLabel {
+ fn default() -> Self {
+ TaskLabel(0)
+ }
+}
+
pub struct SchedulerConfig {
pub randomize_order: bool,
pub seed: u64,
@@ -80,7 +90,11 @@ impl TestScheduler {
}
impl Scheduler for TestScheduler {
- fn schedule_foreground(&self, runnable: Runnable) {
+ fn schedule(&self, runnable: Runnable, _label: Option<TaskLabel>) {
+ runnable.run();
+ }
+
+ fn schedule_foreground(&self, runnable: Runnable, _label: Option<TaskLabel>) {
self.inner.lock().foreground_queue.push_back(runnable);
}
@@ -102,13 +116,45 @@ impl ForegroundExecutor {
})
}
+ pub fn spawn<R: 'static>(&self, future: impl Future<Output = R> + 'static) -> Task<R> {
+ let scheduler = self.scheduler.clone();
+ let (runnable, task) = async_task::spawn_local(future, move |runnable| {
+ scheduler.schedule_foreground(runnable, None);
+ });
+ runnable.schedule();
+ Task(task)
+ }
+
+ pub fn spawn_labeled<R: 'static>(
+ &self,
+ future: impl Future<Output = R> + 'static,
+ label: TaskLabel,
+ ) -> Task<R> {
+ let scheduler = self.scheduler.clone();
+ let (runnable, task) = async_task::spawn_local(future, move |runnable| {
+ scheduler.schedule_foreground(runnable, Some(label));
+ });
+ runnable.schedule();
+ Task(task)
+ }
+}
+
+pub struct BackgroundExecutor {
+ scheduler: Arc<dyn Scheduler>,
+}
+
+impl BackgroundExecutor {
+ pub fn new(scheduler: Arc<dyn Scheduler>) -> Result<Self> {
+ Ok(Self { scheduler })
+ }
+
pub fn spawn<R: 'static + Send>(
&self,
future: impl Future<Output = R> + Send + 'static,
) -> Task<R> {
let scheduler = self.scheduler.clone();
- let (runnable, task) = async_task::spawn_local(future, move |runnable| {
- scheduler.schedule_foreground(runnable);
+ let (runnable, task) = async_task::spawn(future, move |runnable| {
+ scheduler.schedule_foreground(runnable, None);
});
runnable.schedule();
Task(task)
@@ -141,4 +187,33 @@ mod tests {
assert!(flag.load(Ordering::SeqCst));
}
+
+ #[test]
+ fn test_background_task_with_foreground_wait() {
+ let scheduler = Arc::new(TestScheduler::new(SchedulerConfig::default()));
+
+ let flag = Arc::new(AtomicBool::new(false));
+ assert!(!flag.load(Ordering::SeqCst));
+
+ // Spawn background task
+ let bg_executor = BackgroundExecutor::new(scheduler.clone()).unwrap();
+ let _background_task = bg_executor.spawn({
+ let flag = flag.clone();
+ async move {
+ flag.store(true, Ordering::SeqCst);
+ }
+ });
+
+ // Spawn foreground task (nothing special, just demonstrates both types)
+ let fg_executor = ForegroundExecutor::new(scheduler.clone()).unwrap();
+ let _fg_task = fg_executor.spawn(async move {
+ // Foreground-specific work if needed
+ });
+
+ // Run all tasks
+ scheduler.run();
+
+ // Background task should have run and set the flag
+ assert!(flag.load(Ordering::SeqCst));
+ }
}