1use anyhow::{Result, anyhow};
2use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
3use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest, http};
4use serde::{Deserialize, Serialize};
5use serde_json::Value;
6use std::{convert::TryFrom, io, time::Duration};
7use strum::EnumString;
8use thiserror::Error;
9
10pub const OPEN_ROUTER_API_URL: &str = "https://openrouter.ai/api/v1";
11
12fn extract_retry_after(headers: &http::HeaderMap) -> Option<std::time::Duration> {
13 if let Some(reset) = headers.get("X-RateLimit-Reset") {
14 if let Ok(s) = reset.to_str() {
15 if let Ok(epoch_ms) = s.parse::<u64>() {
16 let now = std::time::SystemTime::now()
17 .duration_since(std::time::UNIX_EPOCH)
18 .unwrap_or_default()
19 .as_millis() as u64;
20 if epoch_ms > now {
21 return Some(std::time::Duration::from_millis(epoch_ms - now));
22 }
23 }
24 }
25 }
26 None
27}
28
29fn is_none_or_empty<T: AsRef<[U]>, U>(opt: &Option<T>) -> bool {
30 opt.as_ref().is_none_or(|v| v.as_ref().is_empty())
31}
32
33#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
34#[serde(rename_all = "lowercase")]
35pub enum Role {
36 User,
37 Assistant,
38 System,
39 Tool,
40}
41
42impl TryFrom<String> for Role {
43 type Error = anyhow::Error;
44
45 fn try_from(value: String) -> Result<Self> {
46 match value.as_str() {
47 "user" => Ok(Self::User),
48 "assistant" => Ok(Self::Assistant),
49 "system" => Ok(Self::System),
50 "tool" => Ok(Self::Tool),
51 _ => Err(anyhow!("invalid role '{value}'")),
52 }
53 }
54}
55
56impl From<Role> for String {
57 fn from(val: Role) -> Self {
58 match val {
59 Role::User => "user".to_owned(),
60 Role::Assistant => "assistant".to_owned(),
61 Role::System => "system".to_owned(),
62 Role::Tool => "tool".to_owned(),
63 }
64 }
65}
66
67#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
68#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
69pub struct Model {
70 pub name: String,
71 pub display_name: Option<String>,
72 pub max_tokens: u64,
73 pub supports_tools: Option<bool>,
74 pub supports_images: Option<bool>,
75 #[serde(default)]
76 pub mode: ModelMode,
77}
78
79#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
80#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
81pub enum ModelMode {
82 #[default]
83 Default,
84 Thinking {
85 budget_tokens: Option<u32>,
86 },
87}
88
89impl Model {
90 pub fn default_fast() -> Self {
91 Self::new(
92 "openrouter/auto",
93 Some("Auto Router"),
94 Some(2000000),
95 Some(true),
96 Some(false),
97 Some(ModelMode::Default),
98 )
99 }
100
101 pub fn default() -> Self {
102 Self::default_fast()
103 }
104
105 pub fn new(
106 name: &str,
107 display_name: Option<&str>,
108 max_tokens: Option<u64>,
109 supports_tools: Option<bool>,
110 supports_images: Option<bool>,
111 mode: Option<ModelMode>,
112 ) -> Self {
113 Self {
114 name: name.to_owned(),
115 display_name: display_name.map(|s| s.to_owned()),
116 max_tokens: max_tokens.unwrap_or(2000000),
117 supports_tools,
118 supports_images,
119 mode: mode.unwrap_or(ModelMode::Default),
120 }
121 }
122
123 pub fn id(&self) -> &str {
124 &self.name
125 }
126
127 pub fn display_name(&self) -> &str {
128 self.display_name.as_ref().unwrap_or(&self.name)
129 }
130
131 pub fn max_token_count(&self) -> u64 {
132 self.max_tokens
133 }
134
135 pub fn max_output_tokens(&self) -> Option<u64> {
136 None
137 }
138
139 pub fn supports_tool_calls(&self) -> bool {
140 self.supports_tools.unwrap_or(false)
141 }
142
143 pub fn supports_parallel_tool_calls(&self) -> bool {
144 false
145 }
146}
147
148#[derive(Debug, Serialize, Deserialize)]
149pub struct Request {
150 pub model: String,
151 pub messages: Vec<RequestMessage>,
152 pub stream: bool,
153 #[serde(default, skip_serializing_if = "Option::is_none")]
154 pub max_tokens: Option<u64>,
155 #[serde(default, skip_serializing_if = "Vec::is_empty")]
156 pub stop: Vec<String>,
157 pub temperature: f32,
158 #[serde(default, skip_serializing_if = "Option::is_none")]
159 pub tool_choice: Option<ToolChoice>,
160 #[serde(default, skip_serializing_if = "Option::is_none")]
161 pub parallel_tool_calls: Option<bool>,
162 #[serde(default, skip_serializing_if = "Vec::is_empty")]
163 pub tools: Vec<ToolDefinition>,
164 #[serde(default, skip_serializing_if = "Option::is_none")]
165 pub reasoning: Option<Reasoning>,
166 pub usage: RequestUsage,
167}
168
169#[derive(Debug, Default, Serialize, Deserialize)]
170pub struct RequestUsage {
171 pub include: bool,
172}
173
174#[derive(Debug, Serialize, Deserialize)]
175#[serde(rename_all = "lowercase")]
176pub enum ToolChoice {
177 Auto,
178 Required,
179 None,
180 #[serde(untagged)]
181 Other(ToolDefinition),
182}
183
184#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
185#[derive(Clone, Deserialize, Serialize, Debug)]
186#[serde(tag = "type", rename_all = "snake_case")]
187pub enum ToolDefinition {
188 #[allow(dead_code)]
189 Function { function: FunctionDefinition },
190}
191
192#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
193#[derive(Clone, Debug, Serialize, Deserialize)]
194pub struct FunctionDefinition {
195 pub name: String,
196 pub description: Option<String>,
197 pub parameters: Option<Value>,
198}
199
200#[derive(Debug, Serialize, Deserialize)]
201pub struct Reasoning {
202 #[serde(skip_serializing_if = "Option::is_none")]
203 pub effort: Option<String>,
204 #[serde(skip_serializing_if = "Option::is_none")]
205 pub max_tokens: Option<u32>,
206 #[serde(skip_serializing_if = "Option::is_none")]
207 pub exclude: Option<bool>,
208 #[serde(skip_serializing_if = "Option::is_none")]
209 pub enabled: Option<bool>,
210}
211
212#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
213#[serde(tag = "role", rename_all = "lowercase")]
214pub enum RequestMessage {
215 Assistant {
216 content: Option<MessageContent>,
217 #[serde(default, skip_serializing_if = "Vec::is_empty")]
218 tool_calls: Vec<ToolCall>,
219 },
220 User {
221 content: MessageContent,
222 },
223 System {
224 content: MessageContent,
225 },
226 Tool {
227 content: MessageContent,
228 tool_call_id: String,
229 },
230}
231
232#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
233#[serde(untagged)]
234pub enum MessageContent {
235 Plain(String),
236 Multipart(Vec<MessagePart>),
237}
238
239impl MessageContent {
240 pub fn empty() -> Self {
241 Self::Plain(String::new())
242 }
243
244 pub fn push_part(&mut self, part: MessagePart) {
245 match self {
246 Self::Plain(text) if text.is_empty() => {
247 *self = Self::Multipart(vec![part]);
248 }
249 Self::Plain(text) => {
250 let text_part = MessagePart::Text {
251 text: std::mem::take(text),
252 };
253 *self = Self::Multipart(vec![text_part, part]);
254 }
255 Self::Multipart(parts) => parts.push(part),
256 }
257 }
258}
259
260impl From<Vec<MessagePart>> for MessageContent {
261 fn from(parts: Vec<MessagePart>) -> Self {
262 if parts.len() == 1
263 && let MessagePart::Text { text } = &parts[0]
264 {
265 return Self::Plain(text.clone());
266 }
267 Self::Multipart(parts)
268 }
269}
270
271impl From<String> for MessageContent {
272 fn from(text: String) -> Self {
273 Self::Plain(text)
274 }
275}
276
277impl From<&str> for MessageContent {
278 fn from(text: &str) -> Self {
279 Self::Plain(text.to_string())
280 }
281}
282
283impl MessageContent {
284 pub fn as_text(&self) -> Option<&str> {
285 match self {
286 Self::Plain(text) => Some(text),
287 Self::Multipart(parts) if parts.len() == 1 => {
288 if let MessagePart::Text { text } = &parts[0] {
289 Some(text)
290 } else {
291 None
292 }
293 }
294 _ => None,
295 }
296 }
297
298 pub fn to_text(&self) -> String {
299 match self {
300 Self::Plain(text) => text.clone(),
301 Self::Multipart(parts) => parts
302 .iter()
303 .filter_map(|part| {
304 if let MessagePart::Text { text } = part {
305 Some(text.as_str())
306 } else {
307 None
308 }
309 })
310 .collect::<Vec<_>>()
311 .join(""),
312 }
313 }
314}
315
316#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
317#[serde(tag = "type", rename_all = "snake_case")]
318pub enum MessagePart {
319 Text {
320 text: String,
321 },
322 #[serde(rename = "image_url")]
323 Image {
324 image_url: String,
325 },
326}
327
328#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
329pub struct ToolCall {
330 pub id: String,
331 #[serde(flatten)]
332 pub content: ToolCallContent,
333}
334
335#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
336#[serde(tag = "type", rename_all = "lowercase")]
337pub enum ToolCallContent {
338 Function { function: FunctionContent },
339}
340
341#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
342pub struct FunctionContent {
343 pub name: String,
344 pub arguments: String,
345}
346
347#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
348pub struct ResponseMessageDelta {
349 pub role: Option<Role>,
350 pub content: Option<String>,
351 pub reasoning: Option<String>,
352 #[serde(default, skip_serializing_if = "is_none_or_empty")]
353 pub tool_calls: Option<Vec<ToolCallChunk>>,
354}
355
356#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
357pub struct ToolCallChunk {
358 pub index: usize,
359 pub id: Option<String>,
360 pub function: Option<FunctionChunk>,
361}
362
363#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
364pub struct FunctionChunk {
365 pub name: Option<String>,
366 pub arguments: Option<String>,
367}
368
369#[derive(Serialize, Deserialize, Debug)]
370pub struct Usage {
371 pub prompt_tokens: u64,
372 pub completion_tokens: u64,
373 pub total_tokens: u64,
374}
375
376#[derive(Serialize, Deserialize, Debug)]
377pub struct ChoiceDelta {
378 pub index: u32,
379 pub delta: ResponseMessageDelta,
380 pub finish_reason: Option<String>,
381}
382
383#[derive(Serialize, Deserialize, Debug)]
384pub struct ResponseStreamEvent {
385 #[serde(default, skip_serializing_if = "Option::is_none")]
386 pub id: Option<String>,
387 pub created: u32,
388 pub model: String,
389 pub choices: Vec<ChoiceDelta>,
390 pub usage: Option<Usage>,
391}
392
393#[derive(Serialize, Deserialize, Debug)]
394pub struct Response {
395 pub id: String,
396 pub object: String,
397 pub created: u64,
398 pub model: String,
399 pub choices: Vec<Choice>,
400 pub usage: Usage,
401}
402
403#[derive(Serialize, Deserialize, Debug)]
404pub struct Choice {
405 pub index: u32,
406 pub message: RequestMessage,
407 pub finish_reason: Option<String>,
408}
409
410#[derive(Default, Debug, Clone, PartialEq, Deserialize)]
411pub struct ListModelsResponse {
412 pub data: Vec<ModelEntry>,
413}
414
415#[derive(Default, Debug, Clone, PartialEq, Deserialize)]
416pub struct ModelEntry {
417 pub id: String,
418 pub name: String,
419 pub created: usize,
420 pub description: String,
421 #[serde(default, skip_serializing_if = "Option::is_none")]
422 pub context_length: Option<u64>,
423 #[serde(default, skip_serializing_if = "Vec::is_empty")]
424 pub supported_parameters: Vec<String>,
425 #[serde(default, skip_serializing_if = "Option::is_none")]
426 pub architecture: Option<ModelArchitecture>,
427}
428
429#[derive(Default, Debug, Clone, PartialEq, Deserialize)]
430pub struct ModelArchitecture {
431 #[serde(default, skip_serializing_if = "Vec::is_empty")]
432 pub input_modalities: Vec<String>,
433}
434
435pub async fn stream_completion(
436 client: &dyn HttpClient,
437 api_url: &str,
438 api_key: &str,
439 request: Request,
440) -> Result<BoxStream<'static, Result<ResponseStreamEvent, OpenRouterError>>, OpenRouterError> {
441 let uri = format!("{api_url}/chat/completions");
442 let request_builder = HttpRequest::builder()
443 .method(Method::POST)
444 .uri(uri)
445 .header("Content-Type", "application/json")
446 .header("Authorization", format!("Bearer {}", api_key))
447 .header("HTTP-Referer", "https://zed.dev")
448 .header("X-Title", "Zed Editor");
449
450 let request = request_builder
451 .body(AsyncBody::from(
452 serde_json::to_string(&request).map_err(OpenRouterError::SerializeRequest)?,
453 ))
454 .map_err(OpenRouterError::BuildRequestBody)?;
455 let mut response = client
456 .send(request)
457 .await
458 .map_err(OpenRouterError::HttpSend)?;
459
460 if response.status().is_success() {
461 let reader = BufReader::new(response.into_body());
462 Ok(reader
463 .lines()
464 .filter_map(|line| async move {
465 match line {
466 Ok(line) => {
467 if line.starts_with(':') {
468 return None;
469 }
470
471 let line = line.strip_prefix("data: ")?;
472 if line == "[DONE]" {
473 None
474 } else {
475 match serde_json::from_str::<ResponseStreamEvent>(line) {
476 Ok(response) => Some(Ok(response)),
477 Err(error) => {
478 if line.trim().is_empty() {
479 None
480 } else {
481 Some(Err(OpenRouterError::DeserializeResponse(error)))
482 }
483 }
484 }
485 }
486 }
487 Err(error) => Some(Err(OpenRouterError::ReadResponse(error))),
488 }
489 })
490 .boxed())
491 } else {
492 let code = ApiErrorCode::from_status(response.status().as_u16());
493
494 let mut body = String::new();
495 response
496 .body_mut()
497 .read_to_string(&mut body)
498 .await
499 .map_err(OpenRouterError::ReadResponse)?;
500
501 let error_response = match serde_json::from_str::<OpenRouterErrorResponse>(&body) {
502 Ok(OpenRouterErrorResponse { error }) => error,
503 Err(_) => OpenRouterErrorBody {
504 code: response.status().as_u16(),
505 message: body,
506 metadata: None,
507 },
508 };
509
510 match code {
511 ApiErrorCode::RateLimitError => {
512 let retry_after = extract_retry_after(response.headers());
513 Err(OpenRouterError::RateLimit {
514 retry_after: retry_after.unwrap_or_else(|| std::time::Duration::from_secs(60)),
515 })
516 }
517 ApiErrorCode::OverloadedError => {
518 let retry_after = extract_retry_after(response.headers());
519 Err(OpenRouterError::ServerOverloaded { retry_after })
520 }
521 _ => Err(OpenRouterError::ApiError(ApiError {
522 code: code,
523 message: error_response.message,
524 })),
525 }
526 }
527}
528
529pub async fn list_models(
530 client: &dyn HttpClient,
531 api_url: &str,
532) -> Result<Vec<Model>, OpenRouterError> {
533 let uri = format!("{api_url}/models");
534 let request_builder = HttpRequest::builder()
535 .method(Method::GET)
536 .uri(uri)
537 .header("Accept", "application/json");
538
539 let request = request_builder
540 .body(AsyncBody::default())
541 .map_err(OpenRouterError::BuildRequestBody)?;
542 let mut response = client
543 .send(request)
544 .await
545 .map_err(OpenRouterError::HttpSend)?;
546
547 let mut body = String::new();
548 response
549 .body_mut()
550 .read_to_string(&mut body)
551 .await
552 .map_err(OpenRouterError::ReadResponse)?;
553
554 if response.status().is_success() {
555 let response: ListModelsResponse =
556 serde_json::from_str(&body).map_err(OpenRouterError::DeserializeResponse)?;
557
558 let models = response
559 .data
560 .into_iter()
561 .map(|entry| Model {
562 name: entry.id,
563 // OpenRouter returns display names in the format "provider_name: model_name".
564 // When displayed in the UI, these names can get truncated from the right.
565 // Since users typically already know the provider, we extract just the model name
566 // portion (after the colon) to create a more concise and user-friendly label
567 // for the model dropdown in the agent panel.
568 display_name: Some(
569 entry
570 .name
571 .split(':')
572 .next_back()
573 .unwrap_or(&entry.name)
574 .trim()
575 .to_string(),
576 ),
577 max_tokens: entry.context_length.unwrap_or(2000000),
578 supports_tools: Some(entry.supported_parameters.contains(&"tools".to_string())),
579 supports_images: Some(
580 entry
581 .architecture
582 .as_ref()
583 .map(|arch| arch.input_modalities.contains(&"image".to_string()))
584 .unwrap_or(false),
585 ),
586 mode: if entry
587 .supported_parameters
588 .contains(&"reasoning".to_string())
589 {
590 ModelMode::Thinking {
591 budget_tokens: Some(4_096),
592 }
593 } else {
594 ModelMode::Default
595 },
596 })
597 .collect();
598
599 Ok(models)
600 } else {
601 let code = ApiErrorCode::from_status(response.status().as_u16());
602
603 let mut body = String::new();
604 response
605 .body_mut()
606 .read_to_string(&mut body)
607 .await
608 .map_err(OpenRouterError::ReadResponse)?;
609
610 let error_response = match serde_json::from_str::<OpenRouterErrorResponse>(&body) {
611 Ok(OpenRouterErrorResponse { error }) => error,
612 Err(_) => OpenRouterErrorBody {
613 code: response.status().as_u16(),
614 message: body,
615 metadata: None,
616 },
617 };
618
619 match code {
620 ApiErrorCode::RateLimitError => {
621 let retry_after = extract_retry_after(response.headers());
622 Err(OpenRouterError::RateLimit {
623 retry_after: retry_after.unwrap_or_else(|| std::time::Duration::from_secs(60)),
624 })
625 }
626 ApiErrorCode::OverloadedError => {
627 let retry_after = extract_retry_after(response.headers());
628 Err(OpenRouterError::ServerOverloaded { retry_after })
629 }
630 _ => Err(OpenRouterError::ApiError(ApiError {
631 code: code,
632 message: error_response.message,
633 })),
634 }
635 }
636}
637
638#[derive(Debug)]
639pub enum OpenRouterError {
640 /// Failed to serialize the HTTP request body to JSON
641 SerializeRequest(serde_json::Error),
642
643 /// Failed to construct the HTTP request body
644 BuildRequestBody(http::Error),
645
646 /// Failed to send the HTTP request
647 HttpSend(anyhow::Error),
648
649 /// Failed to deserialize the response from JSON
650 DeserializeResponse(serde_json::Error),
651
652 /// Failed to read from response stream
653 ReadResponse(io::Error),
654
655 /// Rate limit exceeded
656 RateLimit { retry_after: Duration },
657
658 /// Server overloaded
659 ServerOverloaded { retry_after: Option<Duration> },
660
661 /// API returned an error response
662 ApiError(ApiError),
663}
664
665#[derive(Debug, Serialize, Deserialize)]
666pub struct OpenRouterErrorBody {
667 pub code: u16,
668 pub message: String,
669 #[serde(default, skip_serializing_if = "Option::is_none")]
670 pub metadata: Option<std::collections::HashMap<String, serde_json::Value>>,
671}
672
673#[derive(Debug, Serialize, Deserialize)]
674pub struct OpenRouterErrorResponse {
675 pub error: OpenRouterErrorBody,
676}
677
678#[derive(Debug, Serialize, Deserialize, Error)]
679#[error("OpenRouter API Error: {code}: {message}")]
680pub struct ApiError {
681 pub code: ApiErrorCode,
682 pub message: String,
683}
684
685/// An OpenROuter API error code.
686/// <https://openrouter.ai/docs/api-reference/errors#error-codes>
687#[derive(Debug, PartialEq, Eq, Clone, Copy, EnumString, Serialize, Deserialize)]
688#[strum(serialize_all = "snake_case")]
689pub enum ApiErrorCode {
690 /// 400: Bad Request (invalid or missing params, CORS)
691 InvalidRequestError,
692 /// 401: Invalid credentials (OAuth session expired, disabled/invalid API key)
693 AuthenticationError,
694 /// 402: Your account or API key has insufficient credits. Add more credits and retry the request.
695 PaymentRequiredError,
696 /// 403: Your chosen model requires moderation and your input was flagged
697 PermissionError,
698 /// 408: Your request timed out
699 RequestTimedOut,
700 /// 429: You are being rate limited
701 RateLimitError,
702 /// 502: Your chosen model is down or we received an invalid response from it
703 ApiError,
704 /// 503: There is no available model provider that meets your routing requirements
705 OverloadedError,
706}
707
708impl std::fmt::Display for ApiErrorCode {
709 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
710 let s = match self {
711 ApiErrorCode::InvalidRequestError => "invalid_request_error",
712 ApiErrorCode::AuthenticationError => "authentication_error",
713 ApiErrorCode::PaymentRequiredError => "payment_required_error",
714 ApiErrorCode::PermissionError => "permission_error",
715 ApiErrorCode::RequestTimedOut => "request_timed_out",
716 ApiErrorCode::RateLimitError => "rate_limit_error",
717 ApiErrorCode::ApiError => "api_error",
718 ApiErrorCode::OverloadedError => "overloaded_error",
719 };
720 write!(f, "{s}")
721 }
722}
723
724impl ApiErrorCode {
725 pub fn from_status(status: u16) -> Self {
726 match status {
727 400 => ApiErrorCode::InvalidRequestError,
728 401 => ApiErrorCode::AuthenticationError,
729 402 => ApiErrorCode::PaymentRequiredError,
730 403 => ApiErrorCode::PermissionError,
731 408 => ApiErrorCode::RequestTimedOut,
732 429 => ApiErrorCode::RateLimitError,
733 502 => ApiErrorCode::ApiError,
734 503 => ApiErrorCode::OverloadedError,
735 _ => ApiErrorCode::ApiError,
736 }
737 }
738}