diff --git a/Cargo.lock b/Cargo.lock index 7d26e24a8d46081157165dff92a9cd820e615054..d11ca902d1edbdc838071bfdb7df10dea88f9c81 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -8872,6 +8872,7 @@ dependencies = [ "icons", "image", "log", + "open_ai", "open_router", "parking_lot", "proto", @@ -11025,6 +11026,7 @@ dependencies = [ "serde_json", "settings", "strum 0.27.2", + "thiserror 2.0.17", ] [[package]] diff --git a/crates/language_model/Cargo.toml b/crates/language_model/Cargo.toml index 4d40a063b604b405f7bcb29a3457956e1dd5541d..7c6470f4fa0c1eac847c1194e967b451093a76ad 100644 --- a/crates/language_model/Cargo.toml +++ b/crates/language_model/Cargo.toml @@ -29,6 +29,7 @@ http_client.workspace = true icons.workspace = true image.workspace = true log.workspace = true +open_ai = { workspace = true, features = ["schemars"] } open_router.workspace = true parking_lot.workspace = true proto.workspace = true diff --git a/crates/language_model/src/language_model.rs b/crates/language_model/src/language_model.rs index 94f6ec33f15062dd53b4122ca9d9dcac3fbff83d..4f0eed34331980ec0fd499c6a77e49e94b524fe0 100644 --- a/crates/language_model/src/language_model.rs +++ b/crates/language_model/src/language_model.rs @@ -345,6 +345,27 @@ impl From for LanguageModelCompletionError { } } +impl From for LanguageModelCompletionError { + fn from(error: open_ai::RequestError) -> Self { + match error { + open_ai::RequestError::HttpResponseError { + provider, + status_code, + body, + headers, + } => { + let retry_after = headers + .get(http::header::RETRY_AFTER) + .and_then(|val| val.to_str().ok()?.parse::().ok()) + .map(Duration::from_secs); + + Self::from_http_status(provider.into(), status_code, body, retry_after) + } + open_ai::RequestError::Other(e) => Self::Other(e), + } + } +} + impl From for LanguageModelCompletionError { fn from(error: OpenRouterError) -> Self { let provider = LanguageModelProviderName::new("OpenRouter"); diff --git a/crates/language_models/src/provider/open_ai.rs b/crates/language_models/src/provider/open_ai.rs index cabd78c35be58667fd799fe34de07e1d1bfa5808..792d280950ceafa24cdf5e4104b80dd49bd45f3f 100644 --- a/crates/language_models/src/provider/open_ai.rs +++ b/crates/language_models/src/provider/open_ai.rs @@ -226,12 +226,17 @@ impl OpenAiLanguageModel { }; let future = self.request_limiter.stream(async move { + let provider = PROVIDER_NAME; let Some(api_key) = api_key else { - return Err(LanguageModelCompletionError::NoApiKey { - provider: PROVIDER_NAME, - }); + return Err(LanguageModelCompletionError::NoApiKey { provider }); }; - let request = stream_completion(http_client.as_ref(), &api_url, &api_key, request); + let request = stream_completion( + http_client.as_ref(), + provider.0.as_str(), + &api_url, + &api_key, + request, + ); let response = request.await?; Ok(response) }); diff --git a/crates/language_models/src/provider/open_ai_compatible.rs b/crates/language_models/src/provider/open_ai_compatible.rs index 4ed0de851244d65b0f838c582ccdffe763d6775f..a30c8bfa5d3a728d6dd388f8e768cd470ee9736d 100644 --- a/crates/language_models/src/provider/open_ai_compatible.rs +++ b/crates/language_models/src/provider/open_ai_compatible.rs @@ -205,8 +205,13 @@ impl OpenAiCompatibleLanguageModel { &self, request: open_ai::Request, cx: &AsyncApp, - ) -> BoxFuture<'static, Result>>> - { + ) -> BoxFuture< + 'static, + Result< + futures::stream::BoxStream<'static, Result>, + LanguageModelCompletionError, + >, + > { let http_client = self.http_client.clone(); let Ok((api_key, api_url)) = self.state.read_with(cx, |state, _cx| { @@ -216,7 +221,7 @@ impl OpenAiCompatibleLanguageModel { state.settings.api_url.clone(), ) }) else { - return future::ready(Err(anyhow!("App state dropped"))).boxed(); + return future::ready(Err(anyhow!("App state dropped").into())).boxed(); }; let provider = self.provider_name.clone(); @@ -224,7 +229,13 @@ impl OpenAiCompatibleLanguageModel { let Some(api_key) = api_key else { return Err(LanguageModelCompletionError::NoApiKey { provider }); }; - let request = stream_completion(http_client.as_ref(), &api_url, &api_key, request); + let request = stream_completion( + http_client.as_ref(), + provider.0.as_str(), + &api_url, + &api_key, + request, + ); let response = request.await?; Ok(response) }); diff --git a/crates/language_models/src/provider/vercel.rs b/crates/language_models/src/provider/vercel.rs index 20db24274aae0249efcfc897cb1bdfdcce8f1220..061dc1799922c03952b1a96e2785425f61bcf00b 100644 --- a/crates/language_models/src/provider/vercel.rs +++ b/crates/language_models/src/provider/vercel.rs @@ -220,13 +220,17 @@ impl VercelLanguageModel { }; let future = self.request_limiter.stream(async move { + let provider = PROVIDER_NAME; let Some(api_key) = api_key else { - return Err(LanguageModelCompletionError::NoApiKey { - provider: PROVIDER_NAME, - }); + return Err(LanguageModelCompletionError::NoApiKey { provider }); }; - let request = - open_ai::stream_completion(http_client.as_ref(), &api_url, &api_key, request); + let request = open_ai::stream_completion( + http_client.as_ref(), + provider.0.as_str(), + &api_url, + &api_key, + request, + ); let response = request.await?; Ok(response) }); diff --git a/crates/language_models/src/provider/x_ai.rs b/crates/language_models/src/provider/x_ai.rs index e7ee71ba86e202fe17d567923f4b04d3c886ae08..cc54dfa0dd8a3f2ca6ab2b769a779afa8e73988b 100644 --- a/crates/language_models/src/provider/x_ai.rs +++ b/crates/language_models/src/provider/x_ai.rs @@ -211,25 +211,34 @@ impl XAiLanguageModel { &self, request: open_ai::Request, cx: &AsyncApp, - ) -> BoxFuture<'static, Result>>> - { + ) -> BoxFuture< + 'static, + Result< + futures::stream::BoxStream<'static, Result>, + LanguageModelCompletionError, + >, + > { let http_client = self.http_client.clone(); let Ok((api_key, api_url)) = self.state.read_with(cx, |state, cx| { let api_url = XAiLanguageModelProvider::api_url(cx); (state.api_key_state.key(&api_url), api_url) }) else { - return future::ready(Err(anyhow!("App state dropped"))).boxed(); + return future::ready(Err(anyhow!("App state dropped").into())).boxed(); }; let future = self.request_limiter.stream(async move { + let provider = PROVIDER_NAME; let Some(api_key) = api_key else { - return Err(LanguageModelCompletionError::NoApiKey { - provider: PROVIDER_NAME, - }); + return Err(LanguageModelCompletionError::NoApiKey { provider }); }; - let request = - open_ai::stream_completion(http_client.as_ref(), &api_url, &api_key, request); + let request = open_ai::stream_completion( + http_client.as_ref(), + provider.0.as_str(), + &api_url, + &api_key, + request, + ); let response = request.await?; Ok(response) }); diff --git a/crates/open_ai/Cargo.toml b/crates/open_ai/Cargo.toml index 49284eff79c11414c0811abd107f7c16ca701179..037ca14437cd13a6fc4bfe76dafb113c6a9f1482 100644 --- a/crates/open_ai/Cargo.toml +++ b/crates/open_ai/Cargo.toml @@ -25,3 +25,4 @@ serde.workspace = true serde_json.workspace = true settings.workspace = true strum.workspace = true +thiserror.workspace = true diff --git a/crates/open_ai/src/open_ai.rs b/crates/open_ai/src/open_ai.rs index e1f58fe95a487f5be650d758df32b8097ee578e4..aaeee01c9c74f8592ccfffa01893f9333f120e89 100644 --- a/crates/open_ai/src/open_ai.rs +++ b/crates/open_ai/src/open_ai.rs @@ -1,11 +1,15 @@ use anyhow::{Context as _, Result, anyhow}; use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream}; -use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest}; +use http_client::{ + AsyncBody, HttpClient, Method, Request as HttpRequest, StatusCode, + http::{HeaderMap, HeaderValue}, +}; use serde::{Deserialize, Serialize}; use serde_json::Value; pub use settings::OpenAiReasoningEffort as ReasoningEffort; use std::{convert::TryFrom, future::Future}; use strum::EnumIter; +use thiserror::Error; pub const OPEN_AI_API_URL: &str = "https://api.openai.com/v1"; @@ -441,8 +445,21 @@ pub struct ChoiceDelta { pub finish_reason: Option, } +#[derive(Error, Debug)] +pub enum RequestError { + #[error("HTTP response error from {provider}'s API: status {status_code} - {body:?}")] + HttpResponseError { + provider: String, + status_code: StatusCode, + body: String, + headers: HeaderMap, + }, + #[error(transparent)] + Other(#[from] anyhow::Error), +} + #[derive(Serialize, Deserialize, Debug)] -pub struct OpenAiError { +pub struct ResponseStreamError { message: String, } @@ -450,7 +467,7 @@ pub struct OpenAiError { #[serde(untagged)] pub enum ResponseStreamResult { Ok(ResponseStreamEvent), - Err { error: OpenAiError }, + Err { error: ResponseStreamError }, } #[derive(Serialize, Deserialize, Debug)] @@ -461,10 +478,11 @@ pub struct ResponseStreamEvent { pub async fn stream_completion( client: &dyn HttpClient, + provider_name: &str, api_url: &str, api_key: &str, request: Request, -) -> Result>> { +) -> Result>, RequestError> { let uri = format!("{api_url}/chat/completions"); let request_builder = HttpRequest::builder() .method(Method::POST) @@ -472,7 +490,12 @@ pub async fn stream_completion( .header("Content-Type", "application/json") .header("Authorization", format!("Bearer {}", api_key.trim())); - let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?; + let request = request_builder + .body(AsyncBody::from( + serde_json::to_string(&request).map_err(|e| RequestError::Other(e.into()))?, + )) + .map_err(|e| RequestError::Other(e.into()))?; + let mut response = client.send(request).await?; if response.status().is_success() { let reader = BufReader::new(response.into_body()); @@ -508,27 +531,18 @@ pub async fn stream_completion( .boxed()) } else { let mut body = String::new(); - response.body_mut().read_to_string(&mut body).await?; - - #[derive(Deserialize)] - struct OpenAiResponse { - error: OpenAiError, - } - - match serde_json::from_str::(&body) { - Ok(response) if !response.error.message.is_empty() => Err(anyhow!( - "API request to {} failed: {}", - api_url, - response.error.message, - )), - - _ => anyhow::bail!( - "API request to {} failed with status {}: {}", - api_url, - response.status(), - body, - ), - } + response + .body_mut() + .read_to_string(&mut body) + .await + .map_err(|e| RequestError::Other(e.into()))?; + + Err(RequestError::HttpResponseError { + provider: provider_name.to_owned(), + status_code: response.status(), + body, + headers: response.headers().clone(), + }) } }