1use std::collections::HashMap;
2use std::sync::Mutex;
3
4use serde::{Deserialize, Serialize};
5use zed_extension_api::http_client::{HttpMethod, HttpRequest, HttpResponseStream, RedirectPolicy};
6use zed_extension_api::{self as zed, *};
7
8struct CopilotChatProvider {
9 streams: Mutex<HashMap<String, StreamState>>,
10 next_stream_id: Mutex<u64>,
11}
12
13struct StreamState {
14 response_stream: Option<HttpResponseStream>,
15 buffer: String,
16 started: bool,
17 tool_calls: HashMap<usize, AccumulatedToolCall>,
18 tool_calls_emitted: bool,
19}
20
21#[derive(Clone, Default)]
22struct AccumulatedToolCall {
23 id: String,
24 name: String,
25 arguments: String,
26}
27
28struct ModelDefinition {
29 id: &'static str,
30 display_name: &'static str,
31 max_tokens: u64,
32 max_output_tokens: Option<u64>,
33 supports_images: bool,
34 is_default: bool,
35 is_default_fast: bool,
36}
37
38const MODELS: &[ModelDefinition] = &[
39 ModelDefinition {
40 id: "gpt-4o",
41 display_name: "GPT-4o",
42 max_tokens: 128_000,
43 max_output_tokens: Some(16_384),
44 supports_images: true,
45 is_default: true,
46 is_default_fast: false,
47 },
48 ModelDefinition {
49 id: "gpt-4o-mini",
50 display_name: "GPT-4o Mini",
51 max_tokens: 128_000,
52 max_output_tokens: Some(16_384),
53 supports_images: true,
54 is_default: false,
55 is_default_fast: true,
56 },
57 ModelDefinition {
58 id: "gpt-4.1",
59 display_name: "GPT-4.1",
60 max_tokens: 1_000_000,
61 max_output_tokens: Some(32_768),
62 supports_images: true,
63 is_default: false,
64 is_default_fast: false,
65 },
66 ModelDefinition {
67 id: "o1",
68 display_name: "o1",
69 max_tokens: 200_000,
70 max_output_tokens: Some(100_000),
71 supports_images: true,
72 is_default: false,
73 is_default_fast: false,
74 },
75 ModelDefinition {
76 id: "o3-mini",
77 display_name: "o3-mini",
78 max_tokens: 200_000,
79 max_output_tokens: Some(100_000),
80 supports_images: false,
81 is_default: false,
82 is_default_fast: false,
83 },
84 ModelDefinition {
85 id: "claude-3.5-sonnet",
86 display_name: "Claude 3.5 Sonnet",
87 max_tokens: 200_000,
88 max_output_tokens: Some(8_192),
89 supports_images: true,
90 is_default: false,
91 is_default_fast: false,
92 },
93 ModelDefinition {
94 id: "claude-3.7-sonnet",
95 display_name: "Claude 3.7 Sonnet",
96 max_tokens: 200_000,
97 max_output_tokens: Some(8_192),
98 supports_images: true,
99 is_default: false,
100 is_default_fast: false,
101 },
102 ModelDefinition {
103 id: "gemini-2.0-flash-001",
104 display_name: "Gemini 2.0 Flash",
105 max_tokens: 1_000_000,
106 max_output_tokens: Some(8_192),
107 supports_images: true,
108 is_default: false,
109 is_default_fast: false,
110 },
111];
112
113fn get_model_definition(model_id: &str) -> Option<&'static ModelDefinition> {
114 MODELS.iter().find(|m| m.id == model_id)
115}
116
117#[derive(Serialize)]
118struct OpenAiRequest {
119 model: String,
120 messages: Vec<OpenAiMessage>,
121 #[serde(skip_serializing_if = "Option::is_none")]
122 max_tokens: Option<u64>,
123 #[serde(skip_serializing_if = "Vec::is_empty")]
124 tools: Vec<OpenAiTool>,
125 #[serde(skip_serializing_if = "Option::is_none")]
126 tool_choice: Option<String>,
127 #[serde(skip_serializing_if = "Vec::is_empty")]
128 stop: Vec<String>,
129 #[serde(skip_serializing_if = "Option::is_none")]
130 temperature: Option<f32>,
131 stream: bool,
132 #[serde(skip_serializing_if = "Option::is_none")]
133 stream_options: Option<StreamOptions>,
134}
135
136#[derive(Serialize)]
137struct StreamOptions {
138 include_usage: bool,
139}
140
141#[derive(Serialize)]
142struct OpenAiMessage {
143 role: String,
144 #[serde(skip_serializing_if = "Option::is_none")]
145 content: Option<OpenAiContent>,
146 #[serde(skip_serializing_if = "Option::is_none")]
147 tool_calls: Option<Vec<OpenAiToolCall>>,
148 #[serde(skip_serializing_if = "Option::is_none")]
149 tool_call_id: Option<String>,
150}
151
152#[derive(Serialize, Clone)]
153#[serde(untagged)]
154enum OpenAiContent {
155 Text(String),
156 Parts(Vec<OpenAiContentPart>),
157}
158
159#[derive(Serialize, Clone)]
160#[serde(tag = "type")]
161enum OpenAiContentPart {
162 #[serde(rename = "text")]
163 Text { text: String },
164 #[serde(rename = "image_url")]
165 ImageUrl { image_url: ImageUrl },
166}
167
168#[derive(Serialize, Clone)]
169struct ImageUrl {
170 url: String,
171}
172
173#[derive(Serialize, Clone)]
174struct OpenAiToolCall {
175 id: String,
176 #[serde(rename = "type")]
177 call_type: String,
178 function: OpenAiFunctionCall,
179}
180
181#[derive(Serialize, Clone)]
182struct OpenAiFunctionCall {
183 name: String,
184 arguments: String,
185}
186
187#[derive(Serialize)]
188struct OpenAiTool {
189 #[serde(rename = "type")]
190 tool_type: String,
191 function: OpenAiFunctionDef,
192}
193
194#[derive(Serialize)]
195struct OpenAiFunctionDef {
196 name: String,
197 description: String,
198 parameters: serde_json::Value,
199}
200
201#[derive(Deserialize, Debug)]
202struct OpenAiStreamResponse {
203 choices: Vec<OpenAiStreamChoice>,
204 #[serde(default)]
205 usage: Option<OpenAiUsage>,
206}
207
208#[derive(Deserialize, Debug)]
209struct OpenAiStreamChoice {
210 delta: OpenAiDelta,
211 finish_reason: Option<String>,
212}
213
214#[derive(Deserialize, Debug, Default)]
215struct OpenAiDelta {
216 #[serde(default)]
217 content: Option<String>,
218 #[serde(default)]
219 tool_calls: Option<Vec<OpenAiToolCallDelta>>,
220}
221
222#[derive(Deserialize, Debug)]
223struct OpenAiToolCallDelta {
224 index: usize,
225 #[serde(default)]
226 id: Option<String>,
227 #[serde(default)]
228 function: Option<OpenAiFunctionDelta>,
229}
230
231#[derive(Deserialize, Debug, Default)]
232struct OpenAiFunctionDelta {
233 #[serde(default)]
234 name: Option<String>,
235 #[serde(default)]
236 arguments: Option<String>,
237}
238
239#[derive(Deserialize, Debug)]
240struct OpenAiUsage {
241 prompt_tokens: u64,
242 completion_tokens: u64,
243}
244
245fn convert_request(
246 model_id: &str,
247 request: &LlmCompletionRequest,
248) -> Result<OpenAiRequest, String> {
249 let mut messages: Vec<OpenAiMessage> = Vec::new();
250
251 for msg in &request.messages {
252 match msg.role {
253 LlmMessageRole::System => {
254 let mut text_content = String::new();
255 for content in &msg.content {
256 if let LlmMessageContent::Text(text) = content {
257 if !text_content.is_empty() {
258 text_content.push('\n');
259 }
260 text_content.push_str(text);
261 }
262 }
263 if !text_content.is_empty() {
264 messages.push(OpenAiMessage {
265 role: "system".to_string(),
266 content: Some(OpenAiContent::Text(text_content)),
267 tool_calls: None,
268 tool_call_id: None,
269 });
270 }
271 }
272 LlmMessageRole::User => {
273 let mut parts: Vec<OpenAiContentPart> = Vec::new();
274 let mut tool_result_messages: Vec<OpenAiMessage> = Vec::new();
275
276 for content in &msg.content {
277 match content {
278 LlmMessageContent::Text(text) => {
279 if !text.is_empty() {
280 parts.push(OpenAiContentPart::Text { text: text.clone() });
281 }
282 }
283 LlmMessageContent::Image(img) => {
284 let data_url = format!("data:image/png;base64,{}", img.source);
285 parts.push(OpenAiContentPart::ImageUrl {
286 image_url: ImageUrl { url: data_url },
287 });
288 }
289 LlmMessageContent::ToolResult(result) => {
290 let content_text = match &result.content {
291 LlmToolResultContent::Text(t) => t.clone(),
292 LlmToolResultContent::Image(_) => "[Image]".to_string(),
293 };
294 tool_result_messages.push(OpenAiMessage {
295 role: "tool".to_string(),
296 content: Some(OpenAiContent::Text(content_text)),
297 tool_calls: None,
298 tool_call_id: Some(result.tool_use_id.clone()),
299 });
300 }
301 _ => {}
302 }
303 }
304
305 if !parts.is_empty() {
306 let content = if parts.len() == 1 {
307 if let OpenAiContentPart::Text { text } = &parts[0] {
308 OpenAiContent::Text(text.clone())
309 } else {
310 OpenAiContent::Parts(parts)
311 }
312 } else {
313 OpenAiContent::Parts(parts)
314 };
315
316 messages.push(OpenAiMessage {
317 role: "user".to_string(),
318 content: Some(content),
319 tool_calls: None,
320 tool_call_id: None,
321 });
322 }
323
324 messages.extend(tool_result_messages);
325 }
326 LlmMessageRole::Assistant => {
327 let mut text_content = String::new();
328 let mut tool_calls: Vec<OpenAiToolCall> = Vec::new();
329
330 for content in &msg.content {
331 match content {
332 LlmMessageContent::Text(text) => {
333 if !text.is_empty() {
334 if !text_content.is_empty() {
335 text_content.push('\n');
336 }
337 text_content.push_str(text);
338 }
339 }
340 LlmMessageContent::ToolUse(tool_use) => {
341 tool_calls.push(OpenAiToolCall {
342 id: tool_use.id.clone(),
343 call_type: "function".to_string(),
344 function: OpenAiFunctionCall {
345 name: tool_use.name.clone(),
346 arguments: tool_use.input.clone(),
347 },
348 });
349 }
350 _ => {}
351 }
352 }
353
354 messages.push(OpenAiMessage {
355 role: "assistant".to_string(),
356 content: if text_content.is_empty() {
357 None
358 } else {
359 Some(OpenAiContent::Text(text_content))
360 },
361 tool_calls: if tool_calls.is_empty() {
362 None
363 } else {
364 Some(tool_calls)
365 },
366 tool_call_id: None,
367 });
368 }
369 }
370 }
371
372 let tools: Vec<OpenAiTool> = request
373 .tools
374 .iter()
375 .map(|t| OpenAiTool {
376 tool_type: "function".to_string(),
377 function: OpenAiFunctionDef {
378 name: t.name.clone(),
379 description: t.description.clone(),
380 parameters: serde_json::from_str(&t.input_schema)
381 .unwrap_or(serde_json::Value::Object(Default::default())),
382 },
383 })
384 .collect();
385
386 let tool_choice = request.tool_choice.as_ref().map(|tc| match tc {
387 LlmToolChoice::Auto => "auto".to_string(),
388 LlmToolChoice::Any => "required".to_string(),
389 LlmToolChoice::None => "none".to_string(),
390 });
391
392 let model_def = get_model_definition(model_id);
393 let max_tokens = request
394 .max_tokens
395 .or(model_def.and_then(|m| m.max_output_tokens));
396
397 Ok(OpenAiRequest {
398 model: model_id.to_string(),
399 messages,
400 max_tokens,
401 tools,
402 tool_choice,
403 stop: request.stop_sequences.clone(),
404 temperature: request.temperature,
405 stream: true,
406 stream_options: Some(StreamOptions {
407 include_usage: true,
408 }),
409 })
410}
411
412fn parse_sse_line(line: &str) -> Option<OpenAiStreamResponse> {
413 let data = line.strip_prefix("data: ")?;
414 if data.trim() == "[DONE]" {
415 return None;
416 }
417 serde_json::from_str(data).ok()
418}
419
420impl zed::Extension for CopilotChatProvider {
421 fn new() -> Self {
422 Self {
423 streams: Mutex::new(HashMap::new()),
424 next_stream_id: Mutex::new(0),
425 }
426 }
427
428 fn llm_providers(&self) -> Vec<LlmProviderInfo> {
429 vec![LlmProviderInfo {
430 id: "copilot_chat".into(),
431 name: "Copilot Chat".into(),
432 icon: Some("icons/copilot.svg".into()),
433 }]
434 }
435
436 fn llm_provider_models(&self, _provider_id: &str) -> Result<Vec<LlmModelInfo>, String> {
437 Ok(MODELS
438 .iter()
439 .map(|m| LlmModelInfo {
440 id: m.id.to_string(),
441 name: m.display_name.to_string(),
442 max_token_count: m.max_tokens,
443 max_output_tokens: m.max_output_tokens,
444 capabilities: LlmModelCapabilities {
445 supports_images: m.supports_images,
446 supports_tools: true,
447 supports_tool_choice_auto: true,
448 supports_tool_choice_any: true,
449 supports_tool_choice_none: true,
450 supports_thinking: false,
451 tool_input_format: LlmToolInputFormat::JsonSchema,
452 },
453 is_default: m.is_default,
454 is_default_fast: m.is_default_fast,
455 })
456 .collect())
457 }
458
459 fn llm_provider_is_authenticated(&self, _provider_id: &str) -> bool {
460 llm_get_credential("copilot_chat").is_some()
461 }
462
463 fn llm_provider_settings_markdown(&self, _provider_id: &str) -> Option<String> {
464 Some(
465 r#"# Copilot Chat Setup
466
467Welcome to **Copilot Chat**! This extension provides access to GitHub Copilot's chat models.
468
469## Configuration
470
471Enter your GitHub Copilot token below. You need an active GitHub Copilot subscription.
472
473To get your token:
4741. Ensure you have a GitHub Copilot subscription
4752. Generate a token from your GitHub Copilot settings
476
477## Available Models
478
479| Model | Context | Output |
480|-------|---------|--------|
481| GPT-4o | 128K | 16K |
482| GPT-4o Mini | 128K | 16K |
483| GPT-4.1 | 1M | 32K |
484| o1 | 200K | 100K |
485| o3-mini | 200K | 100K |
486| Claude 3.5 Sonnet | 200K | 8K |
487| Claude 3.7 Sonnet | 200K | 8K |
488| Gemini 2.0 Flash | 1M | 8K |
489
490## Features
491
492- ✅ Full streaming support
493- ✅ Tool/function calling
494- ✅ Vision (image inputs)
495- ✅ Multiple model providers via Copilot
496
497## Note
498
499This extension requires an active GitHub Copilot subscription.
500"#
501 .to_string(),
502 )
503 }
504
505 fn llm_provider_reset_credentials(&mut self, _provider_id: &str) -> Result<(), String> {
506 llm_delete_credential("copilot_chat")
507 }
508
509 fn llm_stream_completion_start(
510 &mut self,
511 _provider_id: &str,
512 model_id: &str,
513 request: &LlmCompletionRequest,
514 ) -> Result<String, String> {
515 let api_key = llm_get_credential("copilot_chat").ok_or_else(|| {
516 "No token configured. Please add your GitHub Copilot token in settings.".to_string()
517 })?;
518
519 let openai_request = convert_request(model_id, request)?;
520
521 let body = serde_json::to_vec(&openai_request)
522 .map_err(|e| format!("Failed to serialize request: {}", e))?;
523
524 let http_request = HttpRequest {
525 method: HttpMethod::Post,
526 url: "https://api.githubcopilot.com/chat/completions".to_string(),
527 headers: vec![
528 ("Content-Type".to_string(), "application/json".to_string()),
529 ("Authorization".to_string(), format!("Bearer {}", api_key)),
530 (
531 "Copilot-Integration-Id".to_string(),
532 "vscode-chat".to_string(),
533 ),
534 ("Editor-Version".to_string(), "Zed/1.0.0".to_string()),
535 ],
536 body: Some(body),
537 redirect_policy: RedirectPolicy::FollowAll,
538 };
539
540 let response_stream = http_request
541 .fetch_stream()
542 .map_err(|e| format!("HTTP request failed: {}", e))?;
543
544 let stream_id = {
545 let mut id_counter = self.next_stream_id.lock().unwrap();
546 let id = format!("copilot-stream-{}", *id_counter);
547 *id_counter += 1;
548 id
549 };
550
551 self.streams.lock().unwrap().insert(
552 stream_id.clone(),
553 StreamState {
554 response_stream: Some(response_stream),
555 buffer: String::new(),
556 started: false,
557 tool_calls: HashMap::new(),
558 tool_calls_emitted: false,
559 },
560 );
561
562 Ok(stream_id)
563 }
564
565 fn llm_stream_completion_next(
566 &mut self,
567 stream_id: &str,
568 ) -> Result<Option<LlmCompletionEvent>, String> {
569 let mut streams = self.streams.lock().unwrap();
570 let state = streams
571 .get_mut(stream_id)
572 .ok_or_else(|| format!("Unknown stream: {}", stream_id))?;
573
574 if !state.started {
575 state.started = true;
576 return Ok(Some(LlmCompletionEvent::Started));
577 }
578
579 let response_stream = state
580 .response_stream
581 .as_mut()
582 .ok_or_else(|| "Stream already closed".to_string())?;
583
584 loop {
585 if let Some(newline_pos) = state.buffer.find('\n') {
586 let line = state.buffer[..newline_pos].to_string();
587 state.buffer = state.buffer[newline_pos + 1..].to_string();
588
589 if line.trim().is_empty() {
590 continue;
591 }
592
593 if let Some(response) = parse_sse_line(&line) {
594 if let Some(choice) = response.choices.first() {
595 if let Some(content) = &choice.delta.content {
596 if !content.is_empty() {
597 return Ok(Some(LlmCompletionEvent::Text(content.clone())));
598 }
599 }
600
601 if let Some(tool_calls) = &choice.delta.tool_calls {
602 for tc in tool_calls {
603 let entry = state
604 .tool_calls
605 .entry(tc.index)
606 .or_insert_with(AccumulatedToolCall::default);
607
608 if let Some(id) = &tc.id {
609 entry.id = id.clone();
610 }
611 if let Some(func) = &tc.function {
612 if let Some(name) = &func.name {
613 entry.name = name.clone();
614 }
615 if let Some(args) = &func.arguments {
616 entry.arguments.push_str(args);
617 }
618 }
619 }
620 }
621
622 if let Some(finish_reason) = &choice.finish_reason {
623 if !state.tool_calls.is_empty() && !state.tool_calls_emitted {
624 state.tool_calls_emitted = true;
625 let mut tool_calls: Vec<_> = state.tool_calls.drain().collect();
626 tool_calls.sort_by_key(|(idx, _)| *idx);
627
628 if let Some((_, tc)) = tool_calls.into_iter().next() {
629 return Ok(Some(LlmCompletionEvent::ToolUse(LlmToolUse {
630 id: tc.id,
631 name: tc.name,
632 input: tc.arguments,
633 thought_signature: None,
634 })));
635 }
636 }
637
638 let stop_reason = match finish_reason.as_str() {
639 "stop" => LlmStopReason::EndTurn,
640 "length" => LlmStopReason::MaxTokens,
641 "tool_calls" => LlmStopReason::ToolUse,
642 "content_filter" => LlmStopReason::Refusal,
643 _ => LlmStopReason::EndTurn,
644 };
645 return Ok(Some(LlmCompletionEvent::Stop(stop_reason)));
646 }
647 }
648
649 if let Some(usage) = response.usage {
650 return Ok(Some(LlmCompletionEvent::Usage(LlmTokenUsage {
651 input_tokens: usage.prompt_tokens,
652 output_tokens: usage.completion_tokens,
653 cache_creation_input_tokens: None,
654 cache_read_input_tokens: None,
655 })));
656 }
657 }
658
659 continue;
660 }
661
662 match response_stream.next_chunk() {
663 Ok(Some(chunk)) => {
664 let text = String::from_utf8_lossy(&chunk);
665 state.buffer.push_str(&text);
666 }
667 Ok(None) => {
668 return Ok(None);
669 }
670 Err(e) => {
671 return Err(format!("Stream error: {}", e));
672 }
673 }
674 }
675 }
676
677 fn llm_stream_completion_close(&mut self, stream_id: &str) {
678 self.streams.lock().unwrap().remove(stream_id);
679 }
680}
681
682zed::register_extension!(CopilotChatProvider);