copilot_chat.rs

   1use std::collections::HashMap;
   2use std::sync::Mutex;
   3use std::thread;
   4use std::time::Duration;
   5
   6use serde::{Deserialize, Serialize};
   7use zed_extension_api::http_client::{HttpMethod, HttpRequest, HttpResponseStream, RedirectPolicy};
   8use zed_extension_api::{self as zed, *};
   9
  10const GITHUB_DEVICE_CODE_URL: &str = "https://github.com/login/device/code";
  11const GITHUB_ACCESS_TOKEN_URL: &str = "https://github.com/login/oauth/access_token";
  12const GITHUB_COPILOT_TOKEN_URL: &str = "https://api.github.com/copilot_internal/v2/token";
  13const GITHUB_COPILOT_CLIENT_ID: &str = "Iv1.b507a08c87ecfe98";
  14
  15struct DeviceFlowState {
  16    device_code: String,
  17    interval: u64,
  18    expires_in: u64,
  19}
  20
  21#[derive(Clone)]
  22struct ApiToken {
  23    api_key: String,
  24    api_endpoint: String,
  25}
  26
  27#[derive(Clone, Deserialize)]
  28struct CopilotModel {
  29    id: String,
  30    name: String,
  31    #[serde(default)]
  32    is_chat_default: bool,
  33    #[serde(default)]
  34    is_chat_fallback: bool,
  35    #[serde(default)]
  36    model_picker_enabled: bool,
  37    #[serde(default)]
  38    capabilities: ModelCapabilities,
  39    #[serde(default)]
  40    policy: Option<ModelPolicy>,
  41}
  42
  43#[derive(Clone, Default, Deserialize)]
  44struct ModelCapabilities {
  45    #[serde(default)]
  46    family: String,
  47    #[serde(default)]
  48    limits: ModelLimits,
  49    #[serde(default)]
  50    supports: ModelSupportedFeatures,
  51    #[serde(rename = "type", default)]
  52    model_type: String,
  53}
  54
  55#[derive(Clone, Default, Deserialize)]
  56struct ModelLimits {
  57    #[serde(default)]
  58    max_context_window_tokens: u64,
  59    #[serde(default)]
  60    max_output_tokens: u64,
  61}
  62
  63#[derive(Clone, Default, Deserialize)]
  64struct ModelSupportedFeatures {
  65    #[serde(default)]
  66    streaming: bool,
  67    #[serde(default)]
  68    tool_calls: bool,
  69    #[serde(default)]
  70    vision: bool,
  71}
  72
  73#[derive(Clone, Deserialize)]
  74struct ModelPolicy {
  75    state: String,
  76}
  77
  78struct CopilotChatProvider {
  79    streams: Mutex<HashMap<String, StreamState>>,
  80    next_stream_id: Mutex<u64>,
  81    device_flow_state: Mutex<Option<DeviceFlowState>>,
  82    api_token: Mutex<Option<ApiToken>>,
  83    cached_models: Mutex<Option<Vec<CopilotModel>>>,
  84}
  85
  86struct StreamState {
  87    response_stream: Option<HttpResponseStream>,
  88    buffer: String,
  89    started: bool,
  90    tool_calls: HashMap<usize, AccumulatedToolCall>,
  91    tool_calls_emitted: bool,
  92}
  93
  94#[derive(Clone, Default)]
  95struct AccumulatedToolCall {
  96    id: String,
  97    name: String,
  98    arguments: String,
  99}
 100
 101#[derive(Serialize)]
 102struct OpenAiRequest {
 103    model: String,
 104    messages: Vec<OpenAiMessage>,
 105    #[serde(skip_serializing_if = "Option::is_none")]
 106    max_tokens: Option<u64>,
 107    #[serde(skip_serializing_if = "Vec::is_empty")]
 108    tools: Vec<OpenAiTool>,
 109    #[serde(skip_serializing_if = "Option::is_none")]
 110    tool_choice: Option<String>,
 111    #[serde(skip_serializing_if = "Vec::is_empty")]
 112    stop: Vec<String>,
 113    #[serde(skip_serializing_if = "Option::is_none")]
 114    temperature: Option<f32>,
 115    stream: bool,
 116    #[serde(skip_serializing_if = "Option::is_none")]
 117    stream_options: Option<StreamOptions>,
 118}
 119
 120#[derive(Serialize)]
 121struct StreamOptions {
 122    include_usage: bool,
 123}
 124
 125#[derive(Serialize)]
 126struct OpenAiMessage {
 127    role: String,
 128    #[serde(skip_serializing_if = "Option::is_none")]
 129    content: Option<OpenAiContent>,
 130    #[serde(skip_serializing_if = "Option::is_none")]
 131    tool_calls: Option<Vec<OpenAiToolCall>>,
 132    #[serde(skip_serializing_if = "Option::is_none")]
 133    tool_call_id: Option<String>,
 134}
 135
 136#[derive(Serialize, Clone)]
 137#[serde(untagged)]
 138enum OpenAiContent {
 139    Text(String),
 140    Parts(Vec<OpenAiContentPart>),
 141}
 142
 143#[derive(Serialize, Clone)]
 144#[serde(tag = "type")]
 145enum OpenAiContentPart {
 146    #[serde(rename = "text")]
 147    Text { text: String },
 148    #[serde(rename = "image_url")]
 149    ImageUrl { image_url: ImageUrl },
 150}
 151
 152#[derive(Serialize, Clone)]
 153struct ImageUrl {
 154    url: String,
 155}
 156
 157#[derive(Serialize, Clone)]
 158struct OpenAiToolCall {
 159    id: String,
 160    #[serde(rename = "type")]
 161    call_type: String,
 162    function: OpenAiFunctionCall,
 163}
 164
 165#[derive(Serialize, Clone)]
 166struct OpenAiFunctionCall {
 167    name: String,
 168    arguments: String,
 169}
 170
 171#[derive(Serialize)]
 172struct OpenAiTool {
 173    #[serde(rename = "type")]
 174    tool_type: String,
 175    function: OpenAiFunctionDef,
 176}
 177
 178#[derive(Serialize)]
 179struct OpenAiFunctionDef {
 180    name: String,
 181    description: String,
 182    parameters: serde_json::Value,
 183}
 184
 185#[derive(Deserialize, Debug)]
 186struct OpenAiStreamResponse {
 187    choices: Vec<OpenAiStreamChoice>,
 188    #[serde(default)]
 189    usage: Option<OpenAiUsage>,
 190}
 191
 192#[derive(Deserialize, Debug)]
 193struct OpenAiStreamChoice {
 194    delta: OpenAiDelta,
 195    finish_reason: Option<String>,
 196}
 197
 198#[derive(Deserialize, Debug, Default)]
 199struct OpenAiDelta {
 200    #[serde(default)]
 201    content: Option<String>,
 202    #[serde(default)]
 203    tool_calls: Option<Vec<OpenAiToolCallDelta>>,
 204}
 205
 206#[derive(Deserialize, Debug)]
 207struct OpenAiToolCallDelta {
 208    index: usize,
 209    #[serde(default)]
 210    id: Option<String>,
 211    #[serde(default)]
 212    function: Option<OpenAiFunctionDelta>,
 213}
 214
 215#[derive(Deserialize, Debug, Default)]
 216struct OpenAiFunctionDelta {
 217    #[serde(default)]
 218    name: Option<String>,
 219    #[serde(default)]
 220    arguments: Option<String>,
 221}
 222
 223#[derive(Deserialize, Debug)]
 224struct OpenAiUsage {
 225    prompt_tokens: u64,
 226    completion_tokens: u64,
 227}
 228
 229fn convert_request(
 230    model_id: &str,
 231    request: &LlmCompletionRequest,
 232) -> Result<OpenAiRequest, String> {
 233    let mut messages: Vec<OpenAiMessage> = Vec::new();
 234
 235    for msg in &request.messages {
 236        match msg.role {
 237            LlmMessageRole::System => {
 238                let mut text_content = String::new();
 239                for content in &msg.content {
 240                    if let LlmMessageContent::Text(text) = content {
 241                        if !text_content.is_empty() {
 242                            text_content.push('\n');
 243                        }
 244                        text_content.push_str(text);
 245                    }
 246                }
 247                if !text_content.is_empty() {
 248                    messages.push(OpenAiMessage {
 249                        role: "system".to_string(),
 250                        content: Some(OpenAiContent::Text(text_content)),
 251                        tool_calls: None,
 252                        tool_call_id: None,
 253                    });
 254                }
 255            }
 256            LlmMessageRole::User => {
 257                let mut parts: Vec<OpenAiContentPart> = Vec::new();
 258                let mut tool_result_messages: Vec<OpenAiMessage> = Vec::new();
 259
 260                for content in &msg.content {
 261                    match content {
 262                        LlmMessageContent::Text(text) => {
 263                            if !text.is_empty() {
 264                                parts.push(OpenAiContentPart::Text { text: text.clone() });
 265                            }
 266                        }
 267                        LlmMessageContent::Image(img) => {
 268                            let data_url = format!("data:image/png;base64,{}", img.source);
 269                            parts.push(OpenAiContentPart::ImageUrl {
 270                                image_url: ImageUrl { url: data_url },
 271                            });
 272                        }
 273                        LlmMessageContent::ToolResult(result) => {
 274                            let content_text = match &result.content {
 275                                LlmToolResultContent::Text(t) => t.clone(),
 276                                LlmToolResultContent::Image(_) => "[Image]".to_string(),
 277                            };
 278                            tool_result_messages.push(OpenAiMessage {
 279                                role: "tool".to_string(),
 280                                content: Some(OpenAiContent::Text(content_text)),
 281                                tool_calls: None,
 282                                tool_call_id: Some(result.tool_use_id.clone()),
 283                            });
 284                        }
 285                        _ => {}
 286                    }
 287                }
 288
 289                if !parts.is_empty() {
 290                    let content = if parts.len() == 1 {
 291                        if let OpenAiContentPart::Text { text } = &parts[0] {
 292                            OpenAiContent::Text(text.clone())
 293                        } else {
 294                            OpenAiContent::Parts(parts)
 295                        }
 296                    } else {
 297                        OpenAiContent::Parts(parts)
 298                    };
 299
 300                    messages.push(OpenAiMessage {
 301                        role: "user".to_string(),
 302                        content: Some(content),
 303                        tool_calls: None,
 304                        tool_call_id: None,
 305                    });
 306                }
 307
 308                messages.extend(tool_result_messages);
 309            }
 310            LlmMessageRole::Assistant => {
 311                let mut text_content = String::new();
 312                let mut tool_calls: Vec<OpenAiToolCall> = Vec::new();
 313
 314                for content in &msg.content {
 315                    match content {
 316                        LlmMessageContent::Text(text) => {
 317                            if !text.is_empty() {
 318                                if !text_content.is_empty() {
 319                                    text_content.push('\n');
 320                                }
 321                                text_content.push_str(text);
 322                            }
 323                        }
 324                        LlmMessageContent::ToolUse(tool_use) => {
 325                            tool_calls.push(OpenAiToolCall {
 326                                id: tool_use.id.clone(),
 327                                call_type: "function".to_string(),
 328                                function: OpenAiFunctionCall {
 329                                    name: tool_use.name.clone(),
 330                                    arguments: tool_use.input.clone(),
 331                                },
 332                            });
 333                        }
 334                        _ => {}
 335                    }
 336                }
 337
 338                messages.push(OpenAiMessage {
 339                    role: "assistant".to_string(),
 340                    content: if text_content.is_empty() {
 341                        None
 342                    } else {
 343                        Some(OpenAiContent::Text(text_content))
 344                    },
 345                    tool_calls: if tool_calls.is_empty() {
 346                        None
 347                    } else {
 348                        Some(tool_calls)
 349                    },
 350                    tool_call_id: None,
 351                });
 352            }
 353        }
 354    }
 355
 356    let tools: Vec<OpenAiTool> = request
 357        .tools
 358        .iter()
 359        .map(|t| OpenAiTool {
 360            tool_type: "function".to_string(),
 361            function: OpenAiFunctionDef {
 362                name: t.name.clone(),
 363                description: t.description.clone(),
 364                parameters: serde_json::from_str(&t.input_schema)
 365                    .unwrap_or(serde_json::Value::Object(Default::default())),
 366            },
 367        })
 368        .collect();
 369
 370    let tool_choice = request.tool_choice.as_ref().map(|tc| match tc {
 371        LlmToolChoice::Auto => "auto".to_string(),
 372        LlmToolChoice::Any => "required".to_string(),
 373        LlmToolChoice::None => "none".to_string(),
 374    });
 375
 376    let max_tokens = request.max_tokens;
 377
 378    Ok(OpenAiRequest {
 379        model: model_id.to_string(),
 380        messages,
 381        max_tokens,
 382        tools,
 383        tool_choice,
 384        stop: request.stop_sequences.clone(),
 385        temperature: request.temperature,
 386        stream: true,
 387        stream_options: Some(StreamOptions {
 388            include_usage: true,
 389        }),
 390    })
 391}
 392
 393fn parse_sse_line(line: &str) -> Option<OpenAiStreamResponse> {
 394    let data = line.strip_prefix("data: ")?;
 395    if data.trim() == "[DONE]" {
 396        return None;
 397    }
 398    serde_json::from_str(data).ok()
 399}
 400
 401impl zed::Extension for CopilotChatProvider {
 402    fn new() -> Self {
 403        Self {
 404            streams: Mutex::new(HashMap::new()),
 405            next_stream_id: Mutex::new(0),
 406            device_flow_state: Mutex::new(None),
 407            api_token: Mutex::new(None),
 408            cached_models: Mutex::new(None),
 409        }
 410    }
 411
 412    fn llm_providers(&self) -> Vec<LlmProviderInfo> {
 413        vec![LlmProviderInfo {
 414            id: "copilot-chat".into(),
 415            name: "Copilot Chat".into(),
 416            icon: Some("icons/copilot.svg".into()),
 417        }]
 418    }
 419
 420    fn llm_provider_models(&self, _provider_id: &str) -> Result<Vec<LlmModelInfo>, String> {
 421        // Try to get models from cache first
 422        if let Some(models) = self.cached_models.lock().unwrap().as_ref() {
 423            return Ok(convert_models_to_llm_info(models));
 424        }
 425
 426        // Need to fetch models - requires authentication
 427        let oauth_token = match llm_get_credential("copilot-chat") {
 428            Some(token) => token,
 429            None => return Ok(Vec::new()), // Not authenticated, return empty
 430        };
 431
 432        // Get API token
 433        let api_token = self.get_api_token(&oauth_token)?;
 434
 435        // Fetch models from API
 436        let models = self.fetch_models(&api_token)?;
 437
 438        // Cache the models
 439        *self.cached_models.lock().unwrap() = Some(models.clone());
 440
 441        Ok(convert_models_to_llm_info(&models))
 442    }
 443
 444    fn llm_provider_is_authenticated(&self, _provider_id: &str) -> bool {
 445        llm_get_credential("copilot-chat").is_some()
 446    }
 447
 448    fn llm_provider_settings_markdown(&self, _provider_id: &str) -> Option<String> {
 449        Some(
 450            r#"# Copilot Chat Setup
 451
 452Welcome to **Copilot Chat**! This extension provides access to GitHub Copilot's chat models.
 453
 454## Authentication
 455
 456Click **Sign in with GitHub** to authenticate with your GitHub account. You'll be redirected to GitHub to authorize access. This requires an active GitHub Copilot subscription.
 457
 458Alternatively, you can set the `GH_COPILOT_TOKEN` environment variable with your token.
 459
 460## Available Models
 461
 462| Model | Context | Output |
 463|-------|---------|--------|
 464| GPT-4o | 128K | 16K |
 465| GPT-4o Mini | 128K | 16K |
 466| GPT-4.1 | 1M | 32K |
 467| o1 | 200K | 100K |
 468| o3-mini | 200K | 100K |
 469| Claude 3.5 Sonnet | 200K | 8K |
 470| Claude 3.7 Sonnet | 200K | 8K |
 471| Gemini 2.0 Flash | 1M | 8K |
 472
 473## Features
 474
 475- ✅ Full streaming support
 476- ✅ Tool/function calling
 477- ✅ Vision (image inputs)
 478- ✅ Multiple model providers via Copilot
 479
 480## Note
 481
 482This extension requires an active GitHub Copilot subscription.
 483"#
 484            .to_string(),
 485        )
 486    }
 487
 488    fn llm_provider_authenticate(&mut self, _provider_id: &str) -> Result<(), String> {
 489        // Check if we have existing credentials
 490        if llm_get_credential("copilot-chat").is_some() {
 491            return Ok(());
 492        }
 493
 494        // No credentials found - return error for background auth checks.
 495        // The device flow will be triggered by the host when the user clicks
 496        // the "Sign in with GitHub" button, which calls llm_provider_start_device_flow_sign_in.
 497        Err("CredentialsNotFound".to_string())
 498    }
 499
 500    fn llm_provider_start_device_flow_sign_in(
 501        &mut self,
 502        _provider_id: &str,
 503    ) -> Result<String, String> {
 504        // Step 1: Request device and user verification codes
 505        let device_code_response = llm_oauth_http_request(&LlmOauthHttpRequest {
 506            url: GITHUB_DEVICE_CODE_URL.to_string(),
 507            method: "POST".to_string(),
 508            headers: vec![
 509                ("Accept".to_string(), "application/json".to_string()),
 510                (
 511                    "Content-Type".to_string(),
 512                    "application/x-www-form-urlencoded".to_string(),
 513                ),
 514            ],
 515            body: format!("client_id={}&scope=read:user", GITHUB_COPILOT_CLIENT_ID),
 516        })?;
 517
 518        if device_code_response.status != 200 {
 519            return Err(format!(
 520                "Failed to get device code: HTTP {}",
 521                device_code_response.status
 522            ));
 523        }
 524
 525        #[derive(Deserialize)]
 526        struct DeviceCodeResponse {
 527            device_code: String,
 528            user_code: String,
 529            verification_uri: String,
 530            #[serde(default)]
 531            verification_uri_complete: Option<String>,
 532            expires_in: u64,
 533            interval: u64,
 534        }
 535
 536        let device_info: DeviceCodeResponse = serde_json::from_str(&device_code_response.body)
 537            .map_err(|e| format!("Failed to parse device code response: {}", e))?;
 538
 539        // Store device flow state for polling
 540        *self.device_flow_state.lock().unwrap() = Some(DeviceFlowState {
 541            device_code: device_info.device_code,
 542            interval: device_info.interval,
 543            expires_in: device_info.expires_in,
 544        });
 545
 546        // Step 2: Open browser to verification URL
 547        // Use verification_uri_complete if available (has code pre-filled), otherwise construct URL
 548        let verification_url = device_info.verification_uri_complete.unwrap_or_else(|| {
 549            format!(
 550                "{}?user_code={}",
 551                device_info.verification_uri, &device_info.user_code
 552            )
 553        });
 554        llm_oauth_open_browser(&verification_url)?;
 555
 556        // Return the user code for the host to display
 557        Ok(device_info.user_code)
 558    }
 559
 560    fn llm_provider_poll_device_flow_sign_in(&mut self, _provider_id: &str) -> Result<(), String> {
 561        let state = self
 562            .device_flow_state
 563            .lock()
 564            .unwrap()
 565            .take()
 566            .ok_or("No device flow in progress")?;
 567
 568        let poll_interval = Duration::from_secs(state.interval.max(5));
 569        let max_attempts = (state.expires_in / state.interval.max(5)) as usize;
 570
 571        for _ in 0..max_attempts {
 572            thread::sleep(poll_interval);
 573
 574            let token_response = llm_oauth_http_request(&LlmOauthHttpRequest {
 575                url: GITHUB_ACCESS_TOKEN_URL.to_string(),
 576                method: "POST".to_string(),
 577                headers: vec![
 578                    ("Accept".to_string(), "application/json".to_string()),
 579                    (
 580                        "Content-Type".to_string(),
 581                        "application/x-www-form-urlencoded".to_string(),
 582                    ),
 583                ],
 584                body: format!(
 585                    "client_id={}&device_code={}&grant_type=urn:ietf:params:oauth:grant-type:device_code",
 586                    GITHUB_COPILOT_CLIENT_ID, state.device_code
 587                ),
 588            })?;
 589
 590            #[derive(Deserialize)]
 591            struct TokenResponse {
 592                access_token: Option<String>,
 593                error: Option<String>,
 594                error_description: Option<String>,
 595            }
 596
 597            let token_json: TokenResponse = serde_json::from_str(&token_response.body)
 598                .map_err(|e| format!("Failed to parse token response: {}", e))?;
 599
 600            if let Some(access_token) = token_json.access_token {
 601                llm_store_credential("copilot-chat", &access_token)?;
 602                return Ok(());
 603            }
 604
 605            if let Some(error) = &token_json.error {
 606                match error.as_str() {
 607                    "authorization_pending" => {
 608                        // User hasn't authorized yet, keep polling
 609                        continue;
 610                    }
 611                    "slow_down" => {
 612                        // Need to slow down polling
 613                        thread::sleep(Duration::from_secs(5));
 614                        continue;
 615                    }
 616                    "expired_token" => {
 617                        return Err("Device code expired. Please try again.".to_string());
 618                    }
 619                    "access_denied" => {
 620                        return Err("Authorization was denied.".to_string());
 621                    }
 622                    _ => {
 623                        let description = token_json.error_description.unwrap_or_default();
 624                        return Err(format!("OAuth error: {} - {}", error, description));
 625                    }
 626                }
 627            }
 628        }
 629
 630        Err("Authorization timed out. Please try again.".to_string())
 631    }
 632
 633    fn llm_provider_reset_credentials(&mut self, _provider_id: &str) -> Result<(), String> {
 634        // Clear cached API token and models
 635        *self.api_token.lock().unwrap() = None;
 636        *self.cached_models.lock().unwrap() = None;
 637        llm_delete_credential("copilot-chat")
 638    }
 639
 640    fn llm_stream_completion_start(
 641        &mut self,
 642        _provider_id: &str,
 643        model_id: &str,
 644        request: &LlmCompletionRequest,
 645    ) -> Result<String, String> {
 646        let oauth_token = llm_get_credential("copilot-chat").ok_or_else(|| {
 647            "No token configured. Please add your GitHub Copilot token in settings.".to_string()
 648        })?;
 649
 650        // Get or refresh API token
 651        let api_token = self.get_api_token(&oauth_token)?;
 652
 653        let openai_request = convert_request(model_id, request)?;
 654
 655        let body = serde_json::to_vec(&openai_request)
 656            .map_err(|e| format!("Failed to serialize request: {}", e))?;
 657
 658        let completions_url = format!("{}/chat/completions", api_token.api_endpoint);
 659
 660        let http_request = HttpRequest {
 661            method: HttpMethod::Post,
 662            url: completions_url,
 663            headers: vec![
 664                ("Content-Type".to_string(), "application/json".to_string()),
 665                (
 666                    "Authorization".to_string(),
 667                    format!("Bearer {}", api_token.api_key),
 668                ),
 669                (
 670                    "Copilot-Integration-Id".to_string(),
 671                    "vscode-chat".to_string(),
 672                ),
 673                ("Editor-Version".to_string(), "Zed/1.0.0".to_string()),
 674            ],
 675            body: Some(body),
 676            redirect_policy: RedirectPolicy::FollowAll,
 677        };
 678
 679        let response_stream = http_request
 680            .fetch_stream()
 681            .map_err(|e| format!("HTTP request failed: {}", e))?;
 682
 683        let stream_id = {
 684            let mut id_counter = self.next_stream_id.lock().unwrap();
 685            let id = format!("copilot-stream-{}", *id_counter);
 686            *id_counter += 1;
 687            id
 688        };
 689
 690        self.streams.lock().unwrap().insert(
 691            stream_id.clone(),
 692            StreamState {
 693                response_stream: Some(response_stream),
 694                buffer: String::new(),
 695                started: false,
 696                tool_calls: HashMap::new(),
 697                tool_calls_emitted: false,
 698            },
 699        );
 700
 701        Ok(stream_id)
 702    }
 703
 704    fn llm_stream_completion_next(
 705        &mut self,
 706        stream_id: &str,
 707    ) -> Result<Option<LlmCompletionEvent>, String> {
 708        let mut streams = self.streams.lock().unwrap();
 709        let state = streams
 710            .get_mut(stream_id)
 711            .ok_or_else(|| format!("Unknown stream: {}", stream_id))?;
 712
 713        if !state.started {
 714            state.started = true;
 715            return Ok(Some(LlmCompletionEvent::Started));
 716        }
 717
 718        let response_stream = state
 719            .response_stream
 720            .as_mut()
 721            .ok_or_else(|| "Stream already closed".to_string())?;
 722
 723        loop {
 724            if let Some(newline_pos) = state.buffer.find('\n') {
 725                let line = state.buffer[..newline_pos].to_string();
 726                state.buffer = state.buffer[newline_pos + 1..].to_string();
 727
 728                if line.trim().is_empty() {
 729                    continue;
 730                }
 731
 732                if let Some(response) = parse_sse_line(&line) {
 733                    if let Some(choice) = response.choices.first() {
 734                        if let Some(content) = &choice.delta.content {
 735                            if !content.is_empty() {
 736                                return Ok(Some(LlmCompletionEvent::Text(content.clone())));
 737                            }
 738                        }
 739
 740                        if let Some(tool_calls) = &choice.delta.tool_calls {
 741                            for tc in tool_calls {
 742                                let entry = state
 743                                    .tool_calls
 744                                    .entry(tc.index)
 745                                    .or_insert_with(AccumulatedToolCall::default);
 746
 747                                if let Some(id) = &tc.id {
 748                                    entry.id = id.clone();
 749                                }
 750                                if let Some(func) = &tc.function {
 751                                    if let Some(name) = &func.name {
 752                                        entry.name = name.clone();
 753                                    }
 754                                    if let Some(args) = &func.arguments {
 755                                        entry.arguments.push_str(args);
 756                                    }
 757                                }
 758                            }
 759                        }
 760
 761                        if let Some(finish_reason) = &choice.finish_reason {
 762                            if !state.tool_calls.is_empty() && !state.tool_calls_emitted {
 763                                state.tool_calls_emitted = true;
 764                                let mut tool_calls: Vec<_> = state.tool_calls.drain().collect();
 765                                tool_calls.sort_by_key(|(idx, _)| *idx);
 766
 767                                if let Some((_, tc)) = tool_calls.into_iter().next() {
 768                                    return Ok(Some(LlmCompletionEvent::ToolUse(LlmToolUse {
 769                                        id: tc.id,
 770                                        name: tc.name,
 771                                        input: tc.arguments,
 772                                        thought_signature: None,
 773                                    })));
 774                                }
 775                            }
 776
 777                            let stop_reason = match finish_reason.as_str() {
 778                                "stop" => LlmStopReason::EndTurn,
 779                                "length" => LlmStopReason::MaxTokens,
 780                                "tool_calls" => LlmStopReason::ToolUse,
 781                                "content_filter" => LlmStopReason::Refusal,
 782                                _ => LlmStopReason::EndTurn,
 783                            };
 784                            return Ok(Some(LlmCompletionEvent::Stop(stop_reason)));
 785                        }
 786                    }
 787
 788                    if let Some(usage) = response.usage {
 789                        return Ok(Some(LlmCompletionEvent::Usage(LlmTokenUsage {
 790                            input_tokens: usage.prompt_tokens,
 791                            output_tokens: usage.completion_tokens,
 792                            cache_creation_input_tokens: None,
 793                            cache_read_input_tokens: None,
 794                        })));
 795                    }
 796                }
 797
 798                continue;
 799            }
 800
 801            match response_stream.next_chunk() {
 802                Ok(Some(chunk)) => {
 803                    let text = String::from_utf8_lossy(&chunk);
 804                    state.buffer.push_str(&text);
 805                }
 806                Ok(None) => {
 807                    return Ok(None);
 808                }
 809                Err(e) => {
 810                    return Err(format!("Stream error: {}", e));
 811                }
 812            }
 813        }
 814    }
 815
 816    fn llm_stream_completion_close(&mut self, stream_id: &str) {
 817        self.streams.lock().unwrap().remove(stream_id);
 818    }
 819}
 820
 821impl CopilotChatProvider {
 822    fn get_api_token(&self, oauth_token: &str) -> Result<ApiToken, String> {
 823        // Check if we have a cached token
 824        if let Some(token) = self.api_token.lock().unwrap().clone() {
 825            return Ok(token);
 826        }
 827
 828        // Request a new API token
 829        let http_request = HttpRequest {
 830            method: HttpMethod::Get,
 831            url: GITHUB_COPILOT_TOKEN_URL.to_string(),
 832            headers: vec![
 833                (
 834                    "Authorization".to_string(),
 835                    format!("token {}", oauth_token),
 836                ),
 837                ("Accept".to_string(), "application/json".to_string()),
 838            ],
 839            body: None,
 840            redirect_policy: RedirectPolicy::FollowAll,
 841        };
 842
 843        let response = http_request
 844            .fetch()
 845            .map_err(|e| format!("Failed to request API token: {}", e))?;
 846
 847        #[derive(Deserialize)]
 848        struct ApiTokenResponse {
 849            token: String,
 850            endpoints: ApiEndpoints,
 851        }
 852
 853        #[derive(Deserialize)]
 854        struct ApiEndpoints {
 855            api: String,
 856        }
 857
 858        let token_response: ApiTokenResponse =
 859            serde_json::from_slice(&response.body).map_err(|e| {
 860                format!(
 861                    "Failed to parse API token response: {} - body: {}",
 862                    e,
 863                    String::from_utf8_lossy(&response.body)
 864                )
 865            })?;
 866
 867        let api_token = ApiToken {
 868            api_key: token_response.token,
 869            api_endpoint: token_response.endpoints.api,
 870        };
 871
 872        // Cache the token
 873        *self.api_token.lock().unwrap() = Some(api_token.clone());
 874
 875        Ok(api_token)
 876    }
 877
 878    fn fetch_models(&self, api_token: &ApiToken) -> Result<Vec<CopilotModel>, String> {
 879        let models_url = format!("{}/models", api_token.api_endpoint);
 880
 881        let http_request = HttpRequest {
 882            method: HttpMethod::Get,
 883            url: models_url,
 884            headers: vec![
 885                (
 886                    "Authorization".to_string(),
 887                    format!("Bearer {}", api_token.api_key),
 888                ),
 889                ("Content-Type".to_string(), "application/json".to_string()),
 890                (
 891                    "Copilot-Integration-Id".to_string(),
 892                    "vscode-chat".to_string(),
 893                ),
 894                ("Editor-Version".to_string(), "Zed/1.0.0".to_string()),
 895                ("x-github-api-version".to_string(), "2025-05-01".to_string()),
 896            ],
 897            body: None,
 898            redirect_policy: RedirectPolicy::FollowAll,
 899        };
 900
 901        let response = http_request
 902            .fetch()
 903            .map_err(|e| format!("Failed to fetch models: {}", e))?;
 904
 905        #[derive(Deserialize)]
 906        struct ModelsResponse {
 907            data: Vec<CopilotModel>,
 908        }
 909
 910        let models_response: ModelsResponse =
 911            serde_json::from_slice(&response.body).map_err(|e| {
 912                format!(
 913                    "Failed to parse models response: {} - body: {}",
 914                    e,
 915                    String::from_utf8_lossy(&response.body)
 916                )
 917            })?;
 918
 919        // Filter models like the built-in Copilot Chat does
 920        let mut models: Vec<CopilotModel> = models_response
 921            .data
 922            .into_iter()
 923            .filter(|model| {
 924                model.model_picker_enabled
 925                    && model.capabilities.model_type == "chat"
 926                    && model
 927                        .policy
 928                        .as_ref()
 929                        .map(|p| p.state == "enabled")
 930                        .unwrap_or(true)
 931            })
 932            .collect();
 933
 934        // Sort so default model is first
 935        if let Some(pos) = models.iter().position(|m| m.is_chat_default) {
 936            let default_model = models.remove(pos);
 937            models.insert(0, default_model);
 938        }
 939
 940        Ok(models)
 941    }
 942}
 943
 944fn convert_models_to_llm_info(models: &[CopilotModel]) -> Vec<LlmModelInfo> {
 945    models
 946        .iter()
 947        .map(|m| {
 948            let max_tokens = if m.capabilities.limits.max_context_window_tokens > 0 {
 949                m.capabilities.limits.max_context_window_tokens
 950            } else {
 951                128_000 // Default fallback
 952            };
 953            let max_output = if m.capabilities.limits.max_output_tokens > 0 {
 954                Some(m.capabilities.limits.max_output_tokens)
 955            } else {
 956                None
 957            };
 958
 959            LlmModelInfo {
 960                id: m.id.clone(),
 961                name: m.name.clone(),
 962                max_token_count: max_tokens,
 963                max_output_tokens: max_output,
 964                capabilities: LlmModelCapabilities {
 965                    supports_images: m.capabilities.supports.vision,
 966                    supports_tools: m.capabilities.supports.tool_calls,
 967                    supports_tool_choice_auto: m.capabilities.supports.tool_calls,
 968                    supports_tool_choice_any: m.capabilities.supports.tool_calls,
 969                    supports_tool_choice_none: m.capabilities.supports.tool_calls,
 970                    supports_thinking: false,
 971                    tool_input_format: LlmToolInputFormat::JsonSchema,
 972                },
 973                is_default: m.is_chat_default,
 974                is_default_fast: m.is_chat_fallback,
 975            }
 976        })
 977        .collect()
 978}
 979
 980#[cfg(test)]
 981mod tests {
 982    use super::*;
 983
 984    #[test]
 985    fn test_device_flow_request_body() {
 986        let body = format!("client_id={}&scope=read:user", GITHUB_COPILOT_CLIENT_ID);
 987        assert!(body.contains("client_id=Iv1.b507a08c87ecfe98"));
 988        assert!(body.contains("scope=read:user"));
 989    }
 990
 991    #[test]
 992    fn test_token_poll_request_body() {
 993        let device_code = "test_device_code_123";
 994        let body = format!(
 995            "client_id={}&device_code={}&grant_type=urn:ietf:params:oauth:grant-type:device_code",
 996            GITHUB_COPILOT_CLIENT_ID, device_code
 997        );
 998        assert!(body.contains("client_id=Iv1.b507a08c87ecfe98"));
 999        assert!(body.contains("device_code=test_device_code_123"));
1000        assert!(body.contains("grant_type=urn:ietf:params:oauth:grant-type:device_code"));
1001    }
1002}
1003
1004zed::register_extension!(CopilotChatProvider);