1use anyhow::{Result, anyhow};
2use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
3use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest, http};
4use serde::{Deserialize, Serialize};
5use serde_json::Value;
6pub use settings::DataCollection;
7pub use settings::ModelMode;
8pub use settings::OpenRouterAvailableModel as AvailableModel;
9pub use settings::OpenRouterProvider as Provider;
10use std::{convert::TryFrom, io, time::Duration};
11use strum::EnumString;
12use thiserror::Error;
13
14pub const OPEN_ROUTER_API_URL: &str = "https://openrouter.ai/api/v1";
15
16fn extract_retry_after(headers: &http::HeaderMap) -> Option<std::time::Duration> {
17 if let Some(reset) = headers.get("X-RateLimit-Reset") {
18 if let Ok(s) = reset.to_str() {
19 if let Ok(epoch_ms) = s.parse::<u64>() {
20 let now = std::time::SystemTime::now()
21 .duration_since(std::time::UNIX_EPOCH)
22 .unwrap_or_default()
23 .as_millis() as u64;
24 if epoch_ms > now {
25 return Some(std::time::Duration::from_millis(epoch_ms - now));
26 }
27 }
28 }
29 }
30 None
31}
32
33fn is_none_or_empty<T: AsRef<[U]>, U>(opt: &Option<T>) -> bool {
34 opt.as_ref().is_none_or(|v| v.as_ref().is_empty())
35}
36
37#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
38#[serde(rename_all = "lowercase")]
39pub enum Role {
40 User,
41 Assistant,
42 System,
43 Tool,
44}
45
46impl TryFrom<String> for Role {
47 type Error = anyhow::Error;
48
49 fn try_from(value: String) -> Result<Self> {
50 match value.as_str() {
51 "user" => Ok(Self::User),
52 "assistant" => Ok(Self::Assistant),
53 "system" => Ok(Self::System),
54 "tool" => Ok(Self::Tool),
55 _ => Err(anyhow!("invalid role '{value}'")),
56 }
57 }
58}
59
60impl From<Role> for String {
61 fn from(val: Role) -> Self {
62 match val {
63 Role::User => "user".to_owned(),
64 Role::Assistant => "assistant".to_owned(),
65 Role::System => "system".to_owned(),
66 Role::Tool => "tool".to_owned(),
67 }
68 }
69}
70
71#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
72#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
73pub struct Model {
74 pub name: String,
75 pub display_name: Option<String>,
76 pub max_tokens: u64,
77 pub supports_tools: Option<bool>,
78 pub supports_images: Option<bool>,
79 #[serde(default)]
80 pub mode: ModelMode,
81 pub provider: Option<Provider>,
82}
83
84impl Model {
85 pub fn default_fast() -> Self {
86 Self::new(
87 "openrouter/auto",
88 Some("Auto Router"),
89 Some(2000000),
90 Some(true),
91 Some(false),
92 Some(ModelMode::Default),
93 None,
94 )
95 }
96
97 pub fn default() -> Self {
98 Self::default_fast()
99 }
100
101 pub fn new(
102 name: &str,
103 display_name: Option<&str>,
104 max_tokens: Option<u64>,
105 supports_tools: Option<bool>,
106 supports_images: Option<bool>,
107 mode: Option<ModelMode>,
108 provider: Option<Provider>,
109 ) -> Self {
110 Self {
111 name: name.to_owned(),
112 display_name: display_name.map(|s| s.to_owned()),
113 max_tokens: max_tokens.unwrap_or(2000000),
114 supports_tools,
115 supports_images,
116 mode: mode.unwrap_or(ModelMode::Default),
117 provider,
118 }
119 }
120
121 pub fn id(&self) -> &str {
122 &self.name
123 }
124
125 pub fn display_name(&self) -> &str {
126 self.display_name.as_ref().unwrap_or(&self.name)
127 }
128
129 pub fn max_token_count(&self) -> u64 {
130 self.max_tokens
131 }
132
133 pub fn max_output_tokens(&self) -> Option<u64> {
134 None
135 }
136
137 pub fn supports_tool_calls(&self) -> bool {
138 self.supports_tools.unwrap_or(false)
139 }
140
141 pub fn supports_parallel_tool_calls(&self) -> bool {
142 false
143 }
144}
145
146#[derive(Debug, Serialize, Deserialize)]
147pub struct Request {
148 pub model: String,
149 pub messages: Vec<RequestMessage>,
150 pub stream: bool,
151 #[serde(default, skip_serializing_if = "Option::is_none")]
152 pub max_tokens: Option<u64>,
153 #[serde(default, skip_serializing_if = "Vec::is_empty")]
154 pub stop: Vec<String>,
155 pub temperature: f32,
156 #[serde(default, skip_serializing_if = "Option::is_none")]
157 pub tool_choice: Option<ToolChoice>,
158 #[serde(default, skip_serializing_if = "Option::is_none")]
159 pub parallel_tool_calls: Option<bool>,
160 #[serde(default, skip_serializing_if = "Vec::is_empty")]
161 pub tools: Vec<ToolDefinition>,
162 #[serde(default, skip_serializing_if = "Option::is_none")]
163 pub reasoning: Option<Reasoning>,
164 pub usage: RequestUsage,
165 pub provider: Option<Provider>,
166}
167
168#[derive(Debug, Default, Serialize, Deserialize)]
169pub struct RequestUsage {
170 pub include: bool,
171}
172
173#[derive(Debug, Serialize, Deserialize)]
174#[serde(rename_all = "lowercase")]
175pub enum ToolChoice {
176 Auto,
177 Required,
178 None,
179 #[serde(untagged)]
180 Other(ToolDefinition),
181}
182
183#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
184#[derive(Clone, Deserialize, Serialize, Debug)]
185#[serde(tag = "type", rename_all = "snake_case")]
186pub enum ToolDefinition {
187 #[allow(dead_code)]
188 Function { function: FunctionDefinition },
189}
190
191#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
192#[derive(Clone, Debug, Serialize, Deserialize)]
193pub struct FunctionDefinition {
194 pub name: String,
195 pub description: Option<String>,
196 pub parameters: Option<Value>,
197}
198
199#[derive(Debug, Serialize, Deserialize)]
200pub struct Reasoning {
201 #[serde(skip_serializing_if = "Option::is_none")]
202 pub effort: Option<String>,
203 #[serde(skip_serializing_if = "Option::is_none")]
204 pub max_tokens: Option<u32>,
205 #[serde(skip_serializing_if = "Option::is_none")]
206 pub exclude: Option<bool>,
207 #[serde(skip_serializing_if = "Option::is_none")]
208 pub enabled: Option<bool>,
209}
210
211#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
212#[serde(tag = "role", rename_all = "lowercase")]
213pub enum RequestMessage {
214 Assistant {
215 content: Option<MessageContent>,
216 #[serde(default, skip_serializing_if = "Vec::is_empty")]
217 tool_calls: Vec<ToolCall>,
218 #[serde(default, skip_serializing_if = "Option::is_none")]
219 reasoning_details: Option<serde_json::Value>,
220 },
221 User {
222 content: MessageContent,
223 },
224 System {
225 content: MessageContent,
226 },
227 Tool {
228 content: MessageContent,
229 tool_call_id: String,
230 },
231}
232
233#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
234#[serde(untagged)]
235pub enum MessageContent {
236 Plain(String),
237 Multipart(Vec<MessagePart>),
238}
239
240impl MessageContent {
241 pub fn empty() -> Self {
242 Self::Plain(String::new())
243 }
244
245 pub fn push_part(&mut self, part: MessagePart) {
246 match self {
247 Self::Plain(text) if text.is_empty() => {
248 *self = Self::Multipart(vec![part]);
249 }
250 Self::Plain(text) => {
251 let text_part = MessagePart::Text {
252 text: std::mem::take(text),
253 };
254 *self = Self::Multipart(vec![text_part, part]);
255 }
256 Self::Multipart(parts) => parts.push(part),
257 }
258 }
259}
260
261impl From<Vec<MessagePart>> for MessageContent {
262 fn from(parts: Vec<MessagePart>) -> Self {
263 if parts.len() == 1
264 && let MessagePart::Text { text } = &parts[0]
265 {
266 return Self::Plain(text.clone());
267 }
268 Self::Multipart(parts)
269 }
270}
271
272impl From<String> for MessageContent {
273 fn from(text: String) -> Self {
274 Self::Plain(text)
275 }
276}
277
278impl From<&str> for MessageContent {
279 fn from(text: &str) -> Self {
280 Self::Plain(text.to_string())
281 }
282}
283
284impl MessageContent {
285 pub fn as_text(&self) -> Option<&str> {
286 match self {
287 Self::Plain(text) => Some(text),
288 Self::Multipart(parts) if parts.len() == 1 => {
289 if let MessagePart::Text { text } = &parts[0] {
290 Some(text)
291 } else {
292 None
293 }
294 }
295 _ => None,
296 }
297 }
298
299 pub fn to_text(&self) -> String {
300 match self {
301 Self::Plain(text) => text.clone(),
302 Self::Multipart(parts) => parts
303 .iter()
304 .filter_map(|part| {
305 if let MessagePart::Text { text } = part {
306 Some(text.as_str())
307 } else {
308 None
309 }
310 })
311 .collect::<Vec<_>>()
312 .join(""),
313 }
314 }
315}
316
317#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
318#[serde(tag = "type", rename_all = "snake_case")]
319pub enum MessagePart {
320 Text {
321 text: String,
322 },
323 #[serde(rename = "image_url")]
324 Image {
325 image_url: String,
326 },
327}
328
329#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
330pub struct ToolCall {
331 pub id: String,
332 #[serde(flatten)]
333 pub content: ToolCallContent,
334}
335
336#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
337#[serde(tag = "type", rename_all = "lowercase")]
338pub enum ToolCallContent {
339 Function { function: FunctionContent },
340}
341
342#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
343pub struct FunctionContent {
344 pub name: String,
345 pub arguments: String,
346 #[serde(default, skip_serializing_if = "Option::is_none")]
347 pub thought_signature: Option<String>,
348}
349
350#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
351pub struct ResponseMessageDelta {
352 pub role: Option<Role>,
353 pub content: Option<String>,
354 pub reasoning: Option<String>,
355 #[serde(default, skip_serializing_if = "is_none_or_empty")]
356 pub tool_calls: Option<Vec<ToolCallChunk>>,
357 #[serde(default, skip_serializing_if = "Option::is_none")]
358 pub reasoning_details: Option<serde_json::Value>,
359}
360
361#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
362pub struct ToolCallChunk {
363 pub index: usize,
364 pub id: Option<String>,
365 pub function: Option<FunctionChunk>,
366}
367
368#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
369pub struct FunctionChunk {
370 pub name: Option<String>,
371 pub arguments: Option<String>,
372 #[serde(default)]
373 pub thought_signature: Option<String>,
374}
375
376#[derive(Serialize, Deserialize, Debug)]
377pub struct Usage {
378 pub prompt_tokens: u64,
379 pub completion_tokens: u64,
380 pub total_tokens: u64,
381}
382
383#[derive(Serialize, Deserialize, Debug)]
384pub struct ChoiceDelta {
385 pub index: u32,
386 pub delta: ResponseMessageDelta,
387 pub finish_reason: Option<String>,
388}
389
390#[derive(Serialize, Deserialize, Debug)]
391pub struct ResponseStreamEvent {
392 #[serde(default, skip_serializing_if = "Option::is_none")]
393 pub id: Option<String>,
394 pub created: u32,
395 pub model: String,
396 pub choices: Vec<ChoiceDelta>,
397 pub usage: Option<Usage>,
398}
399
400#[derive(Serialize, Deserialize, Debug)]
401pub struct Response {
402 pub id: String,
403 pub object: String,
404 pub created: u64,
405 pub model: String,
406 pub choices: Vec<Choice>,
407 pub usage: Usage,
408}
409
410#[derive(Serialize, Deserialize, Debug)]
411pub struct Choice {
412 pub index: u32,
413 pub message: RequestMessage,
414 pub finish_reason: Option<String>,
415}
416
417#[derive(Default, Debug, Clone, PartialEq, Deserialize)]
418pub struct ListModelsResponse {
419 pub data: Vec<ModelEntry>,
420}
421
422#[derive(Default, Debug, Clone, PartialEq, Deserialize)]
423pub struct ModelEntry {
424 pub id: String,
425 pub name: String,
426 pub created: usize,
427 pub description: String,
428 #[serde(default, skip_serializing_if = "Option::is_none")]
429 pub context_length: Option<u64>,
430 #[serde(default, skip_serializing_if = "Vec::is_empty")]
431 pub supported_parameters: Vec<String>,
432 #[serde(default, skip_serializing_if = "Option::is_none")]
433 pub architecture: Option<ModelArchitecture>,
434}
435
436#[derive(Default, Debug, Clone, PartialEq, Deserialize)]
437pub struct ModelArchitecture {
438 #[serde(default, skip_serializing_if = "Vec::is_empty")]
439 pub input_modalities: Vec<String>,
440}
441
442pub async fn stream_completion(
443 client: &dyn HttpClient,
444 api_url: &str,
445 api_key: &str,
446 request: Request,
447) -> Result<BoxStream<'static, Result<ResponseStreamEvent, OpenRouterError>>, OpenRouterError> {
448 let uri = format!("{api_url}/chat/completions");
449 let request_builder = HttpRequest::builder()
450 .method(Method::POST)
451 .uri(uri)
452 .header("Content-Type", "application/json")
453 .header("Authorization", format!("Bearer {}", api_key))
454 .header("HTTP-Referer", "https://zed.dev")
455 .header("X-Title", "Zed Editor");
456
457 let request = request_builder
458 .body(AsyncBody::from(
459 serde_json::to_string(&request).map_err(OpenRouterError::SerializeRequest)?,
460 ))
461 .map_err(OpenRouterError::BuildRequestBody)?;
462 let mut response = client
463 .send(request)
464 .await
465 .map_err(OpenRouterError::HttpSend)?;
466
467 if response.status().is_success() {
468 let reader = BufReader::new(response.into_body());
469 Ok(reader
470 .lines()
471 .filter_map(|line| async move {
472 match line {
473 Ok(line) => {
474 if line.starts_with(':') {
475 return None;
476 }
477
478 let line = line.strip_prefix("data: ")?;
479 if line == "[DONE]" {
480 None
481 } else {
482 match serde_json::from_str::<ResponseStreamEvent>(line) {
483 Ok(response) => Some(Ok(response)),
484 Err(error) => {
485 if line.trim().is_empty() {
486 None
487 } else {
488 Some(Err(OpenRouterError::DeserializeResponse(error)))
489 }
490 }
491 }
492 }
493 }
494 Err(error) => Some(Err(OpenRouterError::ReadResponse(error))),
495 }
496 })
497 .boxed())
498 } else {
499 let code = ApiErrorCode::from_status(response.status().as_u16());
500
501 let mut body = String::new();
502 response
503 .body_mut()
504 .read_to_string(&mut body)
505 .await
506 .map_err(OpenRouterError::ReadResponse)?;
507
508 let error_response = match serde_json::from_str::<OpenRouterErrorResponse>(&body) {
509 Ok(OpenRouterErrorResponse { error }) => error,
510 Err(_) => OpenRouterErrorBody {
511 code: response.status().as_u16(),
512 message: body,
513 metadata: None,
514 },
515 };
516
517 match code {
518 ApiErrorCode::RateLimitError => {
519 let retry_after = extract_retry_after(response.headers());
520 Err(OpenRouterError::RateLimit {
521 retry_after: retry_after.unwrap_or_else(|| std::time::Duration::from_secs(60)),
522 })
523 }
524 ApiErrorCode::OverloadedError => {
525 let retry_after = extract_retry_after(response.headers());
526 Err(OpenRouterError::ServerOverloaded { retry_after })
527 }
528 _ => Err(OpenRouterError::ApiError(ApiError {
529 code: code,
530 message: error_response.message,
531 })),
532 }
533 }
534}
535
536pub async fn list_models(
537 client: &dyn HttpClient,
538 api_url: &str,
539 api_key: &str,
540) -> Result<Vec<Model>, OpenRouterError> {
541 let uri = format!("{api_url}/models/user");
542 let request_builder = HttpRequest::builder()
543 .method(Method::GET)
544 .uri(uri)
545 .header("Accept", "application/json")
546 .header("Authorization", format!("Bearer {}", api_key))
547 .header("HTTP-Referer", "https://zed.dev")
548 .header("X-Title", "Zed Editor");
549
550 let request = request_builder
551 .body(AsyncBody::default())
552 .map_err(OpenRouterError::BuildRequestBody)?;
553 let mut response = client
554 .send(request)
555 .await
556 .map_err(OpenRouterError::HttpSend)?;
557
558 let mut body = String::new();
559 response
560 .body_mut()
561 .read_to_string(&mut body)
562 .await
563 .map_err(OpenRouterError::ReadResponse)?;
564
565 if response.status().is_success() {
566 let response: ListModelsResponse =
567 serde_json::from_str(&body).map_err(OpenRouterError::DeserializeResponse)?;
568
569 let models = response
570 .data
571 .into_iter()
572 .map(|entry| Model {
573 name: entry.id,
574 // OpenRouter returns display names in the format "provider_name: model_name".
575 // When displayed in the UI, these names can get truncated from the right.
576 // Since users typically already know the provider, we extract just the model name
577 // portion (after the colon) to create a more concise and user-friendly label
578 // for the model dropdown in the agent panel.
579 display_name: Some(
580 entry
581 .name
582 .split(':')
583 .next_back()
584 .unwrap_or(&entry.name)
585 .trim()
586 .to_string(),
587 ),
588 max_tokens: entry.context_length.unwrap_or(2000000),
589 supports_tools: Some(entry.supported_parameters.contains(&"tools".to_string())),
590 supports_images: Some(
591 entry
592 .architecture
593 .as_ref()
594 .map(|arch| arch.input_modalities.contains(&"image".to_string()))
595 .unwrap_or(false),
596 ),
597 mode: if entry
598 .supported_parameters
599 .contains(&"reasoning".to_string())
600 {
601 ModelMode::Thinking {
602 budget_tokens: Some(4_096),
603 }
604 } else {
605 ModelMode::Default
606 },
607 provider: None,
608 })
609 .collect();
610
611 Ok(models)
612 } else {
613 let code = ApiErrorCode::from_status(response.status().as_u16());
614
615 let mut body = String::new();
616 response
617 .body_mut()
618 .read_to_string(&mut body)
619 .await
620 .map_err(OpenRouterError::ReadResponse)?;
621
622 let error_response = match serde_json::from_str::<OpenRouterErrorResponse>(&body) {
623 Ok(OpenRouterErrorResponse { error }) => error,
624 Err(_) => OpenRouterErrorBody {
625 code: response.status().as_u16(),
626 message: body,
627 metadata: None,
628 },
629 };
630
631 match code {
632 ApiErrorCode::RateLimitError => {
633 let retry_after = extract_retry_after(response.headers());
634 Err(OpenRouterError::RateLimit {
635 retry_after: retry_after.unwrap_or_else(|| std::time::Duration::from_secs(60)),
636 })
637 }
638 ApiErrorCode::OverloadedError => {
639 let retry_after = extract_retry_after(response.headers());
640 Err(OpenRouterError::ServerOverloaded { retry_after })
641 }
642 _ => Err(OpenRouterError::ApiError(ApiError {
643 code: code,
644 message: error_response.message,
645 })),
646 }
647 }
648}
649
650#[derive(Debug)]
651pub enum OpenRouterError {
652 /// Failed to serialize the HTTP request body to JSON
653 SerializeRequest(serde_json::Error),
654
655 /// Failed to construct the HTTP request body
656 BuildRequestBody(http::Error),
657
658 /// Failed to send the HTTP request
659 HttpSend(anyhow::Error),
660
661 /// Failed to deserialize the response from JSON
662 DeserializeResponse(serde_json::Error),
663
664 /// Failed to read from response stream
665 ReadResponse(io::Error),
666
667 /// Rate limit exceeded
668 RateLimit { retry_after: Duration },
669
670 /// Server overloaded
671 ServerOverloaded { retry_after: Option<Duration> },
672
673 /// API returned an error response
674 ApiError(ApiError),
675}
676
677#[derive(Debug, Serialize, Deserialize)]
678pub struct OpenRouterErrorBody {
679 pub code: u16,
680 pub message: String,
681 #[serde(default, skip_serializing_if = "Option::is_none")]
682 pub metadata: Option<std::collections::HashMap<String, serde_json::Value>>,
683}
684
685#[derive(Debug, Serialize, Deserialize)]
686pub struct OpenRouterErrorResponse {
687 pub error: OpenRouterErrorBody,
688}
689
690#[derive(Debug, Serialize, Deserialize, Error)]
691#[error("OpenRouter API Error: {code}: {message}")]
692pub struct ApiError {
693 pub code: ApiErrorCode,
694 pub message: String,
695}
696
697/// An OpenROuter API error code.
698/// <https://openrouter.ai/docs/api-reference/errors#error-codes>
699#[derive(Debug, PartialEq, Eq, Clone, Copy, EnumString, Serialize, Deserialize)]
700#[strum(serialize_all = "snake_case")]
701pub enum ApiErrorCode {
702 /// 400: Bad Request (invalid or missing params, CORS)
703 InvalidRequestError,
704 /// 401: Invalid credentials (OAuth session expired, disabled/invalid API key)
705 AuthenticationError,
706 /// 402: Your account or API key has insufficient credits. Add more credits and retry the request.
707 PaymentRequiredError,
708 /// 403: Your chosen model requires moderation and your input was flagged
709 PermissionError,
710 /// 408: Your request timed out
711 RequestTimedOut,
712 /// 429: You are being rate limited
713 RateLimitError,
714 /// 502: Your chosen model is down or we received an invalid response from it
715 ApiError,
716 /// 503: There is no available model provider that meets your routing requirements
717 OverloadedError,
718}
719
720impl std::fmt::Display for ApiErrorCode {
721 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
722 let s = match self {
723 ApiErrorCode::InvalidRequestError => "invalid_request_error",
724 ApiErrorCode::AuthenticationError => "authentication_error",
725 ApiErrorCode::PaymentRequiredError => "payment_required_error",
726 ApiErrorCode::PermissionError => "permission_error",
727 ApiErrorCode::RequestTimedOut => "request_timed_out",
728 ApiErrorCode::RateLimitError => "rate_limit_error",
729 ApiErrorCode::ApiError => "api_error",
730 ApiErrorCode::OverloadedError => "overloaded_error",
731 };
732 write!(f, "{s}")
733 }
734}
735
736impl ApiErrorCode {
737 pub fn from_status(status: u16) -> Self {
738 match status {
739 400 => ApiErrorCode::InvalidRequestError,
740 401 => ApiErrorCode::AuthenticationError,
741 402 => ApiErrorCode::PaymentRequiredError,
742 403 => ApiErrorCode::PermissionError,
743 408 => ApiErrorCode::RequestTimedOut,
744 429 => ApiErrorCode::RateLimitError,
745 502 => ApiErrorCode::ApiError,
746 503 => ApiErrorCode::OverloadedError,
747 _ => ApiErrorCode::ApiError,
748 }
749 }
750}