1use anyhow::{Context, Result, anyhow};
2use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
3use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
4use serde::{Deserialize, Serialize};
5use serde_json::Value;
6use std::convert::TryFrom;
7
8pub const OPEN_ROUTER_API_URL: &str = "https://openrouter.ai/api/v1";
9
10fn is_none_or_empty<T: AsRef<[U]>, U>(opt: &Option<T>) -> bool {
11 opt.as_ref().map_or(true, |v| v.as_ref().is_empty())
12}
13
14#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
15#[serde(rename_all = "lowercase")]
16pub enum Role {
17 User,
18 Assistant,
19 System,
20 Tool,
21}
22
23impl TryFrom<String> for Role {
24 type Error = anyhow::Error;
25
26 fn try_from(value: String) -> Result<Self> {
27 match value.as_str() {
28 "user" => Ok(Self::User),
29 "assistant" => Ok(Self::Assistant),
30 "system" => Ok(Self::System),
31 "tool" => Ok(Self::Tool),
32 _ => Err(anyhow!("invalid role '{value}'")),
33 }
34 }
35}
36
37impl From<Role> for String {
38 fn from(val: Role) -> Self {
39 match val {
40 Role::User => "user".to_owned(),
41 Role::Assistant => "assistant".to_owned(),
42 Role::System => "system".to_owned(),
43 Role::Tool => "tool".to_owned(),
44 }
45 }
46}
47
48#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
49#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
50pub struct Model {
51 pub name: String,
52 pub display_name: Option<String>,
53 pub max_tokens: usize,
54 pub supports_tools: Option<bool>,
55}
56
57impl Model {
58 pub fn default_fast() -> Self {
59 Self::new(
60 "openrouter/auto",
61 Some("Auto Router"),
62 Some(2000000),
63 Some(true),
64 )
65 }
66
67 pub fn default() -> Self {
68 Self::default_fast()
69 }
70
71 pub fn new(
72 name: &str,
73 display_name: Option<&str>,
74 max_tokens: Option<usize>,
75 supports_tools: Option<bool>,
76 ) -> Self {
77 Self {
78 name: name.to_owned(),
79 display_name: display_name.map(|s| s.to_owned()),
80 max_tokens: max_tokens.unwrap_or(2000000),
81 supports_tools,
82 }
83 }
84
85 pub fn id(&self) -> &str {
86 &self.name
87 }
88
89 pub fn display_name(&self) -> &str {
90 self.display_name.as_ref().unwrap_or(&self.name)
91 }
92
93 pub fn max_token_count(&self) -> usize {
94 self.max_tokens
95 }
96
97 pub fn max_output_tokens(&self) -> Option<u32> {
98 None
99 }
100
101 pub fn supports_tool_calls(&self) -> bool {
102 self.supports_tools.unwrap_or(false)
103 }
104
105 pub fn supports_parallel_tool_calls(&self) -> bool {
106 false
107 }
108}
109
110#[derive(Debug, Serialize, Deserialize)]
111pub struct Request {
112 pub model: String,
113 pub messages: Vec<RequestMessage>,
114 pub stream: bool,
115 #[serde(default, skip_serializing_if = "Option::is_none")]
116 pub max_tokens: Option<u32>,
117 #[serde(default, skip_serializing_if = "Vec::is_empty")]
118 pub stop: Vec<String>,
119 pub temperature: f32,
120 #[serde(default, skip_serializing_if = "Option::is_none")]
121 pub tool_choice: Option<ToolChoice>,
122 #[serde(default, skip_serializing_if = "Option::is_none")]
123 pub parallel_tool_calls: Option<bool>,
124 #[serde(default, skip_serializing_if = "Vec::is_empty")]
125 pub tools: Vec<ToolDefinition>,
126}
127
128#[derive(Debug, Serialize, Deserialize)]
129#[serde(untagged)]
130pub enum ToolChoice {
131 Auto,
132 Required,
133 None,
134 Other(ToolDefinition),
135}
136
137#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
138#[derive(Clone, Deserialize, Serialize, Debug)]
139#[serde(tag = "type", rename_all = "snake_case")]
140pub enum ToolDefinition {
141 #[allow(dead_code)]
142 Function { function: FunctionDefinition },
143}
144
145#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
146#[derive(Clone, Debug, Serialize, Deserialize)]
147pub struct FunctionDefinition {
148 pub name: String,
149 pub description: Option<String>,
150 pub parameters: Option<Value>,
151}
152
153#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
154#[serde(tag = "role", rename_all = "lowercase")]
155pub enum RequestMessage {
156 Assistant {
157 content: Option<String>,
158 #[serde(default, skip_serializing_if = "Vec::is_empty")]
159 tool_calls: Vec<ToolCall>,
160 },
161 User {
162 content: String,
163 },
164 System {
165 content: String,
166 },
167 Tool {
168 content: String,
169 tool_call_id: String,
170 },
171}
172
173#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
174pub struct ToolCall {
175 pub id: String,
176 #[serde(flatten)]
177 pub content: ToolCallContent,
178}
179
180#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
181#[serde(tag = "type", rename_all = "lowercase")]
182pub enum ToolCallContent {
183 Function { function: FunctionContent },
184}
185
186#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
187pub struct FunctionContent {
188 pub name: String,
189 pub arguments: String,
190}
191
192#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
193pub struct ResponseMessageDelta {
194 pub role: Option<Role>,
195 pub content: Option<String>,
196 #[serde(default, skip_serializing_if = "is_none_or_empty")]
197 pub tool_calls: Option<Vec<ToolCallChunk>>,
198}
199
200#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
201pub struct ToolCallChunk {
202 pub index: usize,
203 pub id: Option<String>,
204 pub function: Option<FunctionChunk>,
205}
206
207#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
208pub struct FunctionChunk {
209 pub name: Option<String>,
210 pub arguments: Option<String>,
211}
212
213#[derive(Serialize, Deserialize, Debug)]
214pub struct Usage {
215 pub prompt_tokens: u32,
216 pub completion_tokens: u32,
217 pub total_tokens: u32,
218}
219
220#[derive(Serialize, Deserialize, Debug)]
221pub struct ChoiceDelta {
222 pub index: u32,
223 pub delta: ResponseMessageDelta,
224 pub finish_reason: Option<String>,
225}
226
227#[derive(Serialize, Deserialize, Debug)]
228pub struct ResponseStreamEvent {
229 #[serde(default, skip_serializing_if = "Option::is_none")]
230 pub id: Option<String>,
231 pub created: u32,
232 pub model: String,
233 pub choices: Vec<ChoiceDelta>,
234 pub usage: Option<Usage>,
235}
236
237#[derive(Serialize, Deserialize, Debug)]
238pub struct Response {
239 pub id: String,
240 pub object: String,
241 pub created: u64,
242 pub model: String,
243 pub choices: Vec<Choice>,
244 pub usage: Usage,
245}
246
247#[derive(Serialize, Deserialize, Debug)]
248pub struct Choice {
249 pub index: u32,
250 pub message: RequestMessage,
251 pub finish_reason: Option<String>,
252}
253
254#[derive(Default, Debug, Clone, PartialEq, Deserialize)]
255pub struct ListModelsResponse {
256 pub data: Vec<ModelEntry>,
257}
258
259#[derive(Default, Debug, Clone, PartialEq, Deserialize)]
260pub struct ModelEntry {
261 pub id: String,
262 pub name: String,
263 pub created: usize,
264 pub description: String,
265 #[serde(default, skip_serializing_if = "Option::is_none")]
266 pub context_length: Option<usize>,
267 #[serde(default, skip_serializing_if = "Vec::is_empty")]
268 pub supported_parameters: Vec<String>,
269}
270
271pub async fn complete(
272 client: &dyn HttpClient,
273 api_url: &str,
274 api_key: &str,
275 request: Request,
276) -> Result<Response> {
277 let uri = format!("{api_url}/chat/completions");
278 let request_builder = HttpRequest::builder()
279 .method(Method::POST)
280 .uri(uri)
281 .header("Content-Type", "application/json")
282 .header("Authorization", format!("Bearer {}", api_key))
283 .header("HTTP-Referer", "https://zed.dev")
284 .header("X-Title", "Zed Editor");
285
286 let mut request_body = request;
287 request_body.stream = false;
288
289 let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request_body)?))?;
290 let mut response = client.send(request).await?;
291
292 if response.status().is_success() {
293 let mut body = String::new();
294 response.body_mut().read_to_string(&mut body).await?;
295 let response: Response = serde_json::from_str(&body)?;
296 Ok(response)
297 } else {
298 let mut body = String::new();
299 response.body_mut().read_to_string(&mut body).await?;
300
301 #[derive(Deserialize)]
302 struct OpenRouterResponse {
303 error: OpenRouterError,
304 }
305
306 #[derive(Deserialize)]
307 struct OpenRouterError {
308 message: String,
309 #[serde(default)]
310 code: String,
311 }
312
313 match serde_json::from_str::<OpenRouterResponse>(&body) {
314 Ok(response) if !response.error.message.is_empty() => {
315 let error_message = if !response.error.code.is_empty() {
316 format!("{}: {}", response.error.code, response.error.message)
317 } else {
318 response.error.message
319 };
320
321 Err(anyhow!(
322 "Failed to connect to OpenRouter API: {}",
323 error_message
324 ))
325 }
326 _ => Err(anyhow!(
327 "Failed to connect to OpenRouter API: {} {}",
328 response.status(),
329 body,
330 )),
331 }
332 }
333}
334
335pub async fn stream_completion(
336 client: &dyn HttpClient,
337 api_url: &str,
338 api_key: &str,
339 request: Request,
340) -> Result<BoxStream<'static, Result<ResponseStreamEvent>>> {
341 let uri = format!("{api_url}/chat/completions");
342 let request_builder = HttpRequest::builder()
343 .method(Method::POST)
344 .uri(uri)
345 .header("Content-Type", "application/json")
346 .header("Authorization", format!("Bearer {}", api_key))
347 .header("HTTP-Referer", "https://zed.dev")
348 .header("X-Title", "Zed Editor");
349
350 let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
351 let mut response = client.send(request).await?;
352
353 if response.status().is_success() {
354 let reader = BufReader::new(response.into_body());
355 Ok(reader
356 .lines()
357 .filter_map(|line| async move {
358 match line {
359 Ok(line) => {
360 if line.starts_with(':') {
361 return None;
362 }
363
364 let line = line.strip_prefix("data: ")?;
365 if line == "[DONE]" {
366 None
367 } else {
368 match serde_json::from_str::<ResponseStreamEvent>(line) {
369 Ok(response) => Some(Ok(response)),
370 Err(error) => {
371 #[derive(Deserialize)]
372 struct ErrorResponse {
373 error: String,
374 }
375
376 match serde_json::from_str::<ErrorResponse>(line) {
377 Ok(err_response) => Some(Err(anyhow!(err_response.error))),
378 Err(_) => {
379 if line.trim().is_empty() {
380 None
381 } else {
382 Some(Err(anyhow!(
383 "Failed to parse response: {}. Original content: '{}'",
384 error, line
385 )))
386 }
387 }
388 }
389 }
390 }
391 }
392 }
393 Err(error) => Some(Err(anyhow!(error))),
394 }
395 })
396 .boxed())
397 } else {
398 let mut body = String::new();
399 response.body_mut().read_to_string(&mut body).await?;
400
401 #[derive(Deserialize)]
402 struct OpenRouterResponse {
403 error: OpenRouterError,
404 }
405
406 #[derive(Deserialize)]
407 struct OpenRouterError {
408 message: String,
409 #[serde(default)]
410 code: String,
411 }
412
413 match serde_json::from_str::<OpenRouterResponse>(&body) {
414 Ok(response) if !response.error.message.is_empty() => {
415 let error_message = if !response.error.code.is_empty() {
416 format!("{}: {}", response.error.code, response.error.message)
417 } else {
418 response.error.message
419 };
420
421 Err(anyhow!(
422 "Failed to connect to OpenRouter API: {}",
423 error_message
424 ))
425 }
426 _ => Err(anyhow!(
427 "Failed to connect to OpenRouter API: {} {}",
428 response.status(),
429 body,
430 )),
431 }
432 }
433}
434
435pub async fn list_models(client: &dyn HttpClient, api_url: &str) -> Result<Vec<Model>> {
436 let uri = format!("{api_url}/models");
437 let request_builder = HttpRequest::builder()
438 .method(Method::GET)
439 .uri(uri)
440 .header("Accept", "application/json");
441
442 let request = request_builder.body(AsyncBody::default())?;
443 let mut response = client.send(request).await?;
444
445 let mut body = String::new();
446 response.body_mut().read_to_string(&mut body).await?;
447
448 if response.status().is_success() {
449 let response: ListModelsResponse =
450 serde_json::from_str(&body).context("Unable to parse OpenRouter models response")?;
451
452 let models = response
453 .data
454 .into_iter()
455 .map(|entry| Model {
456 name: entry.id,
457 // OpenRouter returns display names in the format "provider_name: model_name".
458 // When displayed in the UI, these names can get truncated from the right.
459 // Since users typically already know the provider, we extract just the model name
460 // portion (after the colon) to create a more concise and user-friendly label
461 // for the model dropdown in the agent panel.
462 display_name: Some(
463 entry
464 .name
465 .split(':')
466 .next_back()
467 .unwrap_or(&entry.name)
468 .trim()
469 .to_string(),
470 ),
471 max_tokens: entry.context_length.unwrap_or(2000000),
472 supports_tools: Some(entry.supported_parameters.contains(&"tools".to_string())),
473 })
474 .collect();
475
476 Ok(models)
477 } else {
478 Err(anyhow!(
479 "Failed to connect to OpenRouter API: {} {}",
480 response.status(),
481 body,
482 ))
483 }
484}