1use anyhow::{Context, Result, anyhow};
2use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
3use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
4use serde::{Deserialize, Serialize};
5use serde_json::Value;
6use std::convert::TryFrom;
7
8pub const OPEN_ROUTER_API_URL: &str = "https://openrouter.ai/api/v1";
9
10fn is_none_or_empty<T: AsRef<[U]>, U>(opt: &Option<T>) -> bool {
11 opt.as_ref().map_or(true, |v| v.as_ref().is_empty())
12}
13
14#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
15#[serde(rename_all = "lowercase")]
16pub enum Role {
17 User,
18 Assistant,
19 System,
20 Tool,
21}
22
23impl TryFrom<String> for Role {
24 type Error = anyhow::Error;
25
26 fn try_from(value: String) -> Result<Self> {
27 match value.as_str() {
28 "user" => Ok(Self::User),
29 "assistant" => Ok(Self::Assistant),
30 "system" => Ok(Self::System),
31 "tool" => Ok(Self::Tool),
32 _ => Err(anyhow!("invalid role '{value}'")),
33 }
34 }
35}
36
37impl From<Role> for String {
38 fn from(val: Role) -> Self {
39 match val {
40 Role::User => "user".to_owned(),
41 Role::Assistant => "assistant".to_owned(),
42 Role::System => "system".to_owned(),
43 Role::Tool => "tool".to_owned(),
44 }
45 }
46}
47
48#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
49#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
50pub struct Model {
51 pub name: String,
52 pub display_name: Option<String>,
53 pub max_tokens: u64,
54 pub supports_tools: Option<bool>,
55 pub supports_images: Option<bool>,
56}
57
58impl Model {
59 pub fn default_fast() -> Self {
60 Self::new(
61 "openrouter/auto",
62 Some("Auto Router"),
63 Some(2000000),
64 Some(true),
65 Some(false),
66 )
67 }
68
69 pub fn default() -> Self {
70 Self::default_fast()
71 }
72
73 pub fn new(
74 name: &str,
75 display_name: Option<&str>,
76 max_tokens: Option<u64>,
77 supports_tools: Option<bool>,
78 supports_images: Option<bool>,
79 ) -> Self {
80 Self {
81 name: name.to_owned(),
82 display_name: display_name.map(|s| s.to_owned()),
83 max_tokens: max_tokens.unwrap_or(2000000),
84 supports_tools,
85 supports_images,
86 }
87 }
88
89 pub fn id(&self) -> &str {
90 &self.name
91 }
92
93 pub fn display_name(&self) -> &str {
94 self.display_name.as_ref().unwrap_or(&self.name)
95 }
96
97 pub fn max_token_count(&self) -> u64 {
98 self.max_tokens
99 }
100
101 pub fn max_output_tokens(&self) -> Option<u64> {
102 None
103 }
104
105 pub fn supports_tool_calls(&self) -> bool {
106 self.supports_tools.unwrap_or(false)
107 }
108
109 pub fn supports_parallel_tool_calls(&self) -> bool {
110 false
111 }
112}
113
114#[derive(Debug, Serialize, Deserialize)]
115pub struct Request {
116 pub model: String,
117 pub messages: Vec<RequestMessage>,
118 pub stream: bool,
119 #[serde(default, skip_serializing_if = "Option::is_none")]
120 pub max_tokens: Option<u64>,
121 #[serde(default, skip_serializing_if = "Vec::is_empty")]
122 pub stop: Vec<String>,
123 pub temperature: f32,
124 #[serde(default, skip_serializing_if = "Option::is_none")]
125 pub tool_choice: Option<ToolChoice>,
126 #[serde(default, skip_serializing_if = "Option::is_none")]
127 pub parallel_tool_calls: Option<bool>,
128 #[serde(default, skip_serializing_if = "Vec::is_empty")]
129 pub tools: Vec<ToolDefinition>,
130 pub usage: RequestUsage,
131}
132
133#[derive(Debug, Default, Serialize, Deserialize)]
134pub struct RequestUsage {
135 pub include: bool,
136}
137
138#[derive(Debug, Serialize, Deserialize)]
139#[serde(untagged)]
140pub enum ToolChoice {
141 Auto,
142 Required,
143 None,
144 Other(ToolDefinition),
145}
146
147#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
148#[derive(Clone, Deserialize, Serialize, Debug)]
149#[serde(tag = "type", rename_all = "snake_case")]
150pub enum ToolDefinition {
151 #[allow(dead_code)]
152 Function { function: FunctionDefinition },
153}
154
155#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
156#[derive(Clone, Debug, Serialize, Deserialize)]
157pub struct FunctionDefinition {
158 pub name: String,
159 pub description: Option<String>,
160 pub parameters: Option<Value>,
161}
162
163#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
164#[serde(tag = "role", rename_all = "lowercase")]
165pub enum RequestMessage {
166 Assistant {
167 content: Option<MessageContent>,
168 #[serde(default, skip_serializing_if = "Vec::is_empty")]
169 tool_calls: Vec<ToolCall>,
170 },
171 User {
172 content: MessageContent,
173 },
174 System {
175 content: MessageContent,
176 },
177 Tool {
178 content: MessageContent,
179 tool_call_id: String,
180 },
181}
182
183#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
184#[serde(untagged)]
185pub enum MessageContent {
186 Plain(String),
187 Multipart(Vec<MessagePart>),
188}
189
190impl MessageContent {
191 pub fn empty() -> Self {
192 Self::Plain(String::new())
193 }
194
195 pub fn push_part(&mut self, part: MessagePart) {
196 match self {
197 Self::Plain(text) if text.is_empty() => {
198 *self = Self::Multipart(vec![part]);
199 }
200 Self::Plain(text) => {
201 let text_part = MessagePart::Text {
202 text: std::mem::take(text),
203 };
204 *self = Self::Multipart(vec![text_part, part]);
205 }
206 Self::Multipart(parts) => parts.push(part),
207 }
208 }
209}
210
211impl From<Vec<MessagePart>> for MessageContent {
212 fn from(parts: Vec<MessagePart>) -> Self {
213 if parts.len() == 1 {
214 if let MessagePart::Text { text } = &parts[0] {
215 return Self::Plain(text.clone());
216 }
217 }
218 Self::Multipart(parts)
219 }
220}
221
222impl From<String> for MessageContent {
223 fn from(text: String) -> Self {
224 Self::Plain(text)
225 }
226}
227
228impl From<&str> for MessageContent {
229 fn from(text: &str) -> Self {
230 Self::Plain(text.to_string())
231 }
232}
233
234impl MessageContent {
235 pub fn as_text(&self) -> Option<&str> {
236 match self {
237 Self::Plain(text) => Some(text),
238 Self::Multipart(parts) if parts.len() == 1 => {
239 if let MessagePart::Text { text } = &parts[0] {
240 Some(text)
241 } else {
242 None
243 }
244 }
245 _ => None,
246 }
247 }
248
249 pub fn to_text(&self) -> String {
250 match self {
251 Self::Plain(text) => text.clone(),
252 Self::Multipart(parts) => parts
253 .iter()
254 .filter_map(|part| {
255 if let MessagePart::Text { text } = part {
256 Some(text.as_str())
257 } else {
258 None
259 }
260 })
261 .collect::<Vec<_>>()
262 .join(""),
263 }
264 }
265}
266
267#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
268#[serde(tag = "type", rename_all = "snake_case")]
269pub enum MessagePart {
270 Text {
271 text: String,
272 },
273 #[serde(rename = "image_url")]
274 Image {
275 image_url: String,
276 },
277}
278
279#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
280pub struct ToolCall {
281 pub id: String,
282 #[serde(flatten)]
283 pub content: ToolCallContent,
284}
285
286#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
287#[serde(tag = "type", rename_all = "lowercase")]
288pub enum ToolCallContent {
289 Function { function: FunctionContent },
290}
291
292#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
293pub struct FunctionContent {
294 pub name: String,
295 pub arguments: String,
296}
297
298#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
299pub struct ResponseMessageDelta {
300 pub role: Option<Role>,
301 pub content: Option<String>,
302 #[serde(default, skip_serializing_if = "is_none_or_empty")]
303 pub tool_calls: Option<Vec<ToolCallChunk>>,
304}
305
306#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
307pub struct ToolCallChunk {
308 pub index: usize,
309 pub id: Option<String>,
310 pub function: Option<FunctionChunk>,
311}
312
313#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
314pub struct FunctionChunk {
315 pub name: Option<String>,
316 pub arguments: Option<String>,
317}
318
319#[derive(Serialize, Deserialize, Debug)]
320pub struct Usage {
321 pub prompt_tokens: u64,
322 pub completion_tokens: u64,
323 pub total_tokens: u64,
324}
325
326#[derive(Serialize, Deserialize, Debug)]
327pub struct ChoiceDelta {
328 pub index: u32,
329 pub delta: ResponseMessageDelta,
330 pub finish_reason: Option<String>,
331}
332
333#[derive(Serialize, Deserialize, Debug)]
334pub struct ResponseStreamEvent {
335 #[serde(default, skip_serializing_if = "Option::is_none")]
336 pub id: Option<String>,
337 pub created: u32,
338 pub model: String,
339 pub choices: Vec<ChoiceDelta>,
340 pub usage: Option<Usage>,
341}
342
343#[derive(Serialize, Deserialize, Debug)]
344pub struct Response {
345 pub id: String,
346 pub object: String,
347 pub created: u64,
348 pub model: String,
349 pub choices: Vec<Choice>,
350 pub usage: Usage,
351}
352
353#[derive(Serialize, Deserialize, Debug)]
354pub struct Choice {
355 pub index: u32,
356 pub message: RequestMessage,
357 pub finish_reason: Option<String>,
358}
359
360#[derive(Default, Debug, Clone, PartialEq, Deserialize)]
361pub struct ListModelsResponse {
362 pub data: Vec<ModelEntry>,
363}
364
365#[derive(Default, Debug, Clone, PartialEq, Deserialize)]
366pub struct ModelEntry {
367 pub id: String,
368 pub name: String,
369 pub created: usize,
370 pub description: String,
371 #[serde(default, skip_serializing_if = "Option::is_none")]
372 pub context_length: Option<u64>,
373 #[serde(default, skip_serializing_if = "Vec::is_empty")]
374 pub supported_parameters: Vec<String>,
375 #[serde(default, skip_serializing_if = "Option::is_none")]
376 pub architecture: Option<ModelArchitecture>,
377}
378
379#[derive(Default, Debug, Clone, PartialEq, Deserialize)]
380pub struct ModelArchitecture {
381 #[serde(default, skip_serializing_if = "Vec::is_empty")]
382 pub input_modalities: Vec<String>,
383}
384
385pub async fn complete(
386 client: &dyn HttpClient,
387 api_url: &str,
388 api_key: &str,
389 request: Request,
390) -> Result<Response> {
391 let uri = format!("{api_url}/chat/completions");
392 let request_builder = HttpRequest::builder()
393 .method(Method::POST)
394 .uri(uri)
395 .header("Content-Type", "application/json")
396 .header("Authorization", format!("Bearer {}", api_key))
397 .header("HTTP-Referer", "https://zed.dev")
398 .header("X-Title", "Zed Editor");
399
400 let mut request_body = request;
401 request_body.stream = false;
402
403 let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request_body)?))?;
404 let mut response = client.send(request).await?;
405
406 if response.status().is_success() {
407 let mut body = String::new();
408 response.body_mut().read_to_string(&mut body).await?;
409 let response: Response = serde_json::from_str(&body)?;
410 Ok(response)
411 } else {
412 let mut body = String::new();
413 response.body_mut().read_to_string(&mut body).await?;
414
415 #[derive(Deserialize)]
416 struct OpenRouterResponse {
417 error: OpenRouterError,
418 }
419
420 #[derive(Deserialize)]
421 struct OpenRouterError {
422 message: String,
423 #[serde(default)]
424 code: String,
425 }
426
427 match serde_json::from_str::<OpenRouterResponse>(&body) {
428 Ok(response) if !response.error.message.is_empty() => {
429 let error_message = if !response.error.code.is_empty() {
430 format!("{}: {}", response.error.code, response.error.message)
431 } else {
432 response.error.message
433 };
434
435 Err(anyhow!(
436 "Failed to connect to OpenRouter API: {}",
437 error_message
438 ))
439 }
440 _ => Err(anyhow!(
441 "Failed to connect to OpenRouter API: {} {}",
442 response.status(),
443 body,
444 )),
445 }
446 }
447}
448
449pub async fn stream_completion(
450 client: &dyn HttpClient,
451 api_url: &str,
452 api_key: &str,
453 request: Request,
454) -> Result<BoxStream<'static, Result<ResponseStreamEvent>>> {
455 let uri = format!("{api_url}/chat/completions");
456 let request_builder = HttpRequest::builder()
457 .method(Method::POST)
458 .uri(uri)
459 .header("Content-Type", "application/json")
460 .header("Authorization", format!("Bearer {}", api_key))
461 .header("HTTP-Referer", "https://zed.dev")
462 .header("X-Title", "Zed Editor");
463
464 let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
465 let mut response = client.send(request).await?;
466
467 if response.status().is_success() {
468 let reader = BufReader::new(response.into_body());
469 Ok(reader
470 .lines()
471 .filter_map(|line| async move {
472 match line {
473 Ok(line) => {
474 if line.starts_with(':') {
475 return None;
476 }
477
478 let line = line.strip_prefix("data: ")?;
479 if line == "[DONE]" {
480 None
481 } else {
482 match serde_json::from_str::<ResponseStreamEvent>(line) {
483 Ok(response) => Some(Ok(response)),
484 Err(error) => {
485 #[derive(Deserialize)]
486 struct ErrorResponse {
487 error: String,
488 }
489
490 match serde_json::from_str::<ErrorResponse>(line) {
491 Ok(err_response) => Some(Err(anyhow!(err_response.error))),
492 Err(_) => {
493 if line.trim().is_empty() {
494 None
495 } else {
496 Some(Err(anyhow!(
497 "Failed to parse response: {}. Original content: '{}'",
498 error, line
499 )))
500 }
501 }
502 }
503 }
504 }
505 }
506 }
507 Err(error) => Some(Err(anyhow!(error))),
508 }
509 })
510 .boxed())
511 } else {
512 let mut body = String::new();
513 response.body_mut().read_to_string(&mut body).await?;
514
515 #[derive(Deserialize)]
516 struct OpenRouterResponse {
517 error: OpenRouterError,
518 }
519
520 #[derive(Deserialize)]
521 struct OpenRouterError {
522 message: String,
523 #[serde(default)]
524 code: String,
525 }
526
527 match serde_json::from_str::<OpenRouterResponse>(&body) {
528 Ok(response) if !response.error.message.is_empty() => {
529 let error_message = if !response.error.code.is_empty() {
530 format!("{}: {}", response.error.code, response.error.message)
531 } else {
532 response.error.message
533 };
534
535 Err(anyhow!(
536 "Failed to connect to OpenRouter API: {}",
537 error_message
538 ))
539 }
540 _ => Err(anyhow!(
541 "Failed to connect to OpenRouter API: {} {}",
542 response.status(),
543 body,
544 )),
545 }
546 }
547}
548
549pub async fn list_models(client: &dyn HttpClient, api_url: &str) -> Result<Vec<Model>> {
550 let uri = format!("{api_url}/models");
551 let request_builder = HttpRequest::builder()
552 .method(Method::GET)
553 .uri(uri)
554 .header("Accept", "application/json");
555
556 let request = request_builder.body(AsyncBody::default())?;
557 let mut response = client.send(request).await?;
558
559 let mut body = String::new();
560 response.body_mut().read_to_string(&mut body).await?;
561
562 if response.status().is_success() {
563 let response: ListModelsResponse =
564 serde_json::from_str(&body).context("Unable to parse OpenRouter models response")?;
565
566 let models = response
567 .data
568 .into_iter()
569 .map(|entry| Model {
570 name: entry.id,
571 // OpenRouter returns display names in the format "provider_name: model_name".
572 // When displayed in the UI, these names can get truncated from the right.
573 // Since users typically already know the provider, we extract just the model name
574 // portion (after the colon) to create a more concise and user-friendly label
575 // for the model dropdown in the agent panel.
576 display_name: Some(
577 entry
578 .name
579 .split(':')
580 .next_back()
581 .unwrap_or(&entry.name)
582 .trim()
583 .to_string(),
584 ),
585 max_tokens: entry.context_length.unwrap_or(2000000),
586 supports_tools: Some(entry.supported_parameters.contains(&"tools".to_string())),
587 supports_images: Some(
588 entry
589 .architecture
590 .as_ref()
591 .map(|arch| arch.input_modalities.contains(&"image".to_string()))
592 .unwrap_or(false),
593 ),
594 })
595 .collect();
596
597 Ok(models)
598 } else {
599 Err(anyhow!(
600 "Failed to connect to OpenRouter API: {} {}",
601 response.status(),
602 body,
603 ))
604 }
605}