1use anyhow::{Result, anyhow};
2use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
3use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest, http};
4use serde::{Deserialize, Serialize};
5use serde_json::Value;
6pub use settings::DataCollection;
7pub use settings::ModelMode;
8pub use settings::OpenRouterAvailableModel as AvailableModel;
9pub use settings::OpenRouterProvider as Provider;
10use std::{convert::TryFrom, io, time::Duration};
11use strum::EnumString;
12use thiserror::Error;
13
14pub const OPEN_ROUTER_API_URL: &str = "https://openrouter.ai/api/v1";
15
16fn extract_retry_after(headers: &http::HeaderMap) -> Option<std::time::Duration> {
17 if let Some(reset) = headers.get("X-RateLimit-Reset") {
18 if let Ok(s) = reset.to_str() {
19 if let Ok(epoch_ms) = s.parse::<u64>() {
20 let now = std::time::SystemTime::now()
21 .duration_since(std::time::UNIX_EPOCH)
22 .unwrap_or_default()
23 .as_millis() as u64;
24 if epoch_ms > now {
25 return Some(std::time::Duration::from_millis(epoch_ms - now));
26 }
27 }
28 }
29 }
30 None
31}
32
33fn is_none_or_empty<T: AsRef<[U]>, U>(opt: &Option<T>) -> bool {
34 opt.as_ref().is_none_or(|v| v.as_ref().is_empty())
35}
36
37#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
38#[serde(rename_all = "lowercase")]
39pub enum Role {
40 User,
41 Assistant,
42 System,
43 Tool,
44}
45
46impl TryFrom<String> for Role {
47 type Error = anyhow::Error;
48
49 fn try_from(value: String) -> Result<Self> {
50 match value.as_str() {
51 "user" => Ok(Self::User),
52 "assistant" => Ok(Self::Assistant),
53 "system" => Ok(Self::System),
54 "tool" => Ok(Self::Tool),
55 _ => Err(anyhow!("invalid role '{value}'")),
56 }
57 }
58}
59
60impl From<Role> for String {
61 fn from(val: Role) -> Self {
62 match val {
63 Role::User => "user".to_owned(),
64 Role::Assistant => "assistant".to_owned(),
65 Role::System => "system".to_owned(),
66 Role::Tool => "tool".to_owned(),
67 }
68 }
69}
70
71#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
72#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
73pub struct Model {
74 pub name: String,
75 pub display_name: Option<String>,
76 pub max_tokens: u64,
77 pub supports_tools: Option<bool>,
78 pub supports_images: Option<bool>,
79 #[serde(default)]
80 pub mode: ModelMode,
81 pub provider: Option<Provider>,
82}
83
84impl Model {
85 pub fn default() -> Self {
86 Self::new(
87 "openrouter/auto",
88 Some("Auto Router"),
89 Some(2000000),
90 Some(true),
91 Some(false),
92 Some(ModelMode::Default),
93 None,
94 )
95 }
96
97 pub fn new(
98 name: &str,
99 display_name: Option<&str>,
100 max_tokens: Option<u64>,
101 supports_tools: Option<bool>,
102 supports_images: Option<bool>,
103 mode: Option<ModelMode>,
104 provider: Option<Provider>,
105 ) -> Self {
106 Self {
107 name: name.to_owned(),
108 display_name: display_name.map(|s| s.to_owned()),
109 max_tokens: max_tokens.unwrap_or(2000000),
110 supports_tools,
111 supports_images,
112 mode: mode.unwrap_or(ModelMode::Default),
113 provider,
114 }
115 }
116
117 pub fn id(&self) -> &str {
118 &self.name
119 }
120
121 pub fn display_name(&self) -> &str {
122 self.display_name.as_ref().unwrap_or(&self.name)
123 }
124
125 pub fn max_token_count(&self) -> u64 {
126 self.max_tokens
127 }
128
129 pub fn max_output_tokens(&self) -> Option<u64> {
130 None
131 }
132
133 pub fn supports_tool_calls(&self) -> bool {
134 self.supports_tools.unwrap_or(false)
135 }
136
137 pub fn supports_parallel_tool_calls(&self) -> bool {
138 false
139 }
140}
141
142#[derive(Debug, Serialize, Deserialize)]
143pub struct Request {
144 pub model: String,
145 pub messages: Vec<RequestMessage>,
146 pub stream: bool,
147 #[serde(default, skip_serializing_if = "Option::is_none")]
148 pub max_tokens: Option<u64>,
149 #[serde(default, skip_serializing_if = "Vec::is_empty")]
150 pub stop: Vec<String>,
151 pub temperature: f32,
152 #[serde(default, skip_serializing_if = "Option::is_none")]
153 pub tool_choice: Option<ToolChoice>,
154 #[serde(default, skip_serializing_if = "Option::is_none")]
155 pub parallel_tool_calls: Option<bool>,
156 #[serde(default, skip_serializing_if = "Vec::is_empty")]
157 pub tools: Vec<ToolDefinition>,
158 #[serde(default, skip_serializing_if = "Option::is_none")]
159 pub reasoning: Option<Reasoning>,
160 pub usage: RequestUsage,
161 pub provider: Option<Provider>,
162}
163
164#[derive(Debug, Default, Serialize, Deserialize)]
165pub struct RequestUsage {
166 pub include: bool,
167}
168
169#[derive(Debug, Serialize, Deserialize)]
170#[serde(rename_all = "lowercase")]
171pub enum ToolChoice {
172 Auto,
173 Required,
174 None,
175 #[serde(untagged)]
176 Other(ToolDefinition),
177}
178
179#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
180#[derive(Clone, Deserialize, Serialize, Debug)]
181#[serde(tag = "type", rename_all = "snake_case")]
182pub enum ToolDefinition {
183 #[allow(dead_code)]
184 Function { function: FunctionDefinition },
185}
186
187#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
188#[derive(Clone, Debug, Serialize, Deserialize)]
189pub struct FunctionDefinition {
190 pub name: String,
191 pub description: Option<String>,
192 pub parameters: Option<Value>,
193}
194
195#[derive(Debug, Serialize, Deserialize)]
196pub struct Reasoning {
197 #[serde(skip_serializing_if = "Option::is_none")]
198 pub effort: Option<String>,
199 #[serde(skip_serializing_if = "Option::is_none")]
200 pub max_tokens: Option<u32>,
201 #[serde(skip_serializing_if = "Option::is_none")]
202 pub exclude: Option<bool>,
203 #[serde(skip_serializing_if = "Option::is_none")]
204 pub enabled: Option<bool>,
205}
206
207#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
208#[serde(tag = "role", rename_all = "lowercase")]
209pub enum RequestMessage {
210 Assistant {
211 content: Option<MessageContent>,
212 #[serde(default, skip_serializing_if = "Vec::is_empty")]
213 tool_calls: Vec<ToolCall>,
214 #[serde(default, skip_serializing_if = "Option::is_none")]
215 reasoning_details: Option<serde_json::Value>,
216 },
217 User {
218 content: MessageContent,
219 },
220 System {
221 content: MessageContent,
222 },
223 Tool {
224 content: MessageContent,
225 tool_call_id: String,
226 },
227}
228
229#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
230#[serde(untagged)]
231pub enum MessageContent {
232 Plain(String),
233 Multipart(Vec<MessagePart>),
234}
235
236impl MessageContent {
237 pub fn empty() -> Self {
238 Self::Plain(String::new())
239 }
240
241 pub fn push_part(&mut self, part: MessagePart) {
242 match self {
243 Self::Plain(text) if text.is_empty() => {
244 *self = Self::Multipart(vec![part]);
245 }
246 Self::Plain(text) => {
247 let text_part = MessagePart::Text {
248 text: std::mem::take(text),
249 };
250 *self = Self::Multipart(vec![text_part, part]);
251 }
252 Self::Multipart(parts) => parts.push(part),
253 }
254 }
255}
256
257impl From<Vec<MessagePart>> for MessageContent {
258 fn from(parts: Vec<MessagePart>) -> Self {
259 if parts.len() == 1
260 && let MessagePart::Text { text } = &parts[0]
261 {
262 return Self::Plain(text.clone());
263 }
264 Self::Multipart(parts)
265 }
266}
267
268impl From<String> for MessageContent {
269 fn from(text: String) -> Self {
270 Self::Plain(text)
271 }
272}
273
274impl From<&str> for MessageContent {
275 fn from(text: &str) -> Self {
276 Self::Plain(text.to_string())
277 }
278}
279
280impl MessageContent {
281 pub fn as_text(&self) -> Option<&str> {
282 match self {
283 Self::Plain(text) => Some(text),
284 Self::Multipart(parts) if parts.len() == 1 => {
285 if let MessagePart::Text { text } = &parts[0] {
286 Some(text)
287 } else {
288 None
289 }
290 }
291 _ => None,
292 }
293 }
294
295 pub fn to_text(&self) -> String {
296 match self {
297 Self::Plain(text) => text.clone(),
298 Self::Multipart(parts) => parts
299 .iter()
300 .filter_map(|part| {
301 if let MessagePart::Text { text } = part {
302 Some(text.as_str())
303 } else {
304 None
305 }
306 })
307 .collect::<Vec<_>>()
308 .join(""),
309 }
310 }
311}
312
313#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
314#[serde(tag = "type", rename_all = "snake_case")]
315pub enum MessagePart {
316 Text {
317 text: String,
318 },
319 #[serde(rename = "image_url")]
320 Image {
321 image_url: String,
322 },
323}
324
325#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
326pub struct ToolCall {
327 pub id: String,
328 #[serde(flatten)]
329 pub content: ToolCallContent,
330}
331
332#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
333#[serde(tag = "type", rename_all = "lowercase")]
334pub enum ToolCallContent {
335 Function { function: FunctionContent },
336}
337
338#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
339pub struct FunctionContent {
340 pub name: String,
341 pub arguments: String,
342 #[serde(default, skip_serializing_if = "Option::is_none")]
343 pub thought_signature: Option<String>,
344}
345
346#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
347pub struct ResponseMessageDelta {
348 pub role: Option<Role>,
349 pub content: Option<String>,
350 pub reasoning: Option<String>,
351 #[serde(default, skip_serializing_if = "is_none_or_empty")]
352 pub tool_calls: Option<Vec<ToolCallChunk>>,
353 #[serde(default, skip_serializing_if = "Option::is_none")]
354 pub reasoning_details: Option<serde_json::Value>,
355}
356
357#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
358pub struct ToolCallChunk {
359 pub index: usize,
360 pub id: Option<String>,
361 pub function: Option<FunctionChunk>,
362}
363
364#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
365pub struct FunctionChunk {
366 pub name: Option<String>,
367 pub arguments: Option<String>,
368 #[serde(default)]
369 pub thought_signature: Option<String>,
370}
371
372#[derive(Serialize, Deserialize, Debug)]
373pub struct Usage {
374 pub prompt_tokens: u64,
375 pub completion_tokens: u64,
376 pub total_tokens: u64,
377}
378
379#[derive(Serialize, Deserialize, Debug)]
380pub struct ChoiceDelta {
381 pub index: u32,
382 pub delta: ResponseMessageDelta,
383 pub finish_reason: Option<String>,
384}
385
386#[derive(Serialize, Deserialize, Debug)]
387pub struct ResponseStreamEvent {
388 #[serde(default, skip_serializing_if = "Option::is_none")]
389 pub id: Option<String>,
390 pub created: u32,
391 pub model: String,
392 pub choices: Vec<ChoiceDelta>,
393 pub usage: Option<Usage>,
394}
395
396#[derive(Serialize, Deserialize, Debug)]
397pub struct Response {
398 pub id: String,
399 pub object: String,
400 pub created: u64,
401 pub model: String,
402 pub choices: Vec<Choice>,
403 pub usage: Usage,
404}
405
406#[derive(Serialize, Deserialize, Debug)]
407pub struct Choice {
408 pub index: u32,
409 pub message: RequestMessage,
410 pub finish_reason: Option<String>,
411}
412
413#[derive(Default, Debug, Clone, PartialEq, Deserialize)]
414pub struct ListModelsResponse {
415 pub data: Vec<ModelEntry>,
416}
417
418#[derive(Default, Debug, Clone, PartialEq, Deserialize)]
419pub struct ModelEntry {
420 pub id: String,
421 pub name: String,
422 pub created: usize,
423 pub description: String,
424 #[serde(default, skip_serializing_if = "Option::is_none")]
425 pub context_length: Option<u64>,
426 #[serde(default, skip_serializing_if = "Vec::is_empty")]
427 pub supported_parameters: Vec<String>,
428 #[serde(default, skip_serializing_if = "Option::is_none")]
429 pub architecture: Option<ModelArchitecture>,
430}
431
432#[derive(Default, Debug, Clone, PartialEq, Deserialize)]
433pub struct ModelArchitecture {
434 #[serde(default, skip_serializing_if = "Vec::is_empty")]
435 pub input_modalities: Vec<String>,
436}
437
438pub async fn stream_completion(
439 client: &dyn HttpClient,
440 api_url: &str,
441 api_key: &str,
442 request: Request,
443) -> Result<BoxStream<'static, Result<ResponseStreamEvent, OpenRouterError>>, OpenRouterError> {
444 let uri = format!("{api_url}/chat/completions");
445 let request_builder = HttpRequest::builder()
446 .method(Method::POST)
447 .uri(uri)
448 .header("Content-Type", "application/json")
449 .header("Authorization", format!("Bearer {}", api_key))
450 .header("HTTP-Referer", "https://zed.dev")
451 .header("X-Title", "Zed Editor");
452
453 let request = request_builder
454 .body(AsyncBody::from(
455 serde_json::to_string(&request).map_err(OpenRouterError::SerializeRequest)?,
456 ))
457 .map_err(OpenRouterError::BuildRequestBody)?;
458 let mut response = client
459 .send(request)
460 .await
461 .map_err(OpenRouterError::HttpSend)?;
462
463 if response.status().is_success() {
464 let reader = BufReader::new(response.into_body());
465 Ok(reader
466 .lines()
467 .filter_map(|line| async move {
468 match line {
469 Ok(line) => {
470 if line.starts_with(':') {
471 return None;
472 }
473
474 let line = line.strip_prefix("data: ")?;
475 if line == "[DONE]" {
476 None
477 } else {
478 match serde_json::from_str::<ResponseStreamEvent>(line) {
479 Ok(response) => Some(Ok(response)),
480 Err(error) => {
481 if line.trim().is_empty() {
482 None
483 } else {
484 Some(Err(OpenRouterError::DeserializeResponse(error)))
485 }
486 }
487 }
488 }
489 }
490 Err(error) => Some(Err(OpenRouterError::ReadResponse(error))),
491 }
492 })
493 .boxed())
494 } else {
495 let code = ApiErrorCode::from_status(response.status().as_u16());
496
497 let mut body = String::new();
498 response
499 .body_mut()
500 .read_to_string(&mut body)
501 .await
502 .map_err(OpenRouterError::ReadResponse)?;
503
504 let error_response = match serde_json::from_str::<OpenRouterErrorResponse>(&body) {
505 Ok(OpenRouterErrorResponse { error }) => error,
506 Err(_) => OpenRouterErrorBody {
507 code: response.status().as_u16(),
508 message: body,
509 metadata: None,
510 },
511 };
512
513 match code {
514 ApiErrorCode::RateLimitError => {
515 let retry_after = extract_retry_after(response.headers());
516 Err(OpenRouterError::RateLimit {
517 retry_after: retry_after.unwrap_or_else(|| std::time::Duration::from_secs(60)),
518 })
519 }
520 ApiErrorCode::OverloadedError => {
521 let retry_after = extract_retry_after(response.headers());
522 Err(OpenRouterError::ServerOverloaded { retry_after })
523 }
524 _ => Err(OpenRouterError::ApiError(ApiError {
525 code: code,
526 message: error_response.message,
527 })),
528 }
529 }
530}
531
532pub async fn list_models(
533 client: &dyn HttpClient,
534 api_url: &str,
535 api_key: &str,
536) -> Result<Vec<Model>, OpenRouterError> {
537 let uri = format!("{api_url}/models/user");
538 let request_builder = HttpRequest::builder()
539 .method(Method::GET)
540 .uri(uri)
541 .header("Accept", "application/json")
542 .header("Authorization", format!("Bearer {}", api_key))
543 .header("HTTP-Referer", "https://zed.dev")
544 .header("X-Title", "Zed Editor");
545
546 let request = request_builder
547 .body(AsyncBody::default())
548 .map_err(OpenRouterError::BuildRequestBody)?;
549 let mut response = client
550 .send(request)
551 .await
552 .map_err(OpenRouterError::HttpSend)?;
553
554 let mut body = String::new();
555 response
556 .body_mut()
557 .read_to_string(&mut body)
558 .await
559 .map_err(OpenRouterError::ReadResponse)?;
560
561 if response.status().is_success() {
562 let response: ListModelsResponse =
563 serde_json::from_str(&body).map_err(OpenRouterError::DeserializeResponse)?;
564
565 let models = response
566 .data
567 .into_iter()
568 .map(|entry| Model {
569 name: entry.id,
570 // OpenRouter returns display names in the format "provider_name: model_name".
571 // When displayed in the UI, these names can get truncated from the right.
572 // Since users typically already know the provider, we extract just the model name
573 // portion (after the colon) to create a more concise and user-friendly label
574 // for the model dropdown in the agent panel.
575 display_name: Some(
576 entry
577 .name
578 .split(':')
579 .next_back()
580 .unwrap_or(&entry.name)
581 .trim()
582 .to_string(),
583 ),
584 max_tokens: entry.context_length.unwrap_or(2000000),
585 supports_tools: Some(entry.supported_parameters.contains(&"tools".to_string())),
586 supports_images: Some(
587 entry
588 .architecture
589 .as_ref()
590 .map(|arch| arch.input_modalities.contains(&"image".to_string()))
591 .unwrap_or(false),
592 ),
593 mode: if entry
594 .supported_parameters
595 .contains(&"reasoning".to_string())
596 {
597 ModelMode::Thinking {
598 budget_tokens: Some(4_096),
599 }
600 } else {
601 ModelMode::Default
602 },
603 provider: None,
604 })
605 .collect();
606
607 Ok(models)
608 } else {
609 let code = ApiErrorCode::from_status(response.status().as_u16());
610
611 let mut body = String::new();
612 response
613 .body_mut()
614 .read_to_string(&mut body)
615 .await
616 .map_err(OpenRouterError::ReadResponse)?;
617
618 let error_response = match serde_json::from_str::<OpenRouterErrorResponse>(&body) {
619 Ok(OpenRouterErrorResponse { error }) => error,
620 Err(_) => OpenRouterErrorBody {
621 code: response.status().as_u16(),
622 message: body,
623 metadata: None,
624 },
625 };
626
627 match code {
628 ApiErrorCode::RateLimitError => {
629 let retry_after = extract_retry_after(response.headers());
630 Err(OpenRouterError::RateLimit {
631 retry_after: retry_after.unwrap_or_else(|| std::time::Duration::from_secs(60)),
632 })
633 }
634 ApiErrorCode::OverloadedError => {
635 let retry_after = extract_retry_after(response.headers());
636 Err(OpenRouterError::ServerOverloaded { retry_after })
637 }
638 _ => Err(OpenRouterError::ApiError(ApiError {
639 code: code,
640 message: error_response.message,
641 })),
642 }
643 }
644}
645
646#[derive(Debug)]
647pub enum OpenRouterError {
648 /// Failed to serialize the HTTP request body to JSON
649 SerializeRequest(serde_json::Error),
650
651 /// Failed to construct the HTTP request body
652 BuildRequestBody(http::Error),
653
654 /// Failed to send the HTTP request
655 HttpSend(anyhow::Error),
656
657 /// Failed to deserialize the response from JSON
658 DeserializeResponse(serde_json::Error),
659
660 /// Failed to read from response stream
661 ReadResponse(io::Error),
662
663 /// Rate limit exceeded
664 RateLimit { retry_after: Duration },
665
666 /// Server overloaded
667 ServerOverloaded { retry_after: Option<Duration> },
668
669 /// API returned an error response
670 ApiError(ApiError),
671}
672
673#[derive(Debug, Serialize, Deserialize)]
674pub struct OpenRouterErrorBody {
675 pub code: u16,
676 pub message: String,
677 #[serde(default, skip_serializing_if = "Option::is_none")]
678 pub metadata: Option<std::collections::HashMap<String, serde_json::Value>>,
679}
680
681#[derive(Debug, Serialize, Deserialize)]
682pub struct OpenRouterErrorResponse {
683 pub error: OpenRouterErrorBody,
684}
685
686#[derive(Debug, Serialize, Deserialize, Error)]
687#[error("OpenRouter API Error: {code}: {message}")]
688pub struct ApiError {
689 pub code: ApiErrorCode,
690 pub message: String,
691}
692
693/// An OpenROuter API error code.
694/// <https://openrouter.ai/docs/api-reference/errors#error-codes>
695#[derive(Debug, PartialEq, Eq, Clone, Copy, EnumString, Serialize, Deserialize)]
696#[strum(serialize_all = "snake_case")]
697pub enum ApiErrorCode {
698 /// 400: Bad Request (invalid or missing params, CORS)
699 InvalidRequestError,
700 /// 401: Invalid credentials (OAuth session expired, disabled/invalid API key)
701 AuthenticationError,
702 /// 402: Your account or API key has insufficient credits. Add more credits and retry the request.
703 PaymentRequiredError,
704 /// 403: Your chosen model requires moderation and your input was flagged
705 PermissionError,
706 /// 408: Your request timed out
707 RequestTimedOut,
708 /// 429: You are being rate limited
709 RateLimitError,
710 /// 502: Your chosen model is down or we received an invalid response from it
711 ApiError,
712 /// 503: There is no available model provider that meets your routing requirements
713 OverloadedError,
714}
715
716impl std::fmt::Display for ApiErrorCode {
717 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
718 let s = match self {
719 ApiErrorCode::InvalidRequestError => "invalid_request_error",
720 ApiErrorCode::AuthenticationError => "authentication_error",
721 ApiErrorCode::PaymentRequiredError => "payment_required_error",
722 ApiErrorCode::PermissionError => "permission_error",
723 ApiErrorCode::RequestTimedOut => "request_timed_out",
724 ApiErrorCode::RateLimitError => "rate_limit_error",
725 ApiErrorCode::ApiError => "api_error",
726 ApiErrorCode::OverloadedError => "overloaded_error",
727 };
728 write!(f, "{s}")
729 }
730}
731
732impl ApiErrorCode {
733 pub fn from_status(status: u16) -> Self {
734 match status {
735 400 => ApiErrorCode::InvalidRequestError,
736 401 => ApiErrorCode::AuthenticationError,
737 402 => ApiErrorCode::PaymentRequiredError,
738 403 => ApiErrorCode::PermissionError,
739 408 => ApiErrorCode::RequestTimedOut,
740 429 => ApiErrorCode::RateLimitError,
741 502 => ApiErrorCode::ApiError,
742 503 => ApiErrorCode::OverloadedError,
743 _ => ApiErrorCode::ApiError,
744 }
745 }
746}