open_router.rs

  1use anyhow::{Result, anyhow};
  2use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
  3use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest, http};
  4use serde::{Deserialize, Serialize};
  5use serde_json::Value;
  6pub use settings::DataCollection;
  7pub use settings::ModelMode;
  8pub use settings::OpenRouterAvailableModel as AvailableModel;
  9pub use settings::OpenRouterProvider as Provider;
 10use std::{convert::TryFrom, io, time::Duration};
 11use strum::EnumString;
 12use thiserror::Error;
 13
 14pub const OPEN_ROUTER_API_URL: &str = "https://openrouter.ai/api/v1";
 15
 16fn extract_retry_after(headers: &http::HeaderMap) -> Option<std::time::Duration> {
 17    if let Some(reset) = headers.get("X-RateLimit-Reset") {
 18        if let Ok(s) = reset.to_str() {
 19            if let Ok(epoch_ms) = s.parse::<u64>() {
 20                let now = std::time::SystemTime::now()
 21                    .duration_since(std::time::UNIX_EPOCH)
 22                    .unwrap_or_default()
 23                    .as_millis() as u64;
 24                if epoch_ms > now {
 25                    return Some(std::time::Duration::from_millis(epoch_ms - now));
 26                }
 27            }
 28        }
 29    }
 30    None
 31}
 32
 33fn is_none_or_empty<T: AsRef<[U]>, U>(opt: &Option<T>) -> bool {
 34    opt.as_ref().is_none_or(|v| v.as_ref().is_empty())
 35}
 36
 37#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
 38#[serde(rename_all = "lowercase")]
 39pub enum Role {
 40    User,
 41    Assistant,
 42    System,
 43    Tool,
 44}
 45
 46impl TryFrom<String> for Role {
 47    type Error = anyhow::Error;
 48
 49    fn try_from(value: String) -> Result<Self> {
 50        match value.as_str() {
 51            "user" => Ok(Self::User),
 52            "assistant" => Ok(Self::Assistant),
 53            "system" => Ok(Self::System),
 54            "tool" => Ok(Self::Tool),
 55            _ => Err(anyhow!("invalid role '{value}'")),
 56        }
 57    }
 58}
 59
 60impl From<Role> for String {
 61    fn from(val: Role) -> Self {
 62        match val {
 63            Role::User => "user".to_owned(),
 64            Role::Assistant => "assistant".to_owned(),
 65            Role::System => "system".to_owned(),
 66            Role::Tool => "tool".to_owned(),
 67        }
 68    }
 69}
 70
 71#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
 72#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
 73pub struct Model {
 74    pub name: String,
 75    pub display_name: Option<String>,
 76    pub max_tokens: u64,
 77    pub supports_tools: Option<bool>,
 78    pub supports_images: Option<bool>,
 79    #[serde(default)]
 80    pub mode: ModelMode,
 81    pub provider: Option<Provider>,
 82}
 83
 84impl Model {
 85    pub fn default_fast() -> Self {
 86        Self::new(
 87            "openrouter/auto",
 88            Some("Auto Router"),
 89            Some(2000000),
 90            Some(true),
 91            Some(false),
 92            Some(ModelMode::Default),
 93            None,
 94        )
 95    }
 96
 97    pub fn default() -> Self {
 98        Self::default_fast()
 99    }
100
101    pub fn new(
102        name: &str,
103        display_name: Option<&str>,
104        max_tokens: Option<u64>,
105        supports_tools: Option<bool>,
106        supports_images: Option<bool>,
107        mode: Option<ModelMode>,
108        provider: Option<Provider>,
109    ) -> Self {
110        Self {
111            name: name.to_owned(),
112            display_name: display_name.map(|s| s.to_owned()),
113            max_tokens: max_tokens.unwrap_or(2000000),
114            supports_tools,
115            supports_images,
116            mode: mode.unwrap_or(ModelMode::Default),
117            provider,
118        }
119    }
120
121    pub fn id(&self) -> &str {
122        &self.name
123    }
124
125    pub fn display_name(&self) -> &str {
126        self.display_name.as_ref().unwrap_or(&self.name)
127    }
128
129    pub fn max_token_count(&self) -> u64 {
130        self.max_tokens
131    }
132
133    pub fn max_output_tokens(&self) -> Option<u64> {
134        None
135    }
136
137    pub fn supports_tool_calls(&self) -> bool {
138        self.supports_tools.unwrap_or(false)
139    }
140
141    pub fn supports_parallel_tool_calls(&self) -> bool {
142        false
143    }
144}
145
146#[derive(Debug, Serialize, Deserialize)]
147pub struct Request {
148    pub model: String,
149    pub messages: Vec<RequestMessage>,
150    pub stream: bool,
151    #[serde(default, skip_serializing_if = "Option::is_none")]
152    pub max_tokens: Option<u64>,
153    #[serde(default, skip_serializing_if = "Vec::is_empty")]
154    pub stop: Vec<String>,
155    pub temperature: f32,
156    #[serde(default, skip_serializing_if = "Option::is_none")]
157    pub tool_choice: Option<ToolChoice>,
158    #[serde(default, skip_serializing_if = "Option::is_none")]
159    pub parallel_tool_calls: Option<bool>,
160    #[serde(default, skip_serializing_if = "Vec::is_empty")]
161    pub tools: Vec<ToolDefinition>,
162    #[serde(default, skip_serializing_if = "Option::is_none")]
163    pub reasoning: Option<Reasoning>,
164    pub usage: RequestUsage,
165    pub provider: Option<Provider>,
166}
167
168#[derive(Debug, Default, Serialize, Deserialize)]
169pub struct RequestUsage {
170    pub include: bool,
171}
172
173#[derive(Debug, Serialize, Deserialize)]
174#[serde(rename_all = "lowercase")]
175pub enum ToolChoice {
176    Auto,
177    Required,
178    None,
179    #[serde(untagged)]
180    Other(ToolDefinition),
181}
182
183#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
184#[derive(Clone, Deserialize, Serialize, Debug)]
185#[serde(tag = "type", rename_all = "snake_case")]
186pub enum ToolDefinition {
187    #[allow(dead_code)]
188    Function { function: FunctionDefinition },
189}
190
191#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
192#[derive(Clone, Debug, Serialize, Deserialize)]
193pub struct FunctionDefinition {
194    pub name: String,
195    pub description: Option<String>,
196    pub parameters: Option<Value>,
197}
198
199#[derive(Debug, Serialize, Deserialize)]
200pub struct Reasoning {
201    #[serde(skip_serializing_if = "Option::is_none")]
202    pub effort: Option<String>,
203    #[serde(skip_serializing_if = "Option::is_none")]
204    pub max_tokens: Option<u32>,
205    #[serde(skip_serializing_if = "Option::is_none")]
206    pub exclude: Option<bool>,
207    #[serde(skip_serializing_if = "Option::is_none")]
208    pub enabled: Option<bool>,
209}
210
211#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
212#[serde(tag = "role", rename_all = "lowercase")]
213pub enum RequestMessage {
214    Assistant {
215        content: Option<MessageContent>,
216        #[serde(default, skip_serializing_if = "Vec::is_empty")]
217        tool_calls: Vec<ToolCall>,
218    },
219    User {
220        content: MessageContent,
221    },
222    System {
223        content: MessageContent,
224    },
225    Tool {
226        content: MessageContent,
227        tool_call_id: String,
228    },
229}
230
231#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
232#[serde(untagged)]
233pub enum MessageContent {
234    Plain(String),
235    Multipart(Vec<MessagePart>),
236}
237
238impl MessageContent {
239    pub fn empty() -> Self {
240        Self::Plain(String::new())
241    }
242
243    pub fn push_part(&mut self, part: MessagePart) {
244        match self {
245            Self::Plain(text) if text.is_empty() => {
246                *self = Self::Multipart(vec![part]);
247            }
248            Self::Plain(text) => {
249                let text_part = MessagePart::Text {
250                    text: std::mem::take(text),
251                };
252                *self = Self::Multipart(vec![text_part, part]);
253            }
254            Self::Multipart(parts) => parts.push(part),
255        }
256    }
257}
258
259impl From<Vec<MessagePart>> for MessageContent {
260    fn from(parts: Vec<MessagePart>) -> Self {
261        if parts.len() == 1
262            && let MessagePart::Text { text } = &parts[0]
263        {
264            return Self::Plain(text.clone());
265        }
266        Self::Multipart(parts)
267    }
268}
269
270impl From<String> for MessageContent {
271    fn from(text: String) -> Self {
272        Self::Plain(text)
273    }
274}
275
276impl From<&str> for MessageContent {
277    fn from(text: &str) -> Self {
278        Self::Plain(text.to_string())
279    }
280}
281
282impl MessageContent {
283    pub fn as_text(&self) -> Option<&str> {
284        match self {
285            Self::Plain(text) => Some(text),
286            Self::Multipart(parts) if parts.len() == 1 => {
287                if let MessagePart::Text { text } = &parts[0] {
288                    Some(text)
289                } else {
290                    None
291                }
292            }
293            _ => None,
294        }
295    }
296
297    pub fn to_text(&self) -> String {
298        match self {
299            Self::Plain(text) => text.clone(),
300            Self::Multipart(parts) => parts
301                .iter()
302                .filter_map(|part| {
303                    if let MessagePart::Text { text } = part {
304                        Some(text.as_str())
305                    } else {
306                        None
307                    }
308                })
309                .collect::<Vec<_>>()
310                .join(""),
311        }
312    }
313}
314
315#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
316#[serde(tag = "type", rename_all = "snake_case")]
317pub enum MessagePart {
318    Text {
319        text: String,
320    },
321    #[serde(rename = "image_url")]
322    Image {
323        image_url: String,
324    },
325}
326
327#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
328pub struct ToolCall {
329    pub id: String,
330    #[serde(flatten)]
331    pub content: ToolCallContent,
332}
333
334#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
335#[serde(tag = "type", rename_all = "lowercase")]
336pub enum ToolCallContent {
337    Function { function: FunctionContent },
338}
339
340#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
341pub struct FunctionContent {
342    pub name: String,
343    pub arguments: String,
344}
345
346#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
347pub struct ResponseMessageDelta {
348    pub role: Option<Role>,
349    pub content: Option<String>,
350    pub reasoning: Option<String>,
351    #[serde(default, skip_serializing_if = "is_none_or_empty")]
352    pub tool_calls: Option<Vec<ToolCallChunk>>,
353}
354
355#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
356pub struct ToolCallChunk {
357    pub index: usize,
358    pub id: Option<String>,
359    pub function: Option<FunctionChunk>,
360}
361
362#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
363pub struct FunctionChunk {
364    pub name: Option<String>,
365    pub arguments: Option<String>,
366}
367
368#[derive(Serialize, Deserialize, Debug)]
369pub struct Usage {
370    pub prompt_tokens: u64,
371    pub completion_tokens: u64,
372    pub total_tokens: u64,
373}
374
375#[derive(Serialize, Deserialize, Debug)]
376pub struct ChoiceDelta {
377    pub index: u32,
378    pub delta: ResponseMessageDelta,
379    pub finish_reason: Option<String>,
380}
381
382#[derive(Serialize, Deserialize, Debug)]
383pub struct ResponseStreamEvent {
384    #[serde(default, skip_serializing_if = "Option::is_none")]
385    pub id: Option<String>,
386    pub created: u32,
387    pub model: String,
388    pub choices: Vec<ChoiceDelta>,
389    pub usage: Option<Usage>,
390}
391
392#[derive(Serialize, Deserialize, Debug)]
393pub struct Response {
394    pub id: String,
395    pub object: String,
396    pub created: u64,
397    pub model: String,
398    pub choices: Vec<Choice>,
399    pub usage: Usage,
400}
401
402#[derive(Serialize, Deserialize, Debug)]
403pub struct Choice {
404    pub index: u32,
405    pub message: RequestMessage,
406    pub finish_reason: Option<String>,
407}
408
409#[derive(Default, Debug, Clone, PartialEq, Deserialize)]
410pub struct ListModelsResponse {
411    pub data: Vec<ModelEntry>,
412}
413
414#[derive(Default, Debug, Clone, PartialEq, Deserialize)]
415pub struct ModelEntry {
416    pub id: String,
417    pub name: String,
418    pub created: usize,
419    pub description: String,
420    #[serde(default, skip_serializing_if = "Option::is_none")]
421    pub context_length: Option<u64>,
422    #[serde(default, skip_serializing_if = "Vec::is_empty")]
423    pub supported_parameters: Vec<String>,
424    #[serde(default, skip_serializing_if = "Option::is_none")]
425    pub architecture: Option<ModelArchitecture>,
426}
427
428#[derive(Default, Debug, Clone, PartialEq, Deserialize)]
429pub struct ModelArchitecture {
430    #[serde(default, skip_serializing_if = "Vec::is_empty")]
431    pub input_modalities: Vec<String>,
432}
433
434pub async fn stream_completion(
435    client: &dyn HttpClient,
436    api_url: &str,
437    api_key: &str,
438    request: Request,
439) -> Result<BoxStream<'static, Result<ResponseStreamEvent, OpenRouterError>>, OpenRouterError> {
440    let uri = format!("{api_url}/chat/completions");
441    let request_builder = HttpRequest::builder()
442        .method(Method::POST)
443        .uri(uri)
444        .header("Content-Type", "application/json")
445        .header("Authorization", format!("Bearer {}", api_key))
446        .header("HTTP-Referer", "https://zed.dev")
447        .header("X-Title", "Zed Editor");
448
449    let request = request_builder
450        .body(AsyncBody::from(
451            serde_json::to_string(&request).map_err(OpenRouterError::SerializeRequest)?,
452        ))
453        .map_err(OpenRouterError::BuildRequestBody)?;
454    let mut response = client
455        .send(request)
456        .await
457        .map_err(OpenRouterError::HttpSend)?;
458
459    if response.status().is_success() {
460        let reader = BufReader::new(response.into_body());
461        Ok(reader
462            .lines()
463            .filter_map(|line| async move {
464                match line {
465                    Ok(line) => {
466                        if line.starts_with(':') {
467                            return None;
468                        }
469
470                        let line = line.strip_prefix("data: ")?;
471                        if line == "[DONE]" {
472                            None
473                        } else {
474                            match serde_json::from_str::<ResponseStreamEvent>(line) {
475                                Ok(response) => Some(Ok(response)),
476                                Err(error) => {
477                                    if line.trim().is_empty() {
478                                        None
479                                    } else {
480                                        Some(Err(OpenRouterError::DeserializeResponse(error)))
481                                    }
482                                }
483                            }
484                        }
485                    }
486                    Err(error) => Some(Err(OpenRouterError::ReadResponse(error))),
487                }
488            })
489            .boxed())
490    } else {
491        let code = ApiErrorCode::from_status(response.status().as_u16());
492
493        let mut body = String::new();
494        response
495            .body_mut()
496            .read_to_string(&mut body)
497            .await
498            .map_err(OpenRouterError::ReadResponse)?;
499
500        let error_response = match serde_json::from_str::<OpenRouterErrorResponse>(&body) {
501            Ok(OpenRouterErrorResponse { error }) => error,
502            Err(_) => OpenRouterErrorBody {
503                code: response.status().as_u16(),
504                message: body,
505                metadata: None,
506            },
507        };
508
509        match code {
510            ApiErrorCode::RateLimitError => {
511                let retry_after = extract_retry_after(response.headers());
512                Err(OpenRouterError::RateLimit {
513                    retry_after: retry_after.unwrap_or_else(|| std::time::Duration::from_secs(60)),
514                })
515            }
516            ApiErrorCode::OverloadedError => {
517                let retry_after = extract_retry_after(response.headers());
518                Err(OpenRouterError::ServerOverloaded { retry_after })
519            }
520            _ => Err(OpenRouterError::ApiError(ApiError {
521                code: code,
522                message: error_response.message,
523            })),
524        }
525    }
526}
527
528pub async fn list_models(
529    client: &dyn HttpClient,
530    api_url: &str,
531    api_key: &str,
532) -> Result<Vec<Model>, OpenRouterError> {
533    let uri = format!("{api_url}/models/user");
534    let request_builder = HttpRequest::builder()
535        .method(Method::GET)
536        .uri(uri)
537        .header("Accept", "application/json")
538        .header("Authorization", format!("Bearer {}", api_key))
539        .header("HTTP-Referer", "https://zed.dev")
540        .header("X-Title", "Zed Editor");
541
542    let request = request_builder
543        .body(AsyncBody::default())
544        .map_err(OpenRouterError::BuildRequestBody)?;
545    let mut response = client
546        .send(request)
547        .await
548        .map_err(OpenRouterError::HttpSend)?;
549
550    let mut body = String::new();
551    response
552        .body_mut()
553        .read_to_string(&mut body)
554        .await
555        .map_err(OpenRouterError::ReadResponse)?;
556
557    if response.status().is_success() {
558        let response: ListModelsResponse =
559            serde_json::from_str(&body).map_err(OpenRouterError::DeserializeResponse)?;
560
561        let models = response
562            .data
563            .into_iter()
564            .map(|entry| Model {
565                name: entry.id,
566                // OpenRouter returns display names in the format "provider_name: model_name".
567                // When displayed in the UI, these names can get truncated from the right.
568                // Since users typically already know the provider, we extract just the model name
569                // portion (after the colon) to create a more concise and user-friendly label
570                // for the model dropdown in the agent panel.
571                display_name: Some(
572                    entry
573                        .name
574                        .split(':')
575                        .next_back()
576                        .unwrap_or(&entry.name)
577                        .trim()
578                        .to_string(),
579                ),
580                max_tokens: entry.context_length.unwrap_or(2000000),
581                supports_tools: Some(entry.supported_parameters.contains(&"tools".to_string())),
582                supports_images: Some(
583                    entry
584                        .architecture
585                        .as_ref()
586                        .map(|arch| arch.input_modalities.contains(&"image".to_string()))
587                        .unwrap_or(false),
588                ),
589                mode: if entry
590                    .supported_parameters
591                    .contains(&"reasoning".to_string())
592                {
593                    ModelMode::Thinking {
594                        budget_tokens: Some(4_096),
595                    }
596                } else {
597                    ModelMode::Default
598                },
599                provider: None,
600            })
601            .collect();
602
603        Ok(models)
604    } else {
605        let code = ApiErrorCode::from_status(response.status().as_u16());
606
607        let mut body = String::new();
608        response
609            .body_mut()
610            .read_to_string(&mut body)
611            .await
612            .map_err(OpenRouterError::ReadResponse)?;
613
614        let error_response = match serde_json::from_str::<OpenRouterErrorResponse>(&body) {
615            Ok(OpenRouterErrorResponse { error }) => error,
616            Err(_) => OpenRouterErrorBody {
617                code: response.status().as_u16(),
618                message: body,
619                metadata: None,
620            },
621        };
622
623        match code {
624            ApiErrorCode::RateLimitError => {
625                let retry_after = extract_retry_after(response.headers());
626                Err(OpenRouterError::RateLimit {
627                    retry_after: retry_after.unwrap_or_else(|| std::time::Duration::from_secs(60)),
628                })
629            }
630            ApiErrorCode::OverloadedError => {
631                let retry_after = extract_retry_after(response.headers());
632                Err(OpenRouterError::ServerOverloaded { retry_after })
633            }
634            _ => Err(OpenRouterError::ApiError(ApiError {
635                code: code,
636                message: error_response.message,
637            })),
638        }
639    }
640}
641
642#[derive(Debug)]
643pub enum OpenRouterError {
644    /// Failed to serialize the HTTP request body to JSON
645    SerializeRequest(serde_json::Error),
646
647    /// Failed to construct the HTTP request body
648    BuildRequestBody(http::Error),
649
650    /// Failed to send the HTTP request
651    HttpSend(anyhow::Error),
652
653    /// Failed to deserialize the response from JSON
654    DeserializeResponse(serde_json::Error),
655
656    /// Failed to read from response stream
657    ReadResponse(io::Error),
658
659    /// Rate limit exceeded
660    RateLimit { retry_after: Duration },
661
662    /// Server overloaded
663    ServerOverloaded { retry_after: Option<Duration> },
664
665    /// API returned an error response
666    ApiError(ApiError),
667}
668
669#[derive(Debug, Serialize, Deserialize)]
670pub struct OpenRouterErrorBody {
671    pub code: u16,
672    pub message: String,
673    #[serde(default, skip_serializing_if = "Option::is_none")]
674    pub metadata: Option<std::collections::HashMap<String, serde_json::Value>>,
675}
676
677#[derive(Debug, Serialize, Deserialize)]
678pub struct OpenRouterErrorResponse {
679    pub error: OpenRouterErrorBody,
680}
681
682#[derive(Debug, Serialize, Deserialize, Error)]
683#[error("OpenRouter API Error: {code}: {message}")]
684pub struct ApiError {
685    pub code: ApiErrorCode,
686    pub message: String,
687}
688
689/// An OpenROuter API error code.
690/// <https://openrouter.ai/docs/api-reference/errors#error-codes>
691#[derive(Debug, PartialEq, Eq, Clone, Copy, EnumString, Serialize, Deserialize)]
692#[strum(serialize_all = "snake_case")]
693pub enum ApiErrorCode {
694    /// 400: Bad Request (invalid or missing params, CORS)
695    InvalidRequestError,
696    /// 401: Invalid credentials (OAuth session expired, disabled/invalid API key)
697    AuthenticationError,
698    /// 402: Your account or API key has insufficient credits. Add more credits and retry the request.
699    PaymentRequiredError,
700    /// 403: Your chosen model requires moderation and your input was flagged
701    PermissionError,
702    /// 408: Your request timed out
703    RequestTimedOut,
704    /// 429: You are being rate limited
705    RateLimitError,
706    /// 502: Your chosen model is down or we received an invalid response from it
707    ApiError,
708    /// 503: There is no available model provider that meets your routing requirements
709    OverloadedError,
710}
711
712impl std::fmt::Display for ApiErrorCode {
713    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
714        let s = match self {
715            ApiErrorCode::InvalidRequestError => "invalid_request_error",
716            ApiErrorCode::AuthenticationError => "authentication_error",
717            ApiErrorCode::PaymentRequiredError => "payment_required_error",
718            ApiErrorCode::PermissionError => "permission_error",
719            ApiErrorCode::RequestTimedOut => "request_timed_out",
720            ApiErrorCode::RateLimitError => "rate_limit_error",
721            ApiErrorCode::ApiError => "api_error",
722            ApiErrorCode::OverloadedError => "overloaded_error",
723        };
724        write!(f, "{s}")
725    }
726}
727
728impl ApiErrorCode {
729    pub fn from_status(status: u16) -> Self {
730        match status {
731            400 => ApiErrorCode::InvalidRequestError,
732            401 => ApiErrorCode::AuthenticationError,
733            402 => ApiErrorCode::PaymentRequiredError,
734            403 => ApiErrorCode::PermissionError,
735            408 => ApiErrorCode::RequestTimedOut,
736            429 => ApiErrorCode::RateLimitError,
737            502 => ApiErrorCode::ApiError,
738            503 => ApiErrorCode::OverloadedError,
739            _ => ApiErrorCode::ApiError,
740        }
741    }
742}