open_router.rs

  1use anyhow::{Context, Result, anyhow};
  2use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
  3use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
  4use serde::{Deserialize, Serialize};
  5use serde_json::Value;
  6use std::convert::TryFrom;
  7
  8pub const OPEN_ROUTER_API_URL: &str = "https://openrouter.ai/api/v1";
  9
 10fn is_none_or_empty<T: AsRef<[U]>, U>(opt: &Option<T>) -> bool {
 11    opt.as_ref().map_or(true, |v| v.as_ref().is_empty())
 12}
 13
 14#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
 15#[serde(rename_all = "lowercase")]
 16pub enum Role {
 17    User,
 18    Assistant,
 19    System,
 20    Tool,
 21}
 22
 23impl TryFrom<String> for Role {
 24    type Error = anyhow::Error;
 25
 26    fn try_from(value: String) -> Result<Self> {
 27        match value.as_str() {
 28            "user" => Ok(Self::User),
 29            "assistant" => Ok(Self::Assistant),
 30            "system" => Ok(Self::System),
 31            "tool" => Ok(Self::Tool),
 32            _ => Err(anyhow!("invalid role '{value}'")),
 33        }
 34    }
 35}
 36
 37impl From<Role> for String {
 38    fn from(val: Role) -> Self {
 39        match val {
 40            Role::User => "user".to_owned(),
 41            Role::Assistant => "assistant".to_owned(),
 42            Role::System => "system".to_owned(),
 43            Role::Tool => "tool".to_owned(),
 44        }
 45    }
 46}
 47
 48#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
 49#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
 50pub struct Model {
 51    pub name: String,
 52    pub display_name: Option<String>,
 53    pub max_tokens: usize,
 54    pub supports_tools: Option<bool>,
 55    pub supports_images: Option<bool>,
 56}
 57
 58impl Model {
 59    pub fn default_fast() -> Self {
 60        Self::new(
 61            "openrouter/auto",
 62            Some("Auto Router"),
 63            Some(2000000),
 64            Some(true),
 65            Some(false),
 66        )
 67    }
 68
 69    pub fn default() -> Self {
 70        Self::default_fast()
 71    }
 72
 73    pub fn new(
 74        name: &str,
 75        display_name: Option<&str>,
 76        max_tokens: Option<usize>,
 77        supports_tools: Option<bool>,
 78        supports_images: Option<bool>,
 79    ) -> Self {
 80        Self {
 81            name: name.to_owned(),
 82            display_name: display_name.map(|s| s.to_owned()),
 83            max_tokens: max_tokens.unwrap_or(2000000),
 84            supports_tools,
 85            supports_images,
 86        }
 87    }
 88
 89    pub fn id(&self) -> &str {
 90        &self.name
 91    }
 92
 93    pub fn display_name(&self) -> &str {
 94        self.display_name.as_ref().unwrap_or(&self.name)
 95    }
 96
 97    pub fn max_token_count(&self) -> usize {
 98        self.max_tokens
 99    }
100
101    pub fn max_output_tokens(&self) -> Option<u32> {
102        None
103    }
104
105    pub fn supports_tool_calls(&self) -> bool {
106        self.supports_tools.unwrap_or(false)
107    }
108
109    pub fn supports_parallel_tool_calls(&self) -> bool {
110        false
111    }
112}
113
114#[derive(Debug, Serialize, Deserialize)]
115pub struct Request {
116    pub model: String,
117    pub messages: Vec<RequestMessage>,
118    pub stream: bool,
119    #[serde(default, skip_serializing_if = "Option::is_none")]
120    pub max_tokens: Option<u32>,
121    #[serde(default, skip_serializing_if = "Vec::is_empty")]
122    pub stop: Vec<String>,
123    pub temperature: f32,
124    #[serde(default, skip_serializing_if = "Option::is_none")]
125    pub tool_choice: Option<ToolChoice>,
126    #[serde(default, skip_serializing_if = "Option::is_none")]
127    pub parallel_tool_calls: Option<bool>,
128    #[serde(default, skip_serializing_if = "Vec::is_empty")]
129    pub tools: Vec<ToolDefinition>,
130}
131
132#[derive(Debug, Serialize, Deserialize)]
133#[serde(untagged)]
134pub enum ToolChoice {
135    Auto,
136    Required,
137    None,
138    Other(ToolDefinition),
139}
140
141#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
142#[derive(Clone, Deserialize, Serialize, Debug)]
143#[serde(tag = "type", rename_all = "snake_case")]
144pub enum ToolDefinition {
145    #[allow(dead_code)]
146    Function { function: FunctionDefinition },
147}
148
149#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
150#[derive(Clone, Debug, Serialize, Deserialize)]
151pub struct FunctionDefinition {
152    pub name: String,
153    pub description: Option<String>,
154    pub parameters: Option<Value>,
155}
156
157#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
158#[serde(tag = "role", rename_all = "lowercase")]
159pub enum RequestMessage {
160    Assistant {
161        content: Option<MessageContent>,
162        #[serde(default, skip_serializing_if = "Vec::is_empty")]
163        tool_calls: Vec<ToolCall>,
164    },
165    User {
166        content: MessageContent,
167    },
168    System {
169        content: MessageContent,
170    },
171    Tool {
172        content: MessageContent,
173        tool_call_id: String,
174    },
175}
176
177#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
178#[serde(untagged)]
179pub enum MessageContent {
180    Plain(String),
181    Multipart(Vec<MessagePart>),
182}
183
184impl MessageContent {
185    pub fn empty() -> Self {
186        Self::Plain(String::new())
187    }
188
189    pub fn push_part(&mut self, part: MessagePart) {
190        match self {
191            Self::Plain(text) if text.is_empty() => {
192                *self = Self::Multipart(vec![part]);
193            }
194            Self::Plain(text) => {
195                let text_part = MessagePart::Text {
196                    text: std::mem::take(text),
197                };
198                *self = Self::Multipart(vec![text_part, part]);
199            }
200            Self::Multipart(parts) => parts.push(part),
201        }
202    }
203}
204
205impl From<Vec<MessagePart>> for MessageContent {
206    fn from(parts: Vec<MessagePart>) -> Self {
207        if parts.len() == 1 {
208            if let MessagePart::Text { text } = &parts[0] {
209                return Self::Plain(text.clone());
210            }
211        }
212        Self::Multipart(parts)
213    }
214}
215
216impl From<String> for MessageContent {
217    fn from(text: String) -> Self {
218        Self::Plain(text)
219    }
220}
221
222impl From<&str> for MessageContent {
223    fn from(text: &str) -> Self {
224        Self::Plain(text.to_string())
225    }
226}
227
228impl MessageContent {
229    pub fn as_text(&self) -> Option<&str> {
230        match self {
231            Self::Plain(text) => Some(text),
232            Self::Multipart(parts) if parts.len() == 1 => {
233                if let MessagePart::Text { text } = &parts[0] {
234                    Some(text)
235                } else {
236                    None
237                }
238            }
239            _ => None,
240        }
241    }
242
243    pub fn to_text(&self) -> String {
244        match self {
245            Self::Plain(text) => text.clone(),
246            Self::Multipart(parts) => parts
247                .iter()
248                .filter_map(|part| {
249                    if let MessagePart::Text { text } = part {
250                        Some(text.as_str())
251                    } else {
252                        None
253                    }
254                })
255                .collect::<Vec<_>>()
256                .join(""),
257        }
258    }
259}
260
261#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
262#[serde(tag = "type", rename_all = "snake_case")]
263pub enum MessagePart {
264    Text {
265        text: String,
266    },
267    #[serde(rename = "image_url")]
268    Image {
269        image_url: String,
270    },
271}
272
273#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
274pub struct ToolCall {
275    pub id: String,
276    #[serde(flatten)]
277    pub content: ToolCallContent,
278}
279
280#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
281#[serde(tag = "type", rename_all = "lowercase")]
282pub enum ToolCallContent {
283    Function { function: FunctionContent },
284}
285
286#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
287pub struct FunctionContent {
288    pub name: String,
289    pub arguments: String,
290}
291
292#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
293pub struct ResponseMessageDelta {
294    pub role: Option<Role>,
295    pub content: Option<String>,
296    #[serde(default, skip_serializing_if = "is_none_or_empty")]
297    pub tool_calls: Option<Vec<ToolCallChunk>>,
298}
299
300#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
301pub struct ToolCallChunk {
302    pub index: usize,
303    pub id: Option<String>,
304    pub function: Option<FunctionChunk>,
305}
306
307#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
308pub struct FunctionChunk {
309    pub name: Option<String>,
310    pub arguments: Option<String>,
311}
312
313#[derive(Serialize, Deserialize, Debug)]
314pub struct Usage {
315    pub prompt_tokens: u32,
316    pub completion_tokens: u32,
317    pub total_tokens: u32,
318}
319
320#[derive(Serialize, Deserialize, Debug)]
321pub struct ChoiceDelta {
322    pub index: u32,
323    pub delta: ResponseMessageDelta,
324    pub finish_reason: Option<String>,
325}
326
327#[derive(Serialize, Deserialize, Debug)]
328pub struct ResponseStreamEvent {
329    #[serde(default, skip_serializing_if = "Option::is_none")]
330    pub id: Option<String>,
331    pub created: u32,
332    pub model: String,
333    pub choices: Vec<ChoiceDelta>,
334    pub usage: Option<Usage>,
335}
336
337#[derive(Serialize, Deserialize, Debug)]
338pub struct Response {
339    pub id: String,
340    pub object: String,
341    pub created: u64,
342    pub model: String,
343    pub choices: Vec<Choice>,
344    pub usage: Usage,
345}
346
347#[derive(Serialize, Deserialize, Debug)]
348pub struct Choice {
349    pub index: u32,
350    pub message: RequestMessage,
351    pub finish_reason: Option<String>,
352}
353
354#[derive(Default, Debug, Clone, PartialEq, Deserialize)]
355pub struct ListModelsResponse {
356    pub data: Vec<ModelEntry>,
357}
358
359#[derive(Default, Debug, Clone, PartialEq, Deserialize)]
360pub struct ModelEntry {
361    pub id: String,
362    pub name: String,
363    pub created: usize,
364    pub description: String,
365    #[serde(default, skip_serializing_if = "Option::is_none")]
366    pub context_length: Option<usize>,
367    #[serde(default, skip_serializing_if = "Vec::is_empty")]
368    pub supported_parameters: Vec<String>,
369    #[serde(default, skip_serializing_if = "Option::is_none")]
370    pub architecture: Option<ModelArchitecture>,
371}
372
373#[derive(Default, Debug, Clone, PartialEq, Deserialize)]
374pub struct ModelArchitecture {
375    #[serde(default, skip_serializing_if = "Vec::is_empty")]
376    pub input_modalities: Vec<String>,
377}
378
379pub async fn complete(
380    client: &dyn HttpClient,
381    api_url: &str,
382    api_key: &str,
383    request: Request,
384) -> Result<Response> {
385    let uri = format!("{api_url}/chat/completions");
386    let request_builder = HttpRequest::builder()
387        .method(Method::POST)
388        .uri(uri)
389        .header("Content-Type", "application/json")
390        .header("Authorization", format!("Bearer {}", api_key))
391        .header("HTTP-Referer", "https://zed.dev")
392        .header("X-Title", "Zed Editor");
393
394    let mut request_body = request;
395    request_body.stream = false;
396
397    let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request_body)?))?;
398    let mut response = client.send(request).await?;
399
400    if response.status().is_success() {
401        let mut body = String::new();
402        response.body_mut().read_to_string(&mut body).await?;
403        let response: Response = serde_json::from_str(&body)?;
404        Ok(response)
405    } else {
406        let mut body = String::new();
407        response.body_mut().read_to_string(&mut body).await?;
408
409        #[derive(Deserialize)]
410        struct OpenRouterResponse {
411            error: OpenRouterError,
412        }
413
414        #[derive(Deserialize)]
415        struct OpenRouterError {
416            message: String,
417            #[serde(default)]
418            code: String,
419        }
420
421        match serde_json::from_str::<OpenRouterResponse>(&body) {
422            Ok(response) if !response.error.message.is_empty() => {
423                let error_message = if !response.error.code.is_empty() {
424                    format!("{}: {}", response.error.code, response.error.message)
425                } else {
426                    response.error.message
427                };
428
429                Err(anyhow!(
430                    "Failed to connect to OpenRouter API: {}",
431                    error_message
432                ))
433            }
434            _ => Err(anyhow!(
435                "Failed to connect to OpenRouter API: {} {}",
436                response.status(),
437                body,
438            )),
439        }
440    }
441}
442
443pub async fn stream_completion(
444    client: &dyn HttpClient,
445    api_url: &str,
446    api_key: &str,
447    request: Request,
448) -> Result<BoxStream<'static, Result<ResponseStreamEvent>>> {
449    let uri = format!("{api_url}/chat/completions");
450    let request_builder = HttpRequest::builder()
451        .method(Method::POST)
452        .uri(uri)
453        .header("Content-Type", "application/json")
454        .header("Authorization", format!("Bearer {}", api_key))
455        .header("HTTP-Referer", "https://zed.dev")
456        .header("X-Title", "Zed Editor");
457
458    let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
459    let mut response = client.send(request).await?;
460
461    if response.status().is_success() {
462        let reader = BufReader::new(response.into_body());
463        Ok(reader
464            .lines()
465            .filter_map(|line| async move {
466                match line {
467                    Ok(line) => {
468                        if line.starts_with(':') {
469                            return None;
470                        }
471
472                        let line = line.strip_prefix("data: ")?;
473                        if line == "[DONE]" {
474                            None
475                        } else {
476                            match serde_json::from_str::<ResponseStreamEvent>(line) {
477                                Ok(response) => Some(Ok(response)),
478                                Err(error) => {
479                                    #[derive(Deserialize)]
480                                    struct ErrorResponse {
481                                        error: String,
482                                    }
483
484                                    match serde_json::from_str::<ErrorResponse>(line) {
485                                        Ok(err_response) => Some(Err(anyhow!(err_response.error))),
486                                        Err(_) => {
487                                            if line.trim().is_empty() {
488                                                None
489                                            } else {
490                                                Some(Err(anyhow!(
491                                                    "Failed to parse response: {}. Original content: '{}'",
492                                                    error, line
493                                                )))
494                                            }
495                                        }
496                                    }
497                                }
498                            }
499                        }
500                    }
501                    Err(error) => Some(Err(anyhow!(error))),
502                }
503            })
504            .boxed())
505    } else {
506        let mut body = String::new();
507        response.body_mut().read_to_string(&mut body).await?;
508
509        #[derive(Deserialize)]
510        struct OpenRouterResponse {
511            error: OpenRouterError,
512        }
513
514        #[derive(Deserialize)]
515        struct OpenRouterError {
516            message: String,
517            #[serde(default)]
518            code: String,
519        }
520
521        match serde_json::from_str::<OpenRouterResponse>(&body) {
522            Ok(response) if !response.error.message.is_empty() => {
523                let error_message = if !response.error.code.is_empty() {
524                    format!("{}: {}", response.error.code, response.error.message)
525                } else {
526                    response.error.message
527                };
528
529                Err(anyhow!(
530                    "Failed to connect to OpenRouter API: {}",
531                    error_message
532                ))
533            }
534            _ => Err(anyhow!(
535                "Failed to connect to OpenRouter API: {} {}",
536                response.status(),
537                body,
538            )),
539        }
540    }
541}
542
543pub async fn list_models(client: &dyn HttpClient, api_url: &str) -> Result<Vec<Model>> {
544    let uri = format!("{api_url}/models");
545    let request_builder = HttpRequest::builder()
546        .method(Method::GET)
547        .uri(uri)
548        .header("Accept", "application/json");
549
550    let request = request_builder.body(AsyncBody::default())?;
551    let mut response = client.send(request).await?;
552
553    let mut body = String::new();
554    response.body_mut().read_to_string(&mut body).await?;
555
556    if response.status().is_success() {
557        let response: ListModelsResponse =
558            serde_json::from_str(&body).context("Unable to parse OpenRouter models response")?;
559
560        let models = response
561            .data
562            .into_iter()
563            .map(|entry| Model {
564                name: entry.id,
565                // OpenRouter returns display names in the format "provider_name: model_name".
566                // When displayed in the UI, these names can get truncated from the right.
567                // Since users typically already know the provider, we extract just the model name
568                // portion (after the colon) to create a more concise and user-friendly label
569                // for the model dropdown in the agent panel.
570                display_name: Some(
571                    entry
572                        .name
573                        .split(':')
574                        .next_back()
575                        .unwrap_or(&entry.name)
576                        .trim()
577                        .to_string(),
578                ),
579                max_tokens: entry.context_length.unwrap_or(2000000),
580                supports_tools: Some(entry.supported_parameters.contains(&"tools".to_string())),
581                supports_images: Some(
582                    entry
583                        .architecture
584                        .as_ref()
585                        .map(|arch| arch.input_modalities.contains(&"image".to_string()))
586                        .unwrap_or(false),
587                ),
588            })
589            .collect();
590
591        Ok(models)
592    } else {
593        Err(anyhow!(
594            "Failed to connect to OpenRouter API: {} {}",
595            response.status(),
596            body,
597        ))
598    }
599}