Add retry support for OpenAI-compatible LLM providers (#37891)

Tim McLean and Bennet Bo Fenner created

Automatically retry the agent's LLM completion requests when the
provider returns 429 Too Many Requests. Uses the Retry-After header to
determine the retry delay if it is available.

Many providers are frequently overloaded or have low rate limits. These
providers are essentially unusable without automatic retries.

Tested with Cerebras configured via openai_compatible.

Related: #31531 

Release Notes:

- Added automatic retries for OpenAI-compatible LLM providers

---------

Co-authored-by: Bennet Bo Fenner <bennetbo@gmx.de>

Change summary

Cargo.lock                                                |  2 
crates/language_model/Cargo.toml                          |  1 
crates/language_model/src/language_model.rs               | 21 ++
crates/language_models/src/provider/open_ai.rs            | 13 +
crates/language_models/src/provider/open_ai_compatible.rs | 19 ++
crates/language_models/src/provider/vercel.rs             | 14 +
crates/language_models/src/provider/x_ai.rs               | 25 ++-
crates/open_ai/Cargo.toml                                 |  1 
crates/open_ai/src/open_ai.rs                             | 66 +++++---
9 files changed, 115 insertions(+), 47 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -8872,6 +8872,7 @@ dependencies = [
  "icons",
  "image",
  "log",
+ "open_ai",
  "open_router",
  "parking_lot",
  "proto",
@@ -11025,6 +11026,7 @@ dependencies = [
  "serde_json",
  "settings",
  "strum 0.27.2",
+ "thiserror 2.0.17",
 ]
 
 [[package]]

crates/language_model/Cargo.toml 🔗

@@ -29,6 +29,7 @@ http_client.workspace = true
 icons.workspace = true
 image.workspace = true
 log.workspace = true
+open_ai = { workspace = true, features = ["schemars"] }
 open_router.workspace = true
 parking_lot.workspace = true
 proto.workspace = true

crates/language_model/src/language_model.rs 🔗

@@ -345,6 +345,27 @@ impl From<anthropic::ApiError> for LanguageModelCompletionError {
     }
 }
 
+impl From<open_ai::RequestError> for LanguageModelCompletionError {
+    fn from(error: open_ai::RequestError) -> Self {
+        match error {
+            open_ai::RequestError::HttpResponseError {
+                provider,
+                status_code,
+                body,
+                headers,
+            } => {
+                let retry_after = headers
+                    .get(http::header::RETRY_AFTER)
+                    .and_then(|val| val.to_str().ok()?.parse::<u64>().ok())
+                    .map(Duration::from_secs);
+
+                Self::from_http_status(provider.into(), status_code, body, retry_after)
+            }
+            open_ai::RequestError::Other(e) => Self::Other(e),
+        }
+    }
+}
+
 impl From<OpenRouterError> for LanguageModelCompletionError {
     fn from(error: OpenRouterError) -> Self {
         let provider = LanguageModelProviderName::new("OpenRouter");

crates/language_models/src/provider/open_ai.rs 🔗

@@ -226,12 +226,17 @@ impl OpenAiLanguageModel {
         };
 
         let future = self.request_limiter.stream(async move {
+            let provider = PROVIDER_NAME;
             let Some(api_key) = api_key else {
-                return Err(LanguageModelCompletionError::NoApiKey {
-                    provider: PROVIDER_NAME,
-                });
+                return Err(LanguageModelCompletionError::NoApiKey { provider });
             };
-            let request = stream_completion(http_client.as_ref(), &api_url, &api_key, request);
+            let request = stream_completion(
+                http_client.as_ref(),
+                provider.0.as_str(),
+                &api_url,
+                &api_key,
+                request,
+            );
             let response = request.await?;
             Ok(response)
         });

crates/language_models/src/provider/open_ai_compatible.rs 🔗

@@ -205,8 +205,13 @@ impl OpenAiCompatibleLanguageModel {
         &self,
         request: open_ai::Request,
         cx: &AsyncApp,
-    ) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<ResponseStreamEvent>>>>
-    {
+    ) -> BoxFuture<
+        'static,
+        Result<
+            futures::stream::BoxStream<'static, Result<ResponseStreamEvent>>,
+            LanguageModelCompletionError,
+        >,
+    > {
         let http_client = self.http_client.clone();
 
         let Ok((api_key, api_url)) = self.state.read_with(cx, |state, _cx| {
@@ -216,7 +221,7 @@ impl OpenAiCompatibleLanguageModel {
                 state.settings.api_url.clone(),
             )
         }) else {
-            return future::ready(Err(anyhow!("App state dropped"))).boxed();
+            return future::ready(Err(anyhow!("App state dropped").into())).boxed();
         };
 
         let provider = self.provider_name.clone();
@@ -224,7 +229,13 @@ impl OpenAiCompatibleLanguageModel {
             let Some(api_key) = api_key else {
                 return Err(LanguageModelCompletionError::NoApiKey { provider });
             };
-            let request = stream_completion(http_client.as_ref(), &api_url, &api_key, request);
+            let request = stream_completion(
+                http_client.as_ref(),
+                provider.0.as_str(),
+                &api_url,
+                &api_key,
+                request,
+            );
             let response = request.await?;
             Ok(response)
         });

crates/language_models/src/provider/vercel.rs 🔗

@@ -220,13 +220,17 @@ impl VercelLanguageModel {
         };
 
         let future = self.request_limiter.stream(async move {
+            let provider = PROVIDER_NAME;
             let Some(api_key) = api_key else {
-                return Err(LanguageModelCompletionError::NoApiKey {
-                    provider: PROVIDER_NAME,
-                });
+                return Err(LanguageModelCompletionError::NoApiKey { provider });
             };
-            let request =
-                open_ai::stream_completion(http_client.as_ref(), &api_url, &api_key, request);
+            let request = open_ai::stream_completion(
+                http_client.as_ref(),
+                provider.0.as_str(),
+                &api_url,
+                &api_key,
+                request,
+            );
             let response = request.await?;
             Ok(response)
         });

crates/language_models/src/provider/x_ai.rs 🔗

@@ -211,25 +211,34 @@ impl XAiLanguageModel {
         &self,
         request: open_ai::Request,
         cx: &AsyncApp,
-    ) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<ResponseStreamEvent>>>>
-    {
+    ) -> BoxFuture<
+        'static,
+        Result<
+            futures::stream::BoxStream<'static, Result<ResponseStreamEvent>>,
+            LanguageModelCompletionError,
+        >,
+    > {
         let http_client = self.http_client.clone();
 
         let Ok((api_key, api_url)) = self.state.read_with(cx, |state, cx| {
             let api_url = XAiLanguageModelProvider::api_url(cx);
             (state.api_key_state.key(&api_url), api_url)
         }) else {
-            return future::ready(Err(anyhow!("App state dropped"))).boxed();
+            return future::ready(Err(anyhow!("App state dropped").into())).boxed();
         };
 
         let future = self.request_limiter.stream(async move {
+            let provider = PROVIDER_NAME;
             let Some(api_key) = api_key else {
-                return Err(LanguageModelCompletionError::NoApiKey {
-                    provider: PROVIDER_NAME,
-                });
+                return Err(LanguageModelCompletionError::NoApiKey { provider });
             };
-            let request =
-                open_ai::stream_completion(http_client.as_ref(), &api_url, &api_key, request);
+            let request = open_ai::stream_completion(
+                http_client.as_ref(),
+                provider.0.as_str(),
+                &api_url,
+                &api_key,
+                request,
+            );
             let response = request.await?;
             Ok(response)
         });

crates/open_ai/Cargo.toml 🔗

@@ -25,3 +25,4 @@ serde.workspace = true
 serde_json.workspace = true
 settings.workspace = true
 strum.workspace = true
+thiserror.workspace = true

crates/open_ai/src/open_ai.rs 🔗

@@ -1,11 +1,15 @@
 use anyhow::{Context as _, Result, anyhow};
 use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
-use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
+use http_client::{
+    AsyncBody, HttpClient, Method, Request as HttpRequest, StatusCode,
+    http::{HeaderMap, HeaderValue},
+};
 use serde::{Deserialize, Serialize};
 use serde_json::Value;
 pub use settings::OpenAiReasoningEffort as ReasoningEffort;
 use std::{convert::TryFrom, future::Future};
 use strum::EnumIter;
+use thiserror::Error;
 
 pub const OPEN_AI_API_URL: &str = "https://api.openai.com/v1";
 
@@ -441,8 +445,21 @@ pub struct ChoiceDelta {
     pub finish_reason: Option<String>,
 }
 
+#[derive(Error, Debug)]
+pub enum RequestError {
+    #[error("HTTP response error from {provider}'s API: status {status_code} - {body:?}")]
+    HttpResponseError {
+        provider: String,
+        status_code: StatusCode,
+        body: String,
+        headers: HeaderMap<HeaderValue>,
+    },
+    #[error(transparent)]
+    Other(#[from] anyhow::Error),
+}
+
 #[derive(Serialize, Deserialize, Debug)]
-pub struct OpenAiError {
+pub struct ResponseStreamError {
     message: String,
 }
 
@@ -450,7 +467,7 @@ pub struct OpenAiError {
 #[serde(untagged)]
 pub enum ResponseStreamResult {
     Ok(ResponseStreamEvent),
-    Err { error: OpenAiError },
+    Err { error: ResponseStreamError },
 }
 
 #[derive(Serialize, Deserialize, Debug)]
@@ -461,10 +478,11 @@ pub struct ResponseStreamEvent {
 
 pub async fn stream_completion(
     client: &dyn HttpClient,
+    provider_name: &str,
     api_url: &str,
     api_key: &str,
     request: Request,
-) -> Result<BoxStream<'static, Result<ResponseStreamEvent>>> {
+) -> Result<BoxStream<'static, Result<ResponseStreamEvent>>, RequestError> {
     let uri = format!("{api_url}/chat/completions");
     let request_builder = HttpRequest::builder()
         .method(Method::POST)
@@ -472,7 +490,12 @@ pub async fn stream_completion(
         .header("Content-Type", "application/json")
         .header("Authorization", format!("Bearer {}", api_key.trim()));
 
-    let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
+    let request = request_builder
+        .body(AsyncBody::from(
+            serde_json::to_string(&request).map_err(|e| RequestError::Other(e.into()))?,
+        ))
+        .map_err(|e| RequestError::Other(e.into()))?;
+
     let mut response = client.send(request).await?;
     if response.status().is_success() {
         let reader = BufReader::new(response.into_body());
@@ -508,27 +531,18 @@ pub async fn stream_completion(
             .boxed())
     } else {
         let mut body = String::new();
-        response.body_mut().read_to_string(&mut body).await?;
-
-        #[derive(Deserialize)]
-        struct OpenAiResponse {
-            error: OpenAiError,
-        }
-
-        match serde_json::from_str::<OpenAiResponse>(&body) {
-            Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
-                "API request to {} failed: {}",
-                api_url,
-                response.error.message,
-            )),
-
-            _ => anyhow::bail!(
-                "API request to {} failed with status {}: {}",
-                api_url,
-                response.status(),
-                body,
-            ),
-        }
+        response
+            .body_mut()
+            .read_to_string(&mut body)
+            .await
+            .map_err(|e| RequestError::Other(e.into()))?;
+
+        Err(RequestError::HttpResponseError {
+            provider: provider_name.to_owned(),
+            status_code: response.status(),
+            body,
+            headers: response.headers().clone(),
+        })
     }
 }