1use anyhow::{Result, anyhow};
2use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
3use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest, http};
4use serde::{Deserialize, Serialize};
5use serde_json::Value;
6pub use settings::DataCollection;
7pub use settings::ModelMode;
8pub use settings::OpenRouterAvailableModel as AvailableModel;
9pub use settings::OpenRouterProvider as Provider;
10use std::{convert::TryFrom, io, time::Duration};
11use strum::EnumString;
12use thiserror::Error;
13
14pub const OPEN_ROUTER_API_URL: &str = "https://openrouter.ai/api/v1";
15
16fn extract_retry_after(headers: &http::HeaderMap) -> Option<std::time::Duration> {
17 if let Some(reset) = headers.get("X-RateLimit-Reset") {
18 if let Ok(s) = reset.to_str() {
19 if let Ok(epoch_ms) = s.parse::<u64>() {
20 let now = std::time::SystemTime::now()
21 .duration_since(std::time::UNIX_EPOCH)
22 .unwrap_or_default()
23 .as_millis() as u64;
24 if epoch_ms > now {
25 return Some(std::time::Duration::from_millis(epoch_ms - now));
26 }
27 }
28 }
29 }
30 None
31}
32
33fn is_none_or_empty<T: AsRef<[U]>, U>(opt: &Option<T>) -> bool {
34 opt.as_ref().is_none_or(|v| v.as_ref().is_empty())
35}
36
37#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
38#[serde(rename_all = "lowercase")]
39pub enum Role {
40 User,
41 Assistant,
42 System,
43 Tool,
44}
45
46impl TryFrom<String> for Role {
47 type Error = anyhow::Error;
48
49 fn try_from(value: String) -> Result<Self> {
50 match value.as_str() {
51 "user" => Ok(Self::User),
52 "assistant" => Ok(Self::Assistant),
53 "system" => Ok(Self::System),
54 "tool" => Ok(Self::Tool),
55 _ => Err(anyhow!("invalid role '{value}'")),
56 }
57 }
58}
59
60impl From<Role> for String {
61 fn from(val: Role) -> Self {
62 match val {
63 Role::User => "user".to_owned(),
64 Role::Assistant => "assistant".to_owned(),
65 Role::System => "system".to_owned(),
66 Role::Tool => "tool".to_owned(),
67 }
68 }
69}
70
71#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
72#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
73pub struct Model {
74 pub name: String,
75 pub display_name: Option<String>,
76 pub max_tokens: u64,
77 pub supports_tools: Option<bool>,
78 pub supports_images: Option<bool>,
79 #[serde(default)]
80 pub mode: ModelMode,
81 pub provider: Option<Provider>,
82}
83
84impl Model {
85 pub fn default_fast() -> Self {
86 Self::new(
87 "openrouter/auto",
88 Some("Auto Router"),
89 Some(2000000),
90 Some(true),
91 Some(false),
92 Some(ModelMode::Default),
93 None,
94 )
95 }
96
97 pub fn default() -> Self {
98 Self::default_fast()
99 }
100
101 pub fn new(
102 name: &str,
103 display_name: Option<&str>,
104 max_tokens: Option<u64>,
105 supports_tools: Option<bool>,
106 supports_images: Option<bool>,
107 mode: Option<ModelMode>,
108 provider: Option<Provider>,
109 ) -> Self {
110 Self {
111 name: name.to_owned(),
112 display_name: display_name.map(|s| s.to_owned()),
113 max_tokens: max_tokens.unwrap_or(2000000),
114 supports_tools,
115 supports_images,
116 mode: mode.unwrap_or(ModelMode::Default),
117 provider,
118 }
119 }
120
121 pub fn id(&self) -> &str {
122 &self.name
123 }
124
125 pub fn display_name(&self) -> &str {
126 self.display_name.as_ref().unwrap_or(&self.name)
127 }
128
129 pub fn max_token_count(&self) -> u64 {
130 self.max_tokens
131 }
132
133 pub fn max_output_tokens(&self) -> Option<u64> {
134 None
135 }
136
137 pub fn supports_tool_calls(&self) -> bool {
138 self.supports_tools.unwrap_or(false)
139 }
140
141 pub fn supports_parallel_tool_calls(&self) -> bool {
142 false
143 }
144}
145
146#[derive(Debug, Serialize, Deserialize)]
147pub struct Request {
148 pub model: String,
149 pub messages: Vec<RequestMessage>,
150 pub stream: bool,
151 #[serde(default, skip_serializing_if = "Option::is_none")]
152 pub max_tokens: Option<u64>,
153 #[serde(default, skip_serializing_if = "Vec::is_empty")]
154 pub stop: Vec<String>,
155 pub temperature: f32,
156 #[serde(default, skip_serializing_if = "Option::is_none")]
157 pub tool_choice: Option<ToolChoice>,
158 #[serde(default, skip_serializing_if = "Option::is_none")]
159 pub parallel_tool_calls: Option<bool>,
160 #[serde(default, skip_serializing_if = "Vec::is_empty")]
161 pub tools: Vec<ToolDefinition>,
162 #[serde(default, skip_serializing_if = "Option::is_none")]
163 pub reasoning: Option<Reasoning>,
164 pub usage: RequestUsage,
165 pub provider: Option<Provider>,
166}
167
168#[derive(Debug, Default, Serialize, Deserialize)]
169pub struct RequestUsage {
170 pub include: bool,
171}
172
173#[derive(Debug, Serialize, Deserialize)]
174#[serde(rename_all = "lowercase")]
175pub enum ToolChoice {
176 Auto,
177 Required,
178 None,
179 #[serde(untagged)]
180 Other(ToolDefinition),
181}
182
183#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
184#[derive(Clone, Deserialize, Serialize, Debug)]
185#[serde(tag = "type", rename_all = "snake_case")]
186pub enum ToolDefinition {
187 #[allow(dead_code)]
188 Function { function: FunctionDefinition },
189}
190
191#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
192#[derive(Clone, Debug, Serialize, Deserialize)]
193pub struct FunctionDefinition {
194 pub name: String,
195 pub description: Option<String>,
196 pub parameters: Option<Value>,
197}
198
199#[derive(Debug, Serialize, Deserialize)]
200pub struct Reasoning {
201 #[serde(skip_serializing_if = "Option::is_none")]
202 pub effort: Option<String>,
203 #[serde(skip_serializing_if = "Option::is_none")]
204 pub max_tokens: Option<u32>,
205 #[serde(skip_serializing_if = "Option::is_none")]
206 pub exclude: Option<bool>,
207 #[serde(skip_serializing_if = "Option::is_none")]
208 pub enabled: Option<bool>,
209}
210
211#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
212#[serde(tag = "role", rename_all = "lowercase")]
213pub enum RequestMessage {
214 Assistant {
215 content: Option<MessageContent>,
216 #[serde(default, skip_serializing_if = "Vec::is_empty")]
217 tool_calls: Vec<ToolCall>,
218 },
219 User {
220 content: MessageContent,
221 },
222 System {
223 content: MessageContent,
224 },
225 Tool {
226 content: MessageContent,
227 tool_call_id: String,
228 },
229}
230
231#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
232#[serde(untagged)]
233pub enum MessageContent {
234 Plain(String),
235 Multipart(Vec<MessagePart>),
236}
237
238impl MessageContent {
239 pub fn empty() -> Self {
240 Self::Plain(String::new())
241 }
242
243 pub fn push_part(&mut self, part: MessagePart) {
244 match self {
245 Self::Plain(text) if text.is_empty() => {
246 *self = Self::Multipart(vec![part]);
247 }
248 Self::Plain(text) => {
249 let text_part = MessagePart::Text {
250 text: std::mem::take(text),
251 };
252 *self = Self::Multipart(vec![text_part, part]);
253 }
254 Self::Multipart(parts) => parts.push(part),
255 }
256 }
257}
258
259impl From<Vec<MessagePart>> for MessageContent {
260 fn from(parts: Vec<MessagePart>) -> Self {
261 if parts.len() == 1
262 && let MessagePart::Text { text } = &parts[0]
263 {
264 return Self::Plain(text.clone());
265 }
266 Self::Multipart(parts)
267 }
268}
269
270impl From<String> for MessageContent {
271 fn from(text: String) -> Self {
272 Self::Plain(text)
273 }
274}
275
276impl From<&str> for MessageContent {
277 fn from(text: &str) -> Self {
278 Self::Plain(text.to_string())
279 }
280}
281
282impl MessageContent {
283 pub fn as_text(&self) -> Option<&str> {
284 match self {
285 Self::Plain(text) => Some(text),
286 Self::Multipart(parts) if parts.len() == 1 => {
287 if let MessagePart::Text { text } = &parts[0] {
288 Some(text)
289 } else {
290 None
291 }
292 }
293 _ => None,
294 }
295 }
296
297 pub fn to_text(&self) -> String {
298 match self {
299 Self::Plain(text) => text.clone(),
300 Self::Multipart(parts) => parts
301 .iter()
302 .filter_map(|part| {
303 if let MessagePart::Text { text } = part {
304 Some(text.as_str())
305 } else {
306 None
307 }
308 })
309 .collect::<Vec<_>>()
310 .join(""),
311 }
312 }
313}
314
315#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
316#[serde(tag = "type", rename_all = "snake_case")]
317pub enum MessagePart {
318 Text {
319 text: String,
320 },
321 #[serde(rename = "image_url")]
322 Image {
323 image_url: String,
324 },
325}
326
327#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
328pub struct ToolCall {
329 pub id: String,
330 #[serde(flatten)]
331 pub content: ToolCallContent,
332}
333
334#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
335#[serde(tag = "type", rename_all = "lowercase")]
336pub enum ToolCallContent {
337 Function { function: FunctionContent },
338}
339
340#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
341pub struct FunctionContent {
342 pub name: String,
343 pub arguments: String,
344}
345
346#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
347pub struct ResponseMessageDelta {
348 pub role: Option<Role>,
349 pub content: Option<String>,
350 pub reasoning: Option<String>,
351 #[serde(default, skip_serializing_if = "is_none_or_empty")]
352 pub tool_calls: Option<Vec<ToolCallChunk>>,
353}
354
355#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
356pub struct ToolCallChunk {
357 pub index: usize,
358 pub id: Option<String>,
359 pub function: Option<FunctionChunk>,
360}
361
362#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
363pub struct FunctionChunk {
364 pub name: Option<String>,
365 pub arguments: Option<String>,
366}
367
368#[derive(Serialize, Deserialize, Debug)]
369pub struct Usage {
370 pub prompt_tokens: u64,
371 pub completion_tokens: u64,
372 pub total_tokens: u64,
373}
374
375#[derive(Serialize, Deserialize, Debug)]
376pub struct ChoiceDelta {
377 pub index: u32,
378 pub delta: ResponseMessageDelta,
379 pub finish_reason: Option<String>,
380}
381
382#[derive(Serialize, Deserialize, Debug)]
383pub struct ResponseStreamEvent {
384 #[serde(default, skip_serializing_if = "Option::is_none")]
385 pub id: Option<String>,
386 pub created: u32,
387 pub model: String,
388 pub choices: Vec<ChoiceDelta>,
389 pub usage: Option<Usage>,
390}
391
392#[derive(Serialize, Deserialize, Debug)]
393pub struct Response {
394 pub id: String,
395 pub object: String,
396 pub created: u64,
397 pub model: String,
398 pub choices: Vec<Choice>,
399 pub usage: Usage,
400}
401
402#[derive(Serialize, Deserialize, Debug)]
403pub struct Choice {
404 pub index: u32,
405 pub message: RequestMessage,
406 pub finish_reason: Option<String>,
407}
408
409#[derive(Default, Debug, Clone, PartialEq, Deserialize)]
410pub struct ListModelsResponse {
411 pub data: Vec<ModelEntry>,
412}
413
414#[derive(Default, Debug, Clone, PartialEq, Deserialize)]
415pub struct ModelEntry {
416 pub id: String,
417 pub name: String,
418 pub created: usize,
419 pub description: String,
420 #[serde(default, skip_serializing_if = "Option::is_none")]
421 pub context_length: Option<u64>,
422 #[serde(default, skip_serializing_if = "Vec::is_empty")]
423 pub supported_parameters: Vec<String>,
424 #[serde(default, skip_serializing_if = "Option::is_none")]
425 pub architecture: Option<ModelArchitecture>,
426}
427
428#[derive(Default, Debug, Clone, PartialEq, Deserialize)]
429pub struct ModelArchitecture {
430 #[serde(default, skip_serializing_if = "Vec::is_empty")]
431 pub input_modalities: Vec<String>,
432}
433
434pub async fn stream_completion(
435 client: &dyn HttpClient,
436 api_url: &str,
437 api_key: &str,
438 request: Request,
439) -> Result<BoxStream<'static, Result<ResponseStreamEvent, OpenRouterError>>, OpenRouterError> {
440 let uri = format!("{api_url}/chat/completions");
441 let request_builder = HttpRequest::builder()
442 .method(Method::POST)
443 .uri(uri)
444 .header("Content-Type", "application/json")
445 .header("Authorization", format!("Bearer {}", api_key))
446 .header("HTTP-Referer", "https://zed.dev")
447 .header("X-Title", "Zed Editor");
448
449 let request = request_builder
450 .body(AsyncBody::from(
451 serde_json::to_string(&request).map_err(OpenRouterError::SerializeRequest)?,
452 ))
453 .map_err(OpenRouterError::BuildRequestBody)?;
454 let mut response = client
455 .send(request)
456 .await
457 .map_err(OpenRouterError::HttpSend)?;
458
459 if response.status().is_success() {
460 let reader = BufReader::new(response.into_body());
461 Ok(reader
462 .lines()
463 .filter_map(|line| async move {
464 match line {
465 Ok(line) => {
466 if line.starts_with(':') {
467 return None;
468 }
469
470 let line = line.strip_prefix("data: ")?;
471 if line == "[DONE]" {
472 None
473 } else {
474 match serde_json::from_str::<ResponseStreamEvent>(line) {
475 Ok(response) => Some(Ok(response)),
476 Err(error) => {
477 if line.trim().is_empty() {
478 None
479 } else {
480 Some(Err(OpenRouterError::DeserializeResponse(error)))
481 }
482 }
483 }
484 }
485 }
486 Err(error) => Some(Err(OpenRouterError::ReadResponse(error))),
487 }
488 })
489 .boxed())
490 } else {
491 let code = ApiErrorCode::from_status(response.status().as_u16());
492
493 let mut body = String::new();
494 response
495 .body_mut()
496 .read_to_string(&mut body)
497 .await
498 .map_err(OpenRouterError::ReadResponse)?;
499
500 let error_response = match serde_json::from_str::<OpenRouterErrorResponse>(&body) {
501 Ok(OpenRouterErrorResponse { error }) => error,
502 Err(_) => OpenRouterErrorBody {
503 code: response.status().as_u16(),
504 message: body,
505 metadata: None,
506 },
507 };
508
509 match code {
510 ApiErrorCode::RateLimitError => {
511 let retry_after = extract_retry_after(response.headers());
512 Err(OpenRouterError::RateLimit {
513 retry_after: retry_after.unwrap_or_else(|| std::time::Duration::from_secs(60)),
514 })
515 }
516 ApiErrorCode::OverloadedError => {
517 let retry_after = extract_retry_after(response.headers());
518 Err(OpenRouterError::ServerOverloaded { retry_after })
519 }
520 _ => Err(OpenRouterError::ApiError(ApiError {
521 code: code,
522 message: error_response.message,
523 })),
524 }
525 }
526}
527
528pub async fn list_models(
529 client: &dyn HttpClient,
530 api_url: &str,
531 api_key: &str,
532) -> Result<Vec<Model>, OpenRouterError> {
533 let uri = format!("{api_url}/models/user");
534 let request_builder = HttpRequest::builder()
535 .method(Method::GET)
536 .uri(uri)
537 .header("Accept", "application/json")
538 .header("Authorization", format!("Bearer {}", api_key))
539 .header("HTTP-Referer", "https://zed.dev")
540 .header("X-Title", "Zed Editor");
541
542 let request = request_builder
543 .body(AsyncBody::default())
544 .map_err(OpenRouterError::BuildRequestBody)?;
545 let mut response = client
546 .send(request)
547 .await
548 .map_err(OpenRouterError::HttpSend)?;
549
550 let mut body = String::new();
551 response
552 .body_mut()
553 .read_to_string(&mut body)
554 .await
555 .map_err(OpenRouterError::ReadResponse)?;
556
557 if response.status().is_success() {
558 let response: ListModelsResponse =
559 serde_json::from_str(&body).map_err(OpenRouterError::DeserializeResponse)?;
560
561 let models = response
562 .data
563 .into_iter()
564 .map(|entry| Model {
565 name: entry.id,
566 // OpenRouter returns display names in the format "provider_name: model_name".
567 // When displayed in the UI, these names can get truncated from the right.
568 // Since users typically already know the provider, we extract just the model name
569 // portion (after the colon) to create a more concise and user-friendly label
570 // for the model dropdown in the agent panel.
571 display_name: Some(
572 entry
573 .name
574 .split(':')
575 .next_back()
576 .unwrap_or(&entry.name)
577 .trim()
578 .to_string(),
579 ),
580 max_tokens: entry.context_length.unwrap_or(2000000),
581 supports_tools: Some(entry.supported_parameters.contains(&"tools".to_string())),
582 supports_images: Some(
583 entry
584 .architecture
585 .as_ref()
586 .map(|arch| arch.input_modalities.contains(&"image".to_string()))
587 .unwrap_or(false),
588 ),
589 mode: if entry
590 .supported_parameters
591 .contains(&"reasoning".to_string())
592 {
593 ModelMode::Thinking {
594 budget_tokens: Some(4_096),
595 }
596 } else {
597 ModelMode::Default
598 },
599 provider: None,
600 })
601 .collect();
602
603 Ok(models)
604 } else {
605 let code = ApiErrorCode::from_status(response.status().as_u16());
606
607 let mut body = String::new();
608 response
609 .body_mut()
610 .read_to_string(&mut body)
611 .await
612 .map_err(OpenRouterError::ReadResponse)?;
613
614 let error_response = match serde_json::from_str::<OpenRouterErrorResponse>(&body) {
615 Ok(OpenRouterErrorResponse { error }) => error,
616 Err(_) => OpenRouterErrorBody {
617 code: response.status().as_u16(),
618 message: body,
619 metadata: None,
620 },
621 };
622
623 match code {
624 ApiErrorCode::RateLimitError => {
625 let retry_after = extract_retry_after(response.headers());
626 Err(OpenRouterError::RateLimit {
627 retry_after: retry_after.unwrap_or_else(|| std::time::Duration::from_secs(60)),
628 })
629 }
630 ApiErrorCode::OverloadedError => {
631 let retry_after = extract_retry_after(response.headers());
632 Err(OpenRouterError::ServerOverloaded { retry_after })
633 }
634 _ => Err(OpenRouterError::ApiError(ApiError {
635 code: code,
636 message: error_response.message,
637 })),
638 }
639 }
640}
641
642#[derive(Debug)]
643pub enum OpenRouterError {
644 /// Failed to serialize the HTTP request body to JSON
645 SerializeRequest(serde_json::Error),
646
647 /// Failed to construct the HTTP request body
648 BuildRequestBody(http::Error),
649
650 /// Failed to send the HTTP request
651 HttpSend(anyhow::Error),
652
653 /// Failed to deserialize the response from JSON
654 DeserializeResponse(serde_json::Error),
655
656 /// Failed to read from response stream
657 ReadResponse(io::Error),
658
659 /// Rate limit exceeded
660 RateLimit { retry_after: Duration },
661
662 /// Server overloaded
663 ServerOverloaded { retry_after: Option<Duration> },
664
665 /// API returned an error response
666 ApiError(ApiError),
667}
668
669#[derive(Debug, Serialize, Deserialize)]
670pub struct OpenRouterErrorBody {
671 pub code: u16,
672 pub message: String,
673 #[serde(default, skip_serializing_if = "Option::is_none")]
674 pub metadata: Option<std::collections::HashMap<String, serde_json::Value>>,
675}
676
677#[derive(Debug, Serialize, Deserialize)]
678pub struct OpenRouterErrorResponse {
679 pub error: OpenRouterErrorBody,
680}
681
682#[derive(Debug, Serialize, Deserialize, Error)]
683#[error("OpenRouter API Error: {code}: {message}")]
684pub struct ApiError {
685 pub code: ApiErrorCode,
686 pub message: String,
687}
688
689/// An OpenROuter API error code.
690/// <https://openrouter.ai/docs/api-reference/errors#error-codes>
691#[derive(Debug, PartialEq, Eq, Clone, Copy, EnumString, Serialize, Deserialize)]
692#[strum(serialize_all = "snake_case")]
693pub enum ApiErrorCode {
694 /// 400: Bad Request (invalid or missing params, CORS)
695 InvalidRequestError,
696 /// 401: Invalid credentials (OAuth session expired, disabled/invalid API key)
697 AuthenticationError,
698 /// 402: Your account or API key has insufficient credits. Add more credits and retry the request.
699 PaymentRequiredError,
700 /// 403: Your chosen model requires moderation and your input was flagged
701 PermissionError,
702 /// 408: Your request timed out
703 RequestTimedOut,
704 /// 429: You are being rate limited
705 RateLimitError,
706 /// 502: Your chosen model is down or we received an invalid response from it
707 ApiError,
708 /// 503: There is no available model provider that meets your routing requirements
709 OverloadedError,
710}
711
712impl std::fmt::Display for ApiErrorCode {
713 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
714 let s = match self {
715 ApiErrorCode::InvalidRequestError => "invalid_request_error",
716 ApiErrorCode::AuthenticationError => "authentication_error",
717 ApiErrorCode::PaymentRequiredError => "payment_required_error",
718 ApiErrorCode::PermissionError => "permission_error",
719 ApiErrorCode::RequestTimedOut => "request_timed_out",
720 ApiErrorCode::RateLimitError => "rate_limit_error",
721 ApiErrorCode::ApiError => "api_error",
722 ApiErrorCode::OverloadedError => "overloaded_error",
723 };
724 write!(f, "{s}")
725 }
726}
727
728impl ApiErrorCode {
729 pub fn from_status(status: u16) -> Self {
730 match status {
731 400 => ApiErrorCode::InvalidRequestError,
732 401 => ApiErrorCode::AuthenticationError,
733 402 => ApiErrorCode::PaymentRequiredError,
734 403 => ApiErrorCode::PermissionError,
735 408 => ApiErrorCode::RequestTimedOut,
736 429 => ApiErrorCode::RateLimitError,
737 502 => ApiErrorCode::ApiError,
738 503 => ApiErrorCode::OverloadedError,
739 _ => ApiErrorCode::ApiError,
740 }
741 }
742}