1use crate::{Instant, Priority, RunnableMeta, Scheduler, SessionId, Timer};
2use std::{
3 future::Future,
4 marker::PhantomData,
5 mem::ManuallyDrop,
6 panic::Location,
7 pin::Pin,
8 rc::Rc,
9 sync::Arc,
10 task::{Context, Poll},
11 thread::{self, ThreadId},
12 time::Duration,
13};
14
15#[derive(Clone)]
16pub struct ForegroundExecutor {
17 session_id: SessionId,
18 scheduler: Arc<dyn Scheduler>,
19 not_send: PhantomData<Rc<()>>,
20}
21
22impl ForegroundExecutor {
23 pub fn new(session_id: SessionId, scheduler: Arc<dyn Scheduler>) -> Self {
24 Self {
25 session_id,
26 scheduler,
27 not_send: PhantomData,
28 }
29 }
30
31 pub fn session_id(&self) -> SessionId {
32 self.session_id
33 }
34
35 pub fn scheduler(&self) -> &Arc<dyn Scheduler> {
36 &self.scheduler
37 }
38
39 #[track_caller]
40 pub fn spawn<F>(&self, future: F) -> Task<F::Output>
41 where
42 F: Future + 'static,
43 F::Output: 'static,
44 {
45 let session_id = self.session_id;
46 let scheduler = Arc::clone(&self.scheduler);
47 let location = Location::caller();
48 let (runnable, task) = spawn_local_with_source_location(
49 future,
50 move |runnable| {
51 scheduler.schedule_foreground(session_id, runnable);
52 },
53 RunnableMeta { location },
54 );
55 runnable.schedule();
56 Task(TaskState::Spawned(task))
57 }
58
59 pub fn block_on<Fut: Future>(&self, future: Fut) -> Fut::Output {
60 use std::cell::Cell;
61
62 let output = Cell::new(None);
63 let future = async {
64 output.set(Some(future.await));
65 };
66 let mut future = std::pin::pin!(future);
67
68 self.scheduler
69 .block(Some(self.session_id), future.as_mut(), None);
70
71 output.take().expect("block_on future did not complete")
72 }
73
74 /// Block until the future completes or timeout occurs.
75 /// Returns Ok(output) if completed, Err(future) if timed out.
76 pub fn block_with_timeout<Fut: Future>(
77 &self,
78 timeout: Duration,
79 future: Fut,
80 ) -> Result<Fut::Output, impl Future<Output = Fut::Output> + use<Fut>> {
81 use std::cell::Cell;
82
83 let output = Cell::new(None);
84 let mut future = Box::pin(future);
85
86 {
87 let future_ref = &mut future;
88 let wrapper = async {
89 output.set(Some(future_ref.await));
90 };
91 let mut wrapper = std::pin::pin!(wrapper);
92
93 self.scheduler
94 .block(Some(self.session_id), wrapper.as_mut(), Some(timeout));
95 }
96
97 match output.take() {
98 Some(value) => Ok(value),
99 None => Err(future),
100 }
101 }
102
103 #[track_caller]
104 pub fn timer(&self, duration: Duration) -> Timer {
105 self.scheduler.timer(duration)
106 }
107
108 pub fn now(&self) -> Instant {
109 self.scheduler.clock().now()
110 }
111}
112
113#[derive(Clone)]
114pub struct BackgroundExecutor {
115 scheduler: Arc<dyn Scheduler>,
116}
117
118impl BackgroundExecutor {
119 pub fn new(scheduler: Arc<dyn Scheduler>) -> Self {
120 Self { scheduler }
121 }
122
123 #[track_caller]
124 pub fn spawn<F>(&self, future: F) -> Task<F::Output>
125 where
126 F: Future + Send + 'static,
127 F::Output: Send + 'static,
128 {
129 self.spawn_with_priority(Priority::default(), future)
130 }
131
132 #[track_caller]
133 pub fn spawn_with_priority<F>(&self, priority: Priority, future: F) -> Task<F::Output>
134 where
135 F: Future + Send + 'static,
136 F::Output: Send + 'static,
137 {
138 let scheduler = Arc::clone(&self.scheduler);
139 let location = Location::caller();
140 let (runnable, task) = async_task::Builder::new()
141 .metadata(RunnableMeta { location })
142 .spawn(
143 move |_| future,
144 move |runnable| {
145 scheduler.schedule_background_with_priority(runnable, priority);
146 },
147 );
148 runnable.schedule();
149 Task(TaskState::Spawned(task))
150 }
151
152 /// Spawns a future on a dedicated realtime thread for audio processing.
153 #[track_caller]
154 pub fn spawn_realtime<F>(&self, future: F) -> Task<F::Output>
155 where
156 F: Future + Send + 'static,
157 F::Output: Send + 'static,
158 {
159 let location = Location::caller();
160 let (tx, rx) = flume::bounded::<async_task::Runnable<RunnableMeta>>(1);
161
162 self.scheduler.spawn_realtime(Box::new(move || {
163 while let Ok(runnable) = rx.recv() {
164 runnable.run();
165 }
166 }));
167
168 let (runnable, task) = async_task::Builder::new()
169 .metadata(RunnableMeta { location })
170 .spawn(
171 move |_| future,
172 move |runnable| {
173 let _ = tx.send(runnable);
174 },
175 );
176 runnable.schedule();
177 Task(TaskState::Spawned(task))
178 }
179
180 #[track_caller]
181 pub fn timer(&self, duration: Duration) -> Timer {
182 self.scheduler.timer(duration)
183 }
184
185 pub fn now(&self) -> Instant {
186 self.scheduler.clock().now()
187 }
188
189 pub fn scheduler(&self) -> &Arc<dyn Scheduler> {
190 &self.scheduler
191 }
192}
193
194/// Task is a primitive that allows work to happen in the background.
195///
196/// It implements [`Future`] so you can `.await` on it.
197///
198/// If you drop a task it will be cancelled immediately. Calling [`Task::detach`] allows
199/// the task to continue running, but with no way to return a value.
200#[must_use]
201#[derive(Debug)]
202pub struct Task<T>(TaskState<T>);
203
204#[derive(Debug)]
205enum TaskState<T> {
206 /// A task that is ready to return a value
207 Ready(Option<T>),
208
209 /// A task that is currently running.
210 Spawned(async_task::Task<T, RunnableMeta>),
211}
212
213impl<T> Task<T> {
214 /// Creates a new task that will resolve with the value
215 pub fn ready(val: T) -> Self {
216 Task(TaskState::Ready(Some(val)))
217 }
218
219 /// Creates a Task from an async_task::Task
220 pub fn from_async_task(task: async_task::Task<T, RunnableMeta>) -> Self {
221 Task(TaskState::Spawned(task))
222 }
223
224 pub fn is_ready(&self) -> bool {
225 match &self.0 {
226 TaskState::Ready(_) => true,
227 TaskState::Spawned(task) => task.is_finished(),
228 }
229 }
230
231 /// Detaching a task runs it to completion in the background
232 pub fn detach(self) {
233 match self {
234 Task(TaskState::Ready(_)) => {}
235 Task(TaskState::Spawned(task)) => task.detach(),
236 }
237 }
238
239 /// Converts this task into a fallible task that returns `Option<T>`.
240 pub fn fallible(self) -> FallibleTask<T> {
241 FallibleTask(match self.0 {
242 TaskState::Ready(val) => FallibleTaskState::Ready(val),
243 TaskState::Spawned(task) => FallibleTaskState::Spawned(task.fallible()),
244 })
245 }
246}
247
248/// A task that returns `Option<T>` instead of panicking when cancelled.
249#[must_use]
250pub struct FallibleTask<T>(FallibleTaskState<T>);
251
252enum FallibleTaskState<T> {
253 /// A task that is ready to return a value
254 Ready(Option<T>),
255
256 /// A task that is currently running (wraps async_task::FallibleTask).
257 Spawned(async_task::FallibleTask<T, RunnableMeta>),
258}
259
260impl<T> FallibleTask<T> {
261 /// Creates a new fallible task that will resolve with the value.
262 pub fn ready(val: T) -> Self {
263 FallibleTask(FallibleTaskState::Ready(Some(val)))
264 }
265
266 /// Detaching a task runs it to completion in the background.
267 pub fn detach(self) {
268 match self.0 {
269 FallibleTaskState::Ready(_) => {}
270 FallibleTaskState::Spawned(task) => task.detach(),
271 }
272 }
273}
274
275impl<T> Future for FallibleTask<T> {
276 type Output = Option<T>;
277
278 fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
279 match unsafe { self.get_unchecked_mut() } {
280 FallibleTask(FallibleTaskState::Ready(val)) => Poll::Ready(val.take()),
281 FallibleTask(FallibleTaskState::Spawned(task)) => Pin::new(task).poll(cx),
282 }
283 }
284}
285
286impl<T> std::fmt::Debug for FallibleTask<T> {
287 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
288 match &self.0 {
289 FallibleTaskState::Ready(_) => f.debug_tuple("FallibleTask::Ready").finish(),
290 FallibleTaskState::Spawned(task) => {
291 f.debug_tuple("FallibleTask::Spawned").field(task).finish()
292 }
293 }
294 }
295}
296
297impl<T> Future for Task<T> {
298 type Output = T;
299
300 fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
301 match unsafe { self.get_unchecked_mut() } {
302 Task(TaskState::Ready(val)) => Poll::Ready(val.take().unwrap()),
303 Task(TaskState::Spawned(task)) => Pin::new(task).poll(cx),
304 }
305 }
306}
307
308/// Variant of `async_task::spawn_local` that includes the source location of the spawn in panics.
309#[track_caller]
310fn spawn_local_with_source_location<Fut, S>(
311 future: Fut,
312 schedule: S,
313 metadata: RunnableMeta,
314) -> (
315 async_task::Runnable<RunnableMeta>,
316 async_task::Task<Fut::Output, RunnableMeta>,
317)
318where
319 Fut: Future + 'static,
320 Fut::Output: 'static,
321 S: async_task::Schedule<RunnableMeta> + Send + Sync + 'static,
322{
323 #[inline]
324 fn thread_id() -> ThreadId {
325 std::thread_local! {
326 static ID: ThreadId = thread::current().id();
327 }
328 ID.try_with(|id| *id)
329 .unwrap_or_else(|_| thread::current().id())
330 }
331
332 struct Checked<F> {
333 id: ThreadId,
334 inner: ManuallyDrop<F>,
335 location: &'static Location<'static>,
336 }
337
338 impl<F> Drop for Checked<F> {
339 fn drop(&mut self) {
340 assert_eq!(
341 self.id,
342 thread_id(),
343 "local task dropped by a thread that didn't spawn it. Task spawned at {}",
344 self.location
345 );
346 unsafe {
347 ManuallyDrop::drop(&mut self.inner);
348 }
349 }
350 }
351
352 impl<F: Future> Future for Checked<F> {
353 type Output = F::Output;
354
355 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
356 assert!(
357 self.id == thread_id(),
358 "local task polled by a thread that didn't spawn it. Task spawned at {}",
359 self.location
360 );
361 unsafe { self.map_unchecked_mut(|c| &mut *c.inner).poll(cx) }
362 }
363 }
364
365 let location = metadata.location;
366
367 unsafe {
368 async_task::Builder::new()
369 .metadata(metadata)
370 .spawn_unchecked(
371 move |_| Checked {
372 id: thread_id(),
373 inner: ManuallyDrop::new(future),
374 location,
375 },
376 schedule,
377 )
378 }
379}