From a30e4d52288a9c5f7b035b42265a63a6c87274be Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tom=20Houl=C3=A9?= <13155277+tomhoule@users.noreply.github.com> Date: Thu, 19 Mar 2026 11:09:55 +0100 Subject: [PATCH] language_model: Clear the LlmApiToken first on org switch (#51826) When we switch organizations, we try and refresh the token. If the token refresh fails, we are left with the old LlmApiToken, which is for the wrong organization. In this commit, we make sure to clear the old token before trying a refresh on organization switch. Release Notes: - N/A --------- Co-authored-by: Neel --- .../language_model/src/model/cloud_model.rs | 51 +++++++++++++++---- 1 file changed, 41 insertions(+), 10 deletions(-) diff --git a/crates/language_model/src/model/cloud_model.rs b/crates/language_model/src/model/cloud_model.rs index 527d24ec18c0f9ef08576a71fe92562dd94d4afd..f6ad907483e5946652752895d0a48ec129660b0b 100644 --- a/crates/language_model/src/model/cloud_model.rs +++ b/crates/language_model/src/model/cloud_model.rs @@ -63,6 +63,20 @@ impl LlmApiToken { Self::fetch(self.0.write().await, client, organization_id).await } + /// Clears the existing token before attempting to fetch a new one. + /// + /// Used when switching organizations so that a failed refresh doesn't + /// leave a token for the wrong organization. + pub async fn clear_and_refresh( + &self, + client: &Arc, + organization_id: Option, + ) -> Result { + let mut lock = self.0.write().await; + *lock = None; + Self::fetch(lock, client, organization_id).await + } + async fn fetch( mut lock: RwLockWriteGuard<'_, Option>, client: &Arc, @@ -82,13 +96,16 @@ impl LlmApiToken { *lock = Some(response.token.0.clone()); Ok(response.token.0) } - Err(err) => match err { - ClientApiError::Unauthorized => { - client.request_sign_out(); - Err(err).context("Failed to create LLM token") + Err(err) => { + *lock = None; + match err { + ClientApiError::Unauthorized => { + client.request_sign_out(); + Err(err).context("Failed to create LLM token") + } + ClientApiError::Other(err) => Err(err), } - ClientApiError::Other(err) => Err(err), - }, + } } } } @@ -105,6 +122,11 @@ impl NeedsLlmTokenRefresh for http_client::Response { } } +enum TokenRefreshMode { + Refresh, + ClearAndRefresh, +} + struct GlobalRefreshLlmTokenListener(Entity); impl Global for GlobalRefreshLlmTokenListener {} @@ -140,7 +162,7 @@ impl RefreshLlmTokenListener { let subscription = cx.subscribe(&user_store, |this, _user_store, event, cx| { if matches!(event, client::user::Event::OrganizationChanged) { - this.refresh(cx); + this.refresh(TokenRefreshMode::ClearAndRefresh, cx); } }); @@ -152,7 +174,7 @@ impl RefreshLlmTokenListener { } } - fn refresh(&self, cx: &mut Context) { + fn refresh(&self, mode: TokenRefreshMode, cx: &mut Context) { let client = self.client.clone(); let llm_api_token = self.llm_api_token.clone(); let organization_id = self @@ -161,7 +183,16 @@ impl RefreshLlmTokenListener { .current_organization() .map(|organization| organization.id.clone()); cx.spawn(async move |this, cx| { - llm_api_token.refresh(&client, organization_id).await?; + match mode { + TokenRefreshMode::Refresh => { + llm_api_token.refresh(&client, organization_id).await?; + } + TokenRefreshMode::ClearAndRefresh => { + llm_api_token + .clear_and_refresh(&client, organization_id) + .await?; + } + } this.update(cx, |_this, cx| cx.emit(LlmTokenRefreshedEvent)) }) .detach_and_log_err(cx); @@ -170,7 +201,7 @@ impl RefreshLlmTokenListener { fn handle_refresh_llm_token(this: Entity, message: &MessageToClient, cx: &mut App) { match message { MessageToClient::UserUpdated => { - this.update(cx, |this, cx| this.refresh(cx)); + this.update(cx, |this, cx| this.refresh(TokenRefreshMode::Refresh, cx)); } } }