open_router.rs

  1use anyhow::{Context, Result, anyhow};
  2use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
  3use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
  4use serde::{Deserialize, Serialize};
  5use serde_json::Value;
  6use std::convert::TryFrom;
  7
  8pub const OPEN_ROUTER_API_URL: &str = "https://openrouter.ai/api/v1";
  9
 10fn is_none_or_empty<T: AsRef<[U]>, U>(opt: &Option<T>) -> bool {
 11    opt.as_ref().map_or(true, |v| v.as_ref().is_empty())
 12}
 13
 14#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
 15#[serde(rename_all = "lowercase")]
 16pub enum Role {
 17    User,
 18    Assistant,
 19    System,
 20    Tool,
 21}
 22
 23impl TryFrom<String> for Role {
 24    type Error = anyhow::Error;
 25
 26    fn try_from(value: String) -> Result<Self> {
 27        match value.as_str() {
 28            "user" => Ok(Self::User),
 29            "assistant" => Ok(Self::Assistant),
 30            "system" => Ok(Self::System),
 31            "tool" => Ok(Self::Tool),
 32            _ => Err(anyhow!("invalid role '{value}'")),
 33        }
 34    }
 35}
 36
 37impl From<Role> for String {
 38    fn from(val: Role) -> Self {
 39        match val {
 40            Role::User => "user".to_owned(),
 41            Role::Assistant => "assistant".to_owned(),
 42            Role::System => "system".to_owned(),
 43            Role::Tool => "tool".to_owned(),
 44        }
 45    }
 46}
 47
 48#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
 49#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
 50pub struct Model {
 51    pub name: String,
 52    pub display_name: Option<String>,
 53    pub max_tokens: u64,
 54    pub supports_tools: Option<bool>,
 55    pub supports_images: Option<bool>,
 56}
 57
 58impl Model {
 59    pub fn default_fast() -> Self {
 60        Self::new(
 61            "openrouter/auto",
 62            Some("Auto Router"),
 63            Some(2000000),
 64            Some(true),
 65            Some(false),
 66        )
 67    }
 68
 69    pub fn default() -> Self {
 70        Self::default_fast()
 71    }
 72
 73    pub fn new(
 74        name: &str,
 75        display_name: Option<&str>,
 76        max_tokens: Option<u64>,
 77        supports_tools: Option<bool>,
 78        supports_images: Option<bool>,
 79    ) -> Self {
 80        Self {
 81            name: name.to_owned(),
 82            display_name: display_name.map(|s| s.to_owned()),
 83            max_tokens: max_tokens.unwrap_or(2000000),
 84            supports_tools,
 85            supports_images,
 86        }
 87    }
 88
 89    pub fn id(&self) -> &str {
 90        &self.name
 91    }
 92
 93    pub fn display_name(&self) -> &str {
 94        self.display_name.as_ref().unwrap_or(&self.name)
 95    }
 96
 97    pub fn max_token_count(&self) -> u64 {
 98        self.max_tokens
 99    }
100
101    pub fn max_output_tokens(&self) -> Option<u64> {
102        None
103    }
104
105    pub fn supports_tool_calls(&self) -> bool {
106        self.supports_tools.unwrap_or(false)
107    }
108
109    pub fn supports_parallel_tool_calls(&self) -> bool {
110        false
111    }
112}
113
114#[derive(Debug, Serialize, Deserialize)]
115pub struct Request {
116    pub model: String,
117    pub messages: Vec<RequestMessage>,
118    pub stream: bool,
119    #[serde(default, skip_serializing_if = "Option::is_none")]
120    pub max_tokens: Option<u64>,
121    #[serde(default, skip_serializing_if = "Vec::is_empty")]
122    pub stop: Vec<String>,
123    pub temperature: f32,
124    #[serde(default, skip_serializing_if = "Option::is_none")]
125    pub tool_choice: Option<ToolChoice>,
126    #[serde(default, skip_serializing_if = "Option::is_none")]
127    pub parallel_tool_calls: Option<bool>,
128    #[serde(default, skip_serializing_if = "Vec::is_empty")]
129    pub tools: Vec<ToolDefinition>,
130    pub usage: RequestUsage,
131}
132
133#[derive(Debug, Default, Serialize, Deserialize)]
134pub struct RequestUsage {
135    pub include: bool,
136}
137
138#[derive(Debug, Serialize, Deserialize)]
139#[serde(untagged)]
140pub enum ToolChoice {
141    Auto,
142    Required,
143    None,
144    Other(ToolDefinition),
145}
146
147#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
148#[derive(Clone, Deserialize, Serialize, Debug)]
149#[serde(tag = "type", rename_all = "snake_case")]
150pub enum ToolDefinition {
151    #[allow(dead_code)]
152    Function { function: FunctionDefinition },
153}
154
155#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
156#[derive(Clone, Debug, Serialize, Deserialize)]
157pub struct FunctionDefinition {
158    pub name: String,
159    pub description: Option<String>,
160    pub parameters: Option<Value>,
161}
162
163#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
164#[serde(tag = "role", rename_all = "lowercase")]
165pub enum RequestMessage {
166    Assistant {
167        content: Option<MessageContent>,
168        #[serde(default, skip_serializing_if = "Vec::is_empty")]
169        tool_calls: Vec<ToolCall>,
170    },
171    User {
172        content: MessageContent,
173    },
174    System {
175        content: MessageContent,
176    },
177    Tool {
178        content: MessageContent,
179        tool_call_id: String,
180    },
181}
182
183#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
184#[serde(untagged)]
185pub enum MessageContent {
186    Plain(String),
187    Multipart(Vec<MessagePart>),
188}
189
190impl MessageContent {
191    pub fn empty() -> Self {
192        Self::Plain(String::new())
193    }
194
195    pub fn push_part(&mut self, part: MessagePart) {
196        match self {
197            Self::Plain(text) if text.is_empty() => {
198                *self = Self::Multipart(vec![part]);
199            }
200            Self::Plain(text) => {
201                let text_part = MessagePart::Text {
202                    text: std::mem::take(text),
203                };
204                *self = Self::Multipart(vec![text_part, part]);
205            }
206            Self::Multipart(parts) => parts.push(part),
207        }
208    }
209}
210
211impl From<Vec<MessagePart>> for MessageContent {
212    fn from(parts: Vec<MessagePart>) -> Self {
213        if parts.len() == 1 {
214            if let MessagePart::Text { text } = &parts[0] {
215                return Self::Plain(text.clone());
216            }
217        }
218        Self::Multipart(parts)
219    }
220}
221
222impl From<String> for MessageContent {
223    fn from(text: String) -> Self {
224        Self::Plain(text)
225    }
226}
227
228impl From<&str> for MessageContent {
229    fn from(text: &str) -> Self {
230        Self::Plain(text.to_string())
231    }
232}
233
234impl MessageContent {
235    pub fn as_text(&self) -> Option<&str> {
236        match self {
237            Self::Plain(text) => Some(text),
238            Self::Multipart(parts) if parts.len() == 1 => {
239                if let MessagePart::Text { text } = &parts[0] {
240                    Some(text)
241                } else {
242                    None
243                }
244            }
245            _ => None,
246        }
247    }
248
249    pub fn to_text(&self) -> String {
250        match self {
251            Self::Plain(text) => text.clone(),
252            Self::Multipart(parts) => parts
253                .iter()
254                .filter_map(|part| {
255                    if let MessagePart::Text { text } = part {
256                        Some(text.as_str())
257                    } else {
258                        None
259                    }
260                })
261                .collect::<Vec<_>>()
262                .join(""),
263        }
264    }
265}
266
267#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
268#[serde(tag = "type", rename_all = "snake_case")]
269pub enum MessagePart {
270    Text {
271        text: String,
272    },
273    #[serde(rename = "image_url")]
274    Image {
275        image_url: String,
276    },
277}
278
279#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
280pub struct ToolCall {
281    pub id: String,
282    #[serde(flatten)]
283    pub content: ToolCallContent,
284}
285
286#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
287#[serde(tag = "type", rename_all = "lowercase")]
288pub enum ToolCallContent {
289    Function { function: FunctionContent },
290}
291
292#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
293pub struct FunctionContent {
294    pub name: String,
295    pub arguments: String,
296}
297
298#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
299pub struct ResponseMessageDelta {
300    pub role: Option<Role>,
301    pub content: Option<String>,
302    #[serde(default, skip_serializing_if = "is_none_or_empty")]
303    pub tool_calls: Option<Vec<ToolCallChunk>>,
304}
305
306#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
307pub struct ToolCallChunk {
308    pub index: usize,
309    pub id: Option<String>,
310    pub function: Option<FunctionChunk>,
311}
312
313#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
314pub struct FunctionChunk {
315    pub name: Option<String>,
316    pub arguments: Option<String>,
317}
318
319#[derive(Serialize, Deserialize, Debug)]
320pub struct Usage {
321    pub prompt_tokens: u64,
322    pub completion_tokens: u64,
323    pub total_tokens: u64,
324}
325
326#[derive(Serialize, Deserialize, Debug)]
327pub struct ChoiceDelta {
328    pub index: u32,
329    pub delta: ResponseMessageDelta,
330    pub finish_reason: Option<String>,
331}
332
333#[derive(Serialize, Deserialize, Debug)]
334pub struct ResponseStreamEvent {
335    #[serde(default, skip_serializing_if = "Option::is_none")]
336    pub id: Option<String>,
337    pub created: u32,
338    pub model: String,
339    pub choices: Vec<ChoiceDelta>,
340    pub usage: Option<Usage>,
341}
342
343#[derive(Serialize, Deserialize, Debug)]
344pub struct Response {
345    pub id: String,
346    pub object: String,
347    pub created: u64,
348    pub model: String,
349    pub choices: Vec<Choice>,
350    pub usage: Usage,
351}
352
353#[derive(Serialize, Deserialize, Debug)]
354pub struct Choice {
355    pub index: u32,
356    pub message: RequestMessage,
357    pub finish_reason: Option<String>,
358}
359
360#[derive(Default, Debug, Clone, PartialEq, Deserialize)]
361pub struct ListModelsResponse {
362    pub data: Vec<ModelEntry>,
363}
364
365#[derive(Default, Debug, Clone, PartialEq, Deserialize)]
366pub struct ModelEntry {
367    pub id: String,
368    pub name: String,
369    pub created: usize,
370    pub description: String,
371    #[serde(default, skip_serializing_if = "Option::is_none")]
372    pub context_length: Option<u64>,
373    #[serde(default, skip_serializing_if = "Vec::is_empty")]
374    pub supported_parameters: Vec<String>,
375    #[serde(default, skip_serializing_if = "Option::is_none")]
376    pub architecture: Option<ModelArchitecture>,
377}
378
379#[derive(Default, Debug, Clone, PartialEq, Deserialize)]
380pub struct ModelArchitecture {
381    #[serde(default, skip_serializing_if = "Vec::is_empty")]
382    pub input_modalities: Vec<String>,
383}
384
385pub async fn complete(
386    client: &dyn HttpClient,
387    api_url: &str,
388    api_key: &str,
389    request: Request,
390) -> Result<Response> {
391    let uri = format!("{api_url}/chat/completions");
392    let request_builder = HttpRequest::builder()
393        .method(Method::POST)
394        .uri(uri)
395        .header("Content-Type", "application/json")
396        .header("Authorization", format!("Bearer {}", api_key))
397        .header("HTTP-Referer", "https://zed.dev")
398        .header("X-Title", "Zed Editor");
399
400    let mut request_body = request;
401    request_body.stream = false;
402
403    let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request_body)?))?;
404    let mut response = client.send(request).await?;
405
406    if response.status().is_success() {
407        let mut body = String::new();
408        response.body_mut().read_to_string(&mut body).await?;
409        let response: Response = serde_json::from_str(&body)?;
410        Ok(response)
411    } else {
412        let mut body = String::new();
413        response.body_mut().read_to_string(&mut body).await?;
414
415        #[derive(Deserialize)]
416        struct OpenRouterResponse {
417            error: OpenRouterError,
418        }
419
420        #[derive(Deserialize)]
421        struct OpenRouterError {
422            message: String,
423            #[serde(default)]
424            code: String,
425        }
426
427        match serde_json::from_str::<OpenRouterResponse>(&body) {
428            Ok(response) if !response.error.message.is_empty() => {
429                let error_message = if !response.error.code.is_empty() {
430                    format!("{}: {}", response.error.code, response.error.message)
431                } else {
432                    response.error.message
433                };
434
435                Err(anyhow!(
436                    "Failed to connect to OpenRouter API: {}",
437                    error_message
438                ))
439            }
440            _ => Err(anyhow!(
441                "Failed to connect to OpenRouter API: {} {}",
442                response.status(),
443                body,
444            )),
445        }
446    }
447}
448
449pub async fn stream_completion(
450    client: &dyn HttpClient,
451    api_url: &str,
452    api_key: &str,
453    request: Request,
454) -> Result<BoxStream<'static, Result<ResponseStreamEvent>>> {
455    let uri = format!("{api_url}/chat/completions");
456    let request_builder = HttpRequest::builder()
457        .method(Method::POST)
458        .uri(uri)
459        .header("Content-Type", "application/json")
460        .header("Authorization", format!("Bearer {}", api_key))
461        .header("HTTP-Referer", "https://zed.dev")
462        .header("X-Title", "Zed Editor");
463
464    let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
465    let mut response = client.send(request).await?;
466
467    if response.status().is_success() {
468        let reader = BufReader::new(response.into_body());
469        Ok(reader
470            .lines()
471            .filter_map(|line| async move {
472                match line {
473                    Ok(line) => {
474                        if line.starts_with(':') {
475                            return None;
476                        }
477
478                        let line = line.strip_prefix("data: ")?;
479                        if line == "[DONE]" {
480                            None
481                        } else {
482                            match serde_json::from_str::<ResponseStreamEvent>(line) {
483                                Ok(response) => Some(Ok(response)),
484                                Err(error) => {
485                                    #[derive(Deserialize)]
486                                    struct ErrorResponse {
487                                        error: String,
488                                    }
489
490                                    match serde_json::from_str::<ErrorResponse>(line) {
491                                        Ok(err_response) => Some(Err(anyhow!(err_response.error))),
492                                        Err(_) => {
493                                            if line.trim().is_empty() {
494                                                None
495                                            } else {
496                                                Some(Err(anyhow!(
497                                                    "Failed to parse response: {}. Original content: '{}'",
498                                                    error, line
499                                                )))
500                                            }
501                                        }
502                                    }
503                                }
504                            }
505                        }
506                    }
507                    Err(error) => Some(Err(anyhow!(error))),
508                }
509            })
510            .boxed())
511    } else {
512        let mut body = String::new();
513        response.body_mut().read_to_string(&mut body).await?;
514
515        #[derive(Deserialize)]
516        struct OpenRouterResponse {
517            error: OpenRouterError,
518        }
519
520        #[derive(Deserialize)]
521        struct OpenRouterError {
522            message: String,
523            #[serde(default)]
524            code: String,
525        }
526
527        match serde_json::from_str::<OpenRouterResponse>(&body) {
528            Ok(response) if !response.error.message.is_empty() => {
529                let error_message = if !response.error.code.is_empty() {
530                    format!("{}: {}", response.error.code, response.error.message)
531                } else {
532                    response.error.message
533                };
534
535                Err(anyhow!(
536                    "Failed to connect to OpenRouter API: {}",
537                    error_message
538                ))
539            }
540            _ => Err(anyhow!(
541                "Failed to connect to OpenRouter API: {} {}",
542                response.status(),
543                body,
544            )),
545        }
546    }
547}
548
549pub async fn list_models(client: &dyn HttpClient, api_url: &str) -> Result<Vec<Model>> {
550    let uri = format!("{api_url}/models");
551    let request_builder = HttpRequest::builder()
552        .method(Method::GET)
553        .uri(uri)
554        .header("Accept", "application/json");
555
556    let request = request_builder.body(AsyncBody::default())?;
557    let mut response = client.send(request).await?;
558
559    let mut body = String::new();
560    response.body_mut().read_to_string(&mut body).await?;
561
562    if response.status().is_success() {
563        let response: ListModelsResponse =
564            serde_json::from_str(&body).context("Unable to parse OpenRouter models response")?;
565
566        let models = response
567            .data
568            .into_iter()
569            .map(|entry| Model {
570                name: entry.id,
571                // OpenRouter returns display names in the format "provider_name: model_name".
572                // When displayed in the UI, these names can get truncated from the right.
573                // Since users typically already know the provider, we extract just the model name
574                // portion (after the colon) to create a more concise and user-friendly label
575                // for the model dropdown in the agent panel.
576                display_name: Some(
577                    entry
578                        .name
579                        .split(':')
580                        .next_back()
581                        .unwrap_or(&entry.name)
582                        .trim()
583                        .to_string(),
584                ),
585                max_tokens: entry.context_length.unwrap_or(2000000),
586                supports_tools: Some(entry.supported_parameters.contains(&"tools".to_string())),
587                supports_images: Some(
588                    entry
589                        .architecture
590                        .as_ref()
591                        .map(|arch| arch.input_modalities.contains(&"image".to_string()))
592                        .unwrap_or(false),
593                ),
594            })
595            .collect();
596
597        Ok(models)
598    } else {
599        Err(anyhow!(
600            "Failed to connect to OpenRouter API: {} {}",
601            response.status(),
602            body,
603        ))
604    }
605}