open_router.rs

  1use anyhow::{Context, Result, anyhow};
  2use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
  3use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
  4use serde::{Deserialize, Serialize};
  5use serde_json::Value;
  6use std::convert::TryFrom;
  7
  8pub const OPEN_ROUTER_API_URL: &str = "https://openrouter.ai/api/v1";
  9
 10fn is_none_or_empty<T: AsRef<[U]>, U>(opt: &Option<T>) -> bool {
 11    opt.as_ref().map_or(true, |v| v.as_ref().is_empty())
 12}
 13
 14#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
 15#[serde(rename_all = "lowercase")]
 16pub enum Role {
 17    User,
 18    Assistant,
 19    System,
 20    Tool,
 21}
 22
 23impl TryFrom<String> for Role {
 24    type Error = anyhow::Error;
 25
 26    fn try_from(value: String) -> Result<Self> {
 27        match value.as_str() {
 28            "user" => Ok(Self::User),
 29            "assistant" => Ok(Self::Assistant),
 30            "system" => Ok(Self::System),
 31            "tool" => Ok(Self::Tool),
 32            _ => Err(anyhow!("invalid role '{value}'")),
 33        }
 34    }
 35}
 36
 37impl From<Role> for String {
 38    fn from(val: Role) -> Self {
 39        match val {
 40            Role::User => "user".to_owned(),
 41            Role::Assistant => "assistant".to_owned(),
 42            Role::System => "system".to_owned(),
 43            Role::Tool => "tool".to_owned(),
 44        }
 45    }
 46}
 47
 48#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
 49#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
 50pub struct Model {
 51    pub name: String,
 52    pub display_name: Option<String>,
 53    pub max_tokens: usize,
 54    pub supports_tools: Option<bool>,
 55}
 56
 57impl Model {
 58    pub fn default_fast() -> Self {
 59        Self::new(
 60            "openrouter/auto",
 61            Some("Auto Router"),
 62            Some(2000000),
 63            Some(true),
 64        )
 65    }
 66
 67    pub fn default() -> Self {
 68        Self::default_fast()
 69    }
 70
 71    pub fn new(
 72        name: &str,
 73        display_name: Option<&str>,
 74        max_tokens: Option<usize>,
 75        supports_tools: Option<bool>,
 76    ) -> Self {
 77        Self {
 78            name: name.to_owned(),
 79            display_name: display_name.map(|s| s.to_owned()),
 80            max_tokens: max_tokens.unwrap_or(2000000),
 81            supports_tools,
 82        }
 83    }
 84
 85    pub fn id(&self) -> &str {
 86        &self.name
 87    }
 88
 89    pub fn display_name(&self) -> &str {
 90        self.display_name.as_ref().unwrap_or(&self.name)
 91    }
 92
 93    pub fn max_token_count(&self) -> usize {
 94        self.max_tokens
 95    }
 96
 97    pub fn max_output_tokens(&self) -> Option<u32> {
 98        None
 99    }
100
101    pub fn supports_tool_calls(&self) -> bool {
102        self.supports_tools.unwrap_or(false)
103    }
104
105    pub fn supports_parallel_tool_calls(&self) -> bool {
106        false
107    }
108}
109
110#[derive(Debug, Serialize, Deserialize)]
111pub struct Request {
112    pub model: String,
113    pub messages: Vec<RequestMessage>,
114    pub stream: bool,
115    #[serde(default, skip_serializing_if = "Option::is_none")]
116    pub max_tokens: Option<u32>,
117    #[serde(default, skip_serializing_if = "Vec::is_empty")]
118    pub stop: Vec<String>,
119    pub temperature: f32,
120    #[serde(default, skip_serializing_if = "Option::is_none")]
121    pub tool_choice: Option<ToolChoice>,
122    #[serde(default, skip_serializing_if = "Option::is_none")]
123    pub parallel_tool_calls: Option<bool>,
124    #[serde(default, skip_serializing_if = "Vec::is_empty")]
125    pub tools: Vec<ToolDefinition>,
126}
127
128#[derive(Debug, Serialize, Deserialize)]
129#[serde(untagged)]
130pub enum ToolChoice {
131    Auto,
132    Required,
133    None,
134    Other(ToolDefinition),
135}
136
137#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
138#[derive(Clone, Deserialize, Serialize, Debug)]
139#[serde(tag = "type", rename_all = "snake_case")]
140pub enum ToolDefinition {
141    #[allow(dead_code)]
142    Function { function: FunctionDefinition },
143}
144
145#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
146#[derive(Clone, Debug, Serialize, Deserialize)]
147pub struct FunctionDefinition {
148    pub name: String,
149    pub description: Option<String>,
150    pub parameters: Option<Value>,
151}
152
153#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
154#[serde(tag = "role", rename_all = "lowercase")]
155pub enum RequestMessage {
156    Assistant {
157        content: Option<String>,
158        #[serde(default, skip_serializing_if = "Vec::is_empty")]
159        tool_calls: Vec<ToolCall>,
160    },
161    User {
162        content: String,
163    },
164    System {
165        content: String,
166    },
167    Tool {
168        content: String,
169        tool_call_id: String,
170    },
171}
172
173#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
174pub struct ToolCall {
175    pub id: String,
176    #[serde(flatten)]
177    pub content: ToolCallContent,
178}
179
180#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
181#[serde(tag = "type", rename_all = "lowercase")]
182pub enum ToolCallContent {
183    Function { function: FunctionContent },
184}
185
186#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
187pub struct FunctionContent {
188    pub name: String,
189    pub arguments: String,
190}
191
192#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
193pub struct ResponseMessageDelta {
194    pub role: Option<Role>,
195    pub content: Option<String>,
196    #[serde(default, skip_serializing_if = "is_none_or_empty")]
197    pub tool_calls: Option<Vec<ToolCallChunk>>,
198}
199
200#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
201pub struct ToolCallChunk {
202    pub index: usize,
203    pub id: Option<String>,
204    pub function: Option<FunctionChunk>,
205}
206
207#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
208pub struct FunctionChunk {
209    pub name: Option<String>,
210    pub arguments: Option<String>,
211}
212
213#[derive(Serialize, Deserialize, Debug)]
214pub struct Usage {
215    pub prompt_tokens: u32,
216    pub completion_tokens: u32,
217    pub total_tokens: u32,
218}
219
220#[derive(Serialize, Deserialize, Debug)]
221pub struct ChoiceDelta {
222    pub index: u32,
223    pub delta: ResponseMessageDelta,
224    pub finish_reason: Option<String>,
225}
226
227#[derive(Serialize, Deserialize, Debug)]
228pub struct ResponseStreamEvent {
229    #[serde(default, skip_serializing_if = "Option::is_none")]
230    pub id: Option<String>,
231    pub created: u32,
232    pub model: String,
233    pub choices: Vec<ChoiceDelta>,
234    pub usage: Option<Usage>,
235}
236
237#[derive(Serialize, Deserialize, Debug)]
238pub struct Response {
239    pub id: String,
240    pub object: String,
241    pub created: u64,
242    pub model: String,
243    pub choices: Vec<Choice>,
244    pub usage: Usage,
245}
246
247#[derive(Serialize, Deserialize, Debug)]
248pub struct Choice {
249    pub index: u32,
250    pub message: RequestMessage,
251    pub finish_reason: Option<String>,
252}
253
254#[derive(Default, Debug, Clone, PartialEq, Deserialize)]
255pub struct ListModelsResponse {
256    pub data: Vec<ModelEntry>,
257}
258
259#[derive(Default, Debug, Clone, PartialEq, Deserialize)]
260pub struct ModelEntry {
261    pub id: String,
262    pub name: String,
263    pub created: usize,
264    pub description: String,
265    #[serde(default, skip_serializing_if = "Option::is_none")]
266    pub context_length: Option<usize>,
267    #[serde(default, skip_serializing_if = "Vec::is_empty")]
268    pub supported_parameters: Vec<String>,
269}
270
271pub async fn complete(
272    client: &dyn HttpClient,
273    api_url: &str,
274    api_key: &str,
275    request: Request,
276) -> Result<Response> {
277    let uri = format!("{api_url}/chat/completions");
278    let request_builder = HttpRequest::builder()
279        .method(Method::POST)
280        .uri(uri)
281        .header("Content-Type", "application/json")
282        .header("Authorization", format!("Bearer {}", api_key))
283        .header("HTTP-Referer", "https://zed.dev")
284        .header("X-Title", "Zed Editor");
285
286    let mut request_body = request;
287    request_body.stream = false;
288
289    let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request_body)?))?;
290    let mut response = client.send(request).await?;
291
292    if response.status().is_success() {
293        let mut body = String::new();
294        response.body_mut().read_to_string(&mut body).await?;
295        let response: Response = serde_json::from_str(&body)?;
296        Ok(response)
297    } else {
298        let mut body = String::new();
299        response.body_mut().read_to_string(&mut body).await?;
300
301        #[derive(Deserialize)]
302        struct OpenRouterResponse {
303            error: OpenRouterError,
304        }
305
306        #[derive(Deserialize)]
307        struct OpenRouterError {
308            message: String,
309            #[serde(default)]
310            code: String,
311        }
312
313        match serde_json::from_str::<OpenRouterResponse>(&body) {
314            Ok(response) if !response.error.message.is_empty() => {
315                let error_message = if !response.error.code.is_empty() {
316                    format!("{}: {}", response.error.code, response.error.message)
317                } else {
318                    response.error.message
319                };
320
321                Err(anyhow!(
322                    "Failed to connect to OpenRouter API: {}",
323                    error_message
324                ))
325            }
326            _ => Err(anyhow!(
327                "Failed to connect to OpenRouter API: {} {}",
328                response.status(),
329                body,
330            )),
331        }
332    }
333}
334
335pub async fn stream_completion(
336    client: &dyn HttpClient,
337    api_url: &str,
338    api_key: &str,
339    request: Request,
340) -> Result<BoxStream<'static, Result<ResponseStreamEvent>>> {
341    let uri = format!("{api_url}/chat/completions");
342    let request_builder = HttpRequest::builder()
343        .method(Method::POST)
344        .uri(uri)
345        .header("Content-Type", "application/json")
346        .header("Authorization", format!("Bearer {}", api_key))
347        .header("HTTP-Referer", "https://zed.dev")
348        .header("X-Title", "Zed Editor");
349
350    let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
351    let mut response = client.send(request).await?;
352
353    if response.status().is_success() {
354        let reader = BufReader::new(response.into_body());
355        Ok(reader
356            .lines()
357            .filter_map(|line| async move {
358                match line {
359                    Ok(line) => {
360                        if line.starts_with(':') {
361                            return None;
362                        }
363
364                        let line = line.strip_prefix("data: ")?;
365                        if line == "[DONE]" {
366                            None
367                        } else {
368                            match serde_json::from_str::<ResponseStreamEvent>(line) {
369                                Ok(response) => Some(Ok(response)),
370                                Err(error) => {
371                                    #[derive(Deserialize)]
372                                    struct ErrorResponse {
373                                        error: String,
374                                    }
375
376                                    match serde_json::from_str::<ErrorResponse>(line) {
377                                        Ok(err_response) => Some(Err(anyhow!(err_response.error))),
378                                        Err(_) => {
379                                            if line.trim().is_empty() {
380                                                None
381                                            } else {
382                                                Some(Err(anyhow!(
383                                                    "Failed to parse response: {}. Original content: '{}'",
384                                                    error, line
385                                                )))
386                                            }
387                                        }
388                                    }
389                                }
390                            }
391                        }
392                    }
393                    Err(error) => Some(Err(anyhow!(error))),
394                }
395            })
396            .boxed())
397    } else {
398        let mut body = String::new();
399        response.body_mut().read_to_string(&mut body).await?;
400
401        #[derive(Deserialize)]
402        struct OpenRouterResponse {
403            error: OpenRouterError,
404        }
405
406        #[derive(Deserialize)]
407        struct OpenRouterError {
408            message: String,
409            #[serde(default)]
410            code: String,
411        }
412
413        match serde_json::from_str::<OpenRouterResponse>(&body) {
414            Ok(response) if !response.error.message.is_empty() => {
415                let error_message = if !response.error.code.is_empty() {
416                    format!("{}: {}", response.error.code, response.error.message)
417                } else {
418                    response.error.message
419                };
420
421                Err(anyhow!(
422                    "Failed to connect to OpenRouter API: {}",
423                    error_message
424                ))
425            }
426            _ => Err(anyhow!(
427                "Failed to connect to OpenRouter API: {} {}",
428                response.status(),
429                body,
430            )),
431        }
432    }
433}
434
435pub async fn list_models(client: &dyn HttpClient, api_url: &str) -> Result<Vec<Model>> {
436    let uri = format!("{api_url}/models");
437    let request_builder = HttpRequest::builder()
438        .method(Method::GET)
439        .uri(uri)
440        .header("Accept", "application/json");
441
442    let request = request_builder.body(AsyncBody::default())?;
443    let mut response = client.send(request).await?;
444
445    let mut body = String::new();
446    response.body_mut().read_to_string(&mut body).await?;
447
448    if response.status().is_success() {
449        let response: ListModelsResponse =
450            serde_json::from_str(&body).context("Unable to parse OpenRouter models response")?;
451
452        let models = response
453            .data
454            .into_iter()
455            .map(|entry| Model {
456                name: entry.id,
457                // OpenRouter returns display names in the format "provider_name: model_name".
458                // When displayed in the UI, these names can get truncated from the right.
459                // Since users typically already know the provider, we extract just the model name
460                // portion (after the colon) to create a more concise and user-friendly label
461                // for the model dropdown in the agent panel.
462                display_name: Some(
463                    entry
464                        .name
465                        .split(':')
466                        .next_back()
467                        .unwrap_or(&entry.name)
468                        .trim()
469                        .to_string(),
470                ),
471                max_tokens: entry.context_length.unwrap_or(2000000),
472                supports_tools: Some(entry.supported_parameters.contains(&"tools".to_string())),
473            })
474            .collect();
475
476        Ok(models)
477    } else {
478        Err(anyhow!(
479            "Failed to connect to OpenRouter API: {} {}",
480            response.status(),
481            body,
482        ))
483    }
484}