copilot_chat.rs

   1pub mod responses;
   2
   3use std::path::PathBuf;
   4use std::sync::Arc;
   5use std::sync::OnceLock;
   6
   7use anyhow::Context as _;
   8use anyhow::{Result, anyhow};
   9use chrono::DateTime;
  10use collections::HashSet;
  11use fs::Fs;
  12use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
  13use gpui::WeakEntity;
  14use gpui::{App, AsyncApp, Global, prelude::*};
  15use http_client::HttpRequestExt;
  16use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
  17use itertools::Itertools;
  18use paths::home_dir;
  19use serde::{Deserialize, Serialize};
  20
  21use settings::watch_config_dir;
  22
  23pub const COPILOT_OAUTH_ENV_VAR: &str = "GH_COPILOT_TOKEN";
  24
  25#[derive(Default, Clone, Debug, PartialEq)]
  26pub struct CopilotChatConfiguration {
  27    pub enterprise_uri: Option<String>,
  28}
  29
  30impl CopilotChatConfiguration {
  31    pub fn token_url(&self) -> String {
  32        if let Some(enterprise_uri) = &self.enterprise_uri {
  33            let domain = Self::parse_domain(enterprise_uri);
  34            format!("https://api.{}/copilot_internal/v2/token", domain)
  35        } else {
  36            "https://api.github.com/copilot_internal/v2/token".to_string()
  37        }
  38    }
  39
  40    pub fn oauth_domain(&self) -> String {
  41        if let Some(enterprise_uri) = &self.enterprise_uri {
  42            Self::parse_domain(enterprise_uri)
  43        } else {
  44            "github.com".to_string()
  45        }
  46    }
  47
  48    pub fn chat_completions_url_from_endpoint(&self, endpoint: &str) -> String {
  49        format!("{}/chat/completions", endpoint)
  50    }
  51
  52    pub fn responses_url_from_endpoint(&self, endpoint: &str) -> String {
  53        format!("{}/responses", endpoint)
  54    }
  55
  56    pub fn models_url_from_endpoint(&self, endpoint: &str) -> String {
  57        format!("{}/models", endpoint)
  58    }
  59
  60    fn parse_domain(enterprise_uri: &str) -> String {
  61        let uri = enterprise_uri.trim_end_matches('/');
  62
  63        if let Some(domain) = uri.strip_prefix("https://") {
  64            domain.split('/').next().unwrap_or(domain).to_string()
  65        } else if let Some(domain) = uri.strip_prefix("http://") {
  66            domain.split('/').next().unwrap_or(domain).to_string()
  67        } else {
  68            uri.split('/').next().unwrap_or(uri).to_string()
  69        }
  70    }
  71}
  72
  73#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
  74#[serde(rename_all = "lowercase")]
  75pub enum Role {
  76    User,
  77    Assistant,
  78    System,
  79}
  80
  81#[derive(Deserialize, Serialize, Debug, Clone, PartialEq)]
  82pub enum ModelSupportedEndpoint {
  83    #[serde(rename = "/chat/completions")]
  84    ChatCompletions,
  85    #[serde(rename = "/responses")]
  86    Responses,
  87}
  88
  89#[derive(Deserialize)]
  90struct ModelSchema {
  91    #[serde(deserialize_with = "deserialize_models_skip_errors")]
  92    data: Vec<Model>,
  93}
  94
  95fn deserialize_models_skip_errors<'de, D>(deserializer: D) -> Result<Vec<Model>, D::Error>
  96where
  97    D: serde::Deserializer<'de>,
  98{
  99    let raw_values = Vec::<serde_json::Value>::deserialize(deserializer)?;
 100    let models = raw_values
 101        .into_iter()
 102        .filter_map(|value| match serde_json::from_value::<Model>(value) {
 103            Ok(model) => Some(model),
 104            Err(err) => {
 105                log::warn!("GitHub Copilot Chat model failed to deserialize: {:?}", err);
 106                None
 107            }
 108        })
 109        .collect();
 110
 111    Ok(models)
 112}
 113
 114#[derive(Clone, Serialize, Deserialize, Debug, PartialEq)]
 115pub struct Model {
 116    billing: ModelBilling,
 117    capabilities: ModelCapabilities,
 118    id: String,
 119    name: String,
 120    policy: Option<ModelPolicy>,
 121    vendor: ModelVendor,
 122    is_chat_default: bool,
 123    // The model with this value true is selected by VSCode copilot if a premium request limit is
 124    // reached. Zed does not currently implement this behaviour
 125    is_chat_fallback: bool,
 126    model_picker_enabled: bool,
 127    #[serde(default)]
 128    supported_endpoints: Vec<ModelSupportedEndpoint>,
 129}
 130
 131#[derive(Clone, Serialize, Deserialize, Debug, PartialEq)]
 132struct ModelBilling {
 133    is_premium: bool,
 134    multiplier: f64,
 135    // List of plans a model is restricted to
 136    // Field is not present if a model is available for all plans
 137    #[serde(default)]
 138    restricted_to: Option<Vec<String>>,
 139}
 140
 141#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
 142struct ModelCapabilities {
 143    family: String,
 144    #[serde(default)]
 145    limits: ModelLimits,
 146    supports: ModelSupportedFeatures,
 147    #[serde(rename = "type")]
 148    model_type: String,
 149    #[serde(default)]
 150    tokenizer: Option<String>,
 151}
 152
 153#[derive(Default, Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
 154struct ModelLimits {
 155    #[serde(default)]
 156    max_context_window_tokens: usize,
 157    #[serde(default)]
 158    max_output_tokens: usize,
 159    #[serde(default)]
 160    max_prompt_tokens: u64,
 161}
 162
 163#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
 164struct ModelPolicy {
 165    state: String,
 166}
 167
 168#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
 169struct ModelSupportedFeatures {
 170    #[serde(default)]
 171    streaming: bool,
 172    #[serde(default)]
 173    tool_calls: bool,
 174    #[serde(default)]
 175    parallel_tool_calls: bool,
 176    #[serde(default)]
 177    vision: bool,
 178}
 179
 180#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
 181pub enum ModelVendor {
 182    // Azure OpenAI should have no functional difference from OpenAI in Copilot Chat
 183    #[serde(alias = "Azure OpenAI")]
 184    OpenAI,
 185    Google,
 186    Anthropic,
 187    #[serde(rename = "xAI")]
 188    XAI,
 189    /// Unknown vendor that we don't explicitly support yet
 190    #[serde(other)]
 191    Unknown,
 192}
 193
 194#[derive(Serialize, Deserialize, Debug, Eq, PartialEq, Clone)]
 195#[serde(tag = "type")]
 196pub enum ChatMessagePart {
 197    #[serde(rename = "text")]
 198    Text { text: String },
 199    #[serde(rename = "image_url")]
 200    Image { image_url: ImageUrl },
 201}
 202
 203#[derive(Serialize, Deserialize, Debug, Eq, PartialEq, Clone)]
 204pub struct ImageUrl {
 205    pub url: String,
 206}
 207
 208impl Model {
 209    pub fn uses_streaming(&self) -> bool {
 210        self.capabilities.supports.streaming
 211    }
 212
 213    pub fn id(&self) -> &str {
 214        self.id.as_str()
 215    }
 216
 217    pub fn display_name(&self) -> &str {
 218        self.name.as_str()
 219    }
 220
 221    pub fn max_token_count(&self) -> u64 {
 222        self.capabilities.limits.max_prompt_tokens
 223    }
 224
 225    pub fn supports_tools(&self) -> bool {
 226        self.capabilities.supports.tool_calls
 227    }
 228
 229    pub fn vendor(&self) -> ModelVendor {
 230        self.vendor
 231    }
 232
 233    pub fn supports_vision(&self) -> bool {
 234        self.capabilities.supports.vision
 235    }
 236
 237    pub fn supports_parallel_tool_calls(&self) -> bool {
 238        self.capabilities.supports.parallel_tool_calls
 239    }
 240
 241    pub fn tokenizer(&self) -> Option<&str> {
 242        self.capabilities.tokenizer.as_deref()
 243    }
 244
 245    pub fn supports_response(&self) -> bool {
 246        self.supported_endpoints.len() > 0
 247            && !self
 248                .supported_endpoints
 249                .contains(&ModelSupportedEndpoint::ChatCompletions)
 250            && self
 251                .supported_endpoints
 252                .contains(&ModelSupportedEndpoint::Responses)
 253    }
 254}
 255
 256#[derive(Serialize, Deserialize)]
 257pub struct Request {
 258    pub intent: bool,
 259    pub n: usize,
 260    pub stream: bool,
 261    pub temperature: f32,
 262    pub model: String,
 263    pub messages: Vec<ChatMessage>,
 264    #[serde(default, skip_serializing_if = "Vec::is_empty")]
 265    pub tools: Vec<Tool>,
 266    #[serde(default, skip_serializing_if = "Option::is_none")]
 267    pub tool_choice: Option<ToolChoice>,
 268}
 269
 270#[derive(Serialize, Deserialize)]
 271pub struct Function {
 272    pub name: String,
 273    pub description: String,
 274    pub parameters: serde_json::Value,
 275}
 276
 277#[derive(Serialize, Deserialize)]
 278#[serde(tag = "type", rename_all = "snake_case")]
 279pub enum Tool {
 280    Function { function: Function },
 281}
 282
 283#[derive(Serialize, Deserialize, Debug)]
 284#[serde(rename_all = "lowercase")]
 285pub enum ToolChoice {
 286    Auto,
 287    Any,
 288    None,
 289}
 290
 291#[derive(Serialize, Deserialize, Debug)]
 292#[serde(tag = "role", rename_all = "lowercase")]
 293pub enum ChatMessage {
 294    Assistant {
 295        content: ChatMessageContent,
 296        #[serde(default, skip_serializing_if = "Vec::is_empty")]
 297        tool_calls: Vec<ToolCall>,
 298        #[serde(default, skip_serializing_if = "Option::is_none")]
 299        reasoning_opaque: Option<String>,
 300        #[serde(default, skip_serializing_if = "Option::is_none")]
 301        reasoning_text: Option<String>,
 302    },
 303    User {
 304        content: ChatMessageContent,
 305    },
 306    System {
 307        content: String,
 308    },
 309    Tool {
 310        content: ChatMessageContent,
 311        tool_call_id: String,
 312    },
 313}
 314
 315#[derive(Debug, Serialize, Deserialize)]
 316#[serde(untagged)]
 317pub enum ChatMessageContent {
 318    Plain(String),
 319    Multipart(Vec<ChatMessagePart>),
 320}
 321
 322impl ChatMessageContent {
 323    pub fn empty() -> Self {
 324        ChatMessageContent::Multipart(vec![])
 325    }
 326}
 327
 328impl From<Vec<ChatMessagePart>> for ChatMessageContent {
 329    fn from(mut parts: Vec<ChatMessagePart>) -> Self {
 330        if let [ChatMessagePart::Text { text }] = parts.as_mut_slice() {
 331            ChatMessageContent::Plain(std::mem::take(text))
 332        } else {
 333            ChatMessageContent::Multipart(parts)
 334        }
 335    }
 336}
 337
 338impl From<String> for ChatMessageContent {
 339    fn from(text: String) -> Self {
 340        ChatMessageContent::Plain(text)
 341    }
 342}
 343
 344#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
 345pub struct ToolCall {
 346    pub id: String,
 347    #[serde(flatten)]
 348    pub content: ToolCallContent,
 349}
 350
 351#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
 352#[serde(tag = "type", rename_all = "lowercase")]
 353pub enum ToolCallContent {
 354    Function { function: FunctionContent },
 355}
 356
 357#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
 358pub struct FunctionContent {
 359    pub name: String,
 360    pub arguments: String,
 361    #[serde(default, skip_serializing_if = "Option::is_none")]
 362    pub thought_signature: Option<String>,
 363}
 364
 365#[derive(Deserialize, Debug)]
 366#[serde(tag = "type", rename_all = "snake_case")]
 367pub struct ResponseEvent {
 368    pub choices: Vec<ResponseChoice>,
 369    pub id: String,
 370    pub usage: Option<Usage>,
 371}
 372
 373#[derive(Deserialize, Debug)]
 374pub struct Usage {
 375    pub completion_tokens: u64,
 376    pub prompt_tokens: u64,
 377    pub total_tokens: u64,
 378}
 379
 380#[derive(Debug, Deserialize)]
 381pub struct ResponseChoice {
 382    pub index: Option<usize>,
 383    pub finish_reason: Option<String>,
 384    pub delta: Option<ResponseDelta>,
 385    pub message: Option<ResponseDelta>,
 386}
 387
 388#[derive(Debug, Deserialize)]
 389pub struct ResponseDelta {
 390    pub content: Option<String>,
 391    pub role: Option<Role>,
 392    #[serde(default)]
 393    pub tool_calls: Vec<ToolCallChunk>,
 394    pub reasoning_opaque: Option<String>,
 395    pub reasoning_text: Option<String>,
 396}
 397#[derive(Deserialize, Debug, Eq, PartialEq)]
 398pub struct ToolCallChunk {
 399    pub index: Option<usize>,
 400    pub id: Option<String>,
 401    pub function: Option<FunctionChunk>,
 402}
 403
 404#[derive(Deserialize, Debug, Eq, PartialEq)]
 405pub struct FunctionChunk {
 406    pub name: Option<String>,
 407    pub arguments: Option<String>,
 408    pub thought_signature: Option<String>,
 409}
 410
 411#[derive(Deserialize)]
 412struct ApiTokenResponse {
 413    token: String,
 414    expires_at: i64,
 415    endpoints: ApiTokenResponseEndpoints,
 416}
 417
 418#[derive(Deserialize)]
 419struct ApiTokenResponseEndpoints {
 420    api: String,
 421}
 422
 423#[derive(Clone)]
 424struct ApiToken {
 425    api_key: String,
 426    expires_at: DateTime<chrono::Utc>,
 427    api_endpoint: String,
 428}
 429
 430impl ApiToken {
 431    pub fn remaining_seconds(&self) -> i64 {
 432        self.expires_at
 433            .timestamp()
 434            .saturating_sub(chrono::Utc::now().timestamp())
 435    }
 436}
 437
 438impl TryFrom<ApiTokenResponse> for ApiToken {
 439    type Error = anyhow::Error;
 440
 441    fn try_from(response: ApiTokenResponse) -> Result<Self, Self::Error> {
 442        let expires_at =
 443            DateTime::from_timestamp(response.expires_at, 0).context("invalid expires_at")?;
 444
 445        Ok(Self {
 446            api_key: response.token,
 447            expires_at,
 448            api_endpoint: response.endpoints.api,
 449        })
 450    }
 451}
 452
 453struct GlobalCopilotChat(gpui::Entity<CopilotChat>);
 454
 455impl Global for GlobalCopilotChat {}
 456
 457pub struct CopilotChat {
 458    oauth_token: Option<String>,
 459    api_token: Option<ApiToken>,
 460    configuration: CopilotChatConfiguration,
 461    models: Option<Vec<Model>>,
 462    client: Arc<dyn HttpClient>,
 463}
 464
 465pub fn init(
 466    fs: Arc<dyn Fs>,
 467    client: Arc<dyn HttpClient>,
 468    configuration: CopilotChatConfiguration,
 469    cx: &mut App,
 470) {
 471    let copilot_chat = cx.new(|cx| CopilotChat::new(fs, client, configuration, cx));
 472    cx.set_global(GlobalCopilotChat(copilot_chat));
 473}
 474
 475pub fn copilot_chat_config_dir() -> &'static PathBuf {
 476    static COPILOT_CHAT_CONFIG_DIR: OnceLock<PathBuf> = OnceLock::new();
 477
 478    COPILOT_CHAT_CONFIG_DIR.get_or_init(|| {
 479        let config_dir = if cfg!(target_os = "windows") {
 480            dirs::data_local_dir().expect("failed to determine LocalAppData directory")
 481        } else {
 482            std::env::var("XDG_CONFIG_HOME")
 483                .map(PathBuf::from)
 484                .unwrap_or_else(|_| home_dir().join(".config"))
 485        };
 486
 487        config_dir.join("github-copilot")
 488    })
 489}
 490
 491fn copilot_chat_config_paths() -> [PathBuf; 2] {
 492    let base_dir = copilot_chat_config_dir();
 493    [base_dir.join("hosts.json"), base_dir.join("apps.json")]
 494}
 495
 496impl CopilotChat {
 497    pub fn global(cx: &App) -> Option<gpui::Entity<Self>> {
 498        cx.try_global::<GlobalCopilotChat>()
 499            .map(|model| model.0.clone())
 500    }
 501
 502    fn new(
 503        fs: Arc<dyn Fs>,
 504        client: Arc<dyn HttpClient>,
 505        configuration: CopilotChatConfiguration,
 506        cx: &mut Context<Self>,
 507    ) -> Self {
 508        let config_paths: HashSet<PathBuf> = copilot_chat_config_paths().into_iter().collect();
 509        let dir_path = copilot_chat_config_dir();
 510
 511        cx.spawn(async move |this, cx| {
 512            let mut parent_watch_rx = watch_config_dir(
 513                cx.background_executor(),
 514                fs.clone(),
 515                dir_path.clone(),
 516                config_paths,
 517            );
 518            while let Some(contents) = parent_watch_rx.next().await {
 519                let oauth_domain =
 520                    this.read_with(cx, |this, _| this.configuration.oauth_domain())?;
 521                let oauth_token = extract_oauth_token(contents, &oauth_domain);
 522
 523                this.update(cx, |this, cx| {
 524                    this.oauth_token = oauth_token.clone();
 525                    cx.notify();
 526                })?;
 527
 528                if oauth_token.is_some() {
 529                    Self::update_models(&this, cx).await?;
 530                }
 531            }
 532            anyhow::Ok(())
 533        })
 534        .detach_and_log_err(cx);
 535
 536        let this = Self {
 537            oauth_token: std::env::var(COPILOT_OAUTH_ENV_VAR).ok(),
 538            api_token: None,
 539            models: None,
 540            configuration,
 541            client,
 542        };
 543
 544        if this.oauth_token.is_some() {
 545            cx.spawn(async move |this, cx| Self::update_models(&this, cx).await)
 546                .detach_and_log_err(cx);
 547        }
 548
 549        this
 550    }
 551
 552    async fn update_models(this: &WeakEntity<Self>, cx: &mut AsyncApp) -> Result<()> {
 553        let (oauth_token, client, configuration) = this.read_with(cx, |this, _| {
 554            (
 555                this.oauth_token.clone(),
 556                this.client.clone(),
 557                this.configuration.clone(),
 558            )
 559        })?;
 560
 561        let oauth_token = oauth_token
 562            .ok_or_else(|| anyhow!("OAuth token is missing while updating Copilot Chat models"))?;
 563
 564        let token_url = configuration.token_url();
 565        let api_token = request_api_token(&oauth_token, token_url.into(), client.clone()).await?;
 566
 567        let models_url = configuration.models_url_from_endpoint(&api_token.api_endpoint);
 568        let models =
 569            get_models(models_url.into(), api_token.api_key.clone(), client.clone()).await?;
 570
 571        this.update(cx, |this, cx| {
 572            this.api_token = Some(api_token);
 573            this.models = Some(models);
 574            cx.notify();
 575        })?;
 576        anyhow::Ok(())
 577    }
 578
 579    pub fn is_authenticated(&self) -> bool {
 580        self.oauth_token.is_some()
 581    }
 582
 583    pub fn models(&self) -> Option<&[Model]> {
 584        self.models.as_deref()
 585    }
 586
 587    pub async fn stream_completion(
 588        request: Request,
 589        is_user_initiated: bool,
 590        mut cx: AsyncApp,
 591    ) -> Result<BoxStream<'static, Result<ResponseEvent>>> {
 592        let (client, token, configuration) = Self::get_auth_details(&mut cx).await?;
 593
 594        let api_url = configuration.chat_completions_url_from_endpoint(&token.api_endpoint);
 595        stream_completion(
 596            client.clone(),
 597            token.api_key,
 598            api_url.into(),
 599            request,
 600            is_user_initiated,
 601        )
 602        .await
 603    }
 604
 605    pub async fn stream_response(
 606        request: responses::Request,
 607        is_user_initiated: bool,
 608        mut cx: AsyncApp,
 609    ) -> Result<BoxStream<'static, Result<responses::StreamEvent>>> {
 610        let (client, token, configuration) = Self::get_auth_details(&mut cx).await?;
 611
 612        let api_url = configuration.responses_url_from_endpoint(&token.api_endpoint);
 613        responses::stream_response(
 614            client.clone(),
 615            token.api_key,
 616            api_url,
 617            request,
 618            is_user_initiated,
 619        )
 620        .await
 621    }
 622
 623    async fn get_auth_details(
 624        cx: &mut AsyncApp,
 625    ) -> Result<(Arc<dyn HttpClient>, ApiToken, CopilotChatConfiguration)> {
 626        let this = cx
 627            .update(|cx| Self::global(cx))
 628            .context("Copilot chat is not enabled")?;
 629
 630        let (oauth_token, api_token, client, configuration) = this.read_with(cx, |this, _| {
 631            (
 632                this.oauth_token.clone(),
 633                this.api_token.clone(),
 634                this.client.clone(),
 635                this.configuration.clone(),
 636            )
 637        });
 638
 639        let oauth_token = oauth_token.context("No OAuth token available")?;
 640
 641        let token = match api_token {
 642            Some(api_token) if api_token.remaining_seconds() > 5 * 60 => api_token,
 643            _ => {
 644                let token_url = configuration.token_url();
 645                let token =
 646                    request_api_token(&oauth_token, token_url.into(), client.clone()).await?;
 647                this.update(cx, |this, cx| {
 648                    this.api_token = Some(token.clone());
 649                    cx.notify();
 650                });
 651                token
 652            }
 653        };
 654
 655        Ok((client, token, configuration))
 656    }
 657
 658    pub fn set_configuration(
 659        &mut self,
 660        configuration: CopilotChatConfiguration,
 661        cx: &mut Context<Self>,
 662    ) {
 663        let same_configuration = self.configuration == configuration;
 664        self.configuration = configuration;
 665        if !same_configuration {
 666            self.api_token = None;
 667            cx.spawn(async move |this, cx| {
 668                Self::update_models(&this, cx).await?;
 669                Ok::<_, anyhow::Error>(())
 670            })
 671            .detach();
 672        }
 673    }
 674}
 675
 676async fn get_models(
 677    models_url: Arc<str>,
 678    api_token: String,
 679    client: Arc<dyn HttpClient>,
 680) -> Result<Vec<Model>> {
 681    let all_models = request_models(models_url, api_token, client).await?;
 682
 683    let mut models: Vec<Model> = all_models
 684        .into_iter()
 685        .filter(|model| {
 686            model.model_picker_enabled
 687                && model.capabilities.model_type.as_str() == "chat"
 688                && model
 689                    .policy
 690                    .as_ref()
 691                    .is_none_or(|policy| policy.state == "enabled")
 692        })
 693        .dedup_by(|a, b| a.capabilities.family == b.capabilities.family)
 694        .collect();
 695
 696    if let Some(default_model_position) = models.iter().position(|model| model.is_chat_default) {
 697        let default_model = models.remove(default_model_position);
 698        models.insert(0, default_model);
 699    }
 700
 701    Ok(models)
 702}
 703
 704async fn request_models(
 705    models_url: Arc<str>,
 706    api_token: String,
 707    client: Arc<dyn HttpClient>,
 708) -> Result<Vec<Model>> {
 709    let request_builder = HttpRequest::builder()
 710        .method(Method::GET)
 711        .uri(models_url.as_ref())
 712        .header("Authorization", format!("Bearer {}", api_token))
 713        .header("Content-Type", "application/json")
 714        .header("Copilot-Integration-Id", "vscode-chat")
 715        .header("Editor-Version", "vscode/1.103.2")
 716        .header("x-github-api-version", "2025-05-01");
 717
 718    let request = request_builder.body(AsyncBody::empty())?;
 719
 720    let mut response = client.send(request).await?;
 721
 722    anyhow::ensure!(
 723        response.status().is_success(),
 724        "Failed to request models: {}",
 725        response.status()
 726    );
 727    let mut body = Vec::new();
 728    response.body_mut().read_to_end(&mut body).await?;
 729
 730    let body_str = std::str::from_utf8(&body)?;
 731
 732    let models = serde_json::from_str::<ModelSchema>(body_str)?.data;
 733
 734    Ok(models)
 735}
 736
 737async fn request_api_token(
 738    oauth_token: &str,
 739    auth_url: Arc<str>,
 740    client: Arc<dyn HttpClient>,
 741) -> Result<ApiToken> {
 742    let request_builder = HttpRequest::builder()
 743        .method(Method::GET)
 744        .uri(auth_url.as_ref())
 745        .header("Authorization", format!("token {}", oauth_token))
 746        .header("Accept", "application/json");
 747
 748    let request = request_builder.body(AsyncBody::empty())?;
 749
 750    let mut response = client.send(request).await?;
 751
 752    if response.status().is_success() {
 753        let mut body = Vec::new();
 754        response.body_mut().read_to_end(&mut body).await?;
 755
 756        let body_str = std::str::from_utf8(&body)?;
 757
 758        let parsed: ApiTokenResponse = serde_json::from_str(body_str)?;
 759        ApiToken::try_from(parsed)
 760    } else {
 761        let mut body = Vec::new();
 762        response.body_mut().read_to_end(&mut body).await?;
 763
 764        let body_str = std::str::from_utf8(&body)?;
 765        anyhow::bail!("Failed to request API token: {body_str}");
 766    }
 767}
 768
 769fn extract_oauth_token(contents: String, domain: &str) -> Option<String> {
 770    serde_json::from_str::<serde_json::Value>(&contents)
 771        .map(|v| {
 772            v.as_object().and_then(|obj| {
 773                obj.iter().find_map(|(key, value)| {
 774                    if key.starts_with(domain) {
 775                        value["oauth_token"].as_str().map(|v| v.to_string())
 776                    } else {
 777                        None
 778                    }
 779                })
 780            })
 781        })
 782        .ok()
 783        .flatten()
 784}
 785
 786async fn stream_completion(
 787    client: Arc<dyn HttpClient>,
 788    api_key: String,
 789    completion_url: Arc<str>,
 790    request: Request,
 791    is_user_initiated: bool,
 792) -> Result<BoxStream<'static, Result<ResponseEvent>>> {
 793    let is_vision_request = request.messages.iter().any(|message| match message {
 794        ChatMessage::User { content }
 795        | ChatMessage::Assistant { content, .. }
 796        | ChatMessage::Tool { content, .. } => {
 797            matches!(content, ChatMessageContent::Multipart(parts) if parts.iter().any(|part| matches!(part, ChatMessagePart::Image { .. })))
 798        }
 799        _ => false,
 800    });
 801
 802    let request_initiator = if is_user_initiated { "user" } else { "agent" };
 803
 804    let request_builder = HttpRequest::builder()
 805        .method(Method::POST)
 806        .uri(completion_url.as_ref())
 807        .header(
 808            "Editor-Version",
 809            format!(
 810                "Zed/{}",
 811                option_env!("CARGO_PKG_VERSION").unwrap_or("unknown")
 812            ),
 813        )
 814        .header("Authorization", format!("Bearer {}", api_key))
 815        .header("Content-Type", "application/json")
 816        .header("Copilot-Integration-Id", "vscode-chat")
 817        .header("X-Initiator", request_initiator)
 818        .when(is_vision_request, |builder| {
 819            builder.header("Copilot-Vision-Request", is_vision_request.to_string())
 820        });
 821
 822    let is_streaming = request.stream;
 823
 824    let json = serde_json::to_string(&request)?;
 825    let request = request_builder.body(AsyncBody::from(json))?;
 826    let mut response = client.send(request).await?;
 827
 828    if !response.status().is_success() {
 829        let mut body = Vec::new();
 830        response.body_mut().read_to_end(&mut body).await?;
 831        let body_str = std::str::from_utf8(&body)?;
 832        anyhow::bail!(
 833            "Failed to connect to API: {} {}",
 834            response.status(),
 835            body_str
 836        );
 837    }
 838
 839    if is_streaming {
 840        let reader = BufReader::new(response.into_body());
 841        Ok(reader
 842            .lines()
 843            .filter_map(|line| async move {
 844                match line {
 845                    Ok(line) => {
 846                        let line = line.strip_prefix("data: ")?;
 847                        if line.starts_with("[DONE]") {
 848                            return None;
 849                        }
 850
 851                        match serde_json::from_str::<ResponseEvent>(line) {
 852                            Ok(response) => {
 853                                if response.choices.is_empty() {
 854                                    None
 855                                } else {
 856                                    Some(Ok(response))
 857                                }
 858                            }
 859                            Err(error) => Some(Err(anyhow!(error))),
 860                        }
 861                    }
 862                    Err(error) => Some(Err(anyhow!(error))),
 863                }
 864            })
 865            .boxed())
 866    } else {
 867        let mut body = Vec::new();
 868        response.body_mut().read_to_end(&mut body).await?;
 869        let body_str = std::str::from_utf8(&body)?;
 870        let response: ResponseEvent = serde_json::from_str(body_str)?;
 871
 872        Ok(futures::stream::once(async move { Ok(response) }).boxed())
 873    }
 874}
 875
 876#[cfg(test)]
 877mod tests {
 878    use super::*;
 879
 880    #[test]
 881    fn test_resilient_model_schema_deserialize() {
 882        let json = r#"{
 883              "data": [
 884                {
 885                  "billing": {
 886                    "is_premium": false,
 887                    "multiplier": 0
 888                  },
 889                  "capabilities": {
 890                    "family": "gpt-4",
 891                    "limits": {
 892                      "max_context_window_tokens": 32768,
 893                      "max_output_tokens": 4096,
 894                      "max_prompt_tokens": 32768
 895                    },
 896                    "object": "model_capabilities",
 897                    "supports": { "streaming": true, "tool_calls": true },
 898                    "tokenizer": "cl100k_base",
 899                    "type": "chat"
 900                  },
 901                  "id": "gpt-4",
 902                  "is_chat_default": false,
 903                  "is_chat_fallback": false,
 904                  "model_picker_enabled": false,
 905                  "name": "GPT 4",
 906                  "object": "model",
 907                  "preview": false,
 908                  "vendor": "Azure OpenAI",
 909                  "version": "gpt-4-0613"
 910                },
 911                {
 912                    "some-unknown-field": 123
 913                },
 914                {
 915                  "billing": {
 916                    "is_premium": true,
 917                    "multiplier": 1,
 918                    "restricted_to": [
 919                      "pro",
 920                      "pro_plus",
 921                      "business",
 922                      "enterprise"
 923                    ]
 924                  },
 925                  "capabilities": {
 926                    "family": "claude-3.7-sonnet",
 927                    "limits": {
 928                      "max_context_window_tokens": 200000,
 929                      "max_output_tokens": 16384,
 930                      "max_prompt_tokens": 90000,
 931                      "vision": {
 932                        "max_prompt_image_size": 3145728,
 933                        "max_prompt_images": 1,
 934                        "supported_media_types": ["image/jpeg", "image/png", "image/webp"]
 935                      }
 936                    },
 937                    "object": "model_capabilities",
 938                    "supports": {
 939                      "parallel_tool_calls": true,
 940                      "streaming": true,
 941                      "tool_calls": true,
 942                      "vision": true
 943                    },
 944                    "tokenizer": "o200k_base",
 945                    "type": "chat"
 946                  },
 947                  "id": "claude-3.7-sonnet",
 948                  "is_chat_default": false,
 949                  "is_chat_fallback": false,
 950                  "model_picker_enabled": true,
 951                  "name": "Claude 3.7 Sonnet",
 952                  "object": "model",
 953                  "policy": {
 954                    "state": "enabled",
 955                    "terms": "Enable access to the latest Claude 3.7 Sonnet model from Anthropic. [Learn more about how GitHub Copilot serves Claude 3.7 Sonnet](https://docs.github.com/copilot/using-github-copilot/using-claude-sonnet-in-github-copilot)."
 956                  },
 957                  "preview": false,
 958                  "vendor": "Anthropic",
 959                  "version": "claude-3.7-sonnet"
 960                }
 961              ],
 962              "object": "list"
 963            }"#;
 964
 965        let schema: ModelSchema = serde_json::from_str(json).unwrap();
 966
 967        assert_eq!(schema.data.len(), 2);
 968        assert_eq!(schema.data[0].id, "gpt-4");
 969        assert_eq!(schema.data[1].id, "claude-3.7-sonnet");
 970    }
 971
 972    #[test]
 973    fn test_unknown_vendor_resilience() {
 974        let json = r#"{
 975              "data": [
 976                {
 977                  "billing": {
 978                    "is_premium": false,
 979                    "multiplier": 1
 980                  },
 981                  "capabilities": {
 982                    "family": "future-model",
 983                    "limits": {
 984                      "max_context_window_tokens": 128000,
 985                      "max_output_tokens": 8192,
 986                      "max_prompt_tokens": 120000
 987                    },
 988                    "object": "model_capabilities",
 989                    "supports": { "streaming": true, "tool_calls": true },
 990                    "type": "chat"
 991                  },
 992                  "id": "future-model-v1",
 993                  "is_chat_default": false,
 994                  "is_chat_fallback": false,
 995                  "model_picker_enabled": true,
 996                  "name": "Future Model v1",
 997                  "object": "model",
 998                  "preview": false,
 999                  "vendor": "SomeNewVendor",
1000                  "version": "v1.0"
1001                }
1002              ],
1003              "object": "list"
1004            }"#;
1005
1006        let schema: ModelSchema = serde_json::from_str(json).unwrap();
1007
1008        assert_eq!(schema.data.len(), 1);
1009        assert_eq!(schema.data[0].id, "future-model-v1");
1010        assert_eq!(schema.data[0].vendor, ModelVendor::Unknown);
1011    }
1012}