1use crate::{Instant, Priority, RunnableMeta, Scheduler, SessionId, Timer};
2use std::{
3 future::Future,
4 marker::PhantomData,
5 mem::ManuallyDrop,
6 panic::Location,
7 pin::Pin,
8 rc::Rc,
9 sync::{
10 Arc,
11 atomic::{AtomicBool, Ordering},
12 },
13 task::{Context, Poll},
14 thread::{self, ThreadId},
15 time::Duration,
16};
17
18#[derive(Clone)]
19pub struct ForegroundExecutor {
20 session_id: SessionId,
21 scheduler: Arc<dyn Scheduler>,
22 closed: Arc<AtomicBool>,
23 not_send: PhantomData<Rc<()>>,
24}
25
26impl ForegroundExecutor {
27 pub fn new(session_id: SessionId, scheduler: Arc<dyn Scheduler>) -> Self {
28 Self {
29 session_id,
30 scheduler,
31 closed: Arc::new(AtomicBool::new(false)),
32 not_send: PhantomData,
33 }
34 }
35
36 pub fn session_id(&self) -> SessionId {
37 self.session_id
38 }
39
40 pub fn scheduler(&self) -> &Arc<dyn Scheduler> {
41 &self.scheduler
42 }
43
44 /// Returns the closed flag for this executor.
45 pub fn closed(&self) -> &Arc<AtomicBool> {
46 &self.closed
47 }
48
49 /// Close this executor. Tasks will not run after this is called.
50 pub fn close(&self) {
51 self.closed.store(true, Ordering::SeqCst);
52 }
53
54 #[track_caller]
55 pub fn spawn<F>(&self, future: F) -> Task<F::Output>
56 where
57 F: Future + 'static,
58 F::Output: 'static,
59 {
60 let session_id = self.session_id;
61 let scheduler = Arc::clone(&self.scheduler);
62 let location = Location::caller();
63 let closed = self.closed.clone();
64 let (runnable, task) = spawn_local_with_source_location(
65 future,
66 move |runnable| {
67 scheduler.schedule_foreground(session_id, runnable);
68 },
69 RunnableMeta { location, closed },
70 );
71 runnable.schedule();
72 Task(TaskState::Spawned(task))
73 }
74
75 pub fn block_on<Fut: Future>(&self, future: Fut) -> Fut::Output {
76 use std::cell::Cell;
77
78 let output = Cell::new(None);
79 let future = async {
80 output.set(Some(future.await));
81 };
82 let mut future = std::pin::pin!(future);
83
84 self.scheduler
85 .block(Some(self.session_id), future.as_mut(), None);
86
87 output.take().expect("block_on future did not complete")
88 }
89
90 /// Block until the future completes or timeout occurs.
91 /// Returns Ok(output) if completed, Err(future) if timed out.
92 pub fn block_with_timeout<Fut: Future>(
93 &self,
94 timeout: Duration,
95 future: Fut,
96 ) -> Result<Fut::Output, impl Future<Output = Fut::Output> + use<Fut>> {
97 use std::cell::Cell;
98
99 let output = Cell::new(None);
100 let mut future = Box::pin(future);
101
102 {
103 let future_ref = &mut future;
104 let wrapper = async {
105 output.set(Some(future_ref.await));
106 };
107 let mut wrapper = std::pin::pin!(wrapper);
108
109 self.scheduler
110 .block(Some(self.session_id), wrapper.as_mut(), Some(timeout));
111 }
112
113 match output.take() {
114 Some(value) => Ok(value),
115 None => Err(future),
116 }
117 }
118
119 #[track_caller]
120 pub fn timer(&self, duration: Duration) -> Timer {
121 self.scheduler.timer(duration)
122 }
123
124 pub fn now(&self) -> Instant {
125 self.scheduler.clock().now()
126 }
127}
128
129#[derive(Clone)]
130pub struct BackgroundExecutor {
131 scheduler: Arc<dyn Scheduler>,
132 closed: Arc<AtomicBool>,
133}
134
135impl BackgroundExecutor {
136 pub fn new(scheduler: Arc<dyn Scheduler>) -> Self {
137 Self {
138 scheduler,
139 closed: Arc::new(AtomicBool::new(false)),
140 }
141 }
142
143 /// Returns the closed flag for this executor.
144 pub fn closed(&self) -> &Arc<AtomicBool> {
145 &self.closed
146 }
147
148 /// Close this executor. Tasks will not run after this is called.
149 pub fn close(&self) {
150 self.closed.store(true, Ordering::SeqCst);
151 }
152
153 #[track_caller]
154 pub fn spawn<F>(&self, future: F) -> Task<F::Output>
155 where
156 F: Future + Send + 'static,
157 F::Output: Send + 'static,
158 {
159 self.spawn_with_priority(Priority::default(), future)
160 }
161
162 #[track_caller]
163 pub fn spawn_with_priority<F>(&self, priority: Priority, future: F) -> Task<F::Output>
164 where
165 F: Future + Send + 'static,
166 F::Output: Send + 'static,
167 {
168 let scheduler = Arc::clone(&self.scheduler);
169 let location = Location::caller();
170 let closed = self.closed.clone();
171 let (runnable, task) = async_task::Builder::new()
172 .metadata(RunnableMeta { location, closed })
173 .spawn(
174 move |_| future,
175 move |runnable| {
176 scheduler.schedule_background_with_priority(runnable, priority);
177 },
178 );
179 runnable.schedule();
180 Task(TaskState::Spawned(task))
181 }
182
183 /// Spawns a future on a dedicated realtime thread for audio processing.
184 #[track_caller]
185 pub fn spawn_realtime<F>(&self, future: F) -> Task<F::Output>
186 where
187 F: Future + Send + 'static,
188 F::Output: Send + 'static,
189 {
190 let location = Location::caller();
191 let closed = self.closed.clone();
192 let (tx, rx) = flume::bounded::<async_task::Runnable<RunnableMeta>>(1);
193
194 self.scheduler.spawn_realtime(Box::new(move || {
195 while let Ok(runnable) = rx.recv() {
196 if runnable.metadata().is_closed() {
197 continue;
198 }
199 runnable.run();
200 }
201 }));
202
203 let (runnable, task) = async_task::Builder::new()
204 .metadata(RunnableMeta { location, closed })
205 .spawn(
206 move |_| future,
207 move |runnable| {
208 let _ = tx.send(runnable);
209 },
210 );
211 runnable.schedule();
212 Task(TaskState::Spawned(task))
213 }
214
215 #[track_caller]
216 pub fn timer(&self, duration: Duration) -> Timer {
217 self.scheduler.timer(duration)
218 }
219
220 pub fn now(&self) -> Instant {
221 self.scheduler.clock().now()
222 }
223
224 pub fn scheduler(&self) -> &Arc<dyn Scheduler> {
225 &self.scheduler
226 }
227}
228
229/// Task is a primitive that allows work to happen in the background.
230///
231/// It implements [`Future`] so you can `.await` on it.
232///
233/// If you drop a task it will be cancelled immediately. Calling [`Task::detach`] allows
234/// the task to continue running, but with no way to return a value.
235#[must_use]
236#[derive(Debug)]
237pub struct Task<T>(TaskState<T>);
238
239#[derive(Debug)]
240enum TaskState<T> {
241 /// A task that is ready to return a value
242 Ready(Option<T>),
243
244 /// A task that is currently running.
245 Spawned(async_task::Task<T, RunnableMeta>),
246}
247
248impl<T> Task<T> {
249 /// Creates a new task that will resolve with the value
250 pub fn ready(val: T) -> Self {
251 Task(TaskState::Ready(Some(val)))
252 }
253
254 /// Creates a Task from an async_task::Task
255 pub fn from_async_task(task: async_task::Task<T, RunnableMeta>) -> Self {
256 Task(TaskState::Spawned(task))
257 }
258
259 pub fn is_ready(&self) -> bool {
260 match &self.0 {
261 TaskState::Ready(_) => true,
262 TaskState::Spawned(task) => task.is_finished(),
263 }
264 }
265
266 /// Detaching a task runs it to completion in the background
267 pub fn detach(self) {
268 match self {
269 Task(TaskState::Ready(_)) => {}
270 Task(TaskState::Spawned(task)) => task.detach(),
271 }
272 }
273
274 /// Converts this task into a fallible task that returns `Option<T>`.
275 pub fn fallible(self) -> FallibleTask<T> {
276 FallibleTask(match self.0 {
277 TaskState::Ready(val) => FallibleTaskState::Ready(val),
278 TaskState::Spawned(task) => FallibleTaskState::Spawned(task.fallible()),
279 })
280 }
281}
282
283/// A task that returns `Option<T>` instead of panicking when cancelled.
284#[must_use]
285pub struct FallibleTask<T>(FallibleTaskState<T>);
286
287enum FallibleTaskState<T> {
288 /// A task that is ready to return a value
289 Ready(Option<T>),
290
291 /// A task that is currently running (wraps async_task::FallibleTask).
292 Spawned(async_task::FallibleTask<T, RunnableMeta>),
293}
294
295impl<T> FallibleTask<T> {
296 /// Creates a new fallible task that will resolve with the value.
297 pub fn ready(val: T) -> Self {
298 FallibleTask(FallibleTaskState::Ready(Some(val)))
299 }
300
301 /// Detaching a task runs it to completion in the background.
302 pub fn detach(self) {
303 match self.0 {
304 FallibleTaskState::Ready(_) => {}
305 FallibleTaskState::Spawned(task) => task.detach(),
306 }
307 }
308}
309
310impl<T> Future for FallibleTask<T> {
311 type Output = Option<T>;
312
313 fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
314 match unsafe { self.get_unchecked_mut() } {
315 FallibleTask(FallibleTaskState::Ready(val)) => Poll::Ready(val.take()),
316 FallibleTask(FallibleTaskState::Spawned(task)) => Pin::new(task).poll(cx),
317 }
318 }
319}
320
321impl<T> std::fmt::Debug for FallibleTask<T> {
322 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
323 match &self.0 {
324 FallibleTaskState::Ready(_) => f.debug_tuple("FallibleTask::Ready").finish(),
325 FallibleTaskState::Spawned(task) => {
326 f.debug_tuple("FallibleTask::Spawned").field(task).finish()
327 }
328 }
329 }
330}
331
332impl<T> Future for Task<T> {
333 type Output = T;
334
335 fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
336 match unsafe { self.get_unchecked_mut() } {
337 Task(TaskState::Ready(val)) => Poll::Ready(val.take().unwrap()),
338 Task(TaskState::Spawned(task)) => Pin::new(task).poll(cx),
339 }
340 }
341}
342
343/// Variant of `async_task::spawn_local` that includes the source location of the spawn in panics.
344#[track_caller]
345fn spawn_local_with_source_location<Fut, S>(
346 future: Fut,
347 schedule: S,
348 metadata: RunnableMeta,
349) -> (
350 async_task::Runnable<RunnableMeta>,
351 async_task::Task<Fut::Output, RunnableMeta>,
352)
353where
354 Fut: Future + 'static,
355 Fut::Output: 'static,
356 S: async_task::Schedule<RunnableMeta> + Send + Sync + 'static,
357{
358 #[inline]
359 fn thread_id() -> ThreadId {
360 std::thread_local! {
361 static ID: ThreadId = thread::current().id();
362 }
363 ID.try_with(|id| *id)
364 .unwrap_or_else(|_| thread::current().id())
365 }
366
367 struct Checked<F> {
368 id: ThreadId,
369 inner: ManuallyDrop<F>,
370 location: &'static Location<'static>,
371 }
372
373 impl<F> Drop for Checked<F> {
374 fn drop(&mut self) {
375 assert!(
376 self.id == thread_id(),
377 "local task dropped by a thread that didn't spawn it. Task spawned at {}",
378 self.location
379 );
380 unsafe {
381 ManuallyDrop::drop(&mut self.inner);
382 }
383 }
384 }
385
386 impl<F: Future> Future for Checked<F> {
387 type Output = F::Output;
388
389 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
390 assert!(
391 self.id == thread_id(),
392 "local task polled by a thread that didn't spawn it. Task spawned at {}",
393 self.location
394 );
395 unsafe { self.map_unchecked_mut(|c| &mut *c.inner).poll(cx) }
396 }
397 }
398
399 let location = metadata.location;
400
401 unsafe {
402 async_task::Builder::new()
403 .metadata(metadata)
404 .spawn_unchecked(
405 move |_| Checked {
406 id: thread_id(),
407 inner: ManuallyDrop::new(future),
408 location,
409 },
410 schedule,
411 )
412 }
413}