Allow rpc_url to be assigned on Client with test-support feature (#13430)

Nathan Sobo created

Also, allow proto messages to be deserialized. This is to support
translating these messages JS types in a new server implementation based
on CloudFlare durable objects.

Release Notes:

- N/A

Change summary

crates/client/src/client.rs | 89 +++++++++++++++++++++++++-------------
crates/proto/build.rs       |  2 
crates/proto/src/proto.rs   |  2 
3 files changed, 60 insertions(+), 33 deletions(-)

Detailed changes

crates/client/src/client.rs 🔗

@@ -217,6 +217,9 @@ pub struct Client {
             >,
         >,
     >,
+
+    #[cfg(any(test, feature = "test-support"))]
+    rpc_url: RwLock<Option<Url>>,
 }
 
 #[derive(Error, Debug)]
@@ -527,6 +530,8 @@ impl Client {
             authenticate: Default::default(),
             #[cfg(any(test, feature = "test-support"))]
             establish_connection: Default::default(),
+            #[cfg(any(test, feature = "test-support"))]
+            rpc_url: RwLock::default(),
         })
     }
 
@@ -584,6 +589,12 @@ impl Client {
         self
     }
 
+    #[cfg(any(test, feature = "test-support"))]
+    pub fn override_rpc_url(&self, url: Url) -> &Self {
+        *self.rpc_url.write() = Some(url);
+        self
+    }
+
     pub fn global(cx: &AppContext) -> Arc<Self> {
         cx.global::<GlobalClient>().0.clone()
     }
@@ -1086,38 +1097,50 @@ impl Client {
         self.establish_websocket_connection(credentials, cx)
     }
 
-    async fn get_rpc_url(
+    fn rpc_url(
+        &self,
         http: Arc<HttpClientWithUrl>,
         release_channel: Option<ReleaseChannel>,
-    ) -> Result<Url> {
-        if let Some(url) = &*ZED_RPC_URL {
-            return Url::parse(url).context("invalid rpc url");
-        }
+    ) -> impl Future<Output = Result<Url>> {
+        #[cfg(any(test, feature = "test-support"))]
+        let url_override = self.rpc_url.read().clone();
 
-        let mut url = http.build_url("/rpc");
-        if let Some(preview_param) =
-            release_channel.and_then(|channel| channel.release_query_param())
-        {
-            url += "?";
-            url += preview_param;
-        }
-        let response = http.get(&url, Default::default(), false).await?;
-        let collab_url = if response.status().is_redirection() {
-            response
-                .headers()
-                .get("Location")
-                .ok_or_else(|| anyhow!("missing location header in /rpc response"))?
-                .to_str()
-                .map_err(EstablishConnectionError::other)?
-                .to_string()
-        } else {
-            Err(anyhow!(
-                "unexpected /rpc response status {}",
-                response.status()
-            ))?
-        };
+        async move {
+            #[cfg(any(test, feature = "test-support"))]
+            if let Some(url) = url_override {
+                return Ok(url);
+            }
+
+            if let Some(url) = &*ZED_RPC_URL {
+                return Url::parse(url).context("invalid rpc url");
+            }
+
+            let mut url = http.build_url("/rpc");
+            if let Some(preview_param) =
+                release_channel.and_then(|channel| channel.release_query_param())
+            {
+                url += "?";
+                url += preview_param;
+            }
 
-        Url::parse(&collab_url).context("invalid rpc url")
+            let response = http.get(&url, Default::default(), false).await?;
+            let collab_url = if response.status().is_redirection() {
+                response
+                    .headers()
+                    .get("Location")
+                    .ok_or_else(|| anyhow!("missing location header in /rpc response"))?
+                    .to_str()
+                    .map_err(EstablishConnectionError::other)?
+                    .to_string()
+            } else {
+                Err(anyhow!(
+                    "unexpected /rpc response status {}",
+                    response.status()
+                ))?
+            };
+
+            Url::parse(&collab_url).context("invalid rpc url")
+        }
     }
 
     fn establish_websocket_connection(
@@ -1144,8 +1167,9 @@ impl Client {
             );
 
         let http = self.http.clone();
+        let rpc_url = self.rpc_url(http, release_channel);
         cx.background_executor().spawn(async move {
-            let mut rpc_url = Self::get_rpc_url(http, release_channel).await?;
+            let mut rpc_url = rpc_url.await?;
             let rpc_host = rpc_url
                 .host_str()
                 .zip(rpc_url.port_or_known_default())
@@ -1186,6 +1210,7 @@ impl Client {
         cx: &AsyncAppContext,
     ) -> Task<Result<Credentials>> {
         let http = self.http.clone();
+        let this = self.clone();
         cx.spawn(|cx| async move {
             let background = cx.background_executor().clone();
 
@@ -1215,7 +1240,8 @@ impl Client {
                     {
                         eprintln!("authenticate as admin {login}, {token}");
 
-                        return Self::authenticate_as_admin(http, login.clone(), token.clone())
+                        return this
+                            .authenticate_as_admin(http, login.clone(), token.clone())
                             .await;
                     }
 
@@ -1303,6 +1329,7 @@ impl Client {
     }
 
     async fn authenticate_as_admin(
+        self: &Arc<Self>,
         http: Arc<HttpClientWithUrl>,
         login: String,
         mut api_token: String,
@@ -1319,7 +1346,7 @@ impl Client {
 
         // Use the collab server's admin API to retrieve the id
         // of the impersonated user.
-        let mut url = Self::get_rpc_url(http.clone(), None).await?;
+        let mut url = self.rpc_url(http.clone(), None).await?;
         url.set_path("/user");
         url.set_query(Some(&format!("github_login={login}")));
         let request = Request::get(url.as_str())

crates/proto/build.rs 🔗

@@ -1,7 +1,7 @@
 fn main() {
     let mut build = prost_build::Config::new();
     build
-        .type_attribute(".", "#[derive(serde::Serialize)]")
+        .type_attribute(".", "#[derive(serde::Serialize, serde::Deserialize)]")
         .compile_protos(&["proto/zed.proto"], &["proto"])
         .unwrap();
 }

crates/proto/src/proto.rs 🔗

@@ -8,7 +8,7 @@ pub use error::*;
 pub use typed_envelope::*;
 
 use collections::HashMap;
-pub use prost::Message;
+pub use prost::{DecodeError, Message};
 use serde::Serialize;
 use std::any::{Any, TypeId};
 use std::time::Instant;