Add FutureExt::with_timeout and use it for for Room::maintain_connection (#36175)

David Kleingeld and Antonio Scandurra created

Release Notes:

- N/A

---------

Co-authored-by: Antonio Scandurra <me@as-cii.com>

Change summary

crates/call/src/call_impl/room.rs   | 90 ++++++++++++++----------------
crates/gpui/src/app/test_context.rs |  4 
crates/gpui/src/gpui.rs             |  2 
crates/gpui/src/util.rs             | 73 +++++++++++++++++++++----
4 files changed, 107 insertions(+), 62 deletions(-)

Detailed changes

crates/call/src/call_impl/room.rs 🔗

@@ -10,10 +10,10 @@ use client::{
 };
 use collections::{BTreeMap, HashMap, HashSet};
 use fs::Fs;
-use futures::{FutureExt, StreamExt};
+use futures::StreamExt;
 use gpui::{
-    App, AppContext as _, AsyncApp, Context, Entity, EventEmitter, ScreenCaptureSource,
-    ScreenCaptureStream, Task, WeakEntity,
+    App, AppContext as _, AsyncApp, Context, Entity, EventEmitter, FutureExt as _,
+    ScreenCaptureSource, ScreenCaptureStream, Task, Timeout, WeakEntity,
 };
 use gpui_tokio::Tokio;
 use language::LanguageRegistry;
@@ -370,57 +370,53 @@ impl Room {
                     })?;
 
                 // Wait for client to re-establish a connection to the server.
