copilot_chat.rs

  1use std::collections::HashMap;
  2use std::sync::Mutex;
  3use std::thread;
  4use std::time::Duration;
  5
  6use serde::{Deserialize, Serialize};
  7use zed_extension_api::http_client::{HttpMethod, HttpRequest, HttpResponseStream, RedirectPolicy};
  8use zed_extension_api::{self as zed, *};
  9
 10const GITHUB_DEVICE_CODE_URL: &str = "https://github.com/login/device/code";
 11const GITHUB_ACCESS_TOKEN_URL: &str = "https://github.com/login/oauth/access_token";
 12const GITHUB_COPILOT_TOKEN_URL: &str = "https://api.github.com/copilot_internal/v2/token";
 13const GITHUB_COPILOT_CLIENT_ID: &str = "Iv1.b507a08c87ecfe98";
 14
 15struct DeviceFlowState {
 16    device_code: String,
 17    interval: u64,
 18    expires_in: u64,
 19}
 20
 21#[derive(Clone)]
 22struct ApiToken {
 23    api_key: String,
 24    api_endpoint: String,
 25}
 26
 27#[derive(Clone, Deserialize)]
 28struct CopilotModel {
 29    id: String,
 30    name: String,
 31    #[serde(default)]
 32    is_chat_default: bool,
 33    #[serde(default)]
 34    is_chat_fallback: bool,
 35    #[serde(default)]
 36    model_picker_enabled: bool,
 37    #[serde(default)]
 38    capabilities: ModelCapabilities,
 39    #[serde(default)]
 40    policy: Option<ModelPolicy>,
 41}
 42
 43#[derive(Clone, Default, Deserialize)]
 44struct ModelCapabilities {
 45    #[serde(default)]
 46    family: String,
 47    #[serde(default)]
 48    limits: ModelLimits,
 49    #[serde(default)]
 50    supports: ModelSupportedFeatures,
 51    #[serde(rename = "type", default)]
 52    model_type: String,
 53}
 54
 55#[derive(Clone, Default, Deserialize)]
 56struct ModelLimits {
 57    #[serde(default)]
 58    max_context_window_tokens: u64,
 59    #[serde(default)]
 60    max_output_tokens: u64,
 61}
 62
 63#[derive(Clone, Default, Deserialize)]
 64struct ModelSupportedFeatures {
 65    #[serde(default)]
 66    streaming: bool,
 67    #[serde(default)]
 68    tool_calls: bool,
 69    #[serde(default)]
 70    vision: bool,
 71}
 72
 73#[derive(Clone, Deserialize)]
 74struct ModelPolicy {
 75    state: String,
 76}
 77
 78struct CopilotChatProvider {
 79    streams: Mutex<HashMap<String, StreamState>>,
 80    next_stream_id: Mutex<u64>,
 81    device_flow_state: Mutex<Option<DeviceFlowState>>,
 82    api_token: Mutex<Option<ApiToken>>,
 83    cached_models: Mutex<Option<Vec<CopilotModel>>>,
 84}
 85
 86struct StreamState {
 87    response_stream: Option<HttpResponseStream>,
 88    buffer: String,
 89    started: bool,
 90    tool_calls: HashMap<usize, AccumulatedToolCall>,
 91    tool_calls_emitted: bool,
 92}
 93
 94#[derive(Clone, Default)]
 95struct AccumulatedToolCall {
 96    id: String,
 97    name: String,
 98    arguments: String,
 99}
