copilot_chat.rs

  1use std::collections::HashMap;
  2use std::sync::Mutex;
  3
  4use serde::{Deserialize, Serialize};
  5use zed_extension_api::http_client::{HttpMethod, HttpRequest, HttpResponseStream, RedirectPolicy};
  6use zed_extension_api::{self as zed, *};
  7
  8struct CopilotChatProvider {
  9    streams: Mutex<HashMap<String, StreamState>>,
 10    next_stream_id: Mutex<u64>,
 11}
 12
 13struct StreamState {
 14    response_stream: Option<HttpResponseStream>,
 15    buffer: String,
 16    started: bool,
 17    tool_calls: HashMap<usize, AccumulatedToolCall>,
 18    tool_calls_emitted: bool,
 19}
 20
 21#[derive(Clone, Default)]
 22struct AccumulatedToolCall {
 23    id: String,
 24    name: String,
 25    arguments: String,
 26}
 27
 28struct ModelDefinition {
 29    id: &'static str,
 30    display_name: &'static str,
 31    max_tokens: u64,
 32    max_output_tokens: Option<u64>,
 33    supports_images: bool,
 34    is_default: bool,
 35    is_default_fast: bool,
 36}
 37
 38const MODELS: &[ModelDefinition] = &[
 39    ModelDefinition {
 40        id: "gpt-4o",
 41        display_name: "GPT-4o",
 42        max_tokens: 128_000,
 43        max_output_tokens: Some(16_384),
 44        supports_images: true,
 45        is_default: true,
 46        is_default_fast: false,
 47    },
 48    ModelDefinition {
 49        id: "gpt-4o-mini",
 50        display_name: "GPT-4o Mini",
 51        max_tokens: 128_000,
 52        max_output_tokens: Some(16_384),
 53        supports_images: true,
 54        is_default: false,
 55        is_default_fast: true,
 56    },
 57    ModelDefinition {
 58        id: "gpt-4.1",
 59        display_name: "GPT-4.1",
 60        max_tokens: 1_000_000,
 61        max_output_tokens: Some(32_768),
 62        supports_images: true,
 63        is_default: false,
 64        is_default_fast: false,
 65    },
 66    ModelDefinition {
 67        id: "o1",
 68        display_name: "o1",
 69        max_tokens: 200_000,
 70        max_output_tokens: Some(100_000),
 71        supports_images: true,
 72        is_default: false,
 73        is_default_fast: false,
 74    },
 75    ModelDefinition {
 76        id: "o3-mini",
 77        display_name: "o3-mini",
 78        max_tokens: 200_000,
 79        max_output_tokens: Some(100_000),
 80        supports_images: false,
 81        is_default: false,
 82        is_default_fast: false,
 83    },
 84    ModelDefinition {
 85        id: "claude-3.5-sonnet",
 86        display_name: "Claude 3.5 Sonnet",
 87        max_tokens: 200_000,
 88        max_output_tokens: Some(8_192),
 89        supports_images: true,
 90        is_default: false,
 91        is_default_fast: false,
 92    },
 93    ModelDefinition {
 94        id: "claude-3.7-sonnet",
 95        display_name: "Claude 3.7 Sonnet",
 96        max_tokens: 200_000,
 97        max_output_tokens: Some(8_192),
 98        supports_images: true,
 99        is_default: false,
100        is_default_fast: false,
101    },
102    ModelDefinition {
103        id: "gemini-2.0-flash-001",
104        display_name: "Gemini 2.0 Flash",
105        max_tokens: 1_000_000,
106        max_output_tokens: Some(8_192),
107        supports_images: true,
108        is_default: false,
109        is_default_fast: false,
110    },
111];
112
113fn get_model_definition(model_id: &str) -> Option<&'static ModelDefinition> {
114    MODELS.iter().find(|m| m.id == model_id)
115}
116
117#[derive(Serialize)]
118struct OpenAiRequest {
119    model: String,
120    messages: Vec<OpenAiMessage>,
121    #[serde(skip_serializing_if = "Option::is_none")]
122    max_tokens: Option<u64>,
123    #[serde(skip_serializing_if = "Vec::is_empty")]
124    tools: Vec<OpenAiTool>,
125    #[serde(skip_serializing_if = "Option::is_none")]
126    tool_choice: Option<String>,
127    #[serde(skip_serializing_if = "Vec::is_empty")]
128    stop: Vec<String>,
129    #[serde(skip_serializing_if = "Option::is_none")]
130    temperature: Option<f32>,
131    stream: bool,
132    #[serde(skip_serializing_if = "Option::is_none")]
133    stream_options: Option<StreamOptions>,
134}
135
136#[derive(Serialize)]
137struct StreamOptions {
138    include_usage: bool,
139}
140
141#[derive(Serialize)]
142struct OpenAiMessage {
143    role: String,
144    #[serde(skip_serializing_if = "Option::is_none")]
145    content: Option<OpenAiContent>,
146    #[serde(skip_serializing_if = "Option::is_none")]
147    tool_calls: Option<Vec<OpenAiToolCall>>,
148    #[serde(skip_serializing_if = "Option::is_none")]
149    tool_call_id: Option<String>,
150}
151
152#[derive(Serialize, Clone)]
153#[serde(untagged)]
154enum OpenAiContent {
155    Text(String),
156    Parts(Vec<OpenAiContentPart>),
157}
158
159#[derive(Serialize, Clone)]
160#[serde(tag = "type")]
161enum OpenAiContentPart {
162    #[serde(rename = "text")]
163    Text { text: String },
164    #[serde(rename = "image_url")]
165    ImageUrl { image_url: ImageUrl },
166}
167
168#[derive(Serialize, Clone)]
169struct ImageUrl {
170    url: String,
171}
172
173#[derive(Serialize, Clone)]
174struct OpenAiToolCall {
175    id: String,
176    #[serde(rename = "type")]
177    call_type: String,
178    function: OpenAiFunctionCall,
179}
180
181#[derive(Serialize, Clone)]
182struct OpenAiFunctionCall {
183    name: String,
184    arguments: String,
185}
186
187#[derive(Serialize)]
188struct OpenAiTool {
189    #[serde(rename = "type")]
190    tool_type: String,
191    function: OpenAiFunctionDef,
192}
193
194#[derive(Serialize)]
195struct OpenAiFunctionDef {
196    name: String,
197    description: String,
198    parameters: serde_json::Value,
199}
200
201#[derive(Deserialize, Debug)]
202struct OpenAiStreamResponse {
203    choices: Vec<OpenAiStreamChoice>,
204    #[serde(default)]
205    usage: Option<OpenAiUsage>,
206}
207
208#[derive(Deserialize, Debug)]
209struct OpenAiStreamChoice {
210    delta: OpenAiDelta,
211    finish_reason: Option<String>,
212}
213
214#[derive(Deserialize, Debug, Default)]
215struct OpenAiDelta {
216    #[serde(default)]
217    content: Option<String>,
218    #[serde(default)]
219    tool_calls: Option<Vec<OpenAiToolCallDelta>>,
220}
221
222#[derive(Deserialize, Debug)]
223struct OpenAiToolCallDelta {
224    index: usize,
225    #[serde(default)]
226    id: Option<String>,
227    #[serde(default)]
228    function: Option<OpenAiFunctionDelta>,
229}
230
231#[derive(Deserialize, Debug, Default)]
232struct OpenAiFunctionDelta {
233    #[serde(default)]
234    name: Option<String>,
235    #[serde(default)]
236    arguments: Option<String>,
237}
238
239#[derive(Deserialize, Debug)]
240struct OpenAiUsage {
241    prompt_tokens: u64,
242    completion_tokens: u64,
243}
244
245fn convert_request(
246    model_id: &str,
247    request: &LlmCompletionRequest,
248) -> Result<OpenAiRequest, String> {
249    let mut messages: Vec<OpenAiMessage> = Vec::new();
250
251    for msg in &request.messages {
252        match msg.role {
253            LlmMessageRole::System => {
254                let mut text_content = String::new();
255                for content in &msg.content {
256                    if let LlmMessageContent::Text(text) = content {
257                        if !text_content.is_empty() {
258                            text_content.push('\n');
259                        }
260                        text_content.push_str(text);
261                    }
262                }
263                if !text_content.is_empty() {
264                    messages.push(OpenAiMessage {
265                        role: "system".to_string(),
266                        content: Some(OpenAiContent::Text(text_content)),
267                        tool_calls: None,
268                        tool_call_id: None,
269                    });
270                }
271            }
272            LlmMessageRole::User => {
273                let mut parts: Vec<OpenAiContentPart> = Vec::new();
274                let mut tool_result_messages: Vec<OpenAiMessage> = Vec::new();
275
276                for content in &msg.content {
277                    match content {
278                        LlmMessageContent::Text(text) => {
279                            if !text.is_empty() {
280                                parts.push(OpenAiContentPart::Text { text: text.clone() });
281                            }
282                        }
283                        LlmMessageContent::Image(img) => {
284                            let data_url = format!("data:image/png;base64,{}", img.source);
285                            parts.push(OpenAiContentPart::ImageUrl {
286                                image_url: ImageUrl { url: data_url },
287                            });
288                        }
289                        LlmMessageContent::ToolResult(result) => {
290                            let content_text = match &result.content {
291                                LlmToolResultContent::Text(t) => t.clone(),
292                                LlmToolResultContent::Image(_) => "[Image]".to_string(),
293                            };
294                            tool_result_messages.push(OpenAiMessage {
295                                role: "tool".to_string(),
296                                content: Some(OpenAiContent::Text(content_text)),
297                                tool_calls: None,
298                                tool_call_id: Some(result.tool_use_id.clone()),
299                            });
300                        }
301                        _ => {}
302                    }
303                }
304
305                if !parts.is_empty() {
306                    let content = if parts.len() == 1 {
307                        if let OpenAiContentPart::Text { text } = &parts[0] {
308                            OpenAiContent::Text(text.clone())
309                        } else {
310                            OpenAiContent::Parts(parts)
311                        }
312                    } else {
313                        OpenAiContent::Parts(parts)
314                    };
315
316                    messages.push(OpenAiMessage {
317                        role: "user".to_string(),
318                        content: Some(content),
319                        tool_calls: None,
320                        tool_call_id: None,
321                    });
322                }
323
324                messages.extend(tool_result_messages);
325            }
326            LlmMessageRole::Assistant => {
327                let mut text_content = String::new();
328                let mut tool_calls: Vec<OpenAiToolCall> = Vec::new();
329
330                for content in &msg.content {
331                    match content {
332                        LlmMessageContent::Text(text) => {
333                            if !text.is_empty() {
334                                if !text_content.is_empty() {
335                                    text_content.push('\n');
336                                }
337                                text_content.push_str(text);
338                            }
339                        }
340                        LlmMessageContent::ToolUse(tool_use) => {
341                            tool_calls.push(OpenAiToolCall {
342                                id: tool_use.id.clone(),
343                                call_type: "function".to_string(),
344                                function: OpenAiFunctionCall {
345                                    name: tool_use.name.clone(),
346                                    arguments: tool_use.input.clone(),
347                                },
348                            });
349                        }
350                        _ => {}
351                    }
352                }
353
354                messages.push(OpenAiMessage {
355                    role: "assistant".to_string(),
356                    content: if text_content.is_empty() {
357                        None
358                    } else {
359                        Some(OpenAiContent::Text(text_content))
360                    },
361                    tool_calls: if tool_calls.is_empty() {
362                        None
363                    } else {
364                        Some(tool_calls)
365                    },
366                    tool_call_id: None,
367                });
368            }
369        }
370    }
371
372    let tools: Vec<OpenAiTool> = request
373        .tools
374        .iter()
375        .map(|t| OpenAiTool {
376            tool_type: "function".to_string(),
377            function: OpenAiFunctionDef {
378                name: t.name.clone(),
379                description: t.description.clone(),
380                parameters: serde_json::from_str(&t.input_schema)
381                    .unwrap_or(serde_json::Value::Object(Default::default())),
382            },
383        })
384        .collect();
385
386    let tool_choice = request.tool_choice.as_ref().map(|tc| match tc {
387        LlmToolChoice::Auto => "auto".to_string(),
388        LlmToolChoice::Any => "required".to_string(),
389        LlmToolChoice::None => "none".to_string(),
390    });
391
392    let model_def = get_model_definition(model_id);
393    let max_tokens = request
394        .max_tokens
395        .or(model_def.and_then(|m| m.max_output_tokens));
396
397    Ok(OpenAiRequest {
398        model: model_id.to_string(),
399        messages,
400        max_tokens,
401        tools,
402        tool_choice,
403        stop: request.stop_sequences.clone(),
404        temperature: request.temperature,
405        stream: true,
406        stream_options: Some(StreamOptions {
407            include_usage: true,
408        }),
409    })
410}
411
412fn parse_sse_line(line: &str) -> Option<OpenAiStreamResponse> {
413    let data = line.strip_prefix("data: ")?;
414    if data.trim() == "[DONE]" {
415        return None;
416    }
417    serde_json::from_str(data).ok()
418}
419
420impl zed::Extension for CopilotChatProvider {
421    fn new() -> Self {
422        Self {
423            streams: Mutex::new(HashMap::new()),
424            next_stream_id: Mutex::new(0),
425        }
426    }
427
428    fn llm_providers(&self) -> Vec<LlmProviderInfo> {
429        vec![LlmProviderInfo {
430            id: "copilot_chat".into(),
431            name: "Copilot Chat".into(),
432            icon: Some("icons/copilot.svg".into()),
433        }]
434    }
435
436    fn llm_provider_models(&self, _provider_id: &str) -> Result<Vec<LlmModelInfo>, String> {
437        Ok(MODELS
438            .iter()
439            .map(|m| LlmModelInfo {
440                id: m.id.to_string(),
441                name: m.display_name.to_string(),
442                max_token_count: m.max_tokens,
443                max_output_tokens: m.max_output_tokens,
444                capabilities: LlmModelCapabilities {
445                    supports_images: m.supports_images,
446                    supports_tools: true,
447                    supports_tool_choice_auto: true,
448                    supports_tool_choice_any: true,
449                    supports_tool_choice_none: true,
450                    supports_thinking: false,
451                    tool_input_format: LlmToolInputFormat::JsonSchema,
452                },
453                is_default: m.is_default,
454                is_default_fast: m.is_default_fast,
455            })
456            .collect())
457    }
458
459    fn llm_provider_is_authenticated(&self, _provider_id: &str) -> bool {
460        llm_get_credential("copilot_chat").is_some()
461    }
462
463    fn llm_provider_settings_markdown(&self, _provider_id: &str) -> Option<String> {
464        Some(
465            r#"# Copilot Chat Setup
466
467Welcome to **Copilot Chat**! This extension provides access to GitHub Copilot's chat models.
468
469## Configuration
470
471Enter your GitHub Copilot token below. You need an active GitHub Copilot subscription.
472
473To get your token:
4741. Ensure you have a GitHub Copilot subscription
4752. Generate a token from your GitHub Copilot settings
476
477## Available Models
478
479| Model | Context | Output |
480|-------|---------|--------|
481| GPT-4o | 128K | 16K |
482| GPT-4o Mini | 128K | 16K |
483| GPT-4.1 | 1M | 32K |
484| o1 | 200K | 100K |
485| o3-mini | 200K | 100K |
486| Claude 3.5 Sonnet | 200K | 8K |
487| Claude 3.7 Sonnet | 200K | 8K |
488| Gemini 2.0 Flash | 1M | 8K |
489
490## Features
491
492- ✅ Full streaming support
493- ✅ Tool/function calling
494- ✅ Vision (image inputs)
495- ✅ Multiple model providers via Copilot
496
497## Note
498
499This extension requires an active GitHub Copilot subscription.
500"#
501            .to_string(),
502        )
503    }
504
505    fn llm_provider_reset_credentials(&mut self, _provider_id: &str) -> Result<(), String> {
506        llm_delete_credential("copilot_chat")
507    }
508
509    fn llm_stream_completion_start(
510        &mut self,
511        _provider_id: &str,
512        model_id: &str,
513        request: &LlmCompletionRequest,
514    ) -> Result<String, String> {
515        let api_key = llm_get_credential("copilot_chat").ok_or_else(|| {
516            "No token configured. Please add your GitHub Copilot token in settings.".to_string()
517        })?;
518
519        let openai_request = convert_request(model_id, request)?;
520
521        let body = serde_json::to_vec(&openai_request)
522            .map_err(|e| format!("Failed to serialize request: {}", e))?;
523
524        let http_request = HttpRequest {
525            method: HttpMethod::Post,
526            url: "https://api.githubcopilot.com/chat/completions".to_string(),
527            headers: vec![
528                ("Content-Type".to_string(), "application/json".to_string()),
529                ("Authorization".to_string(), format!("Bearer {}", api_key)),
530                (
531                    "Copilot-Integration-Id".to_string(),
532                    "vscode-chat".to_string(),
533                ),
534                ("Editor-Version".to_string(), "Zed/1.0.0".to_string()),
535            ],
536            body: Some(body),
537            redirect_policy: RedirectPolicy::FollowAll,
538        };
539
540        let response_stream = http_request
541            .fetch_stream()
542            .map_err(|e| format!("HTTP request failed: {}", e))?;
543
544        let stream_id = {
545            let mut id_counter = self.next_stream_id.lock().unwrap();
546            let id = format!("copilot-stream-{}", *id_counter);
547            *id_counter += 1;
548            id
549        };
550
551        self.streams.lock().unwrap().insert(
552            stream_id.clone(),
553            StreamState {
554                response_stream: Some(response_stream),
555                buffer: String::new(),
556                started: false,
557                tool_calls: HashMap::new(),
558                tool_calls_emitted: false,
559            },
560        );
561
562        Ok(stream_id)
563    }
564
565    fn llm_stream_completion_next(
566        &mut self,
567        stream_id: &str,
568    ) -> Result<Option<LlmCompletionEvent>, String> {
569        let mut streams = self.streams.lock().unwrap();
570        let state = streams
571            .get_mut(stream_id)
572            .ok_or_else(|| format!("Unknown stream: {}", stream_id))?;
573
574        if !state.started {
575            state.started = true;
576            return Ok(Some(LlmCompletionEvent::Started));
577        }
578
579        let response_stream = state
580            .response_stream
581            .as_mut()
582            .ok_or_else(|| "Stream already closed".to_string())?;
583
584        loop {
585            if let Some(newline_pos) = state.buffer.find('\n') {
586                let line = state.buffer[..newline_pos].to_string();
587                state.buffer = state.buffer[newline_pos + 1..].to_string();
588
589                if line.trim().is_empty() {
590                    continue;
591                }
592
593                if let Some(response) = parse_sse_line(&line) {
594                    if let Some(choice) = response.choices.first() {
595                        if let Some(content) = &choice.delta.content {
596                            if !content.is_empty() {
597                                return Ok(Some(LlmCompletionEvent::Text(content.clone())));
598                            }
599                        }
600
601                        if let Some(tool_calls) = &choice.delta.tool_calls {
602                            for tc in tool_calls {
603                                let entry = state
604                                    .tool_calls
605                                    .entry(tc.index)
606                                    .or_insert_with(AccumulatedToolCall::default);
607
608                                if let Some(id) = &tc.id {
609                                    entry.id = id.clone();
610                                }
611                                if let Some(func) = &tc.function {
612                                    if let Some(name) = &func.name {
613                                        entry.name = name.clone();
614                                    }
615                                    if let Some(args) = &func.arguments {
616                                        entry.arguments.push_str(args);
617                                    }
618                                }
619                            }
620                        }
621
622                        if let Some(finish_reason) = &choice.finish_reason {
623                            if !state.tool_calls.is_empty() && !state.tool_calls_emitted {
624                                state.tool_calls_emitted = true;
625                                let mut tool_calls: Vec<_> = state.tool_calls.drain().collect();
626                                tool_calls.sort_by_key(|(idx, _)| *idx);
627
628                                if let Some((_, tc)) = tool_calls.into_iter().next() {
629                                    return Ok(Some(LlmCompletionEvent::ToolUse(LlmToolUse {
630                                        id: tc.id,
631                                        name: tc.name,
632                                        input: tc.arguments,
633                                        thought_signature: None,
634                                    })));
635                                }
636                            }
637
638                            let stop_reason = match finish_reason.as_str() {
639                                "stop" => LlmStopReason::EndTurn,
640                                "length" => LlmStopReason::MaxTokens,
641                                "tool_calls" => LlmStopReason::ToolUse,
642                                "content_filter" => LlmStopReason::Refusal,
643                                _ => LlmStopReason::EndTurn,
644                            };
645                            return Ok(Some(LlmCompletionEvent::Stop(stop_reason)));
646                        }
647                    }
648
649                    if let Some(usage) = response.usage {
650                        return Ok(Some(LlmCompletionEvent::Usage(LlmTokenUsage {
651                            input_tokens: usage.prompt_tokens,
652                            output_tokens: usage.completion_tokens,
653                            cache_creation_input_tokens: None,
654                            cache_read_input_tokens: None,
655                        })));
656                    }
657                }
658
659                continue;
660            }
661
662            match response_stream.next_chunk() {
663                Ok(Some(chunk)) => {
664                    let text = String::from_utf8_lossy(&chunk);
665                    state.buffer.push_str(&text);
666                }
667                Ok(None) => {
668                    return Ok(None);
669                }
670                Err(e) => {
671                    return Err(format!("Stream error: {}", e));
672                }
673            }
674        }
675    }
676
677    fn llm_stream_completion_close(&mut self, stream_id: &str) {
678        self.streams.lock().unwrap().remove(stream_id);
679    }
680}
681
682zed::register_extension!(CopilotChatProvider);