1use crate::BackgroundExecutor;
2use std::{
3 future::Future,
4 pin::Pin,
5 sync::atomic::{AtomicUsize, Ordering::SeqCst},
6 task,
7 time::Duration,
8};
9
10use scheduler::Timer;
11pub use util::*;
12
13/// A helper trait for building complex objects with imperative conditionals in a fluent style.
14pub trait FluentBuilder {
15 /// Imperatively modify self with the given closure.
16 fn map<U>(self, f: impl FnOnce(Self) -> U) -> U
17 where
18 Self: Sized,
19 {
20 f(self)
21 }
22
23 /// Conditionally modify self with the given closure.
24 fn when(self, condition: bool, then: impl FnOnce(Self) -> Self) -> Self
25 where
26 Self: Sized,
27 {
28 self.map(|this| if condition { then(this) } else { this })
29 }
30
31 /// Conditionally modify self with the given closure.
32 fn when_else(
33 self,
34 condition: bool,
35 then: impl FnOnce(Self) -> Self,
36 else_fn: impl FnOnce(Self) -> Self,
37 ) -> Self
38 where
39 Self: Sized,
40 {
41 self.map(|this| if condition { then(this) } else { else_fn(this) })
42 }
43
44 /// Conditionally unwrap and modify self with the given closure, if the given option is Some.
45 fn when_some<T>(self, option: Option<T>, then: impl FnOnce(Self, T) -> Self) -> Self
46 where
47 Self: Sized,
48 {
49 self.map(|this| {
50 if let Some(value) = option {
51 then(this, value)
52 } else {
53 this
54 }
55 })
56 }
57 /// Conditionally unwrap and modify self with the given closure, if the given option is None.
58 fn when_none<T>(self, option: &Option<T>, then: impl FnOnce(Self) -> Self) -> Self
59 where
60 Self: Sized,
61 {
62 self.map(|this| if option.is_some() { this } else { then(this) })
63 }
64}
65
66/// Extensions for Future types that provide additional combinators and utilities.
67pub trait FutureExt {
68 /// Requires a Future to complete before the specified duration has elapsed.
69 /// Similar to tokio::timeout.
70 fn with_timeout(self, timeout: Duration, executor: &BackgroundExecutor) -> WithTimeout<Self>
71 where
72 Self: Sized;
73}
74
75impl<T: Future> FutureExt for T {
76 fn with_timeout(self, timeout: Duration, executor: &BackgroundExecutor) -> WithTimeout<Self>
77 where
78 Self: Sized,
79 {
80 WithTimeout {
81 future: self,
82 timer: executor.timer(timeout),
83 }
84 }
85}
86
87pub struct WithTimeout<T> {
88 future: T,
89 timer: Timer,
90}
91
92#[derive(Debug, thiserror::Error)]
93#[error("Timed out before future resolved")]
94/// Error returned by with_timeout when the timeout duration elapsed before the future resolved
95pub struct Timeout;
96
97impl<T: Future> Future for WithTimeout<T> {
98 type Output = Result<T::Output, Timeout>;
99
100 fn poll(self: Pin<&mut Self>, cx: &mut task::Context) -> task::Poll<Self::Output> {
101 // SAFETY: the fields of Timeout are private and we never move the future ourselves
102 // And its already pinned since we are being polled (all futures need to be pinned to be polled)
103 let this = unsafe { &raw mut *self.get_unchecked_mut() };
104 let future = unsafe { Pin::new_unchecked(&mut (*this).future) };
105 let timer = unsafe { Pin::new_unchecked(&mut (*this).timer) };
106
107 if let task::Poll::Ready(output) = future.poll(cx) {
108 task::Poll::Ready(Ok(output))
109 } else if timer.poll(cx).is_ready() {
110 task::Poll::Ready(Err(Timeout))
111 } else {
112 task::Poll::Pending
113 }
114 }
115}
116
117#[cfg(any(test, feature = "test-support"))]
118pub async fn smol_timeout<F, T>(timeout: Duration, f: F) -> Result<T, ()>
119where
120 F: Future<Output = T>,
121{
122 let timer = async {
123 smol::Timer::after(timeout).await;
124 Err(())
125 };
126 let future = async move { Ok(f.await) };
127 smol::future::FutureExt::race(timer, future).await
128}
129
130/// Increment the given atomic counter if it is not zero.
131/// Return the new value of the counter.
132pub(crate) fn atomic_incr_if_not_zero(counter: &AtomicUsize) -> usize {
133 let mut loaded = counter.load(SeqCst);
134 loop {
135 if loaded == 0 {
136 return 0;
137 }
138 match counter.compare_exchange_weak(loaded, loaded + 1, SeqCst, SeqCst) {
139 Ok(x) => return x + 1,
140 Err(actual) => loaded = actual,
141 }
142 }
143}