copilot_chat.rs

  1use std::collections::HashMap;
  2use std::sync::Mutex;
  3use std::thread;
  4use std::time::Duration;
  5
  6use serde::{Deserialize, Serialize};
  7use zed_extension_api::http_client::{HttpMethod, HttpRequest, HttpResponseStream, RedirectPolicy};
  8use zed_extension_api::{self as zed, *};
  9
 10const GITHUB_DEVICE_CODE_URL: &str = "https://github.com/login/device/code";
 11const GITHUB_ACCESS_TOKEN_URL: &str = "https://github.com/login/oauth/access_token";
 12const GITHUB_COPILOT_TOKEN_URL: &str = "https://api.github.com/copilot_internal/v2/token";
 13const GITHUB_COPILOT_CLIENT_ID: &str = "Iv1.b507a08c87ecfe98";
 14
 15struct DeviceFlowState {
 16    device_code: String,
 17    interval: u64,
 18    expires_in: u64,
 19}
 20
 21#[derive(Clone)]
 22struct ApiToken {
 23    api_key: String,
 24    api_endpoint: String,
 25}
 26
 27#[derive(Clone, Deserialize)]
 28struct CopilotModel {
 29    id: String,
 30    name: String,
 31    #[serde(default)]
 32    is_chat_default: bool,
 33    #[serde(default)]
 34    is_chat_fallback: bool,
 35    #[serde(default)]
 36    model_picker_enabled: bool,
 37    #[serde(default)]
 38    capabilities: ModelCapabilities,
 39    #[serde(default)]
 40    policy: Option<ModelPolicy>,
 41}
 42
 43#[derive(Clone, Default, Deserialize)]
 44struct ModelCapabilities {
 45    #[serde(default)]
 46    family: String,
 47    #[serde(default)]
 48    limits: ModelLimits,
 49    #[serde(default)]
 50    supports: ModelSupportedFeatures,
 51    #[serde(rename = "type", default)]
 52    model_type: String,
 53}
 54
 55#[derive(Clone, Default, Deserialize)]
 56struct ModelLimits {
 57    #[serde(default)]
 58    max_context_window_tokens: u64,
 59    #[serde(default)]
 60    max_output_tokens: u64,
 61}
 62
 63#[derive(Clone, Default, Deserialize)]
 64struct ModelSupportedFeatures {
 65    #[serde(default)]
 66    streaming: bool,
 67    #[serde(default)]
 68    tool_calls: bool,
 69    #[serde(default)]
 70    vision: bool,
 71}
 72
 73#[derive(Clone, Deserialize)]
 74struct ModelPolicy {
 75    state: String,
 76}
 77
 78struct CopilotChatProvider {
 79    streams: Mutex<HashMap<String, StreamState>>,
 80    next_stream_id: Mutex<u64>,
 81    device_flow_state: Mutex<Option<DeviceFlowState>>,
 82    api_token: Mutex<Option<ApiToken>>,
 83    cached_models: Mutex<Option<Vec<CopilotModel>>>,
 84}
 85
 86struct StreamState {
 87    response_stream: Option<HttpResponseStream>,
 88    buffer: String,
 89    started: bool,
 90    tool_calls: HashMap<usize, AccumulatedToolCall>,
 91    tool_calls_emitted: bool,
 92}
 93
 94#[derive(Clone, Default)]
 95struct AccumulatedToolCall {
 96    id: String,
 97    name: String,
 98    arguments: String,
 99}
