@@ -884,27 +884,28 @@ impl Client {
let old_credentials = self.state.read().credentials.clone();
if let Some(old_credentials) = old_credentials {
- self.cloud_client.set_credentials(
- old_credentials.user_id as u32,
- old_credentials.access_token.clone(),
- );
-
- // Fetch the authenticated user with the old credentials, to ensure they are still valid.
- if self.cloud_client.get_authenticated_user().await.is_ok() {
+ if self
+ .cloud_client
+ .validate_credentials(
+ old_credentials.user_id as u32,
+ &old_credentials.access_token,
+ )
+ .await?
+ {
credentials = Some(old_credentials);
}
}
if credentials.is_none() && try_provider {
if let Some(stored_credentials) = self.credentials_provider.read_credentials(cx).await {
- self.cloud_client.set_credentials(
- stored_credentials.user_id as u32,
- stored_credentials.access_token.clone(),
- );
-
- // Fetch the authenticated user with the stored credentials, and
- // clear them from the credentials provider if that fails.
- if self.cloud_client.get_authenticated_user().await.is_ok() {
+ if self
+ .cloud_client
+ .validate_credentials(
+ stored_credentials.user_id as u32,
+ &stored_credentials.access_token,
+ )
+ .await?
+ {
credentials = Some(stored_credentials);
} else {
self.credentials_provider
@@ -1709,7 +1710,7 @@ pub fn parse_zed_link<'a>(link: &'a str, cx: &App) -> Option<&'a str> {
#[cfg(test)]
mod tests {
use super::*;
- use crate::test::FakeServer;
+ use crate::test::{FakeServer, parse_authorization_header};
use clock::FakeSystemClock;
use gpui::{AppContext as _, BackgroundExecutor, TestAppContext};
@@ -1835,6 +1836,75 @@ mod tests {
));
}
+ #[gpui::test(iterations = 10)]
+ async fn test_reauthenticate_only_if_unauthorized(cx: &mut TestAppContext) {
+ init_test(cx);
+ let auth_count = Arc::new(Mutex::new(0));
+ let http_client = FakeHttpClient::create(|_request| async move {
+ Ok(http_client::Response::builder()
+ .status(200)
+ .body("".into())
+ .unwrap())
+ });
+ let client =
+ cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client.clone(), cx));
+ client.override_authenticate({
+ let auth_count = auth_count.clone();
+ move |cx| {
+ let auth_count = auth_count.clone();
+ cx.background_spawn(async move {
+ *auth_count.lock() += 1;
+ Ok(Credentials {
+ user_id: 1,
+ access_token: auth_count.lock().to_string(),
+ })
+ })
+ }
+ });
+
+ let credentials = client.sign_in(false, &cx.to_async()).await.unwrap();
+ assert_eq!(*auth_count.lock(), 1);
+ assert_eq!(credentials.access_token, "1");
+
+ // If credentials are still valid, signing in doesn't trigger authentication.
+ let credentials = client.sign_in(false, &cx.to_async()).await.unwrap();
+ assert_eq!(*auth_count.lock(), 1);
+ assert_eq!(credentials.access_token, "1");
+
+ // If the server is unavailable, signing in doesn't trigger authentication.
+ http_client
+ .as_fake()
+ .replace_handler(|_, _request| async move {
+ Ok(http_client::Response::builder()
+ .status(503)
+ .body("".into())
+ .unwrap())
+ });
+ client.sign_in(false, &cx.to_async()).await.unwrap_err();
+ assert_eq!(*auth_count.lock(), 1);
+
+ // If credentials became invalid, signing in triggers authentication.
+ http_client
+ .as_fake()
+ .replace_handler(|_, request| async move {
+ let credentials = parse_authorization_header(&request).unwrap();
+ if credentials.access_token == "2" {
+ Ok(http_client::Response::builder()
+ .status(200)
+ .body("".into())
+ .unwrap())
+ } else {
+ Ok(http_client::Response::builder()
+ .status(401)
+ .body("".into())
+ .unwrap())
+ }
+ });
+ let credentials = client.sign_in(false, &cx.to_async()).await.unwrap();
+ assert_eq!(*auth_count.lock(), 2);
+ assert_eq!(credentials.access_token, "2");
+ }
+
#[gpui::test(iterations = 10)]
async fn test_authenticating_more_than_once(
cx: &mut TestAppContext,
@@ -1,10 +1,10 @@
use std::sync::Arc;
-use anyhow::{Result, anyhow};
+use anyhow::{Context, Result, anyhow};
pub use cloud_api_types::*;
use futures::AsyncReadExt as _;
use http_client::http::request;
-use http_client::{AsyncBody, HttpClientWithUrl, Method, Request};
+use http_client::{AsyncBody, HttpClientWithUrl, Method, Request, StatusCode};
use parking_lot::RwLock;
struct Credentials {
@@ -40,27 +40,14 @@ impl CloudApiClient {
*self.credentials.write() = None;
}
- fn authorization_header(&self) -> Result<String> {
- let guard = self.credentials.read();
- let credentials = guard
- .as_ref()
- .ok_or_else(|| anyhow!("No credentials provided"))?;
-
- Ok(format!(
- "{} {}",
- credentials.user_id, credentials.access_token
- ))
- }
-
fn build_request(
&self,
req: request::Builder,
body: impl Into<AsyncBody>,
) -> Result<Request<AsyncBody>> {
- Ok(req
- .header("Content-Type", "application/json")
- .header("Authorization", self.authorization_header()?)
- .body(body.into())?)
+ let credentials = self.credentials.read();
+ let credentials = credentials.as_ref().context("no credentials provided")?;
+ build_request(req, body, credentials)
}
pub async fn get_authenticated_user(&self) -> Result<GetAuthenticatedUserResponse> {
@@ -152,4 +139,50 @@ impl CloudApiClient {
Ok(serde_json::from_str(&body)?)
}
+
+ pub async fn validate_credentials(&self, user_id: u32, access_token: &str) -> Result<bool> {
+ let request = build_request(
+ Request::builder().method(Method::GET).uri(
+ self.http_client
+ .build_zed_cloud_url("/client/users/me", &[])?
+ .as_ref(),
+ ),
+ AsyncBody::default(),
+ &Credentials {
+ user_id,
+ access_token: access_token.into(),
+ },
+ )?;
+
+ let mut response = self.http_client.send(request).await?;
+
+ if response.status().is_success() {
+ Ok(true)
+ } else {
+ let mut body = String::new();
+ response.body_mut().read_to_string(&mut body).await?;
+ if response.status() == StatusCode::UNAUTHORIZED {
+ return Ok(false);
+ } else {
+ return Err(anyhow!(
+ "Failed to get authenticated user.\nStatus: {:?}\nBody: {body}",
+ response.status()
+ ));
+ }
+ }
+ }
+}
+
+fn build_request(
+ req: request::Builder,
+ body: impl Into<AsyncBody>,
+ credentials: &Credentials,
+) -> Result<Request<AsyncBody>> {
+ Ok(req
+ .header("Content-Type", "application/json")
+ .header(
+ "Authorization",
+ format!("{} {}", credentials.user_id, credentials.access_token),
+ )
+ .body(body.into())?)
}