Refresh LLM API token on organization change (#50931)

Neel and Tom Houlé created

Emit client-side organization changed events through
`RefreshLlmTokenListener` so it produces the same `RefreshLlmTokenEvent`
used for server-pushed `UserUpdated` messages.

This keeps token refresh fan-out in one place.

Closes CLO-383.

Release Notes:

- N/A

---------

Co-authored-by: Tom Houlé <tom@tomhoule.com>

Change summary

crates/agent/src/edit_agent/evals.rs                |  2 
crates/agent/src/tests/mod.rs                       |  4 +-
crates/agent_servers/src/e2e_tests.rs               |  4 +
crates/agent_ui/src/inline_assistant.rs             |  2 
crates/client/src/user.rs                           | 18 +++++++++-
crates/edit_prediction/src/capture_example.rs       |  2 
crates/edit_prediction/src/edit_prediction_tests.rs |  9 +++--
crates/edit_prediction_cli/src/headless.rs          |  2 
crates/eval/src/eval.rs                             |  2 
crates/eval_cli/src/headless.rs                     |  2 
crates/language_model/src/language_model.rs         |  7 ++-
crates/language_model/src/model/cloud_model.rs      | 25 +++++++++++---
crates/title_bar/src/title_bar.rs                   |  4 +-
crates/zed/src/main.rs                              |  2 
crates/zed/src/visual_test_runner.rs                |  2 
crates/zed/src/zed.rs                               |  2 
crates/zed/src/zed/edit_prediction_registry.rs      |  2 
17 files changed, 61 insertions(+), 30 deletions(-)

Detailed changes

crates/agent/src/edit_agent/evals.rs 🔗

@@ -1423,7 +1423,7 @@ impl EditAgentTest {
             let client = Client::production(cx);
             let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
             settings::init(cx);
-            language_model::init(client.clone(), cx);
+            language_model::init(user_store.clone(), client.clone(), cx);
             language_models::init(user_store, client.clone(), cx);
         });
 

crates/agent/src/tests/mod.rs 🔗

@@ -3167,7 +3167,7 @@ async fn test_agent_connection(cx: &mut TestAppContext) {
         let clock = Arc::new(clock::FakeSystemClock::new());
         let client = Client::new(clock, http_client, cx);
         let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
-        language_model::init(client.clone(), cx);
+        language_model::init(user_store.clone(), client.clone(), cx);
         language_models::init(user_store, client.clone(), cx);
         LanguageModelRegistry::test(cx);
     });
@@ -3791,7 +3791,7 @@ async fn setup(cx: &mut TestAppContext, model: TestModel) -> ThreadTest {
                 cx.set_http_client(Arc::new(http_client));
                 let client = Client::production(cx);
                 let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
-                language_model::init(client.clone(), cx);
+                language_model::init(user_store.clone(), client.clone(), cx);
                 language_models::init(user_store, client.clone(), cx);
             }
         };

crates/agent_servers/src/e2e_tests.rs 🔗

@@ -2,6 +2,7 @@ use crate::{AgentServer, AgentServerDelegate};
 use acp_thread::{AcpThread, AgentThreadEntry, ToolCall, ToolCallStatus};
 use agent_client_protocol as acp;
 use futures::{FutureExt, StreamExt, channel::mpsc, select};
+use gpui::AppContext;
 use gpui::{Entity, TestAppContext};
 use indoc::indoc;
 use project::{FakeFs, Project};
