copilot_chat.rs

  1use std::collections::HashMap;
  2use std::sync::Mutex;
  3use std::thread;
  4use std::time::Duration;
  5
  6use serde::{Deserialize, Serialize};
  7use zed_extension_api::http_client::{HttpMethod, HttpRequest, HttpResponseStream, RedirectPolicy};
  8use zed_extension_api::{self as zed, *};
  9
 10const GITHUB_DEVICE_CODE_URL: &str = "https://github.com/login/device/code";
 11const GITHUB_ACCESS_TOKEN_URL: &str = "https://github.com/login/oauth/access_token";
 12const GITHUB_COPILOT_TOKEN_URL: &str = "https://api.github.com/copilot_internal/v2/token";
 13const GITHUB_COPILOT_CLIENT_ID: &str = "Iv1.b507a08c87ecfe98";
 14
 15struct DeviceFlowState {
 16    device_code: String,
 17    interval: u64,
 18    expires_in: u64,
 19}
 20
 21#[derive(Clone)]
 22struct ApiToken {
 23    api_key: String,
 24    api_endpoint: String,
 25}
 26
 27#[derive(Clone, Deserialize)]
 28struct CopilotModel {
 29    id: String,
 30    name: String,
 31    #[serde(default)]
 32    is_chat_default: bool,
 33    #[serde(default)]
 34    is_chat_fallback: bool,
 35    #[serde(default)]
 36    model_picker_enabled: bool,
 37    #[serde(default)]
 38    capabilities: ModelCapabilities,
 39    #[serde(default)]
 40    policy: Option<ModelPolicy>,
 41}
 42
 43#[derive(Clone, Default, Deserialize)]
 44struct ModelCapabilities {
 45    #[serde(default)]
 46    family: String,
 47    #[serde(default)]
 48    limits: ModelLimits,
 49    #[serde(default)]
 50    supports: ModelSupportedFeatures,
 51    #[serde(rename = "type", default)]
 52    model_type: String,
 53}
 54
 55#[derive(Clone, Default, Deserialize)]
 56struct ModelLimits {
 57    #[serde(default)]
 58    max_context_window_tokens: u64,
 59    #[serde(default)]
 60    max_output_tokens: u64,
 61}
 62
 63#[derive(Clone, Default, Deserialize)]
 64struct ModelSupportedFeatures {
 65    #[serde(default)]
 66    streaming: bool,
 67    #[serde(default)]
 68    tool_calls: bool,
 69    #[serde(default)]
 70    vision: bool,
 71}
 72
 73#[derive(Clone, Deserialize)]
 74struct ModelPolicy {
 75    state: String,
 76}
 77
 78struct CopilotChatProvider {
 79    streams: Mutex<HashMap<String, StreamState>>,
 80    next_stream_id: Mutex<u64>,
 81    device_flow_state: Mutex<Option<DeviceFlowState>>,
 82    api_token: Mutex<Option<ApiToken>>,
 83    cached_models: Mutex<Option<Vec<CopilotModel>>>,
 84}
 85
 86struct StreamState {
 87    response_stream: Option<HttpResponseStream>,
 88    buffer: String,
 89    started: bool,
 90    tool_calls: HashMap<usize, AccumulatedToolCall>,
 91    tool_calls_emitted: bool,
 92}
 93
 94#[derive(Clone, Default)]
 95struct AccumulatedToolCall {
 96    id: String,
 97    name: String,
 98    arguments: String,
 99}
