open_router.rs

  1use anyhow::{Result, anyhow};
  2use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
  3use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest, http};
  4use serde::{Deserialize, Serialize};
  5use serde_json::Value;
  6pub use settings::DataCollection;
  7pub use settings::ModelMode;
  8pub use settings::OpenRouterAvailableModel as AvailableModel;
  9pub use settings::OpenRouterProvider as Provider;
 10use std::{convert::TryFrom, io, time::Duration};
 11use strum::EnumString;
 12use thiserror::Error;
 13
 14pub const OPEN_ROUTER_API_URL: &str = "https://openrouter.ai/api/v1";
 15
 16fn extract_retry_after(headers: &http::HeaderMap) -> Option<std::time::Duration> {
 17    if let Some(reset) = headers.get("X-RateLimit-Reset") {
 18        if let Ok(s) = reset.to_str() {
 19            if let Ok(epoch_ms) = s.parse::<u64>() {
 20                let now = std::time::SystemTime::now()
 21                    .duration_since(std::time::UNIX_EPOCH)
 22                    .unwrap_or_default()
 23                    .as_millis() as u64;
 24                if epoch_ms > now {
 25                    return Some(std::time::Duration::from_millis(epoch_ms - now));
 26                }
 27            }
 28        }
 29    }
 30    None
 31}
 32
 33fn is_none_or_empty<T: AsRef<[U]>, U>(opt: &Option<T>) -> bool {
 34    opt.as_ref().is_none_or(|v| v.as_ref().is_empty())
 35}
 36
 37#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
 38#[serde(rename_all = "lowercase")]
 39pub enum Role {
 40    User,
 41    Assistant,
 42    System,
 43    Tool,
 44}
 45
 46impl TryFrom<String> for Role {
 47    type Error = anyhow::Error;
 48
 49    fn try_from(value: String) -> Result<Self> {
 50        match value.as_str() {
 51            "user" => Ok(Self::User),
 52            "assistant" => Ok(Self::Assistant),
 53            "system" => Ok(Self::System),
 54            "tool" => Ok(Self::Tool),
 55            _ => Err(anyhow!("invalid role '{value}'")),
 56        }
 57    }
 58}
 59
 60impl From<Role> for String {
 61    fn from(val: Role) -> Self {
 62        match val {
 63            Role::User => "user".to_owned(),
 64            Role::Assistant => "assistant".to_owned(),
 65            Role::System => "system".to_owned(),
 66            Role::Tool => "tool".to_owned(),
 67        }
 68    }
 69}
 70
 71#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
 72#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
 73pub struct Model {
 74    pub name: String,
 75    pub display_name: Option<String>,
 76    pub max_tokens: u64,
 77    pub supports_tools: Option<bool>,
 78    pub supports_images: Option<bool>,
 79    #[serde(default)]
 80    pub mode: ModelMode,
 81    pub provider: Option<Provider>,
 82}
 83
 84impl Model {
 85    pub fn default() -> Self {
 86        Self::new(
 87            "openrouter/auto",
 88            Some("Auto Router"),
 89            Some(2000000),
 90            Some(true),
 91            Some(false),
 92            Some(ModelMode::Default),
 93            None,
 94        )
 95    }
 96
 97    pub fn new(
 98        name: &str,
 99        display_name: Option<&str>,
100        max_tokens: Option<u64>,
101        supports_tools: Option<bool>,
102        supports_images: Option<bool>,
103        mode: Option<ModelMode>,
104        provider: Option<Provider>,
105    ) -> Self {
106        Self {
107            name: name.to_owned(),
108            display_name: display_name.map(|s| s.to_owned()),
109            max_tokens: max_tokens.unwrap_or(2000000),
110            supports_tools,
111            supports_images,
112            mode: mode.unwrap_or(ModelMode::Default),
113            provider,
114        }
115    }
116
117    pub fn id(&self) -> &str {
118        &self.name
119    }
120
121    pub fn display_name(&self) -> &str {
122        self.display_name.as_ref().unwrap_or(&self.name)
123    }
124
125    pub fn max_token_count(&self) -> u64 {
126        self.max_tokens
127    }
128
129    pub fn max_output_tokens(&self) -> Option<u64> {
130        None
131    }
132
133    pub fn supports_tool_calls(&self) -> bool {
134        self.supports_tools.unwrap_or(false)
135    }
136
137    pub fn supports_parallel_tool_calls(&self) -> bool {
138        false
139    }
140}
141
142#[derive(Debug, Serialize, Deserialize)]
143pub struct Request {
144    pub model: String,
145    pub messages: Vec<RequestMessage>,
146    pub stream: bool,
147    #[serde(default, skip_serializing_if = "Option::is_none")]
148    pub max_tokens: Option<u64>,
149    #[serde(default, skip_serializing_if = "Vec::is_empty")]
150    pub stop: Vec<String>,
151    pub temperature: f32,
152    #[serde(default, skip_serializing_if = "Option::is_none")]
153    pub tool_choice: Option<ToolChoice>,
154    #[serde(default, skip_serializing_if = "Option::is_none")]
155    pub parallel_tool_calls: Option<bool>,
156    #[serde(default, skip_serializing_if = "Vec::is_empty")]
157    pub tools: Vec<ToolDefinition>,
158    #[serde(default, skip_serializing_if = "Option::is_none")]
159    pub reasoning: Option<Reasoning>,
160    pub usage: RequestUsage,
161    pub provider: Option<Provider>,
162}
163
164#[derive(Debug, Default, Serialize, Deserialize)]
165pub struct RequestUsage {
166    pub include: bool,
167}
168
169#[derive(Debug, Serialize, Deserialize)]
170#[serde(rename_all = "lowercase")]
171pub enum ToolChoice {
172    Auto,
173    Required,
174    None,
175    #[serde(untagged)]
176    Other(ToolDefinition),
177}
178
179#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
180#[derive(Clone, Deserialize, Serialize, Debug)]
181#[serde(tag = "type", rename_all = "snake_case")]
182pub enum ToolDefinition {
183    #[allow(dead_code)]
184    Function { function: FunctionDefinition },
185}
186
187#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
188#[derive(Clone, Debug, Serialize, Deserialize)]
189pub struct FunctionDefinition {
190    pub name: String,
191    pub description: Option<String>,
192    pub parameters: Option<Value>,
193}
194
195#[derive(Debug, Serialize, Deserialize)]
196pub struct Reasoning {
197    #[serde(skip_serializing_if = "Option::is_none")]
198    pub effort: Option<String>,
199    #[serde(skip_serializing_if = "Option::is_none")]
200    pub max_tokens: Option<u32>,
201    #[serde(skip_serializing_if = "Option::is_none")]
202    pub exclude: Option<bool>,
203    #[serde(skip_serializing_if = "Option::is_none")]
204    pub enabled: Option<bool>,
205}
206
207#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
208#[serde(tag = "role", rename_all = "lowercase")]
209pub enum RequestMessage {
210    Assistant {
211        content: Option<MessageContent>,
212        #[serde(default, skip_serializing_if = "Vec::is_empty")]
213        tool_calls: Vec<ToolCall>,
214        #[serde(default, skip_serializing_if = "Option::is_none")]
215        reasoning_details: Option<serde_json::Value>,
216    },
217    User {
218        content: MessageContent,
219    },
220    System {
221        content: MessageContent,
222    },
223    Tool {
224        content: MessageContent,
225        tool_call_id: String,
226    },
227}
228
229#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
230#[serde(untagged)]
231pub enum MessageContent {
232    Plain(String),
233    Multipart(Vec<MessagePart>),
234}
235
236impl MessageContent {
237    pub fn empty() -> Self {
238        Self::Plain(String::new())
239    }
240
241    pub fn push_part(&mut self, part: MessagePart) {
242        match self {
243            Self::Plain(text) if text.is_empty() => {
244                *self = Self::Multipart(vec![part]);
245            }
246            Self::Plain(text) => {
247                let text_part = MessagePart::Text {
248                    text: std::mem::take(text),
249                };
250                *self = Self::Multipart(vec![text_part, part]);
251            }
252            Self::Multipart(parts) => parts.push(part),
253        }
254    }
255}
256
257impl From<Vec<MessagePart>> for MessageContent {
258    fn from(parts: Vec<MessagePart>) -> Self {
259        if parts.len() == 1
260            && let MessagePart::Text { text } = &parts[0]
261        {
262            return Self::Plain(text.clone());
263        }
264        Self::Multipart(parts)
265    }
266}
267
268impl From<String> for MessageContent {
269    fn from(text: String) -> Self {
270        Self::Plain(text)
271    }
272}
273
274impl From<&str> for MessageContent {
275    fn from(text: &str) -> Self {
276        Self::Plain(text.to_string())
277    }
278}
279
280impl MessageContent {
281    pub fn as_text(&self) -> Option<&str> {
282        match self {
283            Self::Plain(text) => Some(text),
284            Self::Multipart(parts) if parts.len() == 1 => {
285                if let MessagePart::Text { text } = &parts[0] {
286                    Some(text)
287                } else {
288                    None
289                }
290            }
291            _ => None,
292        }
293    }
294
295    pub fn to_text(&self) -> String {
296        match self {
297            Self::Plain(text) => text.clone(),
298            Self::Multipart(parts) => parts
299                .iter()
300                .filter_map(|part| {
301                    if let MessagePart::Text { text } = part {
302                        Some(text.as_str())
303                    } else {
304                        None
305                    }
306                })
307                .collect::<Vec<_>>()
308                .join(""),
309        }
310    }
311}
312
313#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
314#[serde(tag = "type", rename_all = "snake_case")]
315pub enum MessagePart {
316    Text {
317        text: String,
318    },
319    #[serde(rename = "image_url")]
320    Image {
321        image_url: String,
322    },
323}
324
325#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
326pub struct ToolCall {
327    pub id: String,
328    #[serde(flatten)]
329    pub content: ToolCallContent,
330}
331
332#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
333#[serde(tag = "type", rename_all = "lowercase")]
334pub enum ToolCallContent {
335    Function { function: FunctionContent },
336}
337
338#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
339pub struct FunctionContent {
340    pub name: String,
341    pub arguments: String,
342    #[serde(default, skip_serializing_if = "Option::is_none")]
343    pub thought_signature: Option<String>,
344}
345
346#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
347pub struct ResponseMessageDelta {
348    pub role: Option<Role>,
349    pub content: Option<String>,
350    pub reasoning: Option<String>,
351    #[serde(default, skip_serializing_if = "is_none_or_empty")]
352    pub tool_calls: Option<Vec<ToolCallChunk>>,
353    #[serde(default, skip_serializing_if = "Option::is_none")]
354    pub reasoning_details: Option<serde_json::Value>,
355}
356
357#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
358pub struct ToolCallChunk {
359    pub index: usize,
360    pub id: Option<String>,
361    pub function: Option<FunctionChunk>,
362}
363
364#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
365pub struct FunctionChunk {
366    pub name: Option<String>,
367    pub arguments: Option<String>,
368    #[serde(default)]
369    pub thought_signature: Option<String>,
370}
371
372#[derive(Serialize, Deserialize, Debug)]
373pub struct Usage {
374    pub prompt_tokens: u64,
375    pub completion_tokens: u64,
376    pub total_tokens: u64,
377}
378
379#[derive(Serialize, Deserialize, Debug)]
380pub struct ChoiceDelta {
381    pub index: u32,
382    pub delta: ResponseMessageDelta,
383    pub finish_reason: Option<String>,
384}
385
386#[derive(Serialize, Deserialize, Debug)]
387pub struct ResponseStreamEvent {
388    #[serde(default, skip_serializing_if = "Option::is_none")]
389    pub id: Option<String>,
390    pub created: u32,
391    pub model: String,
392    pub choices: Vec<ChoiceDelta>,
393    pub usage: Option<Usage>,
394}
395
396#[derive(Serialize, Deserialize, Debug)]
397pub struct Response {
398    pub id: String,
399    pub object: String,
400    pub created: u64,
401    pub model: String,
402    pub choices: Vec<Choice>,
403    pub usage: Usage,
404}
405
406#[derive(Serialize, Deserialize, Debug)]
407pub struct Choice {
408    pub index: u32,
409    pub message: RequestMessage,
410    pub finish_reason: Option<String>,
411}
412
413#[derive(Default, Debug, Clone, PartialEq, Deserialize)]
414pub struct ListModelsResponse {
415    pub data: Vec<ModelEntry>,
416}
417
418#[derive(Default, Debug, Clone, PartialEq, Deserialize)]
419pub struct ModelEntry {
420    pub id: String,
421    pub name: String,
422    pub created: usize,
423    pub description: String,
424    #[serde(default, skip_serializing_if = "Option::is_none")]
425    pub context_length: Option<u64>,
426    #[serde(default, skip_serializing_if = "Vec::is_empty")]
427    pub supported_parameters: Vec<String>,
428    #[serde(default, skip_serializing_if = "Option::is_none")]
429    pub architecture: Option<ModelArchitecture>,
430}
431
432#[derive(Default, Debug, Clone, PartialEq, Deserialize)]
433pub struct ModelArchitecture {
434    #[serde(default, skip_serializing_if = "Vec::is_empty")]
435    pub input_modalities: Vec<String>,
436}
437
438pub async fn stream_completion(
439    client: &dyn HttpClient,
440    api_url: &str,
441    api_key: &str,
442    request: Request,
443) -> Result<BoxStream<'static, Result<ResponseStreamEvent, OpenRouterError>>, OpenRouterError> {
444    let uri = format!("{api_url}/chat/completions");
445    let request_builder = HttpRequest::builder()
446        .method(Method::POST)
447        .uri(uri)
448        .header("Content-Type", "application/json")
449        .header("Authorization", format!("Bearer {}", api_key))
450        .header("HTTP-Referer", "https://zed.dev")
451        .header("X-Title", "Zed Editor");
452
453    let request = request_builder
454        .body(AsyncBody::from(
455            serde_json::to_string(&request).map_err(OpenRouterError::SerializeRequest)?,
456        ))
457        .map_err(OpenRouterError::BuildRequestBody)?;
458    let mut response = client
459        .send(request)
460        .await
461        .map_err(OpenRouterError::HttpSend)?;
462
463    if response.status().is_success() {
464        let reader = BufReader::new(response.into_body());
465        Ok(reader
466            .lines()
467            .filter_map(|line| async move {
468                match line {
469                    Ok(line) => {
470                        if line.starts_with(':') {
471                            return None;
472                        }
473
474                        let line = line.strip_prefix("data: ")?;
475                        if line == "[DONE]" {
476                            None
477                        } else {
478                            match serde_json::from_str::<ResponseStreamEvent>(line) {
479                                Ok(response) => Some(Ok(response)),
480                                Err(error) => {
481                                    if line.trim().is_empty() {
482                                        None
483                                    } else {
484                                        Some(Err(OpenRouterError::DeserializeResponse(error)))
485                                    }
486                                }
487                            }
488                        }
489                    }
490                    Err(error) => Some(Err(OpenRouterError::ReadResponse(error))),
491                }
492            })
493            .boxed())
494    } else {
495        let code = ApiErrorCode::from_status(response.status().as_u16());
496
497        let mut body = String::new();
498        response
499            .body_mut()
500            .read_to_string(&mut body)
501            .await
502            .map_err(OpenRouterError::ReadResponse)?;
503
504        let error_response = match serde_json::from_str::<OpenRouterErrorResponse>(&body) {
505            Ok(OpenRouterErrorResponse { error }) => error,
506            Err(_) => OpenRouterErrorBody {
507                code: response.status().as_u16(),
508                message: body,
509                metadata: None,
510            },
511        };
512
513        match code {
514            ApiErrorCode::RateLimitError => {
515                let retry_after = extract_retry_after(response.headers());
516                Err(OpenRouterError::RateLimit {
517                    retry_after: retry_after.unwrap_or_else(|| std::time::Duration::from_secs(60)),
518                })
519            }
520            ApiErrorCode::OverloadedError => {
521                let retry_after = extract_retry_after(response.headers());
522                Err(OpenRouterError::ServerOverloaded { retry_after })
523            }
524            _ => Err(OpenRouterError::ApiError(ApiError {
525                code: code,
526                message: error_response.message,
527            })),
528        }
529    }
530}
531
532pub async fn list_models(
533    client: &dyn HttpClient,
534    api_url: &str,
535    api_key: &str,
536) -> Result<Vec<Model>, OpenRouterError> {
537    let uri = format!("{api_url}/models/user");
538    let request_builder = HttpRequest::builder()
539        .method(Method::GET)
540        .uri(uri)
541        .header("Accept", "application/json")
542        .header("Authorization", format!("Bearer {}", api_key))
543        .header("HTTP-Referer", "https://zed.dev")
544        .header("X-Title", "Zed Editor");
545
546    let request = request_builder
547        .body(AsyncBody::default())
548        .map_err(OpenRouterError::BuildRequestBody)?;
549    let mut response = client
550        .send(request)
551        .await
552        .map_err(OpenRouterError::HttpSend)?;
553
554    let mut body = String::new();
555    response
556        .body_mut()
557        .read_to_string(&mut body)
558        .await
559        .map_err(OpenRouterError::ReadResponse)?;
560
561    if response.status().is_success() {
562        let response: ListModelsResponse =
563            serde_json::from_str(&body).map_err(OpenRouterError::DeserializeResponse)?;
564
565        let models = response
566            .data
567            .into_iter()
568            .map(|entry| Model {
569                name: entry.id,
570                // OpenRouter returns display names in the format "provider_name: model_name".
571                // When displayed in the UI, these names can get truncated from the right.
572                // Since users typically already know the provider, we extract just the model name
573                // portion (after the colon) to create a more concise and user-friendly label
574                // for the model dropdown in the agent panel.
575                display_name: Some(
576                    entry
577                        .name
578                        .split(':')
579                        .next_back()
580                        .unwrap_or(&entry.name)
581                        .trim()
582                        .to_string(),
583                ),
584                max_tokens: entry.context_length.unwrap_or(2000000),
585                supports_tools: Some(entry.supported_parameters.contains(&"tools".to_string())),
586                supports_images: Some(
587                    entry
588                        .architecture
589                        .as_ref()
590                        .map(|arch| arch.input_modalities.contains(&"image".to_string()))
591                        .unwrap_or(false),
592                ),
593                mode: if entry
594                    .supported_parameters
595                    .contains(&"reasoning".to_string())
596                {
597                    ModelMode::Thinking {
598                        budget_tokens: Some(4_096),
599                    }
600                } else {
601                    ModelMode::Default
602                },
603                provider: None,
604            })
605            .collect();
606
607        Ok(models)
608    } else {
609        let code = ApiErrorCode::from_status(response.status().as_u16());
610
611        let mut body = String::new();
612        response
613            .body_mut()
614            .read_to_string(&mut body)
615            .await
616            .map_err(OpenRouterError::ReadResponse)?;
617
618        let error_response = match serde_json::from_str::<OpenRouterErrorResponse>(&body) {
619            Ok(OpenRouterErrorResponse { error }) => error,
620            Err(_) => OpenRouterErrorBody {
621                code: response.status().as_u16(),
622                message: body,
623                metadata: None,
624            },
625        };
626
627        match code {
628            ApiErrorCode::RateLimitError => {
629                let retry_after = extract_retry_after(response.headers());
630                Err(OpenRouterError::RateLimit {
631                    retry_after: retry_after.unwrap_or_else(|| std::time::Duration::from_secs(60)),
632                })
633            }
634            ApiErrorCode::OverloadedError => {
635                let retry_after = extract_retry_after(response.headers());
636                Err(OpenRouterError::ServerOverloaded { retry_after })
637            }
638            _ => Err(OpenRouterError::ApiError(ApiError {
639                code: code,
640                message: error_response.message,
641            })),
642        }
643    }
644}
645
646#[derive(Debug)]
647pub enum OpenRouterError {
648    /// Failed to serialize the HTTP request body to JSON
649    SerializeRequest(serde_json::Error),
650
651    /// Failed to construct the HTTP request body
652    BuildRequestBody(http::Error),
653
654    /// Failed to send the HTTP request
655    HttpSend(anyhow::Error),
656
657    /// Failed to deserialize the response from JSON
658    DeserializeResponse(serde_json::Error),
659
660    /// Failed to read from response stream
661    ReadResponse(io::Error),
662
663    /// Rate limit exceeded
664    RateLimit { retry_after: Duration },
665
666    /// Server overloaded
667    ServerOverloaded { retry_after: Option<Duration> },
668
669    /// API returned an error response
670    ApiError(ApiError),
671}
672
673#[derive(Debug, Serialize, Deserialize)]
674pub struct OpenRouterErrorBody {
675    pub code: u16,
676    pub message: String,
677    #[serde(default, skip_serializing_if = "Option::is_none")]
678    pub metadata: Option<std::collections::HashMap<String, serde_json::Value>>,
679}
680
681#[derive(Debug, Serialize, Deserialize)]
682pub struct OpenRouterErrorResponse {
683    pub error: OpenRouterErrorBody,
684}
685
686#[derive(Debug, Serialize, Deserialize, Error)]
687#[error("OpenRouter API Error: {code}: {message}")]
688pub struct ApiError {
689    pub code: ApiErrorCode,
690    pub message: String,
691}
692
693/// An OpenROuter API error code.
694/// <https://openrouter.ai/docs/api-reference/errors#error-codes>
695#[derive(Debug, PartialEq, Eq, Clone, Copy, EnumString, Serialize, Deserialize)]
696#[strum(serialize_all = "snake_case")]
697pub enum ApiErrorCode {
698    /// 400: Bad Request (invalid or missing params, CORS)
699    InvalidRequestError,
700    /// 401: Invalid credentials (OAuth session expired, disabled/invalid API key)
701    AuthenticationError,
702    /// 402: Your account or API key has insufficient credits. Add more credits and retry the request.
703    PaymentRequiredError,
704    /// 403: Your chosen model requires moderation and your input was flagged
705    PermissionError,
706    /// 408: Your request timed out
707    RequestTimedOut,
708    /// 429: You are being rate limited
709    RateLimitError,
710    /// 502: Your chosen model is down or we received an invalid response from it
711    ApiError,
712    /// 503: There is no available model provider that meets your routing requirements
713    OverloadedError,
714}
715
716impl std::fmt::Display for ApiErrorCode {
717    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
718        let s = match self {
719            ApiErrorCode::InvalidRequestError => "invalid_request_error",
720            ApiErrorCode::AuthenticationError => "authentication_error",
721            ApiErrorCode::PaymentRequiredError => "payment_required_error",
722            ApiErrorCode::PermissionError => "permission_error",
723            ApiErrorCode::RequestTimedOut => "request_timed_out",
724            ApiErrorCode::RateLimitError => "rate_limit_error",
725            ApiErrorCode::ApiError => "api_error",
726            ApiErrorCode::OverloadedError => "overloaded_error",
727        };
728        write!(f, "{s}")
729    }
730}
731
732impl ApiErrorCode {
733    pub fn from_status(status: u16) -> Self {
734        match status {
735            400 => ApiErrorCode::InvalidRequestError,
736            401 => ApiErrorCode::AuthenticationError,
737            402 => ApiErrorCode::PaymentRequiredError,
738            403 => ApiErrorCode::PermissionError,
739            408 => ApiErrorCode::RequestTimedOut,
740            429 => ApiErrorCode::RateLimitError,
741            502 => ApiErrorCode::ApiError,
742            503 => ApiErrorCode::OverloadedError,
743            _ => ApiErrorCode::ApiError,
744        }
745    }
746}