Detailed changes
@@ -1495,27 +1495,76 @@ impl Thread {
thread.update(cx, |thread, cx| {
let event = match event {
Ok(event) => event,
- Err(LanguageModelCompletionError::BadInputJson {
- id,
- tool_name,
- raw_input: invalid_input_json,
- json_parse_error,
- }) => {
- thread.receive_invalid_tool_json(
- id,
- tool_name,
- invalid_input_json,
- json_parse_error,
- window,
- cx,
- );
- return Ok(());
- }
- Err(LanguageModelCompletionError::Other(error)) => {
- return Err(error);
- }
- Err(err @ LanguageModelCompletionError::RateLimit(..)) => {
- return Err(err.into());
+ Err(error) => {
+ match error {
+ LanguageModelCompletionError::RateLimitExceeded { retry_after } => {
+ anyhow::bail!(LanguageModelKnownError::RateLimitExceeded { retry_after });
+ }
+ LanguageModelCompletionError::Overloaded => {
+ anyhow::bail!(LanguageModelKnownError::Overloaded);
+ }
+ LanguageModelCompletionError::ApiInternalServerError =>{
+ anyhow::bail!(LanguageModelKnownError::ApiInternalServerError);
+ }
+ LanguageModelCompletionError::PromptTooLarge { tokens } => {
+ let tokens = tokens.unwrap_or_else(|| {
+ // We didn't get an exact token count from the API, so fall back on our estimate.
+ thread.total_token_usage()
+ .map(|usage| usage.total)
+ .unwrap_or(0)
+ // We know the context window was exceeded in practice, so if our estimate was
+ // lower than max tokens, the estimate was wrong; return that we exceeded by 1.
+ .max(model.max_token_count().saturating_add(1))
+ });
+
+ anyhow::bail!(LanguageModelKnownError::ContextWindowLimitExceeded { tokens })
+ }
+ LanguageModelCompletionError::ApiReadResponseError(io_error) => {
+ anyhow::bail!(LanguageModelKnownError::ReadResponseError(io_error));
+ }
+ LanguageModelCompletionError::UnknownResponseFormat(error) => {
+ anyhow::bail!(LanguageModelKnownError::UnknownResponseFormat(error));
+ }
+ LanguageModelCompletionError::HttpResponseError { status, ref body } => {
+ if let Some(known_error) = LanguageModelKnownError::from_http_response(status, body) {
+ anyhow::bail!(known_error);
+ } else {
+ return Err(error.into());
+ }
+ }
+ LanguageModelCompletionError::DeserializeResponse(error) => {
+ anyhow::bail!(LanguageModelKnownError::DeserializeResponse(error));
+ }
+ LanguageModelCompletionError::BadInputJson {
+ id,
+ tool_name,
+ raw_input: invalid_input_json,
+ json_parse_error,
+ } => {
+ thread.receive_invalid_tool_json(
+ id,
+ tool_name,
+ invalid_input_json,
+ json_parse_error,
+ window,
+ cx,
+ );
+ return Ok(());
+ }
+ // These are all errors we can't automatically attempt to recover from (e.g. by retrying)
+ err @ LanguageModelCompletionError::BadRequestFormat |
+ err @ LanguageModelCompletionError::AuthenticationError |
+ err @ LanguageModelCompletionError::PermissionError |
+ err @ LanguageModelCompletionError::ApiEndpointNotFound |
+ err @ LanguageModelCompletionError::SerializeRequest(_) |
+ err @ LanguageModelCompletionError::BuildRequestBody(_) |
+ err @ LanguageModelCompletionError::HttpSend(_) => {
+ anyhow::bail!(err);
+ }
+ LanguageModelCompletionError::Other(error) => {
+ return Err(error);
+ }
+ }
}
};
@@ -1751,6 +1800,18 @@ impl Thread {
project.set_agent_location(None, cx);
});
+ fn emit_generic_error(error: &anyhow::Error, cx: &mut Context<Thread>) {
+ let error_message = error
+ .chain()
+ .map(|err| err.to_string())
+ .collect::<Vec<_>>()
+ .join("\n");
+ cx.emit(ThreadEvent::ShowError(ThreadError::Message {
+ header: "Error interacting with language model".into(),
+ message: SharedString::from(error_message.clone()),
+ }));
+ }
+
if error.is::<PaymentRequiredError>() {
cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
} else if let Some(error) =
@@ -1763,26 +1824,34 @@ impl Thread {
error.downcast_ref::<LanguageModelKnownError>()
{
match known_error {
- LanguageModelKnownError::ContextWindowLimitExceeded {
- tokens,
- } => {
+ LanguageModelKnownError::ContextWindowLimitExceeded { tokens } => {
thread.exceeded_window_error = Some(ExceededWindowError {
model_id: model.id(),
token_count: *tokens,
});
cx.notify();
}
+ LanguageModelKnownError::RateLimitExceeded { .. } => {
+ // In the future we will report the error to the user, wait retry_after, and then retry.
+ emit_generic_error(error, cx);
+ }
+ LanguageModelKnownError::Overloaded => {
+ // In the future we will wait and then retry, up to N times.
+ emit_generic_error(error, cx);
+ }
+ LanguageModelKnownError::ApiInternalServerError => {
+ // In the future we will retry the request, but only once.
+ emit_generic_error(error, cx);
+ }
+ LanguageModelKnownError::ReadResponseError(_) |
+ LanguageModelKnownError::DeserializeResponse(_) |
+ LanguageModelKnownError::UnknownResponseFormat(_) => {
+ // In the future we will attempt to re-roll response, but only once
+ emit_generic_error(error, cx);
+ }
}
} else {
- let error_message = error
- .chain()
- .map(|err| err.to_string())
- .collect::<Vec<_>>()
- .join("\n");
- cx.emit(ThreadEvent::ShowError(ThreadError::Message {
- header: "Error interacting with language model".into(),
- message: SharedString::from(error_message.clone()),
- }));
+ emit_generic_error(error, cx);
}
thread.cancel_last_completion(window, cx);
@@ -1,10 +1,11 @@
+use std::io;
use std::str::FromStr;
use std::time::Duration;
use anyhow::{Context as _, Result, anyhow};
use chrono::{DateTime, Utc};
use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
-use http_client::http::{HeaderMap, HeaderValue};
+use http_client::http::{self, HeaderMap, HeaderValue};
use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
use serde::{Deserialize, Serialize};
use strum::{EnumIter, EnumString};
@@ -336,7 +337,7 @@ pub async fn complete(
let uri = format!("{api_url}/v1/messages");
let beta_headers = Model::from_id(&request.model)
.map(|model| model.beta_headers())
- .unwrap_or_else(|_err| Model::DEFAULT_BETA_HEADERS.join(","));
+ .unwrap_or_else(|_| Model::DEFAULT_BETA_HEADERS.join(","));
let request_builder = HttpRequest::builder()
.method(Method::POST)
.uri(uri)
@@ -346,39 +347,30 @@ pub async fn complete(
.header("Content-Type", "application/json");
let serialized_request =
- serde_json::to_string(&request).context("failed to serialize request")?;
+ serde_json::to_string(&request).map_err(AnthropicError::SerializeRequest)?;
let request = request_builder
.body(AsyncBody::from(serialized_request))
- .context("failed to construct request body")?;
+ .map_err(AnthropicError::BuildRequestBody)?;
let mut response = client
.send(request)
.await
- .context("failed to send request to Anthropic")?;
- if response.status().is_success() {
- let mut body = Vec::new();
- response
- .body_mut()
- .read_to_end(&mut body)
- .await
- .context("failed to read response body")?;
- let response_message: Response =
- serde_json::from_slice(&body).context("failed to deserialize response body")?;
- Ok(response_message)
+ .map_err(AnthropicError::HttpSend)?;
+ let status = response.status();
+ let mut body = String::new();
+ response
+ .body_mut()
+ .read_to_string(&mut body)
+ .await
+ .map_err(AnthropicError::ReadResponse)?;
+
+ if status.is_success() {
+ Ok(serde_json::from_str(&body).map_err(AnthropicError::DeserializeResponse)?)
} else {
- let mut body = Vec::new();
- response
- .body_mut()
- .read_to_end(&mut body)
- .await
- .context("failed to read response body")?;
- let body_str =
- std::str::from_utf8(&body).context("failed to parse response body as UTF-8")?;
- Err(AnthropicError::Other(anyhow!(
- "Failed to connect to API: {} {}",
- response.status(),
- body_str
- )))
+ Err(AnthropicError::HttpResponseError {
+ status: status.as_u16(),
+ body,
+ })
}
}
@@ -491,7 +483,7 @@ pub async fn stream_completion_with_rate_limit_info(
let uri = format!("{api_url}/v1/messages");
let beta_headers = Model::from_id(&request.base.model)
.map(|model| model.beta_headers())
- .unwrap_or_else(|_err| Model::DEFAULT_BETA_HEADERS.join(","));
+ .unwrap_or_else(|_| Model::DEFAULT_BETA_HEADERS.join(","));
let request_builder = HttpRequest::builder()
.method(Method::POST)
.uri(uri)
@@ -500,15 +492,15 @@ pub async fn stream_completion_with_rate_limit_info(
.header("X-Api-Key", api_key)
.header("Content-Type", "application/json");
let serialized_request =
- serde_json::to_string(&request).context("failed to serialize request")?;
+ serde_json::to_string(&request).map_err(AnthropicError::SerializeRequest)?;
let request = request_builder
.body(AsyncBody::from(serialized_request))
- .context("failed to construct request body")?;
+ .map_err(AnthropicError::BuildRequestBody)?;
let mut response = client
.send(request)
.await
- .context("failed to send request to Anthropic")?;
+ .map_err(AnthropicError::HttpSend)?;
let rate_limits = RateLimitInfo::from_headers(response.headers());
if response.status().is_success() {
let reader = BufReader::new(response.into_body());
@@ -520,37 +512,31 @@ pub async fn stream_completion_with_rate_limit_info(
let line = line.strip_prefix("data: ")?;
match serde_json::from_str(line) {
Ok(response) => Some(Ok(response)),
- Err(error) => Some(Err(AnthropicError::Other(anyhow!(error)))),
+ Err(error) => Some(Err(AnthropicError::DeserializeResponse(error))),
}
}
- Err(error) => Some(Err(AnthropicError::Other(anyhow!(error)))),
+ Err(error) => Some(Err(AnthropicError::ReadResponse(error))),
}
})
.boxed();
Ok((stream, Some(rate_limits)))
} else if let Some(retry_after) = rate_limits.retry_after {
- Err(AnthropicError::RateLimit(retry_after))
+ Err(AnthropicError::RateLimit { retry_after })
} else {
- let mut body = Vec::new();
+ let mut body = String::new();
response
.body_mut()
- .read_to_end(&mut body)
+ .read_to_string(&mut body)
.await
- .context("failed to read response body")?;
-
- let body_str =
- std::str::from_utf8(&body).context("failed to parse response body as UTF-8")?;
+ .map_err(AnthropicError::ReadResponse)?;
- match serde_json::from_str::<Event>(body_str) {
+ match serde_json::from_str::<Event>(&body) {
Ok(Event::Error { error }) => Err(AnthropicError::ApiError(error)),
- Ok(_) => Err(AnthropicError::Other(anyhow!(
- "Unexpected success response while expecting an error: '{body_str}'",
- ))),
- Err(_) => Err(AnthropicError::Other(anyhow!(
- "Failed to connect to API: {} {}",
- response.status(),
- body_str,
- ))),
+ Ok(_) => Err(AnthropicError::UnexpectedResponseFormat(body)),
+ Err(_) => Err(AnthropicError::HttpResponseError {
+ status: response.status().as_u16(),
+ body: body,
+ }),
}
}
}
@@ -797,17 +783,38 @@ pub struct MessageDelta {
pub stop_sequence: Option<String>,
}
-#[derive(Error, Debug)]
+#[derive(Debug)]
pub enum AnthropicError {
- #[error("rate limit exceeded, retry after {0:?}")]
- RateLimit(Duration),
- #[error("an error occurred while interacting with the Anthropic API: {error_type}: {message}", error_type = .0.error_type, message = .0.message)]
+ /// Failed to serialize the HTTP request body to JSON
+ SerializeRequest(serde_json::Error),
+
+ /// Failed to construct the HTTP request body
+ BuildRequestBody(http::Error),
+
+ /// Failed to send the HTTP request
+ HttpSend(anyhow::Error),
+
+ /// Failed to deserialize the response from JSON
+ DeserializeResponse(serde_json::Error),
+
+ /// Failed to read from response stream
+ ReadResponse(io::Error),
+
+ /// HTTP error response from the API
+ HttpResponseError { status: u16, body: String },
+
+ /// Rate limit exceeded
+ RateLimit { retry_after: Duration },
+
+ /// API returned an error response
ApiError(ApiError),
- #[error("{0}")]
- Other(#[from] anyhow::Error),
+
+ /// Unexpected response format
+ UnexpectedResponseFormat(String),
}
-#[derive(Debug, Serialize, Deserialize)]
+#[derive(Debug, Serialize, Deserialize, Error)]
+#[error("Anthropic API Error: {error_type}: {message}")]
pub struct ApiError {
#[serde(rename = "type")]
pub error_type: String,
@@ -1659,13 +1659,13 @@ async fn retry_on_rate_limit<R>(mut request: impl AsyncFnMut() -> Result<R>) ->
Ok(result) => return Ok(result),
Err(err) => match err.downcast::<LanguageModelCompletionError>() {
Ok(err) => match err {
- LanguageModelCompletionError::RateLimit(duration) => {
+ LanguageModelCompletionError::RateLimitExceeded { retry_after } => {
// Wait for the duration supplied, with some jitter to avoid all requests being made at the same time.
- let jitter = duration.mul_f64(rand::thread_rng().gen_range(0.0..1.0));
+ let jitter = retry_after.mul_f64(rand::thread_rng().gen_range(0.0..1.0));
eprintln!(
- "Attempt #{attempt}: Rate limit exceeded. Retry after {duration:?} + jitter of {jitter:?}"
+ "Attempt #{attempt}: Rate limit exceeded. Retry after {retry_after:?} + jitter of {jitter:?}"
);
- Timer::after(duration + jitter).await;
+ Timer::after(retry_after + jitter).await;
continue;
}
_ => return Err(err.into()),
@@ -8,19 +8,21 @@ mod telemetry;
#[cfg(any(test, feature = "test-support"))]
pub mod fake_provider;
+use anthropic::{AnthropicError, parse_prompt_too_long};
use anyhow::Result;
use client::Client;
use futures::FutureExt;
use futures::{StreamExt, future::BoxFuture, stream::BoxStream};
use gpui::{AnyElement, AnyView, App, AsyncApp, SharedString, Task, Window};
+use http_client::http;
use icons::IconName;
use parking_lot::Mutex;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize, de::DeserializeOwned};
-use std::fmt;
use std::ops::{Add, Sub};
use std::sync::Arc;
use std::time::Duration;
+use std::{fmt, io};
use thiserror::Error;
use util::serde::is_default;
use zed_llm_client::CompletionRequestStatus;
@@ -34,6 +36,10 @@ pub use crate::telemetry::*;
pub const ZED_CLOUD_PROVIDER_ID: &str = "zed.dev";
+/// If we get a rate limit error that doesn't tell us when we can retry,
+/// default to waiting this long before retrying.
+const DEFAULT_RATE_LIMIT_RETRY_AFTER: Duration = Duration::from_secs(4);
+
pub fn init(client: Arc<Client>, cx: &mut App) {
init_settings(cx);
RefreshLlmTokenListener::register(client.clone(), cx);
@@ -70,8 +76,8 @@ pub enum LanguageModelCompletionEvent {
#[derive(Error, Debug)]
pub enum LanguageModelCompletionError {
- #[error("rate limit exceeded, retry after {0:?}")]
- RateLimit(Duration),
+ #[error("rate limit exceeded, retry after {retry_after:?}")]
+ RateLimitExceeded { retry_after: Duration },
#[error("received bad input JSON")]
BadInputJson {
id: LanguageModelToolUseId,
@@ -79,8 +85,78 @@ pub enum LanguageModelCompletionError {
raw_input: Arc<str>,
json_parse_error: String,
},
+ #[error("language model provider's API is overloaded")]
+ Overloaded,
#[error(transparent)]
Other(#[from] anyhow::Error),
+ #[error("invalid request format to language model provider's API")]
+ BadRequestFormat,
+ #[error("authentication error with language model provider's API")]
+ AuthenticationError,
+ #[error("permission error with language model provider's API")]
+ PermissionError,
+ #[error("language model provider API endpoint not found")]
+ ApiEndpointNotFound,
+ #[error("prompt too large for context window")]
+ PromptTooLarge { tokens: Option<u64> },
+ #[error("internal server error in language model provider's API")]
+ ApiInternalServerError,
+ #[error("I/O error reading response from language model provider's API: {0:?}")]
+ ApiReadResponseError(io::Error),
+ #[error("HTTP response error from language model provider's API: status {status} - {body:?}")]
+ HttpResponseError { status: u16, body: String },
+ #[error("error serializing request to language model provider API: {0}")]
+ SerializeRequest(serde_json::Error),
+ #[error("error building request body to language model provider API: {0}")]
+ BuildRequestBody(http::Error),
+ #[error("error sending HTTP request to language model provider API: {0}")]
+ HttpSend(anyhow::Error),
+ #[error("error deserializing language model provider API response: {0}")]
+ DeserializeResponse(serde_json::Error),
+ #[error("unexpected language model provider API response format: {0}")]
+ UnknownResponseFormat(String),
+}
+
+impl From<AnthropicError> for LanguageModelCompletionError {
+ fn from(error: AnthropicError) -> Self {
+ match error {
+ AnthropicError::SerializeRequest(error) => Self::SerializeRequest(error),
+ AnthropicError::BuildRequestBody(error) => Self::BuildRequestBody(error),
+ AnthropicError::HttpSend(error) => Self::HttpSend(error),
+ AnthropicError::DeserializeResponse(error) => Self::DeserializeResponse(error),
+ AnthropicError::ReadResponse(error) => Self::ApiReadResponseError(error),
+ AnthropicError::HttpResponseError { status, body } => {
+ Self::HttpResponseError { status, body }
+ }
+ AnthropicError::RateLimit { retry_after } => Self::RateLimitExceeded { retry_after },
+ AnthropicError::ApiError(api_error) => api_error.into(),
+ AnthropicError::UnexpectedResponseFormat(error) => Self::UnknownResponseFormat(error),
+ }
+ }
+}
+
+impl From<anthropic::ApiError> for LanguageModelCompletionError {
+ fn from(error: anthropic::ApiError) -> Self {
+ use anthropic::ApiErrorCode::*;
+
+ match error.code() {
+ Some(code) => match code {
+ InvalidRequestError => LanguageModelCompletionError::BadRequestFormat,
+ AuthenticationError => LanguageModelCompletionError::AuthenticationError,
+ PermissionError => LanguageModelCompletionError::PermissionError,
+ NotFoundError => LanguageModelCompletionError::ApiEndpointNotFound,
+ RequestTooLarge => LanguageModelCompletionError::PromptTooLarge {
+ tokens: parse_prompt_too_long(&error.message),
+ },
+ RateLimitError => LanguageModelCompletionError::RateLimitExceeded {
+ retry_after: DEFAULT_RATE_LIMIT_RETRY_AFTER,
+ },
+ ApiError => LanguageModelCompletionError::ApiInternalServerError,
+ OverloadedError => LanguageModelCompletionError::Overloaded,
+ },
+ None => LanguageModelCompletionError::Other(error.into()),
+ }
+ }
}
/// Indicates the format used to define the input schema for a language model tool.
@@ -319,6 +395,33 @@ pub trait LanguageModel: Send + Sync {
pub enum LanguageModelKnownError {
#[error("Context window limit exceeded ({tokens})")]
ContextWindowLimitExceeded { tokens: u64 },
+ #[error("Language model provider's API is currently overloaded")]
+ Overloaded,
+ #[error("Language model provider's API encountered an internal server error")]
+ ApiInternalServerError,
+ #[error("I/O error while reading response from language model provider's API: {0:?}")]
+ ReadResponseError(io::Error),
+ #[error("Error deserializing response from language model provider's API: {0:?}")]
+ DeserializeResponse(serde_json::Error),
+ #[error("Language model provider's API returned a response in an unknown format")]
+ UnknownResponseFormat(String),
+ #[error("Rate limit exceeded for language model provider's API; retry in {retry_after:?}")]
+ RateLimitExceeded { retry_after: Duration },
+}
+
+impl LanguageModelKnownError {
+ /// Attempts to map an HTTP response status code to a known error type.
+ /// Returns None if the status code doesn't map to a specific known error.
+ pub fn from_http_response(status: u16, _body: &str) -> Option<Self> {
+ match status {
+ 429 => Some(Self::RateLimitExceeded {
+ retry_after: DEFAULT_RATE_LIMIT_RETRY_AFTER,
+ }),
+ 503 => Some(Self::Overloaded),
+ 500..=599 => Some(Self::ApiInternalServerError),
+ _ => None,
+ }
+ }
}
pub trait LanguageModelTool: 'static + DeserializeOwned + JsonSchema {
@@ -16,10 +16,10 @@ use gpui::{
use http_client::HttpClient;
use language_model::{
AuthenticateError, LanguageModel, LanguageModelCacheConfiguration,
- LanguageModelCompletionError, LanguageModelId, LanguageModelKnownError, LanguageModelName,
- LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
- LanguageModelProviderState, LanguageModelRequest, LanguageModelToolChoice,
- LanguageModelToolResultContent, MessageContent, RateLimiter, Role,
+ LanguageModelCompletionError, LanguageModelId, LanguageModelName, LanguageModelProvider,
+ LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
+ LanguageModelRequest, LanguageModelToolChoice, LanguageModelToolResultContent, MessageContent,
+ RateLimiter, Role,
};
use language_model::{LanguageModelCompletionEvent, LanguageModelToolUse, StopReason};
use schemars::JsonSchema;
@@ -407,14 +407,7 @@ impl AnthropicModel {
let api_key = api_key.context("Missing Anthropic API Key")?;
let request =
anthropic::stream_completion(http_client.as_ref(), &api_url, &api_key, request);
- request.await.map_err(|err| match err {
- AnthropicError::RateLimit(duration) => {
- LanguageModelCompletionError::RateLimit(duration)
- }
- err @ (AnthropicError::ApiError(..) | AnthropicError::Other(..)) => {
- LanguageModelCompletionError::Other(anthropic_err_to_anyhow(err))
- }
- })
+ request.await.map_err(Into::into)
}
.boxed()
}
@@ -714,7 +707,7 @@ impl AnthropicEventMapper {
events.flat_map(move |event| {
futures::stream::iter(match event {
Ok(event) => self.map_event(event),
- Err(error) => vec![Err(LanguageModelCompletionError::Other(anyhow!(error)))],
+ Err(error) => vec![Err(error.into())],
})
})
}
@@ -859,9 +852,7 @@ impl AnthropicEventMapper {
vec![Ok(LanguageModelCompletionEvent::Stop(self.stop_reason))]
}
Event::Error { error } => {
- vec![Err(LanguageModelCompletionError::Other(anyhow!(
- AnthropicError::ApiError(error)
- )))]
+ vec![Err(error.into())]
}
_ => Vec::new(),
}
@@ -874,16 +865,6 @@ struct RawToolUse {
input_json: String,
}
-pub fn anthropic_err_to_anyhow(err: AnthropicError) -> anyhow::Error {
- if let AnthropicError::ApiError(api_err) = &err {
- if let Some(tokens) = api_err.match_window_exceeded() {
- return anyhow!(LanguageModelKnownError::ContextWindowLimitExceeded { tokens });
- }
- }
-
- anyhow!(err)
-}
-
/// Updates usage data by preferring counts from `new`.
fn update_usage(usage: &mut Usage, new: &Usage) {
if let Some(input_tokens) = new.input_tokens {