diff --git a/Cargo.lock b/Cargo.lock index 4e1ea7131b42f7d3056561f40bfd3a809c6ecf00..9c4a2f4b86d2b404a1e6e3b59af98d11ed86d4d8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3035,6 +3035,7 @@ dependencies = [ "http_client", "parking_lot", "serde_json", + "thiserror 2.0.17", "yawc", ] @@ -9105,6 +9106,7 @@ dependencies = [ "anyhow", "base64 0.22.1", "client", + "cloud_api_client", "cloud_api_types", "cloud_llm_client", "collections", diff --git a/crates/client/src/client.rs b/crates/client/src/client.rs index 55d7eb5dce2be133e8c48dd23e8f34ebb68c50bf..2847af6072f18306e70feb87e692e836cfbfc02d 100644 --- a/crates/client/src/client.rs +++ b/crates/client/src/client.rs @@ -19,11 +19,12 @@ use credentials_provider::CredentialsProvider; use feature_flags::FeatureFlagAppExt as _; use futures::{ AsyncReadExt, FutureExt, SinkExt, Stream, StreamExt, TryFutureExt as _, TryStreamExt, - channel::oneshot, future::BoxFuture, + channel::{mpsc, oneshot}, + future::BoxFuture, }; use gpui::{App, AsyncApp, Entity, Global, Task, WeakEntity, actions}; use http_client::{HttpClient, HttpClientWithUrl, http, read_proxy_from_env}; -use parking_lot::RwLock; +use parking_lot::{Mutex, RwLock}; use postage::watch; use proxy::connect_proxy_stream; use rand::prelude::*; @@ -195,8 +196,10 @@ pub struct Client { telemetry: Arc, credentials_provider: ClientCredentialsProvider, state: RwLock, - handler_set: parking_lot::Mutex, - message_to_client_handlers: parking_lot::Mutex>, + handler_set: Mutex, + message_to_client_handlers: Mutex>, + sign_out_tx: mpsc::UnboundedSender<()>, + _handle_sign_out: Mutex>>, #[allow(clippy::type_complexity)] #[cfg(any(test, feature = "test-support"))] @@ -527,7 +530,8 @@ impl Client { http: Arc, cx: &mut App, ) -> Arc { - Arc::new(Self { + let (sign_out_tx, mut sign_out_rx) = mpsc::unbounded(); + let this = Arc::new(Self { id: AtomicU64::new(0), peer: Peer::new(0), telemetry: Telemetry::new(clock, http.clone(), cx), @@ -536,7 +540,9 @@ impl Client { credentials_provider: ClientCredentialsProvider::new(cx), state: Default::default(), handler_set: Default::default(), - message_to_client_handlers: parking_lot::Mutex::new(Vec::new()), + message_to_client_handlers: Mutex::new(Vec::new()), + sign_out_tx, + _handle_sign_out: Mutex::new(None), #[cfg(any(test, feature = "test-support"))] authenticate: Default::default(), @@ -544,7 +550,19 @@ impl Client { establish_connection: Default::default(), #[cfg(any(test, feature = "test-support"))] rpc_url: RwLock::default(), - }) + }); + this._handle_sign_out.lock().replace(cx.spawn({ + let weak_client = Arc::downgrade(&this); + async move |cx| { + while sign_out_rx.next().await.is_some() { + if let Some(client) = weak_client.upgrade() { + client.sign_out(&cx).await; + } + } + } + })); + + this } pub fn production(cx: &mut App) -> Arc { @@ -1519,6 +1537,11 @@ impl Client { } } + /// Requests a sign out to be performed asynchronously. + pub fn request_sign_out(&self) { + self.sign_out_tx.unbounded_send(()).ok(); + } + pub fn disconnect(self: &Arc, cx: &AsyncApp) { self.peer.teardown(); self.set_status(Status::SignedOut, cx); @@ -1706,7 +1729,7 @@ impl ProtoClient for Client { self.peer.send_dynamic(connection_id, envelope) } - fn message_handler_set(&self) -> &parking_lot::Mutex { + fn message_handler_set(&self) -> &Mutex { &self.handler_set } diff --git a/crates/cloud_api_client/Cargo.toml b/crates/cloud_api_client/Cargo.toml index 9dc009bf2e59ba848c93a6ebc65be566a2aabd55..78c684e3e54ee29a5f3f3ae5620d4a52b445f92e 100644 --- a/crates/cloud_api_client/Cargo.toml +++ b/crates/cloud_api_client/Cargo.toml @@ -20,4 +20,5 @@ gpui_tokio.workspace = true http_client.workspace = true parking_lot.workspace = true serde_json.workspace = true +thiserror.workspace = true yawc.workspace = true diff --git a/crates/cloud_api_client/src/cloud_api_client.rs b/crates/cloud_api_client/src/cloud_api_client.rs index 35f34f5d436150b1bca641c425c0f9380bc0783b..f485e2d20c619715ea342fccd2a5cec0ecaa6f4e 100644 --- a/crates/cloud_api_client/src/cloud_api_client.rs +++ b/crates/cloud_api_client/src/cloud_api_client.rs @@ -11,6 +11,7 @@ use gpui_tokio::Tokio; use http_client::http::request; use http_client::{AsyncBody, HttpClientWithUrl, HttpRequestExt, Method, Request, StatusCode}; use parking_lot::RwLock; +use thiserror::Error; use yawc::WebSocket; use crate::websocket::Connection; @@ -20,6 +21,14 @@ struct Credentials { access_token: String, } +#[derive(Debug, Error)] +pub enum ClientApiError { + #[error("Unauthorized")] + Unauthorized, + #[error(transparent)] + Other(#[from] anyhow::Error), +} + pub struct CloudApiClient { credentials: RwLock>, http_client: Arc, @@ -58,7 +67,9 @@ impl CloudApiClient { build_request(req, body, credentials) } - pub async fn get_authenticated_user(&self) -> Result { + pub async fn get_authenticated_user( + &self, + ) -> Result { let request = self.build_request( Request::builder().method(Method::GET).uri( self.http_client @@ -71,19 +82,31 @@ impl CloudApiClient { let mut response = self.http_client.send(request).await?; if !response.status().is_success() { + if response.status() == StatusCode::UNAUTHORIZED { + return Err(ClientApiError::Unauthorized); + } + let mut body = String::new(); - response.body_mut().read_to_string(&mut body).await?; + response + .body_mut() + .read_to_string(&mut body) + .await + .context("failed to read response body")?; - anyhow::bail!( + return Err(ClientApiError::Other(anyhow::anyhow!( "Failed to get authenticated user.\nStatus: {:?}\nBody: {body}", response.status() - ) + ))); } let mut body = String::new(); - response.body_mut().read_to_string(&mut body).await?; + response + .body_mut() + .read_to_string(&mut body) + .await + .context("failed to read response body")?; - Ok(serde_json::from_str(&body)?) + Ok(serde_json::from_str(&body).context("failed to parse response body")?) } pub fn connect(&self, cx: &App) -> Result>> { @@ -118,7 +141,7 @@ impl CloudApiClient { pub async fn create_llm_token( &self, system_id: Option, - ) -> Result { + ) -> Result { let request_builder = Request::builder() .method(Method::POST) .uri( @@ -135,19 +158,31 @@ impl CloudApiClient { let mut response = self.http_client.send(request).await?; if !response.status().is_success() { + if response.status() == StatusCode::UNAUTHORIZED { + return Err(ClientApiError::Unauthorized); + } + let mut body = String::new(); - response.body_mut().read_to_string(&mut body).await?; + response + .body_mut() + .read_to_string(&mut body) + .await + .context("failed to read response body")?; - anyhow::bail!( + return Err(ClientApiError::Other(anyhow::anyhow!( "Failed to create LLM token.\nStatus: {:?}\nBody: {body}", response.status() - ) + ))); } let mut body = String::new(); - response.body_mut().read_to_string(&mut body).await?; + response + .body_mut() + .read_to_string(&mut body) + .await + .context("failed to read response body")?; - Ok(serde_json::from_str(&body)?) + Ok(serde_json::from_str(&body).context("failed to parse response body")?) } pub async fn validate_credentials(&self, user_id: u32, access_token: &str) -> Result { diff --git a/crates/language_model/Cargo.toml b/crates/language_model/Cargo.toml index e472521074109216bd243f5875dcc325cc9b3fed..a586458e41bab0b12c5f92849659ed33c18f5a68 100644 --- a/crates/language_model/Cargo.toml +++ b/crates/language_model/Cargo.toml @@ -21,6 +21,7 @@ anyhow.workspace = true credentials_provider.workspace = true base64.workspace = true client.workspace = true +cloud_api_client.workspace = true cloud_api_types.workspace = true cloud_llm_client.workspace = true collections.workspace = true diff --git a/crates/language_model/src/model/cloud_model.rs b/crates/language_model/src/model/cloud_model.rs index 1ecb4a3f028a5c7e792ea42d1c4bd141764ffbdd..18e099b4d6fc62867bf35fbd1d4573093af44744 100644 --- a/crates/language_model/src/model/cloud_model.rs +++ b/crates/language_model/src/model/cloud_model.rs @@ -1,8 +1,9 @@ use std::fmt; use std::sync::Arc; -use anyhow::Result; +use anyhow::{Context as _, Result}; use client::Client; +use cloud_api_client::ClientApiError; use cloud_api_types::websocket_protocol::MessageToClient; use cloud_llm_client::{EXPIRED_LLM_TOKEN_HEADER_NAME, OUTDATED_LLM_TOKEN_HEADER_NAME}; use gpui::{App, AppContext as _, Context, Entity, EventEmitter, Global, ReadGlobal as _}; @@ -47,9 +48,20 @@ impl LlmApiToken { .system_id() .map(|system_id| system_id.to_string()); - let response = client.cloud_client().create_llm_token(system_id).await?; - *lock = Some(response.token.0.clone()); - Ok(response.token.0) + let result = client.cloud_client().create_llm_token(system_id).await; + match result { + Ok(response) => { + *lock = Some(response.token.0.clone()); + Ok(response.token.0) + } + Err(err) => match err { + ClientApiError::Unauthorized => { + client.request_sign_out(); + Err(err).context("Failed to create LLM token") + } + ClientApiError::Other(err) => Err(err), + }, + } } }