1use std::collections::HashMap;
2use std::sync::Mutex;
3use std::thread;
4use std::time::Duration;
5
6use serde::{Deserialize, Serialize};
7use zed_extension_api::http_client::{HttpMethod, HttpRequest, HttpResponseStream, RedirectPolicy};
8use zed_extension_api::{self as zed, *};
9
10const GITHUB_DEVICE_CODE_URL: &str = "https://github.com/login/device/code";
11const GITHUB_ACCESS_TOKEN_URL: &str = "https://github.com/login/oauth/access_token";
12const GITHUB_COPILOT_TOKEN_URL: &str = "https://api.github.com/copilot_internal/v2/token";
13const GITHUB_COPILOT_CLIENT_ID: &str = "Iv1.b507a08c87ecfe98";
14
15struct DeviceFlowState {
16 device_code: String,
17 interval: u64,
18 expires_in: u64,
19}
20
21#[derive(Clone)]
22struct ApiToken {
23 api_key: String,
24 api_endpoint: String,
25}
26
27#[derive(Clone, Deserialize)]
28struct CopilotModel {
29 id: String,
30 name: String,
31 #[serde(default)]
32 is_chat_default: bool,
33 #[serde(default)]
34 is_chat_fallback: bool,
35 #[serde(default)]
36 model_picker_enabled: bool,
37 #[serde(default)]
38 capabilities: ModelCapabilities,
39 #[serde(default)]
40 policy: Option<ModelPolicy>,
41}
42
43#[derive(Clone, Default, Deserialize)]
44struct ModelCapabilities {
45 #[serde(default)]
46 family: String,
47 #[serde(default)]
48 limits: ModelLimits,
49 #[serde(default)]
50 supports: ModelSupportedFeatures,
51 #[serde(rename = "type", default)]
52 model_type: String,
53}
54
55#[derive(Clone, Default, Deserialize)]
56struct ModelLimits {
57 #[serde(default)]
58 max_context_window_tokens: u64,
59 #[serde(default)]
60 max_output_tokens: u64,
61}
62
63#[derive(Clone, Default, Deserialize)]
64struct ModelSupportedFeatures {
65 #[serde(default)]
66 streaming: bool,
67 #[serde(default)]
68 tool_calls: bool,
69 #[serde(default)]
70 vision: bool,
71}
72
73#[derive(Clone, Deserialize)]
74struct ModelPolicy {
75 state: String,
76}
77
78struct CopilotChatProvider {
79 streams: Mutex<HashMap<String, StreamState>>,
80 next_stream_id: Mutex<u64>,
81 device_flow_state: Mutex<Option<DeviceFlowState>>,
82 api_token: Mutex<Option<ApiToken>>,
83 cached_models: Mutex<Option<Vec<CopilotModel>>>,
84}
85
86struct StreamState {
87 response_stream: Option<HttpResponseStream>,
88 buffer: String,
89 started: bool,
90 tool_calls: HashMap<usize, AccumulatedToolCall>,
91 tool_calls_emitted: bool,
92}
93
94#[derive(Clone, Default)]
95struct AccumulatedToolCall {
96 id: String,
97 name: String,
98 arguments: String,
99}
100
101#[derive(Serialize)]
102struct OpenAiRequest {
103 model: String,
104 messages: Vec<OpenAiMessage>,
105 #[serde(skip_serializing_if = "Option::is_none")]
106 max_tokens: Option<u64>,
107 #[serde(skip_serializing_if = "Vec::is_empty")]
108 tools: Vec<OpenAiTool>,
109 #[serde(skip_serializing_if = "Option::is_none")]
110 tool_choice: Option<String>,
111 #[serde(skip_serializing_if = "Vec::is_empty")]
112 stop: Vec<String>,
113 #[serde(skip_serializing_if = "Option::is_none")]
114 temperature: Option<f32>,
115 stream: bool,
116 #[serde(skip_serializing_if = "Option::is_none")]
117 stream_options: Option<StreamOptions>,
118}
119
120#[derive(Serialize)]
121struct StreamOptions {
122 include_usage: bool,
123}
124
125#[derive(Serialize)]
126struct OpenAiMessage {
127 role: String,
128 #[serde(skip_serializing_if = "Option::is_none")]
129 content: Option<OpenAiContent>,
130 #[serde(skip_serializing_if = "Option::is_none")]
131 tool_calls: Option<Vec<OpenAiToolCall>>,
132 #[serde(skip_serializing_if = "Option::is_none")]
133 tool_call_id: Option<String>,
134}
135
136#[derive(Serialize, Clone)]
137#[serde(untagged)]
138enum OpenAiContent {
139 Text(String),
140 Parts(Vec<OpenAiContentPart>),
141}
142
143#[derive(Serialize, Clone)]
144#[serde(tag = "type")]
145enum OpenAiContentPart {
146 #[serde(rename = "text")]
147 Text { text: String },
148 #[serde(rename = "image_url")]
149 ImageUrl { image_url: ImageUrl },
150}
151
152#[derive(Serialize, Clone)]
153struct ImageUrl {
154 url: String,
155}
156
157#[derive(Serialize, Clone)]
158struct OpenAiToolCall {
159 id: String,
160 #[serde(rename = "type")]
161 call_type: String,
162 function: OpenAiFunctionCall,
163}
164
165#[derive(Serialize, Clone)]
166struct OpenAiFunctionCall {
167 name: String,
168 arguments: String,
169}
170
171#[derive(Serialize)]
172struct OpenAiTool {
173 #[serde(rename = "type")]
174 tool_type: String,
175 function: OpenAiFunctionDef,
176}
177
178#[derive(Serialize)]
179struct OpenAiFunctionDef {
180 name: String,
181 description: String,
182 parameters: serde_json::Value,
183}
184
185#[derive(Deserialize, Debug)]
186struct OpenAiStreamResponse {
187 choices: Vec<OpenAiStreamChoice>,
188 #[serde(default)]
189 usage: Option<OpenAiUsage>,
190}
191
192#[derive(Deserialize, Debug)]
193struct OpenAiStreamChoice {
194 delta: OpenAiDelta,
195 finish_reason: Option<String>,
196}
197
198#[derive(Deserialize, Debug, Default)]
199struct OpenAiDelta {
200 #[serde(default)]
201 content: Option<String>,
202 #[serde(default)]
203 tool_calls: Option<Vec<OpenAiToolCallDelta>>,
204}
205
206#[derive(Deserialize, Debug)]
207struct OpenAiToolCallDelta {
208 index: usize,
209 #[serde(default)]
210 id: Option<String>,
211 #[serde(default)]
212 function: Option<OpenAiFunctionDelta>,
213}
214
215#[derive(Deserialize, Debug, Default)]
216struct OpenAiFunctionDelta {
217 #[serde(default)]
218 name: Option<String>,
219 #[serde(default)]
220 arguments: Option<String>,
221}
222
223#[derive(Deserialize, Debug)]
224struct OpenAiUsage {
225 prompt_tokens: u64,
226 completion_tokens: u64,
227}
228
229fn convert_request(
230 model_id: &str,
231 request: &LlmCompletionRequest,
232) -> Result<OpenAiRequest, String> {
233 let mut messages: Vec<OpenAiMessage> = Vec::new();
234
235 for msg in &request.messages {
236 match msg.role {
237 LlmMessageRole::System => {
238 let mut text_content = String::new();
239 for content in &msg.content {
240 if let LlmMessageContent::Text(text) = content {
241 if !text_content.is_empty() {
242 text_content.push('\n');
243 }
244 text_content.push_str(text);
245 }
246 }
247 if !text_content.is_empty() {
248 messages.push(OpenAiMessage {
249 role: "system".to_string(),
250 content: Some(OpenAiContent::Text(text_content)),
251 tool_calls: None,
252 tool_call_id: None,
253 });
254 }
255 }
256 LlmMessageRole::User => {
257 let mut parts: Vec<OpenAiContentPart> = Vec::new();
258 let mut tool_result_messages: Vec<OpenAiMessage> = Vec::new();
259
260 for content in &msg.content {
261 match content {
262 LlmMessageContent::Text(text) => {
263 if !text.is_empty() {
264 parts.push(OpenAiContentPart::Text { text: text.clone() });
265 }
266 }
267 LlmMessageContent::Image(img) => {
268 let data_url = format!("data:image/png;base64,{}", img.source);
269 parts.push(OpenAiContentPart::ImageUrl {
270 image_url: ImageUrl { url: data_url },
271 });
272 }
273 LlmMessageContent::ToolResult(result) => {
274 let content_text = match &result.content {
275 LlmToolResultContent::Text(t) => t.clone(),
276 LlmToolResultContent::Image(_) => "[Image]".to_string(),
277 };
278 tool_result_messages.push(OpenAiMessage {
279 role: "tool".to_string(),
280 content: Some(OpenAiContent::Text(content_text)),
281 tool_calls: None,
282 tool_call_id: Some(result.tool_use_id.clone()),
283 });
284 }
285 _ => {}
286 }
287 }
288
289 if !parts.is_empty() {
290 let content = if parts.len() == 1 {
291 if let OpenAiContentPart::Text { text } = &parts[0] {
292 OpenAiContent::Text(text.clone())
293 } else {
294 OpenAiContent::Parts(parts)
295 }
296 } else {
297 OpenAiContent::Parts(parts)
298 };
299
300 messages.push(OpenAiMessage {
301 role: "user".to_string(),
302 content: Some(content),
303 tool_calls: None,
304 tool_call_id: None,
305 });
306 }
307
308 messages.extend(tool_result_messages);
309 }
310 LlmMessageRole::Assistant => {
311 let mut text_content = String::new();
312 let mut tool_calls: Vec<OpenAiToolCall> = Vec::new();
313
314 for content in &msg.content {
315 match content {
316 LlmMessageContent::Text(text) => {
317 if !text.is_empty() {
318 if !text_content.is_empty() {
319 text_content.push('\n');
320 }
321 text_content.push_str(text);
322 }
323 }
324 LlmMessageContent::ToolUse(tool_use) => {
325 tool_calls.push(OpenAiToolCall {
326 id: tool_use.id.clone(),
327 call_type: "function".to_string(),
328 function: OpenAiFunctionCall {
329 name: tool_use.name.clone(),
330 arguments: tool_use.input.clone(),
331 },
332 });
333 }
334 _ => {}
335 }
336 }
337
338 messages.push(OpenAiMessage {
339 role: "assistant".to_string(),
340 content: if text_content.is_empty() {
341 None
342 } else {
343 Some(OpenAiContent::Text(text_content))
344 },
345 tool_calls: if tool_calls.is_empty() {
346 None
347 } else {
348 Some(tool_calls)
349 },
350 tool_call_id: None,
351 });
352 }
353 }
354 }
355
356 let tools: Vec<OpenAiTool> = request
357 .tools
358 .iter()
359 .map(|t| OpenAiTool {
360 tool_type: "function".to_string(),
361 function: OpenAiFunctionDef {
362 name: t.name.clone(),
363 description: t.description.clone(),
364 parameters: serde_json::from_str(&t.input_schema)
365 .unwrap_or(serde_json::Value::Object(Default::default())),
366 },
367 })
368 .collect();
369
370 let tool_choice = request.tool_choice.as_ref().map(|tc| match tc {
371 LlmToolChoice::Auto => "auto".to_string(),
372 LlmToolChoice::Any => "required".to_string(),
373 LlmToolChoice::None => "none".to_string(),
374 });
375
376 let max_tokens = request.max_tokens;
377
378 Ok(OpenAiRequest {
379 model: model_id.to_string(),
380 messages,
381 max_tokens,
382 tools,
383 tool_choice,
384 stop: request.stop_sequences.clone(),
385 temperature: request.temperature,
386 stream: true,
387 stream_options: Some(StreamOptions {
388 include_usage: true,
389 }),
390 })
391}
392
393fn parse_sse_line(line: &str) -> Option<OpenAiStreamResponse> {
394 let data = line.strip_prefix("data: ")?;
395 if data.trim() == "[DONE]" {
396 return None;
397 }
398 serde_json::from_str(data).ok()
399}
400
401impl zed::Extension for CopilotChatProvider {
402 fn new() -> Self {
403 Self {
404 streams: Mutex::new(HashMap::new()),
405 next_stream_id: Mutex::new(0),
406 device_flow_state: Mutex::new(None),
407 api_token: Mutex::new(None),
408 cached_models: Mutex::new(None),
409 }
410 }
411
412 fn llm_providers(&self) -> Vec<LlmProviderInfo> {
413 vec![LlmProviderInfo {
414 id: "copilot-chat".into(),
415 name: "Copilot Chat".into(),
416 icon: Some("icons/copilot.svg".into()),
417 }]
418 }
419
420 fn llm_provider_models(&self, _provider_id: &str) -> Result<Vec<LlmModelInfo>, String> {
421 // Try to get models from cache first
422 if let Some(models) = self.cached_models.lock().unwrap().as_ref() {
423 return Ok(convert_models_to_llm_info(models));
424 }
425
426 // Need to fetch models - requires authentication
427 let oauth_token = match llm_get_credential("copilot-chat") {
428 Some(token) => token,
429 None => return Ok(Vec::new()), // Not authenticated, return empty
430 };
431
432 // Get API token
433 let api_token = self.get_api_token(&oauth_token)?;
434
435 // Fetch models from API
436 let models = self.fetch_models(&api_token)?;
437
438 // Cache the models
439 *self.cached_models.lock().unwrap() = Some(models.clone());
440
441 Ok(convert_models_to_llm_info(&models))
442 }
443
444 fn llm_provider_is_authenticated(&self, _provider_id: &str) -> bool {
445 llm_get_credential("copilot-chat").is_some()
446 }
447
448 fn llm_provider_settings_markdown(&self, _provider_id: &str) -> Option<String> {
449 Some(
450 "To use Copilot Chat, sign in with your GitHub account. This requires an active [GitHub Copilot subscription](https://github.com/features/copilot).".to_string(),
451 )
452 }
453
454 fn llm_provider_start_device_flow_sign_in(
455 &mut self,
456 _provider_id: &str,
457 ) -> Result<String, String> {
458 // Step 1: Request device and user verification codes
459 let device_code_response = llm_oauth_send_http_request(&LlmOauthHttpRequest {
460 url: GITHUB_DEVICE_CODE_URL.to_string(),
461 method: "POST".to_string(),
462 headers: vec![
463 ("Accept".to_string(), "application/json".to_string()),
464 (
465 "Content-Type".to_string(),
466 "application/x-www-form-urlencoded".to_string(),
467 ),
468 ],
469 body: format!("client_id={}&scope=read:user", GITHUB_COPILOT_CLIENT_ID),
470 })?;
471
472 if device_code_response.status != 200 {
473 return Err(format!(
474 "Failed to get device code: HTTP {}",
475 device_code_response.status
476 ));
477 }
478
479 #[derive(Deserialize)]
480 struct DeviceCodeResponse {
481 device_code: String,
482 user_code: String,
483 verification_uri: String,
484 #[serde(default)]
485 verification_uri_complete: Option<String>,
486 expires_in: u64,
487 interval: u64,
488 }
489
490 let device_info: DeviceCodeResponse = serde_json::from_str(&device_code_response.body)
491 .map_err(|e| format!("Failed to parse device code response: {}", e))?;
492
493 // Store device flow state for polling
494 *self.device_flow_state.lock().unwrap() = Some(DeviceFlowState {
495 device_code: device_info.device_code,
496 interval: device_info.interval,
497 expires_in: device_info.expires_in,
498 });
499
500 // Step 2: Open browser to verification URL
501 // Use verification_uri_complete if available (has code pre-filled), otherwise construct URL
502 let verification_url = device_info.verification_uri_complete.unwrap_or_else(|| {
503 format!(
504 "{}?user_code={}",
505 device_info.verification_uri, &device_info.user_code
506 )
507 });
508 llm_oauth_open_browser(&verification_url)?;
509
510 // Return the user code for the host to display
511 Ok(device_info.user_code)
512 }
513
514 fn llm_provider_poll_device_flow_sign_in(&mut self, _provider_id: &str) -> Result<(), String> {
515 let state = self
516 .device_flow_state
517 .lock()
518 .unwrap()
519 .take()
520 .ok_or("No device flow in progress")?;
521
522 let poll_interval = Duration::from_secs(state.interval.max(5));
523 let max_attempts = (state.expires_in / state.interval.max(5)) as usize;
524
525 for _ in 0..max_attempts {
526 thread::sleep(poll_interval);
527
528 let token_response = llm_oauth_send_http_request(&LlmOauthHttpRequest {
529 url: GITHUB_ACCESS_TOKEN_URL.to_string(),
530 method: "POST".to_string(),
531 headers: vec![
532 ("Accept".to_string(), "application/json".to_string()),
533 (
534 "Content-Type".to_string(),
535 "application/x-www-form-urlencoded".to_string(),
536 ),
537 ],
538 body: format!(
539 "client_id={}&device_code={}&grant_type=urn:ietf:params:oauth:grant-type:device_code",
540 GITHUB_COPILOT_CLIENT_ID, state.device_code
541 ),
542 })?;
543
544 #[derive(Deserialize)]
545 struct TokenResponse {
546 access_token: Option<String>,
547 error: Option<String>,
548 error_description: Option<String>,
549 }
550
551 let token_json: TokenResponse = serde_json::from_str(&token_response.body)
552 .map_err(|e| format!("Failed to parse token response: {}", e))?;
553
554 if let Some(access_token) = token_json.access_token {
555 llm_store_credential("copilot-chat", &access_token)?;
556 return Ok(());
557 }
558
559 if let Some(error) = &token_json.error {
560 match error.as_str() {
561 "authorization_pending" => {
562 // User hasn't authorized yet, keep polling
563 continue;
564 }
565 "slow_down" => {
566 // Need to slow down polling
567 thread::sleep(Duration::from_secs(5));
568 continue;
569 }
570 "expired_token" => {
571 return Err("Device code expired. Please try again.".to_string());
572 }
573 "access_denied" => {
574 return Err("Authorization was denied.".to_string());
575 }
576 _ => {
577 let description = token_json.error_description.unwrap_or_default();
578 return Err(format!("OAuth error: {} - {}", error, description));
579 }
580 }
581 }
582 }
583
584 Err("Authorization timed out. Please try again.".to_string())
585 }
586
587 fn llm_provider_reset_credentials(&mut self, _provider_id: &str) -> Result<(), String> {
588 // Clear cached API token and models
589 *self.api_token.lock().unwrap() = None;
590 *self.cached_models.lock().unwrap() = None;
591 llm_delete_credential("copilot-chat")
592 }
593
594 fn llm_stream_completion_start(
595 &mut self,
596 _provider_id: &str,
597 model_id: &str,
598 request: &LlmCompletionRequest,
599 ) -> Result<String, String> {
600 let oauth_token = llm_get_credential("copilot-chat").ok_or_else(|| {
601 "No token configured. Please add your GitHub Copilot token in settings.".to_string()
602 })?;
603
604 // Get or refresh API token
605 let api_token = self.get_api_token(&oauth_token)?;
606
607 let openai_request = convert_request(model_id, request)?;
608
609 let body = serde_json::to_vec(&openai_request)
610 .map_err(|e| format!("Failed to serialize request: {}", e))?;
611
612 let completions_url = format!("{}/chat/completions", api_token.api_endpoint);
613
614 let http_request = HttpRequest {
615 method: HttpMethod::Post,
616 url: completions_url,
617 headers: vec![
618 ("Content-Type".to_string(), "application/json".to_string()),
619 (
620 "Authorization".to_string(),
621 format!("Bearer {}", api_token.api_key),
622 ),
623 (
624 "Copilot-Integration-Id".to_string(),
625 "vscode-chat".to_string(),
626 ),
627 ("Editor-Version".to_string(), "Zed/1.0.0".to_string()),
628 ],
629 body: Some(body),
630 redirect_policy: RedirectPolicy::FollowAll,
631 };
632
633 let response_stream = http_request
634 .fetch_stream()
635 .map_err(|e| format!("HTTP request failed: {}", e))?;
636
637 let stream_id = {
638 let mut id_counter = self.next_stream_id.lock().unwrap();
639 let id = format!("copilot-stream-{}", *id_counter);
640 *id_counter += 1;
641 id
642 };
643
644 self.streams.lock().unwrap().insert(
645 stream_id.clone(),
646 StreamState {
647 response_stream: Some(response_stream),
648 buffer: String::new(),
649 started: false,
650 tool_calls: HashMap::new(),
651 tool_calls_emitted: false,
652 },
653 );
654
655 Ok(stream_id)
656 }
657
658 fn llm_stream_completion_next(
659 &mut self,
660 stream_id: &str,
661 ) -> Result<Option<LlmCompletionEvent>, String> {
662 let mut streams = self.streams.lock().unwrap();
663 let state = streams
664 .get_mut(stream_id)
665 .ok_or_else(|| format!("Unknown stream: {}", stream_id))?;
666
667 if !state.started {
668 state.started = true;
669 return Ok(Some(LlmCompletionEvent::Started));
670 }
671
672 let response_stream = state
673 .response_stream
674 .as_mut()
675 .ok_or_else(|| "Stream already closed".to_string())?;
676
677 loop {
678 if let Some(newline_pos) = state.buffer.find('\n') {
679 let line = state.buffer[..newline_pos].to_string();
680 state.buffer = state.buffer[newline_pos + 1..].to_string();
681
682 if line.trim().is_empty() {
683 continue;
684 }
685
686 if let Some(response) = parse_sse_line(&line) {
687 if let Some(choice) = response.choices.first() {
688 if let Some(content) = &choice.delta.content {
689 if !content.is_empty() {
690 return Ok(Some(LlmCompletionEvent::Text(content.clone())));
691 }
692 }
693
694 if let Some(tool_calls) = &choice.delta.tool_calls {
695 for tc in tool_calls {
696 let entry = state
697 .tool_calls
698 .entry(tc.index)
699 .or_insert_with(AccumulatedToolCall::default);
700
701 if let Some(id) = &tc.id {
702 entry.id = id.clone();
703 }
704 if let Some(func) = &tc.function {
705 if let Some(name) = &func.name {
706 entry.name = name.clone();
707 }
708 if let Some(args) = &func.arguments {
709 entry.arguments.push_str(args);
710 }
711 }
712 }
713 }
714
715 if let Some(finish_reason) = &choice.finish_reason {
716 if !state.tool_calls.is_empty() && !state.tool_calls_emitted {
717 state.tool_calls_emitted = true;
718 let mut tool_calls: Vec<_> = state.tool_calls.drain().collect();
719 tool_calls.sort_by_key(|(idx, _)| *idx);
720
721 if let Some((_, tc)) = tool_calls.into_iter().next() {
722 return Ok(Some(LlmCompletionEvent::ToolUse(LlmToolUse {
723 id: tc.id,
724 name: tc.name,
725 input: tc.arguments,
726 thought_signature: None,
727 })));
728 }
729 }
730
731 let stop_reason = match finish_reason.as_str() {
732 "stop" => LlmStopReason::EndTurn,
733 "length" => LlmStopReason::MaxTokens,
734 "tool_calls" => LlmStopReason::ToolUse,
735 "content_filter" => LlmStopReason::Refusal,
736 _ => LlmStopReason::EndTurn,
737 };
738 return Ok(Some(LlmCompletionEvent::Stop(stop_reason)));
739 }
740 }
741
742 if let Some(usage) = response.usage {
743 return Ok(Some(LlmCompletionEvent::Usage(LlmTokenUsage {
744 input_tokens: usage.prompt_tokens,
745 output_tokens: usage.completion_tokens,
746 cache_creation_input_tokens: None,
747 cache_read_input_tokens: None,
748 })));
749 }
750 }
751
752 continue;
753 }
754
755 match response_stream.next_chunk() {
756 Ok(Some(chunk)) => {
757 let text = String::from_utf8_lossy(&chunk);
758 state.buffer.push_str(&text);
759 }
760 Ok(None) => {
761 return Ok(None);
762 }
763 Err(e) => {
764 return Err(format!("Stream error: {}", e));
765 }
766 }
767 }
768 }
769
770 fn llm_stream_completion_close(&mut self, stream_id: &str) {
771 self.streams.lock().unwrap().remove(stream_id);
772 }
773}
774
775impl CopilotChatProvider {
776 fn get_api_token(&self, oauth_token: &str) -> Result<ApiToken, String> {
777 // Check if we have a cached token
778 if let Some(token) = self.api_token.lock().unwrap().clone() {
779 return Ok(token);
780 }
781
782 // Request a new API token
783 let http_request = HttpRequest {
784 method: HttpMethod::Get,
785 url: GITHUB_COPILOT_TOKEN_URL.to_string(),
786 headers: vec![
787 (
788 "Authorization".to_string(),
789 format!("token {}", oauth_token),
790 ),
791 ("Accept".to_string(), "application/json".to_string()),
792 ],
793 body: None,
794 redirect_policy: RedirectPolicy::FollowAll,
795 };
796
797 let response = http_request
798 .fetch()
799 .map_err(|e| format!("Failed to request API token: {}", e))?;
800
801 #[derive(Deserialize)]
802 struct ApiTokenResponse {
803 token: String,
804 endpoints: ApiEndpoints,
805 }
806
807 #[derive(Deserialize)]
808 struct ApiEndpoints {
809 api: String,
810 }
811
812 let token_response: ApiTokenResponse =
813 serde_json::from_slice(&response.body).map_err(|e| {
814 format!(
815 "Failed to parse API token response: {} - body: {}",
816 e,
817 String::from_utf8_lossy(&response.body)
818 )
819 })?;
820
821 let api_token = ApiToken {
822 api_key: token_response.token,
823 api_endpoint: token_response.endpoints.api,
824 };
825
826 // Cache the token
827 *self.api_token.lock().unwrap() = Some(api_token.clone());
828
829 Ok(api_token)
830 }
831
832 fn fetch_models(&self, api_token: &ApiToken) -> Result<Vec<CopilotModel>, String> {
833 let models_url = format!("{}/models", api_token.api_endpoint);
834
835 let http_request = HttpRequest {
836 method: HttpMethod::Get,
837 url: models_url,
838 headers: vec![
839 (
840 "Authorization".to_string(),
841 format!("Bearer {}", api_token.api_key),
842 ),
843 ("Content-Type".to_string(), "application/json".to_string()),
844 (
845 "Copilot-Integration-Id".to_string(),
846 "vscode-chat".to_string(),
847 ),
848 ("Editor-Version".to_string(), "Zed/1.0.0".to_string()),
849 ("x-github-api-version".to_string(), "2025-05-01".to_string()),
850 ],
851 body: None,
852 redirect_policy: RedirectPolicy::FollowAll,
853 };
854
855 let response = http_request
856 .fetch()
857 .map_err(|e| format!("Failed to fetch models: {}", e))?;
858
859 #[derive(Deserialize)]
860 struct ModelsResponse {
861 data: Vec<CopilotModel>,
862 }
863
864 let models_response: ModelsResponse =
865 serde_json::from_slice(&response.body).map_err(|e| {
866 format!(
867 "Failed to parse models response: {} - body: {}",
868 e,
869 String::from_utf8_lossy(&response.body)
870 )
871 })?;
872
873 // Filter models like the built-in Copilot Chat does
874 let mut models: Vec<CopilotModel> = models_response
875 .data
876 .into_iter()
877 .filter(|model| {
878 model.model_picker_enabled
879 && model.capabilities.model_type == "chat"
880 && model
881 .policy
882 .as_ref()
883 .map(|p| p.state == "enabled")
884 .unwrap_or(true)
885 })
886 .collect();
887
888 // Sort so default model is first
889 if let Some(pos) = models.iter().position(|m| m.is_chat_default) {
890 let default_model = models.remove(pos);
891 models.insert(0, default_model);
892 }
893
894 Ok(models)
895 }
896}
897
898fn convert_models_to_llm_info(models: &[CopilotModel]) -> Vec<LlmModelInfo> {
899 models
900 .iter()
901 .map(|m| {
902 let max_tokens = if m.capabilities.limits.max_context_window_tokens > 0 {
903 m.capabilities.limits.max_context_window_tokens
904 } else {
905 128_000 // Default fallback
906 };
907 let max_output = if m.capabilities.limits.max_output_tokens > 0 {
908 Some(m.capabilities.limits.max_output_tokens)
909 } else {
910 None
911 };
912
913 LlmModelInfo {
914 id: m.id.clone(),
915 name: m.name.clone(),
916 max_token_count: max_tokens,
917 max_output_tokens: max_output,
918 capabilities: LlmModelCapabilities {
919 supports_images: m.capabilities.supports.vision,
920 supports_tools: m.capabilities.supports.tool_calls,
921 supports_tool_choice_auto: m.capabilities.supports.tool_calls,
922 supports_tool_choice_any: m.capabilities.supports.tool_calls,
923 supports_tool_choice_none: m.capabilities.supports.tool_calls,
924 supports_thinking: false,
925 tool_input_format: LlmToolInputFormat::JsonSchema,
926 },
927 is_default: m.is_chat_default,
928 is_default_fast: m.is_chat_fallback,
929 }
930 })
931 .collect()
932}
933
934#[cfg(test)]
935mod tests {
936 use super::*;
937
938 #[test]
939 fn test_device_flow_request_body() {
940 let body = format!("client_id={}&scope=read:user", GITHUB_COPILOT_CLIENT_ID);
941 assert!(body.contains("client_id=Iv1.b507a08c87ecfe98"));
942 assert!(body.contains("scope=read:user"));
943 }
944
945 #[test]
946 fn test_token_poll_request_body() {
947 let device_code = "test_device_code_123";
948 let body = format!(
949 "client_id={}&device_code={}&grant_type=urn:ietf:params:oauth:grant-type:device_code",
950 GITHUB_COPILOT_CLIENT_ID, device_code
951 );
952 assert!(body.contains("client_id=Iv1.b507a08c87ecfe98"));
953 assert!(body.contains("device_code=test_device_code_123"));
954 assert!(body.contains("grant_type=urn:ietf:params:oauth:grant-type:device_code"));
955 }
956}
957
958zed::register_extension!(CopilotChatProvider);