Send websocket pings from both the client and the server

Max Brunsfeld , Nathan Sobo , and Antonio Scandurra created

Remove the client-only logic for sending protobuf pings.

Co-Authored-By: Nathan Sobo <nathan@zed.dev>
Co-Authored-By: Antonio Scandurra <me@as-cii.com>

Change summary

Cargo.lock                  |   6 +
crates/client/src/client.rs |  54 ++++------------
crates/client/src/test.rs   |   3 
crates/gpui/Cargo.toml      |   1 
crates/rpc/Cargo.toml       |   2 
crates/rpc/src/peer.rs      | 124 +++++++++++++++++++++++++-------------
crates/rpc/src/proto.rs     |   7 ++
crates/server/Cargo.toml    |   1 
crates/server/src/rpc.rs    |  68 +++++++++++++++++++--
9 files changed, 174 insertions(+), 92 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -998,7 +998,6 @@ dependencies = [
 name = "clock"
 version = "0.1.0"
 dependencies = [
- "rpc",
  "smallvec",
 ]
 
@@ -2236,6 +2235,7 @@ dependencies = [
  "tiny-skia",
  "tree-sitter",
  "usvg",
+ "util",
  "waker-fn",
 ]
 
@@ -3959,6 +3959,7 @@ dependencies = [
  "async-lock",
  "async-tungstenite",
  "base64 0.13.0",
+ "clock",
  "futures",
  "gpui",
  "log",
@@ -3972,6 +3973,7 @@ dependencies = [
  "smol",
  "smol-timeout",
  "tempdir",
+ "util",
  "zstd",
 ]
 
@@ -5574,7 +5576,6 @@ name = "util"
 version = "0.1.0"
 dependencies = [
  "anyhow",
- "clock",
  "futures",
  "log",
  "rand 0.8.3",
@@ -5959,6 +5960,7 @@ name = "zed-server"
 version = "0.1.0"
 dependencies = [
  "anyhow",
+ "async-io",
  "async-sqlx-session",
  "async-std",
  "async-trait",

crates/client/src/client.rs 🔗

@@ -137,8 +137,8 @@ struct ClientState {
     credentials: Option<Credentials>,
     status: (watch::Sender<Status>, watch::Receiver<Status>),
     entity_id_extractors: HashMap<TypeId, Box<dyn Send + Sync + Fn(&dyn AnyTypedEnvelope) -> u64>>,
-    _maintain_connection: Option<Task<()>>,
-    heartbeat_interval: Duration,
+    _reconnect_task: Option<Task<()>>,
+    reconnect_interval: Duration,
     models_by_entity_type_and_remote_id: HashMap<(TypeId, u64), AnyWeakModelHandle>,
     models_by_message_type: HashMap<TypeId, AnyWeakModelHandle>,
     model_types_by_message_type: HashMap<TypeId, TypeId>,
@@ -168,8 +168,8 @@ impl Default for ClientState {
             credentials: None,
             status: watch::channel_with(Status::SignedOut),
             entity_id_extractors: Default::default(),
-            _maintain_connection: None,
-            heartbeat_interval: Duration::from_secs(5),
+            _reconnect_task: None,
+            reconnect_interval: Duration::from_secs(5),
             models_by_message_type: Default::default(),
             models_by_entity_type_and_remote_id: Default::default(),
             model_types_by_message_type: Default::default(),
@@ -236,7 +236,7 @@ impl Client {
     #[cfg(any(test, feature = "test-support"))]
     pub fn tear_down(&self) {
         let mut state = self.state.write();
-        state._maintain_connection.take();
+        state._reconnect_task.take();
         state.message_handlers.clear();
         state.models_by_message_type.clear();
         state.models_by_entity_type_and_remote_id.clear();
@@ -283,21 +283,13 @@ impl Client {
 
         match status {
             Status::Connected { .. } => {
-                let heartbeat_interval = state.heartbeat_interval;
-                let this = self.clone();
-                let foreground = cx.foreground();
-                state._maintain_connection = Some(cx.foreground().spawn(async move {
-                    loop {
-                        foreground.timer(heartbeat_interval).await;
-                        let _ = this.request(proto::Ping {}).await;
-                    }
-                }));
+                state._reconnect_task = None;
             }
             Status::ConnectionLost => {
                 let this = self.clone();
                 let foreground = cx.foreground();
-                let heartbeat_interval = state.heartbeat_interval;
-                state._maintain_connection = Some(cx.spawn(|cx| async move {
+                let reconnect_interval = state.reconnect_interval;
+                state._reconnect_task = Some(cx.spawn(|cx| async move {
                     let mut rng = StdRng::from_entropy();
                     let mut delay = Duration::from_millis(100);
                     while let Err(error) = this.authenticate_and_connect(&cx).await {
@@ -311,12 +303,12 @@ impl Client {
                         foreground.timer(delay).await;
                         delay = delay
                             .mul_f32(rng.gen_range(1.0..=2.0))
-                            .min(heartbeat_interval);
+                            .min(reconnect_interval);
                     }
                 }));
             }
             Status::SignedOut | Status::UpgradeRequired => {
-                state._maintain_connection.take();
+                state._reconnect_task.take();
             }
             _ => {}
         }
@@ -548,7 +540,11 @@ impl Client {
     }
 
     async fn set_connection(self: &Arc<Self>, conn: Connection, cx: &AsyncAppContext) {
-        let (connection_id, handle_io, mut incoming) = self.peer.add_connection(conn).await;
+        let executor = cx.background();
+        let (connection_id, handle_io, mut incoming) = self
+            .peer
+            .add_connection(conn, move |duration| executor.timer(duration))
+            .await;
         cx.foreground()
             .spawn({
                 let cx = cx.clone();
@@ -940,26 +936,6 @@ mod tests {
     use crate::test::{FakeHttpClient, FakeServer};
     use gpui::TestAppContext;
 
-    #[gpui::test(iterations = 10)]
-    async fn test_heartbeat(cx: &mut TestAppContext) {
-        cx.foreground().forbid_parking();
-
-        let user_id = 5;
-        let mut client = Client::new(FakeHttpClient::with_404_response());
-        let server = FakeServer::for_client(user_id, &mut client, &cx).await;
-
-        cx.foreground().advance_clock(Duration::from_secs(10));
-        let ping = server.receive::<proto::Ping>().await.unwrap();
-        server.respond(ping.receipt(), proto::Ack {}).await;
-
-        cx.foreground().advance_clock(Duration::from_secs(10));
-        let ping = server.receive::<proto::Ping>().await.unwrap();
-        server.respond(ping.receipt(), proto::Ack {}).await;
-
-        client.disconnect(&cx.to_async()).unwrap();
-        assert!(server.receive::<proto::Ping>().await.is_err());
-    }
-
     #[gpui::test(iterations = 10)]
     async fn test_reconnection(cx: &mut TestAppContext) {
         cx.foreground().forbid_parking();

crates/client/src/test.rs 🔗

@@ -75,7 +75,8 @@ impl FakeServer {
                         }
 
                         let (client_conn, server_conn, _) = Connection::in_memory(cx.background());
-                        let (connection_id, io, incoming) = peer.add_connection(server_conn).await;
+                        let (connection_id, io, incoming) =
+                            peer.add_test_connection(server_conn, cx.background()).await;
                         cx.background().spawn(io).detach();
                         let mut state = state.lock();
                         state.connection_id = Some(connection_id);

crates/gpui/Cargo.toml 🔗

@@ -14,6 +14,7 @@ test-support = ["backtrace", "dhat", "env_logger", "collections/test-support"]
 [dependencies]
 collections = { path = "../collections" }
 gpui_macros = { path = "../gpui_macros" }
+util = { path = "../util" }
 sum_tree = { path = "../sum_tree" }
 async-task = "4.0.3"
 backtrace = { version = "0.3", optional = true }

crates/rpc/Cargo.toml 🔗

@@ -26,7 +26,9 @@ rsa = "0.4"
 serde = { version = "1", features = ["derive"] }
 smol-timeout = "0.6"
 zstd = "0.9"
+clock = { path = "../clock" }
 gpui = { path = "../gpui", optional = true }
+util = { path = "../util" }
 
 [build-dependencies]
 prost-build = "0.8"

crates/rpc/src/peer.rs 🔗

@@ -94,6 +94,7 @@ pub struct ConnectionState {
         Arc<Mutex<Option<HashMap<u32, oneshot::Sender<(proto::Envelope, barrier::Sender)>>>>>,
 }
 
+const KEEPALIVE_INTERVAL: Duration = Duration::from_secs(2);
 const WRITE_TIMEOUT: Duration = Duration::from_secs(10);
 
 impl Peer {
@@ -104,14 +105,20 @@ impl Peer {
         })
     }
 
-    pub async fn add_connection(
+    pub async fn add_connection<F, Fut, Out>(
         self: &Arc<Self>,
         connection: Connection,
+        create_timer: F,
     ) -> (
         ConnectionId,
         impl Future<Output = anyhow::Result<()>> + Send,
         BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
-    ) {
+    )
+    where
+        F: Send + Fn(Duration) -> Fut,
+        Fut: Send + Future<Output = Out>,
+        Out: Send,
+    {
         // For outgoing messages, use an unbounded channel so that application code
         // can always send messages without yielding. For incoming messages, use a
         // bounded channel so that other peers will receive backpressure if they send
@@ -121,7 +128,7 @@ impl Peer {
 
         let connection_id = ConnectionId(self.next_connection_id.fetch_add(1, SeqCst));
         let connection_state = ConnectionState {
-            outgoing_tx,
+            outgoing_tx: outgoing_tx.clone(),
             next_message_id: Default::default(),
             response_channels: Arc::new(Mutex::new(Some(Default::default()))),
         };
@@ -131,39 +138,43 @@ impl Peer {
         let this = self.clone();
         let response_channels = connection_state.response_channels.clone();
         let handle_io = async move {
-            let result = 'outer: loop {
+            let _end_connection = util::defer(|| {
+                response_channels.lock().take();
+                this.connections.write().remove(&connection_id);
+            });
+
+            loop {
                 let read_message = reader.read_message().fuse();
                 futures::pin_mut!(read_message);
                 loop {
                     futures::select_biased! {
                         outgoing = outgoing_rx.next().fuse() => match outgoing {
                             Some(outgoing) => {
-                                match writer.write_message(&outgoing).timeout(WRITE_TIMEOUT).await {
-                                    None => break 'outer Err(anyhow!("timed out writing RPC message")),
-                                    Some(Err(result)) => break 'outer Err(result).context("failed to write RPC message"),
-                                    _ => {}
+                                if let Some(result) = writer.write_message(&outgoing).timeout(WRITE_TIMEOUT).await {
+                                    result.context("failed to write RPC message")?;
+                                } else {
+                                    Err(anyhow!("timed out writing message"))?;
                                 }
                             }
-                            None => break 'outer Ok(()),
+                            None => return Ok(()),
                         },
-                        incoming = read_message => match incoming {
-                            Ok(incoming) => {
-                                if incoming_tx.send(incoming).await.is_err() {
-                                    break 'outer Ok(());
-                                }
-                                break;
-                            }
-                            Err(error) => {
-                                break 'outer Err(error).context("received invalid RPC message")
+                        incoming = read_message => {
+                            let incoming = incoming.context("received invalid rpc message")?;
+                            if incoming_tx.send(incoming).await.is_err() {
+                                return Ok(());
                             }
+                            break;
                         },
+                        _ = create_timer(KEEPALIVE_INTERVAL).fuse() => {
+                            if let Some(result) = writer.ping().timeout(WRITE_TIMEOUT).await {
+                                result.context("failed to send websocket ping")?;
+                            } else {
+                                Err(anyhow!("timed out sending websocket ping"))?;
+                            }
+                        }
                     }
                 }
-            };
-
-            response_channels.lock().take();
-            this.connections.write().remove(&connection_id);
-            result
+            }
         };
 
         let response_channels = connection_state.response_channels.clone();
@@ -191,18 +202,31 @@ impl Peer {
 
                     None
                 } else {
-                    if let Some(envelope) = proto::build_typed_envelope(connection_id, incoming) {
-                        Some(envelope)
-                    } else {
+                    proto::build_typed_envelope(connection_id, incoming).or_else(|| {
                         log::error!("unable to construct a typed envelope");
                         None
-                    }
+                    })
                 }
             }
         });
         (connection_id, handle_io, incoming_rx.boxed())
     }
 
+    #[cfg(any(test, feature = "test-support"))]
+    pub async fn add_test_connection(
+        self: &Arc<Self>,
+        connection: Connection,
+        executor: Arc<gpui::executor::Background>,
+    ) -> (
+        ConnectionId,
+        impl Future<Output = anyhow::Result<()>> + Send,
+        BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
+    ) {
+        let executor = executor.clone();
+        self.add_connection(connection, move |duration| executor.timer(duration))
+            .await
+    }
+
     pub fn disconnect(&self, connection_id: ConnectionId) {
         self.connections.write().remove(&connection_id);
     }
@@ -349,15 +373,21 @@ mod tests {
 
         let (client1_to_server_conn, server_to_client_1_conn, _) =
             Connection::in_memory(cx.background());
-        let (client1_conn_id, io_task1, client1_incoming) =
-            client1.add_connection(client1_to_server_conn).await;
-        let (_, io_task2, server_incoming1) = server.add_connection(server_to_client_1_conn).await;
+        let (client1_conn_id, io_task1, client1_incoming) = client1
+            .add_test_connection(client1_to_server_conn, cx.background())
+            .await;
+        let (_, io_task2, server_incoming1) = server
+            .add_test_connection(server_to_client_1_conn, cx.background())
+            .await;
 
         let (client2_to_server_conn, server_to_client_2_conn, _) =
             Connection::in_memory(cx.background());
-        let (client2_conn_id, io_task3, client2_incoming) =
-            client2.add_connection(client2_to_server_conn).await;
-        let (_, io_task4, server_incoming2) = server.add_connection(server_to_client_2_conn).await;
+        let (client2_conn_id, io_task3, client2_incoming) = client2
+            .add_test_connection(client2_to_server_conn, cx.background())
+            .await;
+        let (_, io_task4, server_incoming2) = server
+            .add_test_connection(server_to_client_2_conn, cx.background())
+            .await;
 
         executor.spawn(io_task1).detach();
         executor.spawn(io_task2).detach();
@@ -440,10 +470,12 @@ mod tests {
 
         let (client_to_server_conn, server_to_client_conn, _) =
             Connection::in_memory(cx.background());
-        let (client_to_server_conn_id, io_task1, mut client_incoming) =
-            client.add_connection(client_to_server_conn).await;
-        let (server_to_client_conn_id, io_task2, mut server_incoming) =
-            server.add_connection(server_to_client_conn).await;
+        let (client_to_server_conn_id, io_task1, mut client_incoming) = client
+            .add_test_connection(client_to_server_conn, cx.background())
+            .await;
+        let (server_to_client_conn_id, io_task2, mut server_incoming) = server
+            .add_test_connection(server_to_client_conn, cx.background())
+            .await;
 
         executor.spawn(io_task1).detach();
         executor.spawn(io_task2).detach();
@@ -538,10 +570,12 @@ mod tests {
 
         let (client_to_server_conn, server_to_client_conn, _) =
             Connection::in_memory(cx.background());
-        let (client_to_server_conn_id, io_task1, mut client_incoming) =
-            client.add_connection(client_to_server_conn).await;
-        let (server_to_client_conn_id, io_task2, mut server_incoming) =
-            server.add_connection(server_to_client_conn).await;
+        let (client_to_server_conn_id, io_task1, mut client_incoming) = client
+            .add_test_connection(client_to_server_conn, cx.background())
+            .await;
+        let (server_to_client_conn_id, io_task2, mut server_incoming) = server
+            .add_test_connection(server_to_client_conn, cx.background())
+            .await;
 
         executor.spawn(io_task1).detach();
         executor.spawn(io_task2).detach();
@@ -649,7 +683,9 @@ mod tests {
         let (client_conn, mut server_conn, _) = Connection::in_memory(cx.background());
 
         let client = Peer::new();
-        let (connection_id, io_handler, mut incoming) = client.add_connection(client_conn).await;
+        let (connection_id, io_handler, mut incoming) = client
+            .add_test_connection(client_conn, cx.background())
+            .await;
 
         let (mut io_ended_tx, mut io_ended_rx) = postage::barrier::channel();
         executor
@@ -683,7 +719,9 @@ mod tests {
         let (client_conn, mut server_conn, _) = Connection::in_memory(cx.background());
 
         let client = Peer::new();
-        let (connection_id, io_handler, mut incoming) = client.add_connection(client_conn).await;
+        let (connection_id, io_handler, mut incoming) = client
+            .add_test_connection(client_conn, cx.background())
+            .await;
         executor.spawn(io_handler).detach();
         executor
             .spawn(async move { incoming.next().await })

crates/rpc/src/proto.rs 🔗

@@ -318,6 +318,13 @@ where
         self.stream.send(WebSocketMessage::Binary(buffer)).await?;
         Ok(())
     }
+
+    pub async fn ping(&mut self) -> Result<(), WebSocketError> {
+        self.stream
+            .send(WebSocketMessage::Ping(Default::default()))
+            .await?;
+        Ok(())
+    }
 }
 
 impl<S> MessageStream<S>

crates/server/Cargo.toml 🔗

@@ -16,6 +16,7 @@ required-features = ["seed-support"]
 collections = { path = "../collections" }
 rpc = { path = "../rpc" }
 anyhow = "1.0.40"
+async-io = "1.3"
 async-std = { version = "1.8.0", features = ["attributes"] }
 async-trait = "0.1.50"
 async-tungstenite = "0.16"

crates/server/src/rpc.rs 🔗

@@ -6,6 +6,7 @@ use super::{
     AppState,
 };
 use anyhow::anyhow;
+use async_io::Timer;
 use async_std::task;
 use async_tungstenite::{tungstenite::protocol::Role, WebSocketStream};
 use collections::{HashMap, HashSet};
@@ -16,7 +17,12 @@ use rpc::{
     Connection, ConnectionId, Peer, TypedEnvelope,
 };
 use sha1::{Digest as _, Sha1};
-use std::{any::TypeId, future::Future, sync::Arc, time::Instant};
+use std::{
+    any::TypeId,
+    future::Future,
+    sync::Arc,
+    time::{Duration, Instant},
+};
 use store::{Store, Worktree};
 use surf::StatusCode;
 use tide::log;
@@ -40,10 +46,13 @@ pub struct Server {
     notifications: Option<mpsc::UnboundedSender<()>>,
 }
 
-pub trait Executor {
+pub trait Executor: Send + Clone {
+    type Timer: Send + Future;
     fn spawn_detached<F: 'static + Send + Future<Output = ()>>(&self, future: F);
+    fn timer(&self, duration: Duration) -> Self::Timer;
 }
 
+#[derive(Clone)]
 pub struct RealExecutor;
 
 const MESSAGE_COUNT_PER_PAGE: usize = 100;
@@ -167,8 +176,18 @@ impl Server {
     ) -> impl Future<Output = ()> {
         let mut this = self.clone();
         async move {
-            let (connection_id, handle_io, mut incoming_rx) =
-                this.peer.add_connection(connection).await;
+            let (connection_id, handle_io, mut incoming_rx) = this
+                .peer
+                .add_connection(connection, {
+                    let executor = executor.clone();
+                    move |duration| {
+                        let timer = executor.timer(duration);
+                        async move {
+                            timer.await;
+                        }
+                    }
+                })
+                .await;
 
             if let Some(send_connection_id) = send_connection_id.as_mut() {
                 let _ = send_connection_id.send(connection_id).await;
@@ -883,9 +902,15 @@ impl Server {
 }
 
 impl Executor for RealExecutor {
+    type Timer = Timer;
+
     fn spawn_detached<F: 'static + Send + Future<Output = ()>>(&self, future: F) {
         task::spawn(future);
     }
+
+    fn timer(&self, duration: Duration) -> Self::Timer {
+        Timer::after(duration)
+    }
 }
 
 fn broadcast<F>(
@@ -1759,7 +1784,7 @@ mod tests {
     }
 
     #[gpui::test(iterations = 10)]
-    async fn test_peer_disconnection(cx_a: &mut TestAppContext, cx_b: &mut TestAppContext) {
+    async fn test_leaving_project(cx_a: &mut TestAppContext, cx_b: &mut TestAppContext) {
         cx_a.foreground().forbid_parking();
         let lang_registry = Arc::new(LanguageRegistry::new());
         let fs = FakeFs::new(cx_a.background());
@@ -1817,16 +1842,39 @@ mod tests {
         .await
         .unwrap();
 
-        // See that a guest has joined as client A.
+        // Client A sees that a guest has joined.
         project_a
             .condition(&cx_a, |p, _| p.collaborators().len() == 1)
             .await;
 
-        // Drop client B's connection and ensure client A observes client B leaving the worktree.
+        // Drop client B's connection and ensure client A observes client B leaving the project.
         client_b.disconnect(&cx_b.to_async()).unwrap();
         project_a
             .condition(&cx_a, |p, _| p.collaborators().len() == 0)
             .await;
+
+        // Rejoin the project as client B
+        let _project_b = Project::remote(
+            project_id,
+            client_b.clone(),
+            client_b.user_store.clone(),
+            lang_registry.clone(),
+            fs.clone(),
+            &mut cx_b.to_async(),
+        )
+        .await
+        .unwrap();
+
+        // Client A sees that a guest has re-joined.
+        project_a
+            .condition(&cx_a, |p, _| p.collaborators().len() == 1)
+            .await;
+
+        // Simulate connection loss for client B and ensure client A observes client B leaving the project.
+        server.disconnect_client(client_b.current_user_id(cx_b));
+        project_a
+            .condition(&cx_a, |p, _| p.collaborators().len() == 0)
+            .await;
     }
 
     #[gpui::test(iterations = 10)]
@@ -5031,9 +5079,15 @@ mod tests {
     }
 
     impl Executor for Arc<gpui::executor::Background> {
+        type Timer = BoxFuture<'static, ()>;
+
         fn spawn_detached<F: 'static + Send + Future<Output = ()>>(&self, future: F) {
             self.spawn(future).detach();
         }
+
+        fn timer(&self, duration: Duration) -> Self::Timer {
+            self.as_ref().timer(duration).boxed()
+        }
     }
 
     fn channel_messages(channel: &Channel) -> Vec<(String, String, bool)> {