Detailed changes
@@ -5,7 +5,7 @@ use cloud_llm_client::predict_edits_v3::{
PredictEditsV3Request, PredictEditsV3Response, RawCompletionRequest, RawCompletionResponse,
};
use cloud_llm_client::{
- EXPIRED_LLM_TOKEN_HEADER_NAME, EditPredictionRejectReason, EditPredictionRejection,
+ EditPredictionRejectReason, EditPredictionRejection,
MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
PredictEditsRequestTrigger, RejectEditPredictionsBodyRef, ZED_VERSION_HEADER_NAME,
};
@@ -29,7 +29,7 @@ use gpui::{
use language::language_settings::all_language_settings;
use language::{Anchor, Buffer, File, Point, TextBufferSnapshot, ToOffset, ToPoint};
use language::{BufferSnapshot, OffsetRangeExt};
-use language_model::{LlmApiToken, RefreshLlmTokenListener};
+use language_model::{LlmApiToken, NeedsLlmTokenRefresh, RefreshLlmTokenListener};
use project::{Project, ProjectPath, WorktreeId};
use release_channel::AppVersion;
use semver::Version;
@@ -2046,13 +2046,7 @@ impl EditPredictionStore {
let mut body = Vec::new();
response.body_mut().read_to_end(&mut body).await?;
return Ok((serde_json::from_slice(&body)?, usage));
- } else if !did_retry
- && token.is_some()
- && response
- .headers()
- .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
- .is_some()
- {
+ } else if !did_retry && token.is_some() && response.needs_llm_token_refresh() {
did_retry = true;
token = Some(llm_token.refresh(&client).await?);
} else {
@@ -4,6 +4,7 @@ use std::sync::Arc;
use anyhow::Result;
use client::Client;
use cloud_api_types::websocket_protocol::MessageToClient;
+use cloud_llm_client::EXPIRED_LLM_TOKEN_HEADER_NAME;
use gpui::{App, AppContext as _, Context, Entity, EventEmitter, Global, ReadGlobal as _};
use smol::lock::{RwLock, RwLockUpgradableReadGuard, RwLockWriteGuard};
use thiserror::Error;
@@ -52,6 +53,17 @@ impl LlmApiToken {
}
}
+pub trait NeedsLlmTokenRefresh {
+ /// Returns whether the LLM token needs to be refreshed.
+ fn needs_llm_token_refresh(&self) -> bool;
+}
+
+impl NeedsLlmTokenRefresh for http_client::Response<http_client::AsyncBody> {
+ fn needs_llm_token_refresh(&self) -> bool {
+ self.headers().get(EXPIRED_LLM_TOKEN_HEADER_NAME).is_some()
+ }
+}
+
struct GlobalRefreshLlmTokenListener(Entity<RefreshLlmTokenListener>);
impl Global for GlobalRefreshLlmTokenListener {}
@@ -5,9 +5,8 @@ use chrono::{DateTime, Utc};
use client::{Client, UserStore, zed_urls};
use cloud_llm_client::{
CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, CLIENT_SUPPORTS_X_AI_HEADER_NAME, CompletionBody,
- CompletionEvent, CountTokensBody, CountTokensResponse, EXPIRED_LLM_TOKEN_HEADER_NAME,
- ListModelsResponse, Plan, PlanV2, SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME,
- ZED_VERSION_HEADER_NAME,
+ CompletionEvent, CountTokensBody, CountTokensResponse, ListModelsResponse, Plan, PlanV2,
+ SERVER_SUPPORTS_STATUS_MESSAGES_HEADER_NAME, ZED_VERSION_HEADER_NAME,
};
use feature_flags::{CloudThinkingToggleFeatureFlag, FeatureFlagAppExt as _};
use futures::{
@@ -22,8 +21,8 @@ use language_model::{
LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId, LanguageModelName,
LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
LanguageModelProviderState, LanguageModelRequest, LanguageModelToolChoice,
- LanguageModelToolSchemaFormat, LlmApiToken, PaymentRequiredError, RateLimiter,
- RefreshLlmTokenListener,
+ LanguageModelToolSchemaFormat, LlmApiToken, NeedsLlmTokenRefresh, PaymentRequiredError,
+ RateLimiter, RefreshLlmTokenListener,
};
use release_channel::AppVersion;
use schemars::JsonSchema;
@@ -425,12 +424,7 @@ impl CloudLanguageModel {
});
}
- if !refreshed_token
- && response
- .headers()
- .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
- .is_some()
- {
+ if !refreshed_token && response.needs_llm_token_refresh() {
token = llm_api_token.refresh(&client).await?;
refreshed_token = true;
continue;
@@ -2,11 +2,11 @@ use std::sync::Arc;
use anyhow::{Context as _, Result};
use client::Client;
-use cloud_llm_client::{EXPIRED_LLM_TOKEN_HEADER_NAME, WebSearchBody, WebSearchResponse};
+use cloud_llm_client::{WebSearchBody, WebSearchResponse};
use futures::AsyncReadExt as _;
use gpui::{App, AppContext, Context, Entity, Subscription, Task};
use http_client::{HttpClient, Method};
-use language_model::{LlmApiToken, RefreshLlmTokenListener};
+use language_model::{LlmApiToken, NeedsLlmTokenRefresh, RefreshLlmTokenListener};
use web_search::{WebSearchProvider, WebSearchProviderId};
pub struct CloudWebSearchProvider {
@@ -99,11 +99,7 @@ async fn perform_web_search(
let mut body = String::new();
response.body_mut().read_to_string(&mut body).await?;
return Ok(serde_json::from_str(&body)?);
- } else if response
- .headers()
- .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
- .is_some()
- {
+ } else if response.needs_llm_token_refresh() {
token = llm_api_token.refresh(&client).await?;
retries_remaining -= 1;
} else {