Allow rpc client to connect to an in-memory stream

Max Brunsfeld and Nathan Sobo created

Co-Authored-By: Nathan Sobo <nathan@zed.dev>

Change summary

zed/src/rpc.rs | 77 ++++++++++++++++++++++++---------------------------
1 file changed, 37 insertions(+), 40 deletions(-)

Detailed changes

zed/src/rpc.rs 🔗

@@ -1,5 +1,6 @@
 use crate::{language::LanguageRegistry, worktree::Worktree};
 use anyhow::{anyhow, Context, Result};
+use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSocketMessage};
 use gpui::executor::Background;
 use gpui::{AsyncAppContext, ModelHandle, Task, WeakModelHandle};
 use lazy_static::lazy_static;
@@ -85,24 +86,8 @@ impl Client {
         }
 
         let (user_id, access_token) = Self::login(cx.platform(), &cx.background()).await?;
-        self.connect(
-            &ZED_SERVER_URL,
-            user_id.parse()?,
-            access_token,
-            &cx.background(),
-        )
-        .await?;
-        Ok(())
-    }
-
-    pub async fn connect(
-        &self,
-        server_url: &str,
-        user_id: i32,
-        access_token: String,
-        executor: &Arc<Background>,
-    ) -> surf::Result<()> {
-        let connection_id = if let Some(host) = server_url.strip_prefix("https://") {
+        let user_id = user_id.parse()?;
+        if let Some(host) = ZED_SERVER_URL.strip_prefix("https://") {
             let stream = smol::net::TcpStream::connect(host).await?;
             let (stream, _) = async_tungstenite::async_tls::client_async_tls(
                 format!("wss://{}/rpc", host),
@@ -110,34 +95,46 @@ impl Client {
             )
             .await
             .context("websocket handshake")?;
-            log::info!("connected to rpc address {}", &*ZED_SERVER_URL);
-            let (connection_id, handler) = self.peer.add_connection(stream).await;
-            executor
-                .spawn(async move {
-                    if let Err(error) = handler.run().await {
-                        log::error!("connection error: {:?}", error);
-                    }
-                })
-                .detach();
-            connection_id
-        } else if let Some(host) = server_url.strip_prefix("http://") {
+            log::info!("connected to rpc address {}", *ZED_SERVER_URL);
+            self.add_connection(stream, user_id, access_token, &cx.background())
+                .await?;
+        } else if let Some(host) = ZED_SERVER_URL.strip_prefix("http://") {
             let stream = smol::net::TcpStream::connect(host).await?;
             let (stream, _) =
                 async_tungstenite::client_async(format!("ws://{}/rpc", host), stream).await?;
-            log::info!("connected to rpc address {}", &*ZED_SERVER_URL);
-            let (connection_id, handler) = self.peer.add_connection(stream).await;
-            executor
-                .spawn(async move {
-                    if let Err(error) = handler.run().await {
-                        log::error!("connection error: {:?}", error);
-                    }
-                })
-                .detach();
-            connection_id
+            log::info!("connected to rpc address {}", *ZED_SERVER_URL);
+            self.add_connection(stream, user_id, access_token, &cx.background())
+                .await?;
         } else {
-            return Err(anyhow!("invalid server url: {}", server_url))?;
+            return Err(anyhow!("invalid server url: {}", *ZED_SERVER_URL))?;
         };
 
+        Ok(())
+    }
+
+    pub async fn add_connection<Conn>(
+        &self,
+        conn: Conn,
+        user_id: i32,
+        access_token: String,
+        executor: &Arc<Background>,
+    ) -> surf::Result<()>
+    where
+        Conn: 'static
+            + futures::Sink<WebSocketMessage, Error = WebSocketError>
+            + futures::Stream<Item = Result<WebSocketMessage, WebSocketError>>
+            + Unpin
+            + Send,
+    {
+        let (connection_id, handler) = self.peer.add_connection(conn).await;
+        executor
+            .spawn(async move {
+                if let Err(error) = handler.run().await {
+                    log::error!("connection error: {:?}", error);
+                }
+            })
+            .detach();
+
         let auth_response = self
             .peer
             .request(