WIP

Antonio Scandurra and Max Brunsfeld created

Co-Authored-By: Max Brunsfeld <max@zed.dev>

Change summary

gpui/src/executor.rs |  38 +++++++++++++++
server/src/rpc.rs    |   2 
zed/src/channel.rs   |  62 ++------------------------
zed/src/rpc.rs       | 107 +++++++++++++++++++++++++++++++++++++--------
zed/src/test.rs      |  54 ++++++++++++++++++++++
5 files changed, 182 insertions(+), 81 deletions(-)

Detailed changes

gpui/src/executor.rs 🔗

@@ -3,8 +3,9 @@ use async_task::Runnable;
 pub use async_task::Task;
 use backtrace::{Backtrace, BacktraceFmt, BytesOrWideString};
 use parking_lot::Mutex;
+use postage::{barrier, prelude::Stream as _};
 use rand::prelude::*;
-use smol::{channel, prelude::*, Executor};
+use smol::{channel, prelude::*, Executor, Timer};
 use std::{
     fmt::{self, Debug},
     marker::PhantomData,
@@ -18,7 +19,7 @@ use std::{
     },
     task::{Context, Poll},
     thread,
-    time::Duration,
+    time::{Duration, Instant},
 };
 use waker_fn::waker_fn;
 
@@ -49,6 +50,8 @@ struct DeterministicState {
     spawned_from_foreground: Vec<(Runnable, Backtrace)>,
     forbid_parking: bool,
     block_on_ticks: RangeInclusive<usize>,
+    now: Instant,
+    pending_sleeps: Vec<(Instant, barrier::Sender)>,
 }
 
 pub struct Deterministic {
@@ -67,6 +70,8 @@ impl Deterministic {
                 spawned_from_foreground: Default::default(),
                 forbid_parking: false,
                 block_on_ticks: 0..=1000,
+                now: Instant::now(),
+                pending_sleeps: Default::default(),
             })),
             parker: Default::default(),
         }
@@ -407,6 +412,35 @@ impl Foreground {
         }
     }
 