@@ -408,7 +409,8 @@ pub async fn init_test(cx: &mut TestAppContext) -> Arc<FakeFs> {
         let http_client = reqwest_client::ReqwestClient::user_agent("agent tests").unwrap();
         cx.set_http_client(Arc::new(http_client));
         let client = client::Client::production(cx);
-        language_model::init(client, cx);
+        let user_store = cx.new(|cx| client::UserStore::new(client.clone(), cx));
+        language_model::init(user_store, client, cx);
 
         #[cfg(test)]
         project::agent_server_store::AllAgentServersSettings::override_global(

crates/agent_ui/src/inline_assistant.rs 🔗

@@ -2120,7 +2120,7 @@ pub mod test {
             client::init(&client, cx);
             workspace::init(app_state.clone(), cx);
             let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
-            language_model::init(client.clone(), cx);
+            language_model::init(user_store.clone(), client.clone(), cx);
             language_models::init(user_store, client.clone(), cx);
 
             cx.set_global(inline_assistant);

crates/client/src/user.rs 🔗

@@ -140,6 +140,7 @@ pub enum Event {
     ParticipantIndicesChanged,
     PrivateUserInfoUpdated,
     PlanUpdated,
+    OrganizationChanged,
 }
 
 #[derive(Clone, Copy)]
@@ -694,8 +695,21 @@ impl UserStore {
         self.current_organization.clone()
     }
 
-    pub fn set_current_organization(&mut self, organization: Arc<Organization>) {
-        self.current_organization.replace(organization);
+    pub fn set_current_organization(
+        &mut self,
+        organization: Arc<Organization>,
+        cx: &mut Context<Self>,
+    ) {
+        let is_same_organization = self
+            .current_organization
+            .as_ref()
+            .is_some_and(|current| current.id == organization.id);
+
+        if !is_same_organization {
+            self.current_organization.replace(organization);
+            cx.emit(Event::OrganizationChanged);
+            cx.notify();
+        }
     }
 
     pub fn organizations(&self) -> &Vec<Arc<Organization>> {

crates/edit_prediction/src/capture_example.rs 🔗

@@ -533,8 +533,8 @@ mod tests {
             zlog::init_test();
             let http_client = FakeHttpClient::with_404_response();
             let client = Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
-            language_model::init(client.clone(), cx);
             let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
+            language_model::init(user_store.clone(), client.clone(), cx);
             EditPredictionStore::global(&client, &user_store, cx);
         })
     }

crates/edit_prediction/src/edit_prediction_tests.rs 🔗

@@ -1850,9 +1850,8 @@ fn init_test_with_fake_client(
         let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
         client.cloud_client().set_credentials(1, "test".into());
 
-        language_model::init(client.clone(), cx);
-
         let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
+        language_model::init(user_store.clone(), client.clone(), cx);
         let ep_store = EditPredictionStore::global(&client, &user_store, cx);
 
         (
@@ -2218,8 +2217,9 @@ async fn make_test_ep_store(
     });
 
     let client = cx.update(|cx| Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
+    let user_store = cx.update(|cx| cx.new(|cx| client::UserStore::new(client.clone(), cx)));
     cx.update(|cx| {
-        RefreshLlmTokenListener::register(client.clone(), cx);
+        RefreshLlmTokenListener::register(client.clone(), user_store.clone(), cx);
     });
     let _server = FakeServer::for_client(42, &client, cx).await;
 
@@ -2301,8 +2301,9 @@ async fn test_unauthenticated_without_custom_url_blocks_prediction_impl(cx: &mut
 
     let client =
         cx.update(|cx| client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
+    let user_store = cx.update(|cx| cx.new(|cx| client::UserStore::new(client.clone(), cx)));
     cx.update(|cx| {
-        language_model::RefreshLlmTokenListener::register(client.clone(), cx);
+        language_model::RefreshLlmTokenListener::register(client.clone(), user_store.clone(), cx);
     });
 
     let ep_store = cx.new(|cx| EditPredictionStore::new(client, project.read(cx).user_store(), cx));

crates/edit_prediction_cli/src/headless.rs 🔗

@@ -105,7 +105,7 @@ pub fn init(cx: &mut App) -> EpAppState {
 
     debug_adapter_extension::init(extension_host_proxy.clone(), cx);
     language_extension::init(LspAccess::Noop, extension_host_proxy, languages.clone());
-    language_model::init(client.clone(), cx);
+    language_model::init(user_store.clone(), client.clone(), cx);
     language_models::init(user_store.clone(), client.clone(), cx);
     languages::init(languages.clone(), fs.clone(), node_runtime.clone(), cx);
     prompt_store::init(cx);

crates/eval/src/eval.rs 🔗

@@ -429,7 +429,7 @@ pub fn init(cx: &mut App) -> Arc<AgentAppState> {
     let extension_host_proxy = ExtensionHostProxy::global(cx);
     debug_adapter_extension::init(extension_host_proxy.clone(), cx);
     language_extension::init(LspAccess::Noop, extension_host_proxy, languages.clone());
-    language_model::init(client.clone(), cx);
+    language_model::init(user_store.clone(), client.clone(), cx);
     language_models::init(user_store.clone(), client.clone(), cx);
     languages::init(languages.clone(), fs.clone(), node_runtime.clone(), cx);
     prompt_store::init(cx);

crates/eval_cli/src/headless.rs 🔗

@@ -104,7 +104,7 @@ pub fn init(cx: &mut App) -> Arc<AgentCliAppState> {
     let extension_host_proxy = ExtensionHostProxy::global(cx);
     debug_adapter_extension::init(extension_host_proxy.clone(), cx);
     language_extension::init(LspAccess::Noop, extension_host_proxy, languages.clone());
-    language_model::init(client.clone(), cx);
+    language_model::init(user_store.clone(), client.clone(), cx);
     language_models::init(user_store.clone(), client.clone(), cx);
     languages::init(languages.clone(), fs.clone(), node_runtime.clone(), cx);
     prompt_store::init(cx);

crates/language_model/src/language_model.rs 🔗

@@ -13,10 +13,11 @@ pub mod fake_provider;
 use anthropic::{AnthropicError, parse_prompt_too_long};
 use anyhow::{Result, anyhow};
 use client::Client;
+use client::UserStore;
 use cloud_llm_client::CompletionRequestStatus;
 use futures::FutureExt;
 use futures::{StreamExt, future::BoxFuture, stream::BoxStream};
-use gpui::{AnyView, App, AsyncApp, SharedString, Task, Window};
+use gpui::{AnyView, App, AsyncApp, Entity, SharedString, Task, Window};
 use http_client::{StatusCode, http};
 use icons::IconName;
 use open_router::OpenRouterError;
@@ -61,9 +62,9 @@ pub const ZED_CLOUD_PROVIDER_ID: LanguageModelProviderId = LanguageModelProvider
 pub const ZED_CLOUD_PROVIDER_NAME: LanguageModelProviderName =
     LanguageModelProviderName::new("Zed");
 
-pub fn init(client: Arc<Client>, cx: &mut App) {
+pub fn init(user_store: Entity<UserStore>, client: Arc<Client>, cx: &mut App) {
     init_settings(cx);
-    RefreshLlmTokenListener::register(client, cx);
+    RefreshLlmTokenListener::register(client, user_store, cx);
 }
 
 pub fn init_settings(cx: &mut App) {

crates/language_model/src/model/cloud_model.rs 🔗

@@ -3,11 +3,14 @@ use std::sync::Arc;
 
 use anyhow::{Context as _, Result};
 use client::Client;
+use client::UserStore;
 use cloud_api_client::ClientApiError;
 use cloud_api_types::OrganizationId;
 use cloud_api_types::websocket_protocol::MessageToClient;
 use cloud_llm_client::{EXPIRED_LLM_TOKEN_HEADER_NAME, OUTDATED_LLM_TOKEN_HEADER_NAME};
-use gpui::{App, AppContext as _, Context, Entity, EventEmitter, Global, ReadGlobal as _};
+use gpui::{
+    App, AppContext as _, Context, Entity, EventEmitter, Global, ReadGlobal as _, Subscription,
+};
 use smol::lock::{RwLock, RwLockUpgradableReadGuard, RwLockWriteGuard};
 use thiserror::Error;
 
@@ -101,13 +104,15 @@ impl Global for GlobalRefreshLlmTokenListener {}
 
 pub struct RefreshLlmTokenEvent;
 
-pub struct RefreshLlmTokenListener;
+pub struct RefreshLlmTokenListener {
+    _subscription: Subscription,
+}
 
 impl EventEmitter<RefreshLlmTokenEvent> for RefreshLlmTokenListener {}
 
 impl RefreshLlmTokenListener {
-    pub fn register(client: Arc<Client>, cx: &mut App) {
-        let listener = cx.new(|cx| RefreshLlmTokenListener::new(client, cx));
+    pub fn register(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut App) {
+        let listener = cx.new(|cx| RefreshLlmTokenListener::new(client, user_store, cx));
         cx.set_global(GlobalRefreshLlmTokenListener(listener));
     }
 
@@ -115,7 +120,7 @@ impl RefreshLlmTokenListener {
         GlobalRefreshLlmTokenListener::global(cx).0.clone()
     }
 
-    fn new(client: Arc<Client>, cx: &mut Context<Self>) -> Self {
+    fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
         client.add_message_to_client_handler({
             let this = cx.entity();
             move |message, cx| {
@@ -123,7 +128,15 @@ impl RefreshLlmTokenListener {
             }
         });
 
-        Self
+        let subscription = cx.subscribe(&user_store, |_this, _user_store, event, cx| {
+            if matches!(event, client::user::Event::OrganizationChanged) {
+                cx.emit(RefreshLlmTokenEvent);
+            }
+        });
+
+        Self {
+            _subscription: subscription,
+        }
     }
 
     fn handle_refresh_llm_token(this: Entity<Self>, message: &MessageToClient, cx: &mut App) {

crates/title_bar/src/title_bar.rs 🔗

@@ -1014,9 +1014,9 @@ impl TitleBar {
                                     let user_store = user_store.clone();
                                     let organization = organization.clone();
                                     move |_window, cx| {
-                                        user_store.update(cx, |user_store, _cx| {
+                                        user_store.update(cx, |user_store, cx| {
                                             user_store
-                                                .set_current_organization(organization.clone());
+                                                .set_current_organization(organization.clone(), cx);
                                         });
                                     }
                                 },

crates/zed/src/main.rs 🔗

@@ -657,7 +657,7 @@ fn main() {
         );
 
         copilot_ui::init(&app_state, cx);
-        language_model::init(app_state.client.clone(), cx);
+        language_model::init(app_state.user_store.clone(), app_state.client.clone(), cx);
         language_models::init(app_state.user_store.clone(), app_state.client.clone(), cx);
         acp_tools::init(cx);
         zed::telemetry_log::init(cx);

crates/zed/src/visual_test_runner.rs 🔗

@@ -200,7 +200,7 @@ fn run_visual_tests(project_path: PathBuf, update_baseline: bool) -> Result<()>
         });
         prompt_store::init(cx);
         let prompt_builder = prompt_store::PromptBuilder::load(app_state.fs.clone(), false, cx);
-        language_model::init(app_state.client.clone(), cx);
+        language_model::init(app_state.user_store.clone(), app_state.client.clone(), cx);
         language_models::init(app_state.user_store.clone(), app_state.client.clone(), cx);
         git_ui::init(cx);
         project::AgentRegistryStore::init_global(

crates/zed/src/zed.rs 🔗

@@ -5024,7 +5024,7 @@ mod tests {
                 cx,
             );
             image_viewer::init(cx);
-            language_model::init(app_state.client.clone(), cx);
+            language_model::init(app_state.user_store.clone(), app_state.client.clone(), cx);
             language_models::init(app_state.user_store.clone(), app_state.client.clone(), cx);
             web_search::init(cx);
             git_graph::init(cx);

crates/zed/src/zed/edit_prediction_registry.rs 🔗

@@ -316,7 +316,7 @@ mod tests {
         let app_state = cx.update(|cx| {
             let app_state = AppState::test(cx);
             client::init(&app_state.client, cx);
-            language_model::init(app_state.client.clone(), cx);
+            language_model::init(app_state.user_store.clone(), app_state.client.clone(), cx);
             editor::init(cx);
             app_state
         });