language_model: Invert dep on client

Jakub Konka created

Change summary

Cargo.lock                                                        |   3 
crates/agent/src/tests/mod.rs                                     |   8 
crates/agent_servers/src/e2e_tests.rs                             |   4 
crates/agent_ui/src/agent_configuration/add_llm_provider_modal.rs |   2 
crates/agent_ui/src/agent_diff.rs                                 |   4 
crates/agent_ui/src/inline_assistant.rs                           |   3 
crates/client/Cargo.toml                                          |   2 
crates/client/src/client.rs                                       |  66 
crates/client/src/llm_token.rs                                    | 116 
crates/edit_prediction/src/capture_example.rs                     |   5 
crates/edit_prediction/src/edit_prediction.rs                     |  30 
crates/edit_prediction/src/edit_prediction_tests.rs               |   9 
crates/edit_prediction_cli/src/headless.rs                        |   5 
crates/eval_cli/src/headless.rs                                   |   5 
crates/language_model/Cargo.toml                                  |   1 
crates/language_model/src/language_model.rs                       |  11 
crates/language_model/src/model/cloud_model.rs                    | 158 
crates/language_models/src/provider/cloud.rs                      |  30 
crates/web_search_providers/src/cloud.rs                          |  14 
crates/zed/src/main.rs                                            |   9 
crates/zed/src/visual_test_runner.rs                              |   7 
crates/zed/src/zed.rs                                             |   7 
crates/zed/src/zed/edit_prediction_registry.rs                    |   7 
23 files changed, 304 insertions(+), 202 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -2865,6 +2865,7 @@ dependencies = [
  "chrono",
  "clock",
  "cloud_api_client",
+ "cloud_api_types",
  "cloud_llm_client",
  "collections",
  "credentials_provider",
@@ -2878,6 +2879,7 @@ dependencies = [
  "http_client",
  "http_client_tls",
  "httparse",
+ "language_model",
  "log",
  "objc2-foundation",
  "parking_lot",
@@ -9335,7 +9337,6 @@ dependencies = [
  "anthropic",
  "anyhow",
  "base64 0.22.1",
- "client",
  "cloud_api_client",
  "cloud_api_types",
  "cloud_llm_client",

crates/agent/src/tests/mod.rs 🔗

@@ -6,7 +6,7 @@ use acp_thread::{
 use agent_client_protocol::{self as acp};
 use agent_settings::AgentProfileId;
 use anyhow::Result;
-use client::{Client, UserStore};
+use client::{Client, RefreshLlmTokenListener, UserStore};
 use collections::IndexMap;
 use context_server::{ContextServer, ContextServerCommand, ContextServerId};
 use feature_flags::FeatureFlagAppExt as _;
@@ -3253,7 +3253,8 @@ async fn test_agent_connection(cx: &mut TestAppContext) {
         let clock = Arc::new(clock::FakeSystemClock::new());
         let client = Client::new(clock, http_client, cx);
         let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
-        language_model::init(user_store.clone(), client.clone(), cx);
+        language_model::init(cx);
+        RefreshLlmTokenListener::register(client.clone(), user_store.clone(), cx);
         language_models::init(user_store, client.clone(), cx);
         LanguageModelRegistry::test(cx);
     });
@@ -3982,7 +3983,8 @@ async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
                 cx.set_http_client(Arc::new(http_client));
                 let client = Client::production(cx);
                 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
-                language_model::init(user_store.clone(), client.clone(), cx);
+                language_model::init(cx);
+                RefreshLlmTokenListener::register(client.clone(), user_store.clone(), cx);
                 language_models::init(user_store, client.clone(), cx);
             }
         };

crates/agent_servers/src/e2e_tests.rs 🔗

@@ -1,6 +1,7 @@
 use crate::{AgentServer, AgentServerDelegate};
 use acp_thread::{AcpThread, AgentThreadEntry, ToolCall, ToolCallStatus};
 use agent_client_protocol as acp;
+use client::RefreshLlmTokenListener;
 use futures::{FutureExt, StreamExt, channel::mpsc, select};
 use gpui::AppContext;
 use gpui::{Entity, TestAppContext};
@@ -413,7 +414,8 @@ pub async fn init_test(cx: &mut TestAppContext) -> Arc<FakeFs> {
         cx.set_http_client(Arc::new(http_client));
         let client = client::Client::production(cx);
         let user_store = cx.new(|cx| client::UserStore::new(client.clone(), cx));
-        language_model::init(user_store, client, cx);
+        language_model::init(cx);
+        RefreshLlmTokenListener::register(client.clone(), user_store.clone(), cx);
 
         #[cfg(test)]
         project::agent_server_store::AllAgentServersSettings::override_global(

crates/agent_ui/src/agent_diff.rs 🔗

@@ -1808,7 +1808,7 @@ mod tests {
             cx.set_global(settings_store);
             prompt_store::init(cx);
             theme_settings::init(theme::LoadThemes::JustBase, cx);
-            language_model::init_settings(cx);
+            language_model::init(cx);
         });
 
         let fs = FakeFs::new(cx.executor());
@@ -1965,7 +1965,7 @@ mod tests {
             cx.set_global(settings_store);
             prompt_store::init(cx);
             theme_settings::init(theme::LoadThemes::JustBase, cx);
-            language_model::init_settings(cx);
+            language_model::init(cx);
             workspace::register_project_item::<Editor>(cx);
         });
 

crates/agent_ui/src/inline_assistant.rs 🔗

@@ -2114,7 +2114,8 @@ pub mod evals {
             client::init(&client, cx);
             workspace::init(app_state.clone(), cx);
             let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
-            language_model::init(user_store.clone(), client.clone(), cx);
+            language_model::init(cx);
+            RefreshLlmTokenListener::register(client.clone(), user_store.clone(), cx);
             language_models::init(user_store, client.clone(), cx);
 
             cx.set_global(inline_assistant);

crates/client/Cargo.toml 🔗

@@ -22,6 +22,7 @@ base64.workspace = true
 chrono = { workspace = true, features = ["serde"] }
 clock.workspace = true
 cloud_api_client.workspace = true
+cloud_api_types.workspace = true
 cloud_llm_client.workspace = true
 collections.workspace = true
 credentials_provider.workspace = true
@@ -35,6 +36,7 @@ gpui_tokio.workspace = true
 http_client.workspace = true
 http_client_tls.workspace = true
 httparse = "1.10"
+language_model.workspace = true
 log.workspace = true
 parking_lot.workspace = true
 paths.workspace = true

crates/client/src/client.rs 🔗

@@ -1,6 +1,7 @@
 #[cfg(any(test, feature = "test-support"))]
 pub mod test;
 
+mod llm_token;
 mod proxy;
 pub mod telemetry;
 pub mod user;
@@ -13,8 +14,9 @@ use async_tungstenite::tungstenite::{
     http::{HeaderValue, Request, StatusCode},
 };
 use clock::SystemClock;
-use cloud_api_client::CloudApiClient;
 use cloud_api_client::websocket_protocol::MessageToClient;
+use cloud_api_client::{ClientApiError, CloudApiClient};
+use cloud_api_types::OrganizationId;
 use credentials_provider::CredentialsProvider;
 use feature_flags::FeatureFlagAppExt as _;
 use futures::{
@@ -24,6 +26,7 @@ use futures::{
 };
 use gpui::{App, AsyncApp, Entity, Global, Task, WeakEntity, actions};
 use http_client::{HttpClient, HttpClientWithUrl, http, read_proxy_from_env};
+use language_model::LlmApiToken;
 use parking_lot::{Mutex, RwLock};
 use postage::watch;
 use proxy::connect_proxy_stream;
@@ -51,6 +54,7 @@ use tokio::net::TcpStream;
 use url::Url;
 use util::{ConnectionResult, ResultExt};
 
+pub use llm_token::*;
 pub use rpc::*;
 pub use telemetry_events::Event;
 pub use user::*;
@@ -1517,6 +1521,66 @@ impl Client {
         })
     }
 
+    pub async fn acquire_llm_token(
+        &self,
+        llm_token: &LlmApiToken,
+        organization_id: Option<OrganizationId>,
+    ) -> Result<String> {
+        let system_id = self.telemetry().system_id().map(|x| x.to_string());
+        let cloud_client = self.cloud_client();
+        match llm_token
+            .acquire(&cloud_client, system_id, organization_id)
+            .await
+        {
+            Ok(token) => Ok(token),
+            Err(ClientApiError::Unauthorized) => {
+                self.request_sign_out();
+                Err(ClientApiError::Unauthorized).context("Failed to create LLM token")
+            }
+            Err(err) => Err(anyhow::Error::from(err)),
+        }
+    }
+
+    pub async fn refresh_llm_token(
+        &self,
+        llm_token: &LlmApiToken,
+        organization_id: Option<OrganizationId>,
+    ) -> Result<String> {
+        let system_id = self.telemetry().system_id().map(|x| x.to_string());
+        let cloud_client = self.cloud_client();
+        match llm_token
+            .refresh(&cloud_client, system_id, organization_id)
+            .await
+        {
+            Ok(token) => Ok(token),
+            Err(ClientApiError::Unauthorized) => {
+                self.request_sign_out();
+                return Err(ClientApiError::Unauthorized).context("Failed to create LLM token");
+            }
+            Err(err) => return Err(anyhow::Error::from(err)),
+        }
+    }
+
+    pub async fn clear_and_refresh_llm_token(
+        &self,
+        llm_token: &LlmApiToken,
+        organization_id: Option<OrganizationId>,
+    ) -> Result<String> {
+        let system_id = self.telemetry().system_id().map(|x| x.to_string());
+        let cloud_client = self.cloud_client();
+        match llm_token
+            .clear_and_refresh(&cloud_client, system_id, organization_id)
+            .await
+        {
+            Ok(token) => Ok(token),
+            Err(ClientApiError::Unauthorized) => {
+                self.request_sign_out();
+                return Err(ClientApiError::Unauthorized).context("Failed to create LLM token");
+            }
+            Err(err) => return Err(anyhow::Error::from(err)),
+        }
+    }
+
     pub async fn sign_out(self: &Arc<Self>, cx: &AsyncApp) {
         self.state.write().credentials = None;
         self.cloud_client.clear_credentials();

crates/client/src/llm_token.rs 🔗

@@ -0,0 +1,116 @@
+use super::{Client, UserStore};
+use cloud_api_types::websocket_protocol::MessageToClient;
+use cloud_llm_client::{EXPIRED_LLM_TOKEN_HEADER_NAME, OUTDATED_LLM_TOKEN_HEADER_NAME};
+use gpui::{
+    App, AppContext as _, Context, Entity, EventEmitter, Global, ReadGlobal as _, Subscription,
+};
+use language_model::LlmApiToken;
+use std::sync::Arc;
+
+pub trait NeedsLlmTokenRefresh {
+    /// Returns whether the LLM token needs to be refreshed.
+    fn needs_llm_token_refresh(&self) -> bool;
+}
+
+impl NeedsLlmTokenRefresh for http_client::Response<http_client::AsyncBody> {
+    fn needs_llm_token_refresh(&self) -> bool {
+        self.headers().get(EXPIRED_LLM_TOKEN_HEADER_NAME).is_some()
+            || self.headers().get(OUTDATED_LLM_TOKEN_HEADER_NAME).is_some()
+    }
+}
+
+enum TokenRefreshMode {
+    Refresh,
+    ClearAndRefresh,
+}
+
+pub fn global_llm_token(cx: &App) -> LlmApiToken {
+    RefreshLlmTokenListener::global(cx)
+        .read(cx)
+        .llm_api_token
+        .clone()
+}
+
+struct GlobalRefreshLlmTokenListener(Entity<RefreshLlmTokenListener>);
+
+impl Global for GlobalRefreshLlmTokenListener {}
+
+pub struct LlmTokenRefreshedEvent;
+
+pub struct RefreshLlmTokenListener {
+    client: Arc<Client>,
+    user_store: Entity<UserStore>,
+    llm_api_token: LlmApiToken,
+    _subscription: Subscription,
+}
+
+impl EventEmitter<LlmTokenRefreshedEvent> for RefreshLlmTokenListener {}
+
+impl RefreshLlmTokenListener {
+    pub fn register(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut App) {
+        let listener = cx.new(|cx| RefreshLlmTokenListener::new(client, user_store, cx));
+        cx.set_global(GlobalRefreshLlmTokenListener(listener));
+    }
+
+    pub fn global(cx: &App) -> Entity<Self> {
+        GlobalRefreshLlmTokenListener::global(cx).0.clone()
+    }
+
+    fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
+        client.add_message_to_client_handler({
+            let this = cx.weak_entity();
+            move |message, cx| {
+                if let Some(this) = this.upgrade() {
+                    Self::handle_refresh_llm_token(this, message, cx);
+                }
+            }
+        });
+
+        let subscription = cx.subscribe(&user_store, |this, _user_store, event, cx| {
+            if matches!(event, super::user::Event::OrganizationChanged) {
+                this.refresh(TokenRefreshMode::ClearAndRefresh, cx);
+            }
+        });
+
+        Self {
+            client,
+            user_store,
+            llm_api_token: LlmApiToken::default(),
+            _subscription: subscription,
+        }
+    }
+
+    fn refresh(&self, mode: TokenRefreshMode, cx: &mut Context<Self>) {
+        let client = self.client.clone();
+        let llm_api_token = self.llm_api_token.clone();
+        let organization_id = self
+            .user_store
+            .read(cx)
+            .current_organization()
+            .map(|organization| organization.id.clone());
+        cx.spawn(async move |this, cx| {
+            match mode {
+                TokenRefreshMode::Refresh => {
+                    client
+                        .refresh_llm_token(&llm_api_token, organization_id)
+                        .await?;
+                }
+                TokenRefreshMode::ClearAndRefresh => {
+                    client
+                        .clear_and_refresh_llm_token(&llm_api_token, organization_id)
+                        .await?;
+                }
+            }
+            this.update(cx, |_this, cx| cx.emit(LlmTokenRefreshedEvent))
+        })
+        .detach_and_log_err(cx);
+    }
+
+    fn handle_refresh_llm_token(this: Entity<Self>, message: &MessageToClient, cx: &mut App) {
+        match message {
+            MessageToClient::UserUpdated => {
+                this.update(cx, |this, cx| this.refresh(TokenRefreshMode::Refresh, cx));
+            }
+        }
+    }
+}

crates/edit_prediction/src/capture_example.rs 🔗

@@ -1,6 +1,8 @@
 use crate::{StoredEvent, example_spec::ExampleSpec};
 use anyhow::Result;
 use buffer_diff::BufferDiffSnapshot;
+#[cfg(test)]
+use client::RefreshLlmTokenListener;
 use collections::HashMap;
 use gpui::{App, Entity, Task};
 use language::Buffer;
@@ -548,7 +550,8 @@ mod tests {
             let http_client = FakeHttpClient::with_404_response();
             let client = Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
             let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
-            language_model::init(user_store.clone(), client.clone(), cx);
+            language_model::init(cx);
+            RefreshLlmTokenListener::register(client.clone(), user_store.clone(), cx);
             EditPredictionStore::global(&client, &user_store, cx);
         })
     }

