Fix termination of peer's incoming future

Max Brunsfeld created

* Re-enable peer tests
* Enhance request/response unit test to exercise
  peers interacting with each other end-to-end

Change summary

Cargo.lock          |   1 
zed-rpc/Cargo.toml  |   1 
zed-rpc/src/peer.rs | 387 ++++++++++++++++++++++++++--------------------
3 files changed, 221 insertions(+), 168 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -4364,6 +4364,7 @@ dependencies = [
  "rsa",
  "serde 1.0.125",
  "smol",
+ "tempdir",
 ]
 
 [[package]]

zed-rpc/Cargo.toml 🔗

@@ -21,3 +21,4 @@ prost-build = { git="https://github.com/sfackler/prost", rev="082f3e65874fe91382
 
 [dev-dependencies]
 smol = "1.2.5"
+tempdir = "0.3.7"

zed-rpc/src/peer.rs 🔗

@@ -29,7 +29,6 @@ struct Connection {
     writer: Mutex<MessageStream<BoxedWriter>>,
     response_channels: Mutex<HashMap<u32, oneshot::Sender<proto::Envelope>>>,
     next_message_id: AtomicU32,
-    _close_barrier: barrier::Sender,
 }
 
 type MessageHandler = Box<
@@ -53,7 +52,7 @@ impl<T> TypedEnvelope<T> {
 }
 
 pub struct Peer {
-    connections: RwLock<HashMap<ConnectionId, Arc<Connection>>>,
+    connections: RwLock<HashMap<ConnectionId, (Arc<Connection>, barrier::Sender)>>,
     message_handlers: RwLock<Vec<MessageHandler>>,
     handler_types: Mutex<HashSet<TypeId>>,
     next_connection_id: AtomicU32,
@@ -106,7 +105,7 @@ impl Peer {
     pub async fn add_connection<Conn>(
         self: &Arc<Self>,
         conn: Conn,
-    ) -> (ConnectionId, impl Future<Output = ()>)
+    ) -> (ConnectionId, impl Future<Output = Result<()>>)
     where
         Conn: Clone + AsyncRead + AsyncWrite + Unpin + Send + 'static,
     {
@@ -119,13 +118,12 @@ impl Peer {
             writer: Mutex::new(MessageStream::new(Box::pin(conn.clone()))),
             response_channels: Default::default(),
             next_message_id: Default::default(),
-            _close_barrier: close_tx,
         });
 
         self.connections
             .write()
             .await
-            .insert(connection_id, connection.clone());
+            .insert(connection_id, (connection.clone(), close_tx));
 
         let this = self.clone();
         let handler_future = async move {
@@ -178,8 +176,9 @@ impl Peer {
                     }
                     Either::Left((Err(error), _)) => {
                         log::warn!("received invalid RPC message: {}", error);
+                        Err(error)?;
                     }
-                    Either::Right(_) => break,
+                    Either::Right(_) => return Ok(()),
                 }
             }
         };
