Refactor and write a simple unit test to verify reconnection logic

Antonio Scandurra created

Change summary

server/src/rpc.rs |  48 ++++++++++++++++----
zed/src/rpc.rs    | 112 ++++++++++++++++++++++++------------------------
zed/src/test.rs   |  38 +++++++++++----
3 files changed, 121 insertions(+), 77 deletions(-)

Detailed changes

server/src/rpc.rs 🔗

@@ -1693,20 +1693,46 @@ mod tests {
             cx: &mut TestAppContext,
             name: &str,
         ) -> (UserId, Arc<Client>) {
-            let user_id = self.app_state.db.create_user(name, false).await.unwrap();
-            let client = Client::new();
-            let (client_conn, server_conn) = Conn::in_memory();
-            cx.background()
-                .spawn(
-                    self.server
-                        .handle_connection(server_conn, name.to_string(), user_id),
-                )
-                .detach();
+            let client_user_id = self.app_state.db.create_user(name, false).await.unwrap();
+            let client_name = name.to_string();
+            let mut client = Client::new();
+            let server = self.server.clone();
+            Arc::get_mut(&mut client)
+                .unwrap()
+                .set_login_and_connect_callbacks(
+                    move |cx| {
+                        cx.spawn(|_| async move {
+                            let access_token = "the-token".to_string();
+                            Ok((client_user_id.0 as u64, access_token))
+                        })
+                    },
+                    {
+                        move |user_id, access_token, cx| {
+                            assert_eq!(user_id, client_user_id.0 as u64);
+                            assert_eq!(access_token, "the-token");
+
+                            let server = server.clone();
+                            let client_name = client_name.clone();
+                            cx.spawn(move |cx| async move {
+                                let (client_conn, server_conn) = Conn::in_memory();
+                                cx.background()
+                                    .spawn(server.handle_connection(
+                                        server_conn,
+                                        client_name,
+                                        client_user_id,
+                                    ))
+                                    .detach();
+                                Ok(client_conn)
+                            })
+                        }
+                    },
+                );
+
             client
-                .set_connection(user_id.to_proto(), client_conn, &cx.to_async())
+                .authenticate_and_connect(&cx.to_async())
                 .await
                 .unwrap();
-            (user_id, client)
+            (client_user_id, client)
         }
 
         async fn build_app_state(test_db: &TestDb) -> Arc<AppState> {

zed/src/rpc.rs 🔗

@@ -49,7 +49,10 @@ pub enum Status {
         user_id: u64,
     },
     ConnectionLost,
-    Reconnecting,
+    Reauthenticating,
+    Reconnecting {
+        user_id: u64,
+    },
     ReconnectionError {
         next_reconnection: Instant,
     },
@@ -164,9 +167,10 @@ impl Client {
                     }
                 }));
             }
-            _ => {
+            Status::Disconnected => {
                 state._maintain_connection.take();
             }
+            _ => {}
         }
     }
 
@@ -227,14 +231,20 @@ impl Client {
         self: &Arc<Self>,
         cx: &AsyncAppContext,
     ) -> anyhow::Result<()> {
-        if matches!(
-            *self.status().borrow(),
-            Status::Authenticating | Status::Connecting { .. } | Status::Connected { .. }
-        ) {
-            return Ok(());
-        }
+        let was_disconnected = match *self.status().borrow() {
+            Status::Disconnected => true,
+            Status::Connected { .. }
+            | Status::Connecting { .. }
+            | Status::Reconnecting { .. }
+            | Status::Reauthenticating => return Ok(()),
+            _ => false,
+        };
 
-        self.set_status(Status::Authenticating, cx);
+        if was_disconnected {
+            self.set_status(Status::Authenticating, cx);
+        } else {
+            self.set_status(Status::Reauthenticating, cx)
+        }
 
         let (user_id, access_token) = match self.authenticate(&cx).await {
             Ok(result) => result,
@@ -244,27 +254,25 @@ impl Client {
             }
         };
 
-        self.set_status(Status::Connecting { user_id }, cx);
-
-        let conn = match self.connect(user_id, &access_token, cx).await {
-            Ok(conn) => conn,
+        if was_disconnected {
+            self.set_status(Status::Connecting { user_id }, cx);
+        } else {
+            self.set_status(Status::Reconnecting { user_id }, cx);
+        }
+        match self.connect(user_id, &access_token, cx).await {
+            Ok(conn) => {
+                log::info!("connected to rpc address {}", *ZED_SERVER_URL);
+                self.set_connection(user_id, conn, cx).await;
+                Ok(())
+            }
             Err(err) => {
                 self.set_status(Status::ConnectionError, cx);
-                return Err(err);
+                Err(err)
             }
-        };
-
-        self.set_connection(user_id, conn, cx).await?;
-        log::info!("connected to rpc address {}", *ZED_SERVER_URL);
-        Ok(())
+        }
     }
 