crates/edit_prediction/src/edit_prediction.rs 🔗

@@ -1,5 +1,8 @@
 use anyhow::Result;
-use client::{Client, EditPredictionUsage, UserStore};
+use client::{
+    Client, EditPredictionUsage, NeedsLlmTokenRefresh, UserStore,
+    global_llm_token as global_llm_api_token,
+};
 use cloud_api_types::{OrganizationId, SubmitEditPredictionFeedbackBody};
 use cloud_llm_client::predict_edits_v3::{
     PredictEditsV3Request, PredictEditsV3Response, RawCompletionRequest, RawCompletionResponse,
@@ -31,7 +34,7 @@ use heapless::Vec as ArrayVec;
 use language::language_settings::all_language_settings;
 use language::{Anchor, Buffer, File, Point, TextBufferSnapshot, ToOffset, ToPoint};
 use language::{BufferSnapshot, OffsetRangeExt};
-use language_model::{LlmApiToken, NeedsLlmTokenRefresh};
+use language_model::LlmApiToken;
 use project::{DisableAiSettings, Project, ProjectPath, WorktreeId};
 use release_channel::AppVersion;
 use semver::Version;
@@ -748,7 +751,7 @@ impl EditPredictionStore {
     pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
         let data_collection_choice = Self::load_data_collection_choice(cx);
 
-        let llm_token = LlmApiToken::global(cx);
+        let llm_token = global_llm_api_token(cx);
 
         let (reject_tx, reject_rx) = mpsc::unbounded();
         cx.background_spawn({
@@ -877,7 +880,9 @@ impl EditPredictionStore {
             let experiments = cx
                 .background_spawn(async move {
                     let http_client = client.http_client();
-                    let token = llm_token.acquire(&client, organization_id).await?;
+                    let token = client
+                        .acquire_llm_token(&llm_token, organization_id.clone())
+                        .await?;
                     let url = http_client.build_zed_llm_url("/edit_prediction_experiments", &[])?;
                     let request = http_client::Request::builder()
                         .method(Method::GET)
@@ -2539,12 +2544,15 @@ impl EditPredictionStore {
         Res: DeserializeOwned,
     {
         let http_client = client.http_client();
-
         let mut token = if require_auth {
-            Some(llm_token.acquire(&client, organization_id.clone()).await?)
+            Some(
+                client
+                    .acquire_llm_token(&llm_token, organization_id.clone())
+                    .await?,
+            )
         } else {
-            llm_token
-                .acquire(&client, organization_id.clone())
+            client
+                .acquire_llm_token(&llm_token, organization_id.clone())
                 .await
                 .ok()
         };
@@ -2588,7 +2596,11 @@ impl EditPredictionStore {
                 return Ok((serde_json::from_slice(&body)?, usage));
             } else if !did_retry && token.is_some() && response.needs_llm_token_refresh() {
                 did_retry = true;
-                token = Some(llm_token.refresh(&client, organization_id.clone()).await?);
+                token = Some(
+                    client
+                        .refresh_llm_token(&llm_token, organization_id.clone())
+                        .await?,
+                );
             } else {
                 let mut body = String::new();
                 response.body_mut().read_to_string(&mut body).await?;

crates/edit_prediction/src/edit_prediction_tests.rs 🔗

@@ -1,6 +1,6 @@
 use super::*;
 use crate::udiff::apply_diff_to_string;
-use client::{UserStore, test::FakeServer};
+use client::{RefreshLlmTokenListener, UserStore, test::FakeServer};
 use clock::FakeSystemClock;
 use clock::ReplicaId;
 use cloud_api_types::{CreateLlmTokenResponse, LlmToken};
@@ -23,7 +23,7 @@ use language::{
     Anchor, Buffer, Capability, CursorShape, Diagnostic, DiagnosticEntry, DiagnosticSet,
     DiagnosticSeverity, Operation, Point, Selection, SelectionGoal,
 };
-use language_model::RefreshLlmTokenListener;
+
 use lsp::LanguageServerId;
 use parking_lot::Mutex;
 use pretty_assertions::{assert_eq, assert_matches};
@@ -2439,7 +2439,8 @@ fn init_test_with_fake_client(
         client.cloud_client().set_credentials(1, "test".into());
 
         let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
-        language_model::init(user_store.clone(), client.clone(), cx);
+        language_model::init(cx);
+        RefreshLlmTokenListener::register(client.clone(), user_store.clone(), cx);
         let ep_store = EditPredictionStore::global(&client, &user_store, cx);
 
         (
@@ -2891,7 +2892,7 @@ async fn test_unauthenticated_without_custom_url_blocks_prediction_impl(cx: &mut
         cx.update(|cx| client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
     let user_store = cx.update(|cx| cx.new(|cx| client::UserStore::new(client.clone(), cx)));
     cx.update(|cx| {
-        language_model::RefreshLlmTokenListener::register(client.clone(), user_store.clone(), cx);
+        RefreshLlmTokenListener::register(client.clone(), user_store.clone(), cx);
     });
 
     let ep_store = cx.new(|cx| EditPredictionStore::new(client, project.read(cx).user_store(), cx));

crates/edit_prediction_cli/src/headless.rs 🔗

@@ -1,4 +1,4 @@
-use client::{Client, ProxySettings, UserStore};
+use client::{Client, ProxySettings, RefreshLlmTokenListener, UserStore};
 use db::AppDatabase;
 use extension::ExtensionHostProxy;
 use fs::RealFs;
@@ -109,7 +109,8 @@ pub fn init(cx: &mut App) -> EpAppState {
 
     debug_adapter_extension::init(extension_host_proxy.clone(), cx);
     language_extension::init(LspAccess::Noop, extension_host_proxy, languages.clone());
-    language_model::init(user_store.clone(), client.clone(), cx);
+    language_model::init(cx);
+    RefreshLlmTokenListener::register(client.clone(), user_store.clone(), cx);
     language_models::init(user_store.clone(), client.clone(), cx);
     languages::init(languages.clone(), fs.clone(), node_runtime.clone(), cx);
     prompt_store::init(cx);

crates/eval_cli/src/headless.rs 🔗

@@ -1,7 +1,7 @@
 use std::path::PathBuf;
 use std::sync::Arc;
 
-use client::{Client, ProxySettings, UserStore};
+use client::{Client, ProxySettings, RefreshLlmTokenListener, UserStore};
 use db::AppDatabase;
 use extension::ExtensionHostProxy;
 use fs::RealFs;
@@ -108,7 +108,8 @@ pub fn init(cx: &mut App) -> Arc<AgentCliAppState> {
     let extension_host_proxy = ExtensionHostProxy::global(cx);
     debug_adapter_extension::init(extension_host_proxy.clone(), cx);
     language_extension::init(LspAccess::Noop, extension_host_proxy, languages.clone());
-    language_model::init(user_store.clone(), client.clone(), cx);
+    language_model::init(cx);
+    RefreshLlmTokenListener::register(client.clone(), user_store.clone(), cx);
     language_models::init(user_store.clone(), client.clone(), cx);
     languages::init(languages.clone(), fs.clone(), node_runtime.clone(), cx);
     prompt_store::init(cx);

crates/language_model/Cargo.toml 🔗

@@ -20,7 +20,6 @@ anthropic = { workspace = true, features = ["schemars"] }
 anyhow.workspace = true
 credentials_provider.workspace = true
 base64.workspace = true
-client.workspace = true
 cloud_api_client.workspace = true
 cloud_api_types.workspace = true
 cloud_llm_client.workspace = true

crates/language_model/src/language_model.rs 🔗

@@ -11,12 +11,10 @@ pub mod tool_schema;
 pub mod fake_provider;
 
 use anyhow::{Result, anyhow};
-use client::Client;
-use client::UserStore;
 use cloud_llm_client::CompletionRequestStatus;
 use futures::FutureExt;
 use futures::{StreamExt, future::BoxFuture, stream::BoxStream};
-use gpui::{AnyView, App, AsyncApp, Entity, SharedString, Task, Window};
+use gpui::{AnyView, App, AsyncApp, SharedString, Task, Window};
 use http_client::{StatusCode, http};
 use icons::IconName;
 use parking_lot::Mutex;
@@ -39,12 +37,7 @@ pub use crate::tool_schema::LanguageModelToolSchemaFormat;
 pub use env_var::{EnvVar, env_var};
 pub use provider::*;
 
-pub fn init(user_store: Entity<UserStore>, client: Arc<Client>, cx: &mut App) {
-    init_settings(cx);
-    RefreshLlmTokenListener::register(client, user_store, cx);
-}
-
-pub fn init_settings(cx: &mut App) {
+pub fn init(cx: &mut App) {
     registry::init(cx);
 }
 

crates/language_model/src/model/cloud_model.rs 🔗

@@ -1,16 +1,9 @@
 use std::fmt;
 use std::sync::Arc;
 
-use anyhow::{Context as _, Result};
-use client::Client;
-use client::UserStore;
 use cloud_api_client::ClientApiError;
+use cloud_api_client::CloudApiClient;
 use cloud_api_types::OrganizationId;
-use cloud_api_types::websocket_protocol::MessageToClient;
-use cloud_llm_client::{EXPIRED_LLM_TOKEN_HEADER_NAME, OUTDATED_LLM_TOKEN_HEADER_NAME};
-use gpui::{
-    App, AppContext as _, Context, Entity, EventEmitter, Global, ReadGlobal as _, Subscription,
-};
 use smol::lock::{RwLock, RwLockUpgradableReadGuard, RwLockWriteGuard};
 use thiserror::Error;
 
@@ -30,18 +23,12 @@ impl fmt::Display for PaymentRequiredError {
 pub struct LlmApiToken(Arc<RwLock<Option<String>>>);
 
 impl LlmApiToken {
-    pub fn global(cx: &App) -> Self {
-        RefreshLlmTokenListener::global(cx)
-            .read(cx)
-            .llm_api_token
-            .clone()
-    }
-
     pub async fn acquire(
         &self,
-        client: &Arc<Client>,
+        client: &CloudApiClient,
+        system_id: Option<String>,
         organization_id: Option<OrganizationId>,
-    ) -> Result<String> {
+    ) -> Result<String, ClientApiError> {
         let lock = self.0.upgradable_read().await;
         if let Some(token) = lock.as_ref() {
             Ok(token.to_string())
@@ -49,6 +36,7 @@ impl LlmApiToken {
             Self::fetch(
                 RwLockUpgradableReadGuard::upgrade(lock).await,
                 client,
+                system_id,
                 organization_id,
             )
             .await
@@ -57,10 +45,11 @@ impl LlmApiToken {
 
     pub async fn refresh(
         &self,
-        client: &Arc<Client>,
+        client: &CloudApiClient,
+        system_id: Option<String>,
         organization_id: Option<OrganizationId>,
-    ) -> Result<String> {
-        Self::fetch(self.0.write().await, client, organization_id).await
+    ) -> Result<String, ClientApiError> {
+        Self::fetch(self.0.write().await, client, system_id, organization_id).await
     }
 
     /// Clears the existing token before attempting to fetch a new one.
@@ -69,28 +58,22 @@ impl LlmApiToken {
     /// leave a token for the wrong organization.
     pub async fn clear_and_refresh(
         &self,
-        client: &Arc<Client>,
+        client: &CloudApiClient,
+        system_id: Option<String>,
         organization_id: Option<OrganizationId>,
-    ) -> Result<String> {
+    ) -> Result<String, ClientApiError> {
         let mut lock = self.0.write().await;
         *lock = None;
-        Self::fetch(lock, client, organization_id).await
+        Self::fetch(lock, client, system_id, organization_id).await
     }
 
     async fn fetch(
         mut lock: RwLockWriteGuard<'_, Option<String>>,
-        client: &Arc<Client>,
+        client: &CloudApiClient,
+        system_id: Option<String>,
         organization_id: Option<OrganizationId>,
-    ) -> Result<String> {
-        let system_id = client
-            .telemetry()
-            .system_id()
-            .map(|system_id| system_id.to_string());
-
-        let result = client
-            .cloud_client()
-            .create_llm_token(system_id, organization_id)
-            .await;
+    ) -> Result<String, ClientApiError> {
+        let result = client.create_llm_token(system_id, organization_id).await;
         match result {
             Ok(response) => {
                 *lock = Some(response.token.0.clone());
@@ -98,112 +81,7 @@ impl LlmApiToken {
             }
             Err(err) => {
                 *lock = None;
-                match err {
-                    ClientApiError::Unauthorized => {
-                        client.request_sign_out();
-                        Err(err).context("Failed to create LLM token")
-                    }
-                    ClientApiError::Other(err) => Err(err),
-                }
-            }
-        }
-    }
-}
-
-pub trait NeedsLlmTokenRefresh {
-    /// Returns whether the LLM token needs to be refreshed.
-    fn needs_llm_token_refresh(&self) -> bool;
-}
-
-impl NeedsLlmTokenRefresh for http_client::Response<http_client::AsyncBody> {
-    fn needs_llm_token_refresh(&self) -> bool {
-        self.headers().get(EXPIRED_LLM_TOKEN_HEADER_NAME).is_some()
-            || self.headers().get(OUTDATED_LLM_TOKEN_HEADER_NAME).is_some()
-    }
-}
-
-enum TokenRefreshMode {
-    Refresh,
-    ClearAndRefresh,
-}
-
-struct GlobalRefreshLlmTokenListener(Entity<RefreshLlmTokenListener>);
-
-impl Global for GlobalRefreshLlmTokenListener {}
-
-pub struct LlmTokenRefreshedEvent;
-
-pub struct RefreshLlmTokenListener {
-    client: Arc<Client>,
-    user_store: Entity<UserStore>,
-    llm_api_token: LlmApiToken,
-    _subscription: Subscription,
-}
-
-impl EventEmitter<LlmTokenRefreshedEvent> for RefreshLlmTokenListener {}
-
-impl RefreshLlmTokenListener {
-    pub fn register(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut App) {
-        let listener = cx.new(|cx| RefreshLlmTokenListener::new(client, user_store, cx));
-        cx.set_global(GlobalRefreshLlmTokenListener(listener));
-    }
-
-    pub fn global(cx: &App) -> Entity<Self> {
-        GlobalRefreshLlmTokenListener::global(cx).0.clone()
-    }
-
-    fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
-        client.add_message_to_client_handler({
-            let this = cx.weak_entity();
-            move |message, cx| {
-                if let Some(this) = this.upgrade() {
-                    Self::handle_refresh_llm_token(this, message, cx);
-                }
-            }
-        });
-
-        let subscription = cx.subscribe(&user_store, |this, _user_store, event, cx| {
-            if matches!(event, client::user::Event::OrganizationChanged) {
-                this.refresh(TokenRefreshMode::ClearAndRefresh, cx);
-            }
-        });
-
-        Self {
-            client,
-            user_store,
-            llm_api_token: LlmApiToken::default(),
-            _subscription: subscription,
-        }
-    }
-
-    fn refresh(&self, mode: TokenRefreshMode, cx: &mut Context<Self>) {
-        let client = self.client.clone();
-        let llm_api_token = self.llm_api_token.clone();
-        let organization_id = self
-            .user_store
-            .read(cx)
-            .current_organization()
-            .map(|organization| organization.id.clone());
-        cx.spawn(async move |this, cx| {
-            match mode {
-                TokenRefreshMode::Refresh => {
-                    llm_api_token.refresh(&client, organization_id).await?;
-                }
-                TokenRefreshMode::ClearAndRefresh => {
-                    llm_api_token
-                        .clear_and_refresh(&client, organization_id)
-                        .await?;
-                }
-            }
-            this.update(cx, |_this, cx| cx.emit(LlmTokenRefreshedEvent))
-        })
-        .detach_and_log_err(cx);
-    }
-
-    fn handle_refresh_llm_token(this: Entity<Self>, message: &MessageToClient, cx: &mut App) {
-        match message {
-            MessageToClient::UserUpdated => {
-                this.update(cx, |this, cx| this.refresh(TokenRefreshMode::Refresh, cx));
+                Err(err)
             }
         }
     }

crates/language_models/src/provider/cloud.rs 🔗

@@ -1,7 +1,10 @@
 use ai_onboarding::YoungAccountBanner;
 use anthropic::AnthropicModelMode;
 use anyhow::{Context as _, Result, anyhow};
-use client::{Client, UserStore, zed_urls};
+use client::{
+    Client, NeedsLlmTokenRefresh, RefreshLlmTokenListener, UserStore,
+    global_llm_token as global_llm_api_token, zed_urls,
+};
 use cloud_api_types::{OrganizationId, Plan};
 use cloud_llm_client::{
     CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, CLIENT_SUPPORTS_STATUS_STREAM_ENDED_HEADER_NAME,
@@ -24,10 +27,9 @@ use language_model::{
     LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelEffortLevel,
     LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
     LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
-    LanguageModelToolChoice, LanguageModelToolSchemaFormat, LlmApiToken, NeedsLlmTokenRefresh,
-    OPEN_AI_PROVIDER_ID, OPEN_AI_PROVIDER_NAME, PaymentRequiredError, RateLimiter,
-    RefreshLlmTokenListener, X_AI_PROVIDER_ID, X_AI_PROVIDER_NAME, ZED_CLOUD_PROVIDER_ID,
-    ZED_CLOUD_PROVIDER_NAME,
+    LanguageModelToolChoice, LanguageModelToolSchemaFormat, LlmApiToken, OPEN_AI_PROVIDER_ID,
+    OPEN_AI_PROVIDER_NAME, PaymentRequiredError, RateLimiter, X_AI_PROVIDER_ID, X_AI_PROVIDER_NAME,
+    ZED_CLOUD_PROVIDER_ID, ZED_CLOUD_PROVIDER_NAME,
 };
 use release_channel::AppVersion;
 use schemars::JsonSchema;
@@ -111,7 +113,7 @@ impl State {
         cx: &mut Context<Self>,
     ) -> Self {
         let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
-        let llm_api_token = LlmApiToken::global(cx);
+        let llm_api_token = global_llm_api_token(cx);
         Self {
             client: client.clone(),
             llm_api_token,
@@ -226,7 +228,9 @@ impl State {
         organization_id: Option<OrganizationId>,
     ) -> Result<ListModelsResponse> {
         let http_client = &client.http_client();
-        let token = llm_api_token.acquire(&client, organization_id).await?;
+        let token = client
+            .acquire_llm_token(&llm_api_token, organization_id)
+            .await?;
 
         let request = http_client::Request::builder()
             .method(Method::GET)
@@ -414,8 +418,8 @@ impl CloudLanguageModel {
     ) -> Result<PerformLlmCompletionResponse> {
         let http_client = &client.http_client();
 
-        let mut token = llm_api_token
-            .acquire(&client, organization_id.clone())
+        let mut token = client
+            .acquire_llm_token(&llm_api_token, organization_id.clone())
             .await?;
         let mut refreshed_token = false;
 
@@ -447,8 +451,8 @@ impl CloudLanguageModel {
             }
 
             if !refreshed_token && response.needs_llm_token_refresh() {
-                token = llm_api_token
-                    .refresh(&client, organization_id.clone())
+                token = client
+                    .refresh_llm_token(&llm_api_token, organization_id.clone())
                     .await?;
                 refreshed_token = true;
                 continue;
@@ -713,7 +717,9 @@ impl LanguageModel for CloudLanguageModel {
                     into_google(request, model_id.clone(), GoogleModelMode::Default);
                 async move {
                     let http_client = &client.http_client();
-                    let token = llm_api_token.acquire(&client, organization_id).await?;
+                    let token = client
+                        .acquire_llm_token(&llm_api_token, organization_id)
+                        .await?;
 
                     let request_body = CountTokensBody {
                         provider: cloud_llm_client::LanguageModelProvider::Google,

crates/web_search_providers/src/cloud.rs 🔗

@@ -1,13 +1,13 @@
 use std::sync::Arc;
 
 use anyhow::{Context as _, Result};
-use client::{Client, UserStore};
+use client::{Client, NeedsLlmTokenRefresh, UserStore, global_llm_token as global_llm_api_token};
 use cloud_api_types::OrganizationId;
 use cloud_llm_client::{WebSearchBody, WebSearchResponse};
 use futures::AsyncReadExt as _;
 use gpui::{App, AppContext, Context, Entity, Task};
 use http_client::{HttpClient, Method};
-use language_model::{LlmApiToken, NeedsLlmTokenRefresh};
+use language_model::LlmApiToken;
 use web_search::{WebSearchProvider, WebSearchProviderId};
 
 pub struct CloudWebSearchProvider {
@@ -30,7 +30,7 @@ pub struct State {
 
 impl State {
     pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
-        let llm_api_token = LlmApiToken::global(cx);
+        let llm_api_token = global_llm_api_token(cx);
 
         Self {
             client,
@@ -73,8 +73,8 @@ async fn perform_web_search(
 
     let http_client = &client.http_client();
     let mut retries_remaining = MAX_RETRIES;
-    let mut token = llm_api_token
-        .acquire(&client, organization_id.clone())
+    let mut token = client
+        .acquire_llm_token(&llm_api_token, organization_id.clone())
         .await?;
 
     loop {
@@ -100,8 +100,8 @@ async fn perform_web_search(
             response.body_mut().read_to_string(&mut body).await?;
             return Ok(serde_json::from_str(&body)?);
         } else if response.needs_llm_token_refresh() {
-            token = llm_api_token
-                .refresh(&client, organization_id.clone())
+            token = client
+                .refresh_llm_token(&llm_api_token, organization_id.clone())
                 .await?;
             retries_remaining -= 1;
         } else {

crates/zed/src/main.rs 🔗

@@ -10,7 +10,7 @@ use agent_ui::AgentPanel;
 use anyhow::{Context as _, Error, Result};
 use clap::Parser;
 use cli::FORCE_CLI_MODE_ENV_VAR_NAME;
-use client::{Client, ProxySettings, UserStore, parse_zed_link};
+use client::{Client, ProxySettings, RefreshLlmTokenListener, UserStore, parse_zed_link};
 use collab_ui::channel_view::ChannelView;
 use collections::HashMap;
 use crashes::InitCrashHandler;
@@ -664,7 +664,12 @@ fn main() {
         );
 
         copilot_ui::init(&app_state, cx);
-        language_model::init(app_state.user_store.clone(), app_state.client.clone(), cx);
+        language_model::init(cx);
+        RefreshLlmTokenListener::register(
+            app_state.client.clone(),
+            app_state.user_store.clone(),
+            cx,
+        );
         language_models::init(app_state.user_store.clone(), app_state.client.clone(), cx);
         acp_tools::init(cx);
         zed::telemetry_log::init(cx);

crates/zed/src/visual_test_runner.rs 🔗

@@ -201,7 +201,12 @@ fn run_visual_tests(project_path: PathBuf, update_baseline: bool) -> Result<()>
         });
         prompt_store::init(cx);
         let prompt_builder = prompt_store::PromptBuilder::load(app_state.fs.clone(), false, cx);
-        language_model::init(app_state.user_store.clone(), app_state.client.clone(), cx);
+        language_model::init(cx);
+        client::RefreshLlmTokenListener::register(
+            app_state.client.clone(),
+            app_state.user_store.clone(),
+            cx,
+        );
         language_models::init(app_state.user_store.clone(), app_state.client.clone(), cx);
         git_ui::init(cx);
         project::AgentRegistryStore::init_global(

crates/zed/src/zed.rs 🔗

@@ -5015,7 +5015,12 @@ mod tests {
                 cx,
             );
             image_viewer::init(cx);
-            language_model::init(app_state.user_store.clone(), app_state.client.clone(), cx);
+            language_model::init(cx);
+            client::RefreshLlmTokenListener::register(
+                app_state.client.clone(),
+                app_state.user_store.clone(),
+                cx,
+            );
             language_models::init(app_state.user_store.clone(), app_state.client.clone(), cx);
             web_search::init(cx);
             git_graph::init(cx);

crates/zed/src/zed/edit_prediction_registry.rs 🔗

@@ -313,7 +313,12 @@ mod tests {
         let app_state = cx.update(|cx| {
             let app_state = AppState::test(cx);
             client::init(&app_state.client, cx);
-            language_model::init(app_state.user_store.clone(), app_state.client.clone(), cx);
+            language_model::init(cx);
+            client::RefreshLlmTokenListener::register(
+                app_state.client.clone(),
+                app_state.user_store.clone(),
+                cx,
+            );
             editor::init(cx);
             app_state
         });