open_router.rs

  1use anyhow::{Result, anyhow};
  2use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
  3use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest, http};
  4use serde::{Deserialize, Serialize};
  5use serde_json::Value;
  6use std::{convert::TryFrom, io, time::Duration};
  7use strum::EnumString;
  8use thiserror::Error;
  9
 10pub const OPEN_ROUTER_API_URL: &str = "https://openrouter.ai/api/v1";
 11
 12fn extract_retry_after(headers: &http::HeaderMap) -> Option<std::time::Duration> {
 13    if let Some(reset) = headers.get("X-RateLimit-Reset") {
 14        if let Ok(s) = reset.to_str() {
 15            if let Ok(epoch_ms) = s.parse::<u64>() {
 16                let now = std::time::SystemTime::now()
 17                    .duration_since(std::time::UNIX_EPOCH)
 18                    .unwrap_or_default()
 19                    .as_millis() as u64;
 20                if epoch_ms > now {
 21                    return Some(std::time::Duration::from_millis(epoch_ms - now));
 22                }
 23            }
 24        }
 25    }
 26    None
 27}
 28
 29fn is_none_or_empty<T: AsRef<[U]>, U>(opt: &Option<T>) -> bool {
 30    opt.as_ref().is_none_or(|v| v.as_ref().is_empty())
 31}
 32
 33#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
 34#[serde(rename_all = "lowercase")]
 35pub enum Role {
 36    User,
 37    Assistant,
 38    System,
 39    Tool,
 40}
 41
 42impl TryFrom<String> for Role {
 43    type Error = anyhow::Error;
 44
 45    fn try_from(value: String) -> Result<Self> {
 46        match value.as_str() {
 47            "user" => Ok(Self::User),
 48            "assistant" => Ok(Self::Assistant),
 49            "system" => Ok(Self::System),
 50            "tool" => Ok(Self::Tool),
 51            _ => Err(anyhow!("invalid role '{value}'")),
 52        }
 53    }
 54}
 55
 56impl From<Role> for String {
 57    fn from(val: Role) -> Self {
 58        match val {
 59            Role::User => "user".to_owned(),
 60            Role::Assistant => "assistant".to_owned(),
 61            Role::System => "system".to_owned(),
 62            Role::Tool => "tool".to_owned(),
 63        }
 64    }
 65}
 66
 67#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
 68#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
 69pub struct Model {
 70    pub name: String,
 71    pub display_name: Option<String>,
 72    pub max_tokens: u64,
 73    pub supports_tools: Option<bool>,
 74    pub supports_images: Option<bool>,
 75    #[serde(default)]
 76    pub mode: ModelMode,
 77}
 78
 79#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
 80#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
 81pub enum ModelMode {
 82    #[default]
 83    Default,
 84    Thinking {
 85        budget_tokens: Option<u32>,
 86    },
 87}
 88
 89impl Model {
 90    pub fn default_fast() -> Self {
 91        Self::new(
 92            "openrouter/auto",
 93            Some("Auto Router"),
 94            Some(2000000),
 95            Some(true),
 96            Some(false),
 97            Some(ModelMode::Default),
 98        )
 99    }
100
101    pub fn default() -> Self {
102        Self::default_fast()
103    }
104
105    pub fn new(
106        name: &str,
107        display_name: Option<&str>,
108        max_tokens: Option<u64>,
109        supports_tools: Option<bool>,
110        supports_images: Option<bool>,
111        mode: Option<ModelMode>,
112    ) -> Self {
113        Self {
114            name: name.to_owned(),
115            display_name: display_name.map(|s| s.to_owned()),
116            max_tokens: max_tokens.unwrap_or(2000000),
117            supports_tools,
118            supports_images,
119            mode: mode.unwrap_or(ModelMode::Default),
120        }
121    }
122
123    pub fn id(&self) -> &str {
124        &self.name
125    }
126
127    pub fn display_name(&self) -> &str {
128        self.display_name.as_ref().unwrap_or(&self.name)
129    }
130
131    pub fn max_token_count(&self) -> u64 {
132        self.max_tokens
133    }
134
135    pub fn max_output_tokens(&self) -> Option<u64> {
136        None
137    }
138
139    pub fn supports_tool_calls(&self) -> bool {
140        self.supports_tools.unwrap_or(false)
141    }
142
143    pub fn supports_parallel_tool_calls(&self) -> bool {
144        false
145    }
146}
147
148#[derive(Debug, Serialize, Deserialize)]
149pub struct Request {
150    pub model: String,
151    pub messages: Vec<RequestMessage>,
152    pub stream: bool,
153    #[serde(default, skip_serializing_if = "Option::is_none")]
154    pub max_tokens: Option<u64>,
155    #[serde(default, skip_serializing_if = "Vec::is_empty")]
156    pub stop: Vec<String>,
157    pub temperature: f32,
158    #[serde(default, skip_serializing_if = "Option::is_none")]
159    pub tool_choice: Option<ToolChoice>,
160    #[serde(default, skip_serializing_if = "Option::is_none")]
161    pub parallel_tool_calls: Option<bool>,
162    #[serde(default, skip_serializing_if = "Vec::is_empty")]
163    pub tools: Vec<ToolDefinition>,
164    #[serde(default, skip_serializing_if = "Option::is_none")]
165    pub reasoning: Option<Reasoning>,
166    pub usage: RequestUsage,
167}
168
169#[derive(Debug, Default, Serialize, Deserialize)]
170pub struct RequestUsage {
171    pub include: bool,
172}
173
174#[derive(Debug, Serialize, Deserialize)]
175#[serde(rename_all = "lowercase")]
176pub enum ToolChoice {
177    Auto,
178    Required,
179    None,
180    #[serde(untagged)]
181    Other(ToolDefinition),
182}
183
184#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
185#[derive(Clone, Deserialize, Serialize, Debug)]
186#[serde(tag = "type", rename_all = "snake_case")]
187pub enum ToolDefinition {
188    #[allow(dead_code)]
189    Function { function: FunctionDefinition },
190}
191
192#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
193#[derive(Clone, Debug, Serialize, Deserialize)]
194pub struct FunctionDefinition {
195    pub name: String,
196    pub description: Option<String>,
197    pub parameters: Option<Value>,
198}
199
200#[derive(Debug, Serialize, Deserialize)]
201pub struct Reasoning {
202    #[serde(skip_serializing_if = "Option::is_none")]
203    pub effort: Option<String>,
204    #[serde(skip_serializing_if = "Option::is_none")]
205    pub max_tokens: Option<u32>,
206    #[serde(skip_serializing_if = "Option::is_none")]
207    pub exclude: Option<bool>,
208    #[serde(skip_serializing_if = "Option::is_none")]
209    pub enabled: Option<bool>,
210}
211
212#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
213#[serde(tag = "role", rename_all = "lowercase")]
214pub enum RequestMessage {
215    Assistant {
216        content: Option<MessageContent>,
217        #[serde(default, skip_serializing_if = "Vec::is_empty")]
218        tool_calls: Vec<ToolCall>,
219    },
220    User {
221        content: MessageContent,
222    },
223    System {
224        content: MessageContent,
225    },
226    Tool {
227        content: MessageContent,
228        tool_call_id: String,
229    },
230}
231
232#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
233#[serde(untagged)]
234pub enum MessageContent {
235    Plain(String),
236    Multipart(Vec<MessagePart>),
237}
238
239impl MessageContent {
240    pub fn empty() -> Self {
241        Self::Plain(String::new())
242    }
243
244    pub fn push_part(&mut self, part: MessagePart) {
245        match self {
246            Self::Plain(text) if text.is_empty() => {
247                *self = Self::Multipart(vec![part]);
248            }
249            Self::Plain(text) => {
250                let text_part = MessagePart::Text {
251                    text: std::mem::take(text),
252                };
253                *self = Self::Multipart(vec![text_part, part]);
254            }
255            Self::Multipart(parts) => parts.push(part),
256        }
257    }
258}
259
260impl From<Vec<MessagePart>> for MessageContent {
261    fn from(parts: Vec<MessagePart>) -> Self {
262        if parts.len() == 1
263            && let MessagePart::Text { text } = &parts[0]
264        {
265            return Self::Plain(text.clone());
266        }
267        Self::Multipart(parts)
268    }
269}
270
271impl From<String> for MessageContent {
272    fn from(text: String) -> Self {
273        Self::Plain(text)
274    }
275}
276
277impl From<&str> for MessageContent {
278    fn from(text: &str) -> Self {
279        Self::Plain(text.to_string())
280    }
281}
282
283impl MessageContent {
284    pub fn as_text(&self) -> Option<&str> {
285        match self {
286            Self::Plain(text) => Some(text),
287            Self::Multipart(parts) if parts.len() == 1 => {
288                if let MessagePart::Text { text } = &parts[0] {
289                    Some(text)
290                } else {
291                    None
292                }
293            }
294            _ => None,
295        }
296    }
297
298    pub fn to_text(&self) -> String {
299        match self {
300            Self::Plain(text) => text.clone(),
301            Self::Multipart(parts) => parts
302                .iter()
303                .filter_map(|part| {
304                    if let MessagePart::Text { text } = part {
305                        Some(text.as_str())
306                    } else {
307                        None
308                    }
309                })
310                .collect::<Vec<_>>()
311                .join(""),
312        }
313    }
314}
315
316#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
317#[serde(tag = "type", rename_all = "snake_case")]
318pub enum MessagePart {
319    Text {
320        text: String,
321    },
322    #[serde(rename = "image_url")]
323    Image {
324        image_url: String,
325    },
326}
327
328#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
329pub struct ToolCall {
330    pub id: String,
331    #[serde(flatten)]
332    pub content: ToolCallContent,
333}
334
335#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
336#[serde(tag = "type", rename_all = "lowercase")]
337pub enum ToolCallContent {
338    Function { function: FunctionContent },
339}
340
341#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
342pub struct FunctionContent {
343    pub name: String,
344    pub arguments: String,
345}
346
347#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
348pub struct ResponseMessageDelta {
349    pub role: Option<Role>,
350    pub content: Option<String>,
351    pub reasoning: Option<String>,
352    #[serde(default, skip_serializing_if = "is_none_or_empty")]
353    pub tool_calls: Option<Vec<ToolCallChunk>>,
354}
355
356#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
357pub struct ToolCallChunk {
358    pub index: usize,
359    pub id: Option<String>,
360    pub function: Option<FunctionChunk>,
361}
362
363#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
364pub struct FunctionChunk {
365    pub name: Option<String>,
366    pub arguments: Option<String>,
367}
368
369#[derive(Serialize, Deserialize, Debug)]
370pub struct Usage {
371    pub prompt_tokens: u64,
372    pub completion_tokens: u64,
373    pub total_tokens: u64,
374}
375
376#[derive(Serialize, Deserialize, Debug)]
377pub struct ChoiceDelta {
378    pub index: u32,
379    pub delta: ResponseMessageDelta,
380    pub finish_reason: Option<String>,
381}
382
383#[derive(Serialize, Deserialize, Debug)]
384pub struct ResponseStreamEvent {
385    #[serde(default, skip_serializing_if = "Option::is_none")]
386    pub id: Option<String>,
387    pub created: u32,
388    pub model: String,
389    pub choices: Vec<ChoiceDelta>,
390    pub usage: Option<Usage>,
391}
392
393#[derive(Serialize, Deserialize, Debug)]
394pub struct Response {
395    pub id: String,
396    pub object: String,
397    pub created: u64,
398    pub model: String,
399    pub choices: Vec<Choice>,
400    pub usage: Usage,
401}
402
403#[derive(Serialize, Deserialize, Debug)]
404pub struct Choice {
405    pub index: u32,
406    pub message: RequestMessage,
407    pub finish_reason: Option<String>,
408}
409
410#[derive(Default, Debug, Clone, PartialEq, Deserialize)]
411pub struct ListModelsResponse {
412    pub data: Vec<ModelEntry>,
413}
414
415#[derive(Default, Debug, Clone, PartialEq, Deserialize)]
416pub struct ModelEntry {
417    pub id: String,
418    pub name: String,
419    pub created: usize,
420    pub description: String,
421    #[serde(default, skip_serializing_if = "Option::is_none")]
422    pub context_length: Option<u64>,
423    #[serde(default, skip_serializing_if = "Vec::is_empty")]
424    pub supported_parameters: Vec<String>,
425    #[serde(default, skip_serializing_if = "Option::is_none")]
426    pub architecture: Option<ModelArchitecture>,
427}
428
429#[derive(Default, Debug, Clone, PartialEq, Deserialize)]
430pub struct ModelArchitecture {
431    #[serde(default, skip_serializing_if = "Vec::is_empty")]
432    pub input_modalities: Vec<String>,
433}
434
435pub async fn stream_completion(
436    client: &dyn HttpClient,
437    api_url: &str,
438    api_key: &str,
439    request: Request,
440) -> Result<BoxStream<'static, Result<ResponseStreamEvent, OpenRouterError>>, OpenRouterError> {
441    let uri = format!("{api_url}/chat/completions");
442    let request_builder = HttpRequest::builder()
443        .method(Method::POST)
444        .uri(uri)
445        .header("Content-Type", "application/json")
446        .header("Authorization", format!("Bearer {}", api_key))
447        .header("HTTP-Referer", "https://zed.dev")
448        .header("X-Title", "Zed Editor");
449
450    let request = request_builder
451        .body(AsyncBody::from(
452            serde_json::to_string(&request).map_err(OpenRouterError::SerializeRequest)?,
453        ))
454        .map_err(OpenRouterError::BuildRequestBody)?;
455    let mut response = client
456        .send(request)
457        .await
458        .map_err(OpenRouterError::HttpSend)?;
459
460    if response.status().is_success() {
461        let reader = BufReader::new(response.into_body());
462        Ok(reader
463            .lines()
464            .filter_map(|line| async move {
465                match line {
466                    Ok(line) => {
467                        if line.starts_with(':') {
468                            return None;
469                        }
470
471                        let line = line.strip_prefix("data: ")?;
472                        if line == "[DONE]" {
473                            None
474                        } else {
475                            match serde_json::from_str::<ResponseStreamEvent>(line) {
476                                Ok(response) => Some(Ok(response)),
477                                Err(error) => {
478                                    if line.trim().is_empty() {
479                                        None
480                                    } else {
481                                        Some(Err(OpenRouterError::DeserializeResponse(error)))
482                                    }
483                                }
484                            }
485                        }
486                    }
487                    Err(error) => Some(Err(OpenRouterError::ReadResponse(error))),
488                }
489            })
490            .boxed())
491    } else {
492        let code = ApiErrorCode::from_status(response.status().as_u16());
493
494        let mut body = String::new();
495        response
496            .body_mut()
497            .read_to_string(&mut body)
498            .await
499            .map_err(OpenRouterError::ReadResponse)?;
500
501        let error_response = match serde_json::from_str::<OpenRouterErrorResponse>(&body) {
502            Ok(OpenRouterErrorResponse { error }) => error,
503            Err(_) => OpenRouterErrorBody {
504                code: response.status().as_u16(),
505                message: body,
506                metadata: None,
507            },
508        };
509
510        match code {
511            ApiErrorCode::RateLimitError => {
512                let retry_after = extract_retry_after(response.headers());
513                Err(OpenRouterError::RateLimit {
514                    retry_after: retry_after.unwrap_or_else(|| std::time::Duration::from_secs(60)),
515                })
516            }
517            ApiErrorCode::OverloadedError => {
518                let retry_after = extract_retry_after(response.headers());
519                Err(OpenRouterError::ServerOverloaded { retry_after })
520            }
521            _ => Err(OpenRouterError::ApiError(ApiError {
522                code: code,
523                message: error_response.message,
524            })),
525        }
526    }
527}
528
529pub async fn list_models(
530    client: &dyn HttpClient,
531    api_url: &str,
532    api_key: &str,
533) -> Result<Vec<Model>, OpenRouterError> {
534    let uri = format!("{api_url}/models/user");
535    let request_builder = HttpRequest::builder()
536        .method(Method::GET)
537        .uri(uri)
538        .header("Accept", "application/json")
539        .header("Authorization", format!("Bearer {}", api_key))
540        .header("HTTP-Referer", "https://zed.dev")
541        .header("X-Title", "Zed Editor");
542
543    let request = request_builder
544        .body(AsyncBody::default())
545        .map_err(OpenRouterError::BuildRequestBody)?;
546    let mut response = client
547        .send(request)
548        .await
549        .map_err(OpenRouterError::HttpSend)?;
550
551    let mut body = String::new();
552    response
553        .body_mut()
554        .read_to_string(&mut body)
555        .await
556        .map_err(OpenRouterError::ReadResponse)?;
557
558    if response.status().is_success() {
559        let response: ListModelsResponse =
560            serde_json::from_str(&body).map_err(OpenRouterError::DeserializeResponse)?;
561
562        let models = response
563            .data
564            .into_iter()
565            .map(|entry| Model {
566                name: entry.id,
567                // OpenRouter returns display names in the format "provider_name: model_name".
568                // When displayed in the UI, these names can get truncated from the right.
569                // Since users typically already know the provider, we extract just the model name
570                // portion (after the colon) to create a more concise and user-friendly label
571                // for the model dropdown in the agent panel.
572                display_name: Some(
573                    entry
574                        .name
575                        .split(':')
576                        .next_back()
577                        .unwrap_or(&entry.name)
578                        .trim()
579                        .to_string(),
580                ),
581                max_tokens: entry.context_length.unwrap_or(2000000),
582                supports_tools: Some(entry.supported_parameters.contains(&"tools".to_string())),
583                supports_images: Some(
584                    entry
585                        .architecture
586                        .as_ref()
587                        .map(|arch| arch.input_modalities.contains(&"image".to_string()))
588                        .unwrap_or(false),
589                ),
590                mode: if entry
591                    .supported_parameters
592                    .contains(&"reasoning".to_string())
593                {
594                    ModelMode::Thinking {
595                        budget_tokens: Some(4_096),
596                    }
597                } else {
598                    ModelMode::Default
599                },
600            })
601            .collect();
602
603        Ok(models)
604    } else {
605        let code = ApiErrorCode::from_status(response.status().as_u16());
606
607        let mut body = String::new();
608        response
609            .body_mut()
610            .read_to_string(&mut body)
611            .await
612            .map_err(OpenRouterError::ReadResponse)?;
613
614        let error_response = match serde_json::from_str::<OpenRouterErrorResponse>(&body) {
615            Ok(OpenRouterErrorResponse { error }) => error,
616            Err(_) => OpenRouterErrorBody {
617                code: response.status().as_u16(),
618                message: body,
619                metadata: None,
620            },
621        };
622
623        match code {
624            ApiErrorCode::RateLimitError => {
625                let retry_after = extract_retry_after(response.headers());
626                Err(OpenRouterError::RateLimit {
627                    retry_after: retry_after.unwrap_or_else(|| std::time::Duration::from_secs(60)),
628                })
629            }
630            ApiErrorCode::OverloadedError => {
631                let retry_after = extract_retry_after(response.headers());
632                Err(OpenRouterError::ServerOverloaded { retry_after })
633            }
634            _ => Err(OpenRouterError::ApiError(ApiError {
635                code: code,
636                message: error_response.message,
637            })),
638        }
639    }
640}
641
642#[derive(Debug)]
643pub enum OpenRouterError {
644    /// Failed to serialize the HTTP request body to JSON
645    SerializeRequest(serde_json::Error),
646
647    /// Failed to construct the HTTP request body
648    BuildRequestBody(http::Error),
649
650    /// Failed to send the HTTP request
651    HttpSend(anyhow::Error),
652
653    /// Failed to deserialize the response from JSON
654    DeserializeResponse(serde_json::Error),
655
656    /// Failed to read from response stream
657    ReadResponse(io::Error),
658
659    /// Rate limit exceeded
660    RateLimit { retry_after: Duration },
661
662    /// Server overloaded
663    ServerOverloaded { retry_after: Option<Duration> },
664
665    /// API returned an error response
666    ApiError(ApiError),
667}
668
669#[derive(Debug, Serialize, Deserialize)]
670pub struct OpenRouterErrorBody {
671    pub code: u16,
672    pub message: String,
673    #[serde(default, skip_serializing_if = "Option::is_none")]
674    pub metadata: Option<std::collections::HashMap<String, serde_json::Value>>,
675}
676
677#[derive(Debug, Serialize, Deserialize)]
678pub struct OpenRouterErrorResponse {
679    pub error: OpenRouterErrorBody,
680}
681
682#[derive(Debug, Serialize, Deserialize, Error)]
683#[error("OpenRouter API Error: {code}: {message}")]
684pub struct ApiError {
685    pub code: ApiErrorCode,
686    pub message: String,
687}
688
689/// An OpenROuter API error code.
690/// <https://openrouter.ai/docs/api-reference/errors#error-codes>
691#[derive(Debug, PartialEq, Eq, Clone, Copy, EnumString, Serialize, Deserialize)]
692#[strum(serialize_all = "snake_case")]
693pub enum ApiErrorCode {
694    /// 400: Bad Request (invalid or missing params, CORS)
695    InvalidRequestError,
696    /// 401: Invalid credentials (OAuth session expired, disabled/invalid API key)
697    AuthenticationError,
698    /// 402: Your account or API key has insufficient credits. Add more credits and retry the request.
699    PaymentRequiredError,
700    /// 403: Your chosen model requires moderation and your input was flagged
701    PermissionError,
702    /// 408: Your request timed out
703    RequestTimedOut,
704    /// 429: You are being rate limited
705    RateLimitError,
706    /// 502: Your chosen model is down or we received an invalid response from it
707    ApiError,
708    /// 503: There is no available model provider that meets your routing requirements
709    OverloadedError,
710}
711
712impl std::fmt::Display for ApiErrorCode {
713    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
714        let s = match self {
715            ApiErrorCode::InvalidRequestError => "invalid_request_error",
716            ApiErrorCode::AuthenticationError => "authentication_error",
717            ApiErrorCode::PaymentRequiredError => "payment_required_error",
718            ApiErrorCode::PermissionError => "permission_error",
719            ApiErrorCode::RequestTimedOut => "request_timed_out",
720            ApiErrorCode::RateLimitError => "rate_limit_error",
721            ApiErrorCode::ApiError => "api_error",
722            ApiErrorCode::OverloadedError => "overloaded_error",
723        };
724        write!(f, "{s}")
725    }
726}
727
728impl ApiErrorCode {
729    pub fn from_status(status: u16) -> Self {
730        match status {
731            400 => ApiErrorCode::InvalidRequestError,
732            401 => ApiErrorCode::AuthenticationError,
733            402 => ApiErrorCode::PaymentRequiredError,
734            403 => ApiErrorCode::PermissionError,
735            408 => ApiErrorCode::RequestTimedOut,
736            429 => ApiErrorCode::RateLimitError,
737            502 => ApiErrorCode::ApiError,
738            503 => ApiErrorCode::OverloadedError,
739            _ => ApiErrorCode::ApiError,
740        }
741    }
742}