1use std::collections::HashMap;
2use std::sync::Mutex;
3
4use serde::{Deserialize, Serialize};
5use zed_extension_api::http_client::{HttpMethod, HttpRequest, HttpResponseStream, RedirectPolicy};
6use zed_extension_api::{self as zed, *};
7
8struct OpenRouterProvider {
9 streams: Mutex<HashMap<String, StreamState>>,
10 next_stream_id: Mutex<u64>,
11}
12
13struct StreamState {
14 response_stream: Option<HttpResponseStream>,
15 buffer: String,
16 started: bool,
17 tool_calls: HashMap<usize, AccumulatedToolCall>,
18 tool_calls_emitted: bool,
19}
20
21#[derive(Clone, Default)]
22struct AccumulatedToolCall {
23 id: String,
24 name: String,
25 arguments: String,
26}
27
28struct ModelDefinition {
29 id: &'static str,
30 display_name: &'static str,
31 max_tokens: u64,
32 max_output_tokens: Option<u64>,
33 supports_images: bool,
34 supports_tools: bool,
35 is_default: bool,
36 is_default_fast: bool,
37}
38
39const MODELS: &[ModelDefinition] = &[
40 // Anthropic Models
41 ModelDefinition {
42 id: "anthropic/claude-sonnet-4",
43 display_name: "Claude Sonnet 4",
44 max_tokens: 200_000,
45 max_output_tokens: Some(8_192),
46 supports_images: true,
47 supports_tools: true,
48 is_default: true,
49 is_default_fast: false,
50 },
51 ModelDefinition {
52 id: "anthropic/claude-opus-4",
53 display_name: "Claude Opus 4",
54 max_tokens: 200_000,
55 max_output_tokens: Some(8_192),
56 supports_images: true,
57 supports_tools: true,
58 is_default: false,
59 is_default_fast: false,
60 },
61 ModelDefinition {
62 id: "anthropic/claude-haiku-4",
63 display_name: "Claude Haiku 4",
64 max_tokens: 200_000,
65 max_output_tokens: Some(8_192),
66 supports_images: true,
67 supports_tools: true,
68 is_default: false,
69 is_default_fast: true,
70 },
71 ModelDefinition {
72 id: "anthropic/claude-3.5-sonnet",
73 display_name: "Claude 3.5 Sonnet",
74 max_tokens: 200_000,
75 max_output_tokens: Some(8_192),
76 supports_images: true,
77 supports_tools: true,
78 is_default: false,
79 is_default_fast: false,
80 },
81 // OpenAI Models
82 ModelDefinition {
83 id: "openai/gpt-4o",
84 display_name: "GPT-4o",
85 max_tokens: 128_000,
86 max_output_tokens: Some(16_384),
87 supports_images: true,
88 supports_tools: true,
89 is_default: false,
90 is_default_fast: false,
91 },
92 ModelDefinition {
93 id: "openai/gpt-4o-mini",
94 display_name: "GPT-4o Mini",
95 max_tokens: 128_000,
96 max_output_tokens: Some(16_384),
97 supports_images: true,
98 supports_tools: true,
99 is_default: false,
100 is_default_fast: false,
101 },
102 ModelDefinition {
103 id: "openai/o1",
104 display_name: "o1",
105 max_tokens: 200_000,
106 max_output_tokens: Some(100_000),
107 supports_images: true,
108 supports_tools: false,
109 is_default: false,
110 is_default_fast: false,
111 },
112 ModelDefinition {
113 id: "openai/o3-mini",
114 display_name: "o3-mini",
115 max_tokens: 200_000,
116 max_output_tokens: Some(100_000),
117 supports_images: false,
118 supports_tools: false,
119 is_default: false,
120 is_default_fast: false,
121 },
122 // Google Models
123 ModelDefinition {
124 id: "google/gemini-2.0-flash-001",
125 display_name: "Gemini 2.0 Flash",
126 max_tokens: 1_000_000,
127 max_output_tokens: Some(8_192),
128 supports_images: true,
129 supports_tools: true,
130 is_default: false,
131 is_default_fast: false,
132 },
133 ModelDefinition {
134 id: "google/gemini-2.5-pro-preview",
135 display_name: "Gemini 2.5 Pro",
136 max_tokens: 1_000_000,
137 max_output_tokens: Some(8_192),
138 supports_images: true,
139 supports_tools: true,
140 is_default: false,
141 is_default_fast: false,
142 },
143 // Meta Models
144 ModelDefinition {
145 id: "meta-llama/llama-3.3-70b-instruct",
146 display_name: "Llama 3.3 70B",
147 max_tokens: 128_000,
148 max_output_tokens: Some(4_096),
149 supports_images: false,
150 supports_tools: true,
151 is_default: false,
152 is_default_fast: false,
153 },
154 ModelDefinition {
155 id: "meta-llama/llama-4-maverick",
156 display_name: "Llama 4 Maverick",
157 max_tokens: 128_000,
158 max_output_tokens: Some(4_096),
159 supports_images: true,
160 supports_tools: true,
161 is_default: false,
162 is_default_fast: false,
163 },
164 // Mistral Models
165 ModelDefinition {
166 id: "mistralai/mistral-large-2411",
167 display_name: "Mistral Large",
168 max_tokens: 128_000,
169 max_output_tokens: Some(4_096),
170 supports_images: false,
171 supports_tools: true,
172 is_default: false,
173 is_default_fast: false,
174 },
175 ModelDefinition {
176 id: "mistralai/codestral-latest",
177 display_name: "Codestral",
178 max_tokens: 32_000,
179 max_output_tokens: Some(4_096),
180 supports_images: false,
181 supports_tools: true,
182 is_default: false,
183 is_default_fast: false,
184 },
185 // DeepSeek Models
186 ModelDefinition {
187 id: "deepseek/deepseek-chat-v3-0324",
188 display_name: "DeepSeek V3",
189 max_tokens: 64_000,
190 max_output_tokens: Some(8_192),
191 supports_images: false,
192 supports_tools: true,
193 is_default: false,
194 is_default_fast: false,
195 },
196 ModelDefinition {
197 id: "deepseek/deepseek-r1",
198 display_name: "DeepSeek R1",
199 max_tokens: 64_000,
200 max_output_tokens: Some(8_192),
201 supports_images: false,
202 supports_tools: false,
203 is_default: false,
204 is_default_fast: false,
205 },
206 // Qwen Models
207 ModelDefinition {
208 id: "qwen/qwen3-235b-a22b",
209 display_name: "Qwen 3 235B",
210 max_tokens: 40_000,
211 max_output_tokens: Some(8_192),
212 supports_images: false,
213 supports_tools: true,
214 is_default: false,
215 is_default_fast: false,
216 },
217];
218
219fn get_model_definition(model_id: &str) -> Option<&'static ModelDefinition> {
220 MODELS.iter().find(|m| m.id == model_id)
221}
222
223#[derive(Serialize)]
224struct OpenRouterRequest {
225 model: String,
226 messages: Vec<OpenRouterMessage>,
227 #[serde(skip_serializing_if = "Option::is_none")]
228 max_tokens: Option<u64>,
229 #[serde(skip_serializing_if = "Vec::is_empty")]
230 tools: Vec<OpenRouterTool>,
231 #[serde(skip_serializing_if = "Option::is_none")]
232 tool_choice: Option<String>,
233 #[serde(skip_serializing_if = "Vec::is_empty")]
234 stop: Vec<String>,
235 #[serde(skip_serializing_if = "Option::is_none")]
236 temperature: Option<f32>,
237 stream: bool,
238}
239
240#[derive(Serialize)]
241struct OpenRouterMessage {
242 role: String,
243 #[serde(skip_serializing_if = "Option::is_none")]
244 content: Option<OpenRouterContent>,
245 #[serde(skip_serializing_if = "Option::is_none")]
246 tool_calls: Option<Vec<OpenRouterToolCall>>,
247 #[serde(skip_serializing_if = "Option::is_none")]
248 tool_call_id: Option<String>,
249}
250
251#[derive(Serialize, Clone)]
252#[serde(untagged)]
253enum OpenRouterContent {
254 Text(String),
255 Parts(Vec<OpenRouterContentPart>),
256}
257
258#[derive(Serialize, Clone)]
259#[serde(tag = "type")]
260enum OpenRouterContentPart {
261 #[serde(rename = "text")]
262 Text { text: String },
263 #[serde(rename = "image_url")]
264 ImageUrl { image_url: ImageUrl },
265}
266
267#[derive(Serialize, Clone)]
268struct ImageUrl {
269 url: String,
270}
271
272#[derive(Serialize, Clone)]
273struct OpenRouterToolCall {
274 id: String,
275 #[serde(rename = "type")]
276 call_type: String,
277 function: OpenRouterFunctionCall,
278}
279
280#[derive(Serialize, Clone)]
281struct OpenRouterFunctionCall {
282 name: String,
283 arguments: String,
284}
285
286#[derive(Serialize)]
287struct OpenRouterTool {
288 #[serde(rename = "type")]
289 tool_type: String,
290 function: OpenRouterFunctionDef,
291}
292
293#[derive(Serialize)]
294struct OpenRouterFunctionDef {
295 name: String,
296 description: String,
297 parameters: serde_json::Value,
298}
299
300#[derive(Deserialize, Debug)]
301struct OpenRouterStreamResponse {
302 choices: Vec<OpenRouterStreamChoice>,
303 #[serde(default)]
304 usage: Option<OpenRouterUsage>,
305}
306
307#[derive(Deserialize, Debug)]
308struct OpenRouterStreamChoice {
309 delta: OpenRouterDelta,
310 finish_reason: Option<String>,
311}
312
313#[derive(Deserialize, Debug, Default)]
314struct OpenRouterDelta {
315 #[serde(default)]
316 content: Option<String>,
317 #[serde(default)]
318 tool_calls: Option<Vec<OpenRouterToolCallDelta>>,
319}
320
321#[derive(Deserialize, Debug)]
322struct OpenRouterToolCallDelta {
323 index: usize,
324 #[serde(default)]
325 id: Option<String>,
326 #[serde(default)]
327 function: Option<OpenRouterFunctionDelta>,
328}
329
330#[derive(Deserialize, Debug, Default)]
331struct OpenRouterFunctionDelta {
332 #[serde(default)]
333 name: Option<String>,
334 #[serde(default)]
335 arguments: Option<String>,
336}
337
338#[derive(Deserialize, Debug)]
339struct OpenRouterUsage {
340 prompt_tokens: u64,
341 completion_tokens: u64,
342}
343
344fn convert_request(
345 model_id: &str,
346 request: &LlmCompletionRequest,
347) -> Result<OpenRouterRequest, String> {
348 let mut messages: Vec<OpenRouterMessage> = Vec::new();
349
350 for msg in &request.messages {
351 match msg.role {
352 LlmMessageRole::System => {
353 let mut text_content = String::new();
354 for content in &msg.content {
355 if let LlmMessageContent::Text(text) = content {
356 if !text_content.is_empty() {
357 text_content.push('\n');
358 }
359 text_content.push_str(text);
360 }
361 }
362 if !text_content.is_empty() {
363 messages.push(OpenRouterMessage {
364 role: "system".to_string(),
365 content: Some(OpenRouterContent::Text(text_content)),
366 tool_calls: None,
367 tool_call_id: None,
368 });
369 }
370 }
371 LlmMessageRole::User => {
372 let mut parts: Vec<OpenRouterContentPart> = Vec::new();
373 let mut tool_result_messages: Vec<OpenRouterMessage> = Vec::new();
374
375 for content in &msg.content {
376 match content {
377 LlmMessageContent::Text(text) => {
378 if !text.is_empty() {
379 parts.push(OpenRouterContentPart::Text { text: text.clone() });
380 }
381 }
382 LlmMessageContent::Image(img) => {
383 let data_url = format!("data:image/png;base64,{}", img.source);
384 parts.push(OpenRouterContentPart::ImageUrl {
385 image_url: ImageUrl { url: data_url },
386 });
387 }
388 LlmMessageContent::ToolResult(result) => {
389 let content_text = match &result.content {
390 LlmToolResultContent::Text(t) => t.clone(),
391 LlmToolResultContent::Image(_) => "[Image]".to_string(),
392 };
393 tool_result_messages.push(OpenRouterMessage {
394 role: "tool".to_string(),
395 content: Some(OpenRouterContent::Text(content_text)),
396 tool_calls: None,
397 tool_call_id: Some(result.tool_use_id.clone()),
398 });
399 }
400 _ => {}
401 }
402 }
403
404 if !parts.is_empty() {
405 let content = if parts.len() == 1 {
406 if let OpenRouterContentPart::Text { text } = &parts[0] {
407 OpenRouterContent::Text(text.clone())
408 } else {
409 OpenRouterContent::Parts(parts)
410 }
411 } else {
412 OpenRouterContent::Parts(parts)
413 };
414
415 messages.push(OpenRouterMessage {
416 role: "user".to_string(),
417 content: Some(content),
418 tool_calls: None,
419 tool_call_id: None,
420 });
421 }
422
423 messages.extend(tool_result_messages);
424 }
425 LlmMessageRole::Assistant => {
426 let mut text_content = String::new();
427 let mut tool_calls: Vec<OpenRouterToolCall> = Vec::new();
428
429 for content in &msg.content {
430 match content {
431 LlmMessageContent::Text(text) => {
432 if !text.is_empty() {
433 if !text_content.is_empty() {
434 text_content.push('\n');
435 }
436 text_content.push_str(text);
437 }
438 }
439 LlmMessageContent::ToolUse(tool_use) => {
440 tool_calls.push(OpenRouterToolCall {
441 id: tool_use.id.clone(),
442 call_type: "function".to_string(),
443 function: OpenRouterFunctionCall {
444 name: tool_use.name.clone(),
445 arguments: tool_use.input.clone(),
446 },
447 });
448 }
449 _ => {}
450 }
451 }
452
453 messages.push(OpenRouterMessage {
454 role: "assistant".to_string(),
455 content: if text_content.is_empty() {
456 None
457 } else {
458 Some(OpenRouterContent::Text(text_content))
459 },
460 tool_calls: if tool_calls.is_empty() {
461 None
462 } else {
463 Some(tool_calls)
464 },
465 tool_call_id: None,
466 });
467 }
468 }
469 }
470
471 let model_def = get_model_definition(model_id);
472 let supports_tools = model_def.map(|m| m.supports_tools).unwrap_or(true);
473
474 let tools: Vec<OpenRouterTool> = if supports_tools {
475 request
476 .tools
477 .iter()
478 .map(|t| OpenRouterTool {
479 tool_type: "function".to_string(),
480 function: OpenRouterFunctionDef {
481 name: t.name.clone(),
482 description: t.description.clone(),
483 parameters: serde_json::from_str(&t.input_schema)
484 .unwrap_or(serde_json::Value::Object(Default::default())),
485 },
486 })
487 .collect()
488 } else {
489 Vec::new()
490 };
491
492 let tool_choice = if supports_tools {
493 request.tool_choice.as_ref().map(|tc| match tc {
494 LlmToolChoice::Auto => "auto".to_string(),
495 LlmToolChoice::Any => "required".to_string(),
496 LlmToolChoice::None => "none".to_string(),
497 })
498 } else {
499 None
500 };
501
502 let max_tokens = request
503 .max_tokens
504 .or(model_def.and_then(|m| m.max_output_tokens));
505
506 Ok(OpenRouterRequest {
507 model: model_id.to_string(),
508 messages,
509 max_tokens,
510 tools,
511 tool_choice,
512 stop: request.stop_sequences.clone(),
513 temperature: request.temperature,
514 stream: true,
515 })
516}
517
518fn parse_sse_line(line: &str) -> Option<OpenRouterStreamResponse> {
519 let data = line.strip_prefix("data: ")?;
520 if data.trim() == "[DONE]" {
521 return None;
522 }
523 serde_json::from_str(data).ok()
524}
525
526impl zed::Extension for OpenRouterProvider {
527 fn new() -> Self {
528 Self {
529 streams: Mutex::new(HashMap::new()),
530 next_stream_id: Mutex::new(0),
531 }
532 }
533
534 fn llm_providers(&self) -> Vec<LlmProviderInfo> {
535 vec![LlmProviderInfo {
536 id: "openrouter".into(),
537 name: "OpenRouter".into(),
538 icon: Some("icons/openrouter.svg".into()),
539 }]
540 }
541
542 fn llm_provider_models(&self, _provider_id: &str) -> Result<Vec<LlmModelInfo>, String> {
543 Ok(MODELS
544 .iter()
545 .map(|m| LlmModelInfo {
546 id: m.id.to_string(),
547 name: m.display_name.to_string(),
548 max_token_count: m.max_tokens,
549 max_output_tokens: m.max_output_tokens,
550 capabilities: LlmModelCapabilities {
551 supports_images: m.supports_images,
552 supports_tools: m.supports_tools,
553 supports_tool_choice_auto: m.supports_tools,
554 supports_tool_choice_any: m.supports_tools,
555 supports_tool_choice_none: m.supports_tools,
556 supports_thinking: false,
557 tool_input_format: LlmToolInputFormat::JsonSchema,
558 },
559 is_default: m.is_default,
560 is_default_fast: m.is_default_fast,
561 })
562 .collect())
563 }
564
565 fn llm_provider_is_authenticated(&self, _provider_id: &str) -> bool {
566 llm_get_credential("open_router").is_some()
567 }
568
569 fn llm_provider_settings_markdown(&self, _provider_id: &str) -> Option<String> {
570 Some(
571 "[Create an API key](https://openrouter.ai/keys) to use OpenRouter as your LLM provider.".to_string(),
572 )
573 }
574
575 fn llm_provider_reset_credentials(&mut self, _provider_id: &str) -> Result<(), String> {
576 llm_delete_credential("open_router")
577 }
578
579 fn llm_stream_completion_start(
580 &mut self,
581 _provider_id: &str,
582 model_id: &str,
583 request: &LlmCompletionRequest,
584 ) -> Result<String, String> {
585 let api_key = llm_get_credential("open_router").ok_or_else(|| {
586 "No API key configured. Please add your OpenRouter API key in settings.".to_string()
587 })?;
588
589 let openrouter_request = convert_request(model_id, request)?;
590
591 let body = serde_json::to_vec(&openrouter_request)
592 .map_err(|e| format!("Failed to serialize request: {}", e))?;
593
594 let http_request = HttpRequest {
595 method: HttpMethod::Post,
596 url: "https://openrouter.ai/api/v1/chat/completions".to_string(),
597 headers: vec![
598 ("Content-Type".to_string(), "application/json".to_string()),
599 ("Authorization".to_string(), format!("Bearer {}", api_key)),
600 ("HTTP-Referer".to_string(), "https://zed.dev".to_string()),
601 ("X-Title".to_string(), "Zed Editor".to_string()),
602 ],
603 body: Some(body),
604 redirect_policy: RedirectPolicy::FollowAll,
605 };
606
607 let response_stream = http_request
608 .fetch_stream()
609 .map_err(|e| format!("HTTP request failed: {}", e))?;
610
611 let stream_id = {
612 let mut id_counter = self.next_stream_id.lock().unwrap();
613 let id = format!("openrouter-stream-{}", *id_counter);
614 *id_counter += 1;
615 id
616 };
617
618 self.streams.lock().unwrap().insert(
619 stream_id.clone(),
620 StreamState {
621 response_stream: Some(response_stream),
622 buffer: String::new(),
623 started: false,
624 tool_calls: HashMap::new(),
625 tool_calls_emitted: false,
626 },
627 );
628
629 Ok(stream_id)
630 }
631
632 fn llm_stream_completion_next(
633 &mut self,
634 stream_id: &str,
635 ) -> Result<Option<LlmCompletionEvent>, String> {
636 let mut streams = self.streams.lock().unwrap();
637 let state = streams
638 .get_mut(stream_id)
639 .ok_or_else(|| format!("Unknown stream: {}", stream_id))?;
640
641 if !state.started {
642 state.started = true;
643 return Ok(Some(LlmCompletionEvent::Started));
644 }
645
646 let response_stream = state
647 .response_stream
648 .as_mut()
649 .ok_or_else(|| "Stream already closed".to_string())?;
650
651 loop {
652 if let Some(newline_pos) = state.buffer.find('\n') {
653 let line = state.buffer[..newline_pos].to_string();
654 state.buffer = state.buffer[newline_pos + 1..].to_string();
655
656 if line.trim().is_empty() {
657 continue;
658 }
659
660 if let Some(response) = parse_sse_line(&line) {
661 if let Some(choice) = response.choices.first() {
662 if let Some(content) = &choice.delta.content {
663 if !content.is_empty() {
664 return Ok(Some(LlmCompletionEvent::Text(content.clone())));
665 }
666 }
667
668 if let Some(tool_calls) = &choice.delta.tool_calls {
669 for tc in tool_calls {
670 let entry = state
671 .tool_calls
672 .entry(tc.index)
673 .or_insert_with(AccumulatedToolCall::default);
674
675 if let Some(id) = &tc.id {
676 entry.id = id.clone();
677 }
678 if let Some(func) = &tc.function {
679 if let Some(name) = &func.name {
680 entry.name = name.clone();
681 }
682 if let Some(args) = &func.arguments {
683 entry.arguments.push_str(args);
684 }
685 }
686 }
687 }
688
689 if let Some(finish_reason) = &choice.finish_reason {
690 if !state.tool_calls.is_empty() && !state.tool_calls_emitted {
691 state.tool_calls_emitted = true;
692 let mut tool_calls: Vec<_> = state.tool_calls.drain().collect();
693 tool_calls.sort_by_key(|(idx, _)| *idx);
694
695 if let Some((_, tc)) = tool_calls.into_iter().next() {
696 return Ok(Some(LlmCompletionEvent::ToolUse(LlmToolUse {
697 id: tc.id,
698 name: tc.name,
699 input: tc.arguments,
700 is_input_complete: true,
701 thought_signature: None,
702 })));
703 }
704 }
705
706 let stop_reason = match finish_reason.as_str() {
707 "stop" => LlmStopReason::EndTurn,
708 "length" => LlmStopReason::MaxTokens,
709 "tool_calls" => LlmStopReason::ToolUse,
710 "content_filter" => LlmStopReason::Refusal,
711 _ => LlmStopReason::EndTurn,
712 };
713 return Ok(Some(LlmCompletionEvent::Stop(stop_reason)));
714 }
715 }
716
717 if let Some(usage) = response.usage {
718 return Ok(Some(LlmCompletionEvent::Usage(LlmTokenUsage {
719 input_tokens: usage.prompt_tokens,
720 output_tokens: usage.completion_tokens,
721 cache_creation_input_tokens: None,
722 cache_read_input_tokens: None,
723 })));
724 }
725 }
726
727 continue;
728 }
729
730 match response_stream.next_chunk() {
731 Ok(Some(chunk)) => {
732 let text = String::from_utf8_lossy(&chunk);
733 state.buffer.push_str(&text);
734 }
735 Ok(None) => {
736 return Ok(None);
737 }
738 Err(e) => {
739 return Err(format!("Stream error: {}", e));
740 }
741 }
742 }
743 }
744
745 fn llm_stream_completion_close(&mut self, stream_id: &str) {
746 self.streams.lock().unwrap().remove(stream_id);
747 }
748}
749
750zed::register_extension!(OpenRouterProvider);