executor.rs

  1use crate::{Instant, Priority, RunnableMeta, Scheduler, SessionId, Timer};
  2use std::{
  3    future::Future,
  4    marker::PhantomData,
  5    mem::ManuallyDrop,
  6    panic::Location,
  7    pin::Pin,
  8    rc::Rc,
  9    sync::{
 10        Arc,
 11        atomic::{AtomicBool, Ordering},
 12    },
 13    task::{Context, Poll},
 14    thread::{self, ThreadId},
 15    time::Duration,
 16};
 17
 18#[derive(Clone)]
 19pub struct ForegroundExecutor {
 20    session_id: SessionId,
 21    scheduler: Arc<dyn Scheduler>,
 22    closed: Arc<AtomicBool>,
 23    not_send: PhantomData<Rc<()>>,
 24}
 25
 26impl ForegroundExecutor {
 27    pub fn new(session_id: SessionId, scheduler: Arc<dyn Scheduler>) -> Self {
 28        Self {
 29            session_id,
 30            scheduler,
 31            closed: Arc::new(AtomicBool::new(false)),
 32            not_send: PhantomData,
 33        }
 34    }
 35
 36    pub fn session_id(&self) -> SessionId {
 37        self.session_id
 38    }
 39
 40    pub fn scheduler(&self) -> &Arc<dyn Scheduler> {
 41        &self.scheduler
 42    }
 43
 44    /// Returns the closed flag for this executor.
 45    pub fn closed(&self) -> &Arc<AtomicBool> {
 46        &self.closed
 47    }
 48
 49    /// Close this executor. Tasks will not run after this is called.
 50    pub fn close(&self) {
 51        self.closed.store(true, Ordering::SeqCst);
 52    }
 53
 54    #[track_caller]
 55    pub fn spawn<F>(&self, future: F) -> Task<F::Output>
 56    where
 57        F: Future + 'static,
 58        F::Output: 'static,
 59    {
 60        let session_id = self.session_id;
 61        let scheduler = Arc::clone(&self.scheduler);
 62        let location = Location::caller();
 63        let closed = self.closed.clone();
 64        let (runnable, task) = spawn_local_with_source_location(
 65            future,
 66            move |runnable| {
 67                scheduler.schedule_foreground(session_id, runnable);
 68            },
 69            RunnableMeta { location, closed },
 70        );
 71        runnable.schedule();
 72        Task(TaskState::Spawned(task))
 73    }
 74
 75    pub fn block_on<Fut: Future>(&self, future: Fut) -> Fut::Output {
 76        use std::cell::Cell;
 77
 78        let output = Cell::new(None);
 79        let future = async {
 80            output.set(Some(future.await));
 81        };
 82        let mut future = std::pin::pin!(future);
 83
 84        self.scheduler
 85            .block(Some(self.session_id), future.as_mut(), None);
 86
 87        output.take().expect("block_on future did not complete")
 88    }
 89
 90    /// Block until the future completes or timeout occurs.
 91    /// Returns Ok(output) if completed, Err(future) if timed out.
 92    pub fn block_with_timeout<Fut: Future>(
 93        &self,
 94        timeout: Duration,
 95        future: Fut,
 96    ) -> Result<Fut::Output, impl Future<Output = Fut::Output> + use<Fut>> {
 97        use std::cell::Cell;
 98
 99        let output = Cell::new(None);
100        let mut future = Box::pin(future);
101
102        {
103            let future_ref = &mut future;
104            let wrapper = async {
105                output.set(Some(future_ref.await));
106            };
107            let mut wrapper = std::pin::pin!(wrapper);
108
109            self.scheduler
110                .block(Some(self.session_id), wrapper.as_mut(), Some(timeout));
111        }
112
113        match output.take() {
114            Some(value) => Ok(value),
115            None => Err(future),
116        }
117    }
118
119    #[track_caller]
120    pub fn timer(&self, duration: Duration) -> Timer {
121        self.scheduler.timer(duration)
122    }
123
124    pub fn now(&self) -> Instant {
125        self.scheduler.clock().now()
126    }
127}
128
129#[derive(Clone)]
130pub struct BackgroundExecutor {
131    scheduler: Arc<dyn Scheduler>,
132    closed: Arc<AtomicBool>,
133}
134
135impl BackgroundExecutor {
136    pub fn new(scheduler: Arc<dyn Scheduler>) -> Self {
137        Self {
138            scheduler,
139            closed: Arc::new(AtomicBool::new(false)),
140        }
141    }
142
143    /// Returns the closed flag for this executor.
144    pub fn closed(&self) -> &Arc<AtomicBool> {
145        &self.closed
146    }
147
148    /// Close this executor. Tasks will not run after this is called.
149    pub fn close(&self) {
150        self.closed.store(true, Ordering::SeqCst);
151    }
152
153    #[track_caller]
154    pub fn spawn<F>(&self, future: F) -> Task<F::Output>
155    where
156        F: Future + Send + 'static,
157        F::Output: Send + 'static,
158    {
159        self.spawn_with_priority(Priority::default(), future)
160    }
161
162    #[track_caller]
163    pub fn spawn_with_priority<F>(&self, priority: Priority, future: F) -> Task<F::Output>
164    where
165        F: Future + Send + 'static,
166        F::Output: Send + 'static,
167    {
168        let scheduler = Arc::clone(&self.scheduler);
169        let location = Location::caller();
170        let closed = self.closed.clone();
171        let (runnable, task) = async_task::Builder::new()
172            .metadata(RunnableMeta { location, closed })
173            .spawn(
174                move |_| future,
175                move |runnable| {
176                    scheduler.schedule_background_with_priority(runnable, priority);
177                },
178            );
179        runnable.schedule();
180        Task(TaskState::Spawned(task))
181    }
182
183    /// Spawns a future on a dedicated realtime thread for audio processing.
184    #[track_caller]
185    pub fn spawn_realtime<F>(&self, future: F) -> Task<F::Output>
186    where
187        F: Future + Send + 'static,
188        F::Output: Send + 'static,
189    {
190        let location = Location::caller();
191        let closed = self.closed.clone();
192        let (tx, rx) = flume::bounded::<async_task::Runnable<RunnableMeta>>(1);
193
194        self.scheduler.spawn_realtime(Box::new(move || {
195            while let Ok(runnable) = rx.recv() {
196                if runnable.metadata().is_closed() {
197                    continue;
198                }
199                runnable.run();
200            }
201        }));
202
203        let (runnable, task) = async_task::Builder::new()
204            .metadata(RunnableMeta { location, closed })
205            .spawn(
206                move |_| future,
207                move |runnable| {
208                    let _ = tx.send(runnable);
209                },
210            );
211        runnable.schedule();
212        Task(TaskState::Spawned(task))
213    }
214
215    #[track_caller]
216    pub fn timer(&self, duration: Duration) -> Timer {
217        self.scheduler.timer(duration)
218    }
219
220    pub fn now(&self) -> Instant {
221        self.scheduler.clock().now()
222    }
223
224    pub fn scheduler(&self) -> &Arc<dyn Scheduler> {
225        &self.scheduler
226    }
227}
228
229/// Task is a primitive that allows work to happen in the background.
230///
231/// It implements [`Future`] so you can `.await` on it.
232///
233/// If you drop a task it will be cancelled immediately. Calling [`Task::detach`] allows
234/// the task to continue running, but with no way to return a value.
235#[must_use]
236#[derive(Debug)]
237pub struct Task<T>(TaskState<T>);
238
239#[derive(Debug)]
240enum TaskState<T> {
241    /// A task that is ready to return a value
242    Ready(Option<T>),
243
244    /// A task that is currently running.
245    Spawned(async_task::Task<T, RunnableMeta>),
246}
247
248impl<T> Task<T> {
249    /// Creates a new task that will resolve with the value
250    pub fn ready(val: T) -> Self {
251        Task(TaskState::Ready(Some(val)))
252    }
253
254    /// Creates a Task from an async_task::Task
255    pub fn from_async_task(task: async_task::Task<T, RunnableMeta>) -> Self {
256        Task(TaskState::Spawned(task))
257    }
258
259    pub fn is_ready(&self) -> bool {
260        match &self.0 {
261            TaskState::Ready(_) => true,
262            TaskState::Spawned(task) => task.is_finished(),
263        }
264    }
265
266    /// Detaching a task runs it to completion in the background
267    pub fn detach(self) {
268        match self {
269            Task(TaskState::Ready(_)) => {}
270            Task(TaskState::Spawned(task)) => task.detach(),
271        }
272    }
273
274    /// Converts this task into a fallible task that returns `Option<T>`.
275    pub fn fallible(self) -> FallibleTask<T> {
276        FallibleTask(match self.0 {
277            TaskState::Ready(val) => FallibleTaskState::Ready(val),
278            TaskState::Spawned(task) => FallibleTaskState::Spawned(task.fallible()),
279        })
280    }
281}
282
283/// A task that returns `Option<T>` instead of panicking when cancelled.
284#[must_use]
285pub struct FallibleTask<T>(FallibleTaskState<T>);
286
287enum FallibleTaskState<T> {
288    /// A task that is ready to return a value
289    Ready(Option<T>),
290
291    /// A task that is currently running (wraps async_task::FallibleTask).
292    Spawned(async_task::FallibleTask<T, RunnableMeta>),
293}
294
295impl<T> FallibleTask<T> {
296    /// Creates a new fallible task that will resolve with the value.
297    pub fn ready(val: T) -> Self {
298        FallibleTask(FallibleTaskState::Ready(Some(val)))
299    }
300
301    /// Detaching a task runs it to completion in the background.
302    pub fn detach(self) {
303        match self.0 {
304            FallibleTaskState::Ready(_) => {}
305            FallibleTaskState::Spawned(task) => task.detach(),
306        }
307    }
308}
309
310impl<T> Future for FallibleTask<T> {
311    type Output = Option<T>;
312
313    fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
314        match unsafe { self.get_unchecked_mut() } {
315            FallibleTask(FallibleTaskState::Ready(val)) => Poll::Ready(val.take()),
316            FallibleTask(FallibleTaskState::Spawned(task)) => Pin::new(task).poll(cx),
317        }
318    }
319}
320
321impl<T> std::fmt::Debug for FallibleTask<T> {
322    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
323        match &self.0 {
324            FallibleTaskState::Ready(_) => f.debug_tuple("FallibleTask::Ready").finish(),
325            FallibleTaskState::Spawned(task) => {
326                f.debug_tuple("FallibleTask::Spawned").field(task).finish()
327            }
328        }
329    }
330}
331
332impl<T> Future for Task<T> {
333    type Output = T;
334
335    fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
336        match unsafe { self.get_unchecked_mut() } {
337            Task(TaskState::Ready(val)) => Poll::Ready(val.take().unwrap()),
338            Task(TaskState::Spawned(task)) => Pin::new(task).poll(cx),
339        }
340    }
341}
342
343/// Variant of `async_task::spawn_local` that includes the source location of the spawn in panics.
344#[track_caller]
345fn spawn_local_with_source_location<Fut, S>(
346    future: Fut,
347    schedule: S,
348    metadata: RunnableMeta,
349) -> (
350    async_task::Runnable<RunnableMeta>,
351    async_task::Task<Fut::Output, RunnableMeta>,
352)
353where
354    Fut: Future + 'static,
355    Fut::Output: 'static,
356    S: async_task::Schedule<RunnableMeta> + Send + Sync + 'static,
357{
358    #[inline]
359    fn thread_id() -> ThreadId {
360        std::thread_local! {
361            static ID: ThreadId = thread::current().id();
362        }
363        ID.try_with(|id| *id)
364            .unwrap_or_else(|_| thread::current().id())
365    }
366
367    struct Checked<F> {
368        id: ThreadId,
369        inner: ManuallyDrop<F>,
370        location: &'static Location<'static>,
371    }
372
373    impl<F> Drop for Checked<F> {
374        fn drop(&mut self) {
375            assert!(
376                self.id == thread_id(),
377                "local task dropped by a thread that didn't spawn it. Task spawned at {}",
378                self.location
379            );
380            unsafe {
381                ManuallyDrop::drop(&mut self.inner);
382            }
383        }
384    }
385
386    impl<F: Future> Future for Checked<F> {
387        type Output = F::Output;
388
389        fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
390            assert!(
391                self.id == thread_id(),
392                "local task polled by a thread that didn't spawn it. Task spawned at {}",
393                self.location
394            );
395            unsafe { self.map_unchecked_mut(|c| &mut *c.inner).poll(cx) }
396        }
397    }
398
399    let location = metadata.location;
400
401    unsafe {
402        async_task::Builder::new()
403            .metadata(metadata)
404            .spawn_unchecked(
405                move |_| Checked {
406                    id: thread_id(),
407                    inner: ManuallyDrop::new(future),
408                    location,
409                },
410                schedule,
411            )
412    }
413}