1use std::collections::HashMap;
2use std::sync::Mutex;
3use std::thread;
4use std::time::Duration;
5
6use serde::{Deserialize, Serialize};
7use zed_extension_api::http_client::{HttpMethod, HttpRequest, HttpResponseStream, RedirectPolicy};
8use zed_extension_api::{self as zed, *};
9
10const GITHUB_DEVICE_CODE_URL: &str = "https://github.com/login/device/code";
11const GITHUB_ACCESS_TOKEN_URL: &str = "https://github.com/login/oauth/access_token";
12const GITHUB_COPILOT_TOKEN_URL: &str = "https://api.github.com/copilot_internal/v2/token";
13const GITHUB_COPILOT_CLIENT_ID: &str = "Iv1.b507a08c87ecfe98";
14
15struct DeviceFlowState {
16 device_code: String,
17 interval: u64,
18 expires_in: u64,
19}
20
21#[derive(Clone)]
22struct ApiToken {
23 api_key: String,
24 api_endpoint: String,
25}
26
27#[derive(Clone, Deserialize)]
28struct CopilotModel {
29 id: String,
30 name: String,
31 #[serde(default)]
32 is_chat_default: bool,
33 #[serde(default)]
34 is_chat_fallback: bool,
35 #[serde(default)]
36 model_picker_enabled: bool,
37 #[serde(default)]
38 capabilities: ModelCapabilities,
39 #[serde(default)]
40 policy: Option<ModelPolicy>,
41}
42
43#[derive(Clone, Default, Deserialize)]
44struct ModelCapabilities {
45 #[serde(default)]
46 family: String,
47 #[serde(default)]
48 limits: ModelLimits,
49 #[serde(default)]
50 supports: ModelSupportedFeatures,
51 #[serde(rename = "type", default)]
52 model_type: String,
53}
54
55#[derive(Clone, Default, Deserialize)]
56struct ModelLimits {
57 #[serde(default)]
58 max_context_window_tokens: u64,
59 #[serde(default)]
60 max_output_tokens: u64,
61}
62
63#[derive(Clone, Default, Deserialize)]
64struct ModelSupportedFeatures {
65 #[serde(default)]
66 streaming: bool,
67 #[serde(default)]
68 tool_calls: bool,
69 #[serde(default)]
70 vision: bool,
71}
72
73#[derive(Clone, Deserialize)]
74struct ModelPolicy {
75 state: String,
76}
77
78struct CopilotChatProvider {
79 streams: Mutex<HashMap<String, StreamState>>,
80 next_stream_id: Mutex<u64>,
81 device_flow_state: Mutex<Option<DeviceFlowState>>,
82 api_token: Mutex<Option<ApiToken>>,
83 cached_models: Mutex<Option<Vec<CopilotModel>>>,
84}
85
86struct StreamState {
87 response_stream: Option<HttpResponseStream>,
88 buffer: String,
89 started: bool,
90 tool_calls: HashMap<usize, AccumulatedToolCall>,
91 tool_calls_emitted: bool,
92}
93
94#[derive(Clone, Default)]
95struct AccumulatedToolCall {
96 id: String,
97 name: String,
98 arguments: String,
99}
100
101#[derive(Serialize)]
102struct OpenAiRequest {
103 model: String,
104 messages: Vec<OpenAiMessage>,
105 #[serde(skip_serializing_if = "Option::is_none")]
106 max_tokens: Option<u64>,
107 #[serde(skip_serializing_if = "Vec::is_empty")]
108 tools: Vec<OpenAiTool>,
109 #[serde(skip_serializing_if = "Option::is_none")]
110 tool_choice: Option<String>,
111 #[serde(skip_serializing_if = "Vec::is_empty")]
112 stop: Vec<String>,
113 #[serde(skip_serializing_if = "Option::is_none")]
114 temperature: Option<f32>,
115 stream: bool,
116 #[serde(skip_serializing_if = "Option::is_none")]
117 stream_options: Option<StreamOptions>,
118}
119
120#[derive(Serialize)]
121struct StreamOptions {
122 include_usage: bool,
123}
124
125#[derive(Serialize)]
126struct OpenAiMessage {
127 role: String,
128 #[serde(skip_serializing_if = "Option::is_none")]
129 content: Option<OpenAiContent>,
130 #[serde(skip_serializing_if = "Option::is_none")]
131 tool_calls: Option<Vec<OpenAiToolCall>>,
132 #[serde(skip_serializing_if = "Option::is_none")]
133 tool_call_id: Option<String>,
134}
135
136#[derive(Serialize, Clone)]
137#[serde(untagged)]
138enum OpenAiContent {
139 Text(String),
140 Parts(Vec<OpenAiContentPart>),
141}
142
143#[derive(Serialize, Clone)]
144#[serde(tag = "type")]
145enum OpenAiContentPart {
146 #[serde(rename = "text")]
147 Text { text: String },
148 #[serde(rename = "image_url")]
149 ImageUrl { image_url: ImageUrl },
150}
151
152#[derive(Serialize, Clone)]
153struct ImageUrl {
154 url: String,
155}
156
157#[derive(Serialize, Clone)]
158struct OpenAiToolCall {
159 id: String,
160 #[serde(rename = "type")]
161 call_type: String,
162 function: OpenAiFunctionCall,
163}
164
165#[derive(Serialize, Clone)]
166struct OpenAiFunctionCall {
167 name: String,
168 arguments: String,
169}
170
171#[derive(Serialize)]
172struct OpenAiTool {
173 #[serde(rename = "type")]
174 tool_type: String,
175 function: OpenAiFunctionDef,
176}
177
178#[derive(Serialize)]
179struct OpenAiFunctionDef {
180 name: String,
181 description: String,
182 parameters: serde_json::Value,
183}
184
185#[derive(Deserialize, Debug)]
186struct OpenAiStreamResponse {
187 choices: Vec<OpenAiStreamChoice>,
188 #[serde(default)]
189 usage: Option<OpenAiUsage>,
190}
191
192#[derive(Deserialize, Debug)]
193struct OpenAiStreamChoice {
194 delta: OpenAiDelta,
195 finish_reason: Option<String>,
196}
197
198#[derive(Deserialize, Debug, Default)]
199struct OpenAiDelta {
200 #[serde(default)]
201 content: Option<String>,
202 #[serde(default)]
203 tool_calls: Option<Vec<OpenAiToolCallDelta>>,
204}
205
206#[derive(Deserialize, Debug)]
207struct OpenAiToolCallDelta {
208 index: usize,
209 #[serde(default)]
210 id: Option<String>,
211 #[serde(default)]
212 function: Option<OpenAiFunctionDelta>,
213}
214
215#[derive(Deserialize, Debug, Default)]
216struct OpenAiFunctionDelta {
217 #[serde(default)]
218 name: Option<String>,
219 #[serde(default)]
220 arguments: Option<String>,
221}
222
223#[derive(Deserialize, Debug)]
224struct OpenAiUsage {
225 prompt_tokens: u64,
226 completion_tokens: u64,
227}
228
229fn convert_request(
230 model_id: &str,
231 request: &LlmCompletionRequest,
232) -> Result<OpenAiRequest, String> {
233 let mut messages: Vec<OpenAiMessage> = Vec::new();
234
235 for msg in &request.messages {
236 match msg.role {
237 LlmMessageRole::System => {
238 let mut text_content = String::new();
239 for content in &msg.content {
240 if let LlmMessageContent::Text(text) = content {
241 if !text_content.is_empty() {
242 text_content.push('\n');
243 }
244 text_content.push_str(text);
245 }
246 }
247 if !text_content.is_empty() {
248 messages.push(OpenAiMessage {
249 role: "system".to_string(),
250 content: Some(OpenAiContent::Text(text_content)),
251 tool_calls: None,
252 tool_call_id: None,
253 });
254 }
255 }
256 LlmMessageRole::User => {
257 let mut parts: Vec<OpenAiContentPart> = Vec::new();
258 let mut tool_result_messages: Vec<OpenAiMessage> = Vec::new();
259
260 for content in &msg.content {
261 match content {
262 LlmMessageContent::Text(text) => {
263 if !text.is_empty() {
264 parts.push(OpenAiContentPart::Text { text: text.clone() });
265 }
266 }
267 LlmMessageContent::Image(img) => {
268 let data_url = format!("data:image/png;base64,{}", img.source);
269 parts.push(OpenAiContentPart::ImageUrl {
270 image_url: ImageUrl { url: data_url },
271 });
272 }
273 LlmMessageContent::ToolResult(result) => {
274 let content_text = match &result.content {
275 LlmToolResultContent::Text(t) => t.clone(),
276 LlmToolResultContent::Image(_) => "[Image]".to_string(),
277 };
278 tool_result_messages.push(OpenAiMessage {
279 role: "tool".to_string(),
280 content: Some(OpenAiContent::Text(content_text)),
281 tool_calls: None,
282 tool_call_id: Some(result.tool_use_id.clone()),
283 });
284 }
285 _ => {}
286 }
287 }
288
289 if !parts.is_empty() {
290 let content = if parts.len() == 1 {
291 if let OpenAiContentPart::Text { text } = &parts[0] {
292 OpenAiContent::Text(text.clone())
293 } else {
294 OpenAiContent::Parts(parts)
295 }
296 } else {
297 OpenAiContent::Parts(parts)
298 };
299
300 messages.push(OpenAiMessage {
301 role: "user".to_string(),
302 content: Some(content),
303 tool_calls: None,
304 tool_call_id: None,
305 });
306 }
307
308 messages.extend(tool_result_messages);
309 }
310 LlmMessageRole::Assistant => {
311 let mut text_content = String::new();
312 let mut tool_calls: Vec<OpenAiToolCall> = Vec::new();
313
314 for content in &msg.content {
315 match content {
316 LlmMessageContent::Text(text) => {
317 if !text.is_empty() {
318 if !text_content.is_empty() {
319 text_content.push('\n');
320 }
321 text_content.push_str(text);
322 }
323 }
324 LlmMessageContent::ToolUse(tool_use) => {
325 tool_calls.push(OpenAiToolCall {
326 id: tool_use.id.clone(),
327 call_type: "function".to_string(),
328 function: OpenAiFunctionCall {
329 name: tool_use.name.clone(),
330 arguments: tool_use.input.clone(),
331 },
332 });
333 }
334 _ => {}
335 }
336 }
337
338 messages.push(OpenAiMessage {
339 role: "assistant".to_string(),
340 content: if text_content.is_empty() {
341 None
342 } else {
343 Some(OpenAiContent::Text(text_content))
344 },
345 tool_calls: if tool_calls.is_empty() {
346 None
347 } else {
348 Some(tool_calls)
349 },
350 tool_call_id: None,
351 });
352 }
353 }
354 }
355
356 let tools: Vec<OpenAiTool> = request
357 .tools
358 .iter()
359 .map(|t| OpenAiTool {
360 tool_type: "function".to_string(),
361 function: OpenAiFunctionDef {
362 name: t.name.clone(),
363 description: t.description.clone(),
364 parameters: serde_json::from_str(&t.input_schema)
365 .unwrap_or(serde_json::Value::Object(Default::default())),
366 },
367 })
368 .collect();
369
370 let tool_choice = request.tool_choice.as_ref().map(|tc| match tc {
371 LlmToolChoice::Auto => "auto".to_string(),
372 LlmToolChoice::Any => "required".to_string(),
373 LlmToolChoice::None => "none".to_string(),
374 });
375
376 let max_tokens = request.max_tokens;
377
378 Ok(OpenAiRequest {
379 model: model_id.to_string(),
380 messages,
381 max_tokens,
382 tools,
383 tool_choice,
384 stop: request.stop_sequences.clone(),
385 temperature: request.temperature,
386 stream: true,
387 stream_options: Some(StreamOptions {
388 include_usage: true,
389 }),
390 })
391}
392
393fn parse_sse_line(line: &str) -> Option<OpenAiStreamResponse> {
394 let data = line.strip_prefix("data: ")?;
395 if data.trim() == "[DONE]" {
396 return None;
397 }
398 serde_json::from_str(data).ok()
399}
400
401impl zed::Extension for CopilotChatProvider {
402 fn new() -> Self {
403 Self {
404 streams: Mutex::new(HashMap::new()),
405 next_stream_id: Mutex::new(0),
406 device_flow_state: Mutex::new(None),
407 api_token: Mutex::new(None),
408 cached_models: Mutex::new(None),
409 }
410 }
411
412 fn llm_providers(&self) -> Vec<LlmProviderInfo> {
413 vec![LlmProviderInfo {
414 id: "copilot-chat".into(),
415 name: "Copilot Chat".into(),
416 icon: Some("icons/copilot.svg".into()),
417 }]
418 }
419
420 fn llm_provider_models(&self, _provider_id: &str) -> Result<Vec<LlmModelInfo>, String> {
421 // Try to get models from cache first
422 if let Some(models) = self.cached_models.lock().unwrap().as_ref() {
423 return Ok(convert_models_to_llm_info(models));
424 }
425
426 // Need to fetch models - requires authentication
427 let oauth_token = match llm_get_credential("copilot-chat") {
428 Some(token) => token,
429 None => return Ok(Vec::new()), // Not authenticated, return empty
430 };
431
432 // Get API token
433 let api_token = self.get_api_token(&oauth_token)?;
434
435 // Fetch models from API
436 let models = self.fetch_models(&api_token)?;
437
438 // Cache the models
439 *self.cached_models.lock().unwrap() = Some(models.clone());
440
441 Ok(convert_models_to_llm_info(&models))
442 }
443
444 fn llm_provider_is_authenticated(&self, _provider_id: &str) -> bool {
445 llm_get_credential("copilot-chat").is_some()
446 }
447
448 fn llm_provider_settings_markdown(&self, _provider_id: &str) -> Option<String> {
449 Some(
450 r#"# Copilot Chat Setup
451
452Welcome to **Copilot Chat**! This extension provides access to GitHub Copilot's chat models.
453
454## Authentication
455
456Click **Sign in with GitHub** to authenticate with your GitHub account. You'll be redirected to GitHub to authorize access. This requires an active GitHub Copilot subscription.
457
458Alternatively, you can set the `GH_COPILOT_TOKEN` environment variable with your token.
459
460## Available Models
461
462| Model | Context | Output |
463|-------|---------|--------|
464| GPT-4o | 128K | 16K |
465| GPT-4o Mini | 128K | 16K |
466| GPT-4.1 | 1M | 32K |
467| o1 | 200K | 100K |
468| o3-mini | 200K | 100K |
469| Claude 3.5 Sonnet | 200K | 8K |
470| Claude 3.7 Sonnet | 200K | 8K |
471| Gemini 2.0 Flash | 1M | 8K |
472
473## Features
474
475- ✅ Full streaming support
476- ✅ Tool/function calling
477- ✅ Vision (image inputs)
478- ✅ Multiple model providers via Copilot
479
480## Note
481
482This extension requires an active GitHub Copilot subscription.
483"#
484 .to_string(),
485 )
486 }
487
488 fn llm_provider_authenticate(&mut self, _provider_id: &str) -> Result<(), String> {
489 // Check if we have existing credentials
490 if llm_get_credential("copilot-chat").is_some() {
491 return Ok(());
492 }
493
494 // No credentials found - return error for background auth checks.
495 // The device flow will be triggered by the host when the user clicks
496 // the "Sign in with GitHub" button, which calls llm_provider_start_device_flow_sign_in.
497 Err("CredentialsNotFound".to_string())
498 }
499
500 fn llm_provider_start_device_flow_sign_in(
501 &mut self,
502 _provider_id: &str,
503 ) -> Result<String, String> {
504 // Step 1: Request device and user verification codes
505 let device_code_response = llm_oauth_http_request(&LlmOauthHttpRequest {
506 url: GITHUB_DEVICE_CODE_URL.to_string(),
507 method: "POST".to_string(),
508 headers: vec![
509 ("Accept".to_string(), "application/json".to_string()),
510 (
511 "Content-Type".to_string(),
512 "application/x-www-form-urlencoded".to_string(),
513 ),
514 ],
515 body: format!("client_id={}&scope=read:user", GITHUB_COPILOT_CLIENT_ID),
516 })?;
517
518 if device_code_response.status != 200 {
519 return Err(format!(
520 "Failed to get device code: HTTP {}",
521 device_code_response.status
522 ));
523 }
524
525 #[derive(Deserialize)]
526 struct DeviceCodeResponse {
527 device_code: String,
528 user_code: String,
529 verification_uri: String,
530 #[serde(default)]
531 verification_uri_complete: Option<String>,
532 expires_in: u64,
533 interval: u64,
534 }
535
536 let device_info: DeviceCodeResponse = serde_json::from_str(&device_code_response.body)
537 .map_err(|e| format!("Failed to parse device code response: {}", e))?;
538
539 // Store device flow state for polling
540 *self.device_flow_state.lock().unwrap() = Some(DeviceFlowState {
541 device_code: device_info.device_code,
542 interval: device_info.interval,
543 expires_in: device_info.expires_in,
544 });
545
546 // Step 2: Open browser to verification URL
547 // Use verification_uri_complete if available (has code pre-filled), otherwise construct URL
548 let verification_url = device_info.verification_uri_complete.unwrap_or_else(|| {
549 format!(
550 "{}?user_code={}",
551 device_info.verification_uri, &device_info.user_code
552 )
553 });
554 llm_oauth_open_browser(&verification_url)?;
555
556 // Return the user code for the host to display
557 Ok(device_info.user_code)
558 }
559
560 fn llm_provider_poll_device_flow_sign_in(&mut self, _provider_id: &str) -> Result<(), String> {
561 let state = self
562 .device_flow_state
563 .lock()
564 .unwrap()
565 .take()
566 .ok_or("No device flow in progress")?;
567
568 let poll_interval = Duration::from_secs(state.interval.max(5));
569 let max_attempts = (state.expires_in / state.interval.max(5)) as usize;
570
571 for _ in 0..max_attempts {
572 thread::sleep(poll_interval);
573
574 let token_response = llm_oauth_http_request(&LlmOauthHttpRequest {
575 url: GITHUB_ACCESS_TOKEN_URL.to_string(),
576 method: "POST".to_string(),
577 headers: vec![
578 ("Accept".to_string(), "application/json".to_string()),
579 (
580 "Content-Type".to_string(),
581 "application/x-www-form-urlencoded".to_string(),
582 ),
583 ],
584 body: format!(
585 "client_id={}&device_code={}&grant_type=urn:ietf:params:oauth:grant-type:device_code",
586 GITHUB_COPILOT_CLIENT_ID, state.device_code
587 ),
588 })?;
589
590 #[derive(Deserialize)]
591 struct TokenResponse {
592 access_token: Option<String>,
593 error: Option<String>,
594 error_description: Option<String>,
595 }
596
597 let token_json: TokenResponse = serde_json::from_str(&token_response.body)
598 .map_err(|e| format!("Failed to parse token response: {}", e))?;
599
600 if let Some(access_token) = token_json.access_token {
601 llm_store_credential("copilot-chat", &access_token)?;
602 return Ok(());
603 }
604
605 if let Some(error) = &token_json.error {
606 match error.as_str() {
607 "authorization_pending" => {
608 // User hasn't authorized yet, keep polling
609 continue;
610 }
611 "slow_down" => {
612 // Need to slow down polling
613 thread::sleep(Duration::from_secs(5));
614 continue;
615 }
616 "expired_token" => {
617 return Err("Device code expired. Please try again.".to_string());
618 }
619 "access_denied" => {
620 return Err("Authorization was denied.".to_string());
621 }
622 _ => {
623 let description = token_json.error_description.unwrap_or_default();
624 return Err(format!("OAuth error: {} - {}", error, description));
625 }
626 }
627 }
628 }
629
630 Err("Authorization timed out. Please try again.".to_string())
631 }
632
633 fn llm_provider_reset_credentials(&mut self, _provider_id: &str) -> Result<(), String> {
634 // Clear cached API token and models
635 *self.api_token.lock().unwrap() = None;
636 *self.cached_models.lock().unwrap() = None;
637 llm_delete_credential("copilot-chat")
638 }
639
640 fn llm_stream_completion_start(
641 &mut self,
642 _provider_id: &str,
643 model_id: &str,
644 request: &LlmCompletionRequest,
645 ) -> Result<String, String> {
646 let oauth_token = llm_get_credential("copilot-chat").ok_or_else(|| {
647 "No token configured. Please add your GitHub Copilot token in settings.".to_string()
648 })?;
649
650 // Get or refresh API token
651 let api_token = self.get_api_token(&oauth_token)?;
652
653 let openai_request = convert_request(model_id, request)?;
654
655 let body = serde_json::to_vec(&openai_request)
656 .map_err(|e| format!("Failed to serialize request: {}", e))?;
657
658 let completions_url = format!("{}/chat/completions", api_token.api_endpoint);
659
660 let http_request = HttpRequest {
661 method: HttpMethod::Post,
662 url: completions_url,
663 headers: vec![
664 ("Content-Type".to_string(), "application/json".to_string()),
665 (
666 "Authorization".to_string(),
667 format!("Bearer {}", api_token.api_key),
668 ),
669 (
670 "Copilot-Integration-Id".to_string(),
671 "vscode-chat".to_string(),
672 ),
673 ("Editor-Version".to_string(), "Zed/1.0.0".to_string()),
674 ],
675 body: Some(body),
676 redirect_policy: RedirectPolicy::FollowAll,
677 };
678
679 let response_stream = http_request
680 .fetch_stream()
681 .map_err(|e| format!("HTTP request failed: {}", e))?;
682
683 let stream_id = {
684 let mut id_counter = self.next_stream_id.lock().unwrap();
685 let id = format!("copilot-stream-{}", *id_counter);
686 *id_counter += 1;
687 id
688 };
689
690 self.streams.lock().unwrap().insert(
691 stream_id.clone(),
692 StreamState {
693 response_stream: Some(response_stream),
694 buffer: String::new(),
695 started: false,
696 tool_calls: HashMap::new(),
697 tool_calls_emitted: false,
698 },
699 );
700
701 Ok(stream_id)
702 }
703
704 fn llm_stream_completion_next(
705 &mut self,
706 stream_id: &str,
707 ) -> Result<Option<LlmCompletionEvent>, String> {
708 let mut streams = self.streams.lock().unwrap();
709 let state = streams
710 .get_mut(stream_id)
711 .ok_or_else(|| format!("Unknown stream: {}", stream_id))?;
712
713 if !state.started {
714 state.started = true;
715 return Ok(Some(LlmCompletionEvent::Started));
716 }
717
718 let response_stream = state
719 .response_stream
720 .as_mut()
721 .ok_or_else(|| "Stream already closed".to_string())?;
722
723 loop {
724 if let Some(newline_pos) = state.buffer.find('\n') {
725 let line = state.buffer[..newline_pos].to_string();
726 state.buffer = state.buffer[newline_pos + 1..].to_string();
727
728 if line.trim().is_empty() {
729 continue;
730 }
731
732 if let Some(response) = parse_sse_line(&line) {
733 if let Some(choice) = response.choices.first() {
734 if let Some(content) = &choice.delta.content {
735 if !content.is_empty() {
736 return Ok(Some(LlmCompletionEvent::Text(content.clone())));
737 }
738 }
739
740 if let Some(tool_calls) = &choice.delta.tool_calls {
741 for tc in tool_calls {
742 let entry = state
743 .tool_calls
744 .entry(tc.index)
745 .or_insert_with(AccumulatedToolCall::default);
746
747 if let Some(id) = &tc.id {
748 entry.id = id.clone();
749 }
750 if let Some(func) = &tc.function {
751 if let Some(name) = &func.name {
752 entry.name = name.clone();
753 }
754 if let Some(args) = &func.arguments {
755 entry.arguments.push_str(args);
756 }
757 }
758 }
759 }
760
761 if let Some(finish_reason) = &choice.finish_reason {
762 if !state.tool_calls.is_empty() && !state.tool_calls_emitted {
763 state.tool_calls_emitted = true;
764 let mut tool_calls: Vec<_> = state.tool_calls.drain().collect();
765 tool_calls.sort_by_key(|(idx, _)| *idx);
766
767 if let Some((_, tc)) = tool_calls.into_iter().next() {
768 return Ok(Some(LlmCompletionEvent::ToolUse(LlmToolUse {
769 id: tc.id,
770 name: tc.name,
771 input: tc.arguments,
772 thought_signature: None,
773 })));
774 }
775 }
776
777 let stop_reason = match finish_reason.as_str() {
778 "stop" => LlmStopReason::EndTurn,
779 "length" => LlmStopReason::MaxTokens,
780 "tool_calls" => LlmStopReason::ToolUse,
781 "content_filter" => LlmStopReason::Refusal,
782 _ => LlmStopReason::EndTurn,
783 };
784 return Ok(Some(LlmCompletionEvent::Stop(stop_reason)));
785 }
786 }
787
788 if let Some(usage) = response.usage {
789 return Ok(Some(LlmCompletionEvent::Usage(LlmTokenUsage {
790 input_tokens: usage.prompt_tokens,
791 output_tokens: usage.completion_tokens,
792 cache_creation_input_tokens: None,
793 cache_read_input_tokens: None,
794 })));
795 }
796 }
797
798 continue;
799 }
800
801 match response_stream.next_chunk() {
802 Ok(Some(chunk)) => {
803 let text = String::from_utf8_lossy(&chunk);
804 state.buffer.push_str(&text);
805 }
806 Ok(None) => {
807 return Ok(None);
808 }
809 Err(e) => {
810 return Err(format!("Stream error: {}", e));
811 }
812 }
813 }
814 }
815
816 fn llm_stream_completion_close(&mut self, stream_id: &str) {
817 self.streams.lock().unwrap().remove(stream_id);
818 }
819}
820
821impl CopilotChatProvider {
822 fn get_api_token(&self, oauth_token: &str) -> Result<ApiToken, String> {
823 // Check if we have a cached token
824 if let Some(token) = self.api_token.lock().unwrap().clone() {
825 return Ok(token);
826 }
827
828 // Request a new API token
829 let http_request = HttpRequest {
830 method: HttpMethod::Get,
831 url: GITHUB_COPILOT_TOKEN_URL.to_string(),
832 headers: vec![
833 (
834 "Authorization".to_string(),
835 format!("token {}", oauth_token),
836 ),
837 ("Accept".to_string(), "application/json".to_string()),
838 ],
839 body: None,
840 redirect_policy: RedirectPolicy::FollowAll,
841 };
842
843 let response = http_request
844 .fetch()
845 .map_err(|e| format!("Failed to request API token: {}", e))?;
846
847 #[derive(Deserialize)]
848 struct ApiTokenResponse {
849 token: String,
850 endpoints: ApiEndpoints,
851 }
852
853 #[derive(Deserialize)]
854 struct ApiEndpoints {
855 api: String,
856 }
857
858 let token_response: ApiTokenResponse =
859 serde_json::from_slice(&response.body).map_err(|e| {
860 format!(
861 "Failed to parse API token response: {} - body: {}",
862 e,
863 String::from_utf8_lossy(&response.body)
864 )
865 })?;
866
867 let api_token = ApiToken {
868 api_key: token_response.token,
869 api_endpoint: token_response.endpoints.api,
870 };
871
872 // Cache the token
873 *self.api_token.lock().unwrap() = Some(api_token.clone());
874
875 Ok(api_token)
876 }
877
878 fn fetch_models(&self, api_token: &ApiToken) -> Result<Vec<CopilotModel>, String> {
879 let models_url = format!("{}/models", api_token.api_endpoint);
880
881 let http_request = HttpRequest {
882 method: HttpMethod::Get,
883 url: models_url,
884 headers: vec![
885 (
886 "Authorization".to_string(),
887 format!("Bearer {}", api_token.api_key),
888 ),
889 ("Content-Type".to_string(), "application/json".to_string()),
890 (
891 "Copilot-Integration-Id".to_string(),
892 "vscode-chat".to_string(),
893 ),
894 ("Editor-Version".to_string(), "Zed/1.0.0".to_string()),
895 ("x-github-api-version".to_string(), "2025-05-01".to_string()),
896 ],
897 body: None,
898 redirect_policy: RedirectPolicy::FollowAll,
899 };
900
901 let response = http_request
902 .fetch()
903 .map_err(|e| format!("Failed to fetch models: {}", e))?;
904
905 #[derive(Deserialize)]
906 struct ModelsResponse {
907 data: Vec<CopilotModel>,
908 }
909
910 let models_response: ModelsResponse =
911 serde_json::from_slice(&response.body).map_err(|e| {
912 format!(
913 "Failed to parse models response: {} - body: {}",
914 e,
915 String::from_utf8_lossy(&response.body)
916 )
917 })?;
918
919 // Filter models like the built-in Copilot Chat does
920 let mut models: Vec<CopilotModel> = models_response
921 .data
922 .into_iter()
923 .filter(|model| {
924 model.model_picker_enabled
925 && model.capabilities.model_type == "chat"
926 && model
927 .policy
928 .as_ref()
929 .map(|p| p.state == "enabled")
930 .unwrap_or(true)
931 })
932 .collect();
933
934 // Sort so default model is first
935 if let Some(pos) = models.iter().position(|m| m.is_chat_default) {
936 let default_model = models.remove(pos);
937 models.insert(0, default_model);
938 }
939
940 Ok(models)
941 }
942}
943
944fn convert_models_to_llm_info(models: &[CopilotModel]) -> Vec<LlmModelInfo> {
945 models
946 .iter()
947 .map(|m| {
948 let max_tokens = if m.capabilities.limits.max_context_window_tokens > 0 {
949 m.capabilities.limits.max_context_window_tokens
950 } else {
951 128_000 // Default fallback
952 };
953 let max_output = if m.capabilities.limits.max_output_tokens > 0 {
954 Some(m.capabilities.limits.max_output_tokens)
955 } else {
956 None
957 };
958
959 LlmModelInfo {
960 id: m.id.clone(),
961 name: m.name.clone(),
962 max_token_count: max_tokens,
963 max_output_tokens: max_output,
964 capabilities: LlmModelCapabilities {
965 supports_images: m.capabilities.supports.vision,
966 supports_tools: m.capabilities.supports.tool_calls,
967 supports_tool_choice_auto: m.capabilities.supports.tool_calls,
968 supports_tool_choice_any: m.capabilities.supports.tool_calls,
969 supports_tool_choice_none: m.capabilities.supports.tool_calls,
970 supports_thinking: false,
971 tool_input_format: LlmToolInputFormat::JsonSchema,
972 },
973 is_default: m.is_chat_default,
974 is_default_fast: m.is_chat_fallback,
975 }
976 })
977 .collect()
978}
979
980#[cfg(test)]
981mod tests {
982 use super::*;
983
984 #[test]
985 fn test_device_flow_request_body() {
986 let body = format!("client_id={}&scope=read:user", GITHUB_COPILOT_CLIENT_ID);
987 assert!(body.contains("client_id=Iv1.b507a08c87ecfe98"));
988 assert!(body.contains("scope=read:user"));
989 }
990
991 #[test]
992 fn test_token_poll_request_body() {
993 let device_code = "test_device_code_123";
994 let body = format!(
995 "client_id={}&device_code={}&grant_type=urn:ietf:params:oauth:grant-type:device_code",
996 GITHUB_COPILOT_CLIENT_ID, device_code
997 );
998 assert!(body.contains("client_id=Iv1.b507a08c87ecfe98"));
999 assert!(body.contains("device_code=test_device_code_123"));
1000 assert!(body.contains("grant_type=urn:ietf:params:oauth:grant-type:device_code"));
1001 }
1002}
1003
1004zed::register_extension!(CopilotChatProvider);