1use anyhow::{anyhow, Result};
2use async_task::Runnable;
3pub use async_task::Task;
4use parking_lot::Mutex;
5use rand::prelude::*;
6use smol::{channel, prelude::*, Executor};
7use std::{
8 marker::PhantomData,
9 mem,
10 pin::Pin,
11 rc::Rc,
12 sync::{
13 atomic::{AtomicBool, Ordering::SeqCst},
14 mpsc::SyncSender,
15 Arc,
16 },
17 thread,
18};
19
20use crate::platform;
21
22pub enum Foreground {
23 Platform {
24 dispatcher: Arc<dyn platform::Dispatcher>,
25 _not_send_or_sync: PhantomData<Rc<()>>,
26 },
27 Test(smol::LocalExecutor<'static>),
28 Deterministic(Arc<Deterministic>),
29}
30
31pub enum Background {
32 Deterministic(Arc<Deterministic>),
33 Production {
34 executor: Arc<smol::Executor<'static>>,
35 threads: usize,
36 _stop: channel::Sender<()>,
37 },
38}
39
40struct DeterministicState {
41 rng: StdRng,
42 seed: u64,
43 scheduled: Vec<Runnable>,
44 spawned_from_foreground: Vec<Runnable>,
45 waker: Option<SyncSender<()>>,
46}
47
48pub struct Deterministic(Arc<Mutex<DeterministicState>>);
49
50impl Deterministic {
51 fn new(seed: u64) -> Self {
52 Self(Arc::new(Mutex::new(DeterministicState {
53 rng: StdRng::seed_from_u64(seed),
54 seed,
55 scheduled: Default::default(),
56 spawned_from_foreground: Default::default(),
57 waker: None,
58 })))
59 }
60
61 pub fn spawn_from_foreground<F, T>(&self, future: F) -> Task<T>
62 where
63 T: 'static,
64 F: Future<Output = T> + 'static,
65 {
66 let scheduled_once = AtomicBool::new(false);
67 let state = self.0.clone();
68 let (runnable, task) = async_task::spawn_local(future, move |runnable| {
69 let mut state = state.lock();
70 if scheduled_once.fetch_or(true, SeqCst) {
71 state.scheduled.push(runnable);
72 } else {
73 state.spawned_from_foreground.push(runnable);
74 }
75 if let Some(waker) = state.waker.as_ref() {
76 waker.send(()).ok();
77 }
78 });
79 runnable.schedule();
80 task
81 }
82
83 pub fn spawn<F, T>(&self, future: F) -> Task<T>
84 where
85 T: 'static + Send,
86 F: 'static + Send + Future<Output = T>,
87 {
88 let state = self.0.clone();
89 let (runnable, task) = async_task::spawn(future, move |runnable| {
90 let mut state = state.lock();
91 state.scheduled.push(runnable);
92 if let Some(waker) = state.waker.as_ref() {
93 waker.send(()).ok();
94 }
95 });
96 runnable.schedule();
97 task
98 }
99
100 pub fn run<F, T>(&self, future: F) -> T
101 where
102 T: 'static,
103 F: Future<Output = T> + 'static,
104 {
105 let (wake_tx, wake_rx) = std::sync::mpsc::sync_channel(32);
106 let state = self.0.clone();
107 state.lock().waker = Some(wake_tx);
108
109 let (output_tx, output_rx) = std::sync::mpsc::channel();
110 self.spawn_from_foreground(async move {
111 let output = future.await;
112 output_tx.send(output).unwrap();
113 })
114 .detach();
115
116 loop {
117 if let Ok(value) = output_rx.try_recv() {
118 state.lock().waker = None;
119 return value;
120 }
121
122 wake_rx.recv().unwrap();
123 let runnable = {
124 let state = &mut *state.lock();
125 let ix = state
126 .rng
127 .gen_range(0..state.scheduled.len() + state.spawned_from_foreground.len());
128 if ix < state.scheduled.len() {
129 state.scheduled.remove(ix)
130 } else {
131 state.spawned_from_foreground.remove(0)
132 }
133 };
134
135 runnable.run();
136 }
137 }
138}
139
140impl Foreground {
141 pub fn platform(dispatcher: Arc<dyn platform::Dispatcher>) -> Result<Self> {
142 if dispatcher.is_main_thread() {
143 Ok(Self::Platform {
144 dispatcher,
145 _not_send_or_sync: PhantomData,
146 })
147 } else {
148 Err(anyhow!("must be constructed on main thread"))
149 }
150 }
151
152 pub fn test() -> Self {
153 Self::Test(smol::LocalExecutor::new())
154 }
155
156 pub fn spawn<T: 'static>(&self, future: impl Future<Output = T> + 'static) -> Task<T> {
157 match self {
158 Self::Platform { dispatcher, .. } => {
159 let dispatcher = dispatcher.clone();
160 let schedule = move |runnable: Runnable| dispatcher.run_on_main_thread(runnable);
161 let (runnable, task) = async_task::spawn_local(future, schedule);
162 runnable.schedule();
163 task
164 }
165 Self::Test(executor) => executor.spawn(future),
166 Self::Deterministic(executor) => executor.spawn_from_foreground(future),
167 }
168 }
169
170 pub fn run<T: 'static>(&self, future: impl 'static + Future<Output = T>) -> T {
171 match self {
172 Self::Platform { .. } => panic!("you can't call run on a platform foreground executor"),
173 Self::Test(executor) => smol::block_on(executor.run(future)),
174 Self::Deterministic(executor) => executor.run(future),
175 }
176 }
177
178 pub fn reset(&self) {
179 match self {
180 Self::Platform { .. } => panic!("can't call this method on a platform executor"),
181 Self::Test(_) => panic!("can't call this method on a test executor"),
182 Self::Deterministic(executor) => {
183 let state = &mut *executor.0.lock();
184 state.rng = StdRng::seed_from_u64(state.seed);
185 }
186 }
187 }
188}
189
190impl Background {
191 pub fn new() -> Self {
192 let executor = Arc::new(Executor::new());
193 let stop = channel::unbounded::<()>();
194 let threads = num_cpus::get();
195
196 for i in 0..threads {
197 let executor = executor.clone();
198 let stop = stop.1.clone();
199 thread::Builder::new()
200 .name(format!("background-executor-{}", i))
201 .spawn(move || smol::block_on(executor.run(stop.recv())))
202 .unwrap();
203 }
204
205 Self::Production {
206 executor,
207 threads,
208 _stop: stop.0,
209 }
210 }
211
212 pub fn threads(&self) -> usize {
213 match self {
214 Self::Deterministic(_) => 1,
215 Self::Production { threads, .. } => *threads,
216 }
217 }
218
219 pub fn spawn<T, F>(&self, future: F) -> Task<T>
220 where
221 T: 'static + Send,
222 F: Send + Future<Output = T> + 'static,
223 {
224 match self {
225 Self::Production { executor, .. } => executor.spawn(future),
226 Self::Deterministic(executor) => executor.spawn(future),
227 }
228 }
229
230 pub async fn scoped<'scope, F>(&self, scheduler: F)
231 where
232 F: FnOnce(&mut Scope<'scope>),
233 {
234 let mut scope = Scope {
235 futures: Default::default(),
236 _phantom: PhantomData,
237 };
238 (scheduler)(&mut scope);
239 let spawned = scope
240 .futures
241 .into_iter()
242 .map(|f| self.spawn(f))
243 .collect::<Vec<_>>();
244 for task in spawned {
245 task.await;
246 }
247 }
248}
249
250pub struct Scope<'a> {
251 futures: Vec<Pin<Box<dyn Future<Output = ()> + Send + 'static>>>,
252 _phantom: PhantomData<&'a ()>,
253}
254
255impl<'a> Scope<'a> {
256 pub fn spawn<F>(&mut self, f: F)
257 where
258 F: Future<Output = ()> + Send + 'a,
259 {
260 let f = unsafe {
261 mem::transmute::<
262 Pin<Box<dyn Future<Output = ()> + Send + 'a>>,
263 Pin<Box<dyn Future<Output = ()> + Send + 'static>>,
264 >(Box::pin(f))
265 };
266 self.futures.push(f);
267 }
268}
269
270pub fn deterministic(seed: u64) -> (Rc<Foreground>, Arc<Background>) {
271 let executor = Arc::new(Deterministic::new(seed));
272 (
273 Rc::new(Foreground::Deterministic(executor.clone())),
274 Arc::new(Background::Deterministic(executor)),
275 )
276}