1use std::collections::HashMap;
2use std::sync::Mutex;
3use std::thread;
4use std::time::Duration;
5
6use serde::{Deserialize, Serialize};
7use zed_extension_api::http_client::{HttpMethod, HttpRequest, HttpResponseStream, RedirectPolicy};
8use zed_extension_api::{self as zed, *};
9
10const GITHUB_DEVICE_CODE_URL: &str = "https://github.com/login/device/code";
11const GITHUB_ACCESS_TOKEN_URL: &str = "https://github.com/login/oauth/access_token";
12const GITHUB_COPILOT_TOKEN_URL: &str = "https://api.github.com/copilot_internal/v2/token";
13const GITHUB_COPILOT_CLIENT_ID: &str = "Iv1.b507a08c87ecfe98";
14
15struct DeviceFlowState {
16 device_code: String,
17 interval: u64,
18 expires_in: u64,
19}
20
21#[derive(Clone)]
22struct ApiToken {
23 api_key: String,
24 api_endpoint: String,
25}
26
27#[derive(Clone, Deserialize)]
28struct CopilotModel {
29 id: String,
30 name: String,
31 #[serde(default)]
32 is_chat_default: bool,
33 #[serde(default)]
34 is_chat_fallback: bool,
35 #[serde(default)]
36 model_picker_enabled: bool,
37 #[serde(default)]
38 capabilities: ModelCapabilities,
39 #[serde(default)]
40 policy: Option<ModelPolicy>,
41}
42
43#[derive(Clone, Default, Deserialize)]
44struct ModelCapabilities {
45 #[serde(default)]
46 family: String,
47 #[serde(default)]
48 limits: ModelLimits,
49 #[serde(default)]
50 supports: ModelSupportedFeatures,
51 #[serde(rename = "type", default)]
52 model_type: String,
53}
54
55#[derive(Clone, Default, Deserialize)]
56struct ModelLimits {
57 #[serde(default)]
58 max_context_window_tokens: u64,
59 #[serde(default)]
60 max_output_tokens: u64,
61}
62
63#[derive(Clone, Default, Deserialize)]
64struct ModelSupportedFeatures {
65 #[serde(default)]
66 streaming: bool,
67 #[serde(default)]
68 tool_calls: bool,
69 #[serde(default)]
70 vision: bool,
71}
72
73#[derive(Clone, Deserialize)]
74struct ModelPolicy {
75 state: String,
76}
77
78struct CopilotChatProvider {
79 streams: Mutex<HashMap<String, StreamState>>,
80 next_stream_id: Mutex<u64>,
81 device_flow_state: Mutex<Option<DeviceFlowState>>,
82 api_token: Mutex<Option<ApiToken>>,
83 cached_models: Mutex<Option<Vec<CopilotModel>>>,
84}
85
86struct StreamState {
87 response_stream: Option<HttpResponseStream>,
88 buffer: String,
89 started: bool,
90 tool_calls: HashMap<usize, AccumulatedToolCall>,
91 tool_calls_emitted: bool,
92}
93
94#[derive(Clone, Default)]
95struct AccumulatedToolCall {
96 id: String,
97 name: String,
98 arguments: String,
99}
100
101#[derive(Serialize)]
102struct OpenAiRequest {
103 model: String,
104 messages: Vec<OpenAiMessage>,
105 #[serde(skip_serializing_if = "Option::is_none")]
106 max_tokens: Option<u64>,
107 #[serde(skip_serializing_if = "Vec::is_empty")]
108 tools: Vec<OpenAiTool>,
109 #[serde(skip_serializing_if = "Option::is_none")]
110 tool_choice: Option<String>,
111 #[serde(skip_serializing_if = "Vec::is_empty")]
112 stop: Vec<String>,
113 #[serde(skip_serializing_if = "Option::is_none")]
114 temperature: Option<f32>,
115 stream: bool,
116 #[serde(skip_serializing_if = "Option::is_none")]
117 stream_options: Option<StreamOptions>,
118}
119
120#[derive(Serialize)]
121struct StreamOptions {
122 include_usage: bool,
123}
124
125#[derive(Serialize)]
126struct OpenAiMessage {
127 role: String,
128 #[serde(skip_serializing_if = "Option::is_none")]
129 content: Option<OpenAiContent>,
130 #[serde(skip_serializing_if = "Option::is_none")]
131 tool_calls: Option<Vec<OpenAiToolCall>>,
132 #[serde(skip_serializing_if = "Option::is_none")]
133 tool_call_id: Option<String>,
134}
135
136#[derive(Serialize, Clone)]
137#[serde(untagged)]
138enum OpenAiContent {
139 Text(String),
140 Parts(Vec<OpenAiContentPart>),
141}
142
143#[derive(Serialize, Clone)]
144#[serde(tag = "type")]
145enum OpenAiContentPart {
146 #[serde(rename = "text")]
147 Text { text: String },
148 #[serde(rename = "image_url")]
149 ImageUrl { image_url: ImageUrl },
150}
151
152#[derive(Serialize, Clone)]
153struct ImageUrl {
154 url: String,
155}
156
157#[derive(Serialize, Clone)]
158struct OpenAiToolCall {
159 id: String,
160 #[serde(rename = "type")]
161 call_type: String,
162 function: OpenAiFunctionCall,
163}
164
165#[derive(Serialize, Clone)]
166struct OpenAiFunctionCall {
167 name: String,
168 arguments: String,
169}
170
171#[derive(Serialize)]
172struct OpenAiTool {
173 #[serde(rename = "type")]
174 tool_type: String,
175 function: OpenAiFunctionDef,
176}
177
178#[derive(Serialize)]
179struct OpenAiFunctionDef {
180 name: String,
181 description: String,
182 parameters: serde_json::Value,
183}
184
185#[derive(Deserialize, Debug)]
186struct OpenAiStreamResponse {
187 choices: Vec<OpenAiStreamChoice>,
188 #[serde(default)]
189 usage: Option<OpenAiUsage>,
190}
191
192#[derive(Deserialize, Debug)]
193struct OpenAiStreamChoice {
194 delta: OpenAiDelta,
195 finish_reason: Option<String>,
196}
197
198#[derive(Deserialize, Debug, Default)]
199struct OpenAiDelta {
200 #[serde(default)]
201 content: Option<String>,
202 #[serde(default)]
203 tool_calls: Option<Vec<OpenAiToolCallDelta>>,
204}
205
206#[derive(Deserialize, Debug)]
207struct OpenAiToolCallDelta {
208 index: usize,
209 #[serde(default)]
210 id: Option<String>,
211 #[serde(default)]
212 function: Option<OpenAiFunctionDelta>,
213}
214
215#[derive(Deserialize, Debug, Default)]
216struct OpenAiFunctionDelta {
217 #[serde(default)]
218 name: Option<String>,
219 #[serde(default)]
220 arguments: Option<String>,
221}
222
223#[derive(Deserialize, Debug)]
224struct OpenAiUsage {
225 prompt_tokens: u64,
226 completion_tokens: u64,
227}
228
229fn convert_request(
230 model_id: &str,
231 request: &LlmCompletionRequest,
232) -> Result<OpenAiRequest, String> {
233 let mut messages: Vec<OpenAiMessage> = Vec::new();
234
235 for msg in &request.messages {
236 match msg.role {
237 LlmMessageRole::System => {
238 let mut text_content = String::new();
239 for content in &msg.content {
240 if let LlmMessageContent::Text(text) = content {
241 if !text_content.is_empty() {
242 text_content.push('\n');
243 }
244 text_content.push_str(text);
245 }
246 }
247 if !text_content.is_empty() {
248 messages.push(OpenAiMessage {
249 role: "system".to_string(),
250 content: Some(OpenAiContent::Text(text_content)),
251 tool_calls: None,
252 tool_call_id: None,
253 });
254 }
255 }
256 LlmMessageRole::User => {
257 let mut parts: Vec<OpenAiContentPart> = Vec::new();
258 let mut tool_result_messages: Vec<OpenAiMessage> = Vec::new();
259
260 for content in &msg.content {
261 match content {
262 LlmMessageContent::Text(text) => {
263 if !text.is_empty() {
264 parts.push(OpenAiContentPart::Text { text: text.clone() });
265 }
266 }
267 LlmMessageContent::Image(img) => {
268 let data_url = format!("data:image/png;base64,{}", img.source);
269 parts.push(OpenAiContentPart::ImageUrl {
270 image_url: ImageUrl { url: data_url },
271 });
272 }
273 LlmMessageContent::ToolResult(result) => {
274 let content_text = match &result.content {
275 LlmToolResultContent::Text(t) => t.clone(),
276 LlmToolResultContent::Image(_) => "[Image]".to_string(),
277 };
278 tool_result_messages.push(OpenAiMessage {
279 role: "tool".to_string(),
280 content: Some(OpenAiContent::Text(content_text)),
281 tool_calls: None,
282 tool_call_id: Some(result.tool_use_id.clone()),
283 });
284 }
285 _ => {}
286 }
287 }
288
289 if !parts.is_empty() {
290 let content = if parts.len() == 1 {
291 if let OpenAiContentPart::Text { text } = &parts[0] {
292 OpenAiContent::Text(text.clone())
293 } else {
294 OpenAiContent::Parts(parts)
295 }
296 } else {
297 OpenAiContent::Parts(parts)
298 };
299
300 messages.push(OpenAiMessage {
301 role: "user".to_string(),
302 content: Some(content),
303 tool_calls: None,
304 tool_call_id: None,
305 });
306 }
307
308 messages.extend(tool_result_messages);
309 }
310 LlmMessageRole::Assistant => {
311 let mut text_content = String::new();
312 let mut tool_calls: Vec<OpenAiToolCall> = Vec::new();
313
314 for content in &msg.content {
315 match content {
316 LlmMessageContent::Text(text) => {
317 if !text.is_empty() {
318 if !text_content.is_empty() {
319 text_content.push('\n');
320 }
321 text_content.push_str(text);
322 }
323 }
324 LlmMessageContent::ToolUse(tool_use) => {
325 tool_calls.push(OpenAiToolCall {
326 id: tool_use.id.clone(),
327 call_type: "function".to_string(),
328 function: OpenAiFunctionCall {
329 name: tool_use.name.clone(),
330 arguments: tool_use.input.clone(),
331 },
332 });
333 }
334 _ => {}
335 }
336 }
337
338 messages.push(OpenAiMessage {
339 role: "assistant".to_string(),
340 content: if text_content.is_empty() {
341 None
342 } else {
343 Some(OpenAiContent::Text(text_content))
344 },
345 tool_calls: if tool_calls.is_empty() {
346 None
347 } else {
348 Some(tool_calls)
349 },
350 tool_call_id: None,
351 });
352 }
353 }
354 }
355
356 let tools: Vec<OpenAiTool> = request
357 .tools
358 .iter()
359 .map(|t| OpenAiTool {
360 tool_type: "function".to_string(),
361 function: OpenAiFunctionDef {
362 name: t.name.clone(),
363 description: t.description.clone(),
364 parameters: serde_json::from_str(&t.input_schema)
365 .unwrap_or(serde_json::Value::Object(Default::default())),
366 },
367 })
368 .collect();
369
370 let tool_choice = request.tool_choice.as_ref().map(|tc| match tc {
371 LlmToolChoice::Auto => "auto".to_string(),
372 LlmToolChoice::Any => "required".to_string(),
373 LlmToolChoice::None => "none".to_string(),
374 });
375
376 let max_tokens = request.max_tokens;
377
378 Ok(OpenAiRequest {
379 model: model_id.to_string(),
380 messages,
381 max_tokens,
382 tools,
383 tool_choice,
384 stop: request.stop_sequences.clone(),
385 temperature: request.temperature,
386 stream: true,
387 stream_options: Some(StreamOptions {
388 include_usage: true,
389 }),
390 })
391}
392
393fn parse_sse_line(line: &str) -> Option<OpenAiStreamResponse> {
394 let data = line.strip_prefix("data: ")?;
395 if data.trim() == "[DONE]" {
396 return None;
397 }
398 serde_json::from_str(data).ok()
399}
400
401impl zed::Extension for CopilotChatProvider {
402 fn new() -> Self {
403 Self {
404 streams: Mutex::new(HashMap::new()),
405 next_stream_id: Mutex::new(0),
406 device_flow_state: Mutex::new(None),
407 api_token: Mutex::new(None),
408 cached_models: Mutex::new(None),
409 }
410 }
411
412 fn llm_providers(&self) -> Vec<LlmProviderInfo> {
413 vec![LlmProviderInfo {
414 id: "copilot-chat".into(),
415 name: "Copilot Chat".into(),
416 icon: Some("icons/copilot.svg".into()),
417 }]
418 }
419
420 fn llm_provider_models(&self, _provider_id: &str) -> Result<Vec<LlmModelInfo>, String> {
421 // Try to get models from cache first
422 if let Some(models) = self.cached_models.lock().unwrap().as_ref() {
423 return Ok(convert_models_to_llm_info(models));
424 }
425
426 // Need to fetch models - requires authentication
427 let oauth_token = match llm_get_credential("copilot-chat") {
428 Some(token) => token,
429 None => return Ok(Vec::new()), // Not authenticated, return empty
430 };
431
432 // Get API token
433 let api_token = self.get_api_token(&oauth_token)?;
434
435 // Fetch models from API
436 let models = self.fetch_models(&api_token)?;
437
438 // Cache the models
439 *self.cached_models.lock().unwrap() = Some(models.clone());
440
441 Ok(convert_models_to_llm_info(&models))
442 }
443
444 fn llm_provider_is_authenticated(&self, _provider_id: &str) -> bool {
445 llm_get_credential("copilot-chat").is_some()
446 }
447
448 fn llm_provider_settings_markdown(&self, _provider_id: &str) -> Option<String> {
449 Some(
450 "To use Copilot Chat, sign in with your GitHub account. This requires an active [GitHub Copilot subscription](https://github.com/features/copilot).".to_string(),
451 )
452 }
453
454 fn llm_provider_start_device_flow_sign_in(
455 &mut self,
456 _provider_id: &str,
457 ) -> Result<LlmDeviceFlowPromptInfo, String> {
458 // Step 1: Request device and user verification codes
459 let device_code_response = llm_oauth_send_http_request(&HttpRequest {
460 method: HttpMethod::Post,
461 url: GITHUB_DEVICE_CODE_URL.to_string(),
462 headers: vec![
463 ("Accept".to_string(), "application/json".to_string()),
464 (
465 "Content-Type".to_string(),
466 "application/x-www-form-urlencoded".to_string(),
467 ),
468 ],
469 body: Some(
470 format!("client_id={}&scope=read:user", GITHUB_COPILOT_CLIENT_ID).into_bytes(),
471 ),
472 redirect_policy: RedirectPolicy::NoFollow,
473 })?;
474
475 if device_code_response.status != 200 {
476 return Err(format!(
477 "Failed to get device code: HTTP {}",
478 device_code_response.status
479 ));
480 }
481
482 #[derive(Deserialize)]
483 struct DeviceCodeResponse {
484 device_code: String,
485 user_code: String,
486 verification_uri: String,
487 #[serde(default)]
488 verification_uri_complete: Option<String>,
489 expires_in: u64,
490 interval: u64,
491 }
492
493 let device_info: DeviceCodeResponse = serde_json::from_slice(&device_code_response.body)
494 .map_err(|e| format!("Failed to parse device code response: {}", e))?;
495
496 // Store device flow state for polling
497 *self.device_flow_state.lock().unwrap() = Some(DeviceFlowState {
498 device_code: device_info.device_code,
499 interval: device_info.interval,
500 expires_in: device_info.expires_in,
501 });
502
503 // Step 2: Construct verification URL
504 // Use verification_uri_complete if available (has code pre-filled), otherwise construct URL
505 let verification_url = device_info.verification_uri_complete.unwrap_or_else(|| {
506 format!(
507 "{}?user_code={}",
508 device_info.verification_uri, &device_info.user_code
509 )
510 });
511
512 // Return prompt info for the host to display in the modal
513 Ok(LlmDeviceFlowPromptInfo {
514 user_code: device_info.user_code,
515 verification_url,
516 headline: "Use GitHub Copilot in Zed.".to_string(),
517 description: "Using Copilot requires an active subscription on GitHub.".to_string(),
518 connect_button_label: "Connect to GitHub".to_string(),
519 success_headline: "Copilot Enabled!".to_string(),
520 success_message:
521 "You can update your settings or sign out from the Copilot menu in the status bar."
522 .to_string(),
523 })
524 }
525
526 fn llm_provider_poll_device_flow_sign_in(&mut self, _provider_id: &str) -> Result<(), String> {
527 let state = self
528 .device_flow_state
529 .lock()
530 .unwrap()
531 .take()
532 .ok_or("No device flow in progress")?;
533
534 let poll_interval = Duration::from_secs(state.interval.max(5));
535 let max_attempts = (state.expires_in / state.interval.max(5)) as usize;
536
537 for _ in 0..max_attempts {
538 thread::sleep(poll_interval);
539
540 let token_response = llm_oauth_send_http_request(&HttpRequest {
541 method: HttpMethod::Post,
542 url: GITHUB_ACCESS_TOKEN_URL.to_string(),
543 headers: vec![
544 ("Accept".to_string(), "application/json".to_string()),
545 (
546 "Content-Type".to_string(),
547 "application/x-www-form-urlencoded".to_string(),
548 ),
549 ],
550 body: Some(
551 format!(
552 "client_id={}&device_code={}&grant_type=urn:ietf:params:oauth:grant-type:device_code",
553 GITHUB_COPILOT_CLIENT_ID, state.device_code
554 )
555 .into_bytes(),
556 ),
557 redirect_policy: RedirectPolicy::NoFollow,
558 })?;
559
560 #[derive(Deserialize)]
561 struct TokenResponse {
562 access_token: Option<String>,
563 error: Option<String>,
564 error_description: Option<String>,
565 }
566
567 let token_json: TokenResponse = serde_json::from_slice(&token_response.body)
568 .map_err(|e| format!("Failed to parse token response: {}", e))?;
569
570 if let Some(access_token) = token_json.access_token {
571 llm_store_credential("copilot-chat", &access_token)?;
572 return Ok(());
573 }
574
575 if let Some(error) = &token_json.error {
576 match error.as_str() {
577 "authorization_pending" => {
578 // User hasn't authorized yet, keep polling
579 continue;
580 }
581 "slow_down" => {
582 // Need to slow down polling
583 thread::sleep(Duration::from_secs(5));
584 continue;
585 }
586 "expired_token" => {
587 return Err("Device code expired. Please try again.".to_string());
588 }
589 "access_denied" => {
590 return Err("Authorization was denied.".to_string());
591 }
592 _ => {
593 let description = token_json.error_description.unwrap_or_default();
594 return Err(format!("OAuth error: {} - {}", error, description));
595 }
596 }
597 }
598 }
599
600 Err("Authorization timed out. Please try again.".to_string())
601 }
602
603 fn llm_provider_reset_credentials(&mut self, _provider_id: &str) -> Result<(), String> {
604 // Clear cached API token and models
605 *self.api_token.lock().unwrap() = None;
606 *self.cached_models.lock().unwrap() = None;
607 llm_delete_credential("copilot-chat")
608 }
609
610 fn llm_stream_completion_start(
611 &mut self,
612 _provider_id: &str,
613 model_id: &str,
614 request: &LlmCompletionRequest,
615 ) -> Result<String, String> {
616 let oauth_token = llm_get_credential("copilot-chat").ok_or_else(|| {
617 "No token configured. Please add your GitHub Copilot token in settings.".to_string()
618 })?;
619
620 // Get or refresh API token
621 let api_token = self.get_api_token(&oauth_token)?;
622
623 let openai_request = convert_request(model_id, request)?;
624
625 let body = serde_json::to_vec(&openai_request)
626 .map_err(|e| format!("Failed to serialize request: {}", e))?;
627
628 let completions_url = format!("{}/chat/completions", api_token.api_endpoint);
629
630 let http_request = HttpRequest {
631 method: HttpMethod::Post,
632 url: completions_url,
633 headers: vec![
634 ("Content-Type".to_string(), "application/json".to_string()),
635 (
636 "Authorization".to_string(),
637 format!("Bearer {}", api_token.api_key),
638 ),
639 (
640 "Copilot-Integration-Id".to_string(),
641 "vscode-chat".to_string(),
642 ),
643 ("Editor-Version".to_string(), "Zed/1.0.0".to_string()),
644 ],
645 body: Some(body),
646 redirect_policy: RedirectPolicy::FollowAll,
647 };
648
649 let response_stream = http_request
650 .fetch_stream()
651 .map_err(|e| format!("HTTP request failed: {}", e))?;
652
653 let stream_id = {
654 let mut id_counter = self.next_stream_id.lock().unwrap();
655 let id = format!("copilot-stream-{}", *id_counter);
656 *id_counter += 1;
657 id
658 };
659
660 self.streams.lock().unwrap().insert(
661 stream_id.clone(),
662 StreamState {
663 response_stream: Some(response_stream),
664 buffer: String::new(),
665 started: false,
666 tool_calls: HashMap::new(),
667 tool_calls_emitted: false,
668 },
669 );
670
671 Ok(stream_id)
672 }
673
674 fn llm_stream_completion_next(
675 &mut self,
676 stream_id: &str,
677 ) -> Result<Option<LlmCompletionEvent>, String> {
678 let mut streams = self.streams.lock().unwrap();
679 let state = streams
680 .get_mut(stream_id)
681 .ok_or_else(|| format!("Unknown stream: {}", stream_id))?;
682
683 if !state.started {
684 state.started = true;
685 return Ok(Some(LlmCompletionEvent::Started));
686 }
687
688 let response_stream = state
689 .response_stream
690 .as_mut()
691 .ok_or_else(|| "Stream already closed".to_string())?;
692
693 loop {
694 if let Some(newline_pos) = state.buffer.find('\n') {
695 let line = state.buffer[..newline_pos].to_string();
696 state.buffer = state.buffer[newline_pos + 1..].to_string();
697
698 if line.trim().is_empty() {
699 continue;
700 }
701
702 if let Some(response) = parse_sse_line(&line) {
703 if let Some(choice) = response.choices.first() {
704 if let Some(content) = &choice.delta.content {
705 if !content.is_empty() {
706 return Ok(Some(LlmCompletionEvent::Text(content.clone())));
707 }
708 }
709
710 if let Some(tool_calls) = &choice.delta.tool_calls {
711 for tc in tool_calls {
712 let entry = state
713 .tool_calls
714 .entry(tc.index)
715 .or_insert_with(AccumulatedToolCall::default);
716
717 if let Some(id) = &tc.id {
718 entry.id = id.clone();
719 }
720 if let Some(func) = &tc.function {
721 if let Some(name) = &func.name {
722 entry.name = name.clone();
723 }
724 if let Some(args) = &func.arguments {
725 entry.arguments.push_str(args);
726 }
727 }
728 }
729 }
730
731 if let Some(finish_reason) = &choice.finish_reason {
732 if !state.tool_calls.is_empty() && !state.tool_calls_emitted {
733 state.tool_calls_emitted = true;
734 let mut tool_calls: Vec<_> = state.tool_calls.drain().collect();
735 tool_calls.sort_by_key(|(idx, _)| *idx);
736
737 if let Some((_, tc)) = tool_calls.into_iter().next() {
738 return Ok(Some(LlmCompletionEvent::ToolUse(LlmToolUse {
739 id: tc.id,
740 name: tc.name,
741 input: tc.arguments,
742 is_input_complete: true,
743 thought_signature: None,
744 })));
745 }
746 }
747
748 let stop_reason = match finish_reason.as_str() {
749 "stop" => LlmStopReason::EndTurn,
750 "length" => LlmStopReason::MaxTokens,
751 "tool_calls" => LlmStopReason::ToolUse,
752 "content_filter" => LlmStopReason::Refusal,
753 _ => LlmStopReason::EndTurn,
754 };
755 return Ok(Some(LlmCompletionEvent::Stop(stop_reason)));
756 }
757 }
758
759 if let Some(usage) = response.usage {
760 return Ok(Some(LlmCompletionEvent::Usage(LlmTokenUsage {
761 input_tokens: usage.prompt_tokens,
762 output_tokens: usage.completion_tokens,
763 cache_creation_input_tokens: None,
764 cache_read_input_tokens: None,
765 })));
766 }
767 }
768
769 continue;
770 }
771
772 match response_stream.next_chunk() {
773 Ok(Some(chunk)) => {
774 let text = String::from_utf8_lossy(&chunk);
775 state.buffer.push_str(&text);
776 }
777 Ok(None) => {
778 return Ok(None);
779 }
780 Err(e) => {
781 return Err(format!("Stream error: {}", e));
782 }
783 }
784 }
785 }
786
787 fn llm_stream_completion_close(&mut self, stream_id: &str) {
788 self.streams.lock().unwrap().remove(stream_id);
789 }
790}
791
792impl CopilotChatProvider {
793 fn get_api_token(&self, oauth_token: &str) -> Result<ApiToken, String> {
794 // Check if we have a cached token
795 if let Some(token) = self.api_token.lock().unwrap().clone() {
796 return Ok(token);
797 }
798
799 // Request a new API token
800 let http_request = HttpRequest {
801 method: HttpMethod::Get,
802 url: GITHUB_COPILOT_TOKEN_URL.to_string(),
803 headers: vec![
804 (
805 "Authorization".to_string(),
806 format!("token {}", oauth_token),
807 ),
808 ("Accept".to_string(), "application/json".to_string()),
809 ],
810 body: None,
811 redirect_policy: RedirectPolicy::FollowAll,
812 };
813
814 let response = http_request
815 .fetch()
816 .map_err(|e| format!("Failed to request API token: {}", e))?;
817
818 #[derive(Deserialize)]
819 struct ApiTokenResponse {
820 token: String,
821 endpoints: ApiEndpoints,
822 }
823
824 #[derive(Deserialize)]
825 struct ApiEndpoints {
826 api: String,
827 }
828
829 let token_response: ApiTokenResponse =
830 serde_json::from_slice(&response.body).map_err(|e| {
831 format!(
832 "Failed to parse API token response: {} - body: {}",
833 e,
834 String::from_utf8_lossy(&response.body)
835 )
836 })?;
837
838 let api_token = ApiToken {
839 api_key: token_response.token,
840 api_endpoint: token_response.endpoints.api,
841 };
842
843 // Cache the token
844 *self.api_token.lock().unwrap() = Some(api_token.clone());
845
846 Ok(api_token)
847 }
848
849 fn fetch_models(&self, api_token: &ApiToken) -> Result<Vec<CopilotModel>, String> {
850 let models_url = format!("{}/models", api_token.api_endpoint);
851
852 let http_request = HttpRequest {
853 method: HttpMethod::Get,
854 url: models_url,
855 headers: vec![
856 (
857 "Authorization".to_string(),
858 format!("Bearer {}", api_token.api_key),
859 ),
860 ("Content-Type".to_string(), "application/json".to_string()),
861 (
862 "Copilot-Integration-Id".to_string(),
863 "vscode-chat".to_string(),
864 ),
865 ("Editor-Version".to_string(), "Zed/1.0.0".to_string()),
866 ("x-github-api-version".to_string(), "2025-05-01".to_string()),
867 ],
868 body: None,
869 redirect_policy: RedirectPolicy::FollowAll,
870 };
871
872 let response = http_request
873 .fetch()
874 .map_err(|e| format!("Failed to fetch models: {}", e))?;
875
876 #[derive(Deserialize)]
877 struct ModelsResponse {
878 data: Vec<CopilotModel>,
879 }
880
881 let models_response: ModelsResponse =
882 serde_json::from_slice(&response.body).map_err(|e| {
883 format!(
884 "Failed to parse models response: {} - body: {}",
885 e,
886 String::from_utf8_lossy(&response.body)
887 )
888 })?;
889
890 // Filter models like the built-in Copilot Chat does
891 let mut models: Vec<CopilotModel> = models_response
892 .data
893 .into_iter()
894 .filter(|model| {
895 model.model_picker_enabled
896 && model.capabilities.model_type == "chat"
897 && model
898 .policy
899 .as_ref()
900 .map(|p| p.state == "enabled")
901 .unwrap_or(true)
902 })
903 .collect();
904
905 // Sort so default model is first
906 if let Some(pos) = models.iter().position(|m| m.is_chat_default) {
907 let default_model = models.remove(pos);
908 models.insert(0, default_model);
909 }
910
911 Ok(models)
912 }
913}
914
915fn convert_models_to_llm_info(models: &[CopilotModel]) -> Vec<LlmModelInfo> {
916 models
917 .iter()
918 .map(|m| {
919 let max_tokens = if m.capabilities.limits.max_context_window_tokens > 0 {
920 m.capabilities.limits.max_context_window_tokens
921 } else {
922 128_000 // Default fallback
923 };
924 let max_output = if m.capabilities.limits.max_output_tokens > 0 {
925 Some(m.capabilities.limits.max_output_tokens)
926 } else {
927 None
928 };
929
930 LlmModelInfo {
931 id: m.id.clone(),
932 name: m.name.clone(),
933 max_token_count: max_tokens,
934 max_output_tokens: max_output,
935 capabilities: LlmModelCapabilities {
936 supports_images: m.capabilities.supports.vision,
937 supports_tools: m.capabilities.supports.tool_calls,
938 supports_tool_choice_auto: m.capabilities.supports.tool_calls,
939 supports_tool_choice_any: m.capabilities.supports.tool_calls,
940 supports_tool_choice_none: m.capabilities.supports.tool_calls,
941 supports_thinking: false,
942 tool_input_format: LlmToolInputFormat::JsonSchema,
943 },
944 is_default: m.is_chat_default,
945 is_default_fast: m.is_chat_fallback,
946 }
947 })
948 .collect()
949}
950
951#[cfg(test)]
952mod tests {
953 use super::*;
954
955 #[test]
956 fn test_device_flow_request_body() {
957 let body = format!("client_id={}&scope=read:user", GITHUB_COPILOT_CLIENT_ID);
958 assert!(body.contains("client_id=Iv1.b507a08c87ecfe98"));
959 assert!(body.contains("scope=read:user"));
960 }
961
962 #[test]
963 fn test_token_poll_request_body() {
964 let device_code = "test_device_code_123";
965 let body = format!(
966 "client_id={}&device_code={}&grant_type=urn:ietf:params:oauth:grant-type:device_code",
967 GITHUB_COPILOT_CLIENT_ID, device_code
968 );
969 assert!(body.contains("client_id=Iv1.b507a08c87ecfe98"));
970 assert!(body.contains("device_code=test_device_code_123"));
971 assert!(body.contains("grant_type=urn:ietf:params:oauth:grant-type:device_code"));
972 }
973}
974
975zed::register_extension!(CopilotChatProvider);