diff --git a/Cargo.lock b/Cargo.lock index c935d47f58b42f22c93dad55047c70d60b80b4db..acc8a72e88b077127650333b00c237a71831ca28 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -255,6 +255,7 @@ dependencies = [ "serde", "serde_json", "strum", + "thiserror", "tokio", ] diff --git a/crates/anthropic/Cargo.toml b/crates/anthropic/Cargo.toml index 2d7da211a40a2cc260149ecc8c7eaa72d68d4db8..4628d3db809cf8f33672bfb37ad78ca723265aaf 100644 --- a/crates/anthropic/Cargo.toml +++ b/crates/anthropic/Cargo.toml @@ -24,6 +24,7 @@ schemars = { workspace = true, optional = true } serde.workspace = true serde_json.workspace = true strum.workspace = true +thiserror.workspace = true [dev-dependencies] tokio.workspace = true diff --git a/crates/anthropic/src/anthropic.rs b/crates/anthropic/src/anthropic.rs index 5bb16eeb2c73180330199af5a221e817b4e9d949..0ceee553d21eeb5ac33e5a716ca7d549baf0ae46 100644 --- a/crates/anthropic/src/anthropic.rs +++ b/crates/anthropic/src/anthropic.rs @@ -1,12 +1,14 @@ mod supported_countries; -use anyhow::{anyhow, Result}; +use anyhow::{anyhow, Context, Result}; use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt}; use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest}; use isahc::config::Configurable; use serde::{Deserialize, Serialize}; +use std::str::FromStr; use std::time::Duration; -use strum::EnumIter; +use strum::{EnumIter, EnumString}; +use thiserror::Error; pub use supported_countries::*; @@ -96,7 +98,7 @@ pub async fn complete( api_url: &str, api_key: &str, request: Request, -) -> Result { +) -> Result { let uri = format!("{api_url}/v1/messages"); let request_builder = HttpRequest::builder() .method(Method::POST) @@ -106,24 +108,40 @@ pub async fn complete( .header("X-Api-Key", api_key) .header("Content-Type", "application/json"); - let serialized_request = serde_json::to_string(&request)?; - let request = request_builder.body(AsyncBody::from(serialized_request))?; + let serialized_request = + serde_json::to_string(&request).context("failed to serialize request")?; + let request = request_builder + .body(AsyncBody::from(serialized_request)) + .context("failed to construct request body")?; - let mut response = client.send(request).await?; + let mut response = client + .send(request) + .await + .context("failed to send request to Anthropic")?; if response.status().is_success() { let mut body = Vec::new(); - response.body_mut().read_to_end(&mut body).await?; - let response_message: Response = serde_json::from_slice(&body)?; + response + .body_mut() + .read_to_end(&mut body) + .await + .context("failed to read response body")?; + let response_message: Response = + serde_json::from_slice(&body).context("failed to deserialize response body")?; Ok(response_message) } else { let mut body = Vec::new(); - response.body_mut().read_to_end(&mut body).await?; - let body_str = std::str::from_utf8(&body)?; - Err(anyhow!( + response + .body_mut() + .read_to_end(&mut body) + .await + .context("failed to read response body")?; + let body_str = + std::str::from_utf8(&body).context("failed to parse response body as UTF-8")?; + Err(AnthropicError::Other(anyhow!( "Failed to connect to API: {} {}", response.status(), body_str - )) + ))) } } @@ -133,7 +151,7 @@ pub async fn stream_completion( api_key: &str, request: Request, low_speed_timeout: Option, -) -> Result>> { +) -> Result>, AnthropicError> { let request = StreamingRequest { base: request, stream: true, @@ -149,10 +167,16 @@ pub async fn stream_completion( if let Some(low_speed_timeout) = low_speed_timeout { request_builder = request_builder.low_speed_timeout(100, low_speed_timeout); } - let serialized_request = serde_json::to_string(&request)?; - let request = request_builder.body(AsyncBody::from(serialized_request))?; - - let mut response = client.send(request).await?; + let serialized_request = + serde_json::to_string(&request).context("failed to serialize request")?; + let request = request_builder + .body(AsyncBody::from(serialized_request)) + .context("failed to construct request body")?; + + let mut response = client + .send(request) + .await + .context("failed to send request to Anthropic")?; if response.status().is_success() { let reader = BufReader::new(response.into_body()); Ok(reader @@ -163,36 +187,41 @@ pub async fn stream_completion( let line = line.strip_prefix("data: ")?; match serde_json::from_str(line) { Ok(response) => Some(Ok(response)), - Err(error) => Some(Err(anyhow!(error))), + Err(error) => Some(Err(AnthropicError::Other(anyhow!(error)))), } } - Err(error) => Some(Err(anyhow!(error))), + Err(error) => Some(Err(AnthropicError::Other(anyhow!(error)))), } }) .boxed()) } else { let mut body = Vec::new(); - response.body_mut().read_to_end(&mut body).await?; + response + .body_mut() + .read_to_end(&mut body) + .await + .context("failed to read response body")?; - let body_str = std::str::from_utf8(&body)?; + let body_str = + std::str::from_utf8(&body).context("failed to parse response body as UTF-8")?; match serde_json::from_str::(body_str) { - Ok(Event::Error { error }) => Err(api_error_to_err(error)), - Ok(_) => Err(anyhow!( + Ok(Event::Error { error }) => Err(AnthropicError::ApiError(error)), + Ok(_) => Err(AnthropicError::Other(anyhow!( "Unexpected success response while expecting an error: '{body_str}'", - )), - Err(_) => Err(anyhow!( + ))), + Err(_) => Err(AnthropicError::Other(anyhow!( "Failed to connect to API: {} {}", response.status(), body_str, - )), + ))), } } } pub fn extract_text_from_events( - response: impl Stream>, -) -> impl Stream> { + response: impl Stream>, +) -> impl Stream> { response.filter_map(|response| async move { match response { Ok(response) => match response { @@ -204,7 +233,7 @@ pub fn extract_text_from_events( ContentDelta::TextDelta { text } => Some(Ok(text)), _ => None, }, - Event::Error { error } => Some(Err(api_error_to_err(error))), + Event::Error { error } => Some(Err(AnthropicError::ApiError(error))), _ => None, }, Err(error) => Some(Err(error)), @@ -212,15 +241,6 @@ pub fn extract_text_from_events( }) } -fn api_error_to_err( - ApiError { - error_type, - message, - }: ApiError, -) -> anyhow::Error { - anyhow!("API error. Type: '{error_type}', message: '{message}'",) -} - #[derive(Debug, Serialize, Deserialize)] pub struct Message { pub role: Role, @@ -374,9 +394,53 @@ pub struct MessageDelta { pub stop_sequence: Option, } +#[derive(Error, Debug)] +pub enum AnthropicError { + #[error("an error occurred while interacting with the Anthropic API: {error_type}: {message}", error_type = .0.error_type, message = .0.message)] + ApiError(ApiError), + #[error("{0}")] + Other(#[from] anyhow::Error), +} + #[derive(Debug, Serialize, Deserialize)] pub struct ApiError { #[serde(rename = "type")] pub error_type: String, pub message: String, } + +/// An Anthropic API error code. +/// https://docs.anthropic.com/en/api/errors#http-errors +#[derive(Debug, PartialEq, Eq, Clone, Copy, EnumString)] +#[strum(serialize_all = "snake_case")] +pub enum ApiErrorCode { + /// 400 - `invalid_request_error`: There was an issue with the format or content of your request. + InvalidRequestError, + /// 401 - `authentication_error`: There's an issue with your API key. + AuthenticationError, + /// 403 - `permission_error`: Your API key does not have permission to use the specified resource. + PermissionError, + /// 404 - `not_found_error`: The requested resource was not found. + NotFoundError, + /// 413 - `request_too_large`: Request exceeds the maximum allowed number of bytes. + RequestTooLarge, + /// 429 - `rate_limit_error`: Your account has hit a rate limit. + RateLimitError, + /// 500 - `api_error`: An unexpected error has occurred internal to Anthropic's systems. + ApiError, + /// 529 - `overloaded_error`: Anthropic's API is temporarily overloaded. + OverloadedError, +} + +impl ApiError { + pub fn code(&self) -> Option { + ApiErrorCode::from_str(&self.error_type).ok() + } + + pub fn is_rate_limit_error(&self) -> bool { + match self.error_type.as_str() { + "rate_limit_error" => true, + _ => false, + } + } +} diff --git a/crates/collab/src/llm.rs b/crates/collab/src/llm.rs index e406526fbf22a8cb52948016a642be2b6e45b517..fa743619418ea5b5f98d01fbc6f5ff9bdb05e0fe 100644 --- a/crates/collab/src/llm.rs +++ b/crates/collab/src/llm.rs @@ -197,7 +197,20 @@ async fn perform_completion( request, None, ) - .await?; + .await + .map_err(|err| match err { + anthropic::AnthropicError::ApiError(ref api_error) => { + if api_error.code() == Some(anthropic::ApiErrorCode::RateLimitError) { + return Error::http( + StatusCode::TOO_MANY_REQUESTS, + "Upstream Anthropic rate limit exceeded.".to_string(), + ); + } + + Error::Internal(anyhow!(err)) + } + anthropic::AnthropicError::Other(err) => Error::Internal(err), + })?; chunks .map(move |event| { diff --git a/crates/language_model/src/provider/anthropic.rs b/crates/language_model/src/provider/anthropic.rs index 2fae6e4b099be292d09e1c5223edd0309e3b9b77..d65b789dc95f1d828222cc47930326c30d6178d3 100644 --- a/crates/language_model/src/provider/anthropic.rs +++ b/crates/language_model/src/provider/anthropic.rs @@ -3,6 +3,7 @@ use crate::{ LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role, }; +use anthropic::AnthropicError; use anyhow::{anyhow, Context as _, Result}; use collections::BTreeMap; use editor::{Editor, EditorElement, EditorStyle}; @@ -259,7 +260,9 @@ impl AnthropicModel { async move { let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?; - anthropic::complete(http_client.as_ref(), &api_url, &api_key, request).await + anthropic::complete(http_client.as_ref(), &api_url, &api_key, request) + .await + .context("failed to retrieve completion") } .boxed() } @@ -268,7 +271,8 @@ impl AnthropicModel { &self, request: anthropic::Request, cx: &AsyncAppContext, - ) -> BoxFuture<'static, Result>>> { + ) -> BoxFuture<'static, Result>>> + { let http_client = self.http_client.clone(); let Ok((api_key, api_url, low_speed_timeout)) = cx.read_model(&self.state, |state, cx| { @@ -291,7 +295,7 @@ impl AnthropicModel { request, low_speed_timeout, ); - request.await + request.await.context("failed to stream completion") } .boxed() } @@ -338,10 +342,16 @@ impl LanguageModel for AnthropicModel { let request = request.into_anthropic(self.model.id().into()); let request = self.stream_completion(request, cx); let future = self.request_limiter.stream(async move { - let response = request.await?; + let response = request.await.map_err(|err| anyhow!(err))?; Ok(anthropic::extract_text_from_events(response)) }); - async move { Ok(future.await?.boxed()) }.boxed() + async move { + Ok(future + .await? + .map(|result| result.map_err(|err| anyhow!(err))) + .boxed()) + } + .boxed() } fn use_any_tool( diff --git a/crates/language_model/src/provider/cloud.rs b/crates/language_model/src/provider/cloud.rs index 1f1ae9295643d5391932ae3d5eef6e0e18b5b212..96be497068e54c18c539a9746b79493fbe7e46fa 100644 --- a/crates/language_model/src/provider/cloud.rs +++ b/crates/language_model/src/provider/cloud.rs @@ -4,7 +4,8 @@ use crate::{ LanguageModelName, LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, RateLimiter, ZedModel, }; -use anyhow::{anyhow, bail, Result}; +use anthropic::AnthropicError; +use anyhow::{anyhow, bail, Context as _, Result}; use client::{Client, PerformCompletionParams, UserStore, EXPIRED_LLM_TOKEN_HEADER_NAME}; use collections::BTreeMap; use feature_flags::{FeatureFlagAppExt, LanguageModels}; @@ -446,16 +447,23 @@ impl LanguageModel for CloudLanguageModel { match body.read_line(&mut buffer).await { Ok(0) => Ok(None), Ok(_) => { - let event: anthropic::Event = serde_json::from_str(&buffer)?; + let event: anthropic::Event = serde_json::from_str(&buffer) + .context("failed to parse Anthropic event")?; Ok(Some((event, body))) } - Err(e) => Err(e.into()), + Err(err) => Err(AnthropicError::Other(err.into())), } }); Ok(anthropic::extract_text_from_events(stream)) }); - async move { Ok(future.await?.boxed()) }.boxed() + async move { + Ok(future + .await? + .map(|result| result.map_err(|err| anyhow!(err))) + .boxed()) + } + .boxed() } CloudModel::OpenAi(model) => { let client = self.client.clone();