-                {
-                    let mut reconnection_timeout =
-                        cx.background_executor().timer(RECONNECT_TIMEOUT).fuse();
-                    let client_reconnection = async {
-                        let mut remaining_attempts = 3;
-                        while remaining_attempts > 0 {
-                            if client_status.borrow().is_connected() {
-                                log::info!("client reconnected, attempting to rejoin room");
-
-                                let Some(this) = this.upgrade() else { break };
-                                match this.update(cx, |this, cx| this.rejoin(cx)) {
-                                    Ok(task) => {
-                                        if task.await.log_err().is_some() {
-                                            return true;
-                                        } else {
-                                            remaining_attempts -= 1;
-                                        }
+                let executor = cx.background_executor().clone();
+                let client_reconnection = async {
+                    let mut remaining_attempts = 3;
+                    while remaining_attempts > 0 {
+                        if client_status.borrow().is_connected() {
+                            log::info!("client reconnected, attempting to rejoin room");
+
+                            let Some(this) = this.upgrade() else { break };
+                            match this.update(cx, |this, cx| this.rejoin(cx)) {
+                                Ok(task) => {
+                                    if task.await.log_err().is_some() {
+                                        return true;
+                                    } else {
+                                        remaining_attempts -= 1;
                                     }
-                                    Err(_app_dropped) => return false,
                                 }
-                            } else if client_status.borrow().is_signed_out() {
-                                return false;
+                                Err(_app_dropped) => return false,
                             }
-
-                            log::info!(
-                                "waiting for client status change, remaining attempts {}",
-                                remaining_attempts
-                            );
-                            client_status.next().await;
+                        } else if client_status.borrow().is_signed_out() {
+                            return false;
                         }
-                        false
+
+                        log::info!(
+                            "waiting for client status change, remaining attempts {}",
+                            remaining_attempts
+                        );
+                        client_status.next().await;
                     }
-                    .fuse();
-                    futures::pin_mut!(client_reconnection);
-
-                    futures::select_biased! {
-                        reconnected = client_reconnection => {
-                            if reconnected {
-                                log::info!("successfully reconnected to room");
-                                // If we successfully joined the room, go back around the loop
-                                // waiting for future connection status changes.
-                                continue;
-                            }
-                        }
-                        _ = reconnection_timeout => {
-                            log::info!("room reconnection timeout expired");
-                        }
+                    false
+                };
+
+                match client_reconnection
+                    .with_timeout(RECONNECT_TIMEOUT, &executor)
+                    .await
+                {
+                    Ok(true) => {
+                        log::info!("successfully reconnected to room");
+                        // If we successfully joined the room, go back around the loop
+                        // waiting for future connection status changes.
+                        continue;
+                    }
+                    Ok(false) => break,
+                    Err(Timeout) => {
+                        log::info!("room reconnection timeout expired");
+                        break;
                     }
                 }
-
-                break;
             }
         }
 

crates/gpui/src/app/test_context.rs 🔗

@@ -585,7 +585,7 @@ impl<V: 'static> Entity<V> {
         cx.executor().advance_clock(advance_clock_by);
 
         async move {
-            let notification = crate::util::timeout(duration, rx.recv())
+            let notification = crate::util::smol_timeout(duration, rx.recv())
                 .await
                 .expect("next notification timed out");
             drop(subscription);
@@ -629,7 +629,7 @@ impl<V> Entity<V> {
         let handle = self.downgrade();
 
         async move {
-            crate::util::timeout(Duration::from_secs(1), async move {
+            crate::util::smol_timeout(Duration::from_secs(1), async move {
                 loop {
                     {
                         let cx = cx.borrow();

crates/gpui/src/gpui.rs 🔗

@@ -157,7 +157,7 @@ pub use taffy::{AvailableSpace, LayoutId};
 #[cfg(any(test, feature = "test-support"))]
 pub use test::*;
 pub use text_system::*;
-pub use util::arc_cow::ArcCow;
+pub use util::{FutureExt, Timeout, arc_cow::ArcCow};
 pub use view::*;
 pub use window::*;
 

crates/gpui/src/util.rs 🔗

@@ -1,13 +1,11 @@
-use std::sync::atomic::AtomicUsize;
-use std::sync::atomic::Ordering::SeqCst;
-#[cfg(any(test, feature = "test-support"))]
-use std::time::Duration;
-
-#[cfg(any(test, feature = "test-support"))]
-use futures::Future;
-
-#[cfg(any(test, feature = "test-support"))]
-use smol::future::FutureExt;
+use crate::{BackgroundExecutor, Task};
+use std::{
+    future::Future,
+    pin::Pin,
+    sync::atomic::{AtomicUsize, Ordering::SeqCst},
+    task,
+    time::Duration,
+};
 
 pub use util::*;
 
@@ -70,8 +68,59 @@ pub trait FluentBuilder {
     }
 }
 
+/// Extensions for Future types that provide additional combinators and utilities.
+pub trait FutureExt {
+    /// Requires a Future to complete before the specified duration has elapsed.
+    /// Similar to tokio::timeout.
+    fn with_timeout(self, timeout: Duration, executor: &BackgroundExecutor) -> WithTimeout<Self>
+    where
+        Self: Sized;
+}
+
+impl<T: Future> FutureExt for T {
+    fn with_timeout(self, timeout: Duration, executor: &BackgroundExecutor) -> WithTimeout<Self>
+    where
+        Self: Sized,
+    {
+        WithTimeout {
+            future: self,
+            timer: executor.timer(timeout),
+        }
+    }
+}
+
+pub struct WithTimeout<T> {
+    future: T,
+    timer: Task<()>,
+}
+
+#[derive(Debug, thiserror::Error)]
+#[error("Timed out before future resolved")]
+/// Error returned by with_timeout when the timeout duration elapsed before the future resolved
+pub struct Timeout;
+
+impl<T: Future> Future for WithTimeout<T> {
+    type Output = Result<T::Output, Timeout>;
+
+    fn poll(self: Pin<&mut Self>, cx: &mut task::Context) -> task::Poll<Self::Output> {
+        // SAFETY: the fields of Timeout are private and we never move the future ourselves
+        // And its already pinned since we are being polled (all futures need to be pinned to be polled)
+        let this = unsafe { self.get_unchecked_mut() };
+        let future = unsafe { Pin::new_unchecked(&mut this.future) };
+        let timer = unsafe { Pin::new_unchecked(&mut this.timer) };
+
+        if let task::Poll::Ready(output) = future.poll(cx) {
+            task::Poll::Ready(Ok(output))
+        } else if timer.poll(cx).is_ready() {
+            task::Poll::Ready(Err(Timeout))
+        } else {
+            task::Poll::Pending
+        }
+    }
+}
+
 #[cfg(any(test, feature = "test-support"))]
-pub async fn timeout<F, T>(timeout: Duration, f: F) -> Result<T, ()>
+pub async fn smol_timeout<F, T>(timeout: Duration, f: F) -> Result<T, ()>
 where
     F: Future<Output = T>,
 {
@@ -80,7 +129,7 @@ where
         Err(())
     };
     let future = async move { Ok(f.await) };
-    timer.race(future).await
+    smol::future::FutureExt::race(timer, future).await
 }
 
 /// Increment the given atomic counter if it is not zero.