@@ -217,6 +217,9 @@ pub struct Client {
>,
>,
>,
+
+ #[cfg(any(test, feature = "test-support"))]
+ rpc_url: RwLock<Option<Url>>,
}
#[derive(Error, Debug)]
@@ -527,6 +530,8 @@ impl Client {
authenticate: Default::default(),
#[cfg(any(test, feature = "test-support"))]
establish_connection: Default::default(),
+ #[cfg(any(test, feature = "test-support"))]
+ rpc_url: RwLock::default(),
})
}
@@ -584,6 +589,12 @@ impl Client {
self
}
+ #[cfg(any(test, feature = "test-support"))]
+ pub fn override_rpc_url(&self, url: Url) -> &Self {
+ *self.rpc_url.write() = Some(url);
+ self
+ }
+
pub fn global(cx: &AppContext) -> Arc<Self> {
cx.global::<GlobalClient>().0.clone()
}
@@ -1086,38 +1097,50 @@ impl Client {
self.establish_websocket_connection(credentials, cx)
}
- async fn get_rpc_url(
+ fn rpc_url(
+ &self,
http: Arc<HttpClientWithUrl>,
release_channel: Option<ReleaseChannel>,
- ) -> Result<Url> {
- if let Some(url) = &*ZED_RPC_URL {
- return Url::parse(url).context("invalid rpc url");
- }
+ ) -> impl Future<Output = Result<Url>> {
+ #[cfg(any(test, feature = "test-support"))]
+ let url_override = self.rpc_url.read().clone();
- let mut url = http.build_url("/rpc");
- if let Some(preview_param) =
- release_channel.and_then(|channel| channel.release_query_param())
- {
- url += "?";
- url += preview_param;
- }
- let response = http.get(&url, Default::default(), false).await?;
- let collab_url = if response.status().is_redirection() {
- response
- .headers()
- .get("Location")
- .ok_or_else(|| anyhow!("missing location header in /rpc response"))?
- .to_str()
- .map_err(EstablishConnectionError::other)?
- .to_string()
- } else {
- Err(anyhow!(
- "unexpected /rpc response status {}",
- response.status()
- ))?
- };
+ async move {
+ #[cfg(any(test, feature = "test-support"))]
+ if let Some(url) = url_override {
+ return Ok(url);
+ }
+
+ if let Some(url) = &*ZED_RPC_URL {
+ return Url::parse(url).context("invalid rpc url");
+ }
+
+ let mut url = http.build_url("/rpc");
+ if let Some(preview_param) =
+ release_channel.and_then(|channel| channel.release_query_param())
+ {
+ url += "?";
+ url += preview_param;
+ }
- Url::parse(&collab_url).context("invalid rpc url")
+ let response = http.get(&url, Default::default(), false).await?;
+ let collab_url = if response.status().is_redirection() {
+ response
+ .headers()
+ .get("Location")
+ .ok_or_else(|| anyhow!("missing location header in /rpc response"))?
+ .to_str()
+ .map_err(EstablishConnectionError::other)?
+ .to_string()
+ } else {
+ Err(anyhow!(
+ "unexpected /rpc response status {}",
+ response.status()
+ ))?
+ };
+
+ Url::parse(&collab_url).context("invalid rpc url")
+ }
}
fn establish_websocket_connection(
@@ -1144,8 +1167,9 @@ impl Client {
);
let http = self.http.clone();
+ let rpc_url = self.rpc_url(http, release_channel);
cx.background_executor().spawn(async move {
- let mut rpc_url = Self::get_rpc_url(http, release_channel).await?;
+ let mut rpc_url = rpc_url.await?;
let rpc_host = rpc_url
.host_str()
.zip(rpc_url.port_or_known_default())
@@ -1186,6 +1210,7 @@ impl Client {
cx: &AsyncAppContext,
) -> Task<Result<Credentials>> {
let http = self.http.clone();
+ let this = self.clone();
cx.spawn(|cx| async move {
let background = cx.background_executor().clone();
@@ -1215,7 +1240,8 @@ impl Client {
{
eprintln!("authenticate as admin {login}, {token}");
- return Self::authenticate_as_admin(http, login.clone(), token.clone())
+ return this
+ .authenticate_as_admin(http, login.clone(), token.clone())
.await;
}
@@ -1303,6 +1329,7 @@ impl Client {
}
async fn authenticate_as_admin(
+ self: &Arc<Self>,
http: Arc<HttpClientWithUrl>,
login: String,
mut api_token: String,
@@ -1319,7 +1346,7 @@ impl Client {
// Use the collab server's admin API to retrieve the id
// of the impersonated user.
- let mut url = Self::get_rpc_url(http.clone(), None).await?;
+ let mut url = self.rpc_url(http.clone(), None).await?;
url.set_path("/user");
url.set_query(Some(&format!("github_login={login}")));
let request = Request::get(url.as_str())
@@ -1,7 +1,7 @@
fn main() {
let mut build = prost_build::Config::new();
build
- .type_attribute(".", "#[derive(serde::Serialize)]")
+ .type_attribute(".", "#[derive(serde::Serialize, serde::Deserialize)]")
.compile_protos(&["proto/zed.proto"], &["proto"])
.unwrap();
}