@@ -17,6 +17,7 @@ use futures::{StreamExt, future::BoxFuture, stream::BoxStream};
use gpui::{AnyView, App, AsyncApp, SharedString, Task, Window};
use http_client::{StatusCode, http};
use icons::IconName;
+use open_router::OpenRouterError;
use parking_lot::Mutex;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize, de::DeserializeOwned};
@@ -347,6 +348,72 @@ impl From<anthropic::ApiError> for LanguageModelCompletionError {
}
}
+impl From<OpenRouterError> for LanguageModelCompletionError {
+ fn from(error: OpenRouterError) -> Self {
+ let provider = LanguageModelProviderName::new("OpenRouter");
+ match error {
+ OpenRouterError::SerializeRequest(error) => Self::SerializeRequest { provider, error },
+ OpenRouterError::BuildRequestBody(error) => Self::BuildRequestBody { provider, error },
+ OpenRouterError::HttpSend(error) => Self::HttpSend { provider, error },
+ OpenRouterError::DeserializeResponse(error) => {
+ Self::DeserializeResponse { provider, error }
+ }
+ OpenRouterError::ReadResponse(error) => Self::ApiReadResponseError { provider, error },
+ OpenRouterError::RateLimit { retry_after } => Self::RateLimitExceeded {
+ provider,
+ retry_after: Some(retry_after),
+ },
+ OpenRouterError::ServerOverloaded { retry_after } => Self::ServerOverloaded {
+ provider,
+ retry_after,
+ },
+ OpenRouterError::ApiError(api_error) => api_error.into(),
+ }
+ }
+}
+
+impl From<open_router::ApiError> for LanguageModelCompletionError {
+ fn from(error: open_router::ApiError) -> Self {
+ use open_router::ApiErrorCode::*;
+ let provider = LanguageModelProviderName::new("OpenRouter");
+ match error.code {
+ InvalidRequestError => Self::BadRequestFormat {
+ provider,
+ message: error.message,
+ },
+ AuthenticationError => Self::AuthenticationError {
+ provider,
+ message: error.message,
+ },
+ PaymentRequiredError => Self::AuthenticationError {
+ provider,
+ message: format!("Payment required: {}", error.message),
+ },
+ PermissionError => Self::PermissionError {
+ provider,
+ message: error.message,
+ },
+ RequestTimedOut => Self::HttpResponseError {
+ provider,
+ status_code: StatusCode::REQUEST_TIMEOUT,
+ message: error.message,
+ },
+ RateLimitError => Self::RateLimitExceeded {
+ provider,
+ retry_after: None,
+ },
+ ApiError => Self::ApiInternalServerError {
+ provider,
+ message: error.message,
+ },
+ OverloadedError => Self::ServerOverloaded {
+ provider,
+ retry_after: None,
+ },
+ }
+ }
+}
+
/// Indicates the format used to define the input schema for a language model tool.
#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)]
pub enum LanguageModelToolSchemaFormat {
@@ -152,6 +152,7 @@ impl State {
.open_router
.api_url
.clone();
+
cx.spawn(async move |this, cx| {
let (api_key, from_env) = if let Ok(api_key) = std::env::var(OPENROUTER_API_KEY_VAR) {
(api_key, true)
@@ -161,11 +162,11 @@ impl State {
.await?
.ok_or(AuthenticateError::CredentialsNotFound)?;
(
- String::from_utf8(api_key)
- .context(format!("invalid {} API key", PROVIDER_NAME))?,
+ String::from_utf8(api_key).context("invalid {PROVIDER_NAME} API key")?,
false,
)
};
+
this.update(cx, |this, cx| {
this.api_key = Some(api_key);
this.api_key_from_env = from_env;
@@ -183,7 +184,9 @@ impl State {
let api_url = settings.api_url.clone();
cx.spawn(async move |this, cx| {
- let models = list_models(http_client.as_ref(), &api_url).await?;
+ let models = list_models(http_client.as_ref(), &api_url)
+ .await
+ .map_err(|e| anyhow::anyhow!("OpenRouter error: {:?}", e))?;
this.update(cx, |this, cx| {
this.available_models = models;
@@ -334,27 +337,37 @@ impl OpenRouterLanguageModel {
&self,
request: open_router::Request,
cx: &AsyncApp,
- ) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<ResponseStreamEvent>>>>
- {
+ ) -> BoxFuture<
+ 'static,
+ Result<
+ futures::stream::BoxStream<
+ 'static,
+ Result<ResponseStreamEvent, open_router::OpenRouterError>,
+ >,
+ LanguageModelCompletionError,
+ >,
+ > {
let http_client = self.http_client.clone();
let Ok((api_key, api_url)) = cx.read_entity(&self.state, |state, cx| {
let settings = &AllLanguageModelSettings::get_global(cx).open_router;
(state.api_key.clone(), settings.api_url.clone())
}) else {
- return futures::future::ready(Err(anyhow!(
- "App state dropped: Unable to read API key or API URL from the application state"
- )))
+ return futures::future::ready(Err(LanguageModelCompletionError::Other(anyhow!(
+ "App state dropped"
+ ))))
.boxed();
};
- let future = self.request_limiter.stream(async move {
- let api_key = api_key.ok_or_else(|| anyhow!("Missing OpenRouter API Key"))?;
+ async move {
+ let Some(api_key) = api_key else {
+ return Err(LanguageModelCompletionError::NoApiKey {
+ provider: PROVIDER_NAME,
+ });
+ };
let request = stream_completion(http_client.as_ref(), &api_url, &api_key, request);
- let response = request.await?;
- Ok(response)
- });
-
- async move { Ok(future.await?.boxed()) }.boxed()
+ request.await.map_err(Into::into)
+ }
+ .boxed()
}
}
@@ -435,12 +448,12 @@ impl LanguageModel for OpenRouterLanguageModel {
>,
> {
let request = into_open_router(request, &self.model, self.max_output_tokens());
- let completions = self.stream_completion(request, cx);
- async move {
- let mapper = OpenRouterEventMapper::new();
- Ok(mapper.map_stream(completions.await?).boxed())
- }
- .boxed()
+ let request = self.stream_completion(request, cx);
+ let future = self.request_limiter.stream(async move {
+ let response = request.await?;
+ Ok(OpenRouterEventMapper::new().map_stream(response))
+ });
+ async move { Ok(future.await?.boxed()) }.boxed()
}
}
@@ -608,13 +621,17 @@ impl OpenRouterEventMapper {
pub fn map_stream(
mut self,
- events: Pin<Box<dyn Send + Stream<Item = Result<ResponseStreamEvent>>>>,
+ events: Pin<
+ Box<
+ dyn Send + Stream<Item = Result<ResponseStreamEvent, open_router::OpenRouterError>>,
+ >,
+ >,
) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
{
events.flat_map(move |event| {
futures::stream::iter(match event {
Ok(event) => self.map_event(event),
- Err(error) => vec![Err(LanguageModelCompletionError::from(anyhow!(error)))],
+ Err(error) => vec![Err(error.into())],
})
})
}
@@ -1,12 +1,31 @@
-use anyhow::{Context, Result, anyhow};
+use anyhow::{Result, anyhow};
use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
-use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
+use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest, http};
use serde::{Deserialize, Serialize};
use serde_json::Value;
-use std::convert::TryFrom;
+use std::{convert::TryFrom, io, time::Duration};
+use strum::EnumString;
+use thiserror::Error;
pub const OPEN_ROUTER_API_URL: &str = "https://openrouter.ai/api/v1";
+fn extract_retry_after(headers: &http::HeaderMap) -> Option<std::time::Duration> {
+ if let Some(reset) = headers.get("X-RateLimit-Reset") {
+ if let Ok(s) = reset.to_str() {
+ if let Ok(epoch_ms) = s.parse::<u64>() {
+ let now = std::time::SystemTime::now()
+ .duration_since(std::time::UNIX_EPOCH)
+ .unwrap_or_default()
+ .as_millis() as u64;
+ if epoch_ms > now {
+ return Some(std::time::Duration::from_millis(epoch_ms - now));
+ }
+ }
+ }
+ }
+ None
+}
+
fn is_none_or_empty<T: AsRef<[U]>, U>(opt: &Option<T>) -> bool {
opt.as_ref().is_none_or(|v| v.as_ref().is_empty())
}
@@ -413,76 +432,12 @@ pub struct ModelArchitecture {
pub input_modalities: Vec<String>,
}
-pub async fn complete(
- client: &dyn HttpClient,
- api_url: &str,
- api_key: &str,
- request: Request,
-) -> Result<Response> {
- let uri = format!("{api_url}/chat/completions");
- let request_builder = HttpRequest::builder()
- .method(Method::POST)
- .uri(uri)
- .header("Content-Type", "application/json")
- .header("Authorization", format!("Bearer {}", api_key.trim()))
- .header("HTTP-Referer", "https://zed.dev")
- .header("X-Title", "Zed Editor");
-
- let mut request_body = request;
- request_body.stream = false;
-
- let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request_body)?))?;
- let mut response = client.send(request).await?;
-
- if response.status().is_success() {
- let mut body = String::new();
- response.body_mut().read_to_string(&mut body).await?;
- let response: Response = serde_json::from_str(&body)?;
- Ok(response)
- } else {
- let mut body = String::new();
- response.body_mut().read_to_string(&mut body).await?;
-
- #[derive(Deserialize)]
- struct OpenRouterResponse {
- error: OpenRouterError,
- }
-
- #[derive(Deserialize)]
- struct OpenRouterError {
- message: String,
- #[serde(default)]
- code: String,
- }
-
- match serde_json::from_str::<OpenRouterResponse>(&body) {
- Ok(response) if !response.error.message.is_empty() => {
- let error_message = if !response.error.code.is_empty() {
- format!("{}: {}", response.error.code, response.error.message)
- } else {
- response.error.message
- };
-
- Err(anyhow!(
- "Failed to connect to OpenRouter API: {}",
- error_message
- ))
- }
- _ => Err(anyhow!(
- "Failed to connect to OpenRouter API: {} {}",
- response.status(),
- body,
- )),
- }
- }
-}
-
pub async fn stream_completion(
client: &dyn HttpClient,
api_url: &str,
api_key: &str,
request: Request,
-) -> Result<BoxStream<'static, Result<ResponseStreamEvent>>> {
+) -> Result<BoxStream<'static, Result<ResponseStreamEvent, OpenRouterError>>, OpenRouterError> {
let uri = format!("{api_url}/chat/completions");
let request_builder = HttpRequest::builder()
.method(Method::POST)
@@ -492,8 +447,15 @@ pub async fn stream_completion(
.header("HTTP-Referer", "https://zed.dev")
.header("X-Title", "Zed Editor");
- let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
- let mut response = client.send(request).await?;
+ let request = request_builder
+ .body(AsyncBody::from(
+ serde_json::to_string(&request).map_err(OpenRouterError::SerializeRequest)?,
+ ))
+ .map_err(OpenRouterError::BuildRequestBody)?;
+ let mut response = client
+ .send(request)
+ .await
+ .map_err(OpenRouterError::HttpSend)?;
if response.status().is_success() {
let reader = BufReader::new(response.into_body());
@@ -513,86 +475,85 @@ pub async fn stream_completion(
match serde_json::from_str::<ResponseStreamEvent>(line) {
Ok(response) => Some(Ok(response)),
Err(error) => {
- #[derive(Deserialize)]
- struct ErrorResponse {
- error: String,
- }
-
- match serde_json::from_str::<ErrorResponse>(line) {
- Ok(err_response) => Some(Err(anyhow!(err_response.error))),
- Err(_) => {
- if line.trim().is_empty() {
- None
- } else {
- Some(Err(anyhow!(
- "Failed to parse response: {}. Original content: '{}'",
- error, line
- )))
- }
- }
+ if line.trim().is_empty() {
+ None
+ } else {
+ Some(Err(OpenRouterError::DeserializeResponse(error)))
}
}
}
}
}
- Err(error) => Some(Err(anyhow!(error))),
+ Err(error) => Some(Err(OpenRouterError::ReadResponse(error))),
}
})
.boxed())
} else {
- let mut body = String::new();
- response.body_mut().read_to_string(&mut body).await?;
+ let code = ApiErrorCode::from_status(response.status().as_u16());
- #[derive(Deserialize)]
- struct OpenRouterResponse {
- error: OpenRouterError,
- }
-
- #[derive(Deserialize)]
- struct OpenRouterError {
- message: String,
- #[serde(default)]
- code: String,
- }
-
- match serde_json::from_str::<OpenRouterResponse>(&body) {
- Ok(response) if !response.error.message.is_empty() => {
- let error_message = if !response.error.code.is_empty() {
- format!("{}: {}", response.error.code, response.error.message)
- } else {
- response.error.message
- };
-
- Err(anyhow!(
- "Failed to connect to OpenRouter API: {}",
- error_message
- ))
+ let mut body = String::new();
+ response
+ .body_mut()
+ .read_to_string(&mut body)
+ .await
+ .map_err(OpenRouterError::ReadResponse)?;
+
+ let error_response = match serde_json::from_str::<OpenRouterErrorResponse>(&body) {
+ Ok(OpenRouterErrorResponse { error }) => error,
+ Err(_) => OpenRouterErrorBody {
+ code: response.status().as_u16(),
+ message: body,
+ metadata: None,
+ },
+ };
+
+ match code {
+ ApiErrorCode::RateLimitError => {
+ let retry_after = extract_retry_after(response.headers());
+ Err(OpenRouterError::RateLimit {
+ retry_after: retry_after.unwrap_or_else(|| std::time::Duration::from_secs(60)),
+ })
+ }
+ ApiErrorCode::OverloadedError => {
+ let retry_after = extract_retry_after(response.headers());
+ Err(OpenRouterError::ServerOverloaded { retry_after })
}
- _ => Err(anyhow!(
- "Failed to connect to OpenRouter API: {} {}",
- response.status(),
- body,
- )),
+ _ => Err(OpenRouterError::ApiError(ApiError {
+ code: code,
+ message: error_response.message,
+ })),
}
}
}
-pub async fn list_models(client: &dyn HttpClient, api_url: &str) -> Result<Vec<Model>> {
+pub async fn list_models(
+ client: &dyn HttpClient,
+ api_url: &str,
+) -> Result<Vec<Model>, OpenRouterError> {
let uri = format!("{api_url}/models");
let request_builder = HttpRequest::builder()
.method(Method::GET)
.uri(uri)
.header("Accept", "application/json");
- let request = request_builder.body(AsyncBody::default())?;
- let mut response = client.send(request).await?;
+ let request = request_builder
+ .body(AsyncBody::default())
+ .map_err(OpenRouterError::BuildRequestBody)?;
+ let mut response = client
+ .send(request)
+ .await
+ .map_err(OpenRouterError::HttpSend)?;
let mut body = String::new();
- response.body_mut().read_to_string(&mut body).await?;
+ response
+ .body_mut()
+ .read_to_string(&mut body)
+ .await
+ .map_err(OpenRouterError::ReadResponse)?;
if response.status().is_success() {
let response: ListModelsResponse =
- serde_json::from_str(&body).context("Unable to parse OpenRouter models response")?;
+ serde_json::from_str(&body).map_err(OpenRouterError::DeserializeResponse)?;
let models = response
.data
@@ -637,10 +598,141 @@ pub async fn list_models(client: &dyn HttpClient, api_url: &str) -> Result<Vec<M
Ok(models)
} else {
- Err(anyhow!(
- "Failed to connect to OpenRouter API: {} {}",
- response.status(),
- body,
- ))
+ let code = ApiErrorCode::from_status(response.status().as_u16());
+
+ let mut body = String::new();
+ response
+ .body_mut()
+ .read_to_string(&mut body)
+ .await
+ .map_err(OpenRouterError::ReadResponse)?;
+
+ let error_response = match serde_json::from_str::<OpenRouterErrorResponse>(&body) {
+ Ok(OpenRouterErrorResponse { error }) => error,
+ Err(_) => OpenRouterErrorBody {
+ code: response.status().as_u16(),
+ message: body,
+ metadata: None,
+ },
+ };
+
+ match code {
+ ApiErrorCode::RateLimitError => {
+ let retry_after = extract_retry_after(response.headers());
+ Err(OpenRouterError::RateLimit {
+ retry_after: retry_after.unwrap_or_else(|| std::time::Duration::from_secs(60)),
+ })
+ }
+ ApiErrorCode::OverloadedError => {
+ let retry_after = extract_retry_after(response.headers());
+ Err(OpenRouterError::ServerOverloaded { retry_after })
+ }
+ _ => Err(OpenRouterError::ApiError(ApiError {
+ code: code,
+ message: error_response.message,
+ })),
+ }
+ }
+}
+
+#[derive(Debug)]
+pub enum OpenRouterError {
+ /// Failed to serialize the HTTP request body to JSON
+ SerializeRequest(serde_json::Error),
+
+ /// Failed to construct the HTTP request body
+ BuildRequestBody(http::Error),
+
+ /// Failed to send the HTTP request
+ HttpSend(anyhow::Error),
+
+ /// Failed to deserialize the response from JSON
+ DeserializeResponse(serde_json::Error),
+
+ /// Failed to read from response stream
+ ReadResponse(io::Error),
+
+ /// Rate limit exceeded
+ RateLimit { retry_after: Duration },
+
+ /// Server overloaded
+ ServerOverloaded { retry_after: Option<Duration> },
+
+ /// API returned an error response
+ ApiError(ApiError),
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+pub struct OpenRouterErrorBody {
+ pub code: u16,
+ pub message: String,
+ #[serde(default, skip_serializing_if = "Option::is_none")]
+ pub metadata: Option<std::collections::HashMap<String, serde_json::Value>>,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+pub struct OpenRouterErrorResponse {
+ pub error: OpenRouterErrorBody,
+}
+
+#[derive(Debug, Serialize, Deserialize, Error)]
+#[error("OpenRouter API Error: {code}: {message}")]
+pub struct ApiError {
+ pub code: ApiErrorCode,
+ pub message: String,
+}
+
+/// An OpenROuter API error code.
+/// <https://openrouter.ai/docs/api-reference/errors#error-codes>
+#[derive(Debug, PartialEq, Eq, Clone, Copy, EnumString, Serialize, Deserialize)]
+#[strum(serialize_all = "snake_case")]
+pub enum ApiErrorCode {
+ /// 400: Bad Request (invalid or missing params, CORS)
+ InvalidRequestError,
+ /// 401: Invalid credentials (OAuth session expired, disabled/invalid API key)
+ AuthenticationError,
+ /// 402: Your account or API key has insufficient credits. Add more credits and retry the request.
+ PaymentRequiredError,
+ /// 403: Your chosen model requires moderation and your input was flagged
+ PermissionError,
+ /// 408: Your request timed out
+ RequestTimedOut,
+ /// 429: You are being rate limited
+ RateLimitError,
+ /// 502: Your chosen model is down or we received an invalid response from it
+ ApiError,
+ /// 503: There is no available model provider that meets your routing requirements
+ OverloadedError,
+}
+
+impl std::fmt::Display for ApiErrorCode {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ let s = match self {
+ ApiErrorCode::InvalidRequestError => "invalid_request_error",
+ ApiErrorCode::AuthenticationError => "authentication_error",
+ ApiErrorCode::PaymentRequiredError => "payment_required_error",
+ ApiErrorCode::PermissionError => "permission_error",
+ ApiErrorCode::RequestTimedOut => "request_timed_out",
+ ApiErrorCode::RateLimitError => "rate_limit_error",
+ ApiErrorCode::ApiError => "api_error",
+ ApiErrorCode::OverloadedError => "overloaded_error",
+ };
+ write!(f, "{s}")
+ }
+}
+
+impl ApiErrorCode {
+ pub fn from_status(status: u16) -> Self {
+ match status {
+ 400 => ApiErrorCode::InvalidRequestError,
+ 401 => ApiErrorCode::AuthenticationError,
+ 402 => ApiErrorCode::PaymentRequiredError,
+ 403 => ApiErrorCode::PermissionError,
+ 408 => ApiErrorCode::RequestTimedOut,
+ 429 => ApiErrorCode::RateLimitError,
+ 502 => ApiErrorCode::ApiError,
+ 503 => ApiErrorCode::OverloadedError,
+ _ => ApiErrorCode::ApiError,
+ }
}
}