diff --git a/Cargo.lock b/Cargo.lock index 647f8218bc7c919e08dafb3c676a78c62a0a1b10..e9b3aa4b7ddd5ac46d34159818ab0307699c7a4c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2865,6 +2865,7 @@ dependencies = [ "chrono", "clock", "cloud_api_client", + "cloud_api_types", "cloud_llm_client", "collections", "credentials_provider", @@ -2878,6 +2879,7 @@ dependencies = [ "http_client", "http_client_tls", "httparse", + "language_model", "log", "objc2-foundation", "parking_lot", @@ -9335,7 +9337,6 @@ dependencies = [ "anthropic", "anyhow", "base64 0.22.1", - "client", "cloud_api_client", "cloud_api_types", "cloud_llm_client", diff --git a/crates/agent/src/tests/mod.rs b/crates/agent/src/tests/mod.rs index 036a6f1030c43b16d51f864a1d0176891e90b772..9808b95dd0812f9a857da8a9c39e78fde40af1f9 100644 --- a/crates/agent/src/tests/mod.rs +++ b/crates/agent/src/tests/mod.rs @@ -6,7 +6,7 @@ use acp_thread::{ use agent_client_protocol::{self as acp}; use agent_settings::AgentProfileId; use anyhow::Result; -use client::{Client, UserStore}; +use client::{Client, RefreshLlmTokenListener, UserStore}; use collections::IndexMap; use context_server::{ContextServer, ContextServerCommand, ContextServerId}; use feature_flags::FeatureFlagAppExt as _; @@ -3253,7 +3253,8 @@ async fn test_agent_connection(cx: &mut TestAppContext) { let clock = Arc::new(clock::FakeSystemClock::new()); let client = Client::new(clock, http_client, cx); let user_store = cx.new(|cx| UserStore::new(client.clone(), cx)); - language_model::init(user_store.clone(), client.clone(), cx); + language_model::init(cx); + RefreshLlmTokenListener::register(client.clone(), user_store.clone(), cx); language_models::init(user_store, client.clone(), cx); LanguageModelRegistry::test(cx); }); @@ -3982,7 +3983,8 @@ async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest { cx.set_http_client(Arc::new(http_client)); let client = Client::production(cx); let user_store = cx.new(|cx| UserStore::new(client.clone(), cx)); - language_model::init(user_store.clone(), client.clone(), cx); + language_model::init(cx); + RefreshLlmTokenListener::register(client.clone(), user_store.clone(), cx); language_models::init(user_store, client.clone(), cx); } }; diff --git a/crates/agent_servers/src/e2e_tests.rs b/crates/agent_servers/src/e2e_tests.rs index 956d106df2a260bd2eb31c14f4f1f1705bf74cd6..a08778b25564d79eaef9fca21d1190ff418f9935 100644 --- a/crates/agent_servers/src/e2e_tests.rs +++ b/crates/agent_servers/src/e2e_tests.rs @@ -1,6 +1,7 @@ use crate::{AgentServer, AgentServerDelegate}; use acp_thread::{AcpThread, AgentThreadEntry, ToolCall, ToolCallStatus}; use agent_client_protocol as acp; +use client::RefreshLlmTokenListener; use futures::{FutureExt, StreamExt, channel::mpsc, select}; use gpui::AppContext; use gpui::{Entity, TestAppContext}; @@ -413,7 +414,8 @@ pub async fn init_test(cx: &mut TestAppContext) -> Arc { cx.set_http_client(Arc::new(http_client)); let client = client::Client::production(cx); let user_store = cx.new(|cx| client::UserStore::new(client.clone(), cx)); - language_model::init(user_store, client, cx); + language_model::init(cx); + RefreshLlmTokenListener::register(client.clone(), user_store.clone(), cx); #[cfg(test)] project::agent_server_store::AllAgentServersSettings::override_global( diff --git a/crates/agent_ui/src/agent_configuration/add_llm_provider_modal.rs b/crates/agent_ui/src/agent_configuration/add_llm_provider_modal.rs index 4e3dd63b0337f9be54b550f4f4a6a5ca2e7cdd42..b97583377a00d28ea1a8aae6a1380cff3b69e6a0 100644 --- a/crates/agent_ui/src/agent_configuration/add_llm_provider_modal.rs +++ b/crates/agent_ui/src/agent_configuration/add_llm_provider_modal.rs @@ -815,7 +815,7 @@ mod tests { cx.set_global(store); theme_settings::init(theme::LoadThemes::JustBase, cx); - language_model::init_settings(cx); + language_model::init(cx); editor::init(cx); }); diff --git a/crates/agent_ui/src/agent_diff.rs b/crates/agent_ui/src/agent_diff.rs index 2e709c0be3297e270119c048c7b8e25e7958ee69..cafe1baabb8bd321282ad8ac030dc58f31ed86aa 100644 --- a/crates/agent_ui/src/agent_diff.rs +++ b/crates/agent_ui/src/agent_diff.rs @@ -1808,7 +1808,7 @@ mod tests { cx.set_global(settings_store); prompt_store::init(cx); theme_settings::init(theme::LoadThemes::JustBase, cx); - language_model::init_settings(cx); + language_model::init(cx); }); let fs = FakeFs::new(cx.executor()); @@ -1965,7 +1965,7 @@ mod tests { cx.set_global(settings_store); prompt_store::init(cx); theme_settings::init(theme::LoadThemes::JustBase, cx); - language_model::init_settings(cx); + language_model::init(cx); workspace::register_project_item::(cx); }); diff --git a/crates/agent_ui/src/inline_assistant.rs b/crates/agent_ui/src/inline_assistant.rs index 01543b657fc2d00fbf8c68cd96c6329d2f4952d6..ed34962a05d2243c0a64e7640b3727765629cbe3 100644 --- a/crates/agent_ui/src/inline_assistant.rs +++ b/crates/agent_ui/src/inline_assistant.rs @@ -2114,7 +2114,8 @@ pub mod evals { client::init(&client, cx); workspace::init(app_state.clone(), cx); let user_store = cx.new(|cx| UserStore::new(client.clone(), cx)); - language_model::init(user_store.clone(), client.clone(), cx); + language_model::init(cx); + RefreshLlmTokenListener::register(client.clone(), user_store.clone(), cx); language_models::init(user_store, client.clone(), cx); cx.set_global(inline_assistant); diff --git a/crates/client/Cargo.toml b/crates/client/Cargo.toml index 27b54a42148076020348a8952ea4af489053ecb0..7bbaccb22e0e6c7508240186103e216f83be2f0c 100644 --- a/crates/client/Cargo.toml +++ b/crates/client/Cargo.toml @@ -22,6 +22,7 @@ base64.workspace = true chrono = { workspace = true, features = ["serde"] } clock.workspace = true cloud_api_client.workspace = true +cloud_api_types.workspace = true cloud_llm_client.workspace = true collections.workspace = true credentials_provider.workspace = true @@ -35,6 +36,7 @@ gpui_tokio.workspace = true http_client.workspace = true http_client_tls.workspace = true httparse = "1.10" +language_model.workspace = true log.workspace = true parking_lot.workspace = true paths.workspace = true diff --git a/crates/client/src/client.rs b/crates/client/src/client.rs index 4edee9f0db0bb65ae8d1a70d79de68dbc0f96bd0..2fc926d777978035f11b5506a4746617d06e59c9 100644 --- a/crates/client/src/client.rs +++ b/crates/client/src/client.rs @@ -1,6 +1,7 @@ #[cfg(any(test, feature = "test-support"))] pub mod test; +mod llm_token; mod proxy; pub mod telemetry; pub mod user; @@ -13,8 +14,9 @@ use async_tungstenite::tungstenite::{ http::{HeaderValue, Request, StatusCode}, }; use clock::SystemClock; -use cloud_api_client::CloudApiClient; use cloud_api_client::websocket_protocol::MessageToClient; +use cloud_api_client::{ClientApiError, CloudApiClient}; +use cloud_api_types::OrganizationId; use credentials_provider::CredentialsProvider; use feature_flags::FeatureFlagAppExt as _; use futures::{ @@ -24,6 +26,7 @@ use futures::{ }; use gpui::{App, AsyncApp, Entity, Global, Task, WeakEntity, actions}; use http_client::{HttpClient, HttpClientWithUrl, http, read_proxy_from_env}; +use language_model::LlmApiToken; use parking_lot::{Mutex, RwLock}; use postage::watch; use proxy::connect_proxy_stream; @@ -51,6 +54,7 @@ use tokio::net::TcpStream; use url::Url; use util::{ConnectionResult, ResultExt}; +pub use llm_token::*; pub use rpc::*; pub use telemetry_events::Event; pub use user::*; @@ -1517,6 +1521,66 @@ impl Client { }) } + pub async fn acquire_llm_token( + &self, + llm_token: &LlmApiToken, + organization_id: Option, + ) -> Result { + let system_id = self.telemetry().system_id().map(|x| x.to_string()); + let cloud_client = self.cloud_client(); + match llm_token + .acquire(&cloud_client, system_id, organization_id) + .await + { + Ok(token) => Ok(token), + Err(ClientApiError::Unauthorized) => { + self.request_sign_out(); + Err(ClientApiError::Unauthorized).context("Failed to create LLM token") + } + Err(err) => Err(anyhow::Error::from(err)), + } + } + + pub async fn refresh_llm_token( + &self, + llm_token: &LlmApiToken, + organization_id: Option, + ) -> Result { + let system_id = self.telemetry().system_id().map(|x| x.to_string()); + let cloud_client = self.cloud_client(); + match llm_token + .refresh(&cloud_client, system_id, organization_id) + .await + { + Ok(token) => Ok(token), + Err(ClientApiError::Unauthorized) => { + self.request_sign_out(); + return Err(ClientApiError::Unauthorized).context("Failed to create LLM token"); + } + Err(err) => return Err(anyhow::Error::from(err)), + } + } + + pub async fn clear_and_refresh_llm_token( + &self, + llm_token: &LlmApiToken, + organization_id: Option, + ) -> Result { + let system_id = self.telemetry().system_id().map(|x| x.to_string()); + let cloud_client = self.cloud_client(); + match llm_token + .clear_and_refresh(&cloud_client, system_id, organization_id) + .await + { + Ok(token) => Ok(token), + Err(ClientApiError::Unauthorized) => { + self.request_sign_out(); + return Err(ClientApiError::Unauthorized).context("Failed to create LLM token"); + } + Err(err) => return Err(anyhow::Error::from(err)), + } + } + pub async fn sign_out(self: &Arc, cx: &AsyncApp) { self.state.write().credentials = None; self.cloud_client.clear_credentials(); diff --git a/crates/client/src/llm_token.rs b/crates/client/src/llm_token.rs new file mode 100644 index 0000000000000000000000000000000000000000..f62aa6dd4dc3462bc3a0f6f46c35f0e4e5499816 --- /dev/null +++ b/crates/client/src/llm_token.rs @@ -0,0 +1,116 @@ +use super::{Client, UserStore}; +use cloud_api_types::websocket_protocol::MessageToClient; +use cloud_llm_client::{EXPIRED_LLM_TOKEN_HEADER_NAME, OUTDATED_LLM_TOKEN_HEADER_NAME}; +use gpui::{ + App, AppContext as _, Context, Entity, EventEmitter, Global, ReadGlobal as _, Subscription, +}; +use language_model::LlmApiToken; +use std::sync::Arc; + +pub trait NeedsLlmTokenRefresh { + /// Returns whether the LLM token needs to be refreshed. + fn needs_llm_token_refresh(&self) -> bool; +} + +impl NeedsLlmTokenRefresh for http_client::Response { + fn needs_llm_token_refresh(&self) -> bool { + self.headers().get(EXPIRED_LLM_TOKEN_HEADER_NAME).is_some() + || self.headers().get(OUTDATED_LLM_TOKEN_HEADER_NAME).is_some() + } +} + +enum TokenRefreshMode { + Refresh, + ClearAndRefresh, +} + +pub fn global_llm_token(cx: &App) -> LlmApiToken { + RefreshLlmTokenListener::global(cx) + .read(cx) + .llm_api_token + .clone() +} + +struct GlobalRefreshLlmTokenListener(Entity); + +impl Global for GlobalRefreshLlmTokenListener {} + +pub struct LlmTokenRefreshedEvent; + +pub struct RefreshLlmTokenListener { + client: Arc, + user_store: Entity, + llm_api_token: LlmApiToken, + _subscription: Subscription, +} + +impl EventEmitter for RefreshLlmTokenListener {} + +impl RefreshLlmTokenListener { + pub fn register(client: Arc, user_store: Entity, cx: &mut App) { + let listener = cx.new(|cx| RefreshLlmTokenListener::new(client, user_store, cx)); + cx.set_global(GlobalRefreshLlmTokenListener(listener)); + } + + pub fn global(cx: &App) -> Entity { + GlobalRefreshLlmTokenListener::global(cx).0.clone() + } + + fn new(client: Arc, user_store: Entity, cx: &mut Context) -> Self { + client.add_message_to_client_handler({ + let this = cx.weak_entity(); + move |message, cx| { + if let Some(this) = this.upgrade() { + Self::handle_refresh_llm_token(this, message, cx); + } + } + }); + + let subscription = cx.subscribe(&user_store, |this, _user_store, event, cx| { + if matches!(event, super::user::Event::OrganizationChanged) { + this.refresh(TokenRefreshMode::ClearAndRefresh, cx); + } + }); + + Self { + client, + user_store, + llm_api_token: LlmApiToken::default(), + _subscription: subscription, + } + } + + fn refresh(&self, mode: TokenRefreshMode, cx: &mut Context) { + let client = self.client.clone(); + let llm_api_token = self.llm_api_token.clone(); + let organization_id = self + .user_store + .read(cx) + .current_organization() + .map(|organization| organization.id.clone()); + cx.spawn(async move |this, cx| { + match mode { + TokenRefreshMode::Refresh => { + client + .refresh_llm_token(&llm_api_token, organization_id) + .await?; + } + TokenRefreshMode::ClearAndRefresh => { + client + .clear_and_refresh_llm_token(&llm_api_token, organization_id) + .await?; + } + } + this.update(cx, |_this, cx| cx.emit(LlmTokenRefreshedEvent)) + }) + .detach_and_log_err(cx); + } + + fn handle_refresh_llm_token(this: Entity, message: &MessageToClient, cx: &mut App) { + match message { + MessageToClient::UserUpdated => { + this.update(cx, |this, cx| this.refresh(TokenRefreshMode::Refresh, cx)); + } + } + } +} diff --git a/crates/edit_prediction/src/capture_example.rs b/crates/edit_prediction/src/capture_example.rs index d21df7868162d279cb18aeea3ef04d4ea9d7be7f..f1decb5343bf0137005c66432c6c1ca72a12b1b7 100644 --- a/crates/edit_prediction/src/capture_example.rs +++ b/crates/edit_prediction/src/capture_example.rs @@ -1,6 +1,8 @@ use crate::{StoredEvent, example_spec::ExampleSpec}; use anyhow::Result; use buffer_diff::BufferDiffSnapshot; +#[cfg(test)] +use client::RefreshLlmTokenListener; use collections::HashMap; use gpui::{App, Entity, Task}; use language::Buffer; @@ -548,7 +550,8 @@ mod tests { let http_client = FakeHttpClient::with_404_response(); let client = Client::new(Arc::new(FakeSystemClock::new()), http_client, cx); let user_store = cx.new(|cx| UserStore::new(client.clone(), cx)); - language_model::init(user_store.clone(), client.clone(), cx); + language_model::init(cx); + RefreshLlmTokenListener::register(client.clone(), user_store.clone(), cx); EditPredictionStore::global(&client, &user_store, cx); }) } diff --git a/crates/edit_prediction/src/edit_prediction.rs b/crates/edit_prediction/src/edit_prediction.rs index b9e9ee66e91dfd0392f5068c962ff8f041f1fbee..699f286985548bc9a9dd626b85f11bc8fde3c1af 100644 --- a/crates/edit_prediction/src/edit_prediction.rs +++ b/crates/edit_prediction/src/edit_prediction.rs @@ -1,5 +1,8 @@ use anyhow::Result; -use client::{Client, EditPredictionUsage, UserStore}; +use client::{ + Client, EditPredictionUsage, NeedsLlmTokenRefresh, UserStore, + global_llm_token as global_llm_api_token, +}; use cloud_api_types::{OrganizationId, SubmitEditPredictionFeedbackBody}; use cloud_llm_client::predict_edits_v3::{ PredictEditsV3Request, PredictEditsV3Response, RawCompletionRequest, RawCompletionResponse, @@ -31,7 +34,7 @@ use heapless::Vec as ArrayVec; use language::language_settings::all_language_settings; use language::{Anchor, Buffer, File, Point, TextBufferSnapshot, ToOffset, ToPoint}; use language::{BufferSnapshot, OffsetRangeExt}; -use language_model::{LlmApiToken, NeedsLlmTokenRefresh}; +use language_model::LlmApiToken; use project::{DisableAiSettings, Project, ProjectPath, WorktreeId}; use release_channel::AppVersion; use semver::Version; @@ -748,7 +751,7 @@ impl EditPredictionStore { pub fn new(client: Arc, user_store: Entity, cx: &mut Context) -> Self { let data_collection_choice = Self::load_data_collection_choice(cx); - let llm_token = LlmApiToken::global(cx); + let llm_token = global_llm_api_token(cx); let (reject_tx, reject_rx) = mpsc::unbounded(); cx.background_spawn({ @@ -877,7 +880,9 @@ impl EditPredictionStore { let experiments = cx .background_spawn(async move { let http_client = client.http_client(); - let token = llm_token.acquire(&client, organization_id).await?; + let token = client + .acquire_llm_token(&llm_token, organization_id.clone()) + .await?; let url = http_client.build_zed_llm_url("/edit_prediction_experiments", &[])?; let request = http_client::Request::builder() .method(Method::GET) @@ -2539,12 +2544,15 @@ impl EditPredictionStore { Res: DeserializeOwned, { let http_client = client.http_client(); - let mut token = if require_auth { - Some(llm_token.acquire(&client, organization_id.clone()).await?) + Some( + client + .acquire_llm_token(&llm_token, organization_id.clone()) + .await?, + ) } else { - llm_token - .acquire(&client, organization_id.clone()) + client + .acquire_llm_token(&llm_token, organization_id.clone()) .await .ok() }; @@ -2588,7 +2596,11 @@ impl EditPredictionStore { return Ok((serde_json::from_slice(&body)?, usage)); } else if !did_retry && token.is_some() && response.needs_llm_token_refresh() { did_retry = true; - token = Some(llm_token.refresh(&client, organization_id.clone()).await?); + token = Some( + client + .refresh_llm_token(&llm_token, organization_id.clone()) + .await?, + ); } else { let mut body = String::new(); response.body_mut().read_to_string(&mut body).await?; diff --git a/crates/edit_prediction/src/edit_prediction_tests.rs b/crates/edit_prediction/src/edit_prediction_tests.rs index 6fe61338e764a40aec9cf6f3191f1191bafe9200..1ba8b27aa785024a47a09c3299a1f3786a028ccf 100644 --- a/crates/edit_prediction/src/edit_prediction_tests.rs +++ b/crates/edit_prediction/src/edit_prediction_tests.rs @@ -1,6 +1,6 @@ use super::*; use crate::udiff::apply_diff_to_string; -use client::{UserStore, test::FakeServer}; +use client::{RefreshLlmTokenListener, UserStore, test::FakeServer}; use clock::FakeSystemClock; use clock::ReplicaId; use cloud_api_types::{CreateLlmTokenResponse, LlmToken}; @@ -23,7 +23,7 @@ use language::{ Anchor, Buffer, Capability, CursorShape, Diagnostic, DiagnosticEntry, DiagnosticSet, DiagnosticSeverity, Operation, Point, Selection, SelectionGoal, }; -use language_model::RefreshLlmTokenListener; + use lsp::LanguageServerId; use parking_lot::Mutex; use pretty_assertions::{assert_eq, assert_matches}; @@ -2439,7 +2439,8 @@ fn init_test_with_fake_client( client.cloud_client().set_credentials(1, "test".into()); let user_store = cx.new(|cx| UserStore::new(client.clone(), cx)); - language_model::init(user_store.clone(), client.clone(), cx); + language_model::init(cx); + RefreshLlmTokenListener::register(client.clone(), user_store.clone(), cx); let ep_store = EditPredictionStore::global(&client, &user_store, cx); ( @@ -2891,7 +2892,7 @@ async fn test_unauthenticated_without_custom_url_blocks_prediction_impl(cx: &mut cx.update(|cx| client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx)); let user_store = cx.update(|cx| cx.new(|cx| client::UserStore::new(client.clone(), cx))); cx.update(|cx| { - language_model::RefreshLlmTokenListener::register(client.clone(), user_store.clone(), cx); + RefreshLlmTokenListener::register(client.clone(), user_store.clone(), cx); }); let ep_store = cx.new(|cx| EditPredictionStore::new(client, project.read(cx).user_store(), cx)); diff --git a/crates/edit_prediction_cli/src/headless.rs b/crates/edit_prediction_cli/src/headless.rs index 3a204a7052f8a41d6e7c2c49860b62f588358644..48b7381020f48d868d9f6413ef343b30718e5be6 100644 --- a/crates/edit_prediction_cli/src/headless.rs +++ b/crates/edit_prediction_cli/src/headless.rs @@ -1,4 +1,4 @@ -use client::{Client, ProxySettings, UserStore}; +use client::{Client, ProxySettings, RefreshLlmTokenListener, UserStore}; use db::AppDatabase; use extension::ExtensionHostProxy; use fs::RealFs; @@ -109,7 +109,8 @@ pub fn init(cx: &mut App) -> EpAppState { debug_adapter_extension::init(extension_host_proxy.clone(), cx); language_extension::init(LspAccess::Noop, extension_host_proxy, languages.clone()); - language_model::init(user_store.clone(), client.clone(), cx); + language_model::init(cx); + RefreshLlmTokenListener::register(client.clone(), user_store.clone(), cx); language_models::init(user_store.clone(), client.clone(), cx); languages::init(languages.clone(), fs.clone(), node_runtime.clone(), cx); prompt_store::init(cx); diff --git a/crates/eval_cli/src/headless.rs b/crates/eval_cli/src/headless.rs index 72feaacbae270224240f1da9e6e6c1008ba97c84..0ddd99e8f8abd9dbd73e1d7461526f3e7cb24f11 100644 --- a/crates/eval_cli/src/headless.rs +++ b/crates/eval_cli/src/headless.rs @@ -1,7 +1,7 @@ use std::path::PathBuf; use std::sync::Arc; -use client::{Client, ProxySettings, UserStore}; +use client::{Client, ProxySettings, RefreshLlmTokenListener, UserStore}; use db::AppDatabase; use extension::ExtensionHostProxy; use fs::RealFs; @@ -108,7 +108,8 @@ pub fn init(cx: &mut App) -> Arc { let extension_host_proxy = ExtensionHostProxy::global(cx); debug_adapter_extension::init(extension_host_proxy.clone(), cx); language_extension::init(LspAccess::Noop, extension_host_proxy, languages.clone()); - language_model::init(user_store.clone(), client.clone(), cx); + language_model::init(cx); + RefreshLlmTokenListener::register(client.clone(), user_store.clone(), cx); language_models::init(user_store.clone(), client.clone(), cx); languages::init(languages.clone(), fs.clone(), node_runtime.clone(), cx); prompt_store::init(cx); diff --git a/crates/language_model/Cargo.toml b/crates/language_model/Cargo.toml index c449b11211b7fafcb1d8cc5c70936183d853917f..4712d86dff6c44f9cdd8576a08349ccfa7d0ecca 100644 --- a/crates/language_model/Cargo.toml +++ b/crates/language_model/Cargo.toml @@ -20,7 +20,6 @@ anthropic = { workspace = true, features = ["schemars"] } anyhow.workspace = true credentials_provider.workspace = true base64.workspace = true -client.workspace = true cloud_api_client.workspace = true cloud_api_types.workspace = true cloud_llm_client.workspace = true diff --git a/crates/language_model/src/language_model.rs b/crates/language_model/src/language_model.rs index cf6c11e766888e519a86fc1d71dec2b6b0ebf9cf..3f309b7b1d4152c54324efaaf0ad3bdb7035eea4 100644 --- a/crates/language_model/src/language_model.rs +++ b/crates/language_model/src/language_model.rs @@ -11,12 +11,10 @@ pub mod tool_schema; pub mod fake_provider; use anyhow::{Result, anyhow}; -use client::Client; -use client::UserStore; use cloud_llm_client::CompletionRequestStatus; use futures::FutureExt; use futures::{StreamExt, future::BoxFuture, stream::BoxStream}; -use gpui::{AnyView, App, AsyncApp, Entity, SharedString, Task, Window}; +use gpui::{AnyView, App, AsyncApp, SharedString, Task, Window}; use http_client::{StatusCode, http}; use icons::IconName; use parking_lot::Mutex; @@ -39,12 +37,7 @@ pub use crate::tool_schema::LanguageModelToolSchemaFormat; pub use env_var::{EnvVar, env_var}; pub use provider::*; -pub fn init(user_store: Entity, client: Arc, cx: &mut App) { - init_settings(cx); - RefreshLlmTokenListener::register(client, user_store, cx); -} - -pub fn init_settings(cx: &mut App) { +pub fn init(cx: &mut App) { registry::init(cx); } diff --git a/crates/language_model/src/model/cloud_model.rs b/crates/language_model/src/model/cloud_model.rs index a1362d78292082522f4e883efe42b2ca1e0a0300..db926aab1f70a46a4e70b1b67c2c9e4c4f465c2c 100644 --- a/crates/language_model/src/model/cloud_model.rs +++ b/crates/language_model/src/model/cloud_model.rs @@ -1,16 +1,9 @@ use std::fmt; use std::sync::Arc; -use anyhow::{Context as _, Result}; -use client::Client; -use client::UserStore; use cloud_api_client::ClientApiError; +use cloud_api_client::CloudApiClient; use cloud_api_types::OrganizationId; -use cloud_api_types::websocket_protocol::MessageToClient; -use cloud_llm_client::{EXPIRED_LLM_TOKEN_HEADER_NAME, OUTDATED_LLM_TOKEN_HEADER_NAME}; -use gpui::{ - App, AppContext as _, Context, Entity, EventEmitter, Global, ReadGlobal as _, Subscription, -}; use smol::lock::{RwLock, RwLockUpgradableReadGuard, RwLockWriteGuard}; use thiserror::Error; @@ -30,18 +23,12 @@ impl fmt::Display for PaymentRequiredError { pub struct LlmApiToken(Arc>>); impl LlmApiToken { - pub fn global(cx: &App) -> Self { - RefreshLlmTokenListener::global(cx) - .read(cx) - .llm_api_token - .clone() - } - pub async fn acquire( &self, - client: &Arc, + client: &CloudApiClient, + system_id: Option, organization_id: Option, - ) -> Result { + ) -> Result { let lock = self.0.upgradable_read().await; if let Some(token) = lock.as_ref() { Ok(token.to_string()) @@ -49,6 +36,7 @@ impl LlmApiToken { Self::fetch( RwLockUpgradableReadGuard::upgrade(lock).await, client, + system_id, organization_id, ) .await @@ -57,10 +45,11 @@ impl LlmApiToken { pub async fn refresh( &self, - client: &Arc, + client: &CloudApiClient, + system_id: Option, organization_id: Option, - ) -> Result { - Self::fetch(self.0.write().await, client, organization_id).await + ) -> Result { + Self::fetch(self.0.write().await, client, system_id, organization_id).await } /// Clears the existing token before attempting to fetch a new one. @@ -69,28 +58,22 @@ impl LlmApiToken { /// leave a token for the wrong organization. pub async fn clear_and_refresh( &self, - client: &Arc, + client: &CloudApiClient, + system_id: Option, organization_id: Option, - ) -> Result { + ) -> Result { let mut lock = self.0.write().await; *lock = None; - Self::fetch(lock, client, organization_id).await + Self::fetch(lock, client, system_id, organization_id).await } async fn fetch( mut lock: RwLockWriteGuard<'_, Option>, - client: &Arc, + client: &CloudApiClient, + system_id: Option, organization_id: Option, - ) -> Result { - let system_id = client - .telemetry() - .system_id() - .map(|system_id| system_id.to_string()); - - let result = client - .cloud_client() - .create_llm_token(system_id, organization_id) - .await; + ) -> Result { + let result = client.create_llm_token(system_id, organization_id).await; match result { Ok(response) => { *lock = Some(response.token.0.clone()); @@ -98,112 +81,7 @@ impl LlmApiToken { } Err(err) => { *lock = None; - match err { - ClientApiError::Unauthorized => { - client.request_sign_out(); - Err(err).context("Failed to create LLM token") - } - ClientApiError::Other(err) => Err(err), - } - } - } - } -} - -pub trait NeedsLlmTokenRefresh { - /// Returns whether the LLM token needs to be refreshed. - fn needs_llm_token_refresh(&self) -> bool; -} - -impl NeedsLlmTokenRefresh for http_client::Response { - fn needs_llm_token_refresh(&self) -> bool { - self.headers().get(EXPIRED_LLM_TOKEN_HEADER_NAME).is_some() - || self.headers().get(OUTDATED_LLM_TOKEN_HEADER_NAME).is_some() - } -} - -enum TokenRefreshMode { - Refresh, - ClearAndRefresh, -} - -struct GlobalRefreshLlmTokenListener(Entity); - -impl Global for GlobalRefreshLlmTokenListener {} - -pub struct LlmTokenRefreshedEvent; - -pub struct RefreshLlmTokenListener { - client: Arc, - user_store: Entity, - llm_api_token: LlmApiToken, - _subscription: Subscription, -} - -impl EventEmitter for RefreshLlmTokenListener {} - -impl RefreshLlmTokenListener { - pub fn register(client: Arc, user_store: Entity, cx: &mut App) { - let listener = cx.new(|cx| RefreshLlmTokenListener::new(client, user_store, cx)); - cx.set_global(GlobalRefreshLlmTokenListener(listener)); - } - - pub fn global(cx: &App) -> Entity { - GlobalRefreshLlmTokenListener::global(cx).0.clone() - } - - fn new(client: Arc, user_store: Entity, cx: &mut Context) -> Self { - client.add_message_to_client_handler({ - let this = cx.weak_entity(); - move |message, cx| { - if let Some(this) = this.upgrade() { - Self::handle_refresh_llm_token(this, message, cx); - } - } - }); - - let subscription = cx.subscribe(&user_store, |this, _user_store, event, cx| { - if matches!(event, client::user::Event::OrganizationChanged) { - this.refresh(TokenRefreshMode::ClearAndRefresh, cx); - } - }); - - Self { - client, - user_store, - llm_api_token: LlmApiToken::default(), - _subscription: subscription, - } - } - - fn refresh(&self, mode: TokenRefreshMode, cx: &mut Context) { - let client = self.client.clone(); - let llm_api_token = self.llm_api_token.clone(); - let organization_id = self - .user_store - .read(cx) - .current_organization() - .map(|organization| organization.id.clone()); - cx.spawn(async move |this, cx| { - match mode { - TokenRefreshMode::Refresh => { - llm_api_token.refresh(&client, organization_id).await?; - } - TokenRefreshMode::ClearAndRefresh => { - llm_api_token - .clear_and_refresh(&client, organization_id) - .await?; - } - } - this.update(cx, |_this, cx| cx.emit(LlmTokenRefreshedEvent)) - }) - .detach_and_log_err(cx); - } - - fn handle_refresh_llm_token(this: Entity, message: &MessageToClient, cx: &mut App) { - match message { - MessageToClient::UserUpdated => { - this.update(cx, |this, cx| this.refresh(TokenRefreshMode::Refresh, cx)); + Err(err) } } } diff --git a/crates/language_models/src/provider/cloud.rs b/crates/language_models/src/provider/cloud.rs index f9372a4d7ea9c078c58f633cc58bd5597ef49212..d1daec196aa4ffe5d940c8087669d32c9c64f4f3 100644 --- a/crates/language_models/src/provider/cloud.rs +++ b/crates/language_models/src/provider/cloud.rs @@ -1,7 +1,10 @@ use ai_onboarding::YoungAccountBanner; use anthropic::AnthropicModelMode; use anyhow::{Context as _, Result, anyhow}; -use client::{Client, UserStore, zed_urls}; +use client::{ + Client, NeedsLlmTokenRefresh, RefreshLlmTokenListener, UserStore, + global_llm_token as global_llm_api_token, zed_urls, +}; use cloud_api_types::{OrganizationId, Plan}; use cloud_llm_client::{ CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, CLIENT_SUPPORTS_STATUS_STREAM_ENDED_HEADER_NAME, @@ -24,10 +27,9 @@ use language_model::{ LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelEffortLevel, LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, - LanguageModelToolChoice, LanguageModelToolSchemaFormat, LlmApiToken, NeedsLlmTokenRefresh, - OPEN_AI_PROVIDER_ID, OPEN_AI_PROVIDER_NAME, PaymentRequiredError, RateLimiter, - RefreshLlmTokenListener, X_AI_PROVIDER_ID, X_AI_PROVIDER_NAME, ZED_CLOUD_PROVIDER_ID, - ZED_CLOUD_PROVIDER_NAME, + LanguageModelToolChoice, LanguageModelToolSchemaFormat, LlmApiToken, OPEN_AI_PROVIDER_ID, + OPEN_AI_PROVIDER_NAME, PaymentRequiredError, RateLimiter, X_AI_PROVIDER_ID, X_AI_PROVIDER_NAME, + ZED_CLOUD_PROVIDER_ID, ZED_CLOUD_PROVIDER_NAME, }; use release_channel::AppVersion; use schemars::JsonSchema; @@ -111,7 +113,7 @@ impl State { cx: &mut Context, ) -> Self { let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx); - let llm_api_token = LlmApiToken::global(cx); + let llm_api_token = global_llm_api_token(cx); Self { client: client.clone(), llm_api_token, @@ -226,7 +228,9 @@ impl State { organization_id: Option, ) -> Result { let http_client = &client.http_client(); - let token = llm_api_token.acquire(&client, organization_id).await?; + let token = client + .acquire_llm_token(&llm_api_token, organization_id) + .await?; let request = http_client::Request::builder() .method(Method::GET) @@ -414,8 +418,8 @@ impl CloudLanguageModel { ) -> Result { let http_client = &client.http_client(); - let mut token = llm_api_token - .acquire(&client, organization_id.clone()) + let mut token = client + .acquire_llm_token(&llm_api_token, organization_id.clone()) .await?; let mut refreshed_token = false; @@ -447,8 +451,8 @@ impl CloudLanguageModel { } if !refreshed_token && response.needs_llm_token_refresh() { - token = llm_api_token - .refresh(&client, organization_id.clone()) + token = client + .refresh_llm_token(&llm_api_token, organization_id.clone()) .await?; refreshed_token = true; continue; @@ -713,7 +717,9 @@ impl LanguageModel for CloudLanguageModel { into_google(request, model_id.clone(), GoogleModelMode::Default); async move { let http_client = &client.http_client(); - let token = llm_api_token.acquire(&client, organization_id).await?; + let token = client + .acquire_llm_token(&llm_api_token, organization_id) + .await?; let request_body = CountTokensBody { provider: cloud_llm_client::LanguageModelProvider::Google, diff --git a/crates/web_search_providers/src/cloud.rs b/crates/web_search_providers/src/cloud.rs index 17addd24d445a666138a1b37fef872beedd07aed..b72c6f1f7a4302de17ceb5c6ee8b953ab4c2bf04 100644 --- a/crates/web_search_providers/src/cloud.rs +++ b/crates/web_search_providers/src/cloud.rs @@ -1,13 +1,13 @@ use std::sync::Arc; use anyhow::{Context as _, Result}; -use client::{Client, UserStore}; +use client::{Client, NeedsLlmTokenRefresh, UserStore, global_llm_token as global_llm_api_token}; use cloud_api_types::OrganizationId; use cloud_llm_client::{WebSearchBody, WebSearchResponse}; use futures::AsyncReadExt as _; use gpui::{App, AppContext, Context, Entity, Task}; use http_client::{HttpClient, Method}; -use language_model::{LlmApiToken, NeedsLlmTokenRefresh}; +use language_model::LlmApiToken; use web_search::{WebSearchProvider, WebSearchProviderId}; pub struct CloudWebSearchProvider { @@ -30,7 +30,7 @@ pub struct State { impl State { pub fn new(client: Arc, user_store: Entity, cx: &mut Context) -> Self { - let llm_api_token = LlmApiToken::global(cx); + let llm_api_token = global_llm_api_token(cx); Self { client, @@ -73,8 +73,8 @@ async fn perform_web_search( let http_client = &client.http_client(); let mut retries_remaining = MAX_RETRIES; - let mut token = llm_api_token - .acquire(&client, organization_id.clone()) + let mut token = client + .acquire_llm_token(&llm_api_token, organization_id.clone()) .await?; loop { @@ -100,8 +100,8 @@ async fn perform_web_search( response.body_mut().read_to_string(&mut body).await?; return Ok(serde_json::from_str(&body)?); } else if response.needs_llm_token_refresh() { - token = llm_api_token - .refresh(&client, organization_id.clone()) + token = client + .refresh_llm_token(&llm_api_token, organization_id.clone()) .await?; retries_remaining -= 1; } else { diff --git a/crates/zed/src/main.rs b/crates/zed/src/main.rs index 303f21b8ffa62f9d9f380d9c18beecd77775df20..7e06eb9b9a0dcbeb2ba1844bbb86b0c23a0a4822 100644 --- a/crates/zed/src/main.rs +++ b/crates/zed/src/main.rs @@ -10,7 +10,7 @@ use agent_ui::AgentPanel; use anyhow::{Context as _, Error, Result}; use clap::Parser; use cli::FORCE_CLI_MODE_ENV_VAR_NAME; -use client::{Client, ProxySettings, UserStore, parse_zed_link}; +use client::{Client, ProxySettings, RefreshLlmTokenListener, UserStore, parse_zed_link}; use collab_ui::channel_view::ChannelView; use collections::HashMap; use crashes::InitCrashHandler; @@ -664,7 +664,12 @@ fn main() { ); copilot_ui::init(&app_state, cx); - language_model::init(app_state.user_store.clone(), app_state.client.clone(), cx); + language_model::init(cx); + RefreshLlmTokenListener::register( + app_state.client.clone(), + app_state.user_store.clone(), + cx, + ); language_models::init(app_state.user_store.clone(), app_state.client.clone(), cx); acp_tools::init(cx); zed::telemetry_log::init(cx); diff --git a/crates/zed/src/visual_test_runner.rs b/crates/zed/src/visual_test_runner.rs index e5713e90df397a01af850af55338897f9d437e55..6e457d5addaeba7cd629482757b1b6c5e9c6838b 100644 --- a/crates/zed/src/visual_test_runner.rs +++ b/crates/zed/src/visual_test_runner.rs @@ -201,7 +201,12 @@ fn run_visual_tests(project_path: PathBuf, update_baseline: bool) -> Result<()> }); prompt_store::init(cx); let prompt_builder = prompt_store::PromptBuilder::load(app_state.fs.clone(), false, cx); - language_model::init(app_state.user_store.clone(), app_state.client.clone(), cx); + language_model::init(cx); + client::RefreshLlmTokenListener::register( + app_state.client.clone(), + app_state.user_store.clone(), + cx, + ); language_models::init(app_state.user_store.clone(), app_state.client.clone(), cx); git_ui::init(cx); project::AgentRegistryStore::init_global( diff --git a/crates/zed/src/zed.rs b/crates/zed/src/zed.rs index 01e2354849f3a70399c680c44bd1a3cfbeb64dc4..1e6d86addf8623f74ab418f238fe85fe3f976d86 100644 --- a/crates/zed/src/zed.rs +++ b/crates/zed/src/zed.rs @@ -5015,7 +5015,12 @@ mod tests { cx, ); image_viewer::init(cx); - language_model::init(app_state.user_store.clone(), app_state.client.clone(), cx); + language_model::init(cx); + client::RefreshLlmTokenListener::register( + app_state.client.clone(), + app_state.user_store.clone(), + cx, + ); language_models::init(app_state.user_store.clone(), app_state.client.clone(), cx); web_search::init(cx); git_graph::init(cx); diff --git a/crates/zed/src/zed/edit_prediction_registry.rs b/crates/zed/src/zed/edit_prediction_registry.rs index 8c9e74a42e6c3ddb2b340ac58da39752009825f0..d09dc07af839a681cea96d43217c4217927864d5 100644 --- a/crates/zed/src/zed/edit_prediction_registry.rs +++ b/crates/zed/src/zed/edit_prediction_registry.rs @@ -313,7 +313,12 @@ mod tests { let app_state = cx.update(|cx| { let app_state = AppState::test(cx); client::init(&app_state.client, cx); - language_model::init(app_state.user_store.clone(), app_state.client.clone(), cx); + language_model::init(cx); + client::RefreshLlmTokenListener::register( + app_state.client.clone(), + app_state.user_store.clone(), + cx, + ); editor::init(cx); app_state });