100
101#[derive(Serialize)]
102struct OpenAiRequest {
103    model: String,
104    messages: Vec<OpenAiMessage>,
105    #[serde(skip_serializing_if = "Option::is_none")]
106    max_tokens: Option<u64>,
107    #[serde(skip_serializing_if = "Vec::is_empty")]
108    tools: Vec<OpenAiTool>,
109    #[serde(skip_serializing_if = "Option::is_none")]
110    tool_choice: Option<String>,
111    #[serde(skip_serializing_if = "Vec::is_empty")]
112    stop: Vec<String>,
113    #[serde(skip_serializing_if = "Option::is_none")]
114    temperature: Option<f32>,
115    stream: bool,
116    #[serde(skip_serializing_if = "Option::is_none")]
117    stream_options: Option<StreamOptions>,
118}
119
120#[derive(Serialize)]
121struct StreamOptions {
122    include_usage: bool,
123}
124
125#[derive(Serialize)]
126struct OpenAiMessage {
127    role: String,
128    #[serde(skip_serializing_if = "Option::is_none")]
129    content: Option<OpenAiContent>,
130    #[serde(skip_serializing_if = "Option::is_none")]
131    tool_calls: Option<Vec<OpenAiToolCall>>,
132    #[serde(skip_serializing_if = "Option::is_none")]
133    tool_call_id: Option<String>,
134}
135
136#[derive(Serialize, Clone)]
137#[serde(untagged)]
138enum OpenAiContent {
139    Text(String),
140    Parts(Vec<OpenAiContentPart>),
141}
142
143#[derive(Serialize, Clone)]
144#[serde(tag = "type")]
145enum OpenAiContentPart {
146    #[serde(rename = "text")]
147    Text { text: String },
148    #[serde(rename = "image_url")]
149    ImageUrl { image_url: ImageUrl },
150}
151
152#[derive(Serialize, Clone)]
153struct ImageUrl {
154    url: String,
155}
156
157#[derive(Serialize, Clone)]
158struct OpenAiToolCall {
159    id: String,
160    #[serde(rename = "type")]
161    call_type: String,
162    function: OpenAiFunctionCall,
163}
164
165#[derive(Serialize, Clone)]
166struct OpenAiFunctionCall {
167    name: String,
168    arguments: String,
169}
170
171#[derive(Serialize)]
172struct OpenAiTool {
173    #[serde(rename = "type")]
174    tool_type: String,
175    function: OpenAiFunctionDef,
176}
177
178#[derive(Serialize)]
179struct OpenAiFunctionDef {
180    name: String,
181    description: String,
182    parameters: serde_json::Value,
183}
184
185#[derive(Deserialize, Debug)]
186struct OpenAiStreamResponse {
187    choices: Vec<OpenAiStreamChoice>,
188    #[serde(default)]
189    usage: Option<OpenAiUsage>,
190}
191
192#[derive(Deserialize, Debug)]
193struct OpenAiStreamChoice {
194    delta: OpenAiDelta,
195    finish_reason: Option<String>,
196}
197
198#[derive(Deserialize, Debug, Default)]
199struct OpenAiDelta {
200    #[serde(default)]
201    content: Option<String>,
202    #[serde(default)]
203    tool_calls: Option<Vec<OpenAiToolCallDelta>>,
204}
205
206#[derive(Deserialize, Debug)]
207struct OpenAiToolCallDelta {
208    index: usize,
209    #[serde(default)]
210    id: Option<String>,
211    #[serde(default)]
212    function: Option<OpenAiFunctionDelta>,
213}
214
215#[derive(Deserialize, Debug, Default)]
216struct OpenAiFunctionDelta {
217    #[serde(default)]
218    name: Option<String>,
219    #[serde(default)]
220    arguments: Option<String>,
221}
222
223#[derive(Deserialize, Debug)]
224struct OpenAiUsage {
225    prompt_tokens: u64,
226    completion_tokens: u64,
227}
228
229fn convert_request(
230    model_id: &str,
231    request: &LlmCompletionRequest,
232) -> Result<OpenAiRequest, String> {
233    let mut messages: Vec<OpenAiMessage> = Vec::new();
234
235    for msg in &request.messages {
236        match msg.role {
237            LlmMessageRole::System => {
238                let mut text_content = String::new();
239                for content in &msg.content {
240                    if let LlmMessageContent::Text(text) = content {
241                        if !text_content.is_empty() {
242                            text_content.push('\n');
243                        }
244                        text_content.push_str(text);
245                    }
246                }
247                if !text_content.is_empty() {
248                    messages.push(OpenAiMessage {
249                        role: "system".to_string(),
250                        content: Some(OpenAiContent::Text(text_content)),
251                        tool_calls: None,
252                        tool_call_id: None,
253                    });
254                }
255            }
256            LlmMessageRole::User => {
257                let mut parts: Vec<OpenAiContentPart> = Vec::new();
258                let mut tool_result_messages: Vec<OpenAiMessage> = Vec::new();
259
260                for content in &msg.content {
261                    match content {
262                        LlmMessageContent::Text(text) => {
263                            if !text.is_empty() {
264                                parts.push(OpenAiContentPart::Text { text: text.clone() });
265                            }
266                        }
267                        LlmMessageContent::Image(img) => {
268                            let data_url = format!("data:image/png;base64,{}", img.source);
269                            parts.push(OpenAiContentPart::ImageUrl {
270                                image_url: ImageUrl { url: data_url },
271                            });
272                        }
273                        LlmMessageContent::ToolResult(result) => {
274                            let content_text = match &result.content {
275                                LlmToolResultContent::Text(t) => t.clone(),
276                                LlmToolResultContent::Image(_) => "[Image]".to_string(),
277                            };
278                            tool_result_messages.push(OpenAiMessage {
279                                role: "tool".to_string(),
280                                content: Some(OpenAiContent::Text(content_text)),
281                                tool_calls: None,
282                                tool_call_id: Some(result.tool_use_id.clone()),
283                            });
284                        }
285                        _ => {}
286                    }
287                }
288
289                if !parts.is_empty() {
290                    let content = if parts.len() == 1 {
291                        if let OpenAiContentPart::Text { text } = &parts[0] {
292                            OpenAiContent::Text(text.clone())
293                        } else {
294                            OpenAiContent::Parts(parts)
295                        }
296                    } else {
297                        OpenAiContent::Parts(parts)
298                    };
299
300                    messages.push(OpenAiMessage {
301                        role: "user".to_string(),
302                        content: Some(content),
303                        tool_calls: None,
304                        tool_call_id: None,
305                    });
306                }
307
308                messages.extend(tool_result_messages);
309            }
310            LlmMessageRole::Assistant => {
311                let mut text_content = String::new();
312                let mut tool_calls: Vec<OpenAiToolCall> = Vec::new();
313
314                for content in &msg.content {
315                    match content {
316                        LlmMessageContent::Text(text) => {
317                            if !text.is_empty() {
318                                if !text_content.is_empty() {
319                                    text_content.push('\n');
320                                }
321                                text_content.push_str(text);
322                            }
323                        }
324                        LlmMessageContent::ToolUse(tool_use) => {
325                            tool_calls.push(OpenAiToolCall {
326                                id: tool_use.id.clone(),
327                                call_type: "function".to_string(),
328                                function: OpenAiFunctionCall {
329                                    name: tool_use.name.clone(),
330                                    arguments: tool_use.input.clone(),
331                                },
332                            });
333                        }
334                        _ => {}
335                    }
336                }
337
338                messages.push(OpenAiMessage {
339                    role: "assistant".to_string(),
340                    content: if text_content.is_empty() {
341                        None
342                    } else {
343                        Some(OpenAiContent::Text(text_content))
344                    },
345                    tool_calls: if tool_calls.is_empty() {
346                        None
347                    } else {
348                        Some(tool_calls)
349                    },
350                    tool_call_id: None,
351                });
352            }
353        }
354    }
355
356    let tools: Vec<OpenAiTool> = request
357        .tools
358        .iter()
359        .map(|t| OpenAiTool {
360            tool_type: "function".to_string(),
361            function: OpenAiFunctionDef {
362                name: t.name.clone(),
363                description: t.description.clone(),
364                parameters: serde_json::from_str(&t.input_schema)
365                    .unwrap_or(serde_json::Value::Object(Default::default())),
366            },
367        })
368        .collect();
369
370    let tool_choice = request.tool_choice.as_ref().map(|tc| match tc {
371        LlmToolChoice::Auto => "auto".to_string(),
372        LlmToolChoice::Any => "required".to_string(),
373        LlmToolChoice::None => "none".to_string(),
374    });
375
376    let max_tokens = request.max_tokens;
377
378    Ok(OpenAiRequest {
379        model: model_id.to_string(),
380        messages,
381        max_tokens,
382        tools,
383        tool_choice,
384        stop: request.stop_sequences.clone(),
385        temperature: request.temperature,
386        stream: true,
387        stream_options: Some(StreamOptions {
388            include_usage: true,
389        }),
390    })
391}
392
393fn parse_sse_line(line: &str) -> Option<OpenAiStreamResponse> {
394    let data = line.strip_prefix("data: ")?;
395    if data.trim() == "[DONE]" {
396        return None;
397    }
398    serde_json::from_str(data).ok()
399}
400
401impl zed::Extension for CopilotChatProvider {
402    fn new() -> Self {
403        Self {
404            streams: Mutex::new(HashMap::new()),
405            next_stream_id: Mutex::new(0),
406            device_flow_state: Mutex::new(None),
407            api_token: Mutex::new(None),
408            cached_models: Mutex::new(None),
409        }
410    }
411
412    fn llm_providers(&self) -> Vec<LlmProviderInfo> {
413        vec![LlmProviderInfo {
414            id: "copilot-chat".into(),
415            name: "Copilot Chat".into(),
416            icon: Some("icons/copilot.svg".into()),
417        }]
418    }
419
420    fn llm_provider_models(&self, _provider_id: &str) -> Result<Vec<LlmModelInfo>, String> {
421        // Try to get models from cache first
422        if let Some(models) = self.cached_models.lock().unwrap().as_ref() {
423            return Ok(convert_models_to_llm_info(models));
424        }
425
426        // Need to fetch models - requires authentication
427        let oauth_token = match llm_get_credential("copilot-chat") {
428            Some(token) => token,
429            None => return Ok(Vec::new()), // Not authenticated, return empty
430        };
431
432        // Get API token
433        let api_token = self.get_api_token(&oauth_token)?;
434
435        // Fetch models from API
436        let models = self.fetch_models(&api_token)?;
437
438        // Cache the models
439        *self.cached_models.lock().unwrap() = Some(models.clone());
440
441        Ok(convert_models_to_llm_info(&models))
442    }
443
444    fn llm_provider_is_authenticated(&self, _provider_id: &str) -> bool {
445        llm_get_credential("copilot-chat").is_some()
446    }
447
448    fn llm_provider_settings_markdown(&self, _provider_id: &str) -> Option<String> {
449        Some(
450            "To use Copilot Chat, sign in with your GitHub account. This requires an active [GitHub Copilot subscription](https://github.com/features/copilot).".to_string(),
451        )
452    }
453
454    fn llm_provider_authenticate(&mut self, _provider_id: &str) -> Result<(), String> {
455        // Check if we have existing credentials
456        if llm_get_credential("copilot-chat").is_some() {
457            return Ok(());
458        }
459
460        // No credentials found - return error for background auth checks.
461        // The device flow will be triggered by the host when the user clicks
462        // the "Sign in with GitHub" button, which calls llm_provider_start_device_flow_sign_in.
463        Err("CredentialsNotFound".to_string())
464    }
465
466    fn llm_provider_start_device_flow_sign_in(
467        &mut self,
468        _provider_id: &str,
469    ) -> Result<String, String> {
470        // Step 1: Request device and user verification codes
471        let device_code_response = llm_oauth_http_request(&LlmOauthHttpRequest {
472            url: GITHUB_DEVICE_CODE_URL.to_string(),
473            method: "POST".to_string(),
474            headers: vec![
475                ("Accept".to_string(), "application/json".to_string()),
476                (
477                    "Content-Type".to_string(),
478                    "application/x-www-form-urlencoded".to_string(),
479                ),
480            ],
481            body: format!("client_id={}&scope=read:user", GITHUB_COPILOT_CLIENT_ID),
482        })?;
483
484        if device_code_response.status != 200 {
485            return Err(format!(
486                "Failed to get device code: HTTP {}",
487                device_code_response.status
488            ));
489        }
490
491        #[derive(Deserialize)]
492        struct DeviceCodeResponse {
493            device_code: String,
494            user_code: String,
495            verification_uri: String,
496            #[serde(default)]
497            verification_uri_complete: Option<String>,
498            expires_in: u64,
499            interval: u64,
500        }
501
502        let device_info: DeviceCodeResponse = serde_json::from_str(&device_code_response.body)
503            .map_err(|e| format!("Failed to parse device code response: {}", e))?;
504
505        // Store device flow state for polling
506        *self.device_flow_state.lock().unwrap() = Some(DeviceFlowState {
507            device_code: device_info.device_code,
508            interval: device_info.interval,
509            expires_in: device_info.expires_in,
510        });
511
512        // Step 2: Open browser to verification URL
513        // Use verification_uri_complete if available (has code pre-filled), otherwise construct URL
514        let verification_url = device_info.verification_uri_complete.unwrap_or_else(|| {
515            format!(
516                "{}?user_code={}",
517                device_info.verification_uri, &device_info.user_code
518            )
519        });
520        llm_oauth_open_browser(&verification_url)?;
521
522        // Return the user code for the host to display
523        Ok(device_info.user_code)
524    }
525
526    fn llm_provider_poll_device_flow_sign_in(&mut self, _provider_id: &str) -> Result<(), String> {
527        let state = self
528            .device_flow_state
529            .lock()
530            .unwrap()
531            .take()
532            .ok_or("No device flow in progress")?;
533
534        let poll_interval = Duration::from_secs(state.interval.max(5));
535        let max_attempts = (state.expires_in / state.interval.max(5)) as usize;
536
537        for _ in 0..max_attempts {
538            thread::sleep(poll_interval);
539
540            let token_response = llm_oauth_http_request(&LlmOauthHttpRequest {
541                url: GITHUB_ACCESS_TOKEN_URL.to_string(),
542                method: "POST".to_string(),
543                headers: vec![
544                    ("Accept".to_string(), "application/json".to_string()),
545                    (
546                        "Content-Type".to_string(),
547                        "application/x-www-form-urlencoded".to_string(),
548                    ),
549                ],
550                body: format!(
551                    "client_id={}&device_code={}&grant_type=urn:ietf:params:oauth:grant-type:device_code",
552                    GITHUB_COPILOT_CLIENT_ID, state.device_code
553                ),
554            })?;
555
556            #[derive(Deserialize)]
557            struct TokenResponse {
558                access_token: Option<String>,
559                error: Option<String>,
560                error_description: Option<String>,
561            }
562
563            let token_json: TokenResponse = serde_json::from_str(&token_response.body)
564                .map_err(|e| format!("Failed to parse token response: {}", e))?;
565
566            if let Some(access_token) = token_json.access_token {
567                llm_store_credential("copilot-chat", &access_token)?;
568                return Ok(());
569            }
570
571            if let Some(error) = &token_json.error {
572                match error.as_str() {
573                    "authorization_pending" => {
574                        // User hasn't authorized yet, keep polling
575                        continue;
576                    }
577                    "slow_down" => {
578                        // Need to slow down polling
579                        thread::sleep(Duration::from_secs(5));
580                        continue;
581                    }
582                    "expired_token" => {
583                        return Err("Device code expired. Please try again.".to_string());
584                    }
585                    "access_denied" => {
586                        return Err("Authorization was denied.".to_string());
587                    }
588                    _ => {
589                        let description = token_json.error_description.unwrap_or_default();
590                        return Err(format!("OAuth error: {} - {}", error, description));
591                    }
592                }
593            }
594        }
595
596        Err("Authorization timed out. Please try again.".to_string())
597    }
598
599    fn llm_provider_reset_credentials(&mut self, _provider_id: &str) -> Result<(), String> {
600        // Clear cached API token and models
601        *self.api_token.lock().unwrap() = None;
602        *self.cached_models.lock().unwrap() = None;
603        llm_delete_credential("copilot-chat")
604    }
605
606    fn llm_stream_completion_start(
607        &mut self,
608        _provider_id: &str,
609        model_id: &str,
610        request: &LlmCompletionRequest,
611    ) -> Result<String, String> {
612        let oauth_token = llm_get_credential("copilot-chat").ok_or_else(|| {
613            "No token configured. Please add your GitHub Copilot token in settings.".to_string()
614        })?;
615
616        // Get or refresh API token
617        let api_token = self.get_api_token(&oauth_token)?;
618
619        let openai_request = convert_request(model_id, request)?;
620
621        let body = serde_json::to_vec(&openai_request)
622            .map_err(|e| format!("Failed to serialize request: {}", e))?;
623
624        let completions_url = format!("{}/chat/completions", api_token.api_endpoint);
625
626        let http_request = HttpRequest {
627            method: HttpMethod::Post,
628            url: completions_url,
629            headers: vec![
630                ("Content-Type".to_string(), "application/json".to_string()),
631                (
632                    "Authorization".to_string(),
633                    format!("Bearer {}", api_token.api_key),
634                ),
635                (
636                    "Copilot-Integration-Id".to_string(),
637                    "vscode-chat".to_string(),
638                ),
639                ("Editor-Version".to_string(), "Zed/1.0.0".to_string()),
640            ],
641            body: Some(body),
642            redirect_policy: RedirectPolicy::FollowAll,
643        };
644
645        let response_stream = http_request
646            .fetch_stream()
647            .map_err(|e| format!("HTTP request failed: {}", e))?;
648
649        let stream_id = {
650            let mut id_counter = self.next_stream_id.lock().unwrap();
651            let id = format!("copilot-stream-{}", *id_counter);
652            *id_counter += 1;
653            id
654        };
655
656        self.streams.lock().unwrap().insert(
657            stream_id.clone(),
658            StreamState {
659                response_stream: Some(response_stream),
660                buffer: String::new(),
661                started: false,
662                tool_calls: HashMap::new(),
663                tool_calls_emitted: false,
664            },
665        );
666
667        Ok(stream_id)
668    }
669
670    fn llm_stream_completion_next(
671        &mut self,
672        stream_id: &str,
673    ) -> Result<Option<LlmCompletionEvent>, String> {
674        let mut streams = self.streams.lock().unwrap();
675        let state = streams
676            .get_mut(stream_id)
677            .ok_or_else(|| format!("Unknown stream: {}", stream_id))?;
678
679        if !state.started {
680            state.started = true;
681            return Ok(Some(LlmCompletionEvent::Started));
682        }
683
684        let response_stream = state
685            .response_stream
686            .as_mut()
687            .ok_or_else(|| "Stream already closed".to_string())?;
688
689        loop {
690            if let Some(newline_pos) = state.buffer.find('\n') {
691                let line = state.buffer[..newline_pos].to_string();
692                state.buffer = state.buffer[newline_pos + 1..].to_string();
693
694                if line.trim().is_empty() {
695                    continue;
696                }
697
698                if let Some(response) = parse_sse_line(&line) {
699                    if let Some(choice) = response.choices.first() {
700                        if let Some(content) = &choice.delta.content {
701                            if !content.is_empty() {
702                                return Ok(Some(LlmCompletionEvent::Text(content.clone())));
703                            }
704                        }
705
706                        if let Some(tool_calls) = &choice.delta.tool_calls {
707                            for tc in tool_calls {
708                                let entry = state
709                                    .tool_calls
710                                    .entry(tc.index)
711                                    .or_insert_with(AccumulatedToolCall::default);
712
713                                if let Some(id) = &tc.id {
714                                    entry.id = id.clone();
715                                }
716                                if let Some(func) = &tc.function {
717                                    if let Some(name) = &func.name {
718                                        entry.name = name.clone();
719                                    }
720                                    if let Some(args) = &func.arguments {
721                                        entry.arguments.push_str(args);
722                                    }
723                                }
724                            }
725                        }
726
727                        if let Some(finish_reason) = &choice.finish_reason {
728                            if !state.tool_calls.is_empty() && !state.tool_calls_emitted {
729                                state.tool_calls_emitted = true;
730                                let mut tool_calls: Vec<_> = state.tool_calls.drain().collect();
731                                tool_calls.sort_by_key(|(idx, _)| *idx);
732
733                                if let Some((_, tc)) = tool_calls.into_iter().next() {
734                                    return Ok(Some(LlmCompletionEvent::ToolUse(LlmToolUse {
735                                        id: tc.id,
736                                        name: tc.name,
737                                        input: tc.arguments,
738                                        thought_signature: None,
739                                    })));
740                                }
741                            }
742
743                            let stop_reason = match finish_reason.as_str() {
744                                "stop" => LlmStopReason::EndTurn,
745                                "length" => LlmStopReason::MaxTokens,
746                                "tool_calls" => LlmStopReason::ToolUse,
747                                "content_filter" => LlmStopReason::Refusal,
748                                _ => LlmStopReason::EndTurn,
749                            };
750                            return Ok(Some(LlmCompletionEvent::Stop(stop_reason)));
751                        }
752                    }
753
754                    if let Some(usage) = response.usage {
755                        return Ok(Some(LlmCompletionEvent::Usage(LlmTokenUsage {
756                            input_tokens: usage.prompt_tokens,
757                            output_tokens: usage.completion_tokens,
758                            cache_creation_input_tokens: None,
759                            cache_read_input_tokens: None,
760                        })));
761                    }
762                }
763
764                continue;
765            }
766
767            match response_stream.next_chunk() {
768                Ok(Some(chunk)) => {
769                    let text = String::from_utf8_lossy(&chunk);
770                    state.buffer.push_str(&text);
771                }
772                Ok(None) => {
773                    return Ok(None);
774                }
775                Err(e) => {
776                    return Err(format!("Stream error: {}", e));
777                }
778            }
779        }
780    }
781
782    fn llm_stream_completion_close(&mut self, stream_id: &str) {
783        self.streams.lock().unwrap().remove(stream_id);
784    }
785}
786
787impl CopilotChatProvider {
788    fn get_api_token(&self, oauth_token: &str) -> Result<ApiToken, String> {
789        // Check if we have a cached token
790        if let Some(token) = self.api_token.lock().unwrap().clone() {
791            return Ok(token);
792        }
793
794        // Request a new API token
795        let http_request = HttpRequest {
796            method: HttpMethod::Get,
797            url: GITHUB_COPILOT_TOKEN_URL.to_string(),
798            headers: vec![
799                (
800                    "Authorization".to_string(),
801                    format!("token {}", oauth_token),
802                ),
803                ("Accept".to_string(), "application/json".to_string()),
804            ],
805            body: None,
806            redirect_policy: RedirectPolicy::FollowAll,
807        };
808
809        let response = http_request
810            .fetch()
811            .map_err(|e| format!("Failed to request API token: {}", e))?;
812
813        #[derive(Deserialize)]
814        struct ApiTokenResponse {
815            token: String,
816            endpoints: ApiEndpoints,
817        }
818
819        #[derive(Deserialize)]
820        struct ApiEndpoints {
821            api: String,
822        }
823
824        let token_response: ApiTokenResponse =
825            serde_json::from_slice(&response.body).map_err(|e| {
826                format!(
827                    "Failed to parse API token response: {} - body: {}",
828                    e,
829                    String::from_utf8_lossy(&response.body)
830                )
831            })?;
832
833        let api_token = ApiToken {
834            api_key: token_response.token,
835            api_endpoint: token_response.endpoints.api,
836        };
837
838        // Cache the token
839        *self.api_token.lock().unwrap() = Some(api_token.clone());
840
841        Ok(api_token)
842    }
843
844    fn fetch_models(&self, api_token: &ApiToken) -> Result<Vec<CopilotModel>, String> {
845        let models_url = format!("{}/models", api_token.api_endpoint);
846
847        let http_request = HttpRequest {
848            method: HttpMethod::Get,
849            url: models_url,
850            headers: vec![
851                (
852                    "Authorization".to_string(),
853                    format!("Bearer {}", api_token.api_key),
854                ),
855                ("Content-Type".to_string(), "application/json".to_string()),
856                (
857                    "Copilot-Integration-Id".to_string(),
858                    "vscode-chat".to_string(),
859                ),
860                ("Editor-Version".to_string(), "Zed/1.0.0".to_string()),
861                ("x-github-api-version".to_string(), "2025-05-01".to_string()),
862            ],
863            body: None,
864            redirect_policy: RedirectPolicy::FollowAll,
865        };
866
867        let response = http_request
868            .fetch()
869            .map_err(|e| format!("Failed to fetch models: {}", e))?;
870
871        #[derive(Deserialize)]
872        struct ModelsResponse {
873            data: Vec<CopilotModel>,
874        }
875
876        let models_response: ModelsResponse =
877            serde_json::from_slice(&response.body).map_err(|e| {
878                format!(
879                    "Failed to parse models response: {} - body: {}",
880                    e,
881                    String::from_utf8_lossy(&response.body)
882                )
883            })?;
884
885        // Filter models like the built-in Copilot Chat does
886        let mut models: Vec<CopilotModel> = models_response
887            .data
888            .into_iter()
889            .filter(|model| {
890                model.model_picker_enabled
891                    && model.capabilities.model_type == "chat"
892                    && model
893                        .policy
894                        .as_ref()
895                        .map(|p| p.state == "enabled")
896                        .unwrap_or(true)
897            })
898            .collect();
899
900        // Sort so default model is first
901        if let Some(pos) = models.iter().position(|m| m.is_chat_default) {
902            let default_model = models.remove(pos);
903            models.insert(0, default_model);
904        }
905
906        Ok(models)
907    }
908}
909
910fn convert_models_to_llm_info(models: &[CopilotModel]) -> Vec<LlmModelInfo> {
911    models
912        .iter()
913        .map(|m| {
914            let max_tokens = if m.capabilities.limits.max_context_window_tokens > 0 {
915                m.capabilities.limits.max_context_window_tokens
916            } else {
917                128_000 // Default fallback
918            };
919            let max_output = if m.capabilities.limits.max_output_tokens > 0 {
920                Some(m.capabilities.limits.max_output_tokens)
921            } else {
922                None
923            };
924
925            LlmModelInfo {
926                id: m.id.clone(),
927                name: m.name.clone(),
928                max_token_count: max_tokens,
929                max_output_tokens: max_output,
930                capabilities: LlmModelCapabilities {
931                    supports_images: m.capabilities.supports.vision,
932                    supports_tools: m.capabilities.supports.tool_calls,
933                    supports_tool_choice_auto: m.capabilities.supports.tool_calls,
934                    supports_tool_choice_any: m.capabilities.supports.tool_calls,
935                    supports_tool_choice_none: m.capabilities.supports.tool_calls,
936                    supports_thinking: false,
937                    tool_input_format: LlmToolInputFormat::JsonSchema,
938                },
939                is_default: m.is_chat_default,
940                is_default_fast: m.is_chat_fallback,
941            }
942        })
943        .collect()
944}
945
946#[cfg(test)]
947mod tests {
948    use super::*;
949
950    #[test]
951    fn test_device_flow_request_body() {
952        let body = format!("client_id={}&scope=read:user", GITHUB_COPILOT_CLIENT_ID);
953        assert!(body.contains("client_id=Iv1.b507a08c87ecfe98"));
954        assert!(body.contains("scope=read:user"));
955    }
956
957    #[test]
958    fn test_token_poll_request_body() {
959        let device_code = "test_device_code_123";
960        let body = format!(
961            "client_id={}&device_code={}&grant_type=urn:ietf:params:oauth:grant-type:device_code",
962            GITHUB_COPILOT_CLIENT_ID, device_code
963        );
964        assert!(body.contains("client_id=Iv1.b507a08c87ecfe98"));
965        assert!(body.contains("device_code=test_device_code_123"));
966        assert!(body.contains("grant_type=urn:ietf:params:oauth:grant-type:device_code"));
967    }
968}
969
970zed::register_extension!(CopilotChatProvider);