open_router.rs

  1use anyhow::{Result, anyhow};
  2use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
  3use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest, http};
  4use serde::{Deserialize, Serialize};
  5use serde_json::Value;
  6pub use settings::DataCollection;
  7pub use settings::ModelMode;
  8pub use settings::OpenRouterAvailableModel as AvailableModel;
  9pub use settings::OpenRouterProvider as Provider;
 10use std::{convert::TryFrom, io, time::Duration};
 11use strum::EnumString;
 12use thiserror::Error;
 13
 14pub const OPEN_ROUTER_API_URL: &str = "https://openrouter.ai/api/v1";
 15
 16fn extract_retry_after(headers: &http::HeaderMap) -> Option<std::time::Duration> {
 17    if let Some(reset) = headers.get("X-RateLimit-Reset") {
 18        if let Ok(s) = reset.to_str() {
 19            if let Ok(epoch_ms) = s.parse::<u64>() {
 20                let now = std::time::SystemTime::now()
 21                    .duration_since(std::time::UNIX_EPOCH)
 22                    .unwrap_or_default()
 23                    .as_millis() as u64;
 24                if epoch_ms > now {
 25                    return Some(std::time::Duration::from_millis(epoch_ms - now));
 26                }
 27            }
 28        }
 29    }
 30    None
 31}
 32
 33fn is_none_or_empty<T: AsRef<[U]>, U>(opt: &Option<T>) -> bool {
 34    opt.as_ref().is_none_or(|v| v.as_ref().is_empty())
 35}
 36
 37#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
 38#[serde(rename_all = "lowercase")]
 39pub enum Role {
 40    User,
 41    Assistant,
 42    System,
 43    Tool,
 44}
 45
 46impl TryFrom<String> for Role {
 47    type Error = anyhow::Error;
 48
 49    fn try_from(value: String) -> Result<Self> {
 50        match value.as_str() {
 51            "user" => Ok(Self::User),
 52            "assistant" => Ok(Self::Assistant),
 53            "system" => Ok(Self::System),
 54            "tool" => Ok(Self::Tool),
 55            _ => Err(anyhow!("invalid role '{value}'")),
 56        }
 57    }
 58}
 59
 60impl From<Role> for String {
 61    fn from(val: Role) -> Self {
 62        match val {
 63            Role::User => "user".to_owned(),
 64            Role::Assistant => "assistant".to_owned(),
 65            Role::System => "system".to_owned(),
 66            Role::Tool => "tool".to_owned(),
 67        }
 68    }
 69}
 70
 71#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
 72#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
 73pub struct Model {
 74    pub name: String,
 75    pub display_name: Option<String>,
 76    pub max_tokens: u64,
 77    pub supports_tools: Option<bool>,
 78    pub supports_images: Option<bool>,
 79    #[serde(default)]
 80    pub mode: ModelMode,
 81    pub provider: Option<Provider>,
 82}
 83
 84impl Model {
 85    pub fn default_fast() -> Self {
 86        Self::new(
 87            "openrouter/auto",
 88            Some("Auto Router"),
 89            Some(2000000),
 90            Some(true),
 91            Some(false),
 92            Some(ModelMode::Default),
 93            None,
 94        )
 95    }
 96
 97    pub fn default() -> Self {
 98        Self::default_fast()
 99    }
100
101    pub fn new(
102        name: &str,
103        display_name: Option<&str>,
104        max_tokens: Option<u64>,
105        supports_tools: Option<bool>,
106        supports_images: Option<bool>,
107        mode: Option<ModelMode>,
108        provider: Option<Provider>,
109    ) -> Self {
110        Self {
111            name: name.to_owned(),
112            display_name: display_name.map(|s| s.to_owned()),
113            max_tokens: max_tokens.unwrap_or(2000000),
114            supports_tools,
115            supports_images,
116            mode: mode.unwrap_or(ModelMode::Default),
117            provider,
118        }
119    }
120
121    pub fn id(&self) -> &str {
122        &self.name
123    }
124
125    pub fn display_name(&self) -> &str {
126        self.display_name.as_ref().unwrap_or(&self.name)
127    }
128
129    pub fn max_token_count(&self) -> u64 {
130        self.max_tokens
131    }
132
133    pub fn max_output_tokens(&self) -> Option<u64> {
134        None
135    }
136
137    pub fn supports_tool_calls(&self) -> bool {
138        self.supports_tools.unwrap_or(false)
139    }
140
141    pub fn supports_parallel_tool_calls(&self) -> bool {
142        false
143    }
144}
145
146#[derive(Debug, Serialize, Deserialize)]
147pub struct Request {
148    pub model: String,
149    pub messages: Vec<RequestMessage>,
150    pub stream: bool,
151    #[serde(default, skip_serializing_if = "Option::is_none")]
152    pub max_tokens: Option<u64>,
153    #[serde(default, skip_serializing_if = "Vec::is_empty")]
154    pub stop: Vec<String>,
155    pub temperature: f32,
156    #[serde(default, skip_serializing_if = "Option::is_none")]
157    pub tool_choice: Option<ToolChoice>,
158    #[serde(default, skip_serializing_if = "Option::is_none")]
159    pub parallel_tool_calls: Option<bool>,
160    #[serde(default, skip_serializing_if = "Vec::is_empty")]
161    pub tools: Vec<ToolDefinition>,
162    #[serde(default, skip_serializing_if = "Option::is_none")]
163    pub reasoning: Option<Reasoning>,
164    pub usage: RequestUsage,
165    pub provider: Option<Provider>,
166}
167
168#[derive(Debug, Default, Serialize, Deserialize)]
169pub struct RequestUsage {
170    pub include: bool,
171}
172
173#[derive(Debug, Serialize, Deserialize)]
174#[serde(rename_all = "lowercase")]
175pub enum ToolChoice {
176    Auto,
177    Required,
178    None,
179    #[serde(untagged)]
180    Other(ToolDefinition),
181}
182
183#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
184#[derive(Clone, Deserialize, Serialize, Debug)]
185#[serde(tag = "type", rename_all = "snake_case")]
186pub enum ToolDefinition {
187    #[allow(dead_code)]
188    Function { function: FunctionDefinition },
189}
190
191#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
192#[derive(Clone, Debug, Serialize, Deserialize)]
193pub struct FunctionDefinition {
194    pub name: String,
195    pub description: Option<String>,
196    pub parameters: Option<Value>,
197}
198
199#[derive(Debug, Serialize, Deserialize)]
200pub struct Reasoning {
201    #[serde(skip_serializing_if = "Option::is_none")]
202    pub effort: Option<String>,
203    #[serde(skip_serializing_if = "Option::is_none")]
204    pub max_tokens: Option<u32>,
205    #[serde(skip_serializing_if = "Option::is_none")]
206    pub exclude: Option<bool>,
207    #[serde(skip_serializing_if = "Option::is_none")]
208    pub enabled: Option<bool>,
209}
210
211#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
212#[serde(tag = "role", rename_all = "lowercase")]
213pub enum RequestMessage {
214    Assistant {
215        content: Option<MessageContent>,
216        #[serde(default, skip_serializing_if = "Vec::is_empty")]
217        tool_calls: Vec<ToolCall>,
218        #[serde(default, skip_serializing_if = "Option::is_none")]
219        reasoning_details: Option<serde_json::Value>,
220    },
221    User {
222        content: MessageContent,
223    },
224    System {
225        content: MessageContent,
226    },
227    Tool {
228        content: MessageContent,
229        tool_call_id: String,
230    },
231}
232
233#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
234#[serde(untagged)]
235pub enum MessageContent {
236    Plain(String),
237    Multipart(Vec<MessagePart>),
238}
239
240impl MessageContent {
241    pub fn empty() -> Self {
242        Self::Plain(String::new())
243    }
244
245    pub fn push_part(&mut self, part: MessagePart) {
246        match self {
247            Self::Plain(text) if text.is_empty() => {
248                *self = Self::Multipart(vec![part]);
249            }
250            Self::Plain(text) => {
251                let text_part = MessagePart::Text {
252                    text: std::mem::take(text),
253                };
254                *self = Self::Multipart(vec![text_part, part]);
255            }
256            Self::Multipart(parts) => parts.push(part),
257        }
258    }
259}
260
261impl From<Vec<MessagePart>> for MessageContent {
262    fn from(parts: Vec<MessagePart>) -> Self {
263        if parts.len() == 1
264            && let MessagePart::Text { text } = &parts[0]
265        {
266            return Self::Plain(text.clone());
267        }
268        Self::Multipart(parts)
269    }
270}
271
272impl From<String> for MessageContent {
273    fn from(text: String) -> Self {
274        Self::Plain(text)
275    }
276}
277
278impl From<&str> for MessageContent {
279    fn from(text: &str) -> Self {
280        Self::Plain(text.to_string())
281    }
282}
283
284impl MessageContent {
285    pub fn as_text(&self) -> Option<&str> {
286        match self {
287            Self::Plain(text) => Some(text),
288            Self::Multipart(parts) if parts.len() == 1 => {
289                if let MessagePart::Text { text } = &parts[0] {
290                    Some(text)
291                } else {
292                    None
293                }
294            }
295            _ => None,
296        }
297    }
298
299    pub fn to_text(&self) -> String {
300        match self {
301            Self::Plain(text) => text.clone(),
302            Self::Multipart(parts) => parts
303                .iter()
304                .filter_map(|part| {
305                    if let MessagePart::Text { text } = part {
306                        Some(text.as_str())
307                    } else {
308                        None
309                    }
310                })
311                .collect::<Vec<_>>()
312                .join(""),
313        }
314    }
315}
316
317#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
318#[serde(tag = "type", rename_all = "snake_case")]
319pub enum MessagePart {
320    Text {
321        text: String,
322    },
323    #[serde(rename = "image_url")]
324    Image {
325        image_url: String,
326    },
327}
328
329#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
330pub struct ToolCall {
331    pub id: String,
332    #[serde(flatten)]
333    pub content: ToolCallContent,
334}
335
336#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
337#[serde(tag = "type", rename_all = "lowercase")]
338pub enum ToolCallContent {
339    Function { function: FunctionContent },
340}
341
342#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
343pub struct FunctionContent {
344    pub name: String,
345    pub arguments: String,
346    #[serde(default, skip_serializing_if = "Option::is_none")]
347    pub thought_signature: Option<String>,
348}
349
350#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
351pub struct ResponseMessageDelta {
352    pub role: Option<Role>,
353    pub content: Option<String>,
354    pub reasoning: Option<String>,
355    #[serde(default, skip_serializing_if = "is_none_or_empty")]
356    pub tool_calls: Option<Vec<ToolCallChunk>>,
357    #[serde(default, skip_serializing_if = "Option::is_none")]
358    pub reasoning_details: Option<serde_json::Value>,
359}
360
361#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
362pub struct ToolCallChunk {
363    pub index: usize,
364    pub id: Option<String>,
365    pub function: Option<FunctionChunk>,
366}
367
368#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
369pub struct FunctionChunk {
370    pub name: Option<String>,
371    pub arguments: Option<String>,
372    #[serde(default)]
373    pub thought_signature: Option<String>,
374}
375
376#[derive(Serialize, Deserialize, Debug)]
377pub struct Usage {
378    pub prompt_tokens: u64,
379    pub completion_tokens: u64,
380    pub total_tokens: u64,
381}
382
383#[derive(Serialize, Deserialize, Debug)]
384pub struct ChoiceDelta {
385    pub index: u32,
386    pub delta: ResponseMessageDelta,
387    pub finish_reason: Option<String>,
388}
389
390#[derive(Serialize, Deserialize, Debug)]
391pub struct ResponseStreamEvent {
392    #[serde(default, skip_serializing_if = "Option::is_none")]
393    pub id: Option<String>,
394    pub created: u32,
395    pub model: String,
396    pub choices: Vec<ChoiceDelta>,
397    pub usage: Option<Usage>,
398}
399
400#[derive(Serialize, Deserialize, Debug)]
401pub struct Response {
402    pub id: String,
403    pub object: String,
404    pub created: u64,
405    pub model: String,
406    pub choices: Vec<Choice>,
407    pub usage: Usage,
408}
409
410#[derive(Serialize, Deserialize, Debug)]
411pub struct Choice {
412    pub index: u32,
413    pub message: RequestMessage,
414    pub finish_reason: Option<String>,
415}
416
417#[derive(Default, Debug, Clone, PartialEq, Deserialize)]
418pub struct ListModelsResponse {
419    pub data: Vec<ModelEntry>,
420}
421
422#[derive(Default, Debug, Clone, PartialEq, Deserialize)]
423pub struct ModelEntry {
424    pub id: String,
425    pub name: String,
426    pub created: usize,
427    pub description: String,
428    #[serde(default, skip_serializing_if = "Option::is_none")]
429    pub context_length: Option<u64>,
430    #[serde(default, skip_serializing_if = "Vec::is_empty")]
431    pub supported_parameters: Vec<String>,
432    #[serde(default, skip_serializing_if = "Option::is_none")]
433    pub architecture: Option<ModelArchitecture>,
434}
435
436#[derive(Default, Debug, Clone, PartialEq, Deserialize)]
437pub struct ModelArchitecture {
438    #[serde(default, skip_serializing_if = "Vec::is_empty")]
439    pub input_modalities: Vec<String>,
440}
441
442pub async fn stream_completion(
443    client: &dyn HttpClient,
444    api_url: &str,
445    api_key: &str,
446    request: Request,
447) -> Result<BoxStream<'static, Result<ResponseStreamEvent, OpenRouterError>>, OpenRouterError> {
448    let uri = format!("{api_url}/chat/completions");
449    let request_builder = HttpRequest::builder()
450        .method(Method::POST)
451        .uri(uri)
452        .header("Content-Type", "application/json")
453        .header("Authorization", format!("Bearer {}", api_key))
454        .header("HTTP-Referer", "https://zed.dev")
455        .header("X-Title", "Zed Editor");
456
457    let request = request_builder
458        .body(AsyncBody::from(
459            serde_json::to_string(&request).map_err(OpenRouterError::SerializeRequest)?,
460        ))
461        .map_err(OpenRouterError::BuildRequestBody)?;
462    let mut response = client
463        .send(request)
464        .await
465        .map_err(OpenRouterError::HttpSend)?;
466
467    if response.status().is_success() {
468        let reader = BufReader::new(response.into_body());
469        Ok(reader
470            .lines()
471            .filter_map(|line| async move {
472                match line {
473                    Ok(line) => {
474                        if line.starts_with(':') {
475                            return None;
476                        }
477
478                        let line = line.strip_prefix("data: ")?;
479                        if line == "[DONE]" {
480                            None
481                        } else {
482                            match serde_json::from_str::<ResponseStreamEvent>(line) {
483                                Ok(response) => Some(Ok(response)),
484                                Err(error) => {
485                                    if line.trim().is_empty() {
486                                        None
487                                    } else {
488                                        Some(Err(OpenRouterError::DeserializeResponse(error)))
489                                    }
490                                }
491                            }
492                        }
493                    }
494                    Err(error) => Some(Err(OpenRouterError::ReadResponse(error))),
495                }
496            })
497            .boxed())
498    } else {
499        let code = ApiErrorCode::from_status(response.status().as_u16());
500
501        let mut body = String::new();
502        response
503            .body_mut()
504            .read_to_string(&mut body)
505            .await
506            .map_err(OpenRouterError::ReadResponse)?;
507
508        let error_response = match serde_json::from_str::<OpenRouterErrorResponse>(&body) {
509            Ok(OpenRouterErrorResponse { error }) => error,
510            Err(_) => OpenRouterErrorBody {
511                code: response.status().as_u16(),
512                message: body,
513                metadata: None,
514            },
515        };
516
517        match code {
518            ApiErrorCode::RateLimitError => {
519                let retry_after = extract_retry_after(response.headers());
520                Err(OpenRouterError::RateLimit {
521                    retry_after: retry_after.unwrap_or_else(|| std::time::Duration::from_secs(60)),
522                })
523            }
524            ApiErrorCode::OverloadedError => {
525                let retry_after = extract_retry_after(response.headers());
526                Err(OpenRouterError::ServerOverloaded { retry_after })
527            }
528            _ => Err(OpenRouterError::ApiError(ApiError {
529                code: code,
530                message: error_response.message,
531            })),
532        }
533    }
534}
535
536pub async fn list_models(
537    client: &dyn HttpClient,
538    api_url: &str,
539    api_key: &str,
540) -> Result<Vec<Model>, OpenRouterError> {
541    let uri = format!("{api_url}/models/user");
542    let request_builder = HttpRequest::builder()
543        .method(Method::GET)
544        .uri(uri)
545        .header("Accept", "application/json")
546        .header("Authorization", format!("Bearer {}", api_key))
547        .header("HTTP-Referer", "https://zed.dev")
548        .header("X-Title", "Zed Editor");
549
550    let request = request_builder
551        .body(AsyncBody::default())
552        .map_err(OpenRouterError::BuildRequestBody)?;
553    let mut response = client
554        .send(request)
555        .await
556        .map_err(OpenRouterError::HttpSend)?;
557
558    let mut body = String::new();
559    response
560        .body_mut()
561        .read_to_string(&mut body)
562        .await
563        .map_err(OpenRouterError::ReadResponse)?;
564
565    if response.status().is_success() {
566        let response: ListModelsResponse =
567            serde_json::from_str(&body).map_err(OpenRouterError::DeserializeResponse)?;
568
569        let models = response
570            .data
571            .into_iter()
572            .map(|entry| Model {
573                name: entry.id,
574                // OpenRouter returns display names in the format "provider_name: model_name".
575                // When displayed in the UI, these names can get truncated from the right.
576                // Since users typically already know the provider, we extract just the model name
577                // portion (after the colon) to create a more concise and user-friendly label
578                // for the model dropdown in the agent panel.
579                display_name: Some(
580                    entry
581                        .name
582                        .split(':')
583                        .next_back()
584                        .unwrap_or(&entry.name)
585                        .trim()
586                        .to_string(),
587                ),
588                max_tokens: entry.context_length.unwrap_or(2000000),
589                supports_tools: Some(entry.supported_parameters.contains(&"tools".to_string())),
590                supports_images: Some(
591                    entry
592                        .architecture
593                        .as_ref()
594                        .map(|arch| arch.input_modalities.contains(&"image".to_string()))
595                        .unwrap_or(false),
596                ),
597                mode: if entry
598                    .supported_parameters
599                    .contains(&"reasoning".to_string())
600                {
601                    ModelMode::Thinking {
602                        budget_tokens: Some(4_096),
603                    }
604                } else {
605                    ModelMode::Default
606                },
607                provider: None,
608            })
609            .collect();
610
611        Ok(models)
612    } else {
613        let code = ApiErrorCode::from_status(response.status().as_u16());
614
615        let mut body = String::new();
616        response
617            .body_mut()
618            .read_to_string(&mut body)
619            .await
620            .map_err(OpenRouterError::ReadResponse)?;
621
622        let error_response = match serde_json::from_str::<OpenRouterErrorResponse>(&body) {
623            Ok(OpenRouterErrorResponse { error }) => error,
624            Err(_) => OpenRouterErrorBody {
625                code: response.status().as_u16(),
626                message: body,
627                metadata: None,
628            },
629        };
630
631        match code {
632            ApiErrorCode::RateLimitError => {
633                let retry_after = extract_retry_after(response.headers());
634                Err(OpenRouterError::RateLimit {
635                    retry_after: retry_after.unwrap_or_else(|| std::time::Duration::from_secs(60)),
636                })
637            }
638            ApiErrorCode::OverloadedError => {
639                let retry_after = extract_retry_after(response.headers());
640                Err(OpenRouterError::ServerOverloaded { retry_after })
641            }
642            _ => Err(OpenRouterError::ApiError(ApiError {
643                code: code,
644                message: error_response.message,
645            })),
646        }
647    }
648}
649
650#[derive(Debug)]
651pub enum OpenRouterError {
652    /// Failed to serialize the HTTP request body to JSON
653    SerializeRequest(serde_json::Error),
654
655    /// Failed to construct the HTTP request body
656    BuildRequestBody(http::Error),
657
658    /// Failed to send the HTTP request
659    HttpSend(anyhow::Error),
660
661    /// Failed to deserialize the response from JSON
662    DeserializeResponse(serde_json::Error),
663
664    /// Failed to read from response stream
665    ReadResponse(io::Error),
666
667    /// Rate limit exceeded
668    RateLimit { retry_after: Duration },
669
670    /// Server overloaded
671    ServerOverloaded { retry_after: Option<Duration> },
672
673    /// API returned an error response
674    ApiError(ApiError),
675}
676
677#[derive(Debug, Serialize, Deserialize)]
678pub struct OpenRouterErrorBody {
679    pub code: u16,
680    pub message: String,
681    #[serde(default, skip_serializing_if = "Option::is_none")]
682    pub metadata: Option<std::collections::HashMap<String, serde_json::Value>>,
683}
684
685#[derive(Debug, Serialize, Deserialize)]
686pub struct OpenRouterErrorResponse {
687    pub error: OpenRouterErrorBody,
688}
689
690#[derive(Debug, Serialize, Deserialize, Error)]
691#[error("OpenRouter API Error: {code}: {message}")]
692pub struct ApiError {
693    pub code: ApiErrorCode,
694    pub message: String,
695}
696
697/// An OpenROuter API error code.
698/// <https://openrouter.ai/docs/api-reference/errors#error-codes>
699#[derive(Debug, PartialEq, Eq, Clone, Copy, EnumString, Serialize, Deserialize)]
700#[strum(serialize_all = "snake_case")]
701pub enum ApiErrorCode {
702    /// 400: Bad Request (invalid or missing params, CORS)
703    InvalidRequestError,
704    /// 401: Invalid credentials (OAuth session expired, disabled/invalid API key)
705    AuthenticationError,
706    /// 402: Your account or API key has insufficient credits. Add more credits and retry the request.
707    PaymentRequiredError,
708    /// 403: Your chosen model requires moderation and your input was flagged
709    PermissionError,
710    /// 408: Your request timed out
711    RequestTimedOut,
712    /// 429: You are being rate limited
713    RateLimitError,
714    /// 502: Your chosen model is down or we received an invalid response from it
715    ApiError,
716    /// 503: There is no available model provider that meets your routing requirements
717    OverloadedError,
718}
719
720impl std::fmt::Display for ApiErrorCode {
721    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
722        let s = match self {
723            ApiErrorCode::InvalidRequestError => "invalid_request_error",
724            ApiErrorCode::AuthenticationError => "authentication_error",
725            ApiErrorCode::PaymentRequiredError => "payment_required_error",
726            ApiErrorCode::PermissionError => "permission_error",
727            ApiErrorCode::RequestTimedOut => "request_timed_out",
728            ApiErrorCode::RateLimitError => "rate_limit_error",
729            ApiErrorCode::ApiError => "api_error",
730            ApiErrorCode::OverloadedError => "overloaded_error",
731        };
732        write!(f, "{s}")
733    }
734}
735
736impl ApiErrorCode {
737    pub fn from_status(status: u16) -> Self {
738        match status {
739            400 => ApiErrorCode::InvalidRequestError,
740            401 => ApiErrorCode::AuthenticationError,
741            402 => ApiErrorCode::PaymentRequiredError,
742            403 => ApiErrorCode::PermissionError,
743            408 => ApiErrorCode::RequestTimedOut,
744            429 => ApiErrorCode::RateLimitError,
745            502 => ApiErrorCode::ApiError,
746            503 => ApiErrorCode::OverloadedError,
747            _ => ApiErrorCode::ApiError,
748        }
749    }
750}