1use std::collections::HashMap;
2use std::sync::Mutex;
3
4use serde::{Deserialize, Serialize};
5use zed_extension_api::http_client::{HttpMethod, HttpRequest, HttpResponseStream, RedirectPolicy};
6use zed_extension_api::{self as zed, *};
7
8struct CopilotChatProvider {
9 streams: Mutex<HashMap<String, StreamState>>,
10 next_stream_id: Mutex<u64>,
11}
12
13struct StreamState {
14 response_stream: Option<HttpResponseStream>,
15 buffer: String,
16 started: bool,
17 tool_calls: HashMap<usize, AccumulatedToolCall>,
18 tool_calls_emitted: bool,
19}
20
21#[derive(Clone, Default)]
22struct AccumulatedToolCall {
23 id: String,
24 name: String,
25 arguments: String,
26}
27
28struct ModelDefinition {
29 id: &'static str,
30 display_name: &'static str,
31 max_tokens: u64,
32 max_output_tokens: Option<u64>,
33 supports_images: bool,
34 is_default: bool,
35 is_default_fast: bool,
36}
37
38const MODELS: &[ModelDefinition] = &[
39 ModelDefinition {
40 id: "gpt-4o",
41 display_name: "GPT-4o",
42 max_tokens: 128_000,
43 max_output_tokens: Some(16_384),
44 supports_images: true,
45 is_default: true,
46 is_default_fast: false,
47 },
48 ModelDefinition {
49 id: "gpt-4o-mini",
50 display_name: "GPT-4o Mini",
51 max_tokens: 128_000,
52 max_output_tokens: Some(16_384),
53 supports_images: true,
54 is_default: false,
55 is_default_fast: true,
56 },
57 ModelDefinition {
58 id: "gpt-4.1",
59 display_name: "GPT-4.1",
60 max_tokens: 1_000_000,
61 max_output_tokens: Some(32_768),
62 supports_images: true,
63 is_default: false,
64 is_default_fast: false,
65 },
66 ModelDefinition {
67 id: "o1",
68 display_name: "o1",
69 max_tokens: 200_000,
70 max_output_tokens: Some(100_000),
71 supports_images: true,
72 is_default: false,
73 is_default_fast: false,
74 },
75 ModelDefinition {
76 id: "o3-mini",
77 display_name: "o3-mini",
78 max_tokens: 200_000,
79 max_output_tokens: Some(100_000),
80 supports_images: false,
81 is_default: false,
82 is_default_fast: false,
83 },
84 ModelDefinition {
85 id: "claude-3.5-sonnet",
86 display_name: "Claude 3.5 Sonnet",
87 max_tokens: 200_000,
88 max_output_tokens: Some(8_192),
89 supports_images: true,
90 is_default: false,
91 is_default_fast: false,
92 },
93 ModelDefinition {
94 id: "claude-3.7-sonnet",
95 display_name: "Claude 3.7 Sonnet",
96 max_tokens: 200_000,
97 max_output_tokens: Some(8_192),
98 supports_images: true,
99 is_default: false,
100 is_default_fast: false,
101 },
102 ModelDefinition {
103 id: "gemini-2.0-flash-001",
104 display_name: "Gemini 2.0 Flash",
105 max_tokens: 1_000_000,
106 max_output_tokens: Some(8_192),
107 supports_images: true,
108 is_default: false,
109 is_default_fast: false,
110 },
111];
112
113fn get_model_definition(model_id: &str) -> Option<&'static ModelDefinition> {
114 MODELS.iter().find(|m| m.id == model_id)
115}
116
117#[derive(Serialize)]
118struct OpenAiRequest {
119 model: String,
120 messages: Vec<OpenAiMessage>,
121 #[serde(skip_serializing_if = "Option::is_none")]
122 max_tokens: Option<u64>,
123 #[serde(skip_serializing_if = "Vec::is_empty")]
124 tools: Vec<OpenAiTool>,
125 #[serde(skip_serializing_if = "Option::is_none")]
126 tool_choice: Option<String>,
127 #[serde(skip_serializing_if = "Vec::is_empty")]
128 stop: Vec<String>,
129 #[serde(skip_serializing_if = "Option::is_none")]
130 temperature: Option<f32>,
131 stream: bool,
132 #[serde(skip_serializing_if = "Option::is_none")]
133 stream_options: Option<StreamOptions>,
134}
135
136#[derive(Serialize)]
137struct StreamOptions {
138 include_usage: bool,
139}
140
141#[derive(Serialize)]
142struct OpenAiMessage {
143 role: String,
144 #[serde(skip_serializing_if = "Option::is_none")]
145 content: Option<OpenAiContent>,
146 #[serde(skip_serializing_if = "Option::is_none")]
147 tool_calls: Option<Vec<OpenAiToolCall>>,
148 #[serde(skip_serializing_if = "Option::is_none")]
149 tool_call_id: Option<String>,
150}
151
152#[derive(Serialize, Clone)]
153#[serde(untagged)]
154enum OpenAiContent {
155 Text(String),
156 Parts(Vec<OpenAiContentPart>),
157}
158
159#[derive(Serialize, Clone)]
160#[serde(tag = "type")]
161enum OpenAiContentPart {
162 #[serde(rename = "text")]
163 Text { text: String },
164 #[serde(rename = "image_url")]
165 ImageUrl { image_url: ImageUrl },
166}
167
168#[derive(Serialize, Clone)]
169struct ImageUrl {
170 url: String,
171}
172
173#[derive(Serialize, Clone)]
174struct OpenAiToolCall {
175 id: String,
176 #[serde(rename = "type")]
177 call_type: String,
178 function: OpenAiFunctionCall,
179}
180
181#[derive(Serialize, Clone)]
182struct OpenAiFunctionCall {
183 name: String,
184 arguments: String,
185}
186
187#[derive(Serialize)]
188struct OpenAiTool {
189 #[serde(rename = "type")]
190 tool_type: String,
191 function: OpenAiFunctionDef,
192}
193
194#[derive(Serialize)]
195struct OpenAiFunctionDef {
196 name: String,
197 description: String,
198 parameters: serde_json::Value,
199}
200
201#[derive(Deserialize, Debug)]
202struct OpenAiStreamResponse {
203 choices: Vec<OpenAiStreamChoice>,
204 #[serde(default)]
205 usage: Option<OpenAiUsage>,
206}
207
208#[derive(Deserialize, Debug)]
209struct OpenAiStreamChoice {
210 delta: OpenAiDelta,
211 finish_reason: Option<String>,
212}
213
214#[derive(Deserialize, Debug, Default)]
215struct OpenAiDelta {
216 #[serde(default)]
217 content: Option<String>,
218 #[serde(default)]
219 tool_calls: Option<Vec<OpenAiToolCallDelta>>,
220}
221
222#[derive(Deserialize, Debug)]
223struct OpenAiToolCallDelta {
224 index: usize,
225 #[serde(default)]
226 id: Option<String>,
227 #[serde(default)]
228 function: Option<OpenAiFunctionDelta>,
229}
230
231#[derive(Deserialize, Debug, Default)]
232struct OpenAiFunctionDelta {
233 #[serde(default)]
234 name: Option<String>,
235 #[serde(default)]
236 arguments: Option<String>,
237}
238
239#[derive(Deserialize, Debug)]
240struct OpenAiUsage {
241 prompt_tokens: u64,
242 completion_tokens: u64,
243}
244
245fn convert_request(
246 model_id: &str,
247 request: &LlmCompletionRequest,
248) -> Result<OpenAiRequest, String> {
249 let mut messages: Vec<OpenAiMessage> = Vec::new();
250
251 for msg in &request.messages {
252 match msg.role {
253 LlmMessageRole::System => {
254 let mut text_content = String::new();
255 for content in &msg.content {
256 if let LlmMessageContent::Text(text) = content {
257 if !text_content.is_empty() {
258 text_content.push('\n');
259 }
260 text_content.push_str(text);
261 }
262 }
263 if !text_content.is_empty() {
264 messages.push(OpenAiMessage {
265 role: "system".to_string(),
266 content: Some(OpenAiContent::Text(text_content)),
267 tool_calls: None,
268 tool_call_id: None,
269 });
270 }
271 }
272 LlmMessageRole::User => {
273 let mut parts: Vec<OpenAiContentPart> = Vec::new();
274 let mut tool_result_messages: Vec<OpenAiMessage> = Vec::new();
275
276 for content in &msg.content {
277 match content {
278 LlmMessageContent::Text(text) => {
279 if !text.is_empty() {
280 parts.push(OpenAiContentPart::Text { text: text.clone() });
281 }
282 }
283 LlmMessageContent::Image(img) => {
284 let data_url = format!("data:image/png;base64,{}", img.source);
285 parts.push(OpenAiContentPart::ImageUrl {
286 image_url: ImageUrl { url: data_url },
287 });
288 }
289 LlmMessageContent::ToolResult(result) => {
290 let content_text = match &result.content {
291 LlmToolResultContent::Text(t) => t.clone(),
292 LlmToolResultContent::Image(_) => "[Image]".to_string(),
293 };
294 tool_result_messages.push(OpenAiMessage {
295 role: "tool".to_string(),
296 content: Some(OpenAiContent::Text(content_text)),
297 tool_calls: None,
298 tool_call_id: Some(result.tool_use_id.clone()),
299 });
300 }
301 _ => {}
302 }
303 }
304
305 if !parts.is_empty() {
306 let content = if parts.len() == 1 {
307 if let OpenAiContentPart::Text { text } = &parts[0] {
308 OpenAiContent::Text(text.clone())
309 } else {
310 OpenAiContent::Parts(parts)
311 }
312 } else {
313 OpenAiContent::Parts(parts)
314 };
315
316 messages.push(OpenAiMessage {
317 role: "user".to_string(),
318 content: Some(content),
319 tool_calls: None,
320 tool_call_id: None,
321 });
322 }
323
324 messages.extend(tool_result_messages);
325 }
326 LlmMessageRole::Assistant => {
327 let mut text_content = String::new();
328 let mut tool_calls: Vec<OpenAiToolCall> = Vec::new();
329
330 for content in &msg.content {
331 match content {
332 LlmMessageContent::Text(text) => {
333 if !text.is_empty() {
334 if !text_content.is_empty() {
335 text_content.push('\n');
336 }
337 text_content.push_str(text);
338 }
339 }
340 LlmMessageContent::ToolUse(tool_use) => {
341 tool_calls.push(OpenAiToolCall {
342 id: tool_use.id.clone(),
343 call_type: "function".to_string(),
344 function: OpenAiFunctionCall {
345 name: tool_use.name.clone(),
346 arguments: tool_use.input.clone(),
347 },
348 });
349 }
350 _ => {}
351 }
352 }
353
354 messages.push(OpenAiMessage {
355 role: "assistant".to_string(),
356 content: if text_content.is_empty() {
357 None
358 } else {
359 Some(OpenAiContent::Text(text_content))
360 },
361 tool_calls: if tool_calls.is_empty() {
362 None
363 } else {
364 Some(tool_calls)
365 },
366 tool_call_id: None,
367 });
368 }
369 }
370 }
371
372 let tools: Vec<OpenAiTool> = request
373 .tools
374 .iter()
375 .map(|t| OpenAiTool {
376 tool_type: "function".to_string(),
377 function: OpenAiFunctionDef {
378 name: t.name.clone(),
379 description: t.description.clone(),
380 parameters: serde_json::from_str(&t.input_schema)
381 .unwrap_or(serde_json::Value::Object(Default::default())),
382 },
383 })
384 .collect();
385
386 let tool_choice = request.tool_choice.as_ref().map(|tc| match tc {
387 LlmToolChoice::Auto => "auto".to_string(),
388 LlmToolChoice::Any => "required".to_string(),
389 LlmToolChoice::None => "none".to_string(),
390 });
391
392 let model_def = get_model_definition(model_id);
393 let max_tokens = request
394 .max_tokens
395 .or(model_def.and_then(|m| m.max_output_tokens));
396
397 Ok(OpenAiRequest {
398 model: model_id.to_string(),
399 messages,
400 max_tokens,
401 tools,
402 tool_choice,
403 stop: request.stop_sequences.clone(),
404 temperature: request.temperature,
405 stream: true,
406 stream_options: Some(StreamOptions {
407 include_usage: true,
408 }),
409 })
410}
411
412fn parse_sse_line(line: &str) -> Option<OpenAiStreamResponse> {
413 let data = line.strip_prefix("data: ")?;
414 if data.trim() == "[DONE]" {
415 return None;
416 }
417 serde_json::from_str(data).ok()
418}
419
420impl zed::Extension for CopilotChatProvider {
421 fn new() -> Self {
422 Self {
423 streams: Mutex::new(HashMap::new()),
424 next_stream_id: Mutex::new(0),
425 }
426 }
427
428 fn llm_providers(&self) -> Vec<LlmProviderInfo> {
429 vec![LlmProviderInfo {
430 id: "copilot_chat".into(),
431 name: "Copilot Chat".into(),
432 icon: Some("icons/copilot.svg".into()),
433 }]
434 }
435
436 fn llm_provider_models(&self, _provider_id: &str) -> Result<Vec<LlmModelInfo>, String> {
437 Ok(MODELS
438 .iter()
439 .map(|m| LlmModelInfo {
440 id: m.id.to_string(),
441 name: m.display_name.to_string(),
442 max_token_count: m.max_tokens,
443 max_output_tokens: m.max_output_tokens,
444 capabilities: LlmModelCapabilities {
445 supports_images: m.supports_images,
446 supports_tools: true,
447 supports_tool_choice_auto: true,
448 supports_tool_choice_any: true,
449 supports_tool_choice_none: true,
450 supports_thinking: false,
451 tool_input_format: LlmToolInputFormat::JsonSchema,
452 },
453 is_default: m.is_default,
454 is_default_fast: m.is_default_fast,
455 })
456 .collect())
457 }
458
459 fn llm_provider_is_authenticated(&self, _provider_id: &str) -> bool {
460 llm_get_credential("copilot_chat").is_some()
461 }
462
463 fn llm_provider_settings_markdown(&self, _provider_id: &str) -> Option<String> {
464 Some(
465 r#"# Copilot Chat Setup
466
467Welcome to **Copilot Chat**! This extension provides access to GitHub Copilot's chat models.
468
469## Configuration
470
471Enter your GitHub Copilot token below. You need an active GitHub Copilot subscription.
472
473To get your token:
4741. Ensure you have a GitHub Copilot subscription
4752. Generate a token from your GitHub Copilot settings
476
477## Available Models
478
479| Model | Context | Output |
480|-------|---------|--------|
481| GPT-4o | 128K | 16K |
482| GPT-4o Mini | 128K | 16K |
483| GPT-4.1 | 1M | 32K |
484| o1 | 200K | 100K |
485| o3-mini | 200K | 100K |
486| Claude 3.5 Sonnet | 200K | 8K |
487| Claude 3.7 Sonnet | 200K | 8K |
488| Gemini 2.0 Flash | 1M | 8K |
489
490## Features
491
492- ✅ Full streaming support
493- ✅ Tool/function calling
494- ✅ Vision (image inputs)
495- ✅ Multiple model providers via Copilot
496
497## Note
498
499This extension requires an active GitHub Copilot subscription.
500"#
501 .to_string(),
502 )
503 }
504
505 fn llm_provider_authenticate(&mut self, _provider_id: &str) -> Result<(), String> {
506 let provided = llm_request_credential(
507 "copilot_chat",
508 LlmCredentialType::ApiKey,
509 "GitHub Copilot Token",
510 "ghu_...",
511 )?;
512 if provided {
513 Ok(())
514 } else {
515 Err("Authentication cancelled".to_string())
516 }
517 }
518
519 fn llm_provider_reset_credentials(&mut self, _provider_id: &str) -> Result<(), String> {
520 llm_delete_credential("copilot_chat")
521 }
522
523 fn llm_stream_completion_start(
524 &mut self,
525 _provider_id: &str,
526 model_id: &str,
527 request: &LlmCompletionRequest,
528 ) -> Result<String, String> {
529 let api_key = llm_get_credential("copilot_chat").ok_or_else(|| {
530 "No token configured. Please add your GitHub Copilot token in settings.".to_string()
531 })?;
532
533 let openai_request = convert_request(model_id, request)?;
534
535 let body = serde_json::to_vec(&openai_request)
536 .map_err(|e| format!("Failed to serialize request: {}", e))?;
537
538 let http_request = HttpRequest {
539 method: HttpMethod::Post,
540 url: "https://api.githubcopilot.com/chat/completions".to_string(),
541 headers: vec![
542 ("Content-Type".to_string(), "application/json".to_string()),
543 ("Authorization".to_string(), format!("Bearer {}", api_key)),
544 (
545 "Copilot-Integration-Id".to_string(),
546 "vscode-chat".to_string(),
547 ),
548 ("Editor-Version".to_string(), "Zed/1.0.0".to_string()),
549 ],
550 body: Some(body),
551 redirect_policy: RedirectPolicy::FollowAll,
552 };
553
554 let response_stream = http_request
555 .fetch_stream()
556 .map_err(|e| format!("HTTP request failed: {}", e))?;
557
558 let stream_id = {
559 let mut id_counter = self.next_stream_id.lock().unwrap();
560 let id = format!("copilot-stream-{}", *id_counter);
561 *id_counter += 1;
562 id
563 };
564
565 self.streams.lock().unwrap().insert(
566 stream_id.clone(),
567 StreamState {
568 response_stream: Some(response_stream),
569 buffer: String::new(),
570 started: false,
571 tool_calls: HashMap::new(),
572 tool_calls_emitted: false,
573 },
574 );
575
576 Ok(stream_id)
577 }
578
579 fn llm_stream_completion_next(
580 &mut self,
581 stream_id: &str,
582 ) -> Result<Option<LlmCompletionEvent>, String> {
583 let mut streams = self.streams.lock().unwrap();
584 let state = streams
585 .get_mut(stream_id)
586 .ok_or_else(|| format!("Unknown stream: {}", stream_id))?;
587
588 if !state.started {
589 state.started = true;
590 return Ok(Some(LlmCompletionEvent::Started));
591 }
592
593 let response_stream = state
594 .response_stream
595 .as_mut()
596 .ok_or_else(|| "Stream already closed".to_string())?;
597
598 loop {
599 if let Some(newline_pos) = state.buffer.find('\n') {
600 let line = state.buffer[..newline_pos].to_string();
601 state.buffer = state.buffer[newline_pos + 1..].to_string();
602
603 if line.trim().is_empty() {
604 continue;
605 }
606
607 if let Some(response) = parse_sse_line(&line) {
608 if let Some(choice) = response.choices.first() {
609 if let Some(content) = &choice.delta.content {
610 if !content.is_empty() {
611 return Ok(Some(LlmCompletionEvent::Text(content.clone())));
612 }
613 }
614
615 if let Some(tool_calls) = &choice.delta.tool_calls {
616 for tc in tool_calls {
617 let entry = state
618 .tool_calls
619 .entry(tc.index)
620 .or_insert_with(AccumulatedToolCall::default);
621
622 if let Some(id) = &tc.id {
623 entry.id = id.clone();
624 }
625 if let Some(func) = &tc.function {
626 if let Some(name) = &func.name {
627 entry.name = name.clone();
628 }
629 if let Some(args) = &func.arguments {
630 entry.arguments.push_str(args);
631 }
632 }
633 }
634 }
635
636 if let Some(finish_reason) = &choice.finish_reason {
637 if !state.tool_calls.is_empty() && !state.tool_calls_emitted {
638 state.tool_calls_emitted = true;
639 let mut tool_calls: Vec<_> = state.tool_calls.drain().collect();
640 tool_calls.sort_by_key(|(idx, _)| *idx);
641
642 if let Some((_, tc)) = tool_calls.into_iter().next() {
643 return Ok(Some(LlmCompletionEvent::ToolUse(LlmToolUse {
644 id: tc.id,
645 name: tc.name,
646 input: tc.arguments,
647 thought_signature: None,
648 })));
649 }
650 }
651
652 let stop_reason = match finish_reason.as_str() {
653 "stop" => LlmStopReason::EndTurn,
654 "length" => LlmStopReason::MaxTokens,
655 "tool_calls" => LlmStopReason::ToolUse,
656 "content_filter" => LlmStopReason::Refusal,
657 _ => LlmStopReason::EndTurn,
658 };
659 return Ok(Some(LlmCompletionEvent::Stop(stop_reason)));
660 }
661 }
662
663 if let Some(usage) = response.usage {
664 return Ok(Some(LlmCompletionEvent::Usage(LlmTokenUsage {
665 input_tokens: usage.prompt_tokens,
666 output_tokens: usage.completion_tokens,
667 cache_creation_input_tokens: None,
668 cache_read_input_tokens: None,
669 })));
670 }
671 }
672
673 continue;
674 }
675
676 match response_stream.next_chunk() {
677 Ok(Some(chunk)) => {
678 let text = String::from_utf8_lossy(&chunk);
679 state.buffer.push_str(&text);
680 }
681 Ok(None) => {
682 return Ok(None);
683 }
684 Err(e) => {
685 return Err(format!("Stream error: {}", e));
686 }
687 }
688 }
689 }
690
691 fn llm_stream_completion_close(&mut self, stream_id: &str) {
692 self.streams.lock().unwrap().remove(stream_id);
693 }
694}
695
696zed::register_extension!(CopilotChatProvider);