Test `RpcClient::subscribe`

Antonio Scandurra created

Change summary

zed-rpc/proto/zed.proto | 10 +++
zed-rpc/src/proto.rs    | 11 ++++
zed/src/rpc_client.rs   | 97 +++++++++++++++++++++++++++++++++++++-----
3 files changed, 104 insertions(+), 14 deletions(-)

Detailed changes

zed-rpc/proto/zed.proto 🔗

@@ -9,6 +9,7 @@ message FromClient {
         NewWorktree new_worktree = 3;
         ShareWorktree share_worktree = 4;
         UploadFile upload_file = 5;
+        SubscribeToPathRequests subscribe_to_path_requests = 6;
     }
 
     message Auth {
@@ -32,6 +33,8 @@ message FromClient {
         bytes path = 1;
         bytes content = 2;
     }
+
+    message SubscribeToPathRequests {}
 }
 
 message FromServer {
@@ -41,6 +44,7 @@ message FromServer {
         AuthResponse auth_response = 2;
         NewWorktreeResponse new_worktree_response = 3;
         ShareWorktreeResponse share_worktree_response = 4;
+        PathRequest path_request = 5;
     }
 
     message AuthResponse {
@@ -54,4 +58,8 @@ message FromServer {
     message ShareWorktreeResponse {
         repeated int32 needed_file_indices = 1;
     }
-}
+
+    message PathRequest {
+        bytes path = 1;
+    }
+}

zed-rpc/src/proto.rs 🔗

@@ -67,10 +67,21 @@ macro_rules! send_message {
     };
 }
 
+macro_rules! subscribe_message {
+    ($subscription:ident, $event:ident) => {
+        directed_message!($subscription, ClientMessage, from_client);
+        directed_message!($event, ServerMessage, from_server);
+        impl SubscribeMessage for from_client::$subscription {
+            type Event = from_server::$event;
+        }
+    };
+}
+
 request_message!(Auth, AuthResponse);
 request_message!(NewWorktree, NewWorktreeResponse);
 request_message!(ShareWorktree, ShareWorktreeResponse);
 send_message!(UploadFile);