-    pub async fn set_connection(
-        self: &Arc<Self>,
-        user_id: u64,
-        conn: Conn,
-        cx: &AsyncAppContext,
-    ) -> Result<()> {
+    async fn set_connection(self: &Arc<Self>, user_id: u64, conn: Conn, cx: &AsyncAppContext) {
         let (connection_id, handle_io, mut incoming) = self.peer.add_connection(conn).await;
         cx.foreground()
             .spawn({
@@ -321,7 +329,6 @@ impl Client {
                 }
             })
             .detach();
-        Ok(())
     }
 
     fn authenticate(self: &Arc<Self>, cx: &AsyncAppContext) -> Task<Result<(u64, String)>> {
@@ -489,35 +496,6 @@ impl Client {
     }
 }
 
-pub trait MessageHandler<'a, M: proto::EnvelopedMessage>: Clone {
-    type Output: 'a + Future<Output = anyhow::Result<()>>;
-
-    fn handle(
-        &self,
-        message: TypedEnvelope<M>,
-        rpc: &'a Client,
-        cx: &'a mut gpui::AsyncAppContext,
-    ) -> Self::Output;
-}
-
-impl<'a, M, F, Fut> MessageHandler<'a, M> for F
-where
-    M: proto::EnvelopedMessage,
-    F: Clone + Fn(TypedEnvelope<M>, &'a Client, &'a mut gpui::AsyncAppContext) -> Fut,
-    Fut: 'a + Future<Output = anyhow::Result<()>>,
-{
-    type Output = Fut;
-
-    fn handle(
-        &self,
-        message: TypedEnvelope<M>,
-        rpc: &'a Client,
-        cx: &'a mut gpui::AsyncAppContext,
-    ) -> Self::Output {
-        (self)(message, rpc, cx)
-    }
-}
-
 const WORKTREE_URL_PREFIX: &'static str = "zed://worktrees/";
 
 pub fn encode_worktree_url(id: u64, access_token: &str) -> String {
@@ -550,6 +528,8 @@ mod tests {
 
     #[gpui::test(iterations = 10)]
     async fn test_heartbeat(cx: TestAppContext) {
+        cx.foreground().forbid_parking();
+
         let user_id = 5;
         let mut client = Client::new();
         let server = FakeServer::for_client(user_id, &mut client, &cx).await;
@@ -568,6 +548,28 @@ mod tests {
         assert!(server.receive::<proto::Ping>().await.is_err());
     }
 
+    #[gpui::test(iterations = 10)]
+    async fn test_reconnection(cx: TestAppContext) {
+        cx.foreground().forbid_parking();
+
+        let user_id = 5;
+        let mut client = Client::new();
+        let server = FakeServer::for_client(user_id, &mut client, &cx).await;
+        let mut status = client.status();
+        assert!(matches!(
+            status.recv().await,
+            Some(Status::Connected { .. })
+        ));
+
+        server.forbid_connections();
+        server.disconnect().await;
+        while !matches!(status.recv().await, Some(Status::ReconnectionError { .. })) {}
+
+        server.allow_connections();
+        cx.foreground().advance_clock(Duration::from_secs(10));
+        while !matches!(status.recv().await, Some(Status::Connected { .. })) {}
+    }
+
     #[test]
     fn test_encode_and_decode_worktree_url() {
         let url = encode_worktree_url(5, "deadbeef");

zed/src/test.rs 🔗

@@ -17,7 +17,10 @@ use smol::channel;
 use std::{
     marker::PhantomData,
     path::{Path, PathBuf},
-    sync::Arc,
+    sync::{
+        atomic::{AtomicBool, Ordering::SeqCst},
+        Arc,
+    },
 };
 use tempdir::TempDir;
 use zrpc::{proto, Conn, ConnectionId, Peer, Receipt, TypedEnvelope};
@@ -200,6 +203,7 @@ pub struct FakeServer {
     peer: Arc<Peer>,
     incoming: Mutex<Option<mpsc::Receiver<Box<dyn proto::AnyTypedEnvelope>>>>,
     connection_id: Mutex<Option<ConnectionId>>,
+    forbid_connections: AtomicBool,
 }
 
 impl FakeServer {
@@ -212,6 +216,7 @@ impl FakeServer {
             peer: Peer::new(),
             incoming: Default::default(),
             connection_id: Default::default(),
+            forbid_connections: Default::default(),
         });
 
         Arc::get_mut(client)
@@ -230,15 +235,14 @@ impl FakeServer {
                         assert_eq!(access_token, "the-token");
                         cx.spawn({
                             let server = server.clone();
-                            move |cx| async move { Ok(server.connect(&cx).await) }
+                            move |cx| async move { server.connect(&cx).await }
                         })
                     }
                 },
             );
 
-        let conn = result.connect(&cx.to_async()).await;
         client
-            .set_connection(client_user_id, conn, &cx.to_async())
+            .authenticate_and_connect(&cx.to_async())
             .await
             .unwrap();
         result
@@ -250,13 +254,25 @@ impl FakeServer {
         self.incoming.lock().take();
     }
 
-    async fn connect(&self, cx: &AsyncAppContext) -> Conn {
-        let (client_conn, server_conn) = Conn::in_memory();
-        let (connection_id, io, incoming) = self.peer.add_connection(server_conn).await;
-        cx.background().spawn(io).detach();
-        *self.incoming.lock() = Some(incoming);
-        *self.connection_id.lock() = Some(connection_id);
-        client_conn
+    async fn connect(&self, cx: &AsyncAppContext) -> Result<Conn> {
+        if self.forbid_connections.load(SeqCst) {
+            Err(anyhow!("server is forbidding connections"))
+        } else {
+            let (client_conn, server_conn) = Conn::in_memory();
+            let (connection_id, io, incoming) = self.peer.add_connection(server_conn).await;
+            cx.background().spawn(io).detach();
+            *self.incoming.lock() = Some(incoming);
+            *self.connection_id.lock() = Some(connection_id);
+            Ok(client_conn)
+        }
+    }
+
+    pub fn forbid_connections(&self) {
+        self.forbid_connections.store(true, SeqCst);
+    }
+
+    pub fn allow_connections(&self) {
+        self.forbid_connections.store(false, SeqCst);
     }
 
     pub async fn send<T: proto::EnvelopedMessage>(&self, message: T) {