1use anyhow::{Result, anyhow};
2use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
3use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest, http};
4use serde::{Deserialize, Serialize};
5use serde_json::Value;
6use std::{convert::TryFrom, io, time::Duration};
7use strum::EnumString;
8use thiserror::Error;
9
10pub const OPEN_ROUTER_API_URL: &str = "https://openrouter.ai/api/v1";
11
12fn extract_retry_after(headers: &http::HeaderMap) -> Option<std::time::Duration> {
13 if let Some(reset) = headers.get("X-RateLimit-Reset") {
14 if let Ok(s) = reset.to_str() {
15 if let Ok(epoch_ms) = s.parse::<u64>() {
16 let now = std::time::SystemTime::now()
17 .duration_since(std::time::UNIX_EPOCH)
18 .unwrap_or_default()
19 .as_millis() as u64;
20 if epoch_ms > now {
21 return Some(std::time::Duration::from_millis(epoch_ms - now));
22 }
23 }
24 }
25 }
26 None
27}
28
29fn is_none_or_empty<T: AsRef<[U]>, U>(opt: &Option<T>) -> bool {
30 opt.as_ref().is_none_or(|v| v.as_ref().is_empty())
31}
32
33#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
34#[serde(rename_all = "lowercase")]
35pub enum Role {
36 User,
37 Assistant,
38 System,
39 Tool,
40}
41
42impl TryFrom<String> for Role {
43 type Error = anyhow::Error;
44
45 fn try_from(value: String) -> Result<Self> {
46 match value.as_str() {
47 "user" => Ok(Self::User),
48 "assistant" => Ok(Self::Assistant),
49 "system" => Ok(Self::System),
50 "tool" => Ok(Self::Tool),
51 _ => Err(anyhow!("invalid role '{value}'")),
52 }
53 }
54}
55
56impl From<Role> for String {
57 fn from(val: Role) -> Self {
58 match val {
59 Role::User => "user".to_owned(),
60 Role::Assistant => "assistant".to_owned(),
61 Role::System => "system".to_owned(),
62 Role::Tool => "tool".to_owned(),
63 }
64 }
65}
66
67#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
68#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
69pub struct Model {
70 pub name: String,
71 pub display_name: Option<String>,
72 pub max_tokens: u64,
73 pub supports_tools: Option<bool>,
74 pub supports_images: Option<bool>,
75 #[serde(default)]
76 pub mode: ModelMode,
77}
78
79#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
80#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
81pub enum ModelMode {
82 #[default]
83 Default,
84 Thinking {
85 budget_tokens: Option<u32>,
86 },
87}
88
89impl Model {
90 pub fn default_fast() -> Self {
91 Self::new(
92 "openrouter/auto",
93 Some("Auto Router"),
94 Some(2000000),
95 Some(true),
96 Some(false),
97 Some(ModelMode::Default),
98 )
99 }
100
101 pub fn default() -> Self {
102 Self::default_fast()
103 }
104
105 pub fn new(
106 name: &str,
107 display_name: Option<&str>,
108 max_tokens: Option<u64>,
109 supports_tools: Option<bool>,
110 supports_images: Option<bool>,
111 mode: Option<ModelMode>,
112 ) -> Self {
113 Self {
114 name: name.to_owned(),
115 display_name: display_name.map(|s| s.to_owned()),
116 max_tokens: max_tokens.unwrap_or(2000000),
117 supports_tools,
118 supports_images,
119 mode: mode.unwrap_or(ModelMode::Default),
120 }
121 }
122
123 pub fn id(&self) -> &str {
124 &self.name
125 }
126
127 pub fn display_name(&self) -> &str {
128 self.display_name.as_ref().unwrap_or(&self.name)
129 }
130
131 pub fn max_token_count(&self) -> u64 {
132 self.max_tokens
133 }
134
135 pub fn max_output_tokens(&self) -> Option<u64> {
136 None
137 }
138
139 pub fn supports_tool_calls(&self) -> bool {
140 self.supports_tools.unwrap_or(false)
141 }
142
143 pub fn supports_parallel_tool_calls(&self) -> bool {
144 false
145 }
146}
147
148#[derive(Debug, Serialize, Deserialize)]
149pub struct Request {
150 pub model: String,
151 pub messages: Vec<RequestMessage>,
152 pub stream: bool,
153 #[serde(default, skip_serializing_if = "Option::is_none")]
154 pub max_tokens: Option<u64>,
155 #[serde(default, skip_serializing_if = "Vec::is_empty")]
156 pub stop: Vec<String>,
157 pub temperature: f32,
158 #[serde(default, skip_serializing_if = "Option::is_none")]
159 pub tool_choice: Option<ToolChoice>,
160 #[serde(default, skip_serializing_if = "Option::is_none")]
161 pub parallel_tool_calls: Option<bool>,
162 #[serde(default, skip_serializing_if = "Vec::is_empty")]
163 pub tools: Vec<ToolDefinition>,
164 #[serde(default, skip_serializing_if = "Option::is_none")]
165 pub reasoning: Option<Reasoning>,
166 pub usage: RequestUsage,
167}
168
169#[derive(Debug, Default, Serialize, Deserialize)]
170pub struct RequestUsage {
171 pub include: bool,
172}
173
174#[derive(Debug, Serialize, Deserialize)]
175#[serde(rename_all = "lowercase")]
176pub enum ToolChoice {
177 Auto,
178 Required,
179 None,
180 #[serde(untagged)]
181 Other(ToolDefinition),
182}
183
184#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
185#[derive(Clone, Deserialize, Serialize, Debug)]
186#[serde(tag = "type", rename_all = "snake_case")]
187pub enum ToolDefinition {
188 #[allow(dead_code)]
189 Function { function: FunctionDefinition },
190}
191
192#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
193#[derive(Clone, Debug, Serialize, Deserialize)]
194pub struct FunctionDefinition {
195 pub name: String,
196 pub description: Option<String>,
197 pub parameters: Option<Value>,
198}
199
200#[derive(Debug, Serialize, Deserialize)]
201pub struct Reasoning {
202 #[serde(skip_serializing_if = "Option::is_none")]
203 pub effort: Option<String>,
204 #[serde(skip_serializing_if = "Option::is_none")]
205 pub max_tokens: Option<u32>,
206 #[serde(skip_serializing_if = "Option::is_none")]
207 pub exclude: Option<bool>,
208 #[serde(skip_serializing_if = "Option::is_none")]
209 pub enabled: Option<bool>,
210}
211
212#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
213#[serde(tag = "role", rename_all = "lowercase")]
214pub enum RequestMessage {
215 Assistant {
216 content: Option<MessageContent>,
217 #[serde(default, skip_serializing_if = "Vec::is_empty")]
218 tool_calls: Vec<ToolCall>,
219 },
220 User {
221 content: MessageContent,
222 },
223 System {
224 content: MessageContent,
225 },
226 Tool {
227 content: MessageContent,
228 tool_call_id: String,
229 },
230}
231
232#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
233#[serde(untagged)]
234pub enum MessageContent {
235 Plain(String),
236 Multipart(Vec<MessagePart>),
237}
238
239impl MessageContent {
240 pub fn empty() -> Self {
241 Self::Plain(String::new())
242 }
243
244 pub fn push_part(&mut self, part: MessagePart) {
245 match self {
246 Self::Plain(text) if text.is_empty() => {
247 *self = Self::Multipart(vec![part]);
248 }
249 Self::Plain(text) => {
250 let text_part = MessagePart::Text {
251 text: std::mem::take(text),
252 };
253 *self = Self::Multipart(vec![text_part, part]);
254 }
255 Self::Multipart(parts) => parts.push(part),
256 }
257 }
258}
259
260impl From<Vec<MessagePart>> for MessageContent {
261 fn from(parts: Vec<MessagePart>) -> Self {
262 if parts.len() == 1
263 && let MessagePart::Text { text } = &parts[0]
264 {
265 return Self::Plain(text.clone());
266 }
267 Self::Multipart(parts)
268 }
269}
270
271impl From<String> for MessageContent {
272 fn from(text: String) -> Self {
273 Self::Plain(text)
274 }
275}
276
277impl From<&str> for MessageContent {
278 fn from(text: &str) -> Self {
279 Self::Plain(text.to_string())
280 }
281}
282
283impl MessageContent {
284 pub fn as_text(&self) -> Option<&str> {
285 match self {
286 Self::Plain(text) => Some(text),
287 Self::Multipart(parts) if parts.len() == 1 => {
288 if let MessagePart::Text { text } = &parts[0] {
289 Some(text)
290 } else {
291 None
292 }
293 }
294 _ => None,
295 }
296 }
297
298 pub fn to_text(&self) -> String {
299 match self {
300 Self::Plain(text) => text.clone(),
301 Self::Multipart(parts) => parts
302 .iter()
303 .filter_map(|part| {
304 if let MessagePart::Text { text } = part {
305 Some(text.as_str())
306 } else {
307 None
308 }
309 })
310 .collect::<Vec<_>>()
311 .join(""),
312 }
313 }
314}
315
316#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
317#[serde(tag = "type", rename_all = "snake_case")]
318pub enum MessagePart {
319 Text {
320 text: String,
321 },
322 #[serde(rename = "image_url")]
323 Image {
324 image_url: String,
325 },
326}
327
328#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
329pub struct ToolCall {
330 pub id: String,
331 #[serde(flatten)]
332 pub content: ToolCallContent,
333}
334
335#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
336#[serde(tag = "type", rename_all = "lowercase")]
337pub enum ToolCallContent {
338 Function { function: FunctionContent },
339}
340
341#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
342pub struct FunctionContent {
343 pub name: String,
344 pub arguments: String,
345}
346
347#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
348pub struct ResponseMessageDelta {
349 pub role: Option<Role>,
350 pub content: Option<String>,
351 pub reasoning: Option<String>,
352 #[serde(default, skip_serializing_if = "is_none_or_empty")]
353 pub tool_calls: Option<Vec<ToolCallChunk>>,
354}
355
356#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
357pub struct ToolCallChunk {
358 pub index: usize,
359 pub id: Option<String>,
360 pub function: Option<FunctionChunk>,
361}
362
363#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
364pub struct FunctionChunk {
365 pub name: Option<String>,
366 pub arguments: Option<String>,
367}
368
369#[derive(Serialize, Deserialize, Debug)]
370pub struct Usage {
371 pub prompt_tokens: u64,
372 pub completion_tokens: u64,
373 pub total_tokens: u64,
374}
375
376#[derive(Serialize, Deserialize, Debug)]
377pub struct ChoiceDelta {
378 pub index: u32,
379 pub delta: ResponseMessageDelta,
380 pub finish_reason: Option<String>,
381}
382
383#[derive(Serialize, Deserialize, Debug)]
384pub struct ResponseStreamEvent {
385 #[serde(default, skip_serializing_if = "Option::is_none")]
386 pub id: Option<String>,
387 pub created: u32,
388 pub model: String,
389 pub choices: Vec<ChoiceDelta>,
390 pub usage: Option<Usage>,
391}
392
393#[derive(Serialize, Deserialize, Debug)]
394pub struct Response {
395 pub id: String,
396 pub object: String,
397 pub created: u64,
398 pub model: String,
399 pub choices: Vec<Choice>,
400 pub usage: Usage,
401}
402
403#[derive(Serialize, Deserialize, Debug)]
404pub struct Choice {
405 pub index: u32,
406 pub message: RequestMessage,
407 pub finish_reason: Option<String>,
408}
409
410#[derive(Default, Debug, Clone, PartialEq, Deserialize)]
411pub struct ListModelsResponse {
412 pub data: Vec<ModelEntry>,
413}
414
415#[derive(Default, Debug, Clone, PartialEq, Deserialize)]
416pub struct ModelEntry {
417 pub id: String,
418 pub name: String,
419 pub created: usize,
420 pub description: String,
421 #[serde(default, skip_serializing_if = "Option::is_none")]
422 pub context_length: Option<u64>,
423 #[serde(default, skip_serializing_if = "Vec::is_empty")]
424 pub supported_parameters: Vec<String>,
425 #[serde(default, skip_serializing_if = "Option::is_none")]
426 pub architecture: Option<ModelArchitecture>,
427}
428
429#[derive(Default, Debug, Clone, PartialEq, Deserialize)]
430pub struct ModelArchitecture {
431 #[serde(default, skip_serializing_if = "Vec::is_empty")]
432 pub input_modalities: Vec<String>,
433}
434
435pub async fn stream_completion(
436 client: &dyn HttpClient,
437 api_url: &str,
438 api_key: &str,
439 request: Request,
440) -> Result<BoxStream<'static, Result<ResponseStreamEvent, OpenRouterError>>, OpenRouterError> {
441 let uri = format!("{api_url}/chat/completions");
442 let request_builder = HttpRequest::builder()
443 .method(Method::POST)
444 .uri(uri)
445 .header("Content-Type", "application/json")
446 .header("Authorization", format!("Bearer {}", api_key))
447 .header("HTTP-Referer", "https://zed.dev")
448 .header("X-Title", "Zed Editor");
449
450 let request = request_builder
451 .body(AsyncBody::from(
452 serde_json::to_string(&request).map_err(OpenRouterError::SerializeRequest)?,
453 ))
454 .map_err(OpenRouterError::BuildRequestBody)?;
455 let mut response = client
456 .send(request)
457 .await
458 .map_err(OpenRouterError::HttpSend)?;
459
460 if response.status().is_success() {
461 let reader = BufReader::new(response.into_body());
462 Ok(reader
463 .lines()
464 .filter_map(|line| async move {
465 match line {
466 Ok(line) => {
467 if line.starts_with(':') {
468 return None;
469 }
470
471 let line = line.strip_prefix("data: ")?;
472 if line == "[DONE]" {
473 None
474 } else {
475 match serde_json::from_str::<ResponseStreamEvent>(line) {
476 Ok(response) => Some(Ok(response)),
477 Err(error) => {
478 if line.trim().is_empty() {
479 None
480 } else {
481 Some(Err(OpenRouterError::DeserializeResponse(error)))
482 }
483 }
484 }
485 }
486 }
487 Err(error) => Some(Err(OpenRouterError::ReadResponse(error))),
488 }
489 })
490 .boxed())
491 } else {
492 let code = ApiErrorCode::from_status(response.status().as_u16());
493
494 let mut body = String::new();
495 response
496 .body_mut()
497 .read_to_string(&mut body)
498 .await
499 .map_err(OpenRouterError::ReadResponse)?;
500
501 let error_response = match serde_json::from_str::<OpenRouterErrorResponse>(&body) {
502 Ok(OpenRouterErrorResponse { error }) => error,
503 Err(_) => OpenRouterErrorBody {
504 code: response.status().as_u16(),
505 message: body,
506 metadata: None,
507 },
508 };
509
510 match code {
511 ApiErrorCode::RateLimitError => {
512 let retry_after = extract_retry_after(response.headers());
513 Err(OpenRouterError::RateLimit {
514 retry_after: retry_after.unwrap_or_else(|| std::time::Duration::from_secs(60)),
515 })
516 }
517 ApiErrorCode::OverloadedError => {
518 let retry_after = extract_retry_after(response.headers());
519 Err(OpenRouterError::ServerOverloaded { retry_after })
520 }
521 _ => Err(OpenRouterError::ApiError(ApiError {
522 code: code,
523 message: error_response.message,
524 })),
525 }
526 }
527}
528
529pub async fn list_models(
530 client: &dyn HttpClient,
531 api_url: &str,
532 api_key: &str,
533) -> Result<Vec<Model>, OpenRouterError> {
534 let uri = format!("{api_url}/models/user");
535 let request_builder = HttpRequest::builder()
536 .method(Method::GET)
537 .uri(uri)
538 .header("Accept", "application/json")
539 .header("Authorization", format!("Bearer {}", api_key))
540 .header("HTTP-Referer", "https://zed.dev")
541 .header("X-Title", "Zed Editor");
542
543 let request = request_builder
544 .body(AsyncBody::default())
545 .map_err(OpenRouterError::BuildRequestBody)?;
546 let mut response = client
547 .send(request)
548 .await
549 .map_err(OpenRouterError::HttpSend)?;
550
551 let mut body = String::new();
552 response
553 .body_mut()
554 .read_to_string(&mut body)
555 .await
556 .map_err(OpenRouterError::ReadResponse)?;
557
558 if response.status().is_success() {
559 let response: ListModelsResponse =
560 serde_json::from_str(&body).map_err(OpenRouterError::DeserializeResponse)?;
561
562 let models = response
563 .data
564 .into_iter()
565 .map(|entry| Model {
566 name: entry.id,
567 // OpenRouter returns display names in the format "provider_name: model_name".
568 // When displayed in the UI, these names can get truncated from the right.
569 // Since users typically already know the provider, we extract just the model name
570 // portion (after the colon) to create a more concise and user-friendly label
571 // for the model dropdown in the agent panel.
572 display_name: Some(
573 entry
574 .name
575 .split(':')
576 .next_back()
577 .unwrap_or(&entry.name)
578 .trim()
579 .to_string(),
580 ),
581 max_tokens: entry.context_length.unwrap_or(2000000),
582 supports_tools: Some(entry.supported_parameters.contains(&"tools".to_string())),
583 supports_images: Some(
584 entry
585 .architecture
586 .as_ref()
587 .map(|arch| arch.input_modalities.contains(&"image".to_string()))
588 .unwrap_or(false),
589 ),
590 mode: if entry
591 .supported_parameters
592 .contains(&"reasoning".to_string())
593 {
594 ModelMode::Thinking {
595 budget_tokens: Some(4_096),
596 }
597 } else {
598 ModelMode::Default
599 },
600 })
601 .collect();
602
603 Ok(models)
604 } else {
605 let code = ApiErrorCode::from_status(response.status().as_u16());
606
607 let mut body = String::new();
608 response
609 .body_mut()
610 .read_to_string(&mut body)
611 .await
612 .map_err(OpenRouterError::ReadResponse)?;
613
614 let error_response = match serde_json::from_str::<OpenRouterErrorResponse>(&body) {
615 Ok(OpenRouterErrorResponse { error }) => error,
616 Err(_) => OpenRouterErrorBody {
617 code: response.status().as_u16(),
618 message: body,
619 metadata: None,
620 },
621 };
622
623 match code {
624 ApiErrorCode::RateLimitError => {
625 let retry_after = extract_retry_after(response.headers());
626 Err(OpenRouterError::RateLimit {
627 retry_after: retry_after.unwrap_or_else(|| std::time::Duration::from_secs(60)),
628 })
629 }
630 ApiErrorCode::OverloadedError => {
631 let retry_after = extract_retry_after(response.headers());
632 Err(OpenRouterError::ServerOverloaded { retry_after })
633 }
634 _ => Err(OpenRouterError::ApiError(ApiError {
635 code: code,
636 message: error_response.message,
637 })),
638 }
639 }
640}
641
642#[derive(Debug)]
643pub enum OpenRouterError {
644 /// Failed to serialize the HTTP request body to JSON
645 SerializeRequest(serde_json::Error),
646
647 /// Failed to construct the HTTP request body
648 BuildRequestBody(http::Error),
649
650 /// Failed to send the HTTP request
651 HttpSend(anyhow::Error),
652
653 /// Failed to deserialize the response from JSON
654 DeserializeResponse(serde_json::Error),
655
656 /// Failed to read from response stream
657 ReadResponse(io::Error),
658
659 /// Rate limit exceeded
660 RateLimit { retry_after: Duration },
661
662 /// Server overloaded
663 ServerOverloaded { retry_after: Option<Duration> },
664
665 /// API returned an error response
666 ApiError(ApiError),
667}
668
669#[derive(Debug, Serialize, Deserialize)]
670pub struct OpenRouterErrorBody {
671 pub code: u16,
672 pub message: String,
673 #[serde(default, skip_serializing_if = "Option::is_none")]
674 pub metadata: Option<std::collections::HashMap<String, serde_json::Value>>,
675}
676
677#[derive(Debug, Serialize, Deserialize)]
678pub struct OpenRouterErrorResponse {
679 pub error: OpenRouterErrorBody,
680}
681
682#[derive(Debug, Serialize, Deserialize, Error)]
683#[error("OpenRouter API Error: {code}: {message}")]
684pub struct ApiError {
685 pub code: ApiErrorCode,
686 pub message: String,
687}
688
689/// An OpenROuter API error code.
690/// <https://openrouter.ai/docs/api-reference/errors#error-codes>
691#[derive(Debug, PartialEq, Eq, Clone, Copy, EnumString, Serialize, Deserialize)]
692#[strum(serialize_all = "snake_case")]
693pub enum ApiErrorCode {
694 /// 400: Bad Request (invalid or missing params, CORS)
695 InvalidRequestError,
696 /// 401: Invalid credentials (OAuth session expired, disabled/invalid API key)
697 AuthenticationError,
698 /// 402: Your account or API key has insufficient credits. Add more credits and retry the request.
699 PaymentRequiredError,
700 /// 403: Your chosen model requires moderation and your input was flagged
701 PermissionError,
702 /// 408: Your request timed out
703 RequestTimedOut,
704 /// 429: You are being rate limited
705 RateLimitError,
706 /// 502: Your chosen model is down or we received an invalid response from it
707 ApiError,
708 /// 503: There is no available model provider that meets your routing requirements
709 OverloadedError,
710}
711
712impl std::fmt::Display for ApiErrorCode {
713 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
714 let s = match self {
715 ApiErrorCode::InvalidRequestError => "invalid_request_error",
716 ApiErrorCode::AuthenticationError => "authentication_error",
717 ApiErrorCode::PaymentRequiredError => "payment_required_error",
718 ApiErrorCode::PermissionError => "permission_error",
719 ApiErrorCode::RequestTimedOut => "request_timed_out",
720 ApiErrorCode::RateLimitError => "rate_limit_error",
721 ApiErrorCode::ApiError => "api_error",
722 ApiErrorCode::OverloadedError => "overloaded_error",
723 };
724 write!(f, "{s}")
725 }
726}
727
728impl ApiErrorCode {
729 pub fn from_status(status: u16) -> Self {
730 match status {
731 400 => ApiErrorCode::InvalidRequestError,
732 401 => ApiErrorCode::AuthenticationError,
733 402 => ApiErrorCode::PaymentRequiredError,
734 403 => ApiErrorCode::PermissionError,
735 408 => ApiErrorCode::RequestTimedOut,
736 429 => ApiErrorCode::RateLimitError,
737 502 => ApiErrorCode::ApiError,
738 503 => ApiErrorCode::OverloadedError,
739 _ => ApiErrorCode::ApiError,
740 }
741 }
742}