+subscribe_message!(SubscribeToPathRequests, PathRequest);
 
 /// A stream of protobuf messages.
 pub struct MessageStream<T> {

zed/src/rpc_client.rs 🔗

@@ -72,10 +72,9 @@ impl RpcClient {
         loop {
             let read_message = stream.read_message::<proto::FromServer>();
             let dropped = drop_rx.recv();
-            smol::pin!(read_message);
-            smol::pin!(dropped);
-            let result = futures::future::select(&mut read_message, &mut dropped).await;
-            match result {
+            smol::pin!(read_message, dropped);
+
+            match futures::future::select(&mut read_message, &mut dropped).await {
                 Either::Left((Ok(incoming), _)) => {
                     if let Some(variant) = incoming.variant {
                         if let Some(request_id) = incoming.request_id {
@@ -98,12 +97,9 @@ impl RpcClient {
                     }
                 }
                 Either::Left((Err(error), _)) => {
-                    log::warn!("invalid incoming RPC message {:?}", error)
-                }
-                Either::Right(_) => {
-                    eprintln!("done with incoming loop");
-                    break;
+                    log::warn!("invalid incoming RPC message {:?}", error);
                 }
+                Either::Right(_) => break,
             }
         }
     }
@@ -165,7 +161,7 @@ impl RpcClient {
     }
 
     pub async fn subscribe<T: SubscribeMessage>(
-        &mut self,
+        &self,
         subscription: T,
     ) -> Result<impl Stream<Item = Result<T::Event>>> {
         let message_id = self.next_message_id.fetch_add(1, atomic::Ordering::SeqCst);
@@ -206,7 +202,7 @@ mod tests {
     #[gpui::test]
     async fn test_request_response(cx: gpui::TestAppContext) {
         let executor = cx.read(|app| app.background_executor().clone());
-        let socket_dir_path = TempDir::new("request-response-socket").unwrap();
+        let socket_dir_path = TempDir::new("request-response").unwrap();
         let socket_path = socket_dir_path.path().join(".sock");
         let listener = UnixListener::bind(&socket_path).unwrap();
         let client_conn = UnixStream::connect(&socket_path).await.unwrap();
@@ -267,10 +263,85 @@ mod tests {
         );
     }
 
+    #[gpui::test]
+    async fn test_subscribe(cx: gpui::TestAppContext) {
+        let executor = cx.read(|app| app.background_executor().clone());
+        let socket_dir_path = TempDir::new("subscribe").unwrap();
+        let socket_path = socket_dir_path.path().join(".sock");
+        let listener = UnixListener::bind(&socket_path).unwrap();
+        let client_conn = UnixStream::connect(&socket_path).await.unwrap();
+        let (server_conn, _) = listener.accept().await.unwrap();
+
+        let mut server_stream = MessageStream::new(server_conn);
+        let client = RpcClient::new(client_conn, executor.clone());
+
+        let mut events = client
+            .subscribe(proto::from_client::SubscribeToPathRequests {})
+            .await
+            .unwrap();
+
+        let subscription = server_stream
+            .read_message::<proto::FromClient>()
+            .await
+            .unwrap();
+        assert_eq!(
+            subscription.variant,
+            Some(proto::from_client::Variant::SubscribeToPathRequests(
+                proto::from_client::SubscribeToPathRequests {}
+            ))
+        );
+        server_stream
+            .write_message(&proto::FromServer {
+                request_id: Some(subscription.id),
+                variant: Some(proto::from_server::Variant::PathRequest(
+                    proto::from_server::PathRequest {
+                        path: b"path-1".to_vec(),
+                    },
+                )),
+            })
+            .await
+            .unwrap();
+        server_stream
+            .write_message(&proto::FromServer {
+                request_id: Some(99999),
+                variant: Some(proto::from_server::Variant::PathRequest(
+                    proto::from_server::PathRequest {
+                        path: b"path-2".to_vec(),
+                    },
+                )),
+            })
+            .await
+            .unwrap();
+        server_stream
+            .write_message(&proto::FromServer {
+                request_id: Some(subscription.id),
+                variant: Some(proto::from_server::Variant::PathRequest(
+                    proto::from_server::PathRequest {
+                        path: b"path-3".to_vec(),
+                    },
+                )),
+            })
+            .await
+            .unwrap();
+
+        assert_eq!(
+            events.recv().await.unwrap().unwrap(),
+            proto::from_server::PathRequest {
+                path: b"path-1".to_vec()
+            }
+        );
+        assert_eq!(
+            events.recv().await.unwrap().unwrap(),
+            proto::from_server::PathRequest {
+                path: b"path-3".to_vec()
+            }
+        );
+    }
+
     #[gpui::test]
     async fn test_drop_client(cx: gpui::TestAppContext) {
         let executor = cx.read(|app| app.background_executor().clone());
-        let socket_dir_path = TempDir::new("request-response-socket").unwrap();
+        let socket_dir_path = TempDir::new("drop-client").unwrap();
         let socket_path = socket_dir_path.path().join(".sock");
         let listener = UnixListener::bind(&socket_path).unwrap();
         let client_conn = UnixStream::connect(&socket_path).await.unwrap();
@@ -295,7 +366,7 @@ mod tests {
     #[gpui::test]
     async fn test_io_error(cx: gpui::TestAppContext) {
         let executor = cx.read(|app| app.background_executor().clone());
-        let socket_dir_path = TempDir::new("request-response-socket").unwrap();
+        let socket_dir_path = TempDir::new("io-error").unwrap();
         let socket_path = socket_dir_path.path().join(".sock");
         let _listener = UnixListener::bind(&socket_path).unwrap();
         let mut client_conn = UnixStream::connect(&socket_path).await.unwrap();