1use anyhow::{anyhow, Result};
2use async_task::Runnable;
3pub use async_task::Task;
4use parking_lot::Mutex;
5use rand::prelude::*;
6use smol::{channel, prelude::*, Executor};
7use std::{
8 marker::PhantomData,
9 mem,
10 pin::Pin,
11 rc::Rc,
12 sync::{mpsc::SyncSender, Arc},
13 thread,
14};
15
16use crate::platform;
17
18pub enum Foreground {
19 Platform {
20 dispatcher: Arc<dyn platform::Dispatcher>,
21 _not_send_or_sync: PhantomData<Rc<()>>,
22 },
23 Test(smol::LocalExecutor<'static>),
24 Deterministic(Arc<Deterministic>),
25}
26
27pub enum Background {
28 Deterministic(Arc<Deterministic>),
29 Production {
30 executor: Arc<smol::Executor<'static>>,
31 threads: usize,
32 _stop: channel::Sender<()>,
33 },
34}
35
36#[derive(Default)]
37struct Runnables {
38 scheduled: Vec<Runnable>,
39 spawned_from_foreground: Vec<Runnable>,
40 waker: Option<SyncSender<()>>,
41}
42
43pub struct Deterministic {
44 seed: u64,
45 runnables: Arc<Mutex<Runnables>>,
46}
47
48impl Deterministic {
49 fn new(seed: u64) -> Self {
50 Self {
51 seed,
52 runnables: Default::default(),
53 }
54 }
55
56 pub fn spawn_from_foreground<F, T>(&self, future: F) -> Task<T>
57 where
58 T: 'static,
59 F: Future<Output = T> + 'static,
60 {
61 let scheduled_once = Mutex::new(false);
62 let runnables = self.runnables.clone();
63 let (runnable, task) = async_task::spawn_local(future, move |runnable| {
64 let mut runnables = runnables.lock();
65 if *scheduled_once.lock() {
66 runnables.scheduled.push(runnable);
67 } else {
68 runnables.spawned_from_foreground.push(runnable);
69 *scheduled_once.lock() = true;
70 }
71 if let Some(waker) = runnables.waker.as_ref() {
72 waker.send(()).ok();
73 }
74 });
75 runnable.schedule();
76 task
77 }
78
79 pub fn spawn<F, T>(&self, future: F) -> Task<T>
80 where
81 T: 'static + Send,
82 F: 'static + Send + Future<Output = T>,
83 {
84 let runnables = self.runnables.clone();
85 let (runnable, task) = async_task::spawn(future, move |runnable| {
86 let mut runnables = runnables.lock();
87 runnables.scheduled.push(runnable);
88 if let Some(waker) = runnables.waker.as_ref() {
89 waker.send(()).ok();
90 }
91 });
92 runnable.schedule();
93 task
94 }
95
96 pub fn run<F, T>(&self, future: F) -> T
97 where
98 T: 'static,
99 F: Future<Output = T> + 'static,
100 {
101 let (wake_tx, wake_rx) = std::sync::mpsc::sync_channel(32);
102 let runnables = self.runnables.clone();
103 runnables.lock().waker = Some(wake_tx);
104
105 let (output_tx, output_rx) = std::sync::mpsc::channel();
106 self.spawn_from_foreground(async move {
107 let output = future.await;
108 output_tx.send(output).unwrap();
109 })
110 .detach();
111
112 let mut rng = StdRng::seed_from_u64(self.seed);
113 loop {
114 if let Ok(value) = output_rx.try_recv() {
115 runnables.lock().waker = None;
116 return value;
117 }
118
119 wake_rx.recv().unwrap();
120 let runnable = {
121 let mut runnables = runnables.lock();
122 let ix = rng.gen_range(
123 0..runnables.scheduled.len() + runnables.spawned_from_foreground.len(),
124 );
125 if ix < runnables.scheduled.len() {
126 runnables.scheduled.remove(ix)
127 } else {
128 runnables.spawned_from_foreground.remove(0)
129 }
130 };
131
132 runnable.run();
133 }
134 }
135}
136
137impl Foreground {
138 pub fn platform(dispatcher: Arc<dyn platform::Dispatcher>) -> Result<Self> {
139 if dispatcher.is_main_thread() {
140 Ok(Self::Platform {
141 dispatcher,
142 _not_send_or_sync: PhantomData,
143 })
144 } else {
145 Err(anyhow!("must be constructed on main thread"))
146 }
147 }
148
149 pub fn test() -> Self {
150 Self::Test(smol::LocalExecutor::new())
151 }
152
153 pub fn spawn<T: 'static>(&self, future: impl Future<Output = T> + 'static) -> Task<T> {
154 match self {
155 Self::Platform { dispatcher, .. } => {
156 let dispatcher = dispatcher.clone();
157 let schedule = move |runnable: Runnable| dispatcher.run_on_main_thread(runnable);
158 let (runnable, task) = async_task::spawn_local(future, schedule);
159 runnable.schedule();
160 task
161 }
162 Self::Test(executor) => executor.spawn(future),
163 Self::Deterministic(executor) => executor.spawn_from_foreground(future),
164 }
165 }
166
167 pub fn run<T: 'static>(&self, future: impl 'static + Future<Output = T>) -> T {
168 match self {
169 Self::Platform { .. } => panic!("you can't call run on a platform foreground executor"),
170 Self::Test(executor) => smol::block_on(executor.run(future)),
171 Self::Deterministic(executor) => executor.run(future),
172 }
173 }
174}
175
176impl Background {
177 pub fn new() -> Self {
178 let executor = Arc::new(Executor::new());
179 let stop = channel::unbounded::<()>();
180 let threads = num_cpus::get();
181
182 for i in 0..threads {
183 let executor = executor.clone();
184 let stop = stop.1.clone();
185 thread::Builder::new()
186 .name(format!("background-executor-{}", i))
187 .spawn(move || smol::block_on(executor.run(stop.recv())))
188 .unwrap();
189 }
190
191 Self::Production {
192 executor,
193 threads,
194 _stop: stop.0,
195 }
196 }
197
198 pub fn threads(&self) -> usize {
199 match self {
200 Self::Deterministic(_) => 1,
201 Self::Production { threads, .. } => *threads,
202 }
203 }
204
205 pub fn spawn<T, F>(&self, future: F) -> Task<T>
206 where
207 T: 'static + Send,
208 F: Send + Future<Output = T> + 'static,
209 {
210 match self {
211 Self::Production { executor, .. } => executor.spawn(future),
212 Self::Deterministic(executor) => executor.spawn(future),
213 }
214 }
215
216 pub async fn scoped<'scope, F>(&self, scheduler: F)
217 where
218 F: FnOnce(&mut Scope<'scope>),
219 {
220 let mut scope = Scope {
221 futures: Default::default(),
222 _phantom: PhantomData,
223 };
224 (scheduler)(&mut scope);
225 let spawned = scope
226 .futures
227 .into_iter()
228 .map(|f| self.spawn(f))
229 .collect::<Vec<_>>();
230 for task in spawned {
231 task.await;
232 }
233 }
234}
235
236pub struct Scope<'a> {
237 futures: Vec<Pin<Box<dyn Future<Output = ()> + Send + 'static>>>,
238 _phantom: PhantomData<&'a ()>,
239}
240
241impl<'a> Scope<'a> {
242 pub fn spawn<F>(&mut self, f: F)
243 where
244 F: Future<Output = ()> + Send + 'a,
245 {
246 let f = unsafe {
247 mem::transmute::<
248 Pin<Box<dyn Future<Output = ()> + Send + 'a>>,
249 Pin<Box<dyn Future<Output = ()> + Send + 'static>>,
250 >(Box::pin(f))
251 };
252 self.futures.push(f);
253 }
254}
255
256pub fn deterministic(seed: u64) -> (Rc<Foreground>, Arc<Background>) {
257 let executor = Arc::new(Deterministic::new(seed));
258 (
259 Rc::new(Foreground::Deterministic(executor.clone())),
260 Arc::new(Background::Deterministic(executor)),
261 )
262}