1use std::collections::HashMap;
2use std::sync::Mutex;
3use std::thread;
4use std::time::Duration;
5
6use serde::{Deserialize, Serialize};
7use zed_extension_api::http_client::{HttpMethod, HttpRequest, HttpResponseStream, RedirectPolicy};
8use zed_extension_api::{self as zed, *};
9
10const GITHUB_DEVICE_CODE_URL: &str = "https://github.com/login/device/code";
11const GITHUB_ACCESS_TOKEN_URL: &str = "https://github.com/login/oauth/access_token";
12const GITHUB_COPILOT_TOKEN_URL: &str = "https://api.github.com/copilot_internal/v2/token";
13const GITHUB_COPILOT_CLIENT_ID: &str = "Iv1.b507a08c87ecfe98";
14
15struct DeviceFlowState {
16 device_code: String,
17 interval: u64,
18 expires_in: u64,
19}
20
21#[derive(Clone)]
22struct ApiToken {
23 api_key: String,
24 api_endpoint: String,
25}
26
27#[derive(Clone, Deserialize)]
28struct CopilotModel {
29 id: String,
30 name: String,
31 #[serde(default)]
32 is_chat_default: bool,
33 #[serde(default)]
34 is_chat_fallback: bool,
35 #[serde(default)]
36 model_picker_enabled: bool,
37 #[serde(default)]
38 capabilities: ModelCapabilities,
39 #[serde(default)]
40 policy: Option<ModelPolicy>,
41}
42
43#[derive(Clone, Default, Deserialize)]
44struct ModelCapabilities {
45 #[serde(default)]
46 family: String,
47 #[serde(default)]
48 limits: ModelLimits,
49 #[serde(default)]
50 supports: ModelSupportedFeatures,
51 #[serde(rename = "type", default)]
52 model_type: String,
53}
54
55#[derive(Clone, Default, Deserialize)]
56struct ModelLimits {
57 #[serde(default)]
58 max_context_window_tokens: u64,
59 #[serde(default)]
60 max_output_tokens: u64,
61}
62
63#[derive(Clone, Default, Deserialize)]
64struct ModelSupportedFeatures {
65 #[serde(default)]
66 streaming: bool,
67 #[serde(default)]
68 tool_calls: bool,
69 #[serde(default)]
70 vision: bool,
71}
72
73#[derive(Clone, Deserialize)]
74struct ModelPolicy {
75 state: String,
76}
77
78struct CopilotChatProvider {
79 streams: Mutex<HashMap<String, StreamState>>,
80 next_stream_id: Mutex<u64>,
81 device_flow_state: Mutex<Option<DeviceFlowState>>,
82 api_token: Mutex<Option<ApiToken>>,
83 cached_models: Mutex<Option<Vec<CopilotModel>>>,
84}
85
86struct StreamState {
87 response_stream: Option<HttpResponseStream>,
88 buffer: String,
89 started: bool,
90 tool_calls: HashMap<usize, AccumulatedToolCall>,
91 tool_calls_emitted: bool,
92}
93
94#[derive(Clone, Default)]
95struct AccumulatedToolCall {
96 id: String,
97 name: String,
98 arguments: String,
99}
100
101#[derive(Serialize)]
102struct OpenAiRequest {
103 model: String,
104 messages: Vec<OpenAiMessage>,
105 #[serde(skip_serializing_if = "Option::is_none")]
106 max_tokens: Option<u64>,
107 #[serde(skip_serializing_if = "Vec::is_empty")]
108 tools: Vec<OpenAiTool>,
109 #[serde(skip_serializing_if = "Option::is_none")]
110 tool_choice: Option<String>,
111 #[serde(skip_serializing_if = "Vec::is_empty")]
112 stop: Vec<String>,
113 #[serde(skip_serializing_if = "Option::is_none")]
114 temperature: Option<f32>,
115 stream: bool,
116 #[serde(skip_serializing_if = "Option::is_none")]
117 stream_options: Option<StreamOptions>,
118}
119
120#[derive(Serialize)]
121struct StreamOptions {
122 include_usage: bool,
123}
124
125#[derive(Serialize)]
126struct OpenAiMessage {
127 role: String,
128 #[serde(skip_serializing_if = "Option::is_none")]
129 content: Option<OpenAiContent>,
130 #[serde(skip_serializing_if = "Option::is_none")]
131 tool_calls: Option<Vec<OpenAiToolCall>>,
132 #[serde(skip_serializing_if = "Option::is_none")]
133 tool_call_id: Option<String>,
134}
135
136#[derive(Serialize, Clone)]
137#[serde(untagged)]
138enum OpenAiContent {
139 Text(String),
140 Parts(Vec<OpenAiContentPart>),
141}
142
143#[derive(Serialize, Clone)]
144#[serde(tag = "type")]
145enum OpenAiContentPart {
146 #[serde(rename = "text")]
147 Text { text: String },
148 #[serde(rename = "image_url")]
149 ImageUrl { image_url: ImageUrl },
150}
151
152#[derive(Serialize, Clone)]
153struct ImageUrl {
154 url: String,
155}
156
157#[derive(Serialize, Clone)]
158struct OpenAiToolCall {
159 id: String,
160 #[serde(rename = "type")]
161 call_type: String,
162 function: OpenAiFunctionCall,
163}
164
165#[derive(Serialize, Clone)]
166struct OpenAiFunctionCall {
167 name: String,
168 arguments: String,
169}
170
171#[derive(Serialize)]
172struct OpenAiTool {
173 #[serde(rename = "type")]
174 tool_type: String,
175 function: OpenAiFunctionDef,
176}
177
178#[derive(Serialize)]
179struct OpenAiFunctionDef {
180 name: String,
181 description: String,
182 parameters: serde_json::Value,
183}
184
185#[derive(Deserialize, Debug)]
186struct OpenAiStreamResponse {
187 choices: Vec<OpenAiStreamChoice>,
188 #[serde(default)]
189 usage: Option<OpenAiUsage>,
190}
191
192#[derive(Deserialize, Debug)]
193struct OpenAiStreamChoice {
194 delta: OpenAiDelta,
195 finish_reason: Option<String>,
196}
197
198#[derive(Deserialize, Debug, Default)]
199struct OpenAiDelta {
200 #[serde(default)]
201 content: Option<String>,
202 #[serde(default)]
203 tool_calls: Option<Vec<OpenAiToolCallDelta>>,
204}
205
206#[derive(Deserialize, Debug)]
207struct OpenAiToolCallDelta {
208 index: usize,
209 #[serde(default)]
210 id: Option<String>,
211 #[serde(default)]
212 function: Option<OpenAiFunctionDelta>,
213}
214
215#[derive(Deserialize, Debug, Default)]
216struct OpenAiFunctionDelta {
217 #[serde(default)]
218 name: Option<String>,
219 #[serde(default)]
220 arguments: Option<String>,
221}
222
223#[derive(Deserialize, Debug)]
224struct OpenAiUsage {
225 prompt_tokens: u64,
226 completion_tokens: u64,
227}
228
229fn convert_request(
230 model_id: &str,
231 request: &LlmCompletionRequest,
232) -> Result<OpenAiRequest, String> {
233 let mut messages: Vec<OpenAiMessage> = Vec::new();
234
235 for msg in &request.messages {
236 match msg.role {
237 LlmMessageRole::System => {
238 let mut text_content = String::new();
239 for content in &msg.content {
240 if let LlmMessageContent::Text(text) = content {
241 if !text_content.is_empty() {
242 text_content.push('\n');
243 }
244 text_content.push_str(text);
245 }
246 }
247 if !text_content.is_empty() {
248 messages.push(OpenAiMessage {
249 role: "system".to_string(),
250 content: Some(OpenAiContent::Text(text_content)),
251 tool_calls: None,
252 tool_call_id: None,
253 });
254 }
255 }
256 LlmMessageRole::User => {
257 let mut parts: Vec<OpenAiContentPart> = Vec::new();
258 let mut tool_result_messages: Vec<OpenAiMessage> = Vec::new();
259
260 for content in &msg.content {
261 match content {
262 LlmMessageContent::Text(text) => {
263 if !text.is_empty() {
264 parts.push(OpenAiContentPart::Text { text: text.clone() });
265 }
266 }
267 LlmMessageContent::Image(img) => {
268 let data_url = format!("data:image/png;base64,{}", img.source);
269 parts.push(OpenAiContentPart::ImageUrl {
270 image_url: ImageUrl { url: data_url },
271 });
272 }
273 LlmMessageContent::ToolResult(result) => {
274 let content_text = match &result.content {
275 LlmToolResultContent::Text(t) => t.clone(),
276 LlmToolResultContent::Image(_) => "[Image]".to_string(),
277 };
278 tool_result_messages.push(OpenAiMessage {
279 role: "tool".to_string(),
280 content: Some(OpenAiContent::Text(content_text)),
281 tool_calls: None,
282 tool_call_id: Some(result.tool_use_id.clone()),
283 });
284 }
285 _ => {}
286 }
287 }
288
289 if !parts.is_empty() {
290 let content = if parts.len() == 1 {
291 if let OpenAiContentPart::Text { text } = &parts[0] {
292 OpenAiContent::Text(text.clone())
293 } else {
294 OpenAiContent::Parts(parts)
295 }
296 } else {
297 OpenAiContent::Parts(parts)
298 };
299
300 messages.push(OpenAiMessage {
301 role: "user".to_string(),
302 content: Some(content),
303 tool_calls: None,
304 tool_call_id: None,
305 });
306 }
307
308 messages.extend(tool_result_messages);
309 }
310 LlmMessageRole::Assistant => {
311 let mut text_content = String::new();
312 let mut tool_calls: Vec<OpenAiToolCall> = Vec::new();
313
314 for content in &msg.content {
315 match content {
316 LlmMessageContent::Text(text) => {
317 if !text.is_empty() {
318 if !text_content.is_empty() {
319 text_content.push('\n');
320 }
321 text_content.push_str(text);
322 }
323 }
324 LlmMessageContent::ToolUse(tool_use) => {
325 tool_calls.push(OpenAiToolCall {
326 id: tool_use.id.clone(),
327 call_type: "function".to_string(),
328 function: OpenAiFunctionCall {
329 name: tool_use.name.clone(),
330 arguments: tool_use.input.clone(),
331 },
332 });
333 }
334 _ => {}
335 }
336 }
337
338 messages.push(OpenAiMessage {
339 role: "assistant".to_string(),
340 content: if text_content.is_empty() {
341 None
342 } else {
343 Some(OpenAiContent::Text(text_content))
344 },
345 tool_calls: if tool_calls.is_empty() {
346 None
347 } else {
348 Some(tool_calls)
349 },
350 tool_call_id: None,
351 });
352 }
353 }
354 }
355
356 let tools: Vec<OpenAiTool> = request
357 .tools
358 .iter()
359 .map(|t| OpenAiTool {
360 tool_type: "function".to_string(),
361 function: OpenAiFunctionDef {
362 name: t.name.clone(),
363 description: t.description.clone(),
364 parameters: serde_json::from_str(&t.input_schema)
365 .unwrap_or(serde_json::Value::Object(Default::default())),
366 },
367 })
368 .collect();
369
370 let tool_choice = request.tool_choice.as_ref().map(|tc| match tc {
371 LlmToolChoice::Auto => "auto".to_string(),
372 LlmToolChoice::Any => "required".to_string(),
373 LlmToolChoice::None => "none".to_string(),
374 });
375
376 let max_tokens = request.max_tokens;
377
378 Ok(OpenAiRequest {
379 model: model_id.to_string(),
380 messages,
381 max_tokens,
382 tools,
383 tool_choice,
384 stop: request.stop_sequences.clone(),
385 temperature: request.temperature,
386 stream: true,
387 stream_options: Some(StreamOptions {
388 include_usage: true,
389 }),
390 })
391}
392
393fn parse_sse_line(line: &str) -> Option<OpenAiStreamResponse> {
394 let data = line.strip_prefix("data: ")?;
395 if data.trim() == "[DONE]" {
396 return None;
397 }
398 serde_json::from_str(data).ok()
399}
400
401impl zed::Extension for CopilotChatProvider {
402 fn new() -> Self {
403 Self {
404 streams: Mutex::new(HashMap::new()),
405 next_stream_id: Mutex::new(0),
406 device_flow_state: Mutex::new(None),
407 api_token: Mutex::new(None),
408 cached_models: Mutex::new(None),
409 }
410 }
411
412 fn llm_providers(&self) -> Vec<LlmProviderInfo> {
413 vec![LlmProviderInfo {
414 id: "copilot-chat".into(),
415 name: "Copilot Chat".into(),
416 icon: Some("icons/copilot.svg".into()),
417 }]
418 }
419
420 fn llm_provider_models(&self, _provider_id: &str) -> Result<Vec<LlmModelInfo>, String> {
421 // Try to get models from cache first
422 if let Some(models) = self.cached_models.lock().unwrap().as_ref() {
423 return Ok(convert_models_to_llm_info(models));
424 }
425
426 // Need to fetch models - requires authentication
427 let oauth_token = match llm_get_credential("copilot-chat") {
428 Some(token) => token,
429 None => return Ok(Vec::new()), // Not authenticated, return empty
430 };
431
432 // Get API token
433 let api_token = self.get_api_token(&oauth_token)?;
434
435 // Fetch models from API
436 let models = self.fetch_models(&api_token)?;
437
438 // Cache the models
439 *self.cached_models.lock().unwrap() = Some(models.clone());
440
441 Ok(convert_models_to_llm_info(&models))
442 }
443
444 fn llm_provider_is_authenticated(&self, _provider_id: &str) -> bool {
445 llm_get_credential("copilot-chat").is_some()
446 }
447
448 fn llm_provider_settings_markdown(&self, _provider_id: &str) -> Option<String> {
449 Some(
450 "To use Copilot Chat, sign in with your GitHub account. This requires an active [GitHub Copilot subscription](https://github.com/features/copilot).".to_string(),
451 )
452 }
453
454 fn llm_provider_start_device_flow_sign_in(
455 &mut self,
456 _provider_id: &str,
457 ) -> Result<LlmDeviceFlowPromptInfo, String> {
458 // Step 1: Request device and user verification codes
459 let device_code_response = llm_oauth_send_http_request(&LlmOauthHttpRequest {
460 url: GITHUB_DEVICE_CODE_URL.to_string(),
461 method: "POST".to_string(),
462 headers: vec![
463 ("Accept".to_string(), "application/json".to_string()),
464 (
465 "Content-Type".to_string(),
466 "application/x-www-form-urlencoded".to_string(),
467 ),
468 ],
469 body: format!("client_id={}&scope=read:user", GITHUB_COPILOT_CLIENT_ID),
470 })?;
471
472 if device_code_response.status != 200 {
473 return Err(format!(
474 "Failed to get device code: HTTP {}",
475 device_code_response.status
476 ));
477 }
478
479 #[derive(Deserialize)]
480 struct DeviceCodeResponse {
481 device_code: String,
482 user_code: String,
483 verification_uri: String,
484 #[serde(default)]
485 verification_uri_complete: Option<String>,
486 expires_in: u64,
487 interval: u64,
488 }
489
490 let device_info: DeviceCodeResponse = serde_json::from_str(&device_code_response.body)
491 .map_err(|e| format!("Failed to parse device code response: {}", e))?;
492
493 // Store device flow state for polling
494 *self.device_flow_state.lock().unwrap() = Some(DeviceFlowState {
495 device_code: device_info.device_code,
496 interval: device_info.interval,
497 expires_in: device_info.expires_in,
498 });
499
500 // Step 2: Construct verification URL
501 // Use verification_uri_complete if available (has code pre-filled), otherwise construct URL
502 let verification_url = device_info.verification_uri_complete.unwrap_or_else(|| {
503 format!(
504 "{}?user_code={}",
505 device_info.verification_uri, &device_info.user_code
506 )
507 });
508
509 // Return prompt info for the host to display in the modal
510 Ok(LlmDeviceFlowPromptInfo {
511 user_code: device_info.user_code,
512 verification_url,
513 headline: "Use GitHub Copilot in Zed.".to_string(),
514 description: "Using Copilot requires an active subscription on GitHub.".to_string(),
515 connect_button_label: "Connect to GitHub".to_string(),
516 success_headline: "Copilot Enabled!".to_string(),
517 success_message:
518 "You can update your settings or sign out from the Copilot menu in the status bar."
519 .to_string(),
520 })
521 }
522
523 fn llm_provider_poll_device_flow_sign_in(&mut self, _provider_id: &str) -> Result<(), String> {
524 let state = self
525 .device_flow_state
526 .lock()
527 .unwrap()
528 .take()
529 .ok_or("No device flow in progress")?;
530
531 let poll_interval = Duration::from_secs(state.interval.max(5));
532 let max_attempts = (state.expires_in / state.interval.max(5)) as usize;
533
534 for _ in 0..max_attempts {
535 thread::sleep(poll_interval);
536
537 let token_response = llm_oauth_send_http_request(&LlmOauthHttpRequest {
538 url: GITHUB_ACCESS_TOKEN_URL.to_string(),
539 method: "POST".to_string(),
540 headers: vec![
541 ("Accept".to_string(), "application/json".to_string()),
542 (
543 "Content-Type".to_string(),
544 "application/x-www-form-urlencoded".to_string(),
545 ),
546 ],
547 body: format!(
548 "client_id={}&device_code={}&grant_type=urn:ietf:params:oauth:grant-type:device_code",
549 GITHUB_COPILOT_CLIENT_ID, state.device_code
550 ),
551 })?;
552
553 #[derive(Deserialize)]
554 struct TokenResponse {
555 access_token: Option<String>,
556 error: Option<String>,
557 error_description: Option<String>,
558 }
559
560 let token_json: TokenResponse = serde_json::from_str(&token_response.body)
561 .map_err(|e| format!("Failed to parse token response: {}", e))?;
562
563 if let Some(access_token) = token_json.access_token {
564 llm_store_credential("copilot-chat", &access_token)?;
565 return Ok(());
566 }
567
568 if let Some(error) = &token_json.error {
569 match error.as_str() {
570 "authorization_pending" => {
571 // User hasn't authorized yet, keep polling
572 continue;
573 }
574 "slow_down" => {
575 // Need to slow down polling
576 thread::sleep(Duration::from_secs(5));
577 continue;
578 }
579 "expired_token" => {
580 return Err("Device code expired. Please try again.".to_string());
581 }
582 "access_denied" => {
583 return Err("Authorization was denied.".to_string());
584 }
585 _ => {
586 let description = token_json.error_description.unwrap_or_default();
587 return Err(format!("OAuth error: {} - {}", error, description));
588 }
589 }
590 }
591 }
592
593 Err("Authorization timed out. Please try again.".to_string())
594 }
595
596 fn llm_provider_reset_credentials(&mut self, _provider_id: &str) -> Result<(), String> {
597 // Clear cached API token and models
598 *self.api_token.lock().unwrap() = None;
599 *self.cached_models.lock().unwrap() = None;
600 llm_delete_credential("copilot-chat")
601 }
602
603 fn llm_stream_completion_start(
604 &mut self,
605 _provider_id: &str,
606 model_id: &str,
607 request: &LlmCompletionRequest,
608 ) -> Result<String, String> {
609 let oauth_token = llm_get_credential("copilot-chat").ok_or_else(|| {
610 "No token configured. Please add your GitHub Copilot token in settings.".to_string()
611 })?;
612
613 // Get or refresh API token
614 let api_token = self.get_api_token(&oauth_token)?;
615
616 let openai_request = convert_request(model_id, request)?;
617
618 let body = serde_json::to_vec(&openai_request)
619 .map_err(|e| format!("Failed to serialize request: {}", e))?;
620
621 let completions_url = format!("{}/chat/completions", api_token.api_endpoint);
622
623 let http_request = HttpRequest {
624 method: HttpMethod::Post,
625 url: completions_url,
626 headers: vec![
627 ("Content-Type".to_string(), "application/json".to_string()),
628 (
629 "Authorization".to_string(),
630 format!("Bearer {}", api_token.api_key),
631 ),
632 (
633 "Copilot-Integration-Id".to_string(),
634 "vscode-chat".to_string(),
635 ),
636 ("Editor-Version".to_string(), "Zed/1.0.0".to_string()),
637 ],
638 body: Some(body),
639 redirect_policy: RedirectPolicy::FollowAll,
640 };
641
642 let response_stream = http_request
643 .fetch_stream()
644 .map_err(|e| format!("HTTP request failed: {}", e))?;
645
646 let stream_id = {
647 let mut id_counter = self.next_stream_id.lock().unwrap();
648 let id = format!("copilot-stream-{}", *id_counter);
649 *id_counter += 1;
650 id
651 };
652
653 self.streams.lock().unwrap().insert(
654 stream_id.clone(),
655 StreamState {
656 response_stream: Some(response_stream),
657 buffer: String::new(),
658 started: false,
659 tool_calls: HashMap::new(),
660 tool_calls_emitted: false,
661 },
662 );
663
664 Ok(stream_id)
665 }
666
667 fn llm_stream_completion_next(
668 &mut self,
669 stream_id: &str,
670 ) -> Result<Option<LlmCompletionEvent>, String> {
671 let mut streams = self.streams.lock().unwrap();
672 let state = streams
673 .get_mut(stream_id)
674 .ok_or_else(|| format!("Unknown stream: {}", stream_id))?;
675
676 if !state.started {
677 state.started = true;
678 return Ok(Some(LlmCompletionEvent::Started));
679 }
680
681 let response_stream = state
682 .response_stream
683 .as_mut()
684 .ok_or_else(|| "Stream already closed".to_string())?;
685
686 loop {
687 if let Some(newline_pos) = state.buffer.find('\n') {
688 let line = state.buffer[..newline_pos].to_string();
689 state.buffer = state.buffer[newline_pos + 1..].to_string();
690
691 if line.trim().is_empty() {
692 continue;
693 }
694
695 if let Some(response) = parse_sse_line(&line) {
696 if let Some(choice) = response.choices.first() {
697 if let Some(content) = &choice.delta.content {
698 if !content.is_empty() {
699 return Ok(Some(LlmCompletionEvent::Text(content.clone())));
700 }
701 }
702
703 if let Some(tool_calls) = &choice.delta.tool_calls {
704 for tc in tool_calls {
705 let entry = state
706 .tool_calls
707 .entry(tc.index)
708 .or_insert_with(AccumulatedToolCall::default);
709
710 if let Some(id) = &tc.id {
711 entry.id = id.clone();
712 }
713 if let Some(func) = &tc.function {
714 if let Some(name) = &func.name {
715 entry.name = name.clone();
716 }
717 if let Some(args) = &func.arguments {
718 entry.arguments.push_str(args);
719 }
720 }
721 }
722 }
723
724 if let Some(finish_reason) = &choice.finish_reason {
725 if !state.tool_calls.is_empty() && !state.tool_calls_emitted {
726 state.tool_calls_emitted = true;
727 let mut tool_calls: Vec<_> = state.tool_calls.drain().collect();
728 tool_calls.sort_by_key(|(idx, _)| *idx);
729
730 if let Some((_, tc)) = tool_calls.into_iter().next() {
731 return Ok(Some(LlmCompletionEvent::ToolUse(LlmToolUse {
732 id: tc.id,
733 name: tc.name,
734 input: tc.arguments,
735 thought_signature: None,
736 })));
737 }
738 }
739
740 let stop_reason = match finish_reason.as_str() {
741 "stop" => LlmStopReason::EndTurn,
742 "length" => LlmStopReason::MaxTokens,
743 "tool_calls" => LlmStopReason::ToolUse,
744 "content_filter" => LlmStopReason::Refusal,
745 _ => LlmStopReason::EndTurn,
746 };
747 return Ok(Some(LlmCompletionEvent::Stop(stop_reason)));
748 }
749 }
750
751 if let Some(usage) = response.usage {
752 return Ok(Some(LlmCompletionEvent::Usage(LlmTokenUsage {
753 input_tokens: usage.prompt_tokens,
754 output_tokens: usage.completion_tokens,
755 cache_creation_input_tokens: None,
756 cache_read_input_tokens: None,
757 })));
758 }
759 }
760
761 continue;
762 }
763
764 match response_stream.next_chunk() {
765 Ok(Some(chunk)) => {
766 let text = String::from_utf8_lossy(&chunk);
767 state.buffer.push_str(&text);
768 }
769 Ok(None) => {
770 return Ok(None);
771 }
772 Err(e) => {
773 return Err(format!("Stream error: {}", e));
774 }
775 }
776 }
777 }
778
779 fn llm_stream_completion_close(&mut self, stream_id: &str) {
780 self.streams.lock().unwrap().remove(stream_id);
781 }
782}
783
784impl CopilotChatProvider {
785 fn get_api_token(&self, oauth_token: &str) -> Result<ApiToken, String> {
786 // Check if we have a cached token
787 if let Some(token) = self.api_token.lock().unwrap().clone() {
788 return Ok(token);
789 }
790
791 // Request a new API token
792 let http_request = HttpRequest {
793 method: HttpMethod::Get,
794 url: GITHUB_COPILOT_TOKEN_URL.to_string(),
795 headers: vec![
796 (
797 "Authorization".to_string(),
798 format!("token {}", oauth_token),
799 ),
800 ("Accept".to_string(), "application/json".to_string()),
801 ],
802 body: None,
803 redirect_policy: RedirectPolicy::FollowAll,
804 };
805
806 let response = http_request
807 .fetch()
808 .map_err(|e| format!("Failed to request API token: {}", e))?;
809
810 #[derive(Deserialize)]
811 struct ApiTokenResponse {
812 token: String,
813 endpoints: ApiEndpoints,
814 }
815
816 #[derive(Deserialize)]
817 struct ApiEndpoints {
818 api: String,
819 }
820
821 let token_response: ApiTokenResponse =
822 serde_json::from_slice(&response.body).map_err(|e| {
823 format!(
824 "Failed to parse API token response: {} - body: {}",
825 e,
826 String::from_utf8_lossy(&response.body)
827 )
828 })?;
829
830 let api_token = ApiToken {
831 api_key: token_response.token,
832 api_endpoint: token_response.endpoints.api,
833 };
834
835 // Cache the token
836 *self.api_token.lock().unwrap() = Some(api_token.clone());
837
838 Ok(api_token)
839 }
840
841 fn fetch_models(&self, api_token: &ApiToken) -> Result<Vec<CopilotModel>, String> {
842 let models_url = format!("{}/models", api_token.api_endpoint);
843
844 let http_request = HttpRequest {
845 method: HttpMethod::Get,
846 url: models_url,
847 headers: vec![
848 (
849 "Authorization".to_string(),
850 format!("Bearer {}", api_token.api_key),
851 ),
852 ("Content-Type".to_string(), "application/json".to_string()),
853 (
854 "Copilot-Integration-Id".to_string(),
855 "vscode-chat".to_string(),
856 ),
857 ("Editor-Version".to_string(), "Zed/1.0.0".to_string()),
858 ("x-github-api-version".to_string(), "2025-05-01".to_string()),
859 ],
860 body: None,
861 redirect_policy: RedirectPolicy::FollowAll,
862 };
863
864 let response = http_request
865 .fetch()
866 .map_err(|e| format!("Failed to fetch models: {}", e))?;
867
868 #[derive(Deserialize)]
869 struct ModelsResponse {
870 data: Vec<CopilotModel>,
871 }
872
873 let models_response: ModelsResponse =
874 serde_json::from_slice(&response.body).map_err(|e| {
875 format!(
876 "Failed to parse models response: {} - body: {}",
877 e,
878 String::from_utf8_lossy(&response.body)
879 )
880 })?;
881
882 // Filter models like the built-in Copilot Chat does
883 let mut models: Vec<CopilotModel> = models_response
884 .data
885 .into_iter()
886 .filter(|model| {
887 model.model_picker_enabled
888 && model.capabilities.model_type == "chat"
889 && model
890 .policy
891 .as_ref()
892 .map(|p| p.state == "enabled")
893 .unwrap_or(true)
894 })
895 .collect();
896
897 // Sort so default model is first
898 if let Some(pos) = models.iter().position(|m| m.is_chat_default) {
899 let default_model = models.remove(pos);
900 models.insert(0, default_model);
901 }
902
903 Ok(models)
904 }
905}
906
907fn convert_models_to_llm_info(models: &[CopilotModel]) -> Vec<LlmModelInfo> {
908 models
909 .iter()
910 .map(|m| {
911 let max_tokens = if m.capabilities.limits.max_context_window_tokens > 0 {
912 m.capabilities.limits.max_context_window_tokens
913 } else {
914 128_000 // Default fallback
915 };
916 let max_output = if m.capabilities.limits.max_output_tokens > 0 {
917 Some(m.capabilities.limits.max_output_tokens)
918 } else {
919 None
920 };
921
922 LlmModelInfo {
923 id: m.id.clone(),
924 name: m.name.clone(),
925 max_token_count: max_tokens,
926 max_output_tokens: max_output,
927 capabilities: LlmModelCapabilities {
928 supports_images: m.capabilities.supports.vision,
929 supports_tools: m.capabilities.supports.tool_calls,
930 supports_tool_choice_auto: m.capabilities.supports.tool_calls,
931 supports_tool_choice_any: m.capabilities.supports.tool_calls,
932 supports_tool_choice_none: m.capabilities.supports.tool_calls,
933 supports_thinking: false,
934 tool_input_format: LlmToolInputFormat::JsonSchema,
935 },
936 is_default: m.is_chat_default,
937 is_default_fast: m.is_chat_fallback,
938 }
939 })
940 .collect()
941}
942
943#[cfg(test)]
944mod tests {
945 use super::*;
946
947 #[test]
948 fn test_device_flow_request_body() {
949 let body = format!("client_id={}&scope=read:user", GITHUB_COPILOT_CLIENT_ID);
950 assert!(body.contains("client_id=Iv1.b507a08c87ecfe98"));
951 assert!(body.contains("scope=read:user"));
952 }
953
954 #[test]
955 fn test_token_poll_request_body() {
956 let device_code = "test_device_code_123";
957 let body = format!(
958 "client_id={}&device_code={}&grant_type=urn:ietf:params:oauth:grant-type:device_code",
959 GITHUB_COPILOT_CLIENT_ID, device_code
960 );
961 assert!(body.contains("client_id=Iv1.b507a08c87ecfe98"));
962 assert!(body.contains("device_code=test_device_code_123"));
963 assert!(body.contains("grant_type=urn:ietf:params:oauth:grant-type:device_code"));
964 }
965}
966
967zed::register_extension!(CopilotChatProvider);