Add helper method for checking if the LLM token needs to be refreshed (#47511)

Marshall Bowers created

This PR adds a new `needs_llm_token_refresh` helper method for checking
if the LLM token needs to be refreshed.

We were duplicating the check for the `x-zed-expired-token` header in a
number of spots, and it will be gaining an additional case soon.

Release Notes:

- N/A

Change summary

crates/edit_prediction/src/edit_prediction.rs  | 12 +++---------
crates/language_model/src/model/cloud_model.rs | 12 ++++++++++++
crates/language_models/src/provider/cloud.rs   | 16 +++++-----------
crates/web_search_providers/src/cloud.rs       | 10 +++-------
4 files changed, 23 insertions(+), 27 deletions(-)

Detailed changes

crates/edit_prediction/src/edit_prediction.rs 🔗

@@ -5,7 +5,7 @@ use cloud_llm_client::predict_edits_v3::{
     PredictEditsV3Request, PredictEditsV3Response, RawCompletionRequest, RawCompletionResponse,
 };
 use cloud_llm_client::{
-    EXPIRED_LLM_TOKEN_HEADER_NAME, EditPredictionRejectReason, EditPredictionRejection,
+    EditPredictionRejectReason, EditPredictionRejection,
     MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
     PredictEditsRequestTrigger, RejectEditPredictionsBodyRef, ZED_VERSION_HEADER_NAME,
 };
@@ -29,7 +29,7 @@ use gpui::{
 use language::language_settings::all_language_settings;
 use language::{Anchor, Buffer, File, Point, TextBufferSnapshot, ToOffset, ToPoint};
 use language::{BufferSnapshot, OffsetRangeExt};
-use language_model::{LlmApiToken, RefreshLlmTokenListener};
+use language_model::{LlmApiToken, NeedsLlmTokenRefresh, RefreshLlmTokenListener};
 use project::{Project, ProjectPath, WorktreeId};
 use release_channel::AppVersion;
 use semver::Version;
@@ -2046,13 +2046,7 @@ impl EditPredictionStore {
                 let mut body = Vec::new();
                 response.body_mut().read_to_end(&mut body).await?;
                 return Ok((serde_json::from_slice(&body)?, usage));
-            } else if !did_retry
-                && token.is_some()
-                && response
-                    .headers()
-                    .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
-                    .is_some()
-            {
+            } else if !did_retry && token.is_some() && response.needs_llm_token_refresh() {
                 did_retry = true;
                 token = Some(llm_token.refresh(&client).await?);
             } else {

crates/language_model/src/model/cloud_model.rs 🔗

@@ -4,6 +4,7 @@ use std::sync::Arc;
 use anyhow::Result;
 use client::Client;
 use cloud_api_types::websocket_protocol::MessageToClient;
+use cloud_llm_client::EXPIRED_LLM_TOKEN_HEADER_NAME;
 use gpui::{App, AppContext as _, Context, Entity, EventEmitter, Global, ReadGlobal as _};
 use smol::lock::{RwLock, RwLockUpgradableReadGuard, RwLockWriteGuard};
 use thiserror::Error;
@@ -52,6 +53,17 @@ impl LlmApiToken {
     }
 }
 
+pub trait NeedsLlmTokenRefresh {
+    /// Returns whether the LLM token needs to be refreshed.
+    fn needs_llm_token_refresh(&self) -> bool;
+}
+
+impl NeedsLlmTokenRefresh for http_client::Response<http_client::AsyncBody> {
+    fn needs_llm_token_refresh(&self) -> bool {
+        self.headers().get(EXPIRED_LLM_TOKEN_HEADER_NAME).is_some()
+    }
+}
+
 struct GlobalRefreshLlmTokenListener(Entity<RefreshLlmTokenListener>);
 
 impl Global for GlobalRefreshLlmTokenListener {}

crates/language_models/src/provider/cloud.rs 🔗

@@ -5,9 +5,8 @@ use chrono::{DateTime, Utc};
 use client::{Client, UserStore, zed_urls};
 use cloud_llm_client::{
     CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, CLIENT_SUPPORTS_X_AI_HEADER_NAME, CompletionBody,
-    CompletionEvent, CountTokensBody, CountTokensResponse, EXPIRED_LLM_TOKEN_HEADER_NAME,
-    ListModelsResponse, Plan, PlanV2, SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME,
-    ZED_VERSION_HEADER_NAME,
+    CompletionEvent, CountTokensBody, CountTokensResponse, ListModelsResponse, Plan, PlanV2,
+    SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, ZED_VERSION_HEADER_NAME,
 };
 use feature_flags::{CloudThinkingToggleFeatureFlag, FeatureFlagAppExt as _};
 use futures::{
@@ -22,8 +21,8 @@ use language_model::{
     LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId, LanguageModelName,
     LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
     LanguageModelProviderState, LanguageModelRequest, LanguageModelToolChoice,
-    LanguageModelToolSchemaFormat, LlmApiToken, PaymentRequiredError, RateLimiter,
-    RefreshLlmTokenListener,
+    LanguageModelToolSchemaFormat, LlmApiToken, NeedsLlmTokenRefresh, PaymentRequiredError,
+    RateLimiter, RefreshLlmTokenListener,
 };
 use release_channel::AppVersion;
 use schemars::JsonSchema;
@@ -425,12 +424,7 @@ impl CloudLanguageModel {
                 });
             }
 
-            if !refreshed_token
-                && response
-                    .headers()
-                    .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
-                    .is_some()
-            {
+            if !refreshed_token && response.needs_llm_token_refresh() {
                 token = llm_api_token.refresh(&client).await?;
                 refreshed_token = true;
                 continue;

crates/web_search_providers/src/cloud.rs 🔗

@@ -2,11 +2,11 @@ use std::sync::Arc;
 
 use anyhow::{Context as _, Result};
 use client::Client;
-use cloud_llm_client::{EXPIRED_LLM_TOKEN_HEADER_NAME, WebSearchBody, WebSearchResponse};
+use cloud_llm_client::{WebSearchBody, WebSearchResponse};
 use futures::AsyncReadExt as _;
 use gpui::{App, AppContext, Context, Entity, Subscription, Task};
 use http_client::{HttpClient, Method};
-use language_model::{LlmApiToken, RefreshLlmTokenListener};
+use language_model::{LlmApiToken, NeedsLlmTokenRefresh, RefreshLlmTokenListener};
 use web_search::{WebSearchProvider, WebSearchProviderId};
 
 pub struct CloudWebSearchProvider {
@@ -99,11 +99,7 @@ async fn perform_web_search(
             let mut body = String::new();
             response.body_mut().read_to_string(&mut body).await?;
             return Ok(serde_json::from_str(&body)?);
-        } else if response
-            .headers()
-            .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
-            .is_some()
-        {
+        } else if response.needs_llm_token_refresh() {
             token = llm_api_token.refresh(&client).await?;
             retries_remaining -= 1;
         } else {