100
101#[derive(Serialize)]
102struct OpenAiRequest {
103    model: String,
104    messages: Vec<OpenAiMessage>,
105    #[serde(skip_serializing_if = "Option::is_none")]
106    max_tokens: Option<u64>,
107    #[serde(skip_serializing_if = "Vec::is_empty")]
108    tools: Vec<OpenAiTool>,
109    #[serde(skip_serializing_if = "Option::is_none")]
110    tool_choice: Option<String>,
111    #[serde(skip_serializing_if = "Vec::is_empty")]
112    stop: Vec<String>,
113    #[serde(skip_serializing_if = "Option::is_none")]
114    temperature: Option<f32>,
115    stream: bool,
116    #[serde(skip_serializing_if = "Option::is_none")]
117    stream_options: Option<StreamOptions>,
118}
119
120#[derive(Serialize)]
121struct StreamOptions {
122    include_usage: bool,
123}
124
125#[derive(Serialize)]
126struct OpenAiMessage {
127    role: String,
128    #[serde(skip_serializing_if = "Option::is_none")]
129    content: Option<OpenAiContent>,
130    #[serde(skip_serializing_if = "Option::is_none")]
131    tool_calls: Option<Vec<OpenAiToolCall>>,
132    #[serde(skip_serializing_if = "Option::is_none")]
133    tool_call_id: Option<String>,
134}
135
136#[derive(Serialize, Clone)]
137#[serde(untagged)]
138enum OpenAiContent {
139    Text(String),
140    Parts(Vec<OpenAiContentPart>),
141}
142
143#[derive(Serialize, Clone)]
144#[serde(tag = "type")]
145enum OpenAiContentPart {
146    #[serde(rename = "text")]
147    Text { text: String },
148    #[serde(rename = "image_url")]
149    ImageUrl { image_url: ImageUrl },
150}
151
152#[derive(Serialize, Clone)]
153struct ImageUrl {
154    url: String,
155}
156
157#[derive(Serialize, Clone)]
158struct OpenAiToolCall {
159    id: String,
160    #[serde(rename = "type")]
161    call_type: String,
162    function: OpenAiFunctionCall,
163}
164
165#[derive(Serialize, Clone)]
166struct OpenAiFunctionCall {
167    name: String,
168    arguments: String,
169}
170
171#[derive(Serialize)]
172struct OpenAiTool {
173    #[serde(rename = "type")]
174    tool_type: String,
175    function: OpenAiFunctionDef,
176}
177
178#[derive(Serialize)]
179struct OpenAiFunctionDef {
180    name: String,
181    description: String,
182    parameters: serde_json::Value,
183}
184
185#[derive(Deserialize, Debug)]
186struct OpenAiStreamResponse {
187    choices: Vec<OpenAiStreamChoice>,
188    #[serde(default)]
189    usage: Option<OpenAiUsage>,
190}
191
192#[derive(Deserialize, Debug)]
193struct OpenAiStreamChoice {
194    delta: OpenAiDelta,
195    finish_reason: Option<String>,
196}
197
198#[derive(Deserialize, Debug, Default)]
199struct OpenAiDelta {
200    #[serde(default)]
201    content: Option<String>,
202    #[serde(default)]
203    tool_calls: Option<Vec<OpenAiToolCallDelta>>,
204}
205
206#[derive(Deserialize, Debug)]
207struct OpenAiToolCallDelta {
208    index: usize,
209    #[serde(default)]
210    id: Option<String>,
211    #[serde(default)]
212    function: Option<OpenAiFunctionDelta>,
213}
214
215#[derive(Deserialize, Debug, Default)]
216struct OpenAiFunctionDelta {
217    #[serde(default)]
218    name: Option<String>,
219    #[serde(default)]
220    arguments: Option<String>,
221}
222
223#[derive(Deserialize, Debug)]
224struct OpenAiUsage {
225    prompt_tokens: u64,
226    completion_tokens: u64,
227}
228
229fn convert_request(
230    model_id: &str,
231    request: &LlmCompletionRequest,
232) -> Result<OpenAiRequest, String> {
233    let mut messages: Vec<OpenAiMessage> = Vec::new();
234
235    for msg in &request.messages {
236        match msg.role {
237            LlmMessageRole::System => {
238                let mut text_content = String::new();
239                for content in &msg.content {
240                    if let LlmMessageContent::Text(text) = content {
241                        if !text_content.is_empty() {
242                            text_content.push('\n');
243                        }
244                        text_content.push_str(text);
245                    }
246                }
247                if !text_content.is_empty() {
248                    messages.push(OpenAiMessage {
249                        role: "system".to_string(),
250                        content: Some(OpenAiContent::Text(text_content)),
251                        tool_calls: None,
252                        tool_call_id: None,
253                    });
254                }
255            }
256            LlmMessageRole::User => {
257                let mut parts: Vec<OpenAiContentPart> = Vec::new();
258                let mut tool_result_messages: Vec<OpenAiMessage> = Vec::new();
259
260                for content in &msg.content {
261                    match content {
262                        LlmMessageContent::Text(text) => {
263                            if !text.is_empty() {
264                                parts.push(OpenAiContentPart::Text { text: text.clone() });
265                            }
266                        }
267                        LlmMessageContent::Image(img) => {
268                            let data_url = format!("data:image/png;base64,{}", img.source);
269                            parts.push(OpenAiContentPart::ImageUrl {
270                                image_url: ImageUrl { url: data_url },
271                            });
272                        }
273                        LlmMessageContent::ToolResult(result) => {
274                            let content_text = match &result.content {
275                                LlmToolResultContent::Text(t) => t.clone(),
276                                LlmToolResultContent::Image(_) => "[Image]".to_string(),
277                            };
278                            tool_result_messages.push(OpenAiMessage {
279                                role: "tool".to_string(),
280                                content: Some(OpenAiContent::Text(content_text)),
281                                tool_calls: None,
282                                tool_call_id: Some(result.tool_use_id.clone()),
283                            });
284                        }
285                        _ => {}
286                    }
287                }
288
289                if !parts.is_empty() {
290                    let content = if parts.len() == 1 {
291                        if let OpenAiContentPart::Text { text } = &parts[0] {
292                            OpenAiContent::Text(text.clone())
293                        } else {
294                            OpenAiContent::Parts(parts)
295                        }
296                    } else {
297                        OpenAiContent::Parts(parts)
298                    };
299
300                    messages.push(OpenAiMessage {
301                        role: "user".to_string(),
302                        content: Some(content),
303                        tool_calls: None,
304                        tool_call_id: None,
305                    });
306                }
307
308                messages.extend(tool_result_messages);
309            }
310            LlmMessageRole::Assistant => {
311                let mut text_content = String::new();
312                let mut tool_calls: Vec<OpenAiToolCall> = Vec::new();
313
314                for content in &msg.content {
315                    match content {
316                        LlmMessageContent::Text(text) => {
317                            if !text.is_empty() {
318                                if !text_content.is_empty() {
319                                    text_content.push('\n');
320                                }
321                                text_content.push_str(text);
322                            }
323                        }
324                        LlmMessageContent::ToolUse(tool_use) => {
325                            tool_calls.push(OpenAiToolCall {
326                                id: tool_use.id.clone(),
327                                call_type: "function".to_string(),
328                                function: OpenAiFunctionCall {
329                                    name: tool_use.name.clone(),
330                                    arguments: tool_use.input.clone(),
331                                },
332                            });
333                        }
334                        _ => {}
335                    }
336                }
337
338                messages.push(OpenAiMessage {
339                    role: "assistant".to_string(),
340                    content: if text_content.is_empty() {
341                        None
342                    } else {
343                        Some(OpenAiContent::Text(text_content))
344                    },
345                    tool_calls: if tool_calls.is_empty() {
346                        None
347                    } else {
348                        Some(tool_calls)
349                    },
350                    tool_call_id: None,
351                });
352            }
353        }
354    }
355
356    let tools: Vec<OpenAiTool> = request
357        .tools
358        .iter()
359        .map(|t| OpenAiTool {
360            tool_type: "function".to_string(),
361            function: OpenAiFunctionDef {
362                name: t.name.clone(),
363                description: t.description.clone(),
364                parameters: serde_json::from_str(&t.input_schema)
365                    .unwrap_or(serde_json::Value::Object(Default::default())),
366            },
367        })
368        .collect();
369
370    let tool_choice = request.tool_choice.as_ref().map(|tc| match tc {
371        LlmToolChoice::Auto => "auto".to_string(),
372        LlmToolChoice::Any => "required".to_string(),
373        LlmToolChoice::None => "none".to_string(),
374    });
375
376    let max_tokens = request.max_tokens;
377
378    Ok(OpenAiRequest {
379        model: model_id.to_string(),
380        messages,
381        max_tokens,
382        tools,
383        tool_choice,
384        stop: request.stop_sequences.clone(),
385        temperature: request.temperature,
386        stream: true,
387        stream_options: Some(StreamOptions {
388            include_usage: true,
389        }),
390    })
391}
392
393fn parse_sse_line(line: &str) -> Option<OpenAiStreamResponse> {
394    let data = line.strip_prefix("data: ")?;
395    if data.trim() == "[DONE]" {
396        return None;
397    }
398    serde_json::from_str(data).ok()
399}
400
401impl zed::Extension for CopilotChatProvider {
402    fn new() -> Self {
403        Self {
404            streams: Mutex::new(HashMap::new()),
405            next_stream_id: Mutex::new(0),
406            device_flow_state: Mutex::new(None),
407            api_token: Mutex::new(None),
408            cached_models: Mutex::new(None),
409        }
410    }
411
412    fn llm_providers(&self) -> Vec<LlmProviderInfo> {
413        vec![LlmProviderInfo {
414            id: "copilot-chat".into(),
415            name: "Copilot Chat".into(),
416            icon: Some("icons/copilot.svg".into()),
417        }]
418    }
419
420    fn llm_provider_models(&self, _provider_id: &str) -> Result<Vec<LlmModelInfo>, String> {
421        // Try to get models from cache first
422        if let Some(models) = self.cached_models.lock().unwrap().as_ref() {
423            return Ok(convert_models_to_llm_info(models));
424        }
425
426        // Need to fetch models - requires authentication
427        let oauth_token = match llm_get_credential("copilot-chat") {
428            Some(token) => token,
429            None => return Ok(Vec::new()), // Not authenticated, return empty
430        };
431
432        // Get API token
433        let api_token = self.get_api_token(&oauth_token)?;
434
435        // Fetch models from API
436        let models = self.fetch_models(&api_token)?;
437
438        // Cache the models
439        *self.cached_models.lock().unwrap() = Some(models.clone());
440
441        Ok(convert_models_to_llm_info(&models))
442    }
443
444    fn llm_provider_is_authenticated(&self, _provider_id: &str) -> bool {
445        llm_get_credential("copilot-chat").is_some()
446    }
447
448    fn llm_provider_settings_markdown(&self, _provider_id: &str) -> Option<String> {
449        Some(
450            "To use Copilot Chat, sign in with your GitHub account. This requires an active [GitHub Copilot subscription](https://github.com/features/copilot).".to_string(),
451        )
452    }
453
454    fn llm_provider_start_device_flow_sign_in(
455        &mut self,
456        _provider_id: &str,
457    ) -> Result<String, String> {
458        // Step 1: Request device and user verification codes
459        let device_code_response = llm_oauth_send_http_request(&LlmOauthHttpRequest {
460            url: GITHUB_DEVICE_CODE_URL.to_string(),
461            method: "POST".to_string(),
462            headers: vec![
463                ("Accept".to_string(), "application/json".to_string()),
464                (
465                    "Content-Type".to_string(),
466                    "application/x-www-form-urlencoded".to_string(),
467                ),
468            ],
469            body: format!("client_id={}&scope=read:user", GITHUB_COPILOT_CLIENT_ID),
470        })?;
471
472        if device_code_response.status != 200 {
473            return Err(format!(
474                "Failed to get device code: HTTP {}",
475                device_code_response.status
476            ));
477        }
478
479        #[derive(Deserialize)]
480        struct DeviceCodeResponse {
481            device_code: String,
482            user_code: String,
483            verification_uri: String,
484            #[serde(default)]
485            verification_uri_complete: Option<String>,
486            expires_in: u64,
487            interval: u64,
488        }
489
490        let device_info: DeviceCodeResponse = serde_json::from_str(&device_code_response.body)
491            .map_err(|e| format!("Failed to parse device code response: {}", e))?;
492
493        // Store device flow state for polling
494        *self.device_flow_state.lock().unwrap() = Some(DeviceFlowState {
495            device_code: device_info.device_code,
496            interval: device_info.interval,
497            expires_in: device_info.expires_in,
498        });
499
500        // Step 2: Open browser to verification URL
501        // Use verification_uri_complete if available (has code pre-filled), otherwise construct URL
502        let verification_url = device_info.verification_uri_complete.unwrap_or_else(|| {
503            format!(
504                "{}?user_code={}",
505                device_info.verification_uri, &device_info.user_code
506            )
507        });
508        llm_oauth_open_browser(&verification_url)?;
509
510        // Return the user code for the host to display
511        Ok(device_info.user_code)
512    }
513
514    fn llm_provider_poll_device_flow_sign_in(&mut self, _provider_id: &str) -> Result<(), String> {
515        let state = self
516            .device_flow_state
517            .lock()
518            .unwrap()
519            .take()
520            .ok_or("No device flow in progress")?;
521
522        let poll_interval = Duration::from_secs(state.interval.max(5));
523        let max_attempts = (state.expires_in / state.interval.max(5)) as usize;
524
525        for _ in 0..max_attempts {
526            thread::sleep(poll_interval);
527
528            let token_response = llm_oauth_send_http_request(&LlmOauthHttpRequest {
529                url: GITHUB_ACCESS_TOKEN_URL.to_string(),
530                method: "POST".to_string(),
531                headers: vec![
532                    ("Accept".to_string(), "application/json".to_string()),
533                    (
534                        "Content-Type".to_string(),
535                        "application/x-www-form-urlencoded".to_string(),
536                    ),
537                ],
538                body: format!(
539                    "client_id={}&device_code={}&grant_type=urn:ietf:params:oauth:grant-type:device_code",
540                    GITHUB_COPILOT_CLIENT_ID, state.device_code
541                ),
542            })?;
543
544            #[derive(Deserialize)]
545            struct TokenResponse {
546                access_token: Option<String>,
547                error: Option<String>,
548                error_description: Option<String>,
549            }
550
551            let token_json: TokenResponse = serde_json::from_str(&token_response.body)
552                .map_err(|e| format!("Failed to parse token response: {}", e))?;
553
554            if let Some(access_token) = token_json.access_token {
555                llm_store_credential("copilot-chat", &access_token)?;
556                return Ok(());
557            }
558
559            if let Some(error) = &token_json.error {
560                match error.as_str() {
561                    "authorization_pending" => {
562                        // User hasn't authorized yet, keep polling
563                        continue;
564                    }
565                    "slow_down" => {
566                        // Need to slow down polling
567                        thread::sleep(Duration::from_secs(5));
568                        continue;
569                    }
570                    "expired_token" => {
571                        return Err("Device code expired. Please try again.".to_string());
572                    }
573                    "access_denied" => {
574                        return Err("Authorization was denied.".to_string());
575                    }
576                    _ => {
577                        let description = token_json.error_description.unwrap_or_default();
578                        return Err(format!("OAuth error: {} - {}", error, description));
579                    }
580                }
581            }
582        }
583
584        Err("Authorization timed out. Please try again.".to_string())
585    }
586
587    fn llm_provider_reset_credentials(&mut self, _provider_id: &str) -> Result<(), String> {
588        // Clear cached API token and models
589        *self.api_token.lock().unwrap() = None;
590        *self.cached_models.lock().unwrap() = None;
591        llm_delete_credential("copilot-chat")
592    }
593
594    fn llm_stream_completion_start(
595        &mut self,
596        _provider_id: &str,
597        model_id: &str,
598        request: &LlmCompletionRequest,
599    ) -> Result<String, String> {
600        let oauth_token = llm_get_credential("copilot-chat").ok_or_else(|| {
601            "No token configured. Please add your GitHub Copilot token in settings.".to_string()
602        })?;
603
604        // Get or refresh API token
605        let api_token = self.get_api_token(&oauth_token)?;
606
607        let openai_request = convert_request(model_id, request)?;
608
609        let body = serde_json::to_vec(&openai_request)
610            .map_err(|e| format!("Failed to serialize request: {}", e))?;
611
612        let completions_url = format!("{}/chat/completions", api_token.api_endpoint);
613
614        let http_request = HttpRequest {
615            method: HttpMethod::Post,
616            url: completions_url,
617            headers: vec![
618                ("Content-Type".to_string(), "application/json".to_string()),
619                (
620                    "Authorization".to_string(),
621                    format!("Bearer {}", api_token.api_key),
622                ),
623                (
624                    "Copilot-Integration-Id".to_string(),
625                    "vscode-chat".to_string(),
626                ),
627                ("Editor-Version".to_string(), "Zed/1.0.0".to_string()),
628            ],
629            body: Some(body),
630            redirect_policy: RedirectPolicy::FollowAll,
631        };
632
633        let response_stream = http_request
634            .fetch_stream()
635            .map_err(|e| format!("HTTP request failed: {}", e))?;
636
637        let stream_id = {
638            let mut id_counter = self.next_stream_id.lock().unwrap();
639            let id = format!("copilot-stream-{}", *id_counter);
640            *id_counter += 1;
641            id
642        };
643
644        self.streams.lock().unwrap().insert(
645            stream_id.clone(),
646            StreamState {
647                response_stream: Some(response_stream),
648                buffer: String::new(),
649                started: false,
650                tool_calls: HashMap::new(),
651                tool_calls_emitted: false,
652            },
653        );
654
655        Ok(stream_id)
656    }
657
658    fn llm_stream_completion_next(
659        &mut self,
660        stream_id: &str,
661    ) -> Result<Option<LlmCompletionEvent>, String> {
662        let mut streams = self.streams.lock().unwrap();
663        let state = streams
664            .get_mut(stream_id)
665            .ok_or_else(|| format!("Unknown stream: {}", stream_id))?;
666
667        if !state.started {
668            state.started = true;
669            return Ok(Some(LlmCompletionEvent::Started));
670        }
671
672        let response_stream = state
673            .response_stream
674            .as_mut()
675            .ok_or_else(|| "Stream already closed".to_string())?;
676
677        loop {
678            if let Some(newline_pos) = state.buffer.find('\n') {
679                let line = state.buffer[..newline_pos].to_string();
680                state.buffer = state.buffer[newline_pos + 1..].to_string();
681
682                if line.trim().is_empty() {
683                    continue;
684                }
685
686                if let Some(response) = parse_sse_line(&line) {
687                    if let Some(choice) = response.choices.first() {
688                        if let Some(content) = &choice.delta.content {
689                            if !content.is_empty() {
690                                return Ok(Some(LlmCompletionEvent::Text(content.clone())));
691                            }
692                        }
693
694                        if let Some(tool_calls) = &choice.delta.tool_calls {
695                            for tc in tool_calls {
696                                let entry = state
697                                    .tool_calls
698                                    .entry(tc.index)
699                                    .or_insert_with(AccumulatedToolCall::default);
700
701                                if let Some(id) = &tc.id {
702                                    entry.id = id.clone();
703                                }
704                                if let Some(func) = &tc.function {
705                                    if let Some(name) = &func.name {
706                                        entry.name = name.clone();
707                                    }
708                                    if let Some(args) = &func.arguments {
709                                        entry.arguments.push_str(args);
710                                    }
711                                }
712                            }
713                        }
714
715                        if let Some(finish_reason) = &choice.finish_reason {
716                            if !state.tool_calls.is_empty() && !state.tool_calls_emitted {
717                                state.tool_calls_emitted = true;
718                                let mut tool_calls: Vec<_> = state.tool_calls.drain().collect();
719                                tool_calls.sort_by_key(|(idx, _)| *idx);
720
721                                if let Some((_, tc)) = tool_calls.into_iter().next() {
722                                    return Ok(Some(LlmCompletionEvent::ToolUse(LlmToolUse {
723                                        id: tc.id,
724                                        name: tc.name,
725                                        input: tc.arguments,
726                                        thought_signature: None,
727                                    })));
728                                }
729                            }
730
731                            let stop_reason = match finish_reason.as_str() {
732                                "stop" => LlmStopReason::EndTurn,
733                                "length" => LlmStopReason::MaxTokens,
734                                "tool_calls" => LlmStopReason::ToolUse,
735                                "content_filter" => LlmStopReason::Refusal,
736                                _ => LlmStopReason::EndTurn,
737                            };
738                            return Ok(Some(LlmCompletionEvent::Stop(stop_reason)));
739                        }
740                    }
741
742                    if let Some(usage) = response.usage {
743                        return Ok(Some(LlmCompletionEvent::Usage(LlmTokenUsage {
744                            input_tokens: usage.prompt_tokens,
745                            output_tokens: usage.completion_tokens,
746                            cache_creation_input_tokens: None,
747                            cache_read_input_tokens: None,
748                        })));
749                    }
750                }
751
752                continue;
753            }
754
755            match response_stream.next_chunk() {
756                Ok(Some(chunk)) => {
757                    let text = String::from_utf8_lossy(&chunk);
758                    state.buffer.push_str(&text);
759                }
760                Ok(None) => {
761                    return Ok(None);
762                }
763                Err(e) => {
764                    return Err(format!("Stream error: {}", e));
765                }
766            }
767        }
768    }
769
770    fn llm_stream_completion_close(&mut self, stream_id: &str) {
771        self.streams.lock().unwrap().remove(stream_id);
772    }
773}
774
775impl CopilotChatProvider {
776    fn get_api_token(&self, oauth_token: &str) -> Result<ApiToken, String> {
777        // Check if we have a cached token
778        if let Some(token) = self.api_token.lock().unwrap().clone() {
779            return Ok(token);
780        }
781
782        // Request a new API token
783        let http_request = HttpRequest {
784            method: HttpMethod::Get,
785            url: GITHUB_COPILOT_TOKEN_URL.to_string(),
786            headers: vec![
787                (
788                    "Authorization".to_string(),
789                    format!("token {}", oauth_token),
790                ),
791                ("Accept".to_string(), "application/json".to_string()),
792            ],
793            body: None,
794            redirect_policy: RedirectPolicy::FollowAll,
795        };
796
797        let response = http_request
798            .fetch()
799            .map_err(|e| format!("Failed to request API token: {}", e))?;
800
801        #[derive(Deserialize)]
802        struct ApiTokenResponse {
803            token: String,
804            endpoints: ApiEndpoints,
805        }
806
807        #[derive(Deserialize)]
808        struct ApiEndpoints {
809            api: String,
810        }
811
812        let token_response: ApiTokenResponse =
813            serde_json::from_slice(&response.body).map_err(|e| {
814                format!(
815                    "Failed to parse API token response: {} - body: {}",
816                    e,
817                    String::from_utf8_lossy(&response.body)
818                )
819            })?;
820
821        let api_token = ApiToken {
822            api_key: token_response.token,
823            api_endpoint: token_response.endpoints.api,
824        };
825
826        // Cache the token
827        *self.api_token.lock().unwrap() = Some(api_token.clone());
828
829        Ok(api_token)
830    }
831
832    fn fetch_models(&self, api_token: &ApiToken) -> Result<Vec<CopilotModel>, String> {
833        let models_url = format!("{}/models", api_token.api_endpoint);
834
835        let http_request = HttpRequest {
836            method: HttpMethod::Get,
837            url: models_url,
838            headers: vec![
839                (
840                    "Authorization".to_string(),
841                    format!("Bearer {}", api_token.api_key),
842                ),
843                ("Content-Type".to_string(), "application/json".to_string()),
844                (
845                    "Copilot-Integration-Id".to_string(),
846                    "vscode-chat".to_string(),
847                ),
848                ("Editor-Version".to_string(), "Zed/1.0.0".to_string()),
849                ("x-github-api-version".to_string(), "2025-05-01".to_string()),
850            ],
851            body: None,
852            redirect_policy: RedirectPolicy::FollowAll,
853        };
854
855        let response = http_request
856            .fetch()
857            .map_err(|e| format!("Failed to fetch models: {}", e))?;
858
859        #[derive(Deserialize)]
860        struct ModelsResponse {
861            data: Vec<CopilotModel>,
862        }
863
864        let models_response: ModelsResponse =
865            serde_json::from_slice(&response.body).map_err(|e| {
866                format!(
867                    "Failed to parse models response: {} - body: {}",
868                    e,
869                    String::from_utf8_lossy(&response.body)
870                )
871            })?;
872
873        // Filter models like the built-in Copilot Chat does
874        let mut models: Vec<CopilotModel> = models_response
875            .data
876            .into_iter()
877            .filter(|model| {
878                model.model_picker_enabled
879                    && model.capabilities.model_type == "chat"
880                    && model
881                        .policy
882                        .as_ref()
883                        .map(|p| p.state == "enabled")
884                        .unwrap_or(true)
885            })
886            .collect();
887
888        // Sort so default model is first
889        if let Some(pos) = models.iter().position(|m| m.is_chat_default) {
890            let default_model = models.remove(pos);
891            models.insert(0, default_model);
892        }
893
894        Ok(models)
895    }
896}
897
898fn convert_models_to_llm_info(models: &[CopilotModel]) -> Vec<LlmModelInfo> {
899    models
900        .iter()
901        .map(|m| {
902            let max_tokens = if m.capabilities.limits.max_context_window_tokens > 0 {
903                m.capabilities.limits.max_context_window_tokens
904            } else {
905                128_000 // Default fallback
906            };
907            let max_output = if m.capabilities.limits.max_output_tokens > 0 {
908                Some(m.capabilities.limits.max_output_tokens)
909            } else {
910                None
911            };
912
913            LlmModelInfo {
914                id: m.id.clone(),
915                name: m.name.clone(),
916                max_token_count: max_tokens,
917                max_output_tokens: max_output,
918                capabilities: LlmModelCapabilities {
919                    supports_images: m.capabilities.supports.vision,
920                    supports_tools: m.capabilities.supports.tool_calls,
921                    supports_tool_choice_auto: m.capabilities.supports.tool_calls,
922                    supports_tool_choice_any: m.capabilities.supports.tool_calls,
923                    supports_tool_choice_none: m.capabilities.supports.tool_calls,
924                    supports_thinking: false,
925                    tool_input_format: LlmToolInputFormat::JsonSchema,
926                },
927                is_default: m.is_chat_default,
928                is_default_fast: m.is_chat_fallback,
929            }
930        })
931        .collect()
932}
933
934#[cfg(test)]
935mod tests {
936    use super::*;
937
938    #[test]
939    fn test_device_flow_request_body() {
940        let body = format!("client_id={}&scope=read:user", GITHUB_COPILOT_CLIENT_ID);
941        assert!(body.contains("client_id=Iv1.b507a08c87ecfe98"));
942        assert!(body.contains("scope=read:user"));
943    }
944
945    #[test]
946    fn test_token_poll_request_body() {
947        let device_code = "test_device_code_123";
948        let body = format!(
949            "client_id={}&device_code={}&grant_type=urn:ietf:params:oauth:grant-type:device_code",
950            GITHUB_COPILOT_CLIENT_ID, device_code
951        );
952        assert!(body.contains("client_id=Iv1.b507a08c87ecfe98"));
953        assert!(body.contains("device_code=test_device_code_123"));
954        assert!(body.contains("grant_type=urn:ietf:params:oauth:grant-type:device_code"));
955    }
956}
957
958zed::register_extension!(CopilotChatProvider);