From 097cfae77e619b537233a35567d11d9261c255b2 Mon Sep 17 00:00:00 2001 From: Marshall Bowers Date: Fri, 23 Jan 2026 15:50:50 -0500 Subject: [PATCH] Add helper method for checking if the LLM token needs to be refreshed (#47511) This PR adds a new `needs_llm_token_refresh` helper method for checking if the LLM token needs to be refreshed. We were duplicating the check for the `x-zed-expired-token` header in a number of spots, and it will be gaining an additional case soon. Release Notes: - N/A --- crates/edit_prediction/src/edit_prediction.rs | 12 +++--------- crates/language_model/src/model/cloud_model.rs | 12 ++++++++++++ crates/language_models/src/provider/cloud.rs | 16 +++++----------- crates/web_search_providers/src/cloud.rs | 10 +++------- 4 files changed, 23 insertions(+), 27 deletions(-) diff --git a/crates/edit_prediction/src/edit_prediction.rs b/crates/edit_prediction/src/edit_prediction.rs index 098136b42f2c92ddb80a43a46bb29ed7518aff34..38bccbe65bddc1ce1763a0d362a52c0db09be69a 100644 --- a/crates/edit_prediction/src/edit_prediction.rs +++ b/crates/edit_prediction/src/edit_prediction.rs @@ -5,7 +5,7 @@ use cloud_llm_client::predict_edits_v3::{ PredictEditsV3Request, PredictEditsV3Response, RawCompletionRequest, RawCompletionResponse, }; use cloud_llm_client::{ - EXPIRED_LLM_TOKEN_HEADER_NAME, EditPredictionRejectReason, EditPredictionRejection, + EditPredictionRejectReason, EditPredictionRejection, MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST, MINIMUM_REQUIRED_VERSION_HEADER_NAME, PredictEditsRequestTrigger, RejectEditPredictionsBodyRef, ZED_VERSION_HEADER_NAME, }; @@ -29,7 +29,7 @@ use gpui::{ use language::language_settings::all_language_settings; use language::{Anchor, Buffer, File, Point, TextBufferSnapshot, ToOffset, ToPoint}; use language::{BufferSnapshot, OffsetRangeExt}; -use language_model::{LlmApiToken, RefreshLlmTokenListener}; +use language_model::{LlmApiToken, NeedsLlmTokenRefresh, RefreshLlmTokenListener}; use project::{Project, ProjectPath, WorktreeId}; use release_channel::AppVersion; use semver::Version; @@ -2046,13 +2046,7 @@ impl EditPredictionStore { let mut body = Vec::new(); response.body_mut().read_to_end(&mut body).await?; return Ok((serde_json::from_slice(&body)?, usage)); - } else if !did_retry - && token.is_some() - && response - .headers() - .get(EXPIRED_LLM_TOKEN_HEADER_NAME) - .is_some() - { + } else if !did_retry && token.is_some() && response.needs_llm_token_refresh() { did_retry = true; token = Some(llm_token.refresh(&client).await?); } else { diff --git a/crates/language_model/src/model/cloud_model.rs b/crates/language_model/src/model/cloud_model.rs index b295fee74be60ee16aa32d51c9810ca155b32010..a3c7f6d0d7459b245cd01f3a0f58bbe8bd00539d 100644 --- a/crates/language_model/src/model/cloud_model.rs +++ b/crates/language_model/src/model/cloud_model.rs @@ -4,6 +4,7 @@ use std::sync::Arc; use anyhow::Result; use client::Client; use cloud_api_types::websocket_protocol::MessageToClient; +use cloud_llm_client::EXPIRED_LLM_TOKEN_HEADER_NAME; use gpui::{App, AppContext as _, Context, Entity, EventEmitter, Global, ReadGlobal as _}; use smol::lock::{RwLock, RwLockUpgradableReadGuard, RwLockWriteGuard}; use thiserror::Error; @@ -52,6 +53,17 @@ impl LlmApiToken { } } +pub trait NeedsLlmTokenRefresh { + /// Returns whether the LLM token needs to be refreshed. + fn needs_llm_token_refresh(&self) -> bool; +} + +impl NeedsLlmTokenRefresh for http_client::Response { + fn needs_llm_token_refresh(&self) -> bool { + self.headers().get(EXPIRED_LLM_TOKEN_HEADER_NAME).is_some() + } +} + struct GlobalRefreshLlmTokenListener(Entity); impl Global for GlobalRefreshLlmTokenListener {} diff --git a/crates/language_models/src/provider/cloud.rs b/crates/language_models/src/provider/cloud.rs index d9ee5dd164296a112c7021148a22e651ebd5abd3..642b8c7bc3ad191346cd53c5500b4780478df3cb 100644 --- a/crates/language_models/src/provider/cloud.rs +++ b/crates/language_models/src/provider/cloud.rs @@ -5,9 +5,8 @@ use chrono::{DateTime, Utc}; use client::{Client, UserStore, zed_urls}; use cloud_llm_client::{ CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, CLIENT_SUPPORTS_X_AI_HEADER_NAME, CompletionBody, - CompletionEvent, CountTokensBody, CountTokensResponse, EXPIRED_LLM_TOKEN_HEADER_NAME, - ListModelsResponse, Plan, PlanV2, SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, - ZED_VERSION_HEADER_NAME, + CompletionEvent, CountTokensBody, CountTokensResponse, ListModelsResponse, Plan, PlanV2, + SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, ZED_VERSION_HEADER_NAME, }; use feature_flags::{CloudThinkingToggleFeatureFlag, FeatureFlagAppExt as _}; use futures::{ @@ -22,8 +21,8 @@ use language_model::{ LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, LanguageModelToolChoice, - LanguageModelToolSchemaFormat, LlmApiToken, PaymentRequiredError, RateLimiter, - RefreshLlmTokenListener, + LanguageModelToolSchemaFormat, LlmApiToken, NeedsLlmTokenRefresh, PaymentRequiredError, + RateLimiter, RefreshLlmTokenListener, }; use release_channel::AppVersion; use schemars::JsonSchema; @@ -425,12 +424,7 @@ impl CloudLanguageModel { }); } - if !refreshed_token - && response - .headers() - .get(EXPIRED_LLM_TOKEN_HEADER_NAME) - .is_some() - { + if !refreshed_token && response.needs_llm_token_refresh() { token = llm_api_token.refresh(&client).await?; refreshed_token = true; continue; diff --git a/crates/web_search_providers/src/cloud.rs b/crates/web_search_providers/src/cloud.rs index 75ffb1da63109c802207e80da167cdb0cc3c9a0a..2f3ccdbb52a884471250ad458e8b7922437cb9ae 100644 --- a/crates/web_search_providers/src/cloud.rs +++ b/crates/web_search_providers/src/cloud.rs @@ -2,11 +2,11 @@ use std::sync::Arc; use anyhow::{Context as _, Result}; use client::Client; -use cloud_llm_client::{EXPIRED_LLM_TOKEN_HEADER_NAME, WebSearchBody, WebSearchResponse}; +use cloud_llm_client::{WebSearchBody, WebSearchResponse}; use futures::AsyncReadExt as _; use gpui::{App, AppContext, Context, Entity, Subscription, Task}; use http_client::{HttpClient, Method}; -use language_model::{LlmApiToken, RefreshLlmTokenListener}; +use language_model::{LlmApiToken, NeedsLlmTokenRefresh, RefreshLlmTokenListener}; use web_search::{WebSearchProvider, WebSearchProviderId}; pub struct CloudWebSearchProvider { @@ -99,11 +99,7 @@ async fn perform_web_search( let mut body = String::new(); response.body_mut().read_to_string(&mut body).await?; return Ok(serde_json::from_str(&body)?); - } else if response - .headers() - .get(EXPIRED_LLM_TOKEN_HEADER_NAME) - .is_some() - { + } else if response.needs_llm_token_refresh() { token = llm_api_token.refresh(&client).await?; retries_remaining -= 1; } else {