executor.rs

  1use crate::{Instant, Priority, RunnableMeta, Scheduler, SessionId, Timer};
  2use std::{
  3    future::Future,
  4    marker::PhantomData,
  5    mem::ManuallyDrop,
  6    panic::Location,
  7    pin::Pin,
  8    rc::Rc,
  9    sync::Arc,
 10    task::{Context, Poll},
 11    thread::{self, ThreadId},
 12    time::Duration,
 13};
 14
 15#[derive(Clone)]
 16pub struct ForegroundExecutor {
 17    session_id: SessionId,
 18    scheduler: Arc<dyn Scheduler>,
 19    not_send: PhantomData<Rc<()>>,
 20}
 21
 22impl ForegroundExecutor {
 23    pub fn new(session_id: SessionId, scheduler: Arc<dyn Scheduler>) -> Self {
 24        Self {
 25            session_id,
 26            scheduler,
 27            not_send: PhantomData,
 28        }
 29    }
 30
 31    pub fn session_id(&self) -> SessionId {
 32        self.session_id
 33    }
 34
 35    pub fn scheduler(&self) -> &Arc<dyn Scheduler> {
 36        &self.scheduler
 37    }
 38
 39    #[track_caller]
 40    pub fn spawn<F>(&self, future: F) -> Task<F::Output>
 41    where
 42        F: Future + 'static,
 43        F::Output: 'static,
 44    {
 45        let session_id = self.session_id;
 46        let scheduler = Arc::clone(&self.scheduler);
 47        let location = Location::caller();
 48        let (runnable, task) = spawn_local_with_source_location(
 49            future,
 50            move |runnable| {
 51                scheduler.schedule_foreground(session_id, runnable);
 52            },
 53            RunnableMeta { location },
 54        );
 55        runnable.schedule();
 56        Task(TaskState::Spawned(task))
 57    }
 58
 59    pub fn block_on<Fut: Future>(&self, future: Fut) -> Fut::Output {
 60        use std::cell::Cell;
 61
 62        let output = Cell::new(None);
 63        let future = async {
 64            output.set(Some(future.await));
 65        };
 66        let mut future = std::pin::pin!(future);
 67
 68        self.scheduler
 69            .block(Some(self.session_id), future.as_mut(), None);
 70
 71        output.take().expect("block_on future did not complete")
 72    }
 73
 74    /// Block until the future completes or timeout occurs.
 75    /// Returns Ok(output) if completed, Err(future) if timed out.
 76    pub fn block_with_timeout<Fut: Future>(
 77        &self,
 78        timeout: Duration,
 79        future: Fut,
 80    ) -> Result<Fut::Output, impl Future<Output = Fut::Output> + use<Fut>> {
 81        use std::cell::Cell;
 82
 83        let output = Cell::new(None);
 84        let mut future = Box::pin(future);
 85
 86        {
 87            let future_ref = &mut future;
 88            let wrapper = async {
 89                output.set(Some(future_ref.await));
 90            };
 91            let mut wrapper = std::pin::pin!(wrapper);
 92
 93            self.scheduler
 94                .block(Some(self.session_id), wrapper.as_mut(), Some(timeout));
 95        }
 96
 97        match output.take() {
 98            Some(value) => Ok(value),
 99            None => Err(future),
100        }
101    }
102
103    #[track_caller]
104    pub fn timer(&self, duration: Duration) -> Timer {
105        self.scheduler.timer(duration)
106    }
107
108    pub fn now(&self) -> Instant {
109        self.scheduler.clock().now()
110    }
111}
112
113#[derive(Clone)]
114pub struct BackgroundExecutor {
115    scheduler: Arc<dyn Scheduler>,
116}
117
118impl BackgroundExecutor {
119    pub fn new(scheduler: Arc<dyn Scheduler>) -> Self {
120        Self { scheduler }
121    }
122
123    #[track_caller]
124    pub fn spawn<F>(&self, future: F) -> Task<F::Output>
125    where
126        F: Future + Send + 'static,
127        F::Output: Send + 'static,
128    {
129        self.spawn_with_priority(Priority::default(), future)
130    }
131
132    #[track_caller]
133    pub fn spawn_with_priority<F>(&self, priority: Priority, future: F) -> Task<F::Output>
134    where
135        F: Future + Send + 'static,
136        F::Output: Send + 'static,
137    {
138        let scheduler = Arc::clone(&self.scheduler);
139        let location = Location::caller();
140        let (runnable, task) = async_task::Builder::new()
141            .metadata(RunnableMeta { location })
142            .spawn(
143                move |_| future,
144                move |runnable| {
145                    scheduler.schedule_background_with_priority(runnable, priority);
146                },
147            );
148        runnable.schedule();
149        Task(TaskState::Spawned(task))
150    }
151
152    /// Spawns a future on a dedicated realtime thread for audio processing.
153    #[track_caller]
154    pub fn spawn_realtime<F>(&self, future: F) -> Task<F::Output>
155    where
156        F: Future + Send + 'static,
157        F::Output: Send + 'static,
158    {
159        let location = Location::caller();
160        let (tx, rx) = flume::bounded::<async_task::Runnable<RunnableMeta>>(1);
161
162        self.scheduler.spawn_realtime(Box::new(move || {
163            while let Ok(runnable) = rx.recv() {
164                runnable.run();
165            }
166        }));
167
168        let (runnable, task) = async_task::Builder::new()
169            .metadata(RunnableMeta { location })
170            .spawn(
171                move |_| future,
172                move |runnable| {
173                    let _ = tx.send(runnable);
174                },
175            );
176        runnable.schedule();
177        Task(TaskState::Spawned(task))
178    }
179
180    #[track_caller]
181    pub fn timer(&self, duration: Duration) -> Timer {
182        self.scheduler.timer(duration)
183    }
184
185    pub fn now(&self) -> Instant {
186        self.scheduler.clock().now()
187    }
188
189    pub fn scheduler(&self) -> &Arc<dyn Scheduler> {
190        &self.scheduler
191    }
192}
193
194/// Task is a primitive that allows work to happen in the background.
195///
196/// It implements [`Future`] so you can `.await` on it.
197///
198/// If you drop a task it will be cancelled immediately. Calling [`Task::detach`] allows
199/// the task to continue running, but with no way to return a value.
200#[must_use]
201#[derive(Debug)]
202pub struct Task<T>(TaskState<T>);
203
204#[derive(Debug)]
205enum TaskState<T> {
206    /// A task that is ready to return a value
207    Ready(Option<T>),
208
209    /// A task that is currently running.
210    Spawned(async_task::Task<T, RunnableMeta>),
211}
212
213impl<T> Task<T> {
214    /// Creates a new task that will resolve with the value
215    pub fn ready(val: T) -> Self {
216        Task(TaskState::Ready(Some(val)))
217    }
218
219    /// Creates a Task from an async_task::Task
220    pub fn from_async_task(task: async_task::Task<T, RunnableMeta>) -> Self {
221        Task(TaskState::Spawned(task))
222    }
223
224    pub fn is_ready(&self) -> bool {
225        match &self.0 {
226            TaskState::Ready(_) => true,
227            TaskState::Spawned(task) => task.is_finished(),
228        }
229    }
230
231    /// Detaching a task runs it to completion in the background
232    pub fn detach(self) {
233        match self {
234            Task(TaskState::Ready(_)) => {}
235            Task(TaskState::Spawned(task)) => task.detach(),
236        }
237    }
238
239    /// Converts this task into a fallible task that returns `Option<T>`.
240    pub fn fallible(self) -> FallibleTask<T> {
241        FallibleTask(match self.0 {
242            TaskState::Ready(val) => FallibleTaskState::Ready(val),
243            TaskState::Spawned(task) => FallibleTaskState::Spawned(task.fallible()),
244        })
245    }
246}
247
248/// A task that returns `Option<T>` instead of panicking when cancelled.
249#[must_use]
250pub struct FallibleTask<T>(FallibleTaskState<T>);
251
252enum FallibleTaskState<T> {
253    /// A task that is ready to return a value
254    Ready(Option<T>),
255
256    /// A task that is currently running (wraps async_task::FallibleTask).
257    Spawned(async_task::FallibleTask<T, RunnableMeta>),
258}
259
260impl<T> FallibleTask<T> {
261    /// Creates a new fallible task that will resolve with the value.
262    pub fn ready(val: T) -> Self {
263        FallibleTask(FallibleTaskState::Ready(Some(val)))
264    }
265
266    /// Detaching a task runs it to completion in the background.
267    pub fn detach(self) {
268        match self.0 {
269            FallibleTaskState::Ready(_) => {}
270            FallibleTaskState::Spawned(task) => task.detach(),
271        }
272    }
273}
274
275impl<T> Future for FallibleTask<T> {
276    type Output = Option<T>;
277
278    fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
279        match unsafe { self.get_unchecked_mut() } {
280            FallibleTask(FallibleTaskState::Ready(val)) => Poll::Ready(val.take()),
281            FallibleTask(FallibleTaskState::Spawned(task)) => Pin::new(task).poll(cx),
282        }
283    }
284}
285
286impl<T> std::fmt::Debug for FallibleTask<T> {
287    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
288        match &self.0 {
289            FallibleTaskState::Ready(_) => f.debug_tuple("FallibleTask::Ready").finish(),
290            FallibleTaskState::Spawned(task) => {
291                f.debug_tuple("FallibleTask::Spawned").field(task).finish()
292            }
293        }
294    }
295}
296
297impl<T> Future for Task<T> {
298    type Output = T;
299
300    fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
301        match unsafe { self.get_unchecked_mut() } {
302            Task(TaskState::Ready(val)) => Poll::Ready(val.take().unwrap()),
303            Task(TaskState::Spawned(task)) => Pin::new(task).poll(cx),
304        }
305    }
306}
307
308/// Variant of `async_task::spawn_local` that includes the source location of the spawn in panics.
309#[track_caller]
310fn spawn_local_with_source_location<Fut, S>(
311    future: Fut,
312    schedule: S,
313    metadata: RunnableMeta,
314) -> (
315    async_task::Runnable<RunnableMeta>,
316    async_task::Task<Fut::Output, RunnableMeta>,
317)
318where
319    Fut: Future + 'static,
320    Fut::Output: 'static,
321    S: async_task::Schedule<RunnableMeta> + Send + Sync + 'static,
322{
323    #[inline]
324    fn thread_id() -> ThreadId {
325        std::thread_local! {
326            static ID: ThreadId = thread::current().id();
327        }
328        ID.try_with(|id| *id)
329            .unwrap_or_else(|_| thread::current().id())
330    }
331
332    struct Checked<F> {
333        id: ThreadId,
334        inner: ManuallyDrop<F>,
335        location: &'static Location<'static>,
336    }
337
338    impl<F> Drop for Checked<F> {
339        fn drop(&mut self) {
340            assert_eq!(
341                self.id,
342                thread_id(),
343                "local task dropped by a thread that didn't spawn it. Task spawned at {}",
344                self.location
345            );
346            unsafe {
347                ManuallyDrop::drop(&mut self.inner);
348            }
349        }
350    }
351
352    impl<F: Future> Future for Checked<F> {
353        type Output = F::Output;
354
355        fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
356            assert!(
357                self.id == thread_id(),
358                "local task polled by a thread that didn't spawn it. Task spawned at {}",
359                self.location
360            );
361            unsafe { self.map_unchecked_mut(|c| &mut *c.inner).poll(cx) }
362        }
363    }
364
365    let location = metadata.location;
366
367    unsafe {
368        async_task::Builder::new()
369            .metadata(metadata)
370            .spawn_unchecked(
371                move |_| Checked {
372                    id: thread_id(),
373                    inner: ManuallyDrop::new(future),
374                    location,
375                },
376                schedule,
377            )
378    }
379}