1use crate::{BackgroundExecutor, Task};
2use std::{
3 future::Future,
4 pin::Pin,
5 sync::atomic::{AtomicUsize, Ordering::SeqCst},
6 task,
7 time::Duration,
8};
9
10pub use util::*;
11
12/// A helper trait for building complex objects with imperative conditionals in a fluent style.
13pub trait FluentBuilder {
14 /// Imperatively modify self with the given closure.
15 fn map<U>(self, f: impl FnOnce(Self) -> U) -> U
16 where
17 Self: Sized,
18 {
19 f(self)
20 }
21
22 /// Conditionally modify self with the given closure.
23 fn when(self, condition: bool, then: impl FnOnce(Self) -> Self) -> Self
24 where
25 Self: Sized,
26 {
27 self.map(|this| if condition { then(this) } else { this })
28 }
29
30 /// Conditionally modify self with the given closure.
31 fn when_else(
32 self,
33 condition: bool,
34 then: impl FnOnce(Self) -> Self,
35 else_fn: impl FnOnce(Self) -> Self,
36 ) -> Self
37 where
38 Self: Sized,
39 {
40 self.map(|this| if condition { then(this) } else { else_fn(this) })
41 }
42
43 /// Conditionally unwrap and modify self with the given closure, if the given option is Some.
44 fn when_some<T>(self, option: Option<T>, then: impl FnOnce(Self, T) -> Self) -> Self
45 where
46 Self: Sized,
47 {
48 self.map(|this| {
49 if let Some(value) = option {
50 then(this, value)
51 } else {
52 this
53 }
54 })
55 }
56 /// Conditionally unwrap and modify self with the given closure, if the given option is None.
57 fn when_none<T>(self, option: &Option<T>, then: impl FnOnce(Self) -> Self) -> Self
58 where
59 Self: Sized,
60 {
61 self.map(|this| {
62 if let Some(_) = option {
63 this
64 } else {
65 then(this)
66 }
67 })
68 }
69}
70
71/// Extensions for Future types that provide additional combinators and utilities.
72pub trait FutureExt {
73 /// Requires a Future to complete before the specified duration has elapsed.
74 /// Similar to tokio::timeout.
75 fn with_timeout(self, timeout: Duration, executor: &BackgroundExecutor) -> WithTimeout<Self>
76 where
77 Self: Sized;
78}
79
80impl<T: Future> FutureExt for T {
81 fn with_timeout(self, timeout: Duration, executor: &BackgroundExecutor) -> WithTimeout<Self>
82 where
83 Self: Sized,
84 {
85 WithTimeout {
86 future: self,
87 timer: executor.timer(timeout),
88 }
89 }
90}
91
92pub struct WithTimeout<T> {
93 future: T,
94 timer: Task<()>,
95}
96
97#[derive(Debug, thiserror::Error)]
98#[error("Timed out before future resolved")]
99/// Error returned by with_timeout when the timeout duration elapsed before the future resolved
100pub struct Timeout;
101
102impl<T: Future> Future for WithTimeout<T> {
103 type Output = Result<T::Output, Timeout>;
104
105 fn poll(self: Pin<&mut Self>, cx: &mut task::Context) -> task::Poll<Self::Output> {
106 // SAFETY: the fields of Timeout are private and we never move the future ourselves
107 // And its already pinned since we are being polled (all futures need to be pinned to be polled)
108 let this = unsafe { self.get_unchecked_mut() };
109 let future = unsafe { Pin::new_unchecked(&mut this.future) };
110 let timer = unsafe { Pin::new_unchecked(&mut this.timer) };
111
112 if let task::Poll::Ready(output) = future.poll(cx) {
113 task::Poll::Ready(Ok(output))
114 } else if timer.poll(cx).is_ready() {
115 task::Poll::Ready(Err(Timeout))
116 } else {
117 task::Poll::Pending
118 }
119 }
120}
121
122#[cfg(any(test, feature = "test-support"))]
123pub async fn smol_timeout<F, T>(timeout: Duration, f: F) -> Result<T, ()>
124where
125 F: Future<Output = T>,
126{
127 let timer = async {
128 smol::Timer::after(timeout).await;
129 Err(())
130 };
131 let future = async move { Ok(f.await) };
132 smol::future::FutureExt::race(timer, future).await
133}
134
135/// Increment the given atomic counter if it is not zero.
136/// Return the new value of the counter.
137pub(crate) fn atomic_incr_if_not_zero(counter: &AtomicUsize) -> usize {
138 let mut loaded = counter.load(SeqCst);
139 loop {
140 if loaded == 0 {
141 return 0;
142 }
143 match counter.compare_exchange_weak(loaded, loaded + 1, SeqCst, SeqCst) {
144 Ok(x) => return x + 1,
145 Err(actual) => loaded = actual,
146 }
147 }
148}