@@ -199,13 +198,7 @@ impl Peer {
         let this = self.clone();
         let (tx, mut rx) = oneshot::channel();
         async move {
-            let connection = this
-                .connections
-                .read()
-                .await
-                .get(&connection_id)
-                .ok_or_else(|| anyhow!("unknown connection: {}", connection_id.0))?
-                .clone();
+            let connection = this.connection(connection_id).await?;
             let message_id = connection
                 .next_message_id
                 .fetch_add(1, atomic::Ordering::SeqCst);
@@ -236,13 +229,7 @@ impl Peer {
     ) -> impl Future<Output = Result<()>> {
         let this = self.clone();
         async move {
-            let connection = this
-                .connections
-                .read()
-                .await
-                .get(&connection_id)
-                .ok_or_else(|| anyhow!("unknown connection: {}", connection_id.0))?
-                .clone();
+            let connection = this.connection(connection_id).await?;
             let message_id = connection
                 .next_message_id
                 .fetch_add(1, atomic::Ordering::SeqCst);
@@ -263,13 +250,7 @@ impl Peer {
     ) -> impl Future<Output = Result<()>> {
         let this = self.clone();
         async move {
-            let connection = this
-                .connections
-                .read()
-                .await
-                .get(&request.connection_id)
-                .ok_or_else(|| anyhow!("unknown connection: {}", request.connection_id.0))?
-                .clone();
+            let connection = this.connection(request.connection_id).await?;
             let message_id = connection
                 .next_message_id
                 .fetch_add(1, atomic::Ordering::SeqCst);
@@ -282,146 +263,216 @@ impl Peer {
             Ok(())
         }
     }
+
+    async fn connection(&self, id: ConnectionId) -> Result<Arc<Connection>> {
+        Ok(self
+            .connections
+            .read()
+            .await
+            .get(&id)
+            .ok_or_else(|| anyhow!("unknown connection: {}", id.0))?
+            .0
+            .clone())
+    }
 }
 
-// #[cfg(test)]
-// mod tests {
-//     use super::*;
-//     use smol::{
-//         future::poll_once,
-//         io::AsyncWriteExt,
-//         net::unix::{UnixListener, UnixStream},
-//     };
-//     use std::{future::Future, io};
-//     use tempdir::TempDir;
-
-//     #[gpui::test]
-//     async fn test_request_response(cx: gpui::TestAppContext) {
-//         let executor = cx.read(|app| app.background_executor().clone());
-//         let socket_dir_path = TempDir::new("request-response").unwrap();
-//         let socket_path = socket_dir_path.path().join(".sock");
-//         let listener = UnixListener::bind(&socket_path).unwrap();
-//         let client_conn = UnixStream::connect(&socket_path).await.unwrap();
-//         let (server_conn, _) = listener.accept().await.unwrap();
-
-//         let mut server_stream = MessageStream::new(server_conn);
-//         let client = Peer::new();
-//         let (connection_id, handler) = client.add_connection(client_conn).await;
-//         executor.spawn(handler).detach();
-
-//         let client_req = client.request(
-//             connection_id,
-//             proto::Auth {
-//                 user_id: 42,
-//                 access_token: "token".to_string(),
-//             },
-//         );
-//         smol::pin!(client_req);
-//         let server_req = send_recv(&mut client_req, server_stream.read_message())
-//             .await
-//             .unwrap();
-//         assert_eq!(
-//             server_req.payload,
-//             Some(proto::envelope::Payload::Auth(proto::Auth {
-//                 user_id: 42,
-//                 access_token: "token".to_string()
-//             }))
-//         );
-
-//         // Respond to another request to ensure requests are properly matched up.
-//         server_stream
-//             .write_message(
-//                 &proto::AuthResponse {
-//                     credentials_valid: false,
-//                 }
-//                 .into_envelope(1000, Some(999)),
-//             )
-//             .await
-//             .unwrap();
-//         server_stream
-//             .write_message(
-//                 &proto::AuthResponse {
-//                     credentials_valid: true,
-//                 }
-//                 .into_envelope(1001, Some(server_req.id)),
-//             )
-//             .await
-//             .unwrap();
-//         assert_eq!(
-//             client_req.await.unwrap(),
-//             proto::AuthResponse {
-//                 credentials_valid: true
-//             }
-//         );
-//     }
-
-//     #[gpui::test]
-//     async fn test_disconnect(cx: gpui::TestAppContext) {
-//         let executor = cx.read(|app| app.background_executor().clone());
-//         let socket_dir_path = TempDir::new("drop-client").unwrap();
-//         let socket_path = socket_dir_path.path().join(".sock");
-//         let listener = UnixListener::bind(&socket_path).unwrap();
-//         let client_conn = UnixStream::connect(&socket_path).await.unwrap();
-//         let (mut server_conn, _) = listener.accept().await.unwrap();
-
-//         let client = Peer::new();
-//         let (connection_id, handler) = client.add_connection(client_conn).await;
-//         executor.spawn(handler).detach();
-//         client.disconnect(connection_id).await;
-
-//         // Try sending an empty payload over and over, until the client is dropped and hangs up.
-//         loop {
-//             match server_conn.write(&[]).await {
-//                 Ok(_) => {}
-//                 Err(err) => {
-//                     if err.kind() == io::ErrorKind::BrokenPipe {
-//                         break;
-//                     }
-//                 }
-//             }
-//         }
-//     }
-
-//     #[gpui::test]
-//     async fn test_io_error(cx: gpui::TestAppContext) {
-//         let executor = cx.read(|app| app.background_executor().clone());
-//         let socket_dir_path = TempDir::new("io-error").unwrap();
-//         let socket_path = socket_dir_path.path().join(".sock");
-//         let _listener = UnixListener::bind(&socket_path).unwrap();
-//         let mut client_conn = UnixStream::connect(&socket_path).await.unwrap();
-//         client_conn.close().await.unwrap();
-
-//         let client = Peer::new();
-//         let (connection_id, handler) = client.add_connection(client_conn).await;
-//         executor.spawn(handler).detach();
-//         let err = client
-//             .request(
-//                 connection_id,
-//                 proto::Auth {
-//                     user_id: 42,
-//                     access_token: "token".to_string(),
-//                 },
-//             )
-//             .await
-//             .unwrap_err();
-//         assert_eq!(
-//             err.downcast_ref::<io::Error>().unwrap().kind(),
-//             io::ErrorKind::BrokenPipe
-//         );
-//     }
-
-//     async fn send_recv<S, R, O>(mut sender: S, receiver: R) -> O
-//     where
-//         S: Unpin + Future,
-//         R: Future<Output = O>,
-//     {
-//         smol::pin!(receiver);
-//         loop {
-//             poll_once(&mut sender).await;
-//             match poll_once(&mut receiver).await {
-//                 Some(message) => break message,
-//                 None => continue,
-//             }
-//         }
-//     }
-// }
+#[cfg(test)]
+mod tests {
+    use super::*;
+    use smol::{
+        io::AsyncWriteExt,
+        net::unix::{UnixListener, UnixStream},
+    };
+    use std::io;
+    use tempdir::TempDir;
+
+    #[test]
+    fn test_request_response() {
+        smol::block_on(async move {
+            // create socket
+            let socket_dir_path = TempDir::new("test-request-response").unwrap();
+            let socket_path = socket_dir_path.path().join("test.sock");
+            let listener = UnixListener::bind(&socket_path).unwrap();
+
+            // create 2 clients connected to 1 server
+            let server = Peer::new();
+            let client1 = Peer::new();
+            let client2 = Peer::new();
+            let (client1_conn_id, f1) = client1
+                .add_connection(UnixStream::connect(&socket_path).await.unwrap())
+                .await;
+            let (client2_conn_id, f2) = client2
+                .add_connection(UnixStream::connect(&socket_path).await.unwrap())
+                .await;
+            let (_, f3) = server
+                .add_connection(listener.accept().await.unwrap().0)
+                .await;
+            let (_, f4) = server
+                .add_connection(listener.accept().await.unwrap().0)
+                .await;
+            smol::spawn(f1).detach();
+            smol::spawn(f2).detach();
+            smol::spawn(f3).detach();
+            smol::spawn(f4).detach();
+
+            // define the expected requests and responses
+            let request1 = proto::OpenWorktree {
+                worktree_id: 101,
+                access_token: "first-worktree-access-token".to_string(),
+            };
+            let response1 = proto::OpenWorktreeResponse {
+                worktree: Some(proto::Worktree {
+                    paths: vec![b"path/one".to_vec()],
+                }),
+            };
+            let request2 = proto::OpenWorktree {
+                worktree_id: 102,
+                access_token: "second-worktree-access-token".to_string(),
+            };
+            let response2 = proto::OpenWorktreeResponse {
+                worktree: Some(proto::Worktree {
+                    paths: vec![b"path/two".to_vec(), b"path/three".to_vec()],
+                }),
+            };
+            let request3 = proto::OpenBuffer {
+                worktree_id: 102,
+                path: b"path/two".to_vec(),
+            };
+            let response3 = proto::OpenBufferResponse {
+                buffer: Some(proto::Buffer {
+                    id: 1001,
+                    path: b"path/two".to_vec(),
+                    content: b"path/two content".to_vec(),
+                    history: vec![],
+                }),
+            };
+            let request4 = proto::OpenBuffer {
+                worktree_id: 101,
+                path: b"path/one".to_vec(),
+            };
+            let response4 = proto::OpenBufferResponse {
+                buffer: Some(proto::Buffer {
+                    id: 1002,
+                    path: b"path/one".to_vec(),
+                    content: b"path/one content".to_vec(),
+                    history: vec![],
+                }),
+            };
+
+            // on the server, respond to two requests for each client
+            let mut open_buffer_rx = server.add_message_handler::<proto::OpenBuffer>().await;
+            let mut open_worktree_rx = server.add_message_handler::<proto::OpenWorktree>().await;
+            let (mut server_done_tx, mut server_done_rx) = oneshot::channel::<()>();
+            smol::spawn({
+                let request1 = request1.clone();
+                let request2 = request2.clone();
+                let request3 = request3.clone();
+                let request4 = request4.clone();
+                let response1 = response1.clone();
+                let response2 = response2.clone();
+                let response3 = response3.clone();
+                let response4 = response4.clone();
+                async move {
+                    let msg = open_worktree_rx.recv().await.unwrap();
+                    assert_eq!(msg.payload, request1);
+                    server.respond(msg, response1.clone()).await.unwrap();
+
+                    let msg = open_worktree_rx.recv().await.unwrap();
+                    assert_eq!(msg.payload, request2.clone());
+                    server.respond(msg, response2.clone()).await.unwrap();
+
+                    let msg = open_buffer_rx.recv().await.unwrap();
+                    assert_eq!(msg.payload, request3.clone());
+                    server.respond(msg, response3.clone()).await.unwrap();
+
+                    let msg = open_buffer_rx.recv().await.unwrap();
+                    assert_eq!(msg.payload, request4.clone());
+                    server.respond(msg, response4.clone()).await.unwrap();
+
+                    server_done_tx.send(()).await.unwrap();
+                }
+            })
+            .detach();
+
+            assert_eq!(
+                client1.request(client1_conn_id, request1).await.unwrap(),
+                response1
+            );
+            assert_eq!(
+                client2.request(client2_conn_id, request2).await.unwrap(),
+                response2
+            );
+            assert_eq!(
+                client2.request(client2_conn_id, request3).await.unwrap(),
+                response3
+            );
+            assert_eq!(
+                client1.request(client1_conn_id, request4).await.unwrap(),
+                response4
+            );
+
+            client1.disconnect(client1_conn_id).await;
+            client2.disconnect(client1_conn_id).await;
+
+            server_done_rx.recv().await.unwrap();
+        });
+    }
+
+    #[test]
+    fn test_disconnect() {
+        smol::block_on(async move {
+            let socket_dir_path = TempDir::new("drop-client").unwrap();
+            let socket_path = socket_dir_path.path().join(".sock");
+            let listener = UnixListener::bind(&socket_path).unwrap();
+            let client_conn = UnixStream::connect(&socket_path).await.unwrap();
+            let (mut server_conn, _) = listener.accept().await.unwrap();
+
+            let client = Peer::new();
+            let (connection_id, handler) = client.add_connection(client_conn).await;
+            smol::spawn(handler).detach();
+            client.disconnect(connection_id).await;
+
+            // Try sending an empty payload over and over, until the client is dropped and hangs up.
+            loop {
+                match server_conn.write(&[]).await {
+                    Ok(_) => {}
+                    Err(err) => {
+                        if err.kind() == io::ErrorKind::BrokenPipe {
+                            break;
+                        }
+                    }
+                }
+            }
+        });
+    }
+
+    #[test]
+    fn test_io_error() {
+        smol::block_on(async move {
+            let socket_dir_path = TempDir::new("io-error").unwrap();
+            let socket_path = socket_dir_path.path().join(".sock");
+            let _listener = UnixListener::bind(&socket_path).unwrap();
+            let mut client_conn = UnixStream::connect(&socket_path).await.unwrap();
+            client_conn.close().await.unwrap();
+
+            let client = Peer::new();
+            let (connection_id, handler) = client.add_connection(client_conn).await;
+            smol::spawn(handler).detach();
+
+            let err = client
+                .request(
+                    connection_id,
+                    proto::Auth {
+                        user_id: 42,
+                        access_token: "token".to_string(),
+                    },
+                )
+                .await
+                .unwrap_err();
+            assert_eq!(
+                err.downcast_ref::<io::Error>().unwrap().kind(),
+                io::ErrorKind::BrokenPipe
+            );
+        });
+    }
+}