Close connection when `RpcClient` is dropped and add unit tests

Antonio Scandurra created

Change summary

zed/src/lib.rs        |   6 
zed/src/rpc_client.rs | 181 ++++++++++++++++++++++++++++++++++++--------
2 files changed, 149 insertions(+), 38 deletions(-)

Detailed changes

zed/src/lib.rs 🔗

@@ -1,7 +1,7 @@
 use anyhow::{anyhow, Context, Result};
 use gpui::{AsyncAppContext, MutableAppContext, Task};
 use rpc_client::RpcClient;
-use std::{convert::TryFrom, net::Shutdown, time::Duration};
+use std::{convert::TryFrom, time::Duration};
 use tiny_http::{Header, Response, Server};
 use url::Url;
 use util::SurfResultExt;
@@ -60,9 +60,7 @@ fn share_worktree(_: &(), cx: &mut MutableAppContext) {
         // a TLS stream using `native-tls`.
         let stream = smol::net::TcpStream::connect(rpc_address).await?;
 
-        let mut rpc_client = RpcClient::new(stream, executor, |stream| {
-            stream.shutdown(Shutdown::Read).ok();
-        });
+        let mut rpc_client = RpcClient::new(stream, executor);
 
         let auth_response = rpc_client
             .request(proto::from_client::Auth {

zed/src/rpc_client.rs 🔗

@@ -5,60 +5,81 @@ use postage::{
     oneshot,
     prelude::{Sink, Stream},
 };
-use smol::prelude::{AsyncRead, AsyncWrite};
+use smol::{
+    future::FutureExt,
+    io::WriteHalf,
+    prelude::{AsyncRead, AsyncWrite},
+};
 use std::{collections::HashMap, sync::Arc};
 use zed_rpc::proto::{self, MessageStream, RequestMessage, SendMessage, ServerMessage};
 
-pub struct RpcClient<Conn, ShutdownFn>
-where
-    ShutdownFn: FnMut(&mut Conn),
-{
-    stream: MessageStream<Conn>,
+pub struct RpcClient<Conn> {
+    stream: MessageStream<WriteHalf<Conn>>,
     response_channels: Arc<Mutex<HashMap<i32, oneshot::Sender<proto::from_server::Variant>>>>,
     next_message_id: i32,
-    shutdown_fn: ShutdownFn,
+    _drop_tx: oneshot::Sender<()>,
 }
 
-impl<Conn, ShutdownFn> RpcClient<Conn, ShutdownFn>
+impl<Conn> RpcClient<Conn>
 where
     Conn: Clone + AsyncRead + AsyncWrite + Unpin + Send + 'static,
-    ShutdownFn: FnMut(&mut Conn),
 {
-    pub fn new(conn: Conn, executor: Arc<Background>, shutdown_fn: ShutdownFn) -> Self {
+    pub fn new(conn: Conn, executor: Arc<Background>) -> Self {
+        let (conn_rx, conn_tx) = smol::io::split(conn);
+        let (drop_tx, mut drop_rx) = oneshot::channel();
         let response_channels = Arc::new(Mutex::new(HashMap::new()));
-
-        let result = Self {
+        let client = Self {
             next_message_id: 0,
-            stream: MessageStream::new(conn.clone()),
+            stream: MessageStream::new(conn_tx),
             response_channels: response_channels.clone(),
-            shutdown_fn,
+            _drop_tx: drop_tx,
         };
 
         executor
             .spawn::<Result<()>, _>(async move {
-                let mut stream = MessageStream::new(conn);
+                enum Message {
+                    Message(proto::FromServer),
+                    ClientDropped,
+                }
+
+                let mut stream = MessageStream::new(conn_rx);
+                let client_dropped = async move {
+                    assert!(drop_rx.recv().await.is_none());
+                    Ok(Message::ClientDropped) as Result<_>
+                };
+                smol::pin!(client_dropped);
                 loop {
-                    let message = stream.read_message::<proto::FromServer>().await?;
-                    if let Some(variant) = message.variant {
-                        if let Some(request_id) = message.request_id {
-                            let tx = response_channels.lock().remove(&request_id);
-                            if let Some(mut tx) = tx {
-                                tx.send(variant).await?;
+                    let message = async {
+                        Ok(Message::Message(
+                            stream.read_message::<proto::FromServer>().await?,
+                        ))
+                    };
+
+                    match message.race(&mut client_dropped).await? {
+                        Message::Message(message) => {
+                            if let Some(variant) = message.variant {
+                                if let Some(request_id) = message.request_id {
+                                    let tx = response_channels.lock().remove(&request_id);
+                                    if let Some(mut tx) = tx {
+                                        tx.send(variant).await?;
+                                    } else {
+                                        log::warn!(
+                                            "received RPC response to unknown request id {}",
+                                            request_id
+                                        );
+                                    }
+                                }
                             } else {
-                                log::warn!(
-                                    "received RPC response to unknown request id {}",
-                                    request_id
-                                );
+                                log::warn!("received RPC message with no content");
                             }
                         }
-                    } else {
-                        log::warn!("received RPC message with no content");
+                        Message::ClientDropped => break Ok(()),
                     }
                 }
             })
             .detach();
 
-        result
+        client
     }
 
     pub async fn request<T: RequestMessage>(&mut self, req: T) -> Result<T::Response> {
@@ -87,11 +108,103 @@ where
     }
 }
 
-impl<Conn, ShutdownFn> Drop for RpcClient<Conn, ShutdownFn>
-where
-    ShutdownFn: FnMut(&mut Conn),
-{
-    fn drop(&mut self) {
-        (self.shutdown_fn)(self.stream.inner_mut())
+#[cfg(test)]
+mod tests {
+    use super::*;
+    use smol::{
+        future::poll_once,
+        io::AsyncWriteExt,
+        net::unix::{UnixListener, UnixStream},
+    };
+    use std::{future::Future, io};
+    use tempdir::TempDir;
+
+    #[gpui::test]
+    async fn test_request_response(cx: gpui::TestAppContext) {
+        let executor = cx.read(|app| app.background_executor().clone());
+        let socket_dir_path = TempDir::new("request-response-socket").unwrap();
+        let socket_path = socket_dir_path.path().join(".sock");
+        let listener = UnixListener::bind(&socket_path).unwrap();
+        let client_conn = UnixStream::connect(&socket_path).await.unwrap();
+        let (server_conn, _) = listener.accept().await.unwrap();
+
+        let mut server_stream = MessageStream::new(server_conn);
+        let mut client = RpcClient::new(client_conn, executor.clone());
+
+        let client_req = client.request(proto::from_client::Auth {
+            user_id: 42,
+            access_token: "token".to_string(),
+        });
+        smol::pin!(client_req);
+        let server_req = send_recv(
+            &mut client_req,
+            server_stream.read_message::<proto::FromClient>(),
+        )
+        .await
+        .unwrap();
+        assert_eq!(
+            server_req.variant,
+            Some(proto::from_client::Variant::Auth(
+                proto::from_client::Auth {
+                    user_id: 42,
+                    access_token: "token".to_string()
+                }
+            ))
+        );
+
+        server_stream
+            .write_message(&proto::FromServer {
+                request_id: Some(server_req.id),
+                variant: Some(proto::from_server::Variant::AuthResponse(
+                    proto::from_server::AuthResponse {
+                        credentials_valid: true,
+                    },
+                )),
+            })
+            .await
+            .unwrap();
+        assert_eq!(
+            client_req.await.unwrap(),
+            proto::from_server::AuthResponse {
+                credentials_valid: true
+            }
+        );
+    }
+
+    #[gpui::test]
+    async fn test_drop_client(cx: gpui::TestAppContext) {
+        let executor = cx.read(|app| app.background_executor().clone());
+        let socket_dir_path = TempDir::new("request-response-socket").unwrap();
+        let socket_path = socket_dir_path.path().join(".sock");
+        let listener = UnixListener::bind(&socket_path).unwrap();
+        let client_conn = UnixStream::connect(&socket_path).await.unwrap();
+        let (mut server_conn, _) = listener.accept().await.unwrap();
+
+        let client = RpcClient::new(client_conn, executor.clone());
+        drop(client);
+
+        // Try sending an empty payload over and over, until the client is dropped and hangs up.
+        let error = loop {
+            match server_conn.write(&[0]).await {
+                Ok(_) => continue,
+                Err(err) => break err,
+            }
+        };
+        assert_eq!(error.kind(), io::ErrorKind::BrokenPipe);
+    }
+
+    async fn send_recv<S, R, O>(mut sender: S, receiver: R) -> O
+    where
+        S: Unpin + Future,
+        R: Future<Output = O>,
+    {
+        smol::pin!(receiver);
+        loop {
+            poll_once(&mut sender).await;
+            match poll_once(&mut receiver).await {
+                Some(message) => break message,
+                None => continue,
+            }
+        }
     }
 }