1use anyhow::{Context, Result, anyhow};
2use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
3use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
4use serde::{Deserialize, Serialize};
5use serde_json::Value;
6use std::convert::TryFrom;
7
8pub const OPEN_ROUTER_API_URL: &str = "https://openrouter.ai/api/v1";
9
10fn is_none_or_empty<T: AsRef<[U]>, U>(opt: &Option<T>) -> bool {
11 opt.as_ref().map_or(true, |v| v.as_ref().is_empty())
12}
13
14#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
15#[serde(rename_all = "lowercase")]
16pub enum Role {
17 User,
18 Assistant,
19 System,
20 Tool,
21}
22
23impl TryFrom<String> for Role {
24 type Error = anyhow::Error;
25
26 fn try_from(value: String) -> Result<Self> {
27 match value.as_str() {
28 "user" => Ok(Self::User),
29 "assistant" => Ok(Self::Assistant),
30 "system" => Ok(Self::System),
31 "tool" => Ok(Self::Tool),
32 _ => Err(anyhow!("invalid role '{value}'")),
33 }
34 }
35}
36
37impl From<Role> for String {
38 fn from(val: Role) -> Self {
39 match val {
40 Role::User => "user".to_owned(),
41 Role::Assistant => "assistant".to_owned(),
42 Role::System => "system".to_owned(),
43 Role::Tool => "tool".to_owned(),
44 }
45 }
46}
47
48#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
49#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
50pub struct Model {
51 pub name: String,
52 pub display_name: Option<String>,
53 pub max_tokens: usize,
54 pub supports_tools: Option<bool>,
55 pub supports_images: Option<bool>,
56}
57
58impl Model {
59 pub fn default_fast() -> Self {
60 Self::new(
61 "openrouter/auto",
62 Some("Auto Router"),
63 Some(2000000),
64 Some(true),
65 Some(false),
66 )
67 }
68
69 pub fn default() -> Self {
70 Self::default_fast()
71 }
72
73 pub fn new(
74 name: &str,
75 display_name: Option<&str>,
76 max_tokens: Option<usize>,
77 supports_tools: Option<bool>,
78 supports_images: Option<bool>,
79 ) -> Self {
80 Self {
81 name: name.to_owned(),
82 display_name: display_name.map(|s| s.to_owned()),
83 max_tokens: max_tokens.unwrap_or(2000000),
84 supports_tools,
85 supports_images,
86 }
87 }
88
89 pub fn id(&self) -> &str {
90 &self.name
91 }
92
93 pub fn display_name(&self) -> &str {
94 self.display_name.as_ref().unwrap_or(&self.name)
95 }
96
97 pub fn max_token_count(&self) -> usize {
98 self.max_tokens
99 }
100
101 pub fn max_output_tokens(&self) -> Option<u32> {
102 None
103 }
104
105 pub fn supports_tool_calls(&self) -> bool {
106 self.supports_tools.unwrap_or(false)
107 }
108
109 pub fn supports_parallel_tool_calls(&self) -> bool {
110 false
111 }
112}
113
114#[derive(Debug, Serialize, Deserialize)]
115pub struct Request {
116 pub model: String,
117 pub messages: Vec<RequestMessage>,
118 pub stream: bool,
119 #[serde(default, skip_serializing_if = "Option::is_none")]
120 pub max_tokens: Option<u32>,
121 #[serde(default, skip_serializing_if = "Vec::is_empty")]
122 pub stop: Vec<String>,
123 pub temperature: f32,
124 #[serde(default, skip_serializing_if = "Option::is_none")]
125 pub tool_choice: Option<ToolChoice>,
126 #[serde(default, skip_serializing_if = "Option::is_none")]
127 pub parallel_tool_calls: Option<bool>,
128 #[serde(default, skip_serializing_if = "Vec::is_empty")]
129 pub tools: Vec<ToolDefinition>,
130}
131
132#[derive(Debug, Serialize, Deserialize)]
133#[serde(untagged)]
134pub enum ToolChoice {
135 Auto,
136 Required,
137 None,
138 Other(ToolDefinition),
139}
140
141#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
142#[derive(Clone, Deserialize, Serialize, Debug)]
143#[serde(tag = "type", rename_all = "snake_case")]
144pub enum ToolDefinition {
145 #[allow(dead_code)]
146 Function { function: FunctionDefinition },
147}
148
149#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
150#[derive(Clone, Debug, Serialize, Deserialize)]
151pub struct FunctionDefinition {
152 pub name: String,
153 pub description: Option<String>,
154 pub parameters: Option<Value>,
155}
156
157#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
158#[serde(tag = "role", rename_all = "lowercase")]
159pub enum RequestMessage {
160 Assistant {
161 content: Option<MessageContent>,
162 #[serde(default, skip_serializing_if = "Vec::is_empty")]
163 tool_calls: Vec<ToolCall>,
164 },
165 User {
166 content: MessageContent,
167 },
168 System {
169 content: MessageContent,
170 },
171 Tool {
172 content: MessageContent,
173 tool_call_id: String,
174 },
175}
176
177#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
178#[serde(untagged)]
179pub enum MessageContent {
180 Plain(String),
181 Multipart(Vec<MessagePart>),
182}
183
184impl MessageContent {
185 pub fn empty() -> Self {
186 Self::Plain(String::new())
187 }
188
189 pub fn push_part(&mut self, part: MessagePart) {
190 match self {
191 Self::Plain(text) if text.is_empty() => {
192 *self = Self::Multipart(vec![part]);
193 }
194 Self::Plain(text) => {
195 let text_part = MessagePart::Text {
196 text: std::mem::take(text),
197 };
198 *self = Self::Multipart(vec![text_part, part]);
199 }
200 Self::Multipart(parts) => parts.push(part),
201 }
202 }
203}
204
205impl From<Vec<MessagePart>> for MessageContent {
206 fn from(parts: Vec<MessagePart>) -> Self {
207 if parts.len() == 1 {
208 if let MessagePart::Text { text } = &parts[0] {
209 return Self::Plain(text.clone());
210 }
211 }
212 Self::Multipart(parts)
213 }
214}
215
216impl From<String> for MessageContent {
217 fn from(text: String) -> Self {
218 Self::Plain(text)
219 }
220}
221
222impl From<&str> for MessageContent {
223 fn from(text: &str) -> Self {
224 Self::Plain(text.to_string())
225 }
226}
227
228impl MessageContent {
229 pub fn as_text(&self) -> Option<&str> {
230 match self {
231 Self::Plain(text) => Some(text),
232 Self::Multipart(parts) if parts.len() == 1 => {
233 if let MessagePart::Text { text } = &parts[0] {
234 Some(text)
235 } else {
236 None
237 }
238 }
239 _ => None,
240 }
241 }
242
243 pub fn to_text(&self) -> String {
244 match self {
245 Self::Plain(text) => text.clone(),
246 Self::Multipart(parts) => parts
247 .iter()
248 .filter_map(|part| {
249 if let MessagePart::Text { text } = part {
250 Some(text.as_str())
251 } else {
252 None
253 }
254 })
255 .collect::<Vec<_>>()
256 .join(""),
257 }
258 }
259}
260
261#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
262#[serde(tag = "type", rename_all = "snake_case")]
263pub enum MessagePart {
264 Text {
265 text: String,
266 },
267 #[serde(rename = "image_url")]
268 Image {
269 image_url: String,
270 },
271}
272
273#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
274pub struct ToolCall {
275 pub id: String,
276 #[serde(flatten)]
277 pub content: ToolCallContent,
278}
279
280#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
281#[serde(tag = "type", rename_all = "lowercase")]
282pub enum ToolCallContent {
283 Function { function: FunctionContent },
284}
285
286#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
287pub struct FunctionContent {
288 pub name: String,
289 pub arguments: String,
290}
291
292#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
293pub struct ResponseMessageDelta {
294 pub role: Option<Role>,
295 pub content: Option<String>,
296 #[serde(default, skip_serializing_if = "is_none_or_empty")]
297 pub tool_calls: Option<Vec<ToolCallChunk>>,
298}
299
300#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
301pub struct ToolCallChunk {
302 pub index: usize,
303 pub id: Option<String>,
304 pub function: Option<FunctionChunk>,
305}
306
307#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
308pub struct FunctionChunk {
309 pub name: Option<String>,
310 pub arguments: Option<String>,
311}
312
313#[derive(Serialize, Deserialize, Debug)]
314pub struct Usage {
315 pub prompt_tokens: u32,
316 pub completion_tokens: u32,
317 pub total_tokens: u32,
318}
319
320#[derive(Serialize, Deserialize, Debug)]
321pub struct ChoiceDelta {
322 pub index: u32,
323 pub delta: ResponseMessageDelta,
324 pub finish_reason: Option<String>,
325}
326
327#[derive(Serialize, Deserialize, Debug)]
328pub struct ResponseStreamEvent {
329 #[serde(default, skip_serializing_if = "Option::is_none")]
330 pub id: Option<String>,
331 pub created: u32,
332 pub model: String,
333 pub choices: Vec<ChoiceDelta>,
334 pub usage: Option<Usage>,
335}
336
337#[derive(Serialize, Deserialize, Debug)]
338pub struct Response {
339 pub id: String,
340 pub object: String,
341 pub created: u64,
342 pub model: String,
343 pub choices: Vec<Choice>,
344 pub usage: Usage,
345}
346
347#[derive(Serialize, Deserialize, Debug)]
348pub struct Choice {
349 pub index: u32,
350 pub message: RequestMessage,
351 pub finish_reason: Option<String>,
352}
353
354#[derive(Default, Debug, Clone, PartialEq, Deserialize)]
355pub struct ListModelsResponse {
356 pub data: Vec<ModelEntry>,
357}
358
359#[derive(Default, Debug, Clone, PartialEq, Deserialize)]
360pub struct ModelEntry {
361 pub id: String,
362 pub name: String,
363 pub created: usize,
364 pub description: String,
365 #[serde(default, skip_serializing_if = "Option::is_none")]
366 pub context_length: Option<usize>,
367 #[serde(default, skip_serializing_if = "Vec::is_empty")]
368 pub supported_parameters: Vec<String>,
369 #[serde(default, skip_serializing_if = "Option::is_none")]
370 pub architecture: Option<ModelArchitecture>,
371}
372
373#[derive(Default, Debug, Clone, PartialEq, Deserialize)]
374pub struct ModelArchitecture {
375 #[serde(default, skip_serializing_if = "Vec::is_empty")]
376 pub input_modalities: Vec<String>,
377}
378
379pub async fn complete(
380 client: &dyn HttpClient,
381 api_url: &str,
382 api_key: &str,
383 request: Request,
384) -> Result<Response> {
385 let uri = format!("{api_url}/chat/completions");
386 let request_builder = HttpRequest::builder()
387 .method(Method::POST)
388 .uri(uri)
389 .header("Content-Type", "application/json")
390 .header("Authorization", format!("Bearer {}", api_key))
391 .header("HTTP-Referer", "https://zed.dev")
392 .header("X-Title", "Zed Editor");
393
394 let mut request_body = request;
395 request_body.stream = false;
396
397 let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request_body)?))?;
398 let mut response = client.send(request).await?;
399
400 if response.status().is_success() {
401 let mut body = String::new();
402 response.body_mut().read_to_string(&mut body).await?;
403 let response: Response = serde_json::from_str(&body)?;
404 Ok(response)
405 } else {
406 let mut body = String::new();
407 response.body_mut().read_to_string(&mut body).await?;
408
409 #[derive(Deserialize)]
410 struct OpenRouterResponse {
411 error: OpenRouterError,
412 }
413
414 #[derive(Deserialize)]
415 struct OpenRouterError {
416 message: String,
417 #[serde(default)]
418 code: String,
419 }
420
421 match serde_json::from_str::<OpenRouterResponse>(&body) {
422 Ok(response) if !response.error.message.is_empty() => {
423 let error_message = if !response.error.code.is_empty() {
424 format!("{}: {}", response.error.code, response.error.message)
425 } else {
426 response.error.message
427 };
428
429 Err(anyhow!(
430 "Failed to connect to OpenRouter API: {}",
431 error_message
432 ))
433 }
434 _ => Err(anyhow!(
435 "Failed to connect to OpenRouter API: {} {}",
436 response.status(),
437 body,
438 )),
439 }
440 }
441}
442
443pub async fn stream_completion(
444 client: &dyn HttpClient,
445 api_url: &str,
446 api_key: &str,
447 request: Request,
448) -> Result<BoxStream<'static, Result<ResponseStreamEvent>>> {
449 let uri = format!("{api_url}/chat/completions");
450 let request_builder = HttpRequest::builder()
451 .method(Method::POST)
452 .uri(uri)
453 .header("Content-Type", "application/json")
454 .header("Authorization", format!("Bearer {}", api_key))
455 .header("HTTP-Referer", "https://zed.dev")
456 .header("X-Title", "Zed Editor");
457
458 let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
459 let mut response = client.send(request).await?;
460
461 if response.status().is_success() {
462 let reader = BufReader::new(response.into_body());
463 Ok(reader
464 .lines()
465 .filter_map(|line| async move {
466 match line {
467 Ok(line) => {
468 if line.starts_with(':') {
469 return None;
470 }
471
472 let line = line.strip_prefix("data: ")?;
473 if line == "[DONE]" {
474 None
475 } else {
476 match serde_json::from_str::<ResponseStreamEvent>(line) {
477 Ok(response) => Some(Ok(response)),
478 Err(error) => {
479 #[derive(Deserialize)]
480 struct ErrorResponse {
481 error: String,
482 }
483
484 match serde_json::from_str::<ErrorResponse>(line) {
485 Ok(err_response) => Some(Err(anyhow!(err_response.error))),
486 Err(_) => {
487 if line.trim().is_empty() {
488 None
489 } else {
490 Some(Err(anyhow!(
491 "Failed to parse response: {}. Original content: '{}'",
492 error, line
493 )))
494 }
495 }
496 }
497 }
498 }
499 }
500 }
501 Err(error) => Some(Err(anyhow!(error))),
502 }
503 })
504 .boxed())
505 } else {
506 let mut body = String::new();
507 response.body_mut().read_to_string(&mut body).await?;
508
509 #[derive(Deserialize)]
510 struct OpenRouterResponse {
511 error: OpenRouterError,
512 }
513
514 #[derive(Deserialize)]
515 struct OpenRouterError {
516 message: String,
517 #[serde(default)]
518 code: String,
519 }
520
521 match serde_json::from_str::<OpenRouterResponse>(&body) {
522 Ok(response) if !response.error.message.is_empty() => {
523 let error_message = if !response.error.code.is_empty() {
524 format!("{}: {}", response.error.code, response.error.message)
525 } else {
526 response.error.message
527 };
528
529 Err(anyhow!(
530 "Failed to connect to OpenRouter API: {}",
531 error_message
532 ))
533 }
534 _ => Err(anyhow!(
535 "Failed to connect to OpenRouter API: {} {}",
536 response.status(),
537 body,
538 )),
539 }
540 }
541}
542
543pub async fn list_models(client: &dyn HttpClient, api_url: &str) -> Result<Vec<Model>> {
544 let uri = format!("{api_url}/models");
545 let request_builder = HttpRequest::builder()
546 .method(Method::GET)
547 .uri(uri)
548 .header("Accept", "application/json");
549
550 let request = request_builder.body(AsyncBody::default())?;
551 let mut response = client.send(request).await?;
552
553 let mut body = String::new();
554 response.body_mut().read_to_string(&mut body).await?;
555
556 if response.status().is_success() {
557 let response: ListModelsResponse =
558 serde_json::from_str(&body).context("Unable to parse OpenRouter models response")?;
559
560 let models = response
561 .data
562 .into_iter()
563 .map(|entry| Model {
564 name: entry.id,
565 // OpenRouter returns display names in the format "provider_name: model_name".
566 // When displayed in the UI, these names can get truncated from the right.
567 // Since users typically already know the provider, we extract just the model name
568 // portion (after the colon) to create a more concise and user-friendly label
569 // for the model dropdown in the agent panel.
570 display_name: Some(
571 entry
572 .name
573 .split(':')
574 .next_back()
575 .unwrap_or(&entry.name)
576 .trim()
577 .to_string(),
578 ),
579 max_tokens: entry.context_length.unwrap_or(2000000),
580 supports_tools: Some(entry.supported_parameters.contains(&"tools".to_string())),
581 supports_images: Some(
582 entry
583 .architecture
584 .as_ref()
585 .map(|arch| arch.input_modalities.contains(&"image".to_string()))
586 .unwrap_or(false),
587 ),
588 })
589 .collect();
590
591 Ok(models)
592 } else {
593 Err(anyhow!(
594 "Failed to connect to OpenRouter API: {} {}",
595 response.status(),
596 body,
597 ))
598 }
599}