1use anyhow::{Context, Result, anyhow};
2use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
3use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
4use serde::{Deserialize, Serialize};
5use serde_json::Value;
6use std::convert::TryFrom;
7
8pub const OPEN_ROUTER_API_URL: &str = "https://openrouter.ai/api/v1";
9
10fn is_none_or_empty<T: AsRef<[U]>, U>(opt: &Option<T>) -> bool {
11 opt.as_ref().map_or(true, |v| v.as_ref().is_empty())
12}
13
14#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
15#[serde(rename_all = "lowercase")]
16pub enum Role {
17 User,
18 Assistant,
19 System,
20 Tool,
21}
22
23impl TryFrom<String> for Role {
24 type Error = anyhow::Error;
25
26 fn try_from(value: String) -> Result<Self> {
27 match value.as_str() {
28 "user" => Ok(Self::User),
29 "assistant" => Ok(Self::Assistant),
30 "system" => Ok(Self::System),
31 "tool" => Ok(Self::Tool),
32 _ => Err(anyhow!("invalid role '{value}'")),
33 }
34 }
35}
36
37impl From<Role> for String {
38 fn from(val: Role) -> Self {
39 match val {
40 Role::User => "user".to_owned(),
41 Role::Assistant => "assistant".to_owned(),
42 Role::System => "system".to_owned(),
43 Role::Tool => "tool".to_owned(),
44 }
45 }
46}
47
48#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
49#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
50pub struct Model {
51 pub name: String,
52 pub display_name: Option<String>,
53 pub max_tokens: u64,
54 pub supports_tools: Option<bool>,
55 pub supports_images: Option<bool>,
56 #[serde(default)]
57 pub mode: ModelMode,
58}
59
60#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
61#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
62pub enum ModelMode {
63 #[default]
64 Default,
65 Thinking {
66 budget_tokens: Option<u32>,
67 },
68}
69
70impl Model {
71 pub fn default_fast() -> Self {
72 Self::new(
73 "openrouter/auto",
74 Some("Auto Router"),
75 Some(2000000),
76 Some(true),
77 Some(false),
78 Some(ModelMode::Default),
79 )
80 }
81
82 pub fn default() -> Self {
83 Self::default_fast()
84 }
85
86 pub fn new(
87 name: &str,
88 display_name: Option<&str>,
89 max_tokens: Option<u64>,
90 supports_tools: Option<bool>,
91 supports_images: Option<bool>,
92 mode: Option<ModelMode>,
93 ) -> Self {
94 Self {
95 name: name.to_owned(),
96 display_name: display_name.map(|s| s.to_owned()),
97 max_tokens: max_tokens.unwrap_or(2000000),
98 supports_tools,
99 supports_images,
100 mode: mode.unwrap_or(ModelMode::Default),
101 }
102 }
103
104 pub fn id(&self) -> &str {
105 &self.name
106 }
107
108 pub fn display_name(&self) -> &str {
109 self.display_name.as_ref().unwrap_or(&self.name)
110 }
111
112 pub fn max_token_count(&self) -> u64 {
113 self.max_tokens
114 }
115
116 pub fn max_output_tokens(&self) -> Option<u64> {
117 None
118 }
119
120 pub fn supports_tool_calls(&self) -> bool {
121 self.supports_tools.unwrap_or(false)
122 }
123
124 pub fn supports_parallel_tool_calls(&self) -> bool {
125 false
126 }
127}
128
129#[derive(Debug, Serialize, Deserialize)]
130pub struct Request {
131 pub model: String,
132 pub messages: Vec<RequestMessage>,
133 pub stream: bool,
134 #[serde(default, skip_serializing_if = "Option::is_none")]
135 pub max_tokens: Option<u64>,
136 #[serde(default, skip_serializing_if = "Vec::is_empty")]
137 pub stop: Vec<String>,
138 pub temperature: f32,
139 #[serde(default, skip_serializing_if = "Option::is_none")]
140 pub tool_choice: Option<ToolChoice>,
141 #[serde(default, skip_serializing_if = "Option::is_none")]
142 pub parallel_tool_calls: Option<bool>,
143 #[serde(default, skip_serializing_if = "Vec::is_empty")]
144 pub tools: Vec<ToolDefinition>,
145 #[serde(default, skip_serializing_if = "Option::is_none")]
146 pub reasoning: Option<Reasoning>,
147 pub usage: RequestUsage,
148}
149
150#[derive(Debug, Default, Serialize, Deserialize)]
151pub struct RequestUsage {
152 pub include: bool,
153}
154
155#[derive(Debug, Serialize, Deserialize)]
156#[serde(rename_all = "lowercase")]
157pub enum ToolChoice {
158 Auto,
159 Required,
160 None,
161 #[serde(untagged)]
162 Other(ToolDefinition),
163}
164
165#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
166#[derive(Clone, Deserialize, Serialize, Debug)]
167#[serde(tag = "type", rename_all = "snake_case")]
168pub enum ToolDefinition {
169 #[allow(dead_code)]
170 Function { function: FunctionDefinition },
171}
172
173#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
174#[derive(Clone, Debug, Serialize, Deserialize)]
175pub struct FunctionDefinition {
176 pub name: String,
177 pub description: Option<String>,
178 pub parameters: Option<Value>,
179}
180
181#[derive(Debug, Serialize, Deserialize)]
182pub struct Reasoning {
183 #[serde(skip_serializing_if = "Option::is_none")]
184 pub effort: Option<String>,
185 #[serde(skip_serializing_if = "Option::is_none")]
186 pub max_tokens: Option<u32>,
187 #[serde(skip_serializing_if = "Option::is_none")]
188 pub exclude: Option<bool>,
189 #[serde(skip_serializing_if = "Option::is_none")]
190 pub enabled: Option<bool>,
191}
192
193#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
194#[serde(tag = "role", rename_all = "lowercase")]
195pub enum RequestMessage {
196 Assistant {
197 content: Option<MessageContent>,
198 #[serde(default, skip_serializing_if = "Vec::is_empty")]
199 tool_calls: Vec<ToolCall>,
200 },
201 User {
202 content: MessageContent,
203 },
204 System {
205 content: MessageContent,
206 },
207 Tool {
208 content: MessageContent,
209 tool_call_id: String,
210 },
211}
212
213#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
214#[serde(untagged)]
215pub enum MessageContent {
216 Plain(String),
217 Multipart(Vec<MessagePart>),
218}
219
220impl MessageContent {
221 pub fn empty() -> Self {
222 Self::Plain(String::new())
223 }
224
225 pub fn push_part(&mut self, part: MessagePart) {
226 match self {
227 Self::Plain(text) if text.is_empty() => {
228 *self = Self::Multipart(vec![part]);
229 }
230 Self::Plain(text) => {
231 let text_part = MessagePart::Text {
232 text: std::mem::take(text),
233 };
234 *self = Self::Multipart(vec![text_part, part]);
235 }
236 Self::Multipart(parts) => parts.push(part),
237 }
238 }
239}
240
241impl From<Vec<MessagePart>> for MessageContent {
242 fn from(parts: Vec<MessagePart>) -> Self {
243 if parts.len() == 1
244 && let MessagePart::Text { text } = &parts[0] {
245 return Self::Plain(text.clone());
246 }
247 Self::Multipart(parts)
248 }
249}
250
251impl From<String> for MessageContent {
252 fn from(text: String) -> Self {
253 Self::Plain(text)
254 }
255}
256
257impl From<&str> for MessageContent {
258 fn from(text: &str) -> Self {
259 Self::Plain(text.to_string())
260 }
261}
262
263impl MessageContent {
264 pub fn as_text(&self) -> Option<&str> {
265 match self {
266 Self::Plain(text) => Some(text),
267 Self::Multipart(parts) if parts.len() == 1 => {
268 if let MessagePart::Text { text } = &parts[0] {
269 Some(text)
270 } else {
271 None
272 }
273 }
274 _ => None,
275 }
276 }
277
278 pub fn to_text(&self) -> String {
279 match self {
280 Self::Plain(text) => text.clone(),
281 Self::Multipart(parts) => parts
282 .iter()
283 .filter_map(|part| {
284 if let MessagePart::Text { text } = part {
285 Some(text.as_str())
286 } else {
287 None
288 }
289 })
290 .collect::<Vec<_>>()
291 .join(""),
292 }
293 }
294}
295
296#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
297#[serde(tag = "type", rename_all = "snake_case")]
298pub enum MessagePart {
299 Text {
300 text: String,
301 },
302 #[serde(rename = "image_url")]
303 Image {
304 image_url: String,
305 },
306}
307
308#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
309pub struct ToolCall {
310 pub id: String,
311 #[serde(flatten)]
312 pub content: ToolCallContent,
313}
314
315#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
316#[serde(tag = "type", rename_all = "lowercase")]
317pub enum ToolCallContent {
318 Function { function: FunctionContent },
319}
320
321#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
322pub struct FunctionContent {
323 pub name: String,
324 pub arguments: String,
325}
326
327#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
328pub struct ResponseMessageDelta {
329 pub role: Option<Role>,
330 pub content: Option<String>,
331 pub reasoning: Option<String>,
332 #[serde(default, skip_serializing_if = "is_none_or_empty")]
333 pub tool_calls: Option<Vec<ToolCallChunk>>,
334}
335
336#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
337pub struct ToolCallChunk {
338 pub index: usize,
339 pub id: Option<String>,
340 pub function: Option<FunctionChunk>,
341}
342
343#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
344pub struct FunctionChunk {
345 pub name: Option<String>,
346 pub arguments: Option<String>,
347}
348
349#[derive(Serialize, Deserialize, Debug)]
350pub struct Usage {
351 pub prompt_tokens: u64,
352 pub completion_tokens: u64,
353 pub total_tokens: u64,
354}
355
356#[derive(Serialize, Deserialize, Debug)]
357pub struct ChoiceDelta {
358 pub index: u32,
359 pub delta: ResponseMessageDelta,
360 pub finish_reason: Option<String>,
361}
362
363#[derive(Serialize, Deserialize, Debug)]
364pub struct ResponseStreamEvent {
365 #[serde(default, skip_serializing_if = "Option::is_none")]
366 pub id: Option<String>,
367 pub created: u32,
368 pub model: String,
369 pub choices: Vec<ChoiceDelta>,
370 pub usage: Option<Usage>,
371}
372
373#[derive(Serialize, Deserialize, Debug)]
374pub struct Response {
375 pub id: String,
376 pub object: String,
377 pub created: u64,
378 pub model: String,
379 pub choices: Vec<Choice>,
380 pub usage: Usage,
381}
382
383#[derive(Serialize, Deserialize, Debug)]
384pub struct Choice {
385 pub index: u32,
386 pub message: RequestMessage,
387 pub finish_reason: Option<String>,
388}
389
390#[derive(Default, Debug, Clone, PartialEq, Deserialize)]
391pub struct ListModelsResponse {
392 pub data: Vec<ModelEntry>,
393}
394
395#[derive(Default, Debug, Clone, PartialEq, Deserialize)]
396pub struct ModelEntry {
397 pub id: String,
398 pub name: String,
399 pub created: usize,
400 pub description: String,
401 #[serde(default, skip_serializing_if = "Option::is_none")]
402 pub context_length: Option<u64>,
403 #[serde(default, skip_serializing_if = "Vec::is_empty")]
404 pub supported_parameters: Vec<String>,
405 #[serde(default, skip_serializing_if = "Option::is_none")]
406 pub architecture: Option<ModelArchitecture>,
407}
408
409#[derive(Default, Debug, Clone, PartialEq, Deserialize)]
410pub struct ModelArchitecture {
411 #[serde(default, skip_serializing_if = "Vec::is_empty")]
412 pub input_modalities: Vec<String>,
413}
414
415pub async fn complete(
416 client: &dyn HttpClient,
417 api_url: &str,
418 api_key: &str,
419 request: Request,
420) -> Result<Response> {
421 let uri = format!("{api_url}/chat/completions");
422 let request_builder = HttpRequest::builder()
423 .method(Method::POST)
424 .uri(uri)
425 .header("Content-Type", "application/json")
426 .header("Authorization", format!("Bearer {}", api_key))
427 .header("HTTP-Referer", "https://zed.dev")
428 .header("X-Title", "Zed Editor");
429
430 let mut request_body = request;
431 request_body.stream = false;
432
433 let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request_body)?))?;
434 let mut response = client.send(request).await?;
435
436 if response.status().is_success() {
437 let mut body = String::new();
438 response.body_mut().read_to_string(&mut body).await?;
439 let response: Response = serde_json::from_str(&body)?;
440 Ok(response)
441 } else {
442 let mut body = String::new();
443 response.body_mut().read_to_string(&mut body).await?;
444
445 #[derive(Deserialize)]
446 struct OpenRouterResponse {
447 error: OpenRouterError,
448 }
449
450 #[derive(Deserialize)]
451 struct OpenRouterError {
452 message: String,
453 #[serde(default)]
454 code: String,
455 }
456
457 match serde_json::from_str::<OpenRouterResponse>(&body) {
458 Ok(response) if !response.error.message.is_empty() => {
459 let error_message = if !response.error.code.is_empty() {
460 format!("{}: {}", response.error.code, response.error.message)
461 } else {
462 response.error.message
463 };
464
465 Err(anyhow!(
466 "Failed to connect to OpenRouter API: {}",
467 error_message
468 ))
469 }
470 _ => Err(anyhow!(
471 "Failed to connect to OpenRouter API: {} {}",
472 response.status(),
473 body,
474 )),
475 }
476 }
477}
478
479pub async fn stream_completion(
480 client: &dyn HttpClient,
481 api_url: &str,
482 api_key: &str,
483 request: Request,
484) -> Result<BoxStream<'static, Result<ResponseStreamEvent>>> {
485 let uri = format!("{api_url}/chat/completions");
486 let request_builder = HttpRequest::builder()
487 .method(Method::POST)
488 .uri(uri)
489 .header("Content-Type", "application/json")
490 .header("Authorization", format!("Bearer {}", api_key))
491 .header("HTTP-Referer", "https://zed.dev")
492 .header("X-Title", "Zed Editor");
493
494 let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
495 let mut response = client.send(request).await?;
496
497 if response.status().is_success() {
498 let reader = BufReader::new(response.into_body());
499 Ok(reader
500 .lines()
501 .filter_map(|line| async move {
502 match line {
503 Ok(line) => {
504 if line.starts_with(':') {
505 return None;
506 }
507
508 let line = line.strip_prefix("data: ")?;
509 if line == "[DONE]" {
510 None
511 } else {
512 match serde_json::from_str::<ResponseStreamEvent>(line) {
513 Ok(response) => Some(Ok(response)),
514 Err(error) => {
515 #[derive(Deserialize)]
516 struct ErrorResponse {
517 error: String,
518 }
519
520 match serde_json::from_str::<ErrorResponse>(line) {
521 Ok(err_response) => Some(Err(anyhow!(err_response.error))),
522 Err(_) => {
523 if line.trim().is_empty() {
524 None
525 } else {
526 Some(Err(anyhow!(
527 "Failed to parse response: {}. Original content: '{}'",
528 error, line
529 )))
530 }
531 }
532 }
533 }
534 }
535 }
536 }
537 Err(error) => Some(Err(anyhow!(error))),
538 }
539 })
540 .boxed())
541 } else {
542 let mut body = String::new();
543 response.body_mut().read_to_string(&mut body).await?;
544
545 #[derive(Deserialize)]
546 struct OpenRouterResponse {
547 error: OpenRouterError,
548 }
549
550 #[derive(Deserialize)]
551 struct OpenRouterError {
552 message: String,
553 #[serde(default)]
554 code: String,
555 }
556
557 match serde_json::from_str::<OpenRouterResponse>(&body) {
558 Ok(response) if !response.error.message.is_empty() => {
559 let error_message = if !response.error.code.is_empty() {
560 format!("{}: {}", response.error.code, response.error.message)
561 } else {
562 response.error.message
563 };
564
565 Err(anyhow!(
566 "Failed to connect to OpenRouter API: {}",
567 error_message
568 ))
569 }
570 _ => Err(anyhow!(
571 "Failed to connect to OpenRouter API: {} {}",
572 response.status(),
573 body,
574 )),
575 }
576 }
577}
578
579pub async fn list_models(client: &dyn HttpClient, api_url: &str) -> Result<Vec<Model>> {
580 let uri = format!("{api_url}/models");
581 let request_builder = HttpRequest::builder()
582 .method(Method::GET)
583 .uri(uri)
584 .header("Accept", "application/json");
585
586 let request = request_builder.body(AsyncBody::default())?;
587 let mut response = client.send(request).await?;
588
589 let mut body = String::new();
590 response.body_mut().read_to_string(&mut body).await?;
591
592 if response.status().is_success() {
593 let response: ListModelsResponse =
594 serde_json::from_str(&body).context("Unable to parse OpenRouter models response")?;
595
596 let models = response
597 .data
598 .into_iter()
599 .map(|entry| Model {
600 name: entry.id,
601 // OpenRouter returns display names in the format "provider_name: model_name".
602 // When displayed in the UI, these names can get truncated from the right.
603 // Since users typically already know the provider, we extract just the model name
604 // portion (after the colon) to create a more concise and user-friendly label
605 // for the model dropdown in the agent panel.
606 display_name: Some(
607 entry
608 .name
609 .split(':')
610 .next_back()
611 .unwrap_or(&entry.name)
612 .trim()
613 .to_string(),
614 ),
615 max_tokens: entry.context_length.unwrap_or(2000000),
616 supports_tools: Some(entry.supported_parameters.contains(&"tools".to_string())),
617 supports_images: Some(
618 entry
619 .architecture
620 .as_ref()
621 .map(|arch| arch.input_modalities.contains(&"image".to_string()))
622 .unwrap_or(false),
623 ),
624 mode: if entry
625 .supported_parameters
626 .contains(&"reasoning".to_string())
627 {
628 ModelMode::Thinking {
629 budget_tokens: Some(4_096),
630 }
631 } else {
632 ModelMode::Default
633 },
634 })
635 .collect();
636
637 Ok(models)
638 } else {
639 Err(anyhow!(
640 "Failed to connect to OpenRouter API: {} {}",
641 response.status(),
642 body,
643 ))
644 }
645}