+    pub async fn sleep(&self, duration: Duration) {
+        match self {
+            Self::Deterministic(executor) => {
+                let (tx, mut rx) = barrier::channel();
+                {
+                    let mut state = executor.state.lock();
+                    let wakeup_at = state.now + duration;
+                    state.pending_sleeps.push((wakeup_at, tx));
+                }
+                rx.recv().await;
+            }
+            _ => {
+                Timer::after(duration).await;
+            }
+        }
+    }
+
+    pub fn advance_clock(&self, duration: Duration) {
+        match self {
+            Self::Deterministic(executor) => {
+                let mut state = executor.state.lock();
+                state.now += duration;
+                let now = state.now;
+                state.pending_sleeps.retain(|(wakeup, _)| *wakeup > now);
+            }
+            _ => panic!("this method can only be called on a deterministic executor"),
+        }
+    }
+
     pub fn set_block_on_ticks(&self, range: RangeInclusive<usize>) {
         match self {
             Self::Deterministic(executor) => executor.state.lock().block_on_ticks = range,

server/src/rpc.rs 🔗

@@ -1469,7 +1469,7 @@ mod tests {
             .await;
 
         // Drop client B's connection and ensure client A observes client B leaving the worktree.
-        client_b.disconnect().await.unwrap();
+        client_b.disconnect(&cx_b.to_async()).await.unwrap();
         worktree_a
             .condition(&cx_a, |tree, _| tree.peers().len() == 0)
             .await;

zed/src/channel.rs 🔗

@@ -443,9 +443,8 @@ impl<'a> sum_tree::SeekDimension<'a, ChannelMessageSummary> for Count {
 #[cfg(test)]
 mod tests {
     use super::*;
+    use crate::test::FakeServer;
     use gpui::TestAppContext;
-    use postage::mpsc::Receiver;
-    use zrpc::{test::Channel, ConnectionId, Peer, Receipt};
 
     #[gpui::test]
     async fn test_channel_messages(mut cx: TestAppContext) {
@@ -458,7 +457,7 @@ mod tests {
         channel_list.read_with(&cx, |list, _| assert_eq!(list.available_channels(), None));
 
         // Get the available channels.
-        let get_channels = server.receive::<proto::GetChannels>().await;
+        let get_channels = server.receive::<proto::GetChannels>().await.unwrap();
         server
             .respond(
                 get_channels.receipt(),
@@ -489,7 +488,7 @@ mod tests {
             })
             .unwrap();
         channel.read_with(&cx, |channel, _| assert!(channel.messages().is_empty()));
-        let join_channel = server.receive::<proto::JoinChannel>().await;
+        let join_channel = server.receive::<proto::JoinChannel>().await.unwrap();
         server
             .respond(
                 join_channel.receipt(),
@@ -514,7 +513,7 @@ mod tests {
             .await;
 
         // Client requests all users for the received messages
-        let mut get_users = server.receive::<proto::GetUsers>().await;
+        let mut get_users = server.receive::<proto::GetUsers>().await.unwrap();
         get_users.payload.user_ids.sort();
         assert_eq!(get_users.payload.user_ids, vec![5, 6]);
         server
@@ -571,7 +570,7 @@ mod tests {
             .await;
 
         // Client requests user for message since they haven't seen them yet
-        let get_users = server.receive::<proto::GetUsers>().await;
+        let get_users = server.receive::<proto::GetUsers>().await.unwrap();
         assert_eq!(get_users.payload.user_ids, vec![7]);
         server
             .respond(
@@ -607,7 +606,7 @@ mod tests {
         channel.update(&mut cx, |channel, cx| {
             assert!(channel.load_more_messages(cx));
         });
-        let get_messages = server.receive::<proto::GetChannelMessages>().await;
+        let get_messages = server.receive::<proto::GetChannelMessages>().await.unwrap();
         assert_eq!(get_messages.payload.channel_id, 5);
         assert_eq!(get_messages.payload.before_message_id, 10);
         server
@@ -653,53 +652,4 @@ mod tests {
             );
         });
     }
-
-    struct FakeServer {
-        peer: Arc<Peer>,
-        incoming: Receiver<Box<dyn proto::AnyTypedEnvelope>>,
-        connection_id: ConnectionId,
-    }
-
-    impl FakeServer {
-        async fn for_client(user_id: u64, client: &Arc<Client>, cx: &TestAppContext) -> Self {
-            let (client_conn, server_conn) = Channel::bidirectional();
-            let peer = Peer::new();
-            let (connection_id, io, incoming) = peer.add_connection(server_conn).await;
-            cx.background().spawn(io).detach();
-
-            client
-                .set_connection(user_id, client_conn, &cx.to_async())
-                .await
-                .unwrap();
-
-            Self {
-                peer,
-                incoming,
-                connection_id,
-            }
-        }
-
-        async fn send<T: proto::EnvelopedMessage>(&self, message: T) {
-            self.peer.send(self.connection_id, message).await.unwrap();
-        }
-
-        async fn receive<M: proto::EnvelopedMessage>(&mut self) -> TypedEnvelope<M> {
-            *self
-                .incoming
-                .recv()
-                .await
-                .unwrap()
-                .into_any()
-                .downcast::<TypedEnvelope<M>>()
-                .unwrap()
-        }
-
-        async fn respond<T: proto::RequestMessage>(
-            &self,
-            receipt: Receipt<T>,
-            response: T::Response,
-        ) {
-            self.peer.respond(receipt, response).await.unwrap()
-        }
-    }
 }

zed/src/rpc.rs 🔗

@@ -3,10 +3,12 @@ use anyhow::{anyhow, Context, Result};
 use async_tungstenite::tungstenite::{
     http::Request, Error as WebSocketError, Message as WebSocketMessage,
 };
+use futures::StreamExt as _;
 use gpui::{AsyncAppContext, Entity, ModelContext, Task};
 use lazy_static::lazy_static;
 use parking_lot::RwLock;
 use postage::{prelude::Stream, watch};
+use smol::Timer;
 use std::{
     any::TypeId,
     collections::HashMap,
@@ -42,6 +44,10 @@ pub enum Status {
         user_id: u64,
     },
     ConnectionLost,
+    Reconnecting,
+    ReconnectionError {
+        next_reconnection: Instant,
+    },
 }
 
 struct ClientState {
@@ -51,6 +57,8 @@ struct ClientState {
         (TypeId, u64),
         Box<dyn Send + Sync + FnMut(Box<dyn AnyTypedEnvelope>, &mut AsyncAppContext)>,
     >,
+    _maintain_connection: Option<Task<()>>,
+    heartbeat_interval: Duration,
 }
 
 impl Default for ClientState {
@@ -59,6 +67,8 @@ impl Default for ClientState {
             status: watch::channel_with(Status::Disconnected),
             entity_id_extractors: Default::default(),
             model_handlers: Default::default(),
+            _maintain_connection: None,
+            heartbeat_interval: Duration::from_secs(5),
         }
     }
 }
@@ -95,9 +105,35 @@ impl Client {
         self.state.read().status.1.clone()
     }
 
-    fn set_status(&self, status: Status) {
+    fn set_status(self: &Arc<Self>, status: Status, cx: &AsyncAppContext) {
         let mut state = self.state.write();
         *state.status.0.borrow_mut() = status;
+        match status {
+            Status::Connected { .. } => {
+                let heartbeat_interval = state.heartbeat_interval;
+                let this = self.clone();
+                let foreground = cx.foreground();
+                state._maintain_connection = Some(cx.foreground().spawn(async move {
+                    let mut next_ping_id = 0;
+                    loop {
+                        foreground.sleep(heartbeat_interval).await;
+                        this.request(proto::Ping { id: next_ping_id })
+                            .await
+                            .unwrap();
+                        next_ping_id += 1;
+                    }
+                }));
+            }
+            Status::ConnectionLost => {
+                state._maintain_connection = Some(cx.foreground().spawn(async move {
+                    // TODO: try to reconnect
+                }));
+            }
+            Status::Disconnected => {
+                state._maintain_connection.take();
+            }
+            _ => {}
+        }
     }
 
     pub fn subscribe_from_model<T, M, F>(
@@ -167,14 +203,14 @@ impl Client {
         let (user_id, access_token) = Self::login(cx.platform(), &cx.background()).await?;
         let user_id = user_id.parse::<u64>()?;
 
-        self.set_status(Status::Connecting);
+        self.set_status(Status::Connecting, cx);
         match self.connect(user_id, &access_token, cx).await {
             Ok(()) => {
                 log::info!("connected to rpc address {}", *ZED_SERVER_URL);
                 Ok(())
             }
             Err(err) => {
-                self.set_status(Status::ConnectionError);
+                self.set_status(Status::ConnectionError, cx);
                 Err(err)
             }
         }
@@ -256,20 +292,24 @@ impl Client {
                 .detach();
         }
 
-        self.set_status(Status::Connected {
-            connection_id,
-            user_id,
-        });
+        self.set_status(
+            Status::Connected {
+                connection_id,
+                user_id,
+            },
+            cx,
+        );
 
         let handle_io = cx.background().spawn(handle_io);
         let this = self.clone();
+        let cx = cx.clone();
         cx.foreground()
             .spawn(async move {
                 match handle_io.await {
-                    Ok(()) => this.set_status(Status::Disconnected),
+                    Ok(()) => this.set_status(Status::Disconnected, &cx),
                     Err(err) => {
                         log::error!("connection error: {:?}", err);
-                        this.set_status(Status::ConnectionLost);
+                        this.set_status(Status::ConnectionLost, &cx);
                     }
                 }
             })
@@ -359,10 +399,10 @@ impl Client {
         })
     }
 
-    pub async fn disconnect(&self) -> Result<()> {
+    pub async fn disconnect(self: &Arc<Self>, cx: &AsyncAppContext) -> Result<()> {
         let conn_id = self.connection_id()?;
         self.peer.disconnect(conn_id).await;
-        self.set_status(Status::Disconnected);
+        self.set_status(Status::Disconnected, cx);
         Ok(())
     }
 
@@ -444,13 +484,40 @@ const LOGIN_RESPONSE: &'static str = "
 </html>
 ";
 
-#[test]
-fn test_encode_and_decode_worktree_url() {
-    let url = encode_worktree_url(5, "deadbeef");
-    assert_eq!(decode_worktree_url(&url), Some((5, "deadbeef".to_string())));
-    assert_eq!(
-        decode_worktree_url(&format!("\n {}\t", url)),
-        Some((5, "deadbeef".to_string()))
-    );
-    assert_eq!(decode_worktree_url("not://the-right-format"), None);
+#[cfg(test)]
+mod tests {
+    use super::*;
+    use crate::test::FakeServer;
+    use gpui::TestAppContext;
+
+    #[gpui::test(iterations = 1000)]
+    async fn test_heartbeat(cx: TestAppContext) {
+        let user_id = 5;
+        let client = Client::new();
+
+        client.state.write().heartbeat_interval = Duration::from_millis(1);
+        let mut server = FakeServer::for_client(user_id, &client, &cx).await;
+
+        let ping = server.receive::<proto::Ping>().await.unwrap();
+        assert_eq!(ping.payload.id, 0);
+        server.respond(ping.receipt(), proto::Pong { id: 0 }).await;
+
+        let ping = server.receive::<proto::Ping>().await.unwrap();
+        assert_eq!(ping.payload.id, 1);
+        server.respond(ping.receipt(), proto::Pong { id: 1 }).await;
+
+        client.disconnect(&cx.to_async()).await.unwrap();
+        assert!(server.receive::<proto::Ping>().await.is_err());
+    }
+
+    #[test]
+    fn test_encode_and_decode_worktree_url() {
+        let url = encode_worktree_url(5, "deadbeef");
+        assert_eq!(decode_worktree_url(&url), Some((5, "deadbeef".to_string())));
+        assert_eq!(
+            decode_worktree_url(&format!("\n {}\t", url)),
+            Some((5, "deadbeef".to_string()))
+        );
+        assert_eq!(decode_worktree_url("not://the-right-format"), None);
+    }
 }

zed/src/test.rs 🔗

@@ -3,14 +3,16 @@ use crate::{
     channel::ChannelList,
     fs::RealFs,
     language::LanguageRegistry,
-    rpc,
+    rpc::{self, Client},
     settings::{self, ThemeRegistry},
     time::ReplicaId,
     user::UserStore,
     AppState,
 };
-use gpui::{Entity, ModelHandle, MutableAppContext};
+use anyhow::{anyhow, Result};
+use gpui::{Entity, ModelHandle, MutableAppContext, TestAppContext};
 use parking_lot::Mutex;
+use postage::{mpsc, prelude::Stream as _};
 use smol::channel;
 use std::{
     marker::PhantomData,
@@ -18,6 +20,7 @@ use std::{
     sync::Arc,
 };
 use tempdir::TempDir;
+use zrpc::{proto, ConnectionId, Peer, Receipt, TypedEnvelope};
 
 #[cfg(feature = "test-support")]
 pub use zrpc::test::Channel;
@@ -195,3 +198,50 @@ impl<T: Entity> Observer<T> {
         (observer, notify_rx)
     }
 }
+
+pub struct FakeServer {
+    peer: Arc<Peer>,
+    incoming: mpsc::Receiver<Box<dyn proto::AnyTypedEnvelope>>,
+    connection_id: ConnectionId,
+}
+
+impl FakeServer {
+    pub async fn for_client(user_id: u64, client: &Arc<Client>, cx: &TestAppContext) -> Self {
+        let (client_conn, server_conn) = zrpc::test::Channel::bidirectional();
+        let peer = Peer::new();
+        let (connection_id, io, incoming) = peer.add_connection(server_conn).await;
+        cx.background().spawn(io).detach();
+
+        client
+            .set_connection(user_id, client_conn, &cx.to_async())
+            .await
+            .unwrap();
+
+        Self {
+            peer,
+            incoming,
+            connection_id,
+        }
+    }
+
+    pub async fn send<T: proto::EnvelopedMessage>(&self, message: T) {
+        self.peer.send(self.connection_id, message).await.unwrap();
+    }
+
+    pub async fn receive<M: proto::EnvelopedMessage>(&mut self) -> Result<TypedEnvelope<M>> {
+        let message = self
+            .incoming
+            .recv()
+            .await
+            .ok_or_else(|| anyhow!("other half hung up"))?;
+        Ok(*message.into_any().downcast::<TypedEnvelope<M>>().unwrap())
+    }
+
+    pub async fn respond<T: proto::RequestMessage>(
+        &self,
+        receipt: Receipt<T>,
+        response: T::Response,
+    ) {
+        self.peer.respond(receipt, response).await.unwrap()
+    }
+}