executor.rs

  1use crate::{Scheduler, SessionId, Timer};
  2use futures::FutureExt as _;
  3use std::{
  4    future::Future,
  5    marker::PhantomData,
  6    mem::ManuallyDrop,
  7    panic::Location,
  8    pin::Pin,
  9    rc::Rc,
 10    sync::Arc,
 11    task::{Context, Poll},
 12    thread::{self, ThreadId},
 13    time::Duration,
 14};
 15
 16#[derive(Clone)]
 17pub struct ForegroundExecutor {
 18    session_id: SessionId,
 19    scheduler: Arc<dyn Scheduler>,
 20    not_send: PhantomData<Rc<()>>,
 21}
 22
 23impl ForegroundExecutor {
 24    pub fn new(session_id: SessionId, scheduler: Arc<dyn Scheduler>) -> Self {
 25        Self {
 26            session_id,
 27            scheduler,
 28            not_send: PhantomData,
 29        }
 30    }
 31
 32    #[track_caller]
 33    pub fn spawn<F>(&self, future: F) -> Task<F::Output>
 34    where
 35        F: Future + 'static,
 36        F::Output: 'static,
 37    {
 38        let session_id = self.session_id;
 39        let scheduler = Arc::clone(&self.scheduler);
 40        let (runnable, task) = spawn_local_with_source_location(future, move |runnable| {
 41            scheduler.schedule_foreground(session_id, runnable);
 42        });
 43        runnable.schedule();
 44        Task(TaskState::Spawned(task))
 45    }
 46
 47    pub fn block_on<Fut: Future>(&self, future: Fut) -> Fut::Output {
 48        let mut output = None;
 49        self.scheduler.block(
 50            Some(self.session_id),
 51            async { output = Some(future.await) }.boxed_local(),
 52            None,
 53        );
 54        output.unwrap()
 55    }
 56
 57    pub fn block_with_timeout<Fut: Unpin + Future>(
 58        &self,
 59        timeout: Duration,
 60        mut future: Fut,
 61    ) -> Result<Fut::Output, Fut> {
 62        let mut output = None;
 63        self.scheduler.block(
 64            Some(self.session_id),
 65            async { output = Some((&mut future).await) }.boxed_local(),
 66            Some(timeout),
 67        );
 68        output.ok_or(future)
 69    }
 70
 71    pub fn timer(&self, duration: Duration) -> Timer {
 72        self.scheduler.timer(duration)
 73    }
 74}
 75
 76#[derive(Clone)]
 77pub struct BackgroundExecutor {
 78    scheduler: Arc<dyn Scheduler>,
 79}
 80
 81impl BackgroundExecutor {
 82    pub fn new(scheduler: Arc<dyn Scheduler>) -> Self {
 83        Self { scheduler }
 84    }
 85
 86    pub fn spawn<F>(&self, future: F) -> Task<F::Output>
 87    where
 88        F: Future + Send + 'static,
 89        F::Output: Send + 'static,
 90    {
 91        let scheduler = Arc::clone(&self.scheduler);
 92        let (runnable, task) = async_task::spawn(future, move |runnable| {
 93            scheduler.schedule_background(runnable);
 94        });
 95        runnable.schedule();
 96        Task(TaskState::Spawned(task))
 97    }
 98
 99    pub fn timer(&self, duration: Duration) -> Timer {
100        self.scheduler.timer(duration)
101    }
102
103    pub fn scheduler(&self) -> &Arc<dyn Scheduler> {
104        &self.scheduler
105    }
106}
107
108/// Task is a primitive that allows work to happen in the background.
109///
110/// It implements [`Future`] so you can `.await` on it.
111///
112/// If you drop a task it will be cancelled immediately. Calling [`Task::detach`] allows
113/// the task to continue running, but with no way to return a value.
114#[must_use]
115#[derive(Debug)]
116pub struct Task<T>(TaskState<T>);
117
118#[derive(Debug)]
119enum TaskState<T> {
120    /// A task that is ready to return a value
121    Ready(Option<T>),
122
123    /// A task that is currently running.
124    Spawned(async_task::Task<T>),
125}
126
127impl<T> Task<T> {
128    /// Creates a new task that will resolve with the value
129    pub const fn ready(val: T) -> Self {
130        Task(TaskState::Ready(Some(val)))
131    }
132
133    pub fn is_ready(&self) -> bool {
134        match &self.0 {
135            TaskState::Ready(_) => true,
136            TaskState::Spawned(task) => task.is_finished(),
137        }
138    }
139
140    /// Detaching a task runs it to completion in the background
141    pub fn detach(self) {
142        match self {
143            Task(TaskState::Ready(_)) => {}
144            Task(TaskState::Spawned(task)) => task.detach(),
145        }
146    }
147}
148
149impl<T> Future for Task<T> {
150    type Output = T;
151
152    fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
153        match unsafe { self.get_unchecked_mut() } {
154            Task(TaskState::Ready(val)) => Poll::Ready(val.take().unwrap()),
155            Task(TaskState::Spawned(task)) => Pin::new(task).poll(cx),
156        }
157    }
158}
159
160/// Variant of `async_task::spawn_local` that includes the source location of the spawn in panics.
161///
162/// Copy-modified from:
163/// <https://github.com/smol-rs/async-task/blob/ca9dbe1db9c422fd765847fa91306e30a6bb58a9/src/runnable.rs#L405>
164#[track_caller]
165fn spawn_local_with_source_location<Fut, S>(
166    future: Fut,
167    schedule: S,
168) -> (async_task::Runnable, async_task::Task<Fut::Output, ()>)
169where
170    Fut: Future + 'static,
171    Fut::Output: 'static,
172    S: async_task::Schedule + Send + Sync + 'static,
173{
174    #[inline]
175    fn thread_id() -> ThreadId {
176        std::thread_local! {
177            static ID: ThreadId = thread::current().id();
178        }
179        ID.try_with(|id| *id)
180            .unwrap_or_else(|_| thread::current().id())
181    }
182
183    struct Checked<F> {
184        id: ThreadId,
185        inner: ManuallyDrop<F>,
186        location: &'static Location<'static>,
187    }
188
189    impl<F> Drop for Checked<F> {
190        fn drop(&mut self) {
191            assert!(
192                self.id == thread_id(),
193                "local task dropped by a thread that didn't spawn it. Task spawned at {}",
194                self.location
195            );
196            unsafe {
197                ManuallyDrop::drop(&mut self.inner);
198            }
199        }
200    }
201
202    impl<F: Future> Future for Checked<F> {
203        type Output = F::Output;
204
205        fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
206            assert!(
207                self.id == thread_id(),
208                "local task polled by a thread that didn't spawn it. Task spawned at {}",
209                self.location
210            );
211            unsafe { self.map_unchecked_mut(|c| &mut *c.inner).poll(cx) }
212        }
213    }
214
215    // Wrap the future into one that checks which thread it's on.
216    let future = Checked {
217        id: thread_id(),
218        inner: ManuallyDrop::new(future),
219        location: Location::caller(),
220    };
221
222    unsafe { async_task::spawn_unchecked(future, schedule) }
223}