@@ -1,6 +1,6 @@
use anyhow::{anyhow, Result};
use async_task::Runnable;
-use smol::{channel, prelude::*, Executor, Timer};
+use smol::{channel, prelude::*, Executor};
use std::{
any::Any,
fmt::{self, Display},
@@ -86,6 +86,19 @@ pub struct Deterministic {
parker: parking_lot::Mutex<parking::Parker>,
}
+pub enum Timer {
+ Production(smol::Timer),
+ #[cfg(any(test, feature = "test-support"))]
+ Deterministic(DeterministicTimer),
+}
+
+#[cfg(any(test, feature = "test-support"))]
+pub struct DeterministicTimer {
+ rx: postage::barrier::Receiver,
+ id: usize,
+ state: Arc<parking_lot::Mutex<DeterministicState>>,
+}
+
#[cfg(any(test, feature = "test-support"))]
impl Deterministic {
pub fn new(seed: u64) -> Arc<Self> {
@@ -306,30 +319,14 @@ impl Deterministic {
None
}
- pub fn timer(&self, duration: Duration) -> impl Future<Output = ()> {
- let (tx, mut rx) = postage::barrier::channel();
- let timer_id;
- {
- let mut state = self.state.lock();
- let wakeup_at = state.now + duration;
- timer_id = util::post_inc(&mut state.next_timer_id);
- state.pending_timers.push((timer_id, wakeup_at, tx));
- }
-
- let remove_timer = util::defer({
- let state = self.state.clone();
- move || {
- state
- .lock()
- .pending_timers
- .retain(|(id, _, _)| *id != timer_id);
- }
- });
-
- async move {
- postage::prelude::Stream::recv(&mut rx).await;
- drop(remove_timer);
- }
+ pub fn timer(&self, duration: Duration) -> Timer {
+ let (tx, rx) = postage::barrier::channel();
+ let mut state = self.state.lock();
+ let wakeup_at = state.now + duration;
+ let id = util::post_inc(&mut state.next_timer_id);
+ state.pending_timers.push((id, wakeup_at, tx));
+ let state = self.state.clone();
+ Timer::Deterministic(DeterministicTimer { rx, id, state })
}
pub fn advance_clock(&self, duration: Duration) {
@@ -344,6 +341,43 @@ impl Deterministic {
}
}
+impl Drop for Timer {
+ fn drop(&mut self) {
+ #[cfg(any(test, feature = "test-support"))]
+ if let Timer::Deterministic(DeterministicTimer { state, id, .. }) = self {
+ state
+ .lock()
+ .pending_timers
+ .retain(|(timer_id, _, _)| timer_id != id)
+ }
+ }
+}
+
+impl Future for Timer {
+ type Output = ();
+
+ fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
+ match &mut *self {
+ #[cfg(any(test, feature = "test-support"))]
+ Self::Deterministic(DeterministicTimer { rx, .. }) => {
+ use postage::stream::{PollRecv, Stream as _};
+ smol::pin!(rx);
+ match rx.poll_recv(&mut postage::Context::from_waker(cx.waker())) {
+ PollRecv::Ready(()) | PollRecv::Closed => Poll::Ready(()),
+ PollRecv::Pending => Poll::Pending,
+ }
+ }
+ Self::Production(timer) => {
+ smol::pin!(timer);
+ match timer.poll(cx) {
+ Poll::Ready(_) => Poll::Ready(()),
+ Poll::Pending => Poll::Pending,
+ }
+ }
+ }
+ }
+}
+
#[cfg(any(test, feature = "test-support"))]
impl DeterministicState {
fn will_park(&mut self) {
@@ -464,23 +498,6 @@ impl Foreground {
}
}
- pub fn timer(&self, duration: Duration) -> impl Future<Output = ()> {
- let mut timer = None;
-
- #[cfg(any(test, feature = "test-support"))]
- if let Self::Deterministic { executor, .. } = self {
- timer = Some(executor.timer(duration));
- }
-
- async move {
- if let Some(timer) = timer {
- timer.await;
- } else {
- Timer::after(duration).await;
- }
- }
- }
-
#[cfg(any(test, feature = "test-support"))]
pub fn advance_clock(&self, duration: Duration) {
match self {
@@ -603,20 +620,11 @@ impl Background {
}
}
- pub fn timer(&self, duration: Duration) -> impl Future<Output = ()> {
- let mut timer = None;
-
- #[cfg(any(test, feature = "test-support"))]
- if let Self::Deterministic { executor, .. } = self {
- timer = Some(executor.timer(duration));
- }
-
- async move {
- if let Some(timer) = timer {
- timer.await;
- } else {
- Timer::after(duration).await;
- }
+ pub fn timer(&self, duration: Duration) -> Timer {
+ match self {
+ Background::Production { .. } => Timer::Production(smol::Timer::after(duration)),
+ #[cfg(any(test, feature = "test-support"))]
+ Background::Deterministic { executor } => executor.timer(duration),
}
}