1use crate::{Scheduler, SessionId, Timer};
2use futures::FutureExt as _;
3use std::{
4 future::Future,
5 marker::PhantomData,
6 mem::ManuallyDrop,
7 panic::Location,
8 pin::Pin,
9 rc::Rc,
10 sync::Arc,
11 task::{Context, Poll},
12 thread::{self, ThreadId},
13 time::Duration,
14};
15
16#[derive(Clone)]
17pub struct ForegroundExecutor {
18 session_id: SessionId,
19 scheduler: Arc<dyn Scheduler>,
20 not_send: PhantomData<Rc<()>>,
21}
22
23impl ForegroundExecutor {
24 pub fn new(session_id: SessionId, scheduler: Arc<dyn Scheduler>) -> Self {
25 Self {
26 session_id,
27 scheduler,
28 not_send: PhantomData,
29 }
30 }
31
32 #[track_caller]
33 pub fn spawn<F>(&self, future: F) -> Task<F::Output>
34 where
35 F: Future + 'static,
36 F::Output: 'static,
37 {
38 let session_id = self.session_id;
39 let scheduler = Arc::clone(&self.scheduler);
40 let (runnable, task) = spawn_local_with_source_location(future, move |runnable| {
41 scheduler.schedule_foreground(session_id, runnable);
42 });
43 runnable.schedule();
44 Task(TaskState::Spawned(task))
45 }
46
47 pub fn block_on<Fut: Future>(&self, future: Fut) -> Fut::Output {
48 let mut output = None;
49 self.scheduler.block(
50 Some(self.session_id),
51 async { output = Some(future.await) }.boxed_local(),
52 None,
53 );
54 output.unwrap()
55 }
56
57 pub fn block_with_timeout<Fut: Unpin + Future>(
58 &self,
59 timeout: Duration,
60 mut future: Fut,
61 ) -> Result<Fut::Output, Fut> {
62 let mut output = None;
63 self.scheduler.block(
64 Some(self.session_id),
65 async { output = Some((&mut future).await) }.boxed_local(),
66 Some(timeout),
67 );
68 output.ok_or(future)
69 }
70
71 pub fn timer(&self, duration: Duration) -> Timer {
72 self.scheduler.timer(duration)
73 }
74}
75
76#[derive(Clone)]
77pub struct BackgroundExecutor {
78 scheduler: Arc<dyn Scheduler>,
79}
80
81impl BackgroundExecutor {
82 pub fn new(scheduler: Arc<dyn Scheduler>) -> Self {
83 Self { scheduler }
84 }
85
86 pub fn spawn<F>(&self, future: F) -> Task<F::Output>
87 where
88 F: Future + Send + 'static,
89 F::Output: Send + 'static,
90 {
91 let scheduler = Arc::clone(&self.scheduler);
92 let (runnable, task) = async_task::spawn(future, move |runnable| {
93 scheduler.schedule_background(runnable);
94 });
95 runnable.schedule();
96 Task(TaskState::Spawned(task))
97 }
98
99 pub fn timer(&self, duration: Duration) -> Timer {
100 self.scheduler.timer(duration)
101 }
102
103 pub fn scheduler(&self) -> &Arc<dyn Scheduler> {
104 &self.scheduler
105 }
106}
107
108/// Task is a primitive that allows work to happen in the background.
109///
110/// It implements [`Future`] so you can `.await` on it.
111///
112/// If you drop a task it will be cancelled immediately. Calling [`Task::detach`] allows
113/// the task to continue running, but with no way to return a value.
114#[must_use]
115#[derive(Debug)]
116pub struct Task<T>(TaskState<T>);
117
118#[derive(Debug)]
119enum TaskState<T> {
120 /// A task that is ready to return a value
121 Ready(Option<T>),
122
123 /// A task that is currently running.
124 Spawned(async_task::Task<T>),
125}
126
127impl<T> Task<T> {
128 /// Creates a new task that will resolve with the value
129 pub const fn ready(val: T) -> Self {
130 Task(TaskState::Ready(Some(val)))
131 }
132
133 pub fn is_ready(&self) -> bool {
134 match &self.0 {
135 TaskState::Ready(_) => true,
136 TaskState::Spawned(task) => task.is_finished(),
137 }
138 }
139
140 /// Detaching a task runs it to completion in the background
141 pub fn detach(self) {
142 match self {
143 Task(TaskState::Ready(_)) => {}
144 Task(TaskState::Spawned(task)) => task.detach(),
145 }
146 }
147}
148
149impl<T> Future for Task<T> {
150 type Output = T;
151
152 fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
153 match unsafe { self.get_unchecked_mut() } {
154 Task(TaskState::Ready(val)) => Poll::Ready(val.take().unwrap()),
155 Task(TaskState::Spawned(task)) => Pin::new(task).poll(cx),
156 }
157 }
158}
159
160/// Variant of `async_task::spawn_local` that includes the source location of the spawn in panics.
161///
162/// Copy-modified from:
163/// <https://github.com/smol-rs/async-task/blob/ca9dbe1db9c422fd765847fa91306e30a6bb58a9/src/runnable.rs#L405>
164#[track_caller]
165fn spawn_local_with_source_location<Fut, S>(
166 future: Fut,
167 schedule: S,
168) -> (async_task::Runnable, async_task::Task<Fut::Output, ()>)
169where
170 Fut: Future + 'static,
171 Fut::Output: 'static,
172 S: async_task::Schedule + Send + Sync + 'static,
173{
174 #[inline]
175 fn thread_id() -> ThreadId {
176 std::thread_local! {
177 static ID: ThreadId = thread::current().id();
178 }
179 ID.try_with(|id| *id)
180 .unwrap_or_else(|_| thread::current().id())
181 }
182
183 struct Checked<F> {
184 id: ThreadId,
185 inner: ManuallyDrop<F>,
186 location: &'static Location<'static>,
187 }
188
189 impl<F> Drop for Checked<F> {
190 fn drop(&mut self) {
191 assert!(
192 self.id == thread_id(),
193 "local task dropped by a thread that didn't spawn it. Task spawned at {}",
194 self.location
195 );
196 unsafe {
197 ManuallyDrop::drop(&mut self.inner);
198 }
199 }
200 }
201
202 impl<F: Future> Future for Checked<F> {
203 type Output = F::Output;
204
205 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
206 assert!(
207 self.id == thread_id(),
208 "local task polled by a thread that didn't spawn it. Task spawned at {}",
209 self.location
210 );
211 unsafe { self.map_unchecked_mut(|c| &mut *c.inner).poll(cx) }
212 }
213 }
214
215 // Wrap the future into one that checks which thread it's on.
216 let future = Checked {
217 id: thread_id(),
218 inner: ManuallyDrop::new(future),
219 location: Location::caller(),
220 };
221
222 unsafe { async_task::spawn_unchecked(future, schedule) }
223}