open_router.rs

  1use std::collections::HashMap;
  2use std::sync::Mutex;
  3
  4use serde::{Deserialize, Serialize};
  5use zed_extension_api::http_client::{HttpMethod, HttpRequest, HttpResponseStream, RedirectPolicy};
  6use zed_extension_api::{self as zed, *};
  7
  8struct OpenRouterProvider {
  9    streams: Mutex<HashMap<String, StreamState>>,
 10    next_stream_id: Mutex<u64>,
 11}
 12
 13struct StreamState {
 14    response_stream: Option<HttpResponseStream>,
 15    buffer: String,
 16    started: bool,
 17    tool_calls: HashMap<usize, AccumulatedToolCall>,
 18    tool_calls_emitted: bool,
 19}
 20
 21#[derive(Clone, Default)]
 22struct AccumulatedToolCall {
 23    id: String,
 24    name: String,
 25    arguments: String,
 26}
 27
 28struct ModelDefinition {
 29    id: &'static str,
 30    display_name: &'static str,
 31    max_tokens: u64,
 32    max_output_tokens: Option<u64>,
 33    supports_images: bool,
 34    supports_tools: bool,
 35    is_default: bool,
 36    is_default_fast: bool,
 37}
 38
 39const MODELS: &[ModelDefinition] = &[
 40    // Anthropic Models
 41    ModelDefinition {
 42        id: "anthropic/claude-sonnet-4",
 43        display_name: "Claude Sonnet 4",
 44        max_tokens: 200_000,
 45        max_output_tokens: Some(8_192),
 46        supports_images: true,
 47        supports_tools: true,
 48        is_default: true,
 49        is_default_fast: false,
 50    },
 51    ModelDefinition {
 52        id: "anthropic/claude-opus-4",
 53        display_name: "Claude Opus 4",
 54        max_tokens: 200_000,
 55        max_output_tokens: Some(8_192),
 56        supports_images: true,
 57        supports_tools: true,
 58        is_default: false,
 59        is_default_fast: false,
 60    },
 61    ModelDefinition {
 62        id: "anthropic/claude-haiku-4",
 63        display_name: "Claude Haiku 4",
 64        max_tokens: 200_000,
 65        max_output_tokens: Some(8_192),
 66        supports_images: true,
 67        supports_tools: true,
 68        is_default: false,
 69        is_default_fast: true,
 70    },
 71    ModelDefinition {
 72        id: "anthropic/claude-3.5-sonnet",
 73        display_name: "Claude 3.5 Sonnet",
 74        max_tokens: 200_000,
 75        max_output_tokens: Some(8_192),
 76        supports_images: true,
 77        supports_tools: true,
 78        is_default: false,
 79        is_default_fast: false,
 80    },
 81    // OpenAI Models
 82    ModelDefinition {
 83        id: "openai/gpt-4o",
 84        display_name: "GPT-4o",
 85        max_tokens: 128_000,
 86        max_output_tokens: Some(16_384),
 87        supports_images: true,
 88        supports_tools: true,
 89        is_default: false,
 90        is_default_fast: false,
 91    },
 92    ModelDefinition {
 93        id: "openai/gpt-4o-mini",
 94        display_name: "GPT-4o Mini",
 95        max_tokens: 128_000,
 96        max_output_tokens: Some(16_384),
 97        supports_images: true,
 98        supports_tools: true,
 99        is_default: false,
100        is_default_fast: false,
101    },
102    ModelDefinition {
103        id: "openai/o1",
104        display_name: "o1",
105        max_tokens: 200_000,
106        max_output_tokens: Some(100_000),
107        supports_images: true,
108        supports_tools: false,
109        is_default: false,
110        is_default_fast: false,
111    },
112    ModelDefinition {
113        id: "openai/o3-mini",
114        display_name: "o3-mini",
115        max_tokens: 200_000,
116        max_output_tokens: Some(100_000),
117        supports_images: false,
118        supports_tools: false,
119        is_default: false,
120        is_default_fast: false,
121    },
122    // Google Models
123    ModelDefinition {
124        id: "google/gemini-2.0-flash-001",
125        display_name: "Gemini 2.0 Flash",
126        max_tokens: 1_000_000,
127        max_output_tokens: Some(8_192),
128        supports_images: true,
129        supports_tools: true,
130        is_default: false,
131        is_default_fast: false,
132    },
133    ModelDefinition {
134        id: "google/gemini-2.5-pro-preview",
135        display_name: "Gemini 2.5 Pro",
136        max_tokens: 1_000_000,
137        max_output_tokens: Some(8_192),
138        supports_images: true,
139        supports_tools: true,
140        is_default: false,
141        is_default_fast: false,
142    },
143    // Meta Models
144    ModelDefinition {
145        id: "meta-llama/llama-3.3-70b-instruct",
146        display_name: "Llama 3.3 70B",
147        max_tokens: 128_000,
148        max_output_tokens: Some(4_096),
149        supports_images: false,
150        supports_tools: true,
151        is_default: false,
152        is_default_fast: false,
153    },
154    ModelDefinition {
155        id: "meta-llama/llama-4-maverick",
156        display_name: "Llama 4 Maverick",
157        max_tokens: 128_000,
158        max_output_tokens: Some(4_096),
159        supports_images: true,
160        supports_tools: true,
161        is_default: false,
162        is_default_fast: false,
163    },
164    // Mistral Models
165    ModelDefinition {
166        id: "mistralai/mistral-large-2411",
167        display_name: "Mistral Large",
168        max_tokens: 128_000,
169        max_output_tokens: Some(4_096),
170        supports_images: false,
171        supports_tools: true,
172        is_default: false,
173        is_default_fast: false,
174    },
175    ModelDefinition {
176        id: "mistralai/codestral-latest",
177        display_name: "Codestral",
178        max_tokens: 32_000,
179        max_output_tokens: Some(4_096),
180        supports_images: false,
181        supports_tools: true,
182        is_default: false,
183        is_default_fast: false,
184    },
185    // DeepSeek Models
186    ModelDefinition {
187        id: "deepseek/deepseek-chat-v3-0324",
188        display_name: "DeepSeek V3",
189        max_tokens: 64_000,
190        max_output_tokens: Some(8_192),
191        supports_images: false,
192        supports_tools: true,
193        is_default: false,
194        is_default_fast: false,
195    },
196    ModelDefinition {
197        id: "deepseek/deepseek-r1",
198        display_name: "DeepSeek R1",
199        max_tokens: 64_000,
200        max_output_tokens: Some(8_192),
201        supports_images: false,
202        supports_tools: false,
203        is_default: false,
204        is_default_fast: false,
205    },
206    // Qwen Models
207    ModelDefinition {
208        id: "qwen/qwen3-235b-a22b",
209        display_name: "Qwen 3 235B",
210        max_tokens: 40_000,
211        max_output_tokens: Some(8_192),
212        supports_images: false,
213        supports_tools: true,
214        is_default: false,
215        is_default_fast: false,
216    },
217];
218
219fn get_model_definition(model_id: &str) -> Option<&'static ModelDefinition> {
220    MODELS.iter().find(|m| m.id == model_id)
221}
222
223#[derive(Serialize)]
224struct OpenRouterRequest {
225    model: String,
226    messages: Vec<OpenRouterMessage>,
227    #[serde(skip_serializing_if = "Option::is_none")]
228    max_tokens: Option<u64>,
229    #[serde(skip_serializing_if = "Vec::is_empty")]
230    tools: Vec<OpenRouterTool>,
231    #[serde(skip_serializing_if = "Option::is_none")]
232    tool_choice: Option<String>,
233    #[serde(skip_serializing_if = "Vec::is_empty")]
234    stop: Vec<String>,
235    #[serde(skip_serializing_if = "Option::is_none")]
236    temperature: Option<f32>,
237    stream: bool,
238}
239
240#[derive(Serialize)]
241struct OpenRouterMessage {
242    role: String,
243    #[serde(skip_serializing_if = "Option::is_none")]
244    content: Option<OpenRouterContent>,
245    #[serde(skip_serializing_if = "Option::is_none")]
246    tool_calls: Option<Vec<OpenRouterToolCall>>,
247    #[serde(skip_serializing_if = "Option::is_none")]
248    tool_call_id: Option<String>,
249}
250
251#[derive(Serialize, Clone)]
252#[serde(untagged)]
253enum OpenRouterContent {
254    Text(String),
255    Parts(Vec<OpenRouterContentPart>),
256}
257
258#[derive(Serialize, Clone)]
259#[serde(tag = "type")]
260enum OpenRouterContentPart {
261    #[serde(rename = "text")]
262    Text { text: String },
263    #[serde(rename = "image_url")]
264    ImageUrl { image_url: ImageUrl },
265}
266
267#[derive(Serialize, Clone)]
268struct ImageUrl {
269    url: String,
270}
271
272#[derive(Serialize, Clone)]
273struct OpenRouterToolCall {
274    id: String,
275    #[serde(rename = "type")]
276    call_type: String,
277    function: OpenRouterFunctionCall,
278}
279
280#[derive(Serialize, Clone)]
281struct OpenRouterFunctionCall {
282    name: String,
283    arguments: String,
284}
285
286#[derive(Serialize)]
287struct OpenRouterTool {
288    #[serde(rename = "type")]
289    tool_type: String,
290    function: OpenRouterFunctionDef,
291}
292
293#[derive(Serialize)]
294struct OpenRouterFunctionDef {
295    name: String,
296    description: String,
297    parameters: serde_json::Value,
298}
299
300#[derive(Deserialize, Debug)]
301struct OpenRouterStreamResponse {
302    choices: Vec<OpenRouterStreamChoice>,
303    #[serde(default)]
304    usage: Option<OpenRouterUsage>,
305}
306
307#[derive(Deserialize, Debug)]
308struct OpenRouterStreamChoice {
309    delta: OpenRouterDelta,
310    finish_reason: Option<String>,
311}
312
313#[derive(Deserialize, Debug, Default)]
314struct OpenRouterDelta {
315    #[serde(default)]
316    content: Option<String>,
317    #[serde(default)]
318    tool_calls: Option<Vec<OpenRouterToolCallDelta>>,
319}
320
321#[derive(Deserialize, Debug)]
322struct OpenRouterToolCallDelta {
323    index: usize,
324    #[serde(default)]
325    id: Option<String>,
326    #[serde(default)]
327    function: Option<OpenRouterFunctionDelta>,
328}
329
330#[derive(Deserialize, Debug, Default)]
331struct OpenRouterFunctionDelta {
332    #[serde(default)]
333    name: Option<String>,
334    #[serde(default)]
335    arguments: Option<String>,
336}
337
338#[derive(Deserialize, Debug)]
339struct OpenRouterUsage {
340    prompt_tokens: u64,
341    completion_tokens: u64,
342}
343
344fn convert_request(
345    model_id: &str,
346    request: &LlmCompletionRequest,
347) -> Result<OpenRouterRequest, String> {
348    let mut messages: Vec<OpenRouterMessage> = Vec::new();
349
350    for msg in &request.messages {
351        match msg.role {
352            LlmMessageRole::System => {
353                let mut text_content = String::new();
354                for content in &msg.content {
355                    if let LlmMessageContent::Text(text) = content {
356                        if !text_content.is_empty() {
357                            text_content.push('\n');
358                        }
359                        text_content.push_str(text);
360                    }
361                }
362                if !text_content.is_empty() {
363                    messages.push(OpenRouterMessage {
364                        role: "system".to_string(),
365                        content: Some(OpenRouterContent::Text(text_content)),
366                        tool_calls: None,
367                        tool_call_id: None,
368                    });
369                }
370            }
371            LlmMessageRole::User => {
372                let mut parts: Vec<OpenRouterContentPart> = Vec::new();
373                let mut tool_result_messages: Vec<OpenRouterMessage> = Vec::new();
374
375                for content in &msg.content {
376                    match content {
377                        LlmMessageContent::Text(text) => {
378                            if !text.is_empty() {
379                                parts.push(OpenRouterContentPart::Text { text: text.clone() });
380                            }
381                        }
382                        LlmMessageContent::Image(img) => {
383                            let data_url = format!("data:image/png;base64,{}", img.source);
384                            parts.push(OpenRouterContentPart::ImageUrl {
385                                image_url: ImageUrl { url: data_url },
386                            });
387                        }
388                        LlmMessageContent::ToolResult(result) => {
389                            let content_text = match &result.content {
390                                LlmToolResultContent::Text(t) => t.clone(),
391                                LlmToolResultContent::Image(_) => "[Image]".to_string(),
392                            };
393                            tool_result_messages.push(OpenRouterMessage {
394                                role: "tool".to_string(),
395                                content: Some(OpenRouterContent::Text(content_text)),
396                                tool_calls: None,
397                                tool_call_id: Some(result.tool_use_id.clone()),
398                            });
399                        }
400                        _ => {}
401                    }
402                }
403
404                if !parts.is_empty() {
405                    let content = if parts.len() == 1 {
406                        if let OpenRouterContentPart::Text { text } = &parts[0] {
407                            OpenRouterContent::Text(text.clone())
408                        } else {
409                            OpenRouterContent::Parts(parts)
410                        }
411                    } else {
412                        OpenRouterContent::Parts(parts)
413                    };
414
415                    messages.push(OpenRouterMessage {
416                        role: "user".to_string(),
417                        content: Some(content),
418                        tool_calls: None,
419                        tool_call_id: None,
420                    });
421                }
422
423                messages.extend(tool_result_messages);
424            }
425            LlmMessageRole::Assistant => {
426                let mut text_content = String::new();
427                let mut tool_calls: Vec<OpenRouterToolCall> = Vec::new();
428
429                for content in &msg.content {
430                    match content {
431                        LlmMessageContent::Text(text) => {
432                            if !text.is_empty() {
433                                if !text_content.is_empty() {
434                                    text_content.push('\n');
435                                }
436                                text_content.push_str(text);
437                            }
438                        }
439                        LlmMessageContent::ToolUse(tool_use) => {
440                            tool_calls.push(OpenRouterToolCall {
441                                id: tool_use.id.clone(),
442                                call_type: "function".to_string(),
443                                function: OpenRouterFunctionCall {
444                                    name: tool_use.name.clone(),
445                                    arguments: tool_use.input.clone(),
446                                },
447                            });
448                        }
449                        _ => {}
450                    }
451                }
452
453                messages.push(OpenRouterMessage {
454                    role: "assistant".to_string(),
455                    content: if text_content.is_empty() {
456                        None
457                    } else {
458                        Some(OpenRouterContent::Text(text_content))
459                    },
460                    tool_calls: if tool_calls.is_empty() {
461                        None
462                    } else {
463                        Some(tool_calls)
464                    },
465                    tool_call_id: None,
466                });
467            }
468        }
469    }
470
471    let model_def = get_model_definition(model_id);
472    let supports_tools = model_def.map(|m| m.supports_tools).unwrap_or(true);
473
474    let tools: Vec<OpenRouterTool> = if supports_tools {
475        request
476            .tools
477            .iter()
478            .map(|t| OpenRouterTool {
479                tool_type: "function".to_string(),
480                function: OpenRouterFunctionDef {
481                    name: t.name.clone(),
482                    description: t.description.clone(),
483                    parameters: serde_json::from_str(&t.input_schema)
484                        .unwrap_or(serde_json::Value::Object(Default::default())),
485                },
486            })
487            .collect()
488    } else {
489        Vec::new()
490    };
491
492    let tool_choice = if supports_tools {
493        request.tool_choice.as_ref().map(|tc| match tc {
494            LlmToolChoice::Auto => "auto".to_string(),
495            LlmToolChoice::Any => "required".to_string(),
496            LlmToolChoice::None => "none".to_string(),
497        })
498    } else {
499        None
500    };
501
502    let max_tokens = request
503        .max_tokens
504        .or(model_def.and_then(|m| m.max_output_tokens));
505
506    Ok(OpenRouterRequest {
507        model: model_id.to_string(),
508        messages,
509        max_tokens,
510        tools,
511        tool_choice,
512        stop: request.stop_sequences.clone(),
513        temperature: request.temperature,
514        stream: true,
515    })
516}
517
518fn parse_sse_line(line: &str) -> Option<OpenRouterStreamResponse> {
519    let data = line.strip_prefix("data: ")?;
520    if data.trim() == "[DONE]" {
521        return None;
522    }
523    serde_json::from_str(data).ok()
524}
525
526impl zed::Extension for OpenRouterProvider {
527    fn new() -> Self {
528        Self {
529            streams: Mutex::new(HashMap::new()),
530            next_stream_id: Mutex::new(0),
531        }
532    }
533
534    fn llm_providers(&self) -> Vec<LlmProviderInfo> {
535        vec![LlmProviderInfo {
536            id: "openrouter".into(),
537            name: "OpenRouter".into(),
538            icon: Some("icons/openrouter.svg".into()),
539        }]
540    }
541
542    fn llm_provider_models(&self, _provider_id: &str) -> Result<Vec<LlmModelInfo>, String> {
543        Ok(MODELS
544            .iter()
545            .map(|m| LlmModelInfo {
546                id: m.id.to_string(),
547                name: m.display_name.to_string(),
548                max_token_count: m.max_tokens,
549                max_output_tokens: m.max_output_tokens,
550                capabilities: LlmModelCapabilities {
551                    supports_images: m.supports_images,
552                    supports_tools: m.supports_tools,
553                    supports_tool_choice_auto: m.supports_tools,
554                    supports_tool_choice_any: m.supports_tools,
555                    supports_tool_choice_none: m.supports_tools,
556                    supports_thinking: false,
557                    tool_input_format: LlmToolInputFormat::JsonSchema,
558                },
559                is_default: m.is_default,
560                is_default_fast: m.is_default_fast,
561            })
562            .collect())
563    }
564
565    fn llm_provider_is_authenticated(&self, _provider_id: &str) -> bool {
566        llm_get_credential("open_router").is_some()
567    }
568
569    fn llm_provider_settings_markdown(&self, _provider_id: &str) -> Option<String> {
570        Some(
571            "[Create an API key](https://openrouter.ai/keys) to use OpenRouter as your LLM provider.".to_string(),
572        )
573    }
574
575    fn llm_provider_reset_credentials(&mut self, _provider_id: &str) -> Result<(), String> {
576        llm_delete_credential("open_router")
577    }
578
579    fn llm_stream_completion_start(
580        &mut self,
581        _provider_id: &str,
582        model_id: &str,
583        request: &LlmCompletionRequest,
584    ) -> Result<String, String> {
585        let api_key = llm_get_credential("open_router").ok_or_else(|| {
586            "No API key configured. Please add your OpenRouter API key in settings.".to_string()
587        })?;
588
589        let openrouter_request = convert_request(model_id, request)?;
590
591        let body = serde_json::to_vec(&openrouter_request)
592            .map_err(|e| format!("Failed to serialize request: {}", e))?;
593
594        let http_request = HttpRequest {
595            method: HttpMethod::Post,
596            url: "https://openrouter.ai/api/v1/chat/completions".to_string(),
597            headers: vec![
598                ("Content-Type".to_string(), "application/json".to_string()),
599                ("Authorization".to_string(), format!("Bearer {}", api_key)),
600                ("HTTP-Referer".to_string(), "https://zed.dev".to_string()),
601                ("X-Title".to_string(), "Zed Editor".to_string()),
602            ],
603            body: Some(body),
604            redirect_policy: RedirectPolicy::FollowAll,
605        };
606
607        let response_stream = http_request
608            .fetch_stream()
609            .map_err(|e| format!("HTTP request failed: {}", e))?;
610
611        let stream_id = {
612            let mut id_counter = self.next_stream_id.lock().unwrap();
613            let id = format!("openrouter-stream-{}", *id_counter);
614            *id_counter += 1;
615            id
616        };
617
618        self.streams.lock().unwrap().insert(
619            stream_id.clone(),
620            StreamState {
621                response_stream: Some(response_stream),
622                buffer: String::new(),
623                started: false,
624                tool_calls: HashMap::new(),
625                tool_calls_emitted: false,
626            },
627        );
628
629        Ok(stream_id)
630    }
631
632    fn llm_stream_completion_next(
633        &mut self,
634        stream_id: &str,
635    ) -> Result<Option<LlmCompletionEvent>, String> {
636        let mut streams = self.streams.lock().unwrap();
637        let state = streams
638            .get_mut(stream_id)
639            .ok_or_else(|| format!("Unknown stream: {}", stream_id))?;
640
641        if !state.started {
642            state.started = true;
643            return Ok(Some(LlmCompletionEvent::Started));
644        }
645
646        let response_stream = state
647            .response_stream
648            .as_mut()
649            .ok_or_else(|| "Stream already closed".to_string())?;
650
651        loop {
652            if let Some(newline_pos) = state.buffer.find('\n') {
653                let line = state.buffer[..newline_pos].to_string();
654                state.buffer = state.buffer[newline_pos + 1..].to_string();
655
656                if line.trim().is_empty() {
657                    continue;
658                }
659
660                if let Some(response) = parse_sse_line(&line) {
661                    if let Some(choice) = response.choices.first() {
662                        if let Some(content) = &choice.delta.content {
663                            if !content.is_empty() {
664                                return Ok(Some(LlmCompletionEvent::Text(content.clone())));
665                            }
666                        }
667
668                        if let Some(tool_calls) = &choice.delta.tool_calls {
669                            for tc in tool_calls {
670                                let entry = state
671                                    .tool_calls
672                                    .entry(tc.index)
673                                    .or_insert_with(AccumulatedToolCall::default);
674
675                                if let Some(id) = &tc.id {
676                                    entry.id = id.clone();
677                                }
678                                if let Some(func) = &tc.function {
679                                    if let Some(name) = &func.name {
680                                        entry.name = name.clone();
681                                    }
682                                    if let Some(args) = &func.arguments {
683                                        entry.arguments.push_str(args);
684                                    }
685                                }
686                            }
687                        }
688
689                        if let Some(finish_reason) = &choice.finish_reason {
690                            if !state.tool_calls.is_empty() && !state.tool_calls_emitted {
691                                state.tool_calls_emitted = true;
692                                let mut tool_calls: Vec<_> = state.tool_calls.drain().collect();
693                                tool_calls.sort_by_key(|(idx, _)| *idx);
694
695                                if let Some((_, tc)) = tool_calls.into_iter().next() {
696                                    return Ok(Some(LlmCompletionEvent::ToolUse(LlmToolUse {
697                                        id: tc.id,
698                                        name: tc.name,
699                                        input: tc.arguments,
700                                        is_input_complete: true,
701                                        thought_signature: None,
702                                    })));
703                                }
704                            }
705
706                            let stop_reason = match finish_reason.as_str() {
707                                "stop" => LlmStopReason::EndTurn,
708                                "length" => LlmStopReason::MaxTokens,
709                                "tool_calls" => LlmStopReason::ToolUse,
710                                "content_filter" => LlmStopReason::Refusal,
711                                _ => LlmStopReason::EndTurn,
712                            };
713                            return Ok(Some(LlmCompletionEvent::Stop(stop_reason)));
714                        }
715                    }
716
717                    if let Some(usage) = response.usage {
718                        return Ok(Some(LlmCompletionEvent::Usage(LlmTokenUsage {
719                            input_tokens: usage.prompt_tokens,
720                            output_tokens: usage.completion_tokens,
721                            cache_creation_input_tokens: None,
722                            cache_read_input_tokens: None,
723                        })));
724                    }
725                }
726
727                continue;
728            }
729
730            match response_stream.next_chunk() {
731                Ok(Some(chunk)) => {
732                    let text = String::from_utf8_lossy(&chunk);
733                    state.buffer.push_str(&text);
734                }
735                Ok(None) => {
736                    return Ok(None);
737                }
738                Err(e) => {
739                    return Err(format!("Stream error: {}", e));
740                }
741            }
742        }
743    }
744
745    fn llm_stream_completion_close(&mut self, stream_id: &str) {
746        self.streams.lock().unwrap().remove(stream_id);
747    }
748}
749
750zed::register_extension!(OpenRouterProvider);