1use std::collections::HashMap;
2use std::sync::Mutex;
3use std::thread;
4use std::time::Duration;
5
6use serde::{Deserialize, Serialize};
7use zed_extension_api::http_client::{HttpMethod, HttpRequest, HttpResponseStream, RedirectPolicy};
8use zed_extension_api::{self as zed, *};
9
10const GITHUB_DEVICE_CODE_URL: &str = "https://github.com/login/device/code";
11const GITHUB_ACCESS_TOKEN_URL: &str = "https://github.com/login/oauth/access_token";
12const GITHUB_COPILOT_TOKEN_URL: &str = "https://api.github.com/copilot_internal/v2/token";
13const GITHUB_COPILOT_CLIENT_ID: &str = "Iv1.b507a08c87ecfe98";
14
15struct DeviceFlowState {
16 device_code: String,
17 interval: u64,
18 expires_in: u64,
19}
20
21#[derive(Clone)]
22struct ApiToken {
23 api_key: String,
24 api_endpoint: String,
25}
26
27#[derive(Clone, Deserialize)]
28struct CopilotModel {
29 id: String,
30 name: String,
31 #[serde(default)]
32 is_chat_default: bool,
33 #[serde(default)]
34 is_chat_fallback: bool,
35 #[serde(default)]
36 model_picker_enabled: bool,
37 #[serde(default)]
38 capabilities: ModelCapabilities,
39 #[serde(default)]
40 policy: Option<ModelPolicy>,
41}
42
43#[derive(Clone, Default, Deserialize)]
44struct ModelCapabilities {
45 #[serde(default)]
46 family: String,
47 #[serde(default)]
48 limits: ModelLimits,
49 #[serde(default)]
50 supports: ModelSupportedFeatures,
51 #[serde(rename = "type", default)]
52 model_type: String,
53}
54
55#[derive(Clone, Default, Deserialize)]
56struct ModelLimits {
57 #[serde(default)]
58 max_context_window_tokens: u64,
59 #[serde(default)]
60 max_output_tokens: u64,
61}
62
63#[derive(Clone, Default, Deserialize)]
64struct ModelSupportedFeatures {
65 #[serde(default)]
66 streaming: bool,
67 #[serde(default)]
68 tool_calls: bool,
69 #[serde(default)]
70 vision: bool,
71}
72
73#[derive(Clone, Deserialize)]
74struct ModelPolicy {
75 state: String,
76}
77
78struct CopilotChatProvider {
79 streams: Mutex<HashMap<String, StreamState>>,
80 next_stream_id: Mutex<u64>,
81 device_flow_state: Mutex<Option<DeviceFlowState>>,
82 api_token: Mutex<Option<ApiToken>>,
83 cached_models: Mutex<Option<Vec<CopilotModel>>>,
84}
85
86struct StreamState {
87 response_stream: Option<HttpResponseStream>,
88 buffer: String,
89 started: bool,
90 tool_calls: HashMap<usize, AccumulatedToolCall>,
91 tool_calls_emitted: bool,
92}
93
94#[derive(Clone, Default)]
95struct AccumulatedToolCall {
96 id: String,
97 name: String,
98 arguments: String,
99}
100
101#[derive(Serialize)]
102struct OpenAiRequest {
103 model: String,
104 messages: Vec<OpenAiMessage>,
105 #[serde(skip_serializing_if = "Option::is_none")]
106 max_tokens: Option<u64>,
107 #[serde(skip_serializing_if = "Vec::is_empty")]
108 tools: Vec<OpenAiTool>,
109 #[serde(skip_serializing_if = "Option::is_none")]
110 tool_choice: Option<String>,
111 #[serde(skip_serializing_if = "Vec::is_empty")]
112 stop: Vec<String>,
113 #[serde(skip_serializing_if = "Option::is_none")]
114 temperature: Option<f32>,
115 stream: bool,
116 #[serde(skip_serializing_if = "Option::is_none")]
117 stream_options: Option<StreamOptions>,
118}
119
120#[derive(Serialize)]
121struct StreamOptions {
122 include_usage: bool,
123}
124
125#[derive(Serialize)]
126struct OpenAiMessage {
127 role: String,
128 #[serde(skip_serializing_if = "Option::is_none")]
129 content: Option<OpenAiContent>,
130 #[serde(skip_serializing_if = "Option::is_none")]
131 tool_calls: Option<Vec<OpenAiToolCall>>,
132 #[serde(skip_serializing_if = "Option::is_none")]
133 tool_call_id: Option<String>,
134}
135
136#[derive(Serialize, Clone)]
137#[serde(untagged)]
138enum OpenAiContent {
139 Text(String),
140 Parts(Vec<OpenAiContentPart>),
141}
142
143#[derive(Serialize, Clone)]
144#[serde(tag = "type")]
145enum OpenAiContentPart {
146 #[serde(rename = "text")]
147 Text { text: String },
148 #[serde(rename = "image_url")]
149 ImageUrl { image_url: ImageUrl },
150}
151
152#[derive(Serialize, Clone)]
153struct ImageUrl {
154 url: String,
155}
156
157#[derive(Serialize, Clone)]
158struct OpenAiToolCall {
159 id: String,
160 #[serde(rename = "type")]
161 call_type: String,
162 function: OpenAiFunctionCall,
163}
164
165#[derive(Serialize, Clone)]
166struct OpenAiFunctionCall {
167 name: String,
168 arguments: String,
169}
170
171#[derive(Serialize)]
172struct OpenAiTool {
173 #[serde(rename = "type")]
174 tool_type: String,
175 function: OpenAiFunctionDef,
176}
177
178#[derive(Serialize)]
179struct OpenAiFunctionDef {
180 name: String,
181 description: String,
182 parameters: serde_json::Value,
183}
184
185#[derive(Deserialize, Debug)]
186struct OpenAiStreamResponse {
187 choices: Vec<OpenAiStreamChoice>,
188 #[serde(default)]
189 usage: Option<OpenAiUsage>,
190}
191
192#[derive(Deserialize, Debug)]
193struct OpenAiStreamChoice {
194 delta: OpenAiDelta,
195 finish_reason: Option<String>,
196}
197
198#[derive(Deserialize, Debug, Default)]
199struct OpenAiDelta {
200 #[serde(default)]
201 content: Option<String>,
202 #[serde(default)]
203 tool_calls: Option<Vec<OpenAiToolCallDelta>>,
204}
205
206#[derive(Deserialize, Debug)]
207struct OpenAiToolCallDelta {
208 index: usize,
209 #[serde(default)]
210 id: Option<String>,
211 #[serde(default)]
212 function: Option<OpenAiFunctionDelta>,
213}
214
215#[derive(Deserialize, Debug, Default)]
216struct OpenAiFunctionDelta {
217 #[serde(default)]
218 name: Option<String>,
219 #[serde(default)]
220 arguments: Option<String>,
221}
222
223#[derive(Deserialize, Debug)]
224struct OpenAiUsage {
225 prompt_tokens: u64,
226 completion_tokens: u64,
227}
228
229fn convert_request(
230 model_id: &str,
231 request: &LlmCompletionRequest,
232) -> Result<OpenAiRequest, String> {
233 let mut messages: Vec<OpenAiMessage> = Vec::new();
234
235 for msg in &request.messages {
236 match msg.role {
237 LlmMessageRole::System => {
238 let mut text_content = String::new();
239 for content in &msg.content {
240 if let LlmMessageContent::Text(text) = content {
241 if !text_content.is_empty() {
242 text_content.push('\n');
243 }
244 text_content.push_str(text);
245 }
246 }
247 if !text_content.is_empty() {
248 messages.push(OpenAiMessage {
249 role: "system".to_string(),
250 content: Some(OpenAiContent::Text(text_content)),
251 tool_calls: None,
252 tool_call_id: None,
253 });
254 }
255 }
256 LlmMessageRole::User => {
257 let mut parts: Vec<OpenAiContentPart> = Vec::new();
258 let mut tool_result_messages: Vec<OpenAiMessage> = Vec::new();
259
260 for content in &msg.content {
261 match content {
262 LlmMessageContent::Text(text) => {
263 if !text.is_empty() {
264 parts.push(OpenAiContentPart::Text { text: text.clone() });
265 }
266 }
267 LlmMessageContent::Image(img) => {
268 let data_url = format!("data:image/png;base64,{}", img.source);
269 parts.push(OpenAiContentPart::ImageUrl {
270 image_url: ImageUrl { url: data_url },
271 });
272 }
273 LlmMessageContent::ToolResult(result) => {
274 let content_text = match &result.content {
275 LlmToolResultContent::Text(t) => t.clone(),
276 LlmToolResultContent::Image(_) => "[Image]".to_string(),
277 };
278 tool_result_messages.push(OpenAiMessage {
279 role: "tool".to_string(),
280 content: Some(OpenAiContent::Text(content_text)),
281 tool_calls: None,
282 tool_call_id: Some(result.tool_use_id.clone()),
283 });
284 }
285 _ => {}
286 }
287 }
288
289 if !parts.is_empty() {
290 let content = if parts.len() == 1 {
291 if let OpenAiContentPart::Text { text } = &parts[0] {
292 OpenAiContent::Text(text.clone())
293 } else {
294 OpenAiContent::Parts(parts)
295 }
296 } else {
297 OpenAiContent::Parts(parts)
298 };
299
300 messages.push(OpenAiMessage {
301 role: "user".to_string(),
302 content: Some(content),
303 tool_calls: None,
304 tool_call_id: None,
305 });
306 }
307
308 messages.extend(tool_result_messages);
309 }
310 LlmMessageRole::Assistant => {
311 let mut text_content = String::new();
312 let mut tool_calls: Vec<OpenAiToolCall> = Vec::new();
313
314 for content in &msg.content {
315 match content {
316 LlmMessageContent::Text(text) => {
317 if !text.is_empty() {
318 if !text_content.is_empty() {
319 text_content.push('\n');
320 }
321 text_content.push_str(text);
322 }
323 }
324 LlmMessageContent::ToolUse(tool_use) => {
325 tool_calls.push(OpenAiToolCall {
326 id: tool_use.id.clone(),
327 call_type: "function".to_string(),
328 function: OpenAiFunctionCall {
329 name: tool_use.name.clone(),
330 arguments: tool_use.input.clone(),
331 },
332 });
333 }
334 _ => {}
335 }
336 }
337
338 messages.push(OpenAiMessage {
339 role: "assistant".to_string(),
340 content: if text_content.is_empty() {
341 None
342 } else {
343 Some(OpenAiContent::Text(text_content))
344 },
345 tool_calls: if tool_calls.is_empty() {
346 None
347 } else {
348 Some(tool_calls)
349 },
350 tool_call_id: None,
351 });
352 }
353 }
354 }
355
356 let tools: Vec<OpenAiTool> = request
357 .tools
358 .iter()
359 .map(|t| OpenAiTool {
360 tool_type: "function".to_string(),
361 function: OpenAiFunctionDef {
362 name: t.name.clone(),
363 description: t.description.clone(),
364 parameters: serde_json::from_str(&t.input_schema)
365 .unwrap_or(serde_json::Value::Object(Default::default())),
366 },
367 })
368 .collect();
369
370 let tool_choice = request.tool_choice.as_ref().map(|tc| match tc {
371 LlmToolChoice::Auto => "auto".to_string(),
372 LlmToolChoice::Any => "required".to_string(),
373 LlmToolChoice::None => "none".to_string(),
374 });
375
376 let max_tokens = request.max_tokens;
377
378 Ok(OpenAiRequest {
379 model: model_id.to_string(),
380 messages,
381 max_tokens,
382 tools,
383 tool_choice,
384 stop: request.stop_sequences.clone(),
385 temperature: request.temperature,
386 stream: true,
387 stream_options: Some(StreamOptions {
388 include_usage: true,
389 }),
390 })
391}
392
393fn parse_sse_line(line: &str) -> Option<OpenAiStreamResponse> {
394 let data = line.strip_prefix("data: ")?;
395 if data.trim() == "[DONE]" {
396 return None;
397 }
398 serde_json::from_str(data).ok()
399}
400
401impl zed::Extension for CopilotChatProvider {
402 fn new() -> Self {
403 Self {
404 streams: Mutex::new(HashMap::new()),
405 next_stream_id: Mutex::new(0),
406 device_flow_state: Mutex::new(None),
407 api_token: Mutex::new(None),
408 cached_models: Mutex::new(None),
409 }
410 }
411
412 fn llm_providers(&self) -> Vec<LlmProviderInfo> {
413 vec![LlmProviderInfo {
414 id: "copilot-chat".into(),
415 name: "Copilot Chat".into(),
416 icon: Some("icons/copilot.svg".into()),
417 }]
418 }
419
420 fn llm_provider_models(&self, _provider_id: &str) -> Result<Vec<LlmModelInfo>, String> {
421 // Try to get models from cache first
422 if let Some(models) = self.cached_models.lock().unwrap().as_ref() {
423 return Ok(convert_models_to_llm_info(models));
424 }
425
426 // Need to fetch models - requires authentication
427 let oauth_token = match llm_get_credential("copilot-chat") {
428 Some(token) => token,
429 None => return Ok(Vec::new()), // Not authenticated, return empty
430 };
431
432 // Get API token
433 let api_token = self.get_api_token(&oauth_token)?;
434
435 // Fetch models from API
436 let models = self.fetch_models(&api_token)?;
437
438 // Cache the models
439 *self.cached_models.lock().unwrap() = Some(models.clone());
440
441 Ok(convert_models_to_llm_info(&models))
442 }
443
444 fn llm_provider_is_authenticated(&self, _provider_id: &str) -> bool {
445 llm_get_credential("copilot-chat").is_some()
446 }
447
448 fn llm_provider_settings_markdown(&self, _provider_id: &str) -> Option<String> {
449 Some(
450 "To use Copilot Chat, sign in with your GitHub account. This requires an active [GitHub Copilot subscription](https://github.com/features/copilot).".to_string(),
451 )
452 }
453
454 fn llm_provider_authenticate(&mut self, _provider_id: &str) -> Result<(), String> {
455 // Check if we have existing credentials
456 if llm_get_credential("copilot-chat").is_some() {
457 return Ok(());
458 }
459
460 // No credentials found - return error for background auth checks.
461 // The device flow will be triggered by the host when the user clicks
462 // the "Sign in with GitHub" button, which calls llm_provider_start_device_flow_sign_in.
463 Err("CredentialsNotFound".to_string())
464 }
465
466 fn llm_provider_start_device_flow_sign_in(
467 &mut self,
468 _provider_id: &str,
469 ) -> Result<String, String> {
470 // Step 1: Request device and user verification codes
471 let device_code_response = llm_oauth_http_request(&LlmOauthHttpRequest {
472 url: GITHUB_DEVICE_CODE_URL.to_string(),
473 method: "POST".to_string(),
474 headers: vec![
475 ("Accept".to_string(), "application/json".to_string()),
476 (
477 "Content-Type".to_string(),
478 "application/x-www-form-urlencoded".to_string(),
479 ),
480 ],
481 body: format!("client_id={}&scope=read:user", GITHUB_COPILOT_CLIENT_ID),
482 })?;
483
484 if device_code_response.status != 200 {
485 return Err(format!(
486 "Failed to get device code: HTTP {}",
487 device_code_response.status
488 ));
489 }
490
491 #[derive(Deserialize)]
492 struct DeviceCodeResponse {
493 device_code: String,
494 user_code: String,
495 verification_uri: String,
496 #[serde(default)]
497 verification_uri_complete: Option<String>,
498 expires_in: u64,
499 interval: u64,
500 }
501
502 let device_info: DeviceCodeResponse = serde_json::from_str(&device_code_response.body)
503 .map_err(|e| format!("Failed to parse device code response: {}", e))?;
504
505 // Store device flow state for polling
506 *self.device_flow_state.lock().unwrap() = Some(DeviceFlowState {
507 device_code: device_info.device_code,
508 interval: device_info.interval,
509 expires_in: device_info.expires_in,
510 });
511
512 // Step 2: Open browser to verification URL
513 // Use verification_uri_complete if available (has code pre-filled), otherwise construct URL
514 let verification_url = device_info.verification_uri_complete.unwrap_or_else(|| {
515 format!(
516 "{}?user_code={}",
517 device_info.verification_uri, &device_info.user_code
518 )
519 });
520 llm_oauth_open_browser(&verification_url)?;
521
522 // Return the user code for the host to display
523 Ok(device_info.user_code)
524 }
525
526 fn llm_provider_poll_device_flow_sign_in(&mut self, _provider_id: &str) -> Result<(), String> {
527 let state = self
528 .device_flow_state
529 .lock()
530 .unwrap()
531 .take()
532 .ok_or("No device flow in progress")?;
533
534 let poll_interval = Duration::from_secs(state.interval.max(5));
535 let max_attempts = (state.expires_in / state.interval.max(5)) as usize;
536
537 for _ in 0..max_attempts {
538 thread::sleep(poll_interval);
539
540 let token_response = llm_oauth_http_request(&LlmOauthHttpRequest {
541 url: GITHUB_ACCESS_TOKEN_URL.to_string(),
542 method: "POST".to_string(),
543 headers: vec![
544 ("Accept".to_string(), "application/json".to_string()),
545 (
546 "Content-Type".to_string(),
547 "application/x-www-form-urlencoded".to_string(),
548 ),
549 ],
550 body: format!(
551 "client_id={}&device_code={}&grant_type=urn:ietf:params:oauth:grant-type:device_code",
552 GITHUB_COPILOT_CLIENT_ID, state.device_code
553 ),
554 })?;
555
556 #[derive(Deserialize)]
557 struct TokenResponse {
558 access_token: Option<String>,
559 error: Option<String>,
560 error_description: Option<String>,
561 }
562
563 let token_json: TokenResponse = serde_json::from_str(&token_response.body)
564 .map_err(|e| format!("Failed to parse token response: {}", e))?;
565
566 if let Some(access_token) = token_json.access_token {
567 llm_store_credential("copilot-chat", &access_token)?;
568 return Ok(());
569 }
570
571 if let Some(error) = &token_json.error {
572 match error.as_str() {
573 "authorization_pending" => {
574 // User hasn't authorized yet, keep polling
575 continue;
576 }
577 "slow_down" => {
578 // Need to slow down polling
579 thread::sleep(Duration::from_secs(5));
580 continue;
581 }
582 "expired_token" => {
583 return Err("Device code expired. Please try again.".to_string());
584 }
585 "access_denied" => {
586 return Err("Authorization was denied.".to_string());
587 }
588 _ => {
589 let description = token_json.error_description.unwrap_or_default();
590 return Err(format!("OAuth error: {} - {}", error, description));
591 }
592 }
593 }
594 }
595
596 Err("Authorization timed out. Please try again.".to_string())
597 }
598
599 fn llm_provider_reset_credentials(&mut self, _provider_id: &str) -> Result<(), String> {
600 // Clear cached API token and models
601 *self.api_token.lock().unwrap() = None;
602 *self.cached_models.lock().unwrap() = None;
603 llm_delete_credential("copilot-chat")
604 }
605
606 fn llm_stream_completion_start(
607 &mut self,
608 _provider_id: &str,
609 model_id: &str,
610 request: &LlmCompletionRequest,
611 ) -> Result<String, String> {
612 let oauth_token = llm_get_credential("copilot-chat").ok_or_else(|| {
613 "No token configured. Please add your GitHub Copilot token in settings.".to_string()
614 })?;
615
616 // Get or refresh API token
617 let api_token = self.get_api_token(&oauth_token)?;
618
619 let openai_request = convert_request(model_id, request)?;
620
621 let body = serde_json::to_vec(&openai_request)
622 .map_err(|e| format!("Failed to serialize request: {}", e))?;
623
624 let completions_url = format!("{}/chat/completions", api_token.api_endpoint);
625
626 let http_request = HttpRequest {
627 method: HttpMethod::Post,
628 url: completions_url,
629 headers: vec![
630 ("Content-Type".to_string(), "application/json".to_string()),
631 (
632 "Authorization".to_string(),
633 format!("Bearer {}", api_token.api_key),
634 ),
635 (
636 "Copilot-Integration-Id".to_string(),
637 "vscode-chat".to_string(),
638 ),
639 ("Editor-Version".to_string(), "Zed/1.0.0".to_string()),
640 ],
641 body: Some(body),
642 redirect_policy: RedirectPolicy::FollowAll,
643 };
644
645 let response_stream = http_request
646 .fetch_stream()
647 .map_err(|e| format!("HTTP request failed: {}", e))?;
648
649 let stream_id = {
650 let mut id_counter = self.next_stream_id.lock().unwrap();
651 let id = format!("copilot-stream-{}", *id_counter);
652 *id_counter += 1;
653 id
654 };
655
656 self.streams.lock().unwrap().insert(
657 stream_id.clone(),
658 StreamState {
659 response_stream: Some(response_stream),
660 buffer: String::new(),
661 started: false,
662 tool_calls: HashMap::new(),
663 tool_calls_emitted: false,
664 },
665 );
666
667 Ok(stream_id)
668 }
669
670 fn llm_stream_completion_next(
671 &mut self,
672 stream_id: &str,
673 ) -> Result<Option<LlmCompletionEvent>, String> {
674 let mut streams = self.streams.lock().unwrap();
675 let state = streams
676 .get_mut(stream_id)
677 .ok_or_else(|| format!("Unknown stream: {}", stream_id))?;
678
679 if !state.started {
680 state.started = true;
681 return Ok(Some(LlmCompletionEvent::Started));
682 }
683
684 let response_stream = state
685 .response_stream
686 .as_mut()
687 .ok_or_else(|| "Stream already closed".to_string())?;
688
689 loop {
690 if let Some(newline_pos) = state.buffer.find('\n') {
691 let line = state.buffer[..newline_pos].to_string();
692 state.buffer = state.buffer[newline_pos + 1..].to_string();
693
694 if line.trim().is_empty() {
695 continue;
696 }
697
698 if let Some(response) = parse_sse_line(&line) {
699 if let Some(choice) = response.choices.first() {
700 if let Some(content) = &choice.delta.content {
701 if !content.is_empty() {
702 return Ok(Some(LlmCompletionEvent::Text(content.clone())));
703 }
704 }
705
706 if let Some(tool_calls) = &choice.delta.tool_calls {
707 for tc in tool_calls {
708 let entry = state
709 .tool_calls
710 .entry(tc.index)
711 .or_insert_with(AccumulatedToolCall::default);
712
713 if let Some(id) = &tc.id {
714 entry.id = id.clone();
715 }
716 if let Some(func) = &tc.function {
717 if let Some(name) = &func.name {
718 entry.name = name.clone();
719 }
720 if let Some(args) = &func.arguments {
721 entry.arguments.push_str(args);
722 }
723 }
724 }
725 }
726
727 if let Some(finish_reason) = &choice.finish_reason {
728 if !state.tool_calls.is_empty() && !state.tool_calls_emitted {
729 state.tool_calls_emitted = true;
730 let mut tool_calls: Vec<_> = state.tool_calls.drain().collect();
731 tool_calls.sort_by_key(|(idx, _)| *idx);
732
733 if let Some((_, tc)) = tool_calls.into_iter().next() {
734 return Ok(Some(LlmCompletionEvent::ToolUse(LlmToolUse {
735 id: tc.id,
736 name: tc.name,
737 input: tc.arguments,
738 thought_signature: None,
739 })));
740 }
741 }
742
743 let stop_reason = match finish_reason.as_str() {
744 "stop" => LlmStopReason::EndTurn,
745 "length" => LlmStopReason::MaxTokens,
746 "tool_calls" => LlmStopReason::ToolUse,
747 "content_filter" => LlmStopReason::Refusal,
748 _ => LlmStopReason::EndTurn,
749 };
750 return Ok(Some(LlmCompletionEvent::Stop(stop_reason)));
751 }
752 }
753
754 if let Some(usage) = response.usage {
755 return Ok(Some(LlmCompletionEvent::Usage(LlmTokenUsage {
756 input_tokens: usage.prompt_tokens,
757 output_tokens: usage.completion_tokens,
758 cache_creation_input_tokens: None,
759 cache_read_input_tokens: None,
760 })));
761 }
762 }
763
764 continue;
765 }
766
767 match response_stream.next_chunk() {
768 Ok(Some(chunk)) => {
769 let text = String::from_utf8_lossy(&chunk);
770 state.buffer.push_str(&text);
771 }
772 Ok(None) => {
773 return Ok(None);
774 }
775 Err(e) => {
776 return Err(format!("Stream error: {}", e));
777 }
778 }
779 }
780 }
781
782 fn llm_stream_completion_close(&mut self, stream_id: &str) {
783 self.streams.lock().unwrap().remove(stream_id);
784 }
785}
786
787impl CopilotChatProvider {
788 fn get_api_token(&self, oauth_token: &str) -> Result<ApiToken, String> {
789 // Check if we have a cached token
790 if let Some(token) = self.api_token.lock().unwrap().clone() {
791 return Ok(token);
792 }
793
794 // Request a new API token
795 let http_request = HttpRequest {
796 method: HttpMethod::Get,
797 url: GITHUB_COPILOT_TOKEN_URL.to_string(),
798 headers: vec![
799 (
800 "Authorization".to_string(),
801 format!("token {}", oauth_token),
802 ),
803 ("Accept".to_string(), "application/json".to_string()),
804 ],
805 body: None,
806 redirect_policy: RedirectPolicy::FollowAll,
807 };
808
809 let response = http_request
810 .fetch()
811 .map_err(|e| format!("Failed to request API token: {}", e))?;
812
813 #[derive(Deserialize)]
814 struct ApiTokenResponse {
815 token: String,
816 endpoints: ApiEndpoints,
817 }
818
819 #[derive(Deserialize)]
820 struct ApiEndpoints {
821 api: String,
822 }
823
824 let token_response: ApiTokenResponse =
825 serde_json::from_slice(&response.body).map_err(|e| {
826 format!(
827 "Failed to parse API token response: {} - body: {}",
828 e,
829 String::from_utf8_lossy(&response.body)
830 )
831 })?;
832
833 let api_token = ApiToken {
834 api_key: token_response.token,
835 api_endpoint: token_response.endpoints.api,
836 };
837
838 // Cache the token
839 *self.api_token.lock().unwrap() = Some(api_token.clone());
840
841 Ok(api_token)
842 }
843
844 fn fetch_models(&self, api_token: &ApiToken) -> Result<Vec<CopilotModel>, String> {
845 let models_url = format!("{}/models", api_token.api_endpoint);
846
847 let http_request = HttpRequest {
848 method: HttpMethod::Get,
849 url: models_url,
850 headers: vec![
851 (
852 "Authorization".to_string(),
853 format!("Bearer {}", api_token.api_key),
854 ),
855 ("Content-Type".to_string(), "application/json".to_string()),
856 (
857 "Copilot-Integration-Id".to_string(),
858 "vscode-chat".to_string(),
859 ),
860 ("Editor-Version".to_string(), "Zed/1.0.0".to_string()),
861 ("x-github-api-version".to_string(), "2025-05-01".to_string()),
862 ],
863 body: None,
864 redirect_policy: RedirectPolicy::FollowAll,
865 };
866
867 let response = http_request
868 .fetch()
869 .map_err(|e| format!("Failed to fetch models: {}", e))?;
870
871 #[derive(Deserialize)]
872 struct ModelsResponse {
873 data: Vec<CopilotModel>,
874 }
875
876 let models_response: ModelsResponse =
877 serde_json::from_slice(&response.body).map_err(|e| {
878 format!(
879 "Failed to parse models response: {} - body: {}",
880 e,
881 String::from_utf8_lossy(&response.body)
882 )
883 })?;
884
885 // Filter models like the built-in Copilot Chat does
886 let mut models: Vec<CopilotModel> = models_response
887 .data
888 .into_iter()
889 .filter(|model| {
890 model.model_picker_enabled
891 && model.capabilities.model_type == "chat"
892 && model
893 .policy
894 .as_ref()
895 .map(|p| p.state == "enabled")
896 .unwrap_or(true)
897 })
898 .collect();
899
900 // Sort so default model is first
901 if let Some(pos) = models.iter().position(|m| m.is_chat_default) {
902 let default_model = models.remove(pos);
903 models.insert(0, default_model);
904 }
905
906 Ok(models)
907 }
908}
909
910fn convert_models_to_llm_info(models: &[CopilotModel]) -> Vec<LlmModelInfo> {
911 models
912 .iter()
913 .map(|m| {
914 let max_tokens = if m.capabilities.limits.max_context_window_tokens > 0 {
915 m.capabilities.limits.max_context_window_tokens
916 } else {
917 128_000 // Default fallback
918 };
919 let max_output = if m.capabilities.limits.max_output_tokens > 0 {
920 Some(m.capabilities.limits.max_output_tokens)
921 } else {
922 None
923 };
924
925 LlmModelInfo {
926 id: m.id.clone(),
927 name: m.name.clone(),
928 max_token_count: max_tokens,
929 max_output_tokens: max_output,
930 capabilities: LlmModelCapabilities {
931 supports_images: m.capabilities.supports.vision,
932 supports_tools: m.capabilities.supports.tool_calls,
933 supports_tool_choice_auto: m.capabilities.supports.tool_calls,
934 supports_tool_choice_any: m.capabilities.supports.tool_calls,
935 supports_tool_choice_none: m.capabilities.supports.tool_calls,
936 supports_thinking: false,
937 tool_input_format: LlmToolInputFormat::JsonSchema,
938 },
939 is_default: m.is_chat_default,
940 is_default_fast: m.is_chat_fallback,
941 }
942 })
943 .collect()
944}
945
946#[cfg(test)]
947mod tests {
948 use super::*;
949
950 #[test]
951 fn test_device_flow_request_body() {
952 let body = format!("client_id={}&scope=read:user", GITHUB_COPILOT_CLIENT_ID);
953 assert!(body.contains("client_id=Iv1.b507a08c87ecfe98"));
954 assert!(body.contains("scope=read:user"));
955 }
956
957 #[test]
958 fn test_token_poll_request_body() {
959 let device_code = "test_device_code_123";
960 let body = format!(
961 "client_id={}&device_code={}&grant_type=urn:ietf:params:oauth:grant-type:device_code",
962 GITHUB_COPILOT_CLIENT_ID, device_code
963 );
964 assert!(body.contains("client_id=Iv1.b507a08c87ecfe98"));
965 assert!(body.contains("device_code=test_device_code_123"));
966 assert!(body.contains("grant_type=urn:ietf:params:oauth:grant-type:device_code"));
967 }
968}
969
970zed::register_extension!(CopilotChatProvider);