language_models: Add OpenRouterError and map OpenRouter errors to LanguageModelCompletionError (#34227)

Umesh Yadav created

Improves the error handling for openrouter and adds automatic retry like
anthropic for few of the status codes.
Release Notes:

- Improves error messages for Openrouter provider
- Automatic retry when rate limited or Server error from Openrouter

Change summary

Cargo.lock                                         |   3 
crates/language_model/Cargo.toml                   |   1 
crates/language_model/src/language_model.rs        |  67 +++
crates/language_models/src/provider/open_router.rs |  63 +-
crates/open_router/Cargo.toml                      |   2 
crates/open_router/src/open_router.rs              | 350 ++++++++++-----
6 files changed, 334 insertions(+), 152 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -9147,6 +9147,7 @@ dependencies = [
  "icons",
  "image",
  "log",
+ "open_router",
  "parking_lot",
  "proto",
  "schemars",
@@ -11222,6 +11223,8 @@ dependencies = [
  "schemars",
  "serde",
  "serde_json",
+ "strum 0.27.1",
+ "thiserror 2.0.12",
  "workspace-hack",
 ]
 

crates/language_model/Cargo.toml 🔗

@@ -17,6 +17,7 @@ test-support = []
 
 [dependencies]
 anthropic = { workspace = true, features = ["schemars"] }
+open_router.workspace = true
 anyhow.workspace = true
 base64.workspace = true
 client.workspace = true

crates/language_model/src/language_model.rs 🔗

@@ -17,6 +17,7 @@ use futures::{StreamExt, future::BoxFuture, stream::BoxStream};
 use gpui::{AnyView, App, AsyncApp, SharedString, Task, Window};
 use http_client::{StatusCode, http};
 use icons::IconName;
+use open_router::OpenRouterError;
 use parking_lot::Mutex;
 use schemars::JsonSchema;
 use serde::{Deserialize, Serialize, de::DeserializeOwned};
@@ -347,6 +348,72 @@ impl From<anthropic::ApiError> for LanguageModelCompletionError {
     }
 }
 
+impl From<OpenRouterError> for LanguageModelCompletionError {
+    fn from(error: OpenRouterError) -> Self {
+        let provider = LanguageModelProviderName::new("OpenRouter");
+        match error {
+            OpenRouterError::SerializeRequest(error) => Self::SerializeRequest { provider, error },
+            OpenRouterError::BuildRequestBody(error) => Self::BuildRequestBody { provider, error },
+            OpenRouterError::HttpSend(error) => Self::HttpSend { provider, error },
+            OpenRouterError::DeserializeResponse(error) => {
+                Self::DeserializeResponse { provider, error }
+            }
+            OpenRouterError::ReadResponse(error) => Self::ApiReadResponseError { provider, error },
+            OpenRouterError::RateLimit { retry_after } => Self::RateLimitExceeded {
+                provider,
+                retry_after: Some(retry_after),
+            },
+            OpenRouterError::ServerOverloaded { retry_after } => Self::ServerOverloaded {
+                provider,
+                retry_after,
+            },
+            OpenRouterError::ApiError(api_error) => api_error.into(),
+        }
+    }
+}
+
+impl From<open_router::ApiError> for LanguageModelCompletionError {
+    fn from(error: open_router::ApiError) -> Self {
+        use open_router::ApiErrorCode::*;
+        let provider = LanguageModelProviderName::new("OpenRouter");
+        match error.code {
+            InvalidRequestError => Self::BadRequestFormat {
+                provider,
+                message: error.message,
+            },
+            AuthenticationError => Self::AuthenticationError {
+                provider,
+                message: error.message,
+            },
+            PaymentRequiredError => Self::AuthenticationError {
+                provider,
+                message: format!("Payment required: {}", error.message),
+            },
+            PermissionError => Self::PermissionError {
+                provider,
+                message: error.message,
+            },
+            RequestTimedOut => Self::HttpResponseError {
+                provider,
+                status_code: StatusCode::REQUEST_TIMEOUT,
+                message: error.message,
+            },
+            RateLimitError => Self::RateLimitExceeded {
+                provider,
+                retry_after: None,
+            },
+            ApiError => Self::ApiInternalServerError {
+                provider,
+                message: error.message,
+            },
+            OverloadedError => Self::ServerOverloaded {
+                provider,
+                retry_after: None,
+            },
+        }
+    }
+}
+
 /// Indicates the format used to define the input schema for a language model tool.
 #[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)]
 pub enum LanguageModelToolSchemaFormat {

crates/language_models/src/provider/open_router.rs 🔗

@@ -152,6 +152,7 @@ impl State {
             .open_router
             .api_url
             .clone();
+
         cx.spawn(async move |this, cx| {
             let (api_key, from_env) = if let Ok(api_key) = std::env::var(OPENROUTER_API_KEY_VAR) {
                 (api_key, true)
@@ -161,11 +162,11 @@ impl State {
                     .await?
                     .ok_or(AuthenticateError::CredentialsNotFound)?;
                 (
-                    String::from_utf8(api_key)
-                        .context(format!("invalid {} API key", PROVIDER_NAME))?,
+                    String::from_utf8(api_key).context("invalid {PROVIDER_NAME} API key")?,
                     false,
                 )
             };
+
             this.update(cx, |this, cx| {
                 this.api_key = Some(api_key);
                 this.api_key_from_env = from_env;
@@ -183,7 +184,9 @@ impl State {
         let api_url = settings.api_url.clone();
 
         cx.spawn(async move |this, cx| {
-            let models = list_models(http_client.as_ref(), &api_url).await?;
+            let models = list_models(http_client.as_ref(), &api_url)
+                .await
+                .map_err(|e| anyhow::anyhow!("OpenRouter error: {:?}", e))?;
 
             this.update(cx, |this, cx| {
                 this.available_models = models;
@@ -334,27 +337,37 @@ impl OpenRouterLanguageModel {
         &self,
         request: open_router::Request,
         cx: &AsyncApp,
-    ) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<ResponseStreamEvent>>>>
-    {
+    ) -> BoxFuture<
+        'static,
+        Result<
+            futures::stream::BoxStream<
+                'static,
+                Result<ResponseStreamEvent, open_router::OpenRouterError>,
+            >,
+            LanguageModelCompletionError,
+        >,
+    > {
         let http_client = self.http_client.clone();
         let Ok((api_key, api_url)) = cx.read_entity(&self.state, |state, cx| {
             let settings = &AllLanguageModelSettings::get_global(cx).open_router;
             (state.api_key.clone(), settings.api_url.clone())
         }) else {
-            return futures::future::ready(Err(anyhow!(
-                "App state dropped: Unable to read API key or API URL from the application state"
-            )))
+            return futures::future::ready(Err(LanguageModelCompletionError::Other(anyhow!(
+                "App state dropped"
+            ))))
             .boxed();
         };
 
-        let future = self.request_limiter.stream(async move {
-            let api_key = api_key.ok_or_else(|| anyhow!("Missing OpenRouter API Key"))?;
+        async move {
+            let Some(api_key) = api_key else {
+                return Err(LanguageModelCompletionError::NoApiKey {
+                    provider: PROVIDER_NAME,
+                });
+            };
             let request = stream_completion(http_client.as_ref(), &api_url, &api_key, request);
-            let response = request.await?;
-            Ok(response)
-        });
-
-        async move { Ok(future.await?.boxed()) }.boxed()
+            request.await.map_err(Into::into)
+        }
+        .boxed()
     }
 }
 
@@ -435,12 +448,12 @@ impl LanguageModel for OpenRouterLanguageModel {
         >,
     > {
         let request = into_open_router(request, &self.model, self.max_output_tokens());
-        let completions = self.stream_completion(request, cx);
-        async move {
-            let mapper = OpenRouterEventMapper::new();
-            Ok(mapper.map_stream(completions.await?).boxed())
-        }
-        .boxed()
+        let request = self.stream_completion(request, cx);
+        let future = self.request_limiter.stream(async move {
+            let response = request.await?;
+            Ok(OpenRouterEventMapper::new().map_stream(response))
+        });
+        async move { Ok(future.await?.boxed()) }.boxed()
     }
 }
 
@@ -608,13 +621,17 @@ impl OpenRouterEventMapper {
 
     pub fn map_stream(
         mut self,
-        events: Pin<Box<dyn Send + Stream<Item = Result<ResponseStreamEvent>>>>,
+        events: Pin<
+            Box<
+                dyn Send + Stream<Item = Result<ResponseStreamEvent, open_router::OpenRouterError>>,
+            >,
+        >,
     ) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
     {
         events.flat_map(move |event| {
             futures::stream::iter(match event {
                 Ok(event) => self.map_event(event),
-                Err(error) => vec![Err(LanguageModelCompletionError::from(anyhow!(error)))],
+                Err(error) => vec![Err(error.into())],
             })
         })
     }

crates/open_router/Cargo.toml 🔗

@@ -22,4 +22,6 @@ http_client.workspace = true
 schemars = { workspace = true, optional = true }
 serde.workspace = true
 serde_json.workspace = true
+thiserror.workspace = true
+strum.workspace = true
 workspace-hack.workspace = true

crates/open_router/src/open_router.rs 🔗

@@ -1,12 +1,31 @@
-use anyhow::{Context, Result, anyhow};
+use anyhow::{Result, anyhow};
 use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
-use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
+use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest, http};
 use serde::{Deserialize, Serialize};
 use serde_json::Value;
-use std::convert::TryFrom;
+use std::{convert::TryFrom, io, time::Duration};
+use strum::EnumString;
+use thiserror::Error;
 
 pub const OPEN_ROUTER_API_URL: &str = "https://openrouter.ai/api/v1";
 
+fn extract_retry_after(headers: &http::HeaderMap) -> Option<std::time::Duration> {
+    if let Some(reset) = headers.get("X-RateLimit-Reset") {
+        if let Ok(s) = reset.to_str() {
+            if let Ok(epoch_ms) = s.parse::<u64>() {
+                let now = std::time::SystemTime::now()
+                    .duration_since(std::time::UNIX_EPOCH)
+                    .unwrap_or_default()
+                    .as_millis() as u64;
+                if epoch_ms > now {
+                    return Some(std::time::Duration::from_millis(epoch_ms - now));
+                }
+            }
+        }
+    }
+    None
+}
+
 fn is_none_or_empty<T: AsRef<[U]>, U>(opt: &Option<T>) -> bool {
     opt.as_ref().is_none_or(|v| v.as_ref().is_empty())
 }
@@ -413,76 +432,12 @@ pub struct ModelArchitecture {
     pub input_modalities: Vec<String>,
 }
 
-pub async fn complete(
-    client: &dyn HttpClient,
-    api_url: &str,
-    api_key: &str,
-    request: Request,
-) -> Result<Response> {
-    let uri = format!("{api_url}/chat/completions");
-    let request_builder = HttpRequest::builder()
-        .method(Method::POST)
-        .uri(uri)
-        .header("Content-Type", "application/json")
-        .header("Authorization", format!("Bearer {}", api_key.trim()))
-        .header("HTTP-Referer", "https://zed.dev")
-        .header("X-Title", "Zed Editor");
-
-    let mut request_body = request;
-    request_body.stream = false;
-
-    let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request_body)?))?;
-    let mut response = client.send(request).await?;
-
-    if response.status().is_success() {
-        let mut body = String::new();
-        response.body_mut().read_to_string(&mut body).await?;
-        let response: Response = serde_json::from_str(&body)?;
-        Ok(response)
-    } else {
-        let mut body = String::new();
-        response.body_mut().read_to_string(&mut body).await?;
-
-        #[derive(Deserialize)]
-        struct OpenRouterResponse {
-            error: OpenRouterError,
-        }
-
-        #[derive(Deserialize)]
-        struct OpenRouterError {
-            message: String,
-            #[serde(default)]
-            code: String,
-        }
-
-        match serde_json::from_str::<OpenRouterResponse>(&body) {
-            Ok(response) if !response.error.message.is_empty() => {
-                let error_message = if !response.error.code.is_empty() {
-                    format!("{}: {}", response.error.code, response.error.message)
-                } else {
-                    response.error.message
-                };
-
-                Err(anyhow!(
-                    "Failed to connect to OpenRouter API: {}",
-                    error_message
-                ))
-            }
-            _ => Err(anyhow!(
-                "Failed to connect to OpenRouter API: {} {}",
-                response.status(),
-                body,
-            )),
-        }
-    }
-}
-
 pub async fn stream_completion(
     client: &dyn HttpClient,
     api_url: &str,
     api_key: &str,
     request: Request,
-) -> Result<BoxStream<'static, Result<ResponseStreamEvent>>> {
+) -> Result<BoxStream<'static, Result<ResponseStreamEvent, OpenRouterError>>, OpenRouterError> {
     let uri = format!("{api_url}/chat/completions");
     let request_builder = HttpRequest::builder()
         .method(Method::POST)
@@ -492,8 +447,15 @@ pub async fn stream_completion(
         .header("HTTP-Referer", "https://zed.dev")
         .header("X-Title", "Zed Editor");
 
-    let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
-    let mut response = client.send(request).await?;
+    let request = request_builder
+        .body(AsyncBody::from(
+            serde_json::to_string(&request).map_err(OpenRouterError::SerializeRequest)?,
+        ))
+        .map_err(OpenRouterError::BuildRequestBody)?;
+    let mut response = client
+        .send(request)
+        .await
+        .map_err(OpenRouterError::HttpSend)?;
 
     if response.status().is_success() {
         let reader = BufReader::new(response.into_body());
@@ -513,86 +475,85 @@ pub async fn stream_completion(
                             match serde_json::from_str::<ResponseStreamEvent>(line) {
                                 Ok(response) => Some(Ok(response)),
                                 Err(error) => {
-                                    #[derive(Deserialize)]
-                                    struct ErrorResponse {
-                                        error: String,
-                                    }
-
-                                    match serde_json::from_str::<ErrorResponse>(line) {
-                                        Ok(err_response) => Some(Err(anyhow!(err_response.error))),
-                                        Err(_) => {
-                                            if line.trim().is_empty() {
-                                                None
-                                            } else {
-                                                Some(Err(anyhow!(
-                                                    "Failed to parse response: {}. Original content: '{}'",
-                                                    error, line
-                                                )))
-                                            }
-                                        }
+                                    if line.trim().is_empty() {
+                                        None
+                                    } else {
+                                        Some(Err(OpenRouterError::DeserializeResponse(error)))
                                     }
                                 }
                             }
                         }
                     }
-                    Err(error) => Some(Err(anyhow!(error))),
+                    Err(error) => Some(Err(OpenRouterError::ReadResponse(error))),
                 }
             })
             .boxed())
     } else {
-        let mut body = String::new();
-        response.body_mut().read_to_string(&mut body).await?;
+        let code = ApiErrorCode::from_status(response.status().as_u16());
 
-        #[derive(Deserialize)]
-        struct OpenRouterResponse {
-            error: OpenRouterError,
-        }
-
-        #[derive(Deserialize)]
-        struct OpenRouterError {
-            message: String,
-            #[serde(default)]
-            code: String,
-        }
-
-        match serde_json::from_str::<OpenRouterResponse>(&body) {
-            Ok(response) if !response.error.message.is_empty() => {
-                let error_message = if !response.error.code.is_empty() {
-                    format!("{}: {}", response.error.code, response.error.message)
-                } else {
-                    response.error.message
-                };
-
-                Err(anyhow!(
-                    "Failed to connect to OpenRouter API: {}",
-                    error_message
-                ))
+        let mut body = String::new();
+        response
+            .body_mut()
+            .read_to_string(&mut body)
+            .await
+            .map_err(OpenRouterError::ReadResponse)?;
+
+        let error_response = match serde_json::from_str::<OpenRouterErrorResponse>(&body) {
+            Ok(OpenRouterErrorResponse { error }) => error,
+            Err(_) => OpenRouterErrorBody {
+                code: response.status().as_u16(),
+                message: body,
+                metadata: None,
+            },
+        };
+
+        match code {
+            ApiErrorCode::RateLimitError => {
+                let retry_after = extract_retry_after(response.headers());
+                Err(OpenRouterError::RateLimit {
+                    retry_after: retry_after.unwrap_or_else(|| std::time::Duration::from_secs(60)),
+                })
+            }
+            ApiErrorCode::OverloadedError => {
+                let retry_after = extract_retry_after(response.headers());
+                Err(OpenRouterError::ServerOverloaded { retry_after })
             }
-            _ => Err(anyhow!(
-                "Failed to connect to OpenRouter API: {} {}",
-                response.status(),
-                body,
-            )),
+            _ => Err(OpenRouterError::ApiError(ApiError {
+                code: code,
+                message: error_response.message,
+            })),
         }
     }
 }
 
-pub async fn list_models(client: &dyn HttpClient, api_url: &str) -> Result<Vec<Model>> {
+pub async fn list_models(
+    client: &dyn HttpClient,
+    api_url: &str,
+) -> Result<Vec<Model>, OpenRouterError> {
     let uri = format!("{api_url}/models");
     let request_builder = HttpRequest::builder()
         .method(Method::GET)
         .uri(uri)
         .header("Accept", "application/json");
 
-    let request = request_builder.body(AsyncBody::default())?;
-    let mut response = client.send(request).await?;
+    let request = request_builder
+        .body(AsyncBody::default())
+        .map_err(OpenRouterError::BuildRequestBody)?;
+    let mut response = client
+        .send(request)
+        .await
+        .map_err(OpenRouterError::HttpSend)?;
 
     let mut body = String::new();
-    response.body_mut().read_to_string(&mut body).await?;
+    response
+        .body_mut()
+        .read_to_string(&mut body)
+        .await
+        .map_err(OpenRouterError::ReadResponse)?;
 
     if response.status().is_success() {
         let response: ListModelsResponse =
-            serde_json::from_str(&body).context("Unable to parse OpenRouter models response")?;
+            serde_json::from_str(&body).map_err(OpenRouterError::DeserializeResponse)?;
 
         let models = response
             .data
@@ -637,10 +598,141 @@ pub async fn list_models(client: &dyn HttpClient, api_url: &str) -> Result<Vec<M
 
         Ok(models)
     } else {
-        Err(anyhow!(
-            "Failed to connect to OpenRouter API: {} {}",
-            response.status(),
-            body,
-        ))
+        let code = ApiErrorCode::from_status(response.status().as_u16());
+
+        let mut body = String::new();
+        response
+            .body_mut()
+            .read_to_string(&mut body)
+            .await
+            .map_err(OpenRouterError::ReadResponse)?;
+
+        let error_response = match serde_json::from_str::<OpenRouterErrorResponse>(&body) {
+            Ok(OpenRouterErrorResponse { error }) => error,
+            Err(_) => OpenRouterErrorBody {
+                code: response.status().as_u16(),
+                message: body,
+                metadata: None,
+            },
+        };
+
+        match code {
+            ApiErrorCode::RateLimitError => {
+                let retry_after = extract_retry_after(response.headers());
+                Err(OpenRouterError::RateLimit {
+                    retry_after: retry_after.unwrap_or_else(|| std::time::Duration::from_secs(60)),
+                })
+            }
+            ApiErrorCode::OverloadedError => {
+                let retry_after = extract_retry_after(response.headers());
+                Err(OpenRouterError::ServerOverloaded { retry_after })
+            }
+            _ => Err(OpenRouterError::ApiError(ApiError {
+                code: code,
+                message: error_response.message,
+            })),
+        }
+    }
+}
+
+#[derive(Debug)]
+pub enum OpenRouterError {
+    /// Failed to serialize the HTTP request body to JSON
+    SerializeRequest(serde_json::Error),
+
+    /// Failed to construct the HTTP request body
+    BuildRequestBody(http::Error),
+
+    /// Failed to send the HTTP request
+    HttpSend(anyhow::Error),
+
+    /// Failed to deserialize the response from JSON
+    DeserializeResponse(serde_json::Error),
+
+    /// Failed to read from response stream
+    ReadResponse(io::Error),
+
+    /// Rate limit exceeded
+    RateLimit { retry_after: Duration },
+
+    /// Server overloaded
+    ServerOverloaded { retry_after: Option<Duration> },
+
+    /// API returned an error response
+    ApiError(ApiError),
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+pub struct OpenRouterErrorBody {
+    pub code: u16,
+    pub message: String,
+    #[serde(default, skip_serializing_if = "Option::is_none")]
+    pub metadata: Option<std::collections::HashMap<String, serde_json::Value>>,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+pub struct OpenRouterErrorResponse {
+    pub error: OpenRouterErrorBody,
+}
+
+#[derive(Debug, Serialize, Deserialize, Error)]
+#[error("OpenRouter API Error: {code}: {message}")]
+pub struct ApiError {
+    pub code: ApiErrorCode,
+    pub message: String,
+}
+
+/// An OpenROuter API error code.
+/// <https://openrouter.ai/docs/api-reference/errors#error-codes>
+#[derive(Debug, PartialEq, Eq, Clone, Copy, EnumString, Serialize, Deserialize)]
+#[strum(serialize_all = "snake_case")]
+pub enum ApiErrorCode {
+    /// 400: Bad Request (invalid or missing params, CORS)
+    InvalidRequestError,
+    /// 401: Invalid credentials (OAuth session expired, disabled/invalid API key)
+    AuthenticationError,
+    /// 402: Your account or API key has insufficient credits. Add more credits and retry the request.
+    PaymentRequiredError,
+    /// 403: Your chosen model requires moderation and your input was flagged
+    PermissionError,
+    /// 408: Your request timed out
+    RequestTimedOut,
+    /// 429: You are being rate limited
+    RateLimitError,
+    /// 502: Your chosen model is down or we received an invalid response from it
+    ApiError,
+    /// 503: There is no available model provider that meets your routing requirements
+    OverloadedError,
+}
+
+impl std::fmt::Display for ApiErrorCode {
+    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+        let s = match self {
+            ApiErrorCode::InvalidRequestError => "invalid_request_error",
+            ApiErrorCode::AuthenticationError => "authentication_error",
+            ApiErrorCode::PaymentRequiredError => "payment_required_error",
+            ApiErrorCode::PermissionError => "permission_error",
+            ApiErrorCode::RequestTimedOut => "request_timed_out",
+            ApiErrorCode::RateLimitError => "rate_limit_error",
+            ApiErrorCode::ApiError => "api_error",
+            ApiErrorCode::OverloadedError => "overloaded_error",
+        };
+        write!(f, "{s}")
+    }
+}
+
+impl ApiErrorCode {
+    pub fn from_status(status: u16) -> Self {
+        match status {
+            400 => ApiErrorCode::InvalidRequestError,
+            401 => ApiErrorCode::AuthenticationError,
+            402 => ApiErrorCode::PaymentRequiredError,
+            403 => ApiErrorCode::PermissionError,
+            408 => ApiErrorCode::RequestTimedOut,
+            429 => ApiErrorCode::RateLimitError,
+            502 => ApiErrorCode::ApiError,
+            503 => ApiErrorCode::OverloadedError,
+            _ => ApiErrorCode::ApiError,
+        }
     }
 }