100
101#[derive(Serialize)]
102struct OpenAiRequest {
103    model: String,
104    messages: Vec<OpenAiMessage>,
105    #[serde(skip_serializing_if = "Option::is_none")]
106    max_tokens: Option<u64>,
107    #[serde(skip_serializing_if = "Vec::is_empty")]
108    tools: Vec<OpenAiTool>,
109    #[serde(skip_serializing_if = "Option::is_none")]
110    tool_choice: Option<String>,
111    #[serde(skip_serializing_if = "Vec::is_empty")]
112    stop: Vec<String>,
113    #[serde(skip_serializing_if = "Option::is_none")]
114    temperature: Option<f32>,
115    stream: bool,
116    #[serde(skip_serializing_if = "Option::is_none")]
117    stream_options: Option<StreamOptions>,
118}
119
120#[derive(Serialize)]
121struct StreamOptions {
122    include_usage: bool,
123}
124
125#[derive(Serialize)]
126struct OpenAiMessage {
127    role: String,
128    #[serde(skip_serializing_if = "Option::is_none")]
129    content: Option<OpenAiContent>,
130    #[serde(skip_serializing_if = "Option::is_none")]
131    tool_calls: Option<Vec<OpenAiToolCall>>,
132    #[serde(skip_serializing_if = "Option::is_none")]
133    tool_call_id: Option<String>,
134}
135
136#[derive(Serialize, Clone)]
137#[serde(untagged)]
138enum OpenAiContent {
139    Text(String),
140    Parts(Vec<OpenAiContentPart>),
141}
142
143#[derive(Serialize, Clone)]
144#[serde(tag = "type")]
145enum OpenAiContentPart {
146    #[serde(rename = "text")]
147    Text { text: String },
148    #[serde(rename = "image_url")]
149    ImageUrl { image_url: ImageUrl },
150}
151
152#[derive(Serialize, Clone)]
153struct ImageUrl {
154    url: String,
155}
156
157#[derive(Serialize, Clone)]
158struct OpenAiToolCall {
159    id: String,
160    #[serde(rename = "type")]
161    call_type: String,
162    function: OpenAiFunctionCall,
163}
164
165#[derive(Serialize, Clone)]
166struct OpenAiFunctionCall {
167    name: String,
168    arguments: String,
169}
170
171#[derive(Serialize)]
172struct OpenAiTool {
173    #[serde(rename = "type")]
174    tool_type: String,
175    function: OpenAiFunctionDef,
176}
177
178#[derive(Serialize)]
179struct OpenAiFunctionDef {
180    name: String,
181    description: String,
182    parameters: serde_json::Value,
183}
184
185#[derive(Deserialize, Debug)]
186struct OpenAiStreamResponse {
187    choices: Vec<OpenAiStreamChoice>,
188    #[serde(default)]
189    usage: Option<OpenAiUsage>,
190}
191
192#[derive(Deserialize, Debug)]
193struct OpenAiStreamChoice {
194    delta: OpenAiDelta,
195    finish_reason: Option<String>,
196}
197
198#[derive(Deserialize, Debug, Default)]
199struct OpenAiDelta {
200    #[serde(default)]
201    content: Option<String>,
202    #[serde(default)]
203    tool_calls: Option<Vec<OpenAiToolCallDelta>>,
204}
205
206#[derive(Deserialize, Debug)]
207struct OpenAiToolCallDelta {
208    index: usize,
209    #[serde(default)]
210    id: Option<String>,
211    #[serde(default)]
212    function: Option<OpenAiFunctionDelta>,
213}
214
215#[derive(Deserialize, Debug, Default)]
216struct OpenAiFunctionDelta {
217    #[serde(default)]
218    name: Option<String>,
219    #[serde(default)]
220    arguments: Option<String>,
221}
222
223#[derive(Deserialize, Debug)]
224struct OpenAiUsage {
225    prompt_tokens: u64,
226    completion_tokens: u64,
227}
228
229fn convert_request(
230    model_id: &str,
231    request: &LlmCompletionRequest,
232) -> Result<OpenAiRequest, String> {
233    let mut messages: Vec<OpenAiMessage> = Vec::new();
234
235    for msg in &request.messages {
236        match msg.role {
237            LlmMessageRole::System => {
238                let mut text_content = String::new();
239                for content in &msg.content {
240                    if let LlmMessageContent::Text(text) = content {
241                        if !text_content.is_empty() {
242                            text_content.push('\n');
243                        }
244                        text_content.push_str(text);
245                    }
246                }
247                if !text_content.is_empty() {
248                    messages.push(OpenAiMessage {
249                        role: "system".to_string(),
250                        content: Some(OpenAiContent::Text(text_content)),
251                        tool_calls: None,
252                        tool_call_id: None,
253                    });
254                }
255            }
256            LlmMessageRole::User => {
257                let mut parts: Vec<OpenAiContentPart> = Vec::new();
258                let mut tool_result_messages: Vec<OpenAiMessage> = Vec::new();
259
260                for content in &msg.content {
261                    match content {
262                        LlmMessageContent::Text(text) => {
263                            if !text.is_empty() {
264                                parts.push(OpenAiContentPart::Text { text: text.clone() });
265                            }
266                        }
267                        LlmMessageContent::Image(img) => {
268                            let data_url = format!("data:image/png;base64,{}", img.source);
269                            parts.push(OpenAiContentPart::ImageUrl {
270                                image_url: ImageUrl { url: data_url },
271                            });
272                        }
273                        LlmMessageContent::ToolResult(result) => {
274                            let content_text = match &result.content {
275                                LlmToolResultContent::Text(t) => t.clone(),
276                                LlmToolResultContent::Image(_) => "[Image]".to_string(),
277                            };
278                            tool_result_messages.push(OpenAiMessage {
279                                role: "tool".to_string(),
280                                content: Some(OpenAiContent::Text(content_text)),
281                                tool_calls: None,
282                                tool_call_id: Some(result.tool_use_id.clone()),
283                            });
284                        }
285                        _ => {}
286                    }
287                }
288
289                if !parts.is_empty() {
290                    let content = if parts.len() == 1 {
291                        if let OpenAiContentPart::Text { text } = &parts[0] {
292                            OpenAiContent::Text(text.clone())
293                        } else {
294                            OpenAiContent::Parts(parts)
295                        }
296                    } else {
297                        OpenAiContent::Parts(parts)
298                    };
299
300                    messages.push(OpenAiMessage {
301                        role: "user".to_string(),
302                        content: Some(content),
303                        tool_calls: None,
304                        tool_call_id: None,
305                    });
306                }
307
308                messages.extend(tool_result_messages);
309            }
310            LlmMessageRole::Assistant => {
311                let mut text_content = String::new();
312                let mut tool_calls: Vec<OpenAiToolCall> = Vec::new();
313
314                for content in &msg.content {
315                    match content {
316                        LlmMessageContent::Text(text) => {
317                            if !text.is_empty() {
318                                if !text_content.is_empty() {
319                                    text_content.push('\n');
320                                }
321                                text_content.push_str(text);
322                            }
323                        }
324                        LlmMessageContent::ToolUse(tool_use) => {
325                            tool_calls.push(OpenAiToolCall {
326                                id: tool_use.id.clone(),
327                                call_type: "function".to_string(),
328                                function: OpenAiFunctionCall {
329                                    name: tool_use.name.clone(),
330                                    arguments: tool_use.input.clone(),
331                                },
332                            });
333                        }
334                        _ => {}
335                    }
336                }
337
338                messages.push(OpenAiMessage {
339                    role: "assistant".to_string(),
340                    content: if text_content.is_empty() {
341                        None
342                    } else {
343                        Some(OpenAiContent::Text(text_content))
344                    },
345                    tool_calls: if tool_calls.is_empty() {
346                        None
347                    } else {
348                        Some(tool_calls)
349                    },
350                    tool_call_id: None,
351                });
352            }
353        }
354    }
355
356    let tools: Vec<OpenAiTool> = request
357        .tools
358        .iter()
359        .map(|t| OpenAiTool {
360            tool_type: "function".to_string(),
361            function: OpenAiFunctionDef {
362                name: t.name.clone(),
363                description: t.description.clone(),
364                parameters: serde_json::from_str(&t.input_schema)
365                    .unwrap_or(serde_json::Value::Object(Default::default())),
366            },
367        })
368        .collect();
369
370    let tool_choice = request.tool_choice.as_ref().map(|tc| match tc {
371        LlmToolChoice::Auto => "auto".to_string(),
372        LlmToolChoice::Any => "required".to_string(),
373        LlmToolChoice::None => "none".to_string(),
374    });
375
376    let max_tokens = request.max_tokens;
377
378    Ok(OpenAiRequest {
379        model: model_id.to_string(),
380        messages,
381        max_tokens,
382        tools,
383        tool_choice,
384        stop: request.stop_sequences.clone(),
385        temperature: request.temperature,
386        stream: true,
387        stream_options: Some(StreamOptions {
388            include_usage: true,
389        }),
390    })
391}
392
393fn parse_sse_line(line: &str) -> Option<OpenAiStreamResponse> {
394    let data = line.strip_prefix("data: ")?;
395    if data.trim() == "[DONE]" {
396        return None;
397    }
398    serde_json::from_str(data).ok()
399}
400
401impl zed::Extension for CopilotChatProvider {
402    fn new() -> Self {
403        Self {
404            streams: Mutex::new(HashMap::new()),
405            next_stream_id: Mutex::new(0),
406            device_flow_state: Mutex::new(None),
407            api_token: Mutex::new(None),
408            cached_models: Mutex::new(None),
409        }
410    }
411
412    fn llm_providers(&self) -> Vec<LlmProviderInfo> {
413        vec![LlmProviderInfo {
414            id: "copilot-chat".into(),
415            name: "Copilot Chat".into(),
416            icon: Some("icons/copilot.svg".into()),
417        }]
418    }
419
420    fn llm_provider_models(&self, _provider_id: &str) -> Result<Vec<LlmModelInfo>, String> {
421        // Try to get models from cache first
422        if let Some(models) = self.cached_models.lock().unwrap().as_ref() {
423            return Ok(convert_models_to_llm_info(models));
424        }
425
426        // Need to fetch models - requires authentication
427        let oauth_token = match llm_get_credential("copilot-chat") {
428            Some(token) => token,
429            None => return Ok(Vec::new()), // Not authenticated, return empty
430        };
431
432        // Get API token
433        let api_token = self.get_api_token(&oauth_token)?;
434
435        // Fetch models from API
436        let models = self.fetch_models(&api_token)?;
437
438        // Cache the models
439        *self.cached_models.lock().unwrap() = Some(models.clone());
440
441        Ok(convert_models_to_llm_info(&models))
442    }
443
444    fn llm_provider_is_authenticated(&self, _provider_id: &str) -> bool {
445        llm_get_credential("copilot-chat").is_some()
446    }
447
448    fn llm_provider_settings_markdown(&self, _provider_id: &str) -> Option<String> {
449        Some(
450            "To use Copilot Chat, sign in with your GitHub account. This requires an active [GitHub Copilot subscription](https://github.com/features/copilot).".to_string(),
451        )
452    }
453
454    fn llm_provider_start_device_flow_sign_in(
455        &mut self,
456        _provider_id: &str,
457    ) -> Result<LlmDeviceFlowPromptInfo, String> {
458        // Step 1: Request device and user verification codes
459        let device_code_response = llm_oauth_send_http_request(&HttpRequest {
460            method: HttpMethod::Post,
461            url: GITHUB_DEVICE_CODE_URL.to_string(),
462            headers: vec![
463                ("Accept".to_string(), "application/json".to_string()),
464                (
465                    "Content-Type".to_string(),
466                    "application/x-www-form-urlencoded".to_string(),
467                ),
468            ],
469            body: Some(
470                format!("client_id={}&scope=read:user", GITHUB_COPILOT_CLIENT_ID).into_bytes(),
471            ),
472            redirect_policy: RedirectPolicy::NoFollow,
473        })?;
474
475        if device_code_response.status != 200 {
476            return Err(format!(
477                "Failed to get device code: HTTP {}",
478                device_code_response.status
479            ));
480        }
481
482        #[derive(Deserialize)]
483        struct DeviceCodeResponse {
484            device_code: String,
485            user_code: String,
486            verification_uri: String,
487            #[serde(default)]
488            verification_uri_complete: Option<String>,
489            expires_in: u64,
490            interval: u64,
491        }
492
493        let device_info: DeviceCodeResponse = serde_json::from_slice(&device_code_response.body)
494            .map_err(|e| format!("Failed to parse device code response: {}", e))?;
495
496        // Store device flow state for polling
497        *self.device_flow_state.lock().unwrap() = Some(DeviceFlowState {
498            device_code: device_info.device_code,
499            interval: device_info.interval,
500            expires_in: device_info.expires_in,
501        });
502
503        // Step 2: Construct verification URL
504        // Use verification_uri_complete if available (has code pre-filled), otherwise construct URL
505        let verification_url = device_info.verification_uri_complete.unwrap_or_else(|| {
506            format!(
507                "{}?user_code={}",
508                device_info.verification_uri, &device_info.user_code
509            )
510        });
511
512        // Return prompt info for the host to display in the modal
513        Ok(LlmDeviceFlowPromptInfo {
514            user_code: device_info.user_code,
515            verification_url,
516            headline: "Use GitHub Copilot in Zed.".to_string(),
517            description: "Using Copilot requires an active subscription on GitHub.".to_string(),
518            connect_button_label: "Connect to GitHub".to_string(),
519            success_headline: "Copilot Enabled!".to_string(),
520            success_message:
521                "You can update your settings or sign out from the Copilot menu in the status bar."
522                    .to_string(),
523        })
524    }
525
526    fn llm_provider_poll_device_flow_sign_in(&mut self, _provider_id: &str) -> Result<(), String> {
527        let state = self
528            .device_flow_state
529            .lock()
530            .unwrap()
531            .take()
532            .ok_or("No device flow in progress")?;
533
534        let poll_interval = Duration::from_secs(state.interval.max(5));
535        let max_attempts = (state.expires_in / state.interval.max(5)) as usize;
536
537        for _ in 0..max_attempts {
538            thread::sleep(poll_interval);
539
540            let token_response = llm_oauth_send_http_request(&HttpRequest {
541                method: HttpMethod::Post,
542                url: GITHUB_ACCESS_TOKEN_URL.to_string(),
543                headers: vec![
544                    ("Accept".to_string(), "application/json".to_string()),
545                    (
546                        "Content-Type".to_string(),
547                        "application/x-www-form-urlencoded".to_string(),
548                    ),
549                ],
550                body: Some(
551                    format!(
552                        "client_id={}&device_code={}&grant_type=urn:ietf:params:oauth:grant-type:device_code",
553                        GITHUB_COPILOT_CLIENT_ID, state.device_code
554                    )
555                    .into_bytes(),
556                ),
557                redirect_policy: RedirectPolicy::NoFollow,
558            })?;
559
560            #[derive(Deserialize)]
561            struct TokenResponse {
562                access_token: Option<String>,
563                error: Option<String>,
564                error_description: Option<String>,
565            }
566
567            let token_json: TokenResponse = serde_json::from_slice(&token_response.body)
568                .map_err(|e| format!("Failed to parse token response: {}", e))?;
569
570            if let Some(access_token) = token_json.access_token {
571                llm_store_credential("copilot-chat", &access_token)?;
572                return Ok(());
573            }
574
575            if let Some(error) = &token_json.error {
576                match error.as_str() {
577                    "authorization_pending" => {
578                        // User hasn't authorized yet, keep polling
579                        continue;
580                    }
581                    "slow_down" => {
582                        // Need to slow down polling
583                        thread::sleep(Duration::from_secs(5));
584                        continue;
585                    }
586                    "expired_token" => {
587                        return Err("Device code expired. Please try again.".to_string());
588                    }
589                    "access_denied" => {
590                        return Err("Authorization was denied.".to_string());
591                    }
592                    _ => {
593                        let description = token_json.error_description.unwrap_or_default();
594                        return Err(format!("OAuth error: {} - {}", error, description));
595                    }
596                }
597            }
598        }
599
600        Err("Authorization timed out. Please try again.".to_string())
601    }
602
603    fn llm_provider_reset_credentials(&mut self, _provider_id: &str) -> Result<(), String> {
604        // Clear cached API token and models
605        *self.api_token.lock().unwrap() = None;
606        *self.cached_models.lock().unwrap() = None;
607        llm_delete_credential("copilot-chat")
608    }
609
610    fn llm_stream_completion_start(
611        &mut self,
612        _provider_id: &str,
613        model_id: &str,
614        request: &LlmCompletionRequest,
615    ) -> Result<String, String> {
616        let oauth_token = llm_get_credential("copilot-chat").ok_or_else(|| {
617            "No token configured. Please add your GitHub Copilot token in settings.".to_string()
618        })?;
619
620        // Get or refresh API token
621        let api_token = self.get_api_token(&oauth_token)?;
622
623        let openai_request = convert_request(model_id, request)?;
624
625        let body = serde_json::to_vec(&openai_request)
626            .map_err(|e| format!("Failed to serialize request: {}", e))?;
627
628        let completions_url = format!("{}/chat/completions", api_token.api_endpoint);
629
630        let http_request = HttpRequest {
631            method: HttpMethod::Post,
632            url: completions_url,
633            headers: vec![
634                ("Content-Type".to_string(), "application/json".to_string()),
635                (
636                    "Authorization".to_string(),
637                    format!("Bearer {}", api_token.api_key),
638                ),
639                (
640                    "Copilot-Integration-Id".to_string(),
641                    "vscode-chat".to_string(),
642                ),
643                ("Editor-Version".to_string(), "Zed/1.0.0".to_string()),
644            ],
645            body: Some(body),
646            redirect_policy: RedirectPolicy::FollowAll,
647        };
648
649        let response_stream = http_request
650            .fetch_stream()
651            .map_err(|e| format!("HTTP request failed: {}", e))?;
652
653        let stream_id = {
654            let mut id_counter = self.next_stream_id.lock().unwrap();
655            let id = format!("copilot-stream-{}", *id_counter);
656            *id_counter += 1;
657            id
658        };
659
660        self.streams.lock().unwrap().insert(
661            stream_id.clone(),
662            StreamState {
663                response_stream: Some(response_stream),
664                buffer: String::new(),
665                started: false,
666                tool_calls: HashMap::new(),
667                tool_calls_emitted: false,
668            },
669        );
670
671        Ok(stream_id)
672    }
673
674    fn llm_stream_completion_next(
675        &mut self,
676        stream_id: &str,
677    ) -> Result<Option<LlmCompletionEvent>, String> {
678        let mut streams = self.streams.lock().unwrap();
679        let state = streams
680            .get_mut(stream_id)
681            .ok_or_else(|| format!("Unknown stream: {}", stream_id))?;
682
683        if !state.started {
684            state.started = true;
685            return Ok(Some(LlmCompletionEvent::Started));
686        }
687
688        let response_stream = state
689            .response_stream
690            .as_mut()
691            .ok_or_else(|| "Stream already closed".to_string())?;
692
693        loop {
694            if let Some(newline_pos) = state.buffer.find('\n') {
695                let line = state.buffer[..newline_pos].to_string();
696                state.buffer = state.buffer[newline_pos + 1..].to_string();
697
698                if line.trim().is_empty() {
699                    continue;
700                }
701
702                if let Some(response) = parse_sse_line(&line) {
703                    if let Some(choice) = response.choices.first() {
704                        if let Some(content) = &choice.delta.content {
705                            if !content.is_empty() {
706                                return Ok(Some(LlmCompletionEvent::Text(content.clone())));
707                            }
708                        }
709
710                        if let Some(tool_calls) = &choice.delta.tool_calls {
711                            for tc in tool_calls {
712                                let entry = state
713                                    .tool_calls
714                                    .entry(tc.index)
715                                    .or_insert_with(AccumulatedToolCall::default);
716
717                                if let Some(id) = &tc.id {
718                                    entry.id = id.clone();
719                                }
720                                if let Some(func) = &tc.function {
721                                    if let Some(name) = &func.name {
722                                        entry.name = name.clone();
723                                    }
724                                    if let Some(args) = &func.arguments {
725                                        entry.arguments.push_str(args);
726                                    }
727                                }
728                            }
729                        }
730
731                        if let Some(finish_reason) = &choice.finish_reason {
732                            if !state.tool_calls.is_empty() && !state.tool_calls_emitted {
733                                state.tool_calls_emitted = true;
734                                let mut tool_calls: Vec<_> = state.tool_calls.drain().collect();
735                                tool_calls.sort_by_key(|(idx, _)| *idx);
736
737                                if let Some((_, tc)) = tool_calls.into_iter().next() {
738                                    return Ok(Some(LlmCompletionEvent::ToolUse(LlmToolUse {
739                                        id: tc.id,
740                                        name: tc.name,
741                                        input: tc.arguments,
742                                        is_input_complete: true,
743                                        thought_signature: None,
744                                    })));
745                                }
746                            }
747
748                            let stop_reason = match finish_reason.as_str() {
749                                "stop" => LlmStopReason::EndTurn,
750                                "length" => LlmStopReason::MaxTokens,
751                                "tool_calls" => LlmStopReason::ToolUse,
752                                "content_filter" => LlmStopReason::Refusal,
753                                _ => LlmStopReason::EndTurn,
754                            };
755                            return Ok(Some(LlmCompletionEvent::Stop(stop_reason)));
756                        }
757                    }
758
759                    if let Some(usage) = response.usage {
760                        return Ok(Some(LlmCompletionEvent::Usage(LlmTokenUsage {
761                            input_tokens: usage.prompt_tokens,
762                            output_tokens: usage.completion_tokens,
763                            cache_creation_input_tokens: None,
764                            cache_read_input_tokens: None,
765                        })));
766                    }
767                }
768
769                continue;
770            }
771
772            match response_stream.next_chunk() {
773                Ok(Some(chunk)) => {
774                    let text = String::from_utf8_lossy(&chunk);
775                    state.buffer.push_str(&text);
776                }
777                Ok(None) => {
778                    return Ok(None);
779                }
780                Err(e) => {
781                    return Err(format!("Stream error: {}", e));
782                }
783            }
784        }
785    }
786
787    fn llm_stream_completion_close(&mut self, stream_id: &str) {
788        self.streams.lock().unwrap().remove(stream_id);
789    }
790}
791
792impl CopilotChatProvider {
793    fn get_api_token(&self, oauth_token: &str) -> Result<ApiToken, String> {
794        // Check if we have a cached token
795        if let Some(token) = self.api_token.lock().unwrap().clone() {
796            return Ok(token);
797        }
798
799        // Request a new API token
800        let http_request = HttpRequest {
801            method: HttpMethod::Get,
802            url: GITHUB_COPILOT_TOKEN_URL.to_string(),
803            headers: vec![
804                (
805                    "Authorization".to_string(),
806                    format!("token {}", oauth_token),
807                ),
808                ("Accept".to_string(), "application/json".to_string()),
809            ],
810            body: None,
811            redirect_policy: RedirectPolicy::FollowAll,
812        };
813
814        let response = http_request
815            .fetch()
816            .map_err(|e| format!("Failed to request API token: {}", e))?;
817
818        #[derive(Deserialize)]
819        struct ApiTokenResponse {
820            token: String,
821            endpoints: ApiEndpoints,
822        }
823
824        #[derive(Deserialize)]
825        struct ApiEndpoints {
826            api: String,
827        }
828
829        let token_response: ApiTokenResponse =
830            serde_json::from_slice(&response.body).map_err(|e| {
831                format!(
832                    "Failed to parse API token response: {} - body: {}",
833                    e,
834                    String::from_utf8_lossy(&response.body)
835                )
836            })?;
837
838        let api_token = ApiToken {
839            api_key: token_response.token,
840            api_endpoint: token_response.endpoints.api,
841        };
842
843        // Cache the token
844        *self.api_token.lock().unwrap() = Some(api_token.clone());
845
846        Ok(api_token)
847    }
848
849    fn fetch_models(&self, api_token: &ApiToken) -> Result<Vec<CopilotModel>, String> {
850        let models_url = format!("{}/models", api_token.api_endpoint);
851
852        let http_request = HttpRequest {
853            method: HttpMethod::Get,
854            url: models_url,
855            headers: vec![
856                (
857                    "Authorization".to_string(),
858                    format!("Bearer {}", api_token.api_key),
859                ),
860                ("Content-Type".to_string(), "application/json".to_string()),
861                (
862                    "Copilot-Integration-Id".to_string(),
863                    "vscode-chat".to_string(),
864                ),
865                ("Editor-Version".to_string(), "Zed/1.0.0".to_string()),
866                ("x-github-api-version".to_string(), "2025-05-01".to_string()),
867            ],
868            body: None,
869            redirect_policy: RedirectPolicy::FollowAll,
870        };
871
872        let response = http_request
873            .fetch()
874            .map_err(|e| format!("Failed to fetch models: {}", e))?;
875
876        #[derive(Deserialize)]
877        struct ModelsResponse {
878            data: Vec<CopilotModel>,
879        }
880
881        let models_response: ModelsResponse =
882            serde_json::from_slice(&response.body).map_err(|e| {
883                format!(
884                    "Failed to parse models response: {} - body: {}",
885                    e,
886                    String::from_utf8_lossy(&response.body)
887                )
888            })?;
889
890        // Filter models like the built-in Copilot Chat does
891        let mut models: Vec<CopilotModel> = models_response
892            .data
893            .into_iter()
894            .filter(|model| {
895                model.model_picker_enabled
896                    && model.capabilities.model_type == "chat"
897                    && model
898                        .policy
899                        .as_ref()
900                        .map(|p| p.state == "enabled")
901                        .unwrap_or(true)
902            })
903            .collect();
904
905        // Sort so default model is first
906        if let Some(pos) = models.iter().position(|m| m.is_chat_default) {
907            let default_model = models.remove(pos);
908            models.insert(0, default_model);
909        }
910
911        Ok(models)
912    }
913}
914
915fn convert_models_to_llm_info(models: &[CopilotModel]) -> Vec<LlmModelInfo> {
916    models
917        .iter()
918        .map(|m| {
919            let max_tokens = if m.capabilities.limits.max_context_window_tokens > 0 {
920                m.capabilities.limits.max_context_window_tokens
921            } else {
922                128_000 // Default fallback
923            };
924            let max_output = if m.capabilities.limits.max_output_tokens > 0 {
925                Some(m.capabilities.limits.max_output_tokens)
926            } else {
927                None
928            };
929
930            LlmModelInfo {
931                id: m.id.clone(),
932                name: m.name.clone(),
933                max_token_count: max_tokens,
934                max_output_tokens: max_output,
935                capabilities: LlmModelCapabilities {
936                    supports_images: m.capabilities.supports.vision,
937                    supports_tools: m.capabilities.supports.tool_calls,
938                    supports_tool_choice_auto: m.capabilities.supports.tool_calls,
939                    supports_tool_choice_any: m.capabilities.supports.tool_calls,
940                    supports_tool_choice_none: m.capabilities.supports.tool_calls,
941                    supports_thinking: false,
942                    tool_input_format: LlmToolInputFormat::JsonSchema,
943                },
944                is_default: m.is_chat_default,
945                is_default_fast: m.is_chat_fallback,
946            }
947        })
948        .collect()
949}
950
951#[cfg(test)]
952mod tests {
953    use super::*;
954
955    #[test]
956    fn test_device_flow_request_body() {
957        let body = format!("client_id={}&scope=read:user", GITHUB_COPILOT_CLIENT_ID);
958        assert!(body.contains("client_id=Iv1.b507a08c87ecfe98"));
959        assert!(body.contains("scope=read:user"));
960    }
961
962    #[test]
963    fn test_token_poll_request_body() {
964        let device_code = "test_device_code_123";
965        let body = format!(
966            "client_id={}&device_code={}&grant_type=urn:ietf:params:oauth:grant-type:device_code",
967            GITHUB_COPILOT_CLIENT_ID, device_code
968        );
969        assert!(body.contains("client_id=Iv1.b507a08c87ecfe98"));
970        assert!(body.contains("device_code=test_device_code_123"));
971        assert!(body.contains("grant_type=urn:ietf:params:oauth:grant-type:device_code"));
972    }
973}
974
975zed::register_extension!(CopilotChatProvider);