From fb90b12073aabd8c753e048d0fe0a185a1de2b2c Mon Sep 17 00:00:00 2001 From: Tim McLean Date: Thu, 13 Nov 2025 09:15:46 -0500 Subject: [PATCH] Add retry support for OpenAI-compatible LLM providers (#37891) Automatically retry the agent's LLM completion requests when the provider returns 429 Too Many Requests. Uses the Retry-After header to determine the retry delay if it is available. Many providers are frequently overloaded or have low rate limits. These providers are essentially unusable without automatic retries. Tested with Cerebras configured via openai_compatible. Related: #31531 Release Notes: - Added automatic retries for OpenAI-compatible LLM providers --------- Co-authored-by: Bennet Bo Fenner --- Cargo.lock | 2 + crates/language_model/Cargo.toml | 1 + crates/language_model/src/language_model.rs | 21 ++++++ .../language_models/src/provider/open_ai.rs | 13 ++-- .../src/provider/open_ai_compatible.rs | 19 ++++-- crates/language_models/src/provider/vercel.rs | 14 ++-- crates/language_models/src/provider/x_ai.rs | 25 ++++--- crates/open_ai/Cargo.toml | 1 + crates/open_ai/src/open_ai.rs | 66 +++++++++++-------- 9 files changed, 115 insertions(+), 47 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 7d26e24a8d46081157165dff92a9cd820e615054..d11ca902d1edbdc838071bfdb7df10dea88f9c81 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -8872,6 +8872,7 @@ dependencies = [ "icons", "image", "log", + "open_ai", "open_router", "parking_lot", "proto", @@ -11025,6 +11026,7 @@ dependencies = [ "serde_json", "settings", "strum 0.27.2", + "thiserror 2.0.17", ] [[package]] diff --git a/crates/language_model/Cargo.toml b/crates/language_model/Cargo.toml index 4d40a063b604b405f7bcb29a3457956e1dd5541d..7c6470f4fa0c1eac847c1194e967b451093a76ad 100644 --- a/crates/language_model/Cargo.toml +++ b/crates/language_model/Cargo.toml @@ -29,6 +29,7 @@ http_client.workspace = true icons.workspace = true image.workspace = true log.workspace = true +open_ai = { workspace = true, features = ["schemars"] } open_router.workspace = true parking_lot.workspace = true proto.workspace = true diff --git a/crates/language_model/src/language_model.rs b/crates/language_model/src/language_model.rs index 94f6ec33f15062dd53b4122ca9d9dcac3fbff83d..4f0eed34331980ec0fd499c6a77e49e94b524fe0 100644 --- a/crates/language_model/src/language_model.rs +++ b/crates/language_model/src/language_model.rs @@ -345,6 +345,27 @@ impl From for LanguageModelCompletionError { } } +impl From for LanguageModelCompletionError { + fn from(error: open_ai::RequestError) -> Self { + match error { + open_ai::RequestError::HttpResponseError { + provider, + status_code, + body, + headers, + } => { + let retry_after = headers + .get(http::header::RETRY_AFTER) + .and_then(|val| val.to_str().ok()?.parse::().ok()) + .map(Duration::from_secs); + + Self::from_http_status(provider.into(), status_code, body, retry_after) + } + open_ai::RequestError::Other(e) => Self::Other(e), + } + } +} + impl From for LanguageModelCompletionError { fn from(error: OpenRouterError) -> Self { let provider = LanguageModelProviderName::new("OpenRouter"); diff --git a/crates/language_models/src/provider/open_ai.rs b/crates/language_models/src/provider/open_ai.rs index cabd78c35be58667fd799fe34de07e1d1bfa5808..792d280950ceafa24cdf5e4104b80dd49bd45f3f 100644 --- a/crates/language_models/src/provider/open_ai.rs +++ b/crates/language_models/src/provider/open_ai.rs @@ -226,12 +226,17 @@ impl OpenAiLanguageModel { }; let future = self.request_limiter.stream(async move { + let provider = PROVIDER_NAME; let Some(api_key) = api_key else { - return Err(LanguageModelCompletionError::NoApiKey { - provider: PROVIDER_NAME, - }); + return Err(LanguageModelCompletionError::NoApiKey { provider }); }; - let request = stream_completion(http_client.as_ref(), &api_url, &api_key, request); + let request = stream_completion( + http_client.as_ref(), + provider.0.as_str(), + &api_url, + &api_key, + request, + ); let response = request.await?; Ok(response) }); diff --git a/crates/language_models/src/provider/open_ai_compatible.rs b/crates/language_models/src/provider/open_ai_compatible.rs index 4ed0de851244d65b0f838c582ccdffe763d6775f..a30c8bfa5d3a728d6dd388f8e768cd470ee9736d 100644 --- a/crates/language_models/src/provider/open_ai_compatible.rs +++ b/crates/language_models/src/provider/open_ai_compatible.rs @@ -205,8 +205,13 @@ impl OpenAiCompatibleLanguageModel { &self, request: open_ai::Request, cx: &AsyncApp, - ) -> BoxFuture<'static, Result>>> - { + ) -> BoxFuture< + 'static, + Result< + futures::stream::BoxStream<'static, Result>, + LanguageModelCompletionError, + >, + > { let http_client = self.http_client.clone(); let Ok((api_key, api_url)) = self.state.read_with(cx, |state, _cx| { @@ -216,7 +221,7 @@ impl OpenAiCompatibleLanguageModel { state.settings.api_url.clone(), ) }) else { - return future::ready(Err(anyhow!("App state dropped"))).boxed(); + return future::ready(Err(anyhow!("App state dropped").into())).boxed(); }; let provider = self.provider_name.clone(); @@ -224,7 +229,13 @@ impl OpenAiCompatibleLanguageModel { let Some(api_key) = api_key else { return Err(LanguageModelCompletionError::NoApiKey { provider }); }; - let request = stream_completion(http_client.as_ref(), &api_url, &api_key, request); + let request = stream_completion( + http_client.as_ref(), + provider.0.as_str(), + &api_url, + &api_key, + request, + ); let response = request.await?; Ok(response) }); diff --git a/crates/language_models/src/provider/vercel.rs b/crates/language_models/src/provider/vercel.rs index 20db24274aae0249efcfc897cb1bdfdcce8f1220..061dc1799922c03952b1a96e2785425f61bcf00b 100644 --- a/crates/language_models/src/provider/vercel.rs +++ b/crates/language_models/src/provider/vercel.rs @@ -220,13 +220,17 @@ impl VercelLanguageModel { }; let future = self.request_limiter.stream(async move { + let provider = PROVIDER_NAME; let Some(api_key) = api_key else { - return Err(LanguageModelCompletionError::NoApiKey { - provider: PROVIDER_NAME, - }); + return Err(LanguageModelCompletionError::NoApiKey { provider }); }; - let request = - open_ai::stream_completion(http_client.as_ref(), &api_url, &api_key, request); + let request = open_ai::stream_completion( + http_client.as_ref(), + provider.0.as_str(), + &api_url, + &api_key, + request, + ); let response = request.await?; Ok(response) }); diff --git a/crates/language_models/src/provider/x_ai.rs b/crates/language_models/src/provider/x_ai.rs index e7ee71ba86e202fe17d567923f4b04d3c886ae08..cc54dfa0dd8a3f2ca6ab2b769a779afa8e73988b 100644 --- a/crates/language_models/src/provider/x_ai.rs +++ b/crates/language_models/src/provider/x_ai.rs @@ -211,25 +211,34 @@ impl XAiLanguageModel { &self, request: open_ai::Request, cx: &AsyncApp, - ) -> BoxFuture<'static, Result>>> - { + ) -> BoxFuture< + 'static, + Result< + futures::stream::BoxStream<'static, Result>, + LanguageModelCompletionError, + >, + > { let http_client = self.http_client.clone(); let Ok((api_key, api_url)) = self.state.read_with(cx, |state, cx| { let api_url = XAiLanguageModelProvider::api_url(cx); (state.api_key_state.key(&api_url), api_url) }) else { - return future::ready(Err(anyhow!("App state dropped"))).boxed(); + return future::ready(Err(anyhow!("App state dropped").into())).boxed(); }; let future = self.request_limiter.stream(async move { + let provider = PROVIDER_NAME; let Some(api_key) = api_key else { - return Err(LanguageModelCompletionError::NoApiKey { - provider: PROVIDER_NAME, - }); + return Err(LanguageModelCompletionError::NoApiKey { provider }); }; - let request = - open_ai::stream_completion(http_client.as_ref(), &api_url, &api_key, request); + let request = open_ai::stream_completion( + http_client.as_ref(), + provider.0.as_str(), + &api_url, + &api_key, + request, + ); let response = request.await?; Ok(response) }); diff --git a/crates/open_ai/Cargo.toml b/crates/open_ai/Cargo.toml index 49284eff79c11414c0811abd107f7c16ca701179..037ca14437cd13a6fc4bfe76dafb113c6a9f1482 100644 --- a/crates/open_ai/Cargo.toml +++ b/crates/open_ai/Cargo.toml @@ -25,3 +25,4 @@ serde.workspace = true serde_json.workspace = true settings.workspace = true strum.workspace = true +thiserror.workspace = true diff --git a/crates/open_ai/src/open_ai.rs b/crates/open_ai/src/open_ai.rs index e1f58fe95a487f5be650d758df32b8097ee578e4..aaeee01c9c74f8592ccfffa01893f9333f120e89 100644 --- a/crates/open_ai/src/open_ai.rs +++ b/crates/open_ai/src/open_ai.rs @@ -1,11 +1,15 @@ use anyhow::{Context as _, Result, anyhow}; use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream}; -use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest}; +use http_client::{ + AsyncBody, HttpClient, Method, Request as HttpRequest, StatusCode, + http::{HeaderMap, HeaderValue}, +}; use serde::{Deserialize, Serialize}; use serde_json::Value; pub use settings::OpenAiReasoningEffort as ReasoningEffort; use std::{convert::TryFrom, future::Future}; use strum::EnumIter; +use thiserror::Error; pub const OPEN_AI_API_URL: &str = "https://api.openai.com/v1"; @@ -441,8 +445,21 @@ pub struct ChoiceDelta { pub finish_reason: Option, } +#[derive(Error, Debug)] +pub enum RequestError { + #[error("HTTP response error from {provider}'s API: status {status_code} - {body:?}")] + HttpResponseError { + provider: String, + status_code: StatusCode, + body: String, + headers: HeaderMap, + }, + #[error(transparent)] + Other(#[from] anyhow::Error), +} + #[derive(Serialize, Deserialize, Debug)] -pub struct OpenAiError { +pub struct ResponseStreamError { message: String, } @@ -450,7 +467,7 @@ pub struct OpenAiError { #[serde(untagged)] pub enum ResponseStreamResult { Ok(ResponseStreamEvent), - Err { error: OpenAiError }, + Err { error: ResponseStreamError }, } #[derive(Serialize, Deserialize, Debug)] @@ -461,10 +478,11 @@ pub struct ResponseStreamEvent { pub async fn stream_completion( client: &dyn HttpClient, + provider_name: &str, api_url: &str, api_key: &str, request: Request, -) -> Result>> { +) -> Result>, RequestError> { let uri = format!("{api_url}/chat/completions"); let request_builder = HttpRequest::builder() .method(Method::POST) @@ -472,7 +490,12 @@ pub async fn stream_completion( .header("Content-Type", "application/json") .header("Authorization", format!("Bearer {}", api_key.trim())); - let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?; + let request = request_builder + .body(AsyncBody::from( + serde_json::to_string(&request).map_err(|e| RequestError::Other(e.into()))?, + )) + .map_err(|e| RequestError::Other(e.into()))?; + let mut response = client.send(request).await?; if response.status().is_success() { let reader = BufReader::new(response.into_body()); @@ -508,27 +531,18 @@ pub async fn stream_completion( .boxed()) } else { let mut body = String::new(); - response.body_mut().read_to_string(&mut body).await?; - - #[derive(Deserialize)] - struct OpenAiResponse { - error: OpenAiError, - } - - match serde_json::from_str::(&body) { - Ok(response) if !response.error.message.is_empty() => Err(anyhow!( - "API request to {} failed: {}", - api_url, - response.error.message, - )), - - _ => anyhow::bail!( - "API request to {} failed with status {}: {}", - api_url, - response.status(), - body, - ), - } + response + .body_mut() + .read_to_string(&mut body) + .await + .map_err(|e| RequestError::Other(e.into()))?; + + Err(RequestError::HttpResponseError { + provider: provider_name.to_owned(), + status_code: response.status(), + body, + headers: response.headers().clone(), + }) } }