1use anyhow::{Context, Result, anyhow};
2use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
3use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
4use serde::{Deserialize, Serialize};
5use serde_json::Value;
6use std::convert::TryFrom;
7
8pub const OPEN_ROUTER_API_URL: &str = "https://openrouter.ai/api/v1";
9
10fn is_none_or_empty<T: AsRef<[U]>, U>(opt: &Option<T>) -> bool {
11 opt.as_ref().is_none_or(|v| v.as_ref().is_empty())
12}
13
14#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
15#[serde(rename_all = "lowercase")]
16pub enum Role {
17 User,
18 Assistant,
19 System,
20 Tool,
21}
22
23impl TryFrom<String> for Role {
24 type Error = anyhow::Error;
25
26 fn try_from(value: String) -> Result<Self> {
27 match value.as_str() {
28 "user" => Ok(Self::User),
29 "assistant" => Ok(Self::Assistant),
30 "system" => Ok(Self::System),
31 "tool" => Ok(Self::Tool),
32 _ => Err(anyhow!("invalid role '{value}'")),
33 }
34 }
35}
36
37impl From<Role> for String {
38 fn from(val: Role) -> Self {
39 match val {
40 Role::User => "user".to_owned(),
41 Role::Assistant => "assistant".to_owned(),
42 Role::System => "system".to_owned(),
43 Role::Tool => "tool".to_owned(),
44 }
45 }
46}
47
48#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
49#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
50pub struct Model {
51 pub name: String,
52 pub display_name: Option<String>,
53 pub max_tokens: u64,
54 pub supports_tools: Option<bool>,
55 pub supports_images: Option<bool>,
56 #[serde(default)]
57 pub mode: ModelMode,
58}
59
60#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
61#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
62pub enum ModelMode {
63 #[default]
64 Default,
65 Thinking {
66 budget_tokens: Option<u32>,
67 },
68}
69
70impl Model {
71 pub fn default_fast() -> Self {
72 Self::new(
73 "openrouter/auto",
74 Some("Auto Router"),
75 Some(2000000),
76 Some(true),
77 Some(false),
78 Some(ModelMode::Default),
79 )
80 }
81
82 pub fn default() -> Self {
83 Self::default_fast()
84 }
85
86 pub fn new(
87 name: &str,
88 display_name: Option<&str>,
89 max_tokens: Option<u64>,
90 supports_tools: Option<bool>,
91 supports_images: Option<bool>,
92 mode: Option<ModelMode>,
93 ) -> Self {
94 Self {
95 name: name.to_owned(),
96 display_name: display_name.map(|s| s.to_owned()),
97 max_tokens: max_tokens.unwrap_or(2000000),
98 supports_tools,
99 supports_images,
100 mode: mode.unwrap_or(ModelMode::Default),
101 }
102 }
103
104 pub fn id(&self) -> &str {
105 &self.name
106 }
107
108 pub fn display_name(&self) -> &str {
109 self.display_name.as_ref().unwrap_or(&self.name)
110 }
111
112 pub fn max_token_count(&self) -> u64 {
113 self.max_tokens
114 }
115
116 pub fn max_output_tokens(&self) -> Option<u64> {
117 None
118 }
119
120 pub fn supports_tool_calls(&self) -> bool {
121 self.supports_tools.unwrap_or(false)
122 }
123
124 pub fn supports_parallel_tool_calls(&self) -> bool {
125 false
126 }
127}
128
129#[derive(Debug, Serialize, Deserialize)]
130pub struct Request {
131 pub model: String,
132 pub messages: Vec<RequestMessage>,
133 pub stream: bool,
134 #[serde(default, skip_serializing_if = "Option::is_none")]
135 pub max_tokens: Option<u64>,
136 #[serde(default, skip_serializing_if = "Vec::is_empty")]
137 pub stop: Vec<String>,
138 pub temperature: f32,
139 #[serde(default, skip_serializing_if = "Option::is_none")]
140 pub tool_choice: Option<ToolChoice>,
141 #[serde(default, skip_serializing_if = "Option::is_none")]
142 pub parallel_tool_calls: Option<bool>,
143 #[serde(default, skip_serializing_if = "Vec::is_empty")]
144 pub tools: Vec<ToolDefinition>,
145 #[serde(default, skip_serializing_if = "Option::is_none")]
146 pub reasoning: Option<Reasoning>,
147 pub usage: RequestUsage,
148}
149
150#[derive(Debug, Default, Serialize, Deserialize)]
151pub struct RequestUsage {
152 pub include: bool,
153}
154
155#[derive(Debug, Serialize, Deserialize)]
156#[serde(rename_all = "lowercase")]
157pub enum ToolChoice {
158 Auto,
159 Required,
160 None,
161 #[serde(untagged)]
162 Other(ToolDefinition),
163}
164
165#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
166#[derive(Clone, Deserialize, Serialize, Debug)]
167#[serde(tag = "type", rename_all = "snake_case")]
168pub enum ToolDefinition {
169 #[allow(dead_code)]
170 Function { function: FunctionDefinition },
171}
172
173#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
174#[derive(Clone, Debug, Serialize, Deserialize)]
175pub struct FunctionDefinition {
176 pub name: String,
177 pub description: Option<String>,
178 pub parameters: Option<Value>,
179}
180
181#[derive(Debug, Serialize, Deserialize)]
182pub struct Reasoning {
183 #[serde(skip_serializing_if = "Option::is_none")]
184 pub effort: Option<String>,
185 #[serde(skip_serializing_if = "Option::is_none")]
186 pub max_tokens: Option<u32>,
187 #[serde(skip_serializing_if = "Option::is_none")]
188 pub exclude: Option<bool>,
189 #[serde(skip_serializing_if = "Option::is_none")]
190 pub enabled: Option<bool>,
191}
192
193#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
194#[serde(tag = "role", rename_all = "lowercase")]
195pub enum RequestMessage {
196 Assistant {
197 content: Option<MessageContent>,
198 #[serde(default, skip_serializing_if = "Vec::is_empty")]
199 tool_calls: Vec<ToolCall>,
200 },
201 User {
202 content: MessageContent,
203 },
204 System {
205 content: MessageContent,
206 },
207 Tool {
208 content: MessageContent,
209 tool_call_id: String,
210 },
211}
212
213#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
214#[serde(untagged)]
215pub enum MessageContent {
216 Plain(String),
217 Multipart(Vec<MessagePart>),
218}
219
220impl MessageContent {
221 pub fn empty() -> Self {
222 Self::Plain(String::new())
223 }
224
225 pub fn push_part(&mut self, part: MessagePart) {
226 match self {
227 Self::Plain(text) if text.is_empty() => {
228 *self = Self::Multipart(vec![part]);
229 }
230 Self::Plain(text) => {
231 let text_part = MessagePart::Text {
232 text: std::mem::take(text),
233 };
234 *self = Self::Multipart(vec![text_part, part]);
235 }
236 Self::Multipart(parts) => parts.push(part),
237 }
238 }
239}
240
241impl From<Vec<MessagePart>> for MessageContent {
242 fn from(parts: Vec<MessagePart>) -> Self {
243 if parts.len() == 1
244 && let MessagePart::Text { text } = &parts[0]
245 {
246 return Self::Plain(text.clone());
247 }
248 Self::Multipart(parts)
249 }
250}
251
252impl From<String> for MessageContent {
253 fn from(text: String) -> Self {
254 Self::Plain(text)
255 }
256}
257
258impl From<&str> for MessageContent {
259 fn from(text: &str) -> Self {
260 Self::Plain(text.to_string())
261 }
262}
263
264impl MessageContent {
265 pub fn as_text(&self) -> Option<&str> {
266 match self {
267 Self::Plain(text) => Some(text),
268 Self::Multipart(parts) if parts.len() == 1 => {
269 if let MessagePart::Text { text } = &parts[0] {
270 Some(text)
271 } else {
272 None
273 }
274 }
275 _ => None,
276 }
277 }
278
279 pub fn to_text(&self) -> String {
280 match self {
281 Self::Plain(text) => text.clone(),
282 Self::Multipart(parts) => parts
283 .iter()
284 .filter_map(|part| {
285 if let MessagePart::Text { text } = part {
286 Some(text.as_str())
287 } else {
288 None
289 }
290 })
291 .collect::<Vec<_>>()
292 .join(""),
293 }
294 }
295}
296
297#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
298#[serde(tag = "type", rename_all = "snake_case")]
299pub enum MessagePart {
300 Text {
301 text: String,
302 },
303 #[serde(rename = "image_url")]
304 Image {
305 image_url: String,
306 },
307}
308
309#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
310pub struct ToolCall {
311 pub id: String,
312 #[serde(flatten)]
313 pub content: ToolCallContent,
314}
315
316#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
317#[serde(tag = "type", rename_all = "lowercase")]
318pub enum ToolCallContent {
319 Function { function: FunctionContent },
320}
321
322#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
323pub struct FunctionContent {
324 pub name: String,
325 pub arguments: String,
326}
327
328#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
329pub struct ResponseMessageDelta {
330 pub role: Option<Role>,
331 pub content: Option<String>,
332 pub reasoning: Option<String>,
333 #[serde(default, skip_serializing_if = "is_none_or_empty")]
334 pub tool_calls: Option<Vec<ToolCallChunk>>,
335}
336
337#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
338pub struct ToolCallChunk {
339 pub index: usize,
340 pub id: Option<String>,
341 pub function: Option<FunctionChunk>,
342}
343
344#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
345pub struct FunctionChunk {
346 pub name: Option<String>,
347 pub arguments: Option<String>,
348}
349
350#[derive(Serialize, Deserialize, Debug)]
351pub struct Usage {
352 pub prompt_tokens: u64,
353 pub completion_tokens: u64,
354 pub total_tokens: u64,
355}
356
357#[derive(Serialize, Deserialize, Debug)]
358pub struct ChoiceDelta {
359 pub index: u32,
360 pub delta: ResponseMessageDelta,
361 pub finish_reason: Option<String>,
362}
363
364#[derive(Serialize, Deserialize, Debug)]
365pub struct ResponseStreamEvent {
366 #[serde(default, skip_serializing_if = "Option::is_none")]
367 pub id: Option<String>,
368 pub created: u32,
369 pub model: String,
370 pub choices: Vec<ChoiceDelta>,
371 pub usage: Option<Usage>,
372}
373
374#[derive(Serialize, Deserialize, Debug)]
375pub struct Response {
376 pub id: String,
377 pub object: String,
378 pub created: u64,
379 pub model: String,
380 pub choices: Vec<Choice>,
381 pub usage: Usage,
382}
383
384#[derive(Serialize, Deserialize, Debug)]
385pub struct Choice {
386 pub index: u32,
387 pub message: RequestMessage,
388 pub finish_reason: Option<String>,
389}
390
391#[derive(Default, Debug, Clone, PartialEq, Deserialize)]
392pub struct ListModelsResponse {
393 pub data: Vec<ModelEntry>,
394}
395
396#[derive(Default, Debug, Clone, PartialEq, Deserialize)]
397pub struct ModelEntry {
398 pub id: String,
399 pub name: String,
400 pub created: usize,
401 pub description: String,
402 #[serde(default, skip_serializing_if = "Option::is_none")]
403 pub context_length: Option<u64>,
404 #[serde(default, skip_serializing_if = "Vec::is_empty")]
405 pub supported_parameters: Vec<String>,
406 #[serde(default, skip_serializing_if = "Option::is_none")]
407 pub architecture: Option<ModelArchitecture>,
408}
409
410#[derive(Default, Debug, Clone, PartialEq, Deserialize)]
411pub struct ModelArchitecture {
412 #[serde(default, skip_serializing_if = "Vec::is_empty")]
413 pub input_modalities: Vec<String>,
414}
415
416pub async fn complete(
417 client: &dyn HttpClient,
418 api_url: &str,
419 api_key: &str,
420 request: Request,
421) -> Result<Response> {
422 let uri = format!("{api_url}/chat/completions");
423 let request_builder = HttpRequest::builder()
424 .method(Method::POST)
425 .uri(uri)
426 .header("Content-Type", "application/json")
427 .header("Authorization", format!("Bearer {}", api_key.trim()))
428 .header("HTTP-Referer", "https://zed.dev")
429 .header("X-Title", "Zed Editor");
430
431 let mut request_body = request;
432 request_body.stream = false;
433
434 let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request_body)?))?;
435 let mut response = client.send(request).await?;
436
437 if response.status().is_success() {
438 let mut body = String::new();
439 response.body_mut().read_to_string(&mut body).await?;
440 let response: Response = serde_json::from_str(&body)?;
441 Ok(response)
442 } else {
443 let mut body = String::new();
444 response.body_mut().read_to_string(&mut body).await?;
445
446 #[derive(Deserialize)]
447 struct OpenRouterResponse {
448 error: OpenRouterError,
449 }
450
451 #[derive(Deserialize)]
452 struct OpenRouterError {
453 message: String,
454 #[serde(default)]
455 code: String,
456 }
457
458 match serde_json::from_str::<OpenRouterResponse>(&body) {
459 Ok(response) if !response.error.message.is_empty() => {
460 let error_message = if !response.error.code.is_empty() {
461 format!("{}: {}", response.error.code, response.error.message)
462 } else {
463 response.error.message
464 };
465
466 Err(anyhow!(
467 "Failed to connect to OpenRouter API: {}",
468 error_message
469 ))
470 }
471 _ => Err(anyhow!(
472 "Failed to connect to OpenRouter API: {} {}",
473 response.status(),
474 body,
475 )),
476 }
477 }
478}
479
480pub async fn stream_completion(
481 client: &dyn HttpClient,
482 api_url: &str,
483 api_key: &str,
484 request: Request,
485) -> Result<BoxStream<'static, Result<ResponseStreamEvent>>> {
486 let uri = format!("{api_url}/chat/completions");
487 let request_builder = HttpRequest::builder()
488 .method(Method::POST)
489 .uri(uri)
490 .header("Content-Type", "application/json")
491 .header("Authorization", format!("Bearer {}", api_key))
492 .header("HTTP-Referer", "https://zed.dev")
493 .header("X-Title", "Zed Editor");
494
495 let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
496 let mut response = client.send(request).await?;
497
498 if response.status().is_success() {
499 let reader = BufReader::new(response.into_body());
500 Ok(reader
501 .lines()
502 .filter_map(|line| async move {
503 match line {
504 Ok(line) => {
505 if line.starts_with(':') {
506 return None;
507 }
508
509 let line = line.strip_prefix("data: ")?;
510 if line == "[DONE]" {
511 None
512 } else {
513 match serde_json::from_str::<ResponseStreamEvent>(line) {
514 Ok(response) => Some(Ok(response)),
515 Err(error) => {
516 #[derive(Deserialize)]
517 struct ErrorResponse {
518 error: String,
519 }
520
521 match serde_json::from_str::<ErrorResponse>(line) {
522 Ok(err_response) => Some(Err(anyhow!(err_response.error))),
523 Err(_) => {
524 if line.trim().is_empty() {
525 None
526 } else {
527 Some(Err(anyhow!(
528 "Failed to parse response: {}. Original content: '{}'",
529 error, line
530 )))
531 }
532 }
533 }
534 }
535 }
536 }
537 }
538 Err(error) => Some(Err(anyhow!(error))),
539 }
540 })
541 .boxed())
542 } else {
543 let mut body = String::new();
544 response.body_mut().read_to_string(&mut body).await?;
545
546 #[derive(Deserialize)]
547 struct OpenRouterResponse {
548 error: OpenRouterError,
549 }
550
551 #[derive(Deserialize)]
552 struct OpenRouterError {
553 message: String,
554 #[serde(default)]
555 code: String,
556 }
557
558 match serde_json::from_str::<OpenRouterResponse>(&body) {
559 Ok(response) if !response.error.message.is_empty() => {
560 let error_message = if !response.error.code.is_empty() {
561 format!("{}: {}", response.error.code, response.error.message)
562 } else {
563 response.error.message
564 };
565
566 Err(anyhow!(
567 "Failed to connect to OpenRouter API: {}",
568 error_message
569 ))
570 }
571 _ => Err(anyhow!(
572 "Failed to connect to OpenRouter API: {} {}",
573 response.status(),
574 body,
575 )),
576 }
577 }
578}
579
580pub async fn list_models(client: &dyn HttpClient, api_url: &str) -> Result<Vec<Model>> {
581 let uri = format!("{api_url}/models");
582 let request_builder = HttpRequest::builder()
583 .method(Method::GET)
584 .uri(uri)
585 .header("Accept", "application/json");
586
587 let request = request_builder.body(AsyncBody::default())?;
588 let mut response = client.send(request).await?;
589
590 let mut body = String::new();
591 response.body_mut().read_to_string(&mut body).await?;
592
593 if response.status().is_success() {
594 let response: ListModelsResponse =
595 serde_json::from_str(&body).context("Unable to parse OpenRouter models response")?;
596
597 let models = response
598 .data
599 .into_iter()
600 .map(|entry| Model {
601 name: entry.id,
602 // OpenRouter returns display names in the format "provider_name: model_name".
603 // When displayed in the UI, these names can get truncated from the right.
604 // Since users typically already know the provider, we extract just the model name
605 // portion (after the colon) to create a more concise and user-friendly label
606 // for the model dropdown in the agent panel.
607 display_name: Some(
608 entry
609 .name
610 .split(':')
611 .next_back()
612 .unwrap_or(&entry.name)
613 .trim()
614 .to_string(),
615 ),
616 max_tokens: entry.context_length.unwrap_or(2000000),
617 supports_tools: Some(entry.supported_parameters.contains(&"tools".to_string())),
618 supports_images: Some(
619 entry
620 .architecture
621 .as_ref()
622 .map(|arch| arch.input_modalities.contains(&"image".to_string()))
623 .unwrap_or(false),
624 ),
625 mode: if entry
626 .supported_parameters
627 .contains(&"reasoning".to_string())
628 {
629 ModelMode::Thinking {
630 budget_tokens: Some(4_096),
631 }
632 } else {
633 ModelMode::Default
634 },
635 })
636 .collect();
637
638 Ok(models)
639 } else {
640 Err(anyhow!(
641 "Failed to connect to OpenRouter API: {} {}",
642 response.status(),
643 body,
644 ))
645 }
646}