1use anyhow::{anyhow, Result};
2use async_task::Runnable;
3pub use async_task::Task;
4use parking_lot::Mutex;
5use rand::prelude::*;
6use smol::{channel, prelude::*, Executor};
7use std::{
8 marker::PhantomData,
9 mem,
10 pin::Pin,
11 rc::Rc,
12 sync::{mpsc::SyncSender, Arc},
13 thread,
14};
15
16use crate::platform;
17
18pub enum Foreground {
19 Platform {
20 dispatcher: Arc<dyn platform::Dispatcher>,
21 _not_send_or_sync: PhantomData<Rc<()>>,
22 },
23 Test(smol::LocalExecutor<'static>),
24 Deterministic(Arc<Deterministic>),
25}
26
27pub enum Background {
28 Deterministic(Arc<Deterministic>),
29 Production {
30 executor: Arc<smol::Executor<'static>>,
31 threads: usize,
32 _stop: channel::Sender<()>,
33 },
34}
35
36pub struct Deterministic {
37 seed: u64,
38 runnables: Arc<Mutex<(Vec<Runnable>, Option<SyncSender<()>>)>>,
39}
40
41impl Deterministic {
42 fn new(seed: u64) -> Self {
43 Self {
44 seed,
45 runnables: Default::default(),
46 }
47 }
48
49 pub fn spawn_local<F, T>(&self, future: F) -> Task<T>
50 where
51 T: 'static,
52 F: Future<Output = T> + 'static,
53 {
54 let runnables = self.runnables.clone();
55 let (runnable, task) = async_task::spawn_local(future, move |runnable| {
56 let mut runnables = runnables.lock();
57 runnables.0.push(runnable);
58 if let Some(wake_tx) = runnables.1.as_ref() {
59 wake_tx.send(()).ok();
60 }
61 });
62 runnable.schedule();
63 task
64 }
65
66 pub fn spawn<F, T>(&self, future: F) -> Task<T>
67 where
68 T: 'static + Send,
69 F: 'static + Send + Future<Output = T>,
70 {
71 let runnables = self.runnables.clone();
72 let (runnable, task) = async_task::spawn(future, move |runnable| {
73 let mut runnables = runnables.lock();
74 runnables.0.push(runnable);
75 if let Some(wake_tx) = runnables.1.as_ref() {
76 wake_tx.send(()).ok();
77 }
78 });
79 runnable.schedule();
80 task
81 }
82
83 pub fn run<F, T>(&self, future: F) -> T
84 where
85 T: 'static,
86 F: Future<Output = T> + 'static,
87 {
88 let (wake_tx, wake_rx) = std::sync::mpsc::sync_channel(32);
89 let runnables = self.runnables.clone();
90 runnables.lock().1 = Some(wake_tx);
91
92 let (output_tx, output_rx) = std::sync::mpsc::channel();
93 self.spawn_local(async move {
94 let output = future.await;
95 output_tx.send(output).unwrap();
96 })
97 .detach();
98
99 let mut rng = StdRng::seed_from_u64(self.seed);
100 loop {
101 if let Ok(value) = output_rx.try_recv() {
102 runnables.lock().1 = None;
103 return value;
104 }
105
106 wake_rx.recv().unwrap();
107 let runnable = {
108 let mut runnables = runnables.lock();
109 let runnables = &mut runnables.0;
110 let index = rng.gen_range(0..runnables.len());
111 runnables.remove(index)
112 };
113
114 runnable.run();
115 }
116 }
117}
118
119impl Foreground {
120 pub fn platform(dispatcher: Arc<dyn platform::Dispatcher>) -> Result<Self> {
121 if dispatcher.is_main_thread() {
122 Ok(Self::Platform {
123 dispatcher,
124 _not_send_or_sync: PhantomData,
125 })
126 } else {
127 Err(anyhow!("must be constructed on main thread"))
128 }
129 }
130
131 pub fn test() -> Self {
132 Self::Test(smol::LocalExecutor::new())
133 }
134
135 pub fn spawn<T: 'static>(&self, future: impl Future<Output = T> + 'static) -> Task<T> {
136 match self {
137 Self::Platform { dispatcher, .. } => {
138 let dispatcher = dispatcher.clone();
139 let schedule = move |runnable: Runnable| dispatcher.run_on_main_thread(runnable);
140 let (runnable, task) = async_task::spawn_local(future, schedule);
141 runnable.schedule();
142 task
143 }
144 Self::Test(executor) => executor.spawn(future),
145 Self::Deterministic(executor) => executor.spawn_local(future),
146 }
147 }
148
149 pub fn run<T: 'static>(&self, future: impl 'static + Future<Output = T>) -> T {
150 match self {
151 Self::Platform { .. } => panic!("you can't call run on a platform foreground executor"),
152 Self::Test(executor) => smol::block_on(executor.run(future)),
153 Self::Deterministic(executor) => executor.run(future),
154 }
155 }
156}
157
158impl Background {
159 pub fn new() -> Self {
160 let executor = Arc::new(Executor::new());
161 let stop = channel::unbounded::<()>();
162 let threads = num_cpus::get();
163
164 for i in 0..threads {
165 let executor = executor.clone();
166 let stop = stop.1.clone();
167 thread::Builder::new()
168 .name(format!("background-executor-{}", i))
169 .spawn(move || smol::block_on(executor.run(stop.recv())))
170 .unwrap();
171 }
172
173 Self::Production {
174 executor,
175 threads,
176 _stop: stop.0,
177 }
178 }
179
180 pub fn threads(&self) -> usize {
181 match self {
182 Self::Deterministic(_) => 1,
183 Self::Production { threads, .. } => *threads,
184 }
185 }
186
187 pub fn spawn<T, F>(&self, future: F) -> Task<T>
188 where
189 T: 'static + Send,
190 F: Send + Future<Output = T> + 'static,
191 {
192 match self {
193 Self::Production { executor, .. } => executor.spawn(future),
194 Self::Deterministic(executor) => executor.spawn(future),
195 }
196 }
197
198 pub async fn scoped<'scope, F>(&self, scheduler: F)
199 where
200 F: FnOnce(&mut Scope<'scope>),
201 {
202 let mut scope = Scope {
203 futures: Default::default(),
204 _phantom: PhantomData,
205 };
206 (scheduler)(&mut scope);
207 let spawned = scope
208 .futures
209 .into_iter()
210 .map(|f| self.spawn(f))
211 .collect::<Vec<_>>();
212 for task in spawned {
213 task.await;
214 }
215 }
216}
217
218pub struct Scope<'a> {
219 futures: Vec<Pin<Box<dyn Future<Output = ()> + Send + 'static>>>,
220 _phantom: PhantomData<&'a ()>,
221}
222
223impl<'a> Scope<'a> {
224 pub fn spawn<F>(&mut self, f: F)
225 where
226 F: Future<Output = ()> + Send + 'a,
227 {
228 let f = unsafe {
229 mem::transmute::<
230 Pin<Box<dyn Future<Output = ()> + Send + 'a>>,
231 Pin<Box<dyn Future<Output = ()> + Send + 'static>>,
232 >(Box::pin(f))
233 };
234 self.futures.push(f);
235 }
236}
237
238pub fn deterministic(seed: u64) -> (Rc<Foreground>, Arc<Background>) {
239 let executor = Arc::new(Deterministic::new(seed));
240 (
241 Rc::new(Foreground::Deterministic(executor.clone())),
242 Arc::new(Background::Deterministic(executor)),
243 )
244}