Cargo.lock 🔗
@@ -7708,6 +7708,7 @@ dependencies = [
"smol",
"strum",
"theme",
+ "thiserror 2.0.12",
"tiktoken-rs",
"tokio",
"ui",
Agus Zubiaga created

Release Notes:
- agent: Handle context window exceeded errors from Anthropic
Cargo.lock | 1
crates/agent/src/message_editor.rs | 42 ++++++++++++--
crates/agent/src/thread.rs | 51 ++++++++++++++---
crates/agent/src/thread_store.rs | 7 ++
crates/anthropic/src/anthropic.rs | 50 +++++++++++++++++
crates/language_model/src/language_model.rs | 8 ++
crates/language_models/Cargo.toml | 1
crates/language_models/src/provider/anthropic.rs | 24 +++++++-
crates/language_models/src/provider/cloud.rs | 34 +++++++++--
9 files changed, 190 insertions(+), 28 deletions(-)
@@ -7708,6 +7708,7 @@ dependencies = [
"smol",
"strum",
"theme",
+ "thiserror 2.0.12",
"tiktoken-rs",
"tokio",
"ui",
@@ -761,13 +761,29 @@ impl MessageEditor {
})
}
- fn render_reaching_token_limit(&self, line_height: Pixels, cx: &mut Context<Self>) -> Div {
+ fn render_token_limit_callout(
+ &self,
+ line_height: Pixels,
+ token_usage_ratio: TokenUsageRatio,
+ cx: &mut Context<Self>,
+ ) -> Div {
+ let heading = if token_usage_ratio == TokenUsageRatio::Exceeded {
+ "Thread reached the token limit"
+ } else {
+ "Thread reaching the token limit soon"
+ };
+
h_flex()
.p_2()
.gap_2()
.flex_wrap()
.justify_between()
- .bg(cx.theme().status().warning_background.opacity(0.1))
+ .bg(
+ if token_usage_ratio == TokenUsageRatio::Exceeded {
+ cx.theme().status().error_background.opacity(0.1)
+ } else {
+ cx.theme().status().warning_background.opacity(0.1)
+ })
.border_t_1()
.border_color(cx.theme().colors().border)
.child(
@@ -779,15 +795,21 @@ impl MessageEditor {
.h(line_height)
.justify_center()
.child(
- Icon::new(IconName::Warning)
- .color(Color::Warning)
- .size(IconSize::XSmall),
+ if token_usage_ratio == TokenUsageRatio::Exceeded {
+ Icon::new(IconName::X)
+ .color(Color::Error)
+ .size(IconSize::XSmall)
+ } else {
+ Icon::new(IconName::Warning)
+ .color(Color::Warning)
+ .size(IconSize::XSmall)
+ }
),
)
.child(
v_flex()
.mr_auto()
- .child(Label::new("Thread reaching the token limit soon").size(LabelSize::Small))
+ .child(Label::new(heading).size(LabelSize::Small))
.child(
Label::new(
"Start a new thread from a summary to continue the conversation.",
@@ -875,7 +897,13 @@ impl Render for MessageEditor {
.child(self.render_editor(font_size, line_height, window, cx))
.when(
total_token_usage.ratio != TokenUsageRatio::Normal,
- |parent| parent.child(self.render_reaching_token_limit(line_height, cx)),
+ |parent| {
+ parent.child(self.render_token_limit_callout(
+ line_height,
+ total_token_usage.ratio,
+ cx,
+ ))
+ },
)
}
}
@@ -15,10 +15,11 @@ use futures::{FutureExt, StreamExt as _};
use git::repository::DiffType;
use gpui::{App, AppContext, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
use language_model::{
- ConfiguredModel, LanguageModel, LanguageModelCompletionEvent, LanguageModelRegistry,
- LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool,
- LanguageModelToolResult, LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent,
- PaymentRequiredError, Role, StopReason, TokenUsage,
+ ConfiguredModel, LanguageModel, LanguageModelCompletionEvent, LanguageModelId,
+ LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
+ LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
+ LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent, PaymentRequiredError,
+ Role, StopReason, TokenUsage,
};
use project::Project;
use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
@@ -228,7 +229,7 @@ pub struct TotalTokenUsage {
pub ratio: TokenUsageRatio,
}
-#[derive(Default, PartialEq, Eq)]
+#[derive(Debug, Default, PartialEq, Eq)]
pub enum TokenUsageRatio {
#[default]
Normal,
@@ -260,11 +261,20 @@ pub struct Thread {
pending_checkpoint: Option<ThreadCheckpoint>,
initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
cumulative_token_usage: TokenUsage,
+ exceeded_window_error: Option<ExceededWindowError>,
feedback: Option<ThreadFeedback>,
message_feedback: HashMap<MessageId, ThreadFeedback>,
last_auto_capture_at: Option<Instant>,
}
+#[derive(Debug, Clone, Serialize, Deserialize)]
+pub struct ExceededWindowError {
+ /// Model used when last message exceeded context window
+ model_id: LanguageModelId,
+ /// Token count including last message
+ token_count: usize,
+}
+
impl Thread {
pub fn new(
project: Entity<Project>,
@@ -301,6 +311,7 @@ impl Thread {
.shared()
},
cumulative_token_usage: TokenUsage::default(),
+ exceeded_window_error: None,
feedback: None,
message_feedback: HashMap::default(),
last_auto_capture_at: None,
@@ -367,6 +378,7 @@ impl Thread {
action_log: cx.new(|_| ActionLog::new(project)),
initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
cumulative_token_usage: serialized.cumulative_token_usage,
+ exceeded_window_error: None,
feedback: None,
message_feedback: HashMap::default(),
last_auto_capture_at: None,
@@ -817,6 +829,7 @@ impl Thread {
initial_project_snapshot,
cumulative_token_usage: this.cumulative_token_usage.clone(),
detailed_summary_state: this.detailed_summary_state.clone(),
+ exceeded_window_error: this.exceeded_window_error.clone(),
})
})
}
@@ -1129,6 +1142,20 @@ impl Thread {
cx.emit(ThreadEvent::ShowError(
ThreadError::MaxMonthlySpendReached,
));
+ } else if let Some(known_error) =
+ error.downcast_ref::<LanguageModelKnownError>()
+ {
+ match known_error {
+ LanguageModelKnownError::ContextWindowLimitExceeded {
+ tokens,
+ } => {
+ thread.exceeded_window_error = Some(ExceededWindowError {
+ model_id: model.id(),
+ token_count: *tokens,
+ });
+ cx.notify();
+ }
+ }
} else {
let error_message = error
.chain()
@@ -1784,10 +1811,6 @@ impl Thread {
&self.project
}
- pub fn cumulative_token_usage(&self) -> TokenUsage {
- self.cumulative_token_usage.clone()
- }
-
pub fn auto_capture_telemetry(&mut self, cx: &mut Context<Self>) {
if !cx.has_flag::<feature_flags::ThreadAutoCapture>() {
return;
@@ -1840,6 +1863,16 @@ impl Thread {
let max = model.model.max_token_count();
+ if let Some(exceeded_error) = &self.exceeded_window_error {
+ if model.model.id() == exceeded_error.model_id {
+ return TotalTokenUsage {
+ total: exceeded_error.token_count,
+ max,
+ ratio: TokenUsageRatio::Exceeded,
+ };
+ }
+ }
+
#[cfg(debug_assertions)]
let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
.unwrap_or("0.8".to_string())
@@ -27,7 +27,9 @@ use serde::{Deserialize, Serialize};
use settings::{Settings as _, SettingsStore};
use util::ResultExt as _;
-use crate::thread::{DetailedSummaryState, MessageId, ProjectSnapshot, Thread, ThreadId};
+use crate::thread::{
+ DetailedSummaryState, ExceededWindowError, MessageId, ProjectSnapshot, Thread, ThreadId,
+};
const RULES_FILE_NAMES: [&'static str; 6] = [
".rules",
@@ -491,6 +493,8 @@ pub struct SerializedThread {
pub cumulative_token_usage: TokenUsage,
#[serde(default)]
pub detailed_summary_state: DetailedSummaryState,
+ #[serde(default)]
+ pub exceeded_window_error: Option<ExceededWindowError>,
}
impl SerializedThread {
@@ -577,6 +581,7 @@ impl LegacySerializedThread {
initial_project_snapshot: self.initial_project_snapshot,
cumulative_token_usage: TokenUsage::default(),
detailed_summary_state: DetailedSummaryState::default(),
+ exceeded_window_error: None,
}
}
}
@@ -724,4 +724,54 @@ impl ApiError {
pub fn is_rate_limit_error(&self) -> bool {
matches!(self.error_type.as_str(), "rate_limit_error")
}
+
+ pub fn match_window_exceeded(&self) -> Option<usize> {
+ let Some(ApiErrorCode::InvalidRequestError) = self.code() else {
+ return None;
+ };
+
+ parse_prompt_too_long(&self.message)
+ }
+}
+
+pub fn parse_prompt_too_long(message: &str) -> Option<usize> {
+ message
+ .strip_prefix("prompt is too long: ")?
+ .split_once(" tokens")?
+ .0
+ .parse::<usize>()
+ .ok()
+}
+
+#[test]
+fn test_match_window_exceeded() {
+ let error = ApiError {
+ error_type: "invalid_request_error".to_string(),
+ message: "prompt is too long: 220000 tokens > 200000".to_string(),
+ };
+ assert_eq!(error.match_window_exceeded(), Some(220_000));
+
+ let error = ApiError {
+ error_type: "invalid_request_error".to_string(),
+ message: "prompt is too long: 1234953 tokens".to_string(),
+ };
+ assert_eq!(error.match_window_exceeded(), Some(1234953));
+
+ let error = ApiError {
+ error_type: "invalid_request_error".to_string(),
+ message: "not a prompt length error".to_string(),
+ };
+ assert_eq!(error.match_window_exceeded(), None);
+
+ let error = ApiError {
+ error_type: "rate_limit_error".to_string(),
+ message: "prompt is too long: 12345 tokens".to_string(),
+ };
+ assert_eq!(error.match_window_exceeded(), None);
+
+ let error = ApiError {
+ error_type: "invalid_request_error".to_string(),
+ message: "prompt is too long: invalid tokens".to_string(),
+ };
+ assert_eq!(error.match_window_exceeded(), None);
}
@@ -278,6 +278,12 @@ pub trait LanguageModel: Send + Sync {
}
}
+#[derive(Debug, Error)]
+pub enum LanguageModelKnownError {
+ #[error("Context window limit exceeded ({tokens})")]
+ ContextWindowLimitExceeded { tokens: usize },
+}
+
pub trait LanguageModelTool: 'static + DeserializeOwned + JsonSchema {
fn name() -> String;
fn description() -> String;
@@ -347,7 +353,7 @@ pub trait LanguageModelProviderState: 'static {
}
}
-#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
+#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd, Serialize, Deserialize)]
pub struct LanguageModelId(pub SharedString);
#[derive(Clone, Eq, PartialEq, Hash, Debug, Ord, PartialOrd)]
@@ -47,6 +47,7 @@ settings.workspace = true
smol.workspace = true
strum.workspace = true
theme.workspace = true
+thiserror.workspace = true
tiktoken-rs.workspace = true
tokio = { workspace = true, features = ["rt", "rt-multi-thread"] }
ui.workspace = true
@@ -13,8 +13,9 @@ use gpui::{
use http_client::HttpClient;
use language_model::{
AuthenticateError, LanguageModel, LanguageModelCacheConfiguration, LanguageModelId,
- LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
- LanguageModelProviderState, LanguageModelRequest, MessageContent, RateLimiter, Role,
+ LanguageModelKnownError, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
+ LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, MessageContent,
+ RateLimiter, Role,
};
use language_model::{LanguageModelCompletionEvent, LanguageModelToolUse, StopReason};
use schemars::JsonSchema;
@@ -454,7 +455,12 @@ impl LanguageModel for AnthropicModel {
);
let request = self.stream_completion(request, cx);
let future = self.request_limiter.stream(async move {
- let response = request.await.map_err(|err| anyhow!(err))?;
+ let response = request
+ .await
+ .map_err(|err| match err.downcast::<AnthropicError>() {
+ Ok(anthropic_err) => anthropic_err_to_anyhow(anthropic_err),
+ Err(err) => anyhow!(err),
+ })?;
Ok(map_to_language_model_completion_events(response))
});
async move { Ok(future.await?.boxed()) }.boxed()
@@ -746,7 +752,7 @@ pub fn map_to_language_model_completion_events(
_ => {}
},
Err(err) => {
- return Some((vec![Err(anyhow!(err))], state));
+ return Some((vec![Err(anthropic_err_to_anyhow(err))], state));
}
}
}
@@ -757,6 +763,16 @@ pub fn map_to_language_model_completion_events(
.flat_map(futures::stream::iter)
}
+pub fn anthropic_err_to_anyhow(err: AnthropicError) -> anyhow::Error {
+ if let AnthropicError::ApiError(api_err) = &err {
+ if let Some(tokens) = api_err.match_window_exceeded() {
+ return anyhow!(LanguageModelKnownError::ContextWindowLimitExceeded { tokens });
+ }
+ }
+
+ anyhow!(err)
+}
+
/// Updates usage data by preferring counts from `new`.
fn update_usage(usage: &mut Usage, new: &Usage) {
if let Some(input_tokens) = new.input_tokens {
@@ -1,4 +1,4 @@
-use anthropic::{AnthropicError, AnthropicModelMode};
+use anthropic::{AnthropicError, AnthropicModelMode, parse_prompt_too_long};
use anyhow::{Result, anyhow};
use client::{
Client, EXPIRED_LLM_TOKEN_HEADER_NAME, MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME,
@@ -14,7 +14,7 @@ use gpui::{AnyElement, AnyView, App, AsyncApp, Context, Entity, Subscription, Ta
use http_client::{AsyncBody, HttpClient, Method, Response, StatusCode};
use language_model::{
AuthenticateError, CloudModel, LanguageModel, LanguageModelCacheConfiguration, LanguageModelId,
- LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
+ LanguageModelKnownError, LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
LanguageModelProviderState, LanguageModelProviderTosView, LanguageModelRequest,
LanguageModelToolSchemaFormat, RateLimiter, ZED_CLOUD_PROVIDER_ID,
};
@@ -33,6 +33,7 @@ use std::{
time::Duration,
};
use strum::IntoEnumIterator;
+use thiserror::Error;
use ui::{TintColor, prelude::*};
use crate::AllLanguageModelSettings;
@@ -575,14 +576,19 @@ impl CloudLanguageModel {
} else {
let mut body = String::new();
response.body_mut().read_to_string(&mut body).await?;
- return Err(anyhow!(
- "cloud language model completion failed with status {status}: {body}",
- ));
+ return Err(anyhow!(ApiError { status, body }));
}
}
}
}
+#[derive(Debug, Error)]
+#[error("cloud language model completion failed with status {status}: {body}")]
+struct ApiError {
+ status: StatusCode,
+ body: String,
+}
+
impl LanguageModel for CloudLanguageModel {
fn id(&self) -> LanguageModelId {
self.id.clone()
@@ -696,7 +702,23 @@ impl LanguageModel for CloudLanguageModel {
)?)?,
},
)
- .await?;
+ .await
+ .map_err(|err| match err.downcast::<ApiError>() {
+ Ok(api_err) => {
+ if api_err.status == StatusCode::BAD_REQUEST {
+ if let Some(tokens) = parse_prompt_too_long(&api_err.body) {
+ return anyhow!(
+ LanguageModelKnownError::ContextWindowLimitExceeded {
+ tokens
+ }
+ );
+ }
+ }
+ anyhow!(api_err)
+ }
+ Err(err) => anyhow!(err),
+ })?;
+
Ok(
crate::provider::anthropic::map_to_language_model_completion_events(
Box::pin(response_lines(response).map_err(AnthropicError::Other)),