copilot_chat.rs

   1use std::path::PathBuf;
   2use std::sync::Arc;
   3use std::sync::OnceLock;
   4
   5use anyhow::Context as _;
   6use anyhow::{Result, anyhow};
   7use chrono::DateTime;
   8use collections::HashSet;
   9use fs::Fs;
  10use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
  11use gpui::WeakEntity;
  12use gpui::{App, AsyncApp, Global, prelude::*};
  13use http_client::HttpRequestExt;
  14use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
  15use itertools::Itertools;
  16use paths::home_dir;
  17use serde::{Deserialize, Serialize};
  18
  19use crate::copilot_responses as responses;
  20use settings::watch_config_dir;
  21
  22pub const COPILOT_OAUTH_ENV_VAR: &str = "GH_COPILOT_TOKEN";
  23
  24#[derive(Default, Clone, Debug, PartialEq)]
  25pub struct CopilotChatConfiguration {
  26    pub enterprise_uri: Option<String>,
  27}
  28
  29impl CopilotChatConfiguration {
  30    pub fn token_url(&self) -> String {
  31        if let Some(enterprise_uri) = &self.enterprise_uri {
  32            let domain = Self::parse_domain(enterprise_uri);
  33            format!("https://api.{}/copilot_internal/v2/token", domain)
  34        } else {
  35            "https://api.github.com/copilot_internal/v2/token".to_string()
  36        }
  37    }
  38
  39    pub fn oauth_domain(&self) -> String {
  40        if let Some(enterprise_uri) = &self.enterprise_uri {
  41            Self::parse_domain(enterprise_uri)
  42        } else {
  43            "github.com".to_string()
  44        }
  45    }
  46
  47    pub fn chat_completions_url_from_endpoint(&self, endpoint: &str) -> String {
  48        format!("{}/chat/completions", endpoint)
  49    }
  50
  51    pub fn responses_url_from_endpoint(&self, endpoint: &str) -> String {
  52        format!("{}/responses", endpoint)
  53    }
  54
  55    pub fn models_url_from_endpoint(&self, endpoint: &str) -> String {
  56        format!("{}/models", endpoint)
  57    }
  58
  59    fn parse_domain(enterprise_uri: &str) -> String {
  60        let uri = enterprise_uri.trim_end_matches('/');
  61
  62        if let Some(domain) = uri.strip_prefix("https://") {
  63            domain.split('/').next().unwrap_or(domain).to_string()
  64        } else if let Some(domain) = uri.strip_prefix("http://") {
  65            domain.split('/').next().unwrap_or(domain).to_string()
  66        } else {
  67            uri.split('/').next().unwrap_or(uri).to_string()
  68        }
  69    }
  70}
  71
  72#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
  73#[serde(rename_all = "lowercase")]
  74pub enum Role {
  75    User,
  76    Assistant,
  77    System,
  78}
  79
  80#[derive(Deserialize, Serialize, Debug, Clone, PartialEq)]
  81pub enum ModelSupportedEndpoint {
  82    #[serde(rename = "/chat/completions")]
  83    ChatCompletions,
  84    #[serde(rename = "/responses")]
  85    Responses,
  86}
  87
  88#[derive(Deserialize)]
  89struct ModelSchema {
  90    #[serde(deserialize_with = "deserialize_models_skip_errors")]
  91    data: Vec<Model>,
  92}
  93
  94fn deserialize_models_skip_errors<'de, D>(deserializer: D) -> Result<Vec<Model>, D::Error>
  95where
  96    D: serde::Deserializer<'de>,
  97{
  98    let raw_values = Vec::<serde_json::Value>::deserialize(deserializer)?;
  99    let models = raw_values
 100        .into_iter()
 101        .filter_map(|value| match serde_json::from_value::<Model>(value) {
 102            Ok(model) => Some(model),
 103            Err(err) => {
 104                log::warn!("GitHub Copilot Chat model failed to deserialize: {:?}", err);
 105                None
 106            }
 107        })
 108        .collect();
 109
 110    Ok(models)
 111}
 112
 113#[derive(Clone, Serialize, Deserialize, Debug, PartialEq)]
 114pub struct Model {
 115    billing: ModelBilling,
 116    capabilities: ModelCapabilities,
 117    id: String,
 118    name: String,
 119    policy: Option<ModelPolicy>,
 120    vendor: ModelVendor,
 121    is_chat_default: bool,
 122    // The model with this value true is selected by VSCode copilot if a premium request limit is
 123    // reached. Zed does not currently implement this behaviour
 124    is_chat_fallback: bool,
 125    model_picker_enabled: bool,
 126    #[serde(default)]
 127    supported_endpoints: Vec<ModelSupportedEndpoint>,
 128}
 129
 130#[derive(Clone, Serialize, Deserialize, Debug, PartialEq)]
 131struct ModelBilling {
 132    is_premium: bool,
 133    multiplier: f64,
 134    // List of plans a model is restricted to
 135    // Field is not present if a model is available for all plans
 136    #[serde(default)]
 137    restricted_to: Option<Vec<String>>,
 138}
 139
 140#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
 141struct ModelCapabilities {
 142    family: String,
 143    #[serde(default)]
 144    limits: ModelLimits,
 145    supports: ModelSupportedFeatures,
 146    #[serde(rename = "type")]
 147    model_type: String,
 148    #[serde(default)]
 149    tokenizer: Option<String>,
 150}
 151
 152#[derive(Default, Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
 153struct ModelLimits {
 154    #[serde(default)]
 155    max_context_window_tokens: usize,
 156    #[serde(default)]
 157    max_output_tokens: usize,
 158    #[serde(default)]
 159    max_prompt_tokens: u64,
 160}
 161
 162#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
 163struct ModelPolicy {
 164    state: String,
 165}
 166
 167#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
 168struct ModelSupportedFeatures {
 169    #[serde(default)]
 170    streaming: bool,
 171    #[serde(default)]
 172    tool_calls: bool,
 173    #[serde(default)]
 174    parallel_tool_calls: bool,
 175    #[serde(default)]
 176    vision: bool,
 177}
 178
 179#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
 180pub enum ModelVendor {
 181    // Azure OpenAI should have no functional difference from OpenAI in Copilot Chat
 182    #[serde(alias = "Azure OpenAI")]
 183    OpenAI,
 184    Google,
 185    Anthropic,
 186    #[serde(rename = "xAI")]
 187    XAI,
 188    /// Unknown vendor that we don't explicitly support yet
 189    #[serde(other)]
 190    Unknown,
 191}
 192
 193#[derive(Serialize, Deserialize, Debug, Eq, PartialEq, Clone)]
 194#[serde(tag = "type")]
 195pub enum ChatMessagePart {
 196    #[serde(rename = "text")]
 197    Text { text: String },
 198    #[serde(rename = "image_url")]
 199    Image { image_url: ImageUrl },
 200}
 201
 202#[derive(Serialize, Deserialize, Debug, Eq, PartialEq, Clone)]
 203pub struct ImageUrl {
 204    pub url: String,
 205}
 206
 207impl Model {
 208    pub fn uses_streaming(&self) -> bool {
 209        self.capabilities.supports.streaming
 210    }
 211
 212    pub fn id(&self) -> &str {
 213        self.id.as_str()
 214    }
 215
 216    pub fn display_name(&self) -> &str {
 217        self.name.as_str()
 218    }
 219
 220    pub fn max_token_count(&self) -> u64 {
 221        self.capabilities.limits.max_prompt_tokens
 222    }
 223
 224    pub fn supports_tools(&self) -> bool {
 225        self.capabilities.supports.tool_calls
 226    }
 227
 228    pub fn vendor(&self) -> ModelVendor {
 229        self.vendor
 230    }
 231
 232    pub fn supports_vision(&self) -> bool {
 233        self.capabilities.supports.vision
 234    }
 235
 236    pub fn supports_parallel_tool_calls(&self) -> bool {
 237        self.capabilities.supports.parallel_tool_calls
 238    }
 239
 240    pub fn tokenizer(&self) -> Option<&str> {
 241        self.capabilities.tokenizer.as_deref()
 242    }
 243
 244    pub fn supports_response(&self) -> bool {
 245        self.supported_endpoints.len() > 0
 246            && !self
 247                .supported_endpoints
 248                .contains(&ModelSupportedEndpoint::ChatCompletions)
 249            && self
 250                .supported_endpoints
 251                .contains(&ModelSupportedEndpoint::Responses)
 252    }
 253}
 254
 255#[derive(Serialize, Deserialize)]
 256pub struct Request {
 257    pub intent: bool,
 258    pub n: usize,
 259    pub stream: bool,
 260    pub temperature: f32,
 261    pub model: String,
 262    pub messages: Vec<ChatMessage>,
 263    #[serde(default, skip_serializing_if = "Vec::is_empty")]
 264    pub tools: Vec<Tool>,
 265    #[serde(default, skip_serializing_if = "Option::is_none")]
 266    pub tool_choice: Option<ToolChoice>,
 267}
 268
 269#[derive(Serialize, Deserialize)]
 270pub struct Function {
 271    pub name: String,
 272    pub description: String,
 273    pub parameters: serde_json::Value,
 274}
 275
 276#[derive(Serialize, Deserialize)]
 277#[serde(tag = "type", rename_all = "snake_case")]
 278pub enum Tool {
 279    Function { function: Function },
 280}
 281
 282#[derive(Serialize, Deserialize, Debug)]
 283#[serde(rename_all = "lowercase")]
 284pub enum ToolChoice {
 285    Auto,
 286    Any,
 287    None,
 288}
 289
 290#[derive(Serialize, Deserialize, Debug)]
 291#[serde(tag = "role", rename_all = "lowercase")]
 292pub enum ChatMessage {
 293    Assistant {
 294        content: ChatMessageContent,
 295        #[serde(default, skip_serializing_if = "Vec::is_empty")]
 296        tool_calls: Vec<ToolCall>,
 297    },
 298    User {
 299        content: ChatMessageContent,
 300    },
 301    System {
 302        content: String,
 303    },
 304    Tool {
 305        content: ChatMessageContent,
 306        tool_call_id: String,
 307    },
 308}
 309
 310#[derive(Debug, Serialize, Deserialize)]
 311#[serde(untagged)]
 312pub enum ChatMessageContent {
 313    Plain(String),
 314    Multipart(Vec<ChatMessagePart>),
 315}
 316
 317impl ChatMessageContent {
 318    pub fn empty() -> Self {
 319        ChatMessageContent::Multipart(vec![])
 320    }
 321}
 322
 323impl From<Vec<ChatMessagePart>> for ChatMessageContent {
 324    fn from(mut parts: Vec<ChatMessagePart>) -> Self {
 325        if let [ChatMessagePart::Text { text }] = parts.as_mut_slice() {
 326            ChatMessageContent::Plain(std::mem::take(text))
 327        } else {
 328            ChatMessageContent::Multipart(parts)
 329        }
 330    }
 331}
 332
 333impl From<String> for ChatMessageContent {
 334    fn from(text: String) -> Self {
 335        ChatMessageContent::Plain(text)
 336    }
 337}
 338
 339#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
 340pub struct ToolCall {
 341    pub id: String,
 342    #[serde(flatten)]
 343    pub content: ToolCallContent,
 344}
 345
 346#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
 347#[serde(tag = "type", rename_all = "lowercase")]
 348pub enum ToolCallContent {
 349    Function { function: FunctionContent },
 350}
 351
 352#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
 353pub struct FunctionContent {
 354    pub name: String,
 355    pub arguments: String,
 356    #[serde(default, skip_serializing_if = "Option::is_none")]
 357    pub thought_signature: Option<String>,
 358}
 359
 360#[derive(Deserialize, Debug)]
 361#[serde(tag = "type", rename_all = "snake_case")]
 362pub struct ResponseEvent {
 363    pub choices: Vec<ResponseChoice>,
 364    pub id: String,
 365    pub usage: Option<Usage>,
 366}
 367
 368#[derive(Deserialize, Debug)]
 369pub struct Usage {
 370    pub completion_tokens: u64,
 371    pub prompt_tokens: u64,
 372    pub total_tokens: u64,
 373}
 374
 375#[derive(Debug, Deserialize)]
 376pub struct ResponseChoice {
 377    pub index: Option<usize>,
 378    pub finish_reason: Option<String>,
 379    pub delta: Option<ResponseDelta>,
 380    pub message: Option<ResponseDelta>,
 381}
 382
 383#[derive(Debug, Deserialize)]
 384pub struct ResponseDelta {
 385    pub content: Option<String>,
 386    pub role: Option<Role>,
 387    #[serde(default)]
 388    pub tool_calls: Vec<ToolCallChunk>,
 389}
 390#[derive(Deserialize, Debug, Eq, PartialEq)]
 391pub struct ToolCallChunk {
 392    pub index: Option<usize>,
 393    pub id: Option<String>,
 394    pub function: Option<FunctionChunk>,
 395}
 396
 397#[derive(Deserialize, Debug, Eq, PartialEq)]
 398pub struct FunctionChunk {
 399    pub name: Option<String>,
 400    pub arguments: Option<String>,
 401    pub thought_signature: Option<String>,
 402}
 403
 404#[derive(Deserialize)]
 405struct ApiTokenResponse {
 406    token: String,
 407    expires_at: i64,
 408    endpoints: ApiTokenResponseEndpoints,
 409}
 410
 411#[derive(Deserialize)]
 412struct ApiTokenResponseEndpoints {
 413    api: String,
 414}
 415
 416#[derive(Clone)]
 417struct ApiToken {
 418    api_key: String,
 419    expires_at: DateTime<chrono::Utc>,
 420    api_endpoint: String,
 421}
 422
 423impl ApiToken {
 424    pub fn remaining_seconds(&self) -> i64 {
 425        self.expires_at
 426            .timestamp()
 427            .saturating_sub(chrono::Utc::now().timestamp())
 428    }
 429}
 430
 431impl TryFrom<ApiTokenResponse> for ApiToken {
 432    type Error = anyhow::Error;
 433
 434    fn try_from(response: ApiTokenResponse) -> Result<Self, Self::Error> {
 435        let expires_at =
 436            DateTime::from_timestamp(response.expires_at, 0).context("invalid expires_at")?;
 437
 438        Ok(Self {
 439            api_key: response.token,
 440            expires_at,
 441            api_endpoint: response.endpoints.api,
 442        })
 443    }
 444}
 445
 446struct GlobalCopilotChat(gpui::Entity<CopilotChat>);
 447
 448impl Global for GlobalCopilotChat {}
 449
 450pub struct CopilotChat {
 451    oauth_token: Option<String>,
 452    api_token: Option<ApiToken>,
 453    configuration: CopilotChatConfiguration,
 454    models: Option<Vec<Model>>,
 455    client: Arc<dyn HttpClient>,
 456}
 457
 458pub fn init(
 459    fs: Arc<dyn Fs>,
 460    client: Arc<dyn HttpClient>,
 461    configuration: CopilotChatConfiguration,
 462    cx: &mut App,
 463) {
 464    let copilot_chat = cx.new(|cx| CopilotChat::new(fs, client, configuration, cx));
 465    cx.set_global(GlobalCopilotChat(copilot_chat));
 466}
 467
 468pub fn copilot_chat_config_dir() -> &'static PathBuf {
 469    static COPILOT_CHAT_CONFIG_DIR: OnceLock<PathBuf> = OnceLock::new();
 470
 471    COPILOT_CHAT_CONFIG_DIR.get_or_init(|| {
 472        let config_dir = if cfg!(target_os = "windows") {
 473            dirs::data_local_dir().expect("failed to determine LocalAppData directory")
 474        } else {
 475            std::env::var("XDG_CONFIG_HOME")
 476                .map(PathBuf::from)
 477                .unwrap_or_else(|_| home_dir().join(".config"))
 478        };
 479
 480        config_dir.join("github-copilot")
 481    })
 482}
 483
 484fn copilot_chat_config_paths() -> [PathBuf; 2] {
 485    let base_dir = copilot_chat_config_dir();
 486    [base_dir.join("hosts.json"), base_dir.join("apps.json")]
 487}
 488
 489impl CopilotChat {
 490    pub fn global(cx: &App) -> Option<gpui::Entity<Self>> {
 491        cx.try_global::<GlobalCopilotChat>()
 492            .map(|model| model.0.clone())
 493    }
 494
 495    fn new(
 496        fs: Arc<dyn Fs>,
 497        client: Arc<dyn HttpClient>,
 498        configuration: CopilotChatConfiguration,
 499        cx: &mut Context<Self>,
 500    ) -> Self {
 501        let config_paths: HashSet<PathBuf> = copilot_chat_config_paths().into_iter().collect();
 502        let dir_path = copilot_chat_config_dir();
 503
 504        cx.spawn(async move |this, cx| {
 505            let mut parent_watch_rx = watch_config_dir(
 506                cx.background_executor(),
 507                fs.clone(),
 508                dir_path.clone(),
 509                config_paths,
 510            );
 511            while let Some(contents) = parent_watch_rx.next().await {
 512                let oauth_domain =
 513                    this.read_with(cx, |this, _| this.configuration.oauth_domain())?;
 514                let oauth_token = extract_oauth_token(contents, &oauth_domain);
 515
 516                this.update(cx, |this, cx| {
 517                    this.oauth_token = oauth_token.clone();
 518                    cx.notify();
 519                })?;
 520
 521                if oauth_token.is_some() {
 522                    Self::update_models(&this, cx).await?;
 523                }
 524            }
 525            anyhow::Ok(())
 526        })
 527        .detach_and_log_err(cx);
 528
 529        let this = Self {
 530            oauth_token: std::env::var(COPILOT_OAUTH_ENV_VAR).ok(),
 531            api_token: None,
 532            models: None,
 533            configuration,
 534            client,
 535        };
 536
 537        if this.oauth_token.is_some() {
 538            cx.spawn(async move |this, cx| Self::update_models(&this, cx).await)
 539                .detach_and_log_err(cx);
 540        }
 541
 542        this
 543    }
 544
 545    async fn update_models(this: &WeakEntity<Self>, cx: &mut AsyncApp) -> Result<()> {
 546        let (oauth_token, client, configuration) = this.read_with(cx, |this, _| {
 547            (
 548                this.oauth_token.clone(),
 549                this.client.clone(),
 550                this.configuration.clone(),
 551            )
 552        })?;
 553
 554        let oauth_token = oauth_token
 555            .ok_or_else(|| anyhow!("OAuth token is missing while updating Copilot Chat models"))?;
 556
 557        let token_url = configuration.token_url();
 558        let api_token = request_api_token(&oauth_token, token_url.into(), client.clone()).await?;
 559
 560        let models_url = configuration.models_url_from_endpoint(&api_token.api_endpoint);
 561        let models =
 562            get_models(models_url.into(), api_token.api_key.clone(), client.clone()).await?;
 563
 564        this.update(cx, |this, cx| {
 565            this.api_token = Some(api_token);
 566            this.models = Some(models);
 567            cx.notify();
 568        })?;
 569        anyhow::Ok(())
 570    }
 571
 572    pub fn is_authenticated(&self) -> bool {
 573        self.oauth_token.is_some()
 574    }
 575
 576    pub fn models(&self) -> Option<&[Model]> {
 577        self.models.as_deref()
 578    }
 579
 580    pub async fn stream_completion(
 581        request: Request,
 582        is_user_initiated: bool,
 583        mut cx: AsyncApp,
 584    ) -> Result<BoxStream<'static, Result<ResponseEvent>>> {
 585        let (client, token, configuration) = Self::get_auth_details(&mut cx).await?;
 586
 587        let api_url = configuration.chat_completions_url_from_endpoint(&token.api_endpoint);
 588        stream_completion(
 589            client.clone(),
 590            token.api_key,
 591            api_url.into(),
 592            request,
 593            is_user_initiated,
 594        )
 595        .await
 596    }
 597
 598    pub async fn stream_response(
 599        request: responses::Request,
 600        is_user_initiated: bool,
 601        mut cx: AsyncApp,
 602    ) -> Result<BoxStream<'static, Result<responses::StreamEvent>>> {
 603        let (client, token, configuration) = Self::get_auth_details(&mut cx).await?;
 604
 605        let api_url = configuration.responses_url_from_endpoint(&token.api_endpoint);
 606        responses::stream_response(
 607            client.clone(),
 608            token.api_key,
 609            api_url,
 610            request,
 611            is_user_initiated,
 612        )
 613        .await
 614    }
 615
 616    async fn get_auth_details(
 617        cx: &mut AsyncApp,
 618    ) -> Result<(Arc<dyn HttpClient>, ApiToken, CopilotChatConfiguration)> {
 619        let this = cx
 620            .update(|cx| Self::global(cx))
 621            .ok()
 622            .flatten()
 623            .context("Copilot chat is not enabled")?;
 624
 625        let (oauth_token, api_token, client, configuration) = this.read_with(cx, |this, _| {
 626            (
 627                this.oauth_token.clone(),
 628                this.api_token.clone(),
 629                this.client.clone(),
 630                this.configuration.clone(),
 631            )
 632        })?;
 633
 634        let oauth_token = oauth_token.context("No OAuth token available")?;
 635
 636        let token = match api_token {
 637            Some(api_token) if api_token.remaining_seconds() > 5 * 60 => api_token,
 638            _ => {
 639                let token_url = configuration.token_url();
 640                let token =
 641                    request_api_token(&oauth_token, token_url.into(), client.clone()).await?;
 642                this.update(cx, |this, cx| {
 643                    this.api_token = Some(token.clone());
 644                    cx.notify();
 645                })?;
 646                token
 647            }
 648        };
 649
 650        Ok((client, token, configuration))
 651    }
 652
 653    pub fn set_configuration(
 654        &mut self,
 655        configuration: CopilotChatConfiguration,
 656        cx: &mut Context<Self>,
 657    ) {
 658        let same_configuration = self.configuration == configuration;
 659        self.configuration = configuration;
 660        if !same_configuration {
 661            self.api_token = None;
 662            cx.spawn(async move |this, cx| {
 663                Self::update_models(&this, cx).await?;
 664                Ok::<_, anyhow::Error>(())
 665            })
 666            .detach();
 667        }
 668    }
 669}
 670
 671async fn get_models(
 672    models_url: Arc<str>,
 673    api_token: String,
 674    client: Arc<dyn HttpClient>,
 675) -> Result<Vec<Model>> {
 676    let all_models = request_models(models_url, api_token, client).await?;
 677
 678    let mut models: Vec<Model> = all_models
 679        .into_iter()
 680        .filter(|model| {
 681            model.model_picker_enabled
 682                && model.capabilities.model_type.as_str() == "chat"
 683                && model
 684                    .policy
 685                    .as_ref()
 686                    .is_none_or(|policy| policy.state == "enabled")
 687        })
 688        .dedup_by(|a, b| a.capabilities.family == b.capabilities.family)
 689        .collect();
 690
 691    if let Some(default_model_position) = models.iter().position(|model| model.is_chat_default) {
 692        let default_model = models.remove(default_model_position);
 693        models.insert(0, default_model);
 694    }
 695
 696    Ok(models)
 697}
 698
 699async fn request_models(
 700    models_url: Arc<str>,
 701    api_token: String,
 702    client: Arc<dyn HttpClient>,
 703) -> Result<Vec<Model>> {
 704    let request_builder = HttpRequest::builder()
 705        .method(Method::GET)
 706        .uri(models_url.as_ref())
 707        .header("Authorization", format!("Bearer {}", api_token))
 708        .header("Content-Type", "application/json")
 709        .header("Copilot-Integration-Id", "vscode-chat")
 710        .header("Editor-Version", "vscode/1.103.2")
 711        .header("x-github-api-version", "2025-05-01");
 712
 713    let request = request_builder.body(AsyncBody::empty())?;
 714
 715    let mut response = client.send(request).await?;
 716
 717    anyhow::ensure!(
 718        response.status().is_success(),
 719        "Failed to request models: {}",
 720        response.status()
 721    );
 722    let mut body = Vec::new();
 723    response.body_mut().read_to_end(&mut body).await?;
 724
 725    let body_str = std::str::from_utf8(&body)?;
 726
 727    let models = serde_json::from_str::<ModelSchema>(body_str)?.data;
 728
 729    Ok(models)
 730}
 731
 732async fn request_api_token(
 733    oauth_token: &str,
 734    auth_url: Arc<str>,
 735    client: Arc<dyn HttpClient>,
 736) -> Result<ApiToken> {
 737    let request_builder = HttpRequest::builder()
 738        .method(Method::GET)
 739        .uri(auth_url.as_ref())
 740        .header("Authorization", format!("token {}", oauth_token))
 741        .header("Accept", "application/json");
 742
 743    let request = request_builder.body(AsyncBody::empty())?;
 744
 745    let mut response = client.send(request).await?;
 746
 747    if response.status().is_success() {
 748        let mut body = Vec::new();
 749        response.body_mut().read_to_end(&mut body).await?;
 750
 751        let body_str = std::str::from_utf8(&body)?;
 752
 753        let parsed: ApiTokenResponse = serde_json::from_str(body_str)?;
 754        ApiToken::try_from(parsed)
 755    } else {
 756        let mut body = Vec::new();
 757        response.body_mut().read_to_end(&mut body).await?;
 758
 759        let body_str = std::str::from_utf8(&body)?;
 760        anyhow::bail!("Failed to request API token: {body_str}");
 761    }
 762}
 763
 764fn extract_oauth_token(contents: String, domain: &str) -> Option<String> {
 765    serde_json::from_str::<serde_json::Value>(&contents)
 766        .map(|v| {
 767            v.as_object().and_then(|obj| {
 768                obj.iter().find_map(|(key, value)| {
 769                    if key.starts_with(domain) {
 770                        value["oauth_token"].as_str().map(|v| v.to_string())
 771                    } else {
 772                        None
 773                    }
 774                })
 775            })
 776        })
 777        .ok()
 778        .flatten()
 779}
 780
 781async fn stream_completion(
 782    client: Arc<dyn HttpClient>,
 783    api_key: String,
 784    completion_url: Arc<str>,
 785    request: Request,
 786    is_user_initiated: bool,
 787) -> Result<BoxStream<'static, Result<ResponseEvent>>> {
 788    let is_vision_request = request.messages.iter().any(|message| match message {
 789      ChatMessage::User { content }
 790      | ChatMessage::Assistant { content, .. }
 791      | ChatMessage::Tool { content, .. } => {
 792          matches!(content, ChatMessageContent::Multipart(parts) if parts.iter().any(|part| matches!(part, ChatMessagePart::Image { .. })))
 793      }
 794      _ => false,
 795  });
 796
 797    let request_initiator = if is_user_initiated { "user" } else { "agent" };
 798
 799    let request_builder = HttpRequest::builder()
 800        .method(Method::POST)
 801        .uri(completion_url.as_ref())
 802        .header(
 803            "Editor-Version",
 804            format!(
 805                "Zed/{}",
 806                option_env!("CARGO_PKG_VERSION").unwrap_or("unknown")
 807            ),
 808        )
 809        .header("Authorization", format!("Bearer {}", api_key))
 810        .header("Content-Type", "application/json")
 811        .header("Copilot-Integration-Id", "vscode-chat")
 812        .header("X-Initiator", request_initiator)
 813        .when(is_vision_request, |builder| {
 814            builder.header("Copilot-Vision-Request", is_vision_request.to_string())
 815        });
 816
 817    let is_streaming = request.stream;
 818
 819    let json = serde_json::to_string(&request)?;
 820    let request = request_builder.body(AsyncBody::from(json))?;
 821    let mut response = client.send(request).await?;
 822
 823    if !response.status().is_success() {
 824        let mut body = Vec::new();
 825        response.body_mut().read_to_end(&mut body).await?;
 826        let body_str = std::str::from_utf8(&body)?;
 827        anyhow::bail!(
 828            "Failed to connect to API: {} {}",
 829            response.status(),
 830            body_str
 831        );
 832    }
 833
 834    if is_streaming {
 835        let reader = BufReader::new(response.into_body());
 836        Ok(reader
 837            .lines()
 838            .filter_map(|line| async move {
 839                match line {
 840                    Ok(line) => {
 841                        let line = line.strip_prefix("data: ")?;
 842                        if line.starts_with("[DONE]") {
 843                            return None;
 844                        }
 845
 846                        match serde_json::from_str::<ResponseEvent>(line) {
 847                            Ok(response) => {
 848                                if response.choices.is_empty() {
 849                                    None
 850                                } else {
 851                                    Some(Ok(response))
 852                                }
 853                            }
 854                            Err(error) => Some(Err(anyhow!(error))),
 855                        }
 856                    }
 857                    Err(error) => Some(Err(anyhow!(error))),
 858                }
 859            })
 860            .boxed())
 861    } else {
 862        let mut body = Vec::new();
 863        response.body_mut().read_to_end(&mut body).await?;
 864        let body_str = std::str::from_utf8(&body)?;
 865        let response: ResponseEvent = serde_json::from_str(body_str)?;
 866
 867        Ok(futures::stream::once(async move { Ok(response) }).boxed())
 868    }
 869}
 870
 871#[cfg(test)]
 872mod tests {
 873    use super::*;
 874
 875    #[test]
 876    fn test_resilient_model_schema_deserialize() {
 877        let json = r#"{
 878              "data": [
 879                {
 880                  "billing": {
 881                    "is_premium": false,
 882                    "multiplier": 0
 883                  },
 884                  "capabilities": {
 885                    "family": "gpt-4",
 886                    "limits": {
 887                      "max_context_window_tokens": 32768,
 888                      "max_output_tokens": 4096,
 889                      "max_prompt_tokens": 32768
 890                    },
 891                    "object": "model_capabilities",
 892                    "supports": { "streaming": true, "tool_calls": true },
 893                    "tokenizer": "cl100k_base",
 894                    "type": "chat"
 895                  },
 896                  "id": "gpt-4",
 897                  "is_chat_default": false,
 898                  "is_chat_fallback": false,
 899                  "model_picker_enabled": false,
 900                  "name": "GPT 4",
 901                  "object": "model",
 902                  "preview": false,
 903                  "vendor": "Azure OpenAI",
 904                  "version": "gpt-4-0613"
 905                },
 906                {
 907                    "some-unknown-field": 123
 908                },
 909                {
 910                  "billing": {
 911                    "is_premium": true,
 912                    "multiplier": 1,
 913                    "restricted_to": [
 914                      "pro",
 915                      "pro_plus",
 916                      "business",
 917                      "enterprise"
 918                    ]
 919                  },
 920                  "capabilities": {
 921                    "family": "claude-3.7-sonnet",
 922                    "limits": {
 923                      "max_context_window_tokens": 200000,
 924                      "max_output_tokens": 16384,
 925                      "max_prompt_tokens": 90000,
 926                      "vision": {
 927                        "max_prompt_image_size": 3145728,
 928                        "max_prompt_images": 1,
 929                        "supported_media_types": ["image/jpeg", "image/png", "image/webp"]
 930                      }
 931                    },
 932                    "object": "model_capabilities",
 933                    "supports": {
 934                      "parallel_tool_calls": true,
 935                      "streaming": true,
 936                      "tool_calls": true,
 937                      "vision": true
 938                    },
 939                    "tokenizer": "o200k_base",
 940                    "type": "chat"
 941                  },
 942                  "id": "claude-3.7-sonnet",
 943                  "is_chat_default": false,
 944                  "is_chat_fallback": false,
 945                  "model_picker_enabled": true,
 946                  "name": "Claude 3.7 Sonnet",
 947                  "object": "model",
 948                  "policy": {
 949                    "state": "enabled",
 950                    "terms": "Enable access to the latest Claude 3.7 Sonnet model from Anthropic. [Learn more about how GitHub Copilot serves Claude 3.7 Sonnet](https://docs.github.com/copilot/using-github-copilot/using-claude-sonnet-in-github-copilot)."
 951                  },
 952                  "preview": false,
 953                  "vendor": "Anthropic",
 954                  "version": "claude-3.7-sonnet"
 955                }
 956              ],
 957              "object": "list"
 958            }"#;
 959
 960        let schema: ModelSchema = serde_json::from_str(json).unwrap();
 961
 962        assert_eq!(schema.data.len(), 2);
 963        assert_eq!(schema.data[0].id, "gpt-4");
 964        assert_eq!(schema.data[1].id, "claude-3.7-sonnet");
 965    }
 966
 967    #[test]
 968    fn test_unknown_vendor_resilience() {
 969        let json = r#"{
 970              "data": [
 971                {
 972                  "billing": {
 973                    "is_premium": false,
 974                    "multiplier": 1
 975                  },
 976                  "capabilities": {
 977                    "family": "future-model",
 978                    "limits": {
 979                      "max_context_window_tokens": 128000,
 980                      "max_output_tokens": 8192,
 981                      "max_prompt_tokens": 120000
 982                    },
 983                    "object": "model_capabilities",
 984                    "supports": { "streaming": true, "tool_calls": true },
 985                    "type": "chat"
 986                  },
 987                  "id": "future-model-v1",
 988                  "is_chat_default": false,
 989                  "is_chat_fallback": false,
 990                  "model_picker_enabled": true,
 991                  "name": "Future Model v1",
 992                  "object": "model",
 993                  "preview": false,
 994                  "vendor": "SomeNewVendor",
 995                  "version": "v1.0"
 996                }
 997              ],
 998              "object": "list"
 999            }"#;
1000
1001        let schema: ModelSchema = serde_json::from_str(json).unwrap();
1002
1003        assert_eq!(schema.data.len(), 1);
1004        assert_eq!(schema.data[0].id, "future-model-v1");
1005        assert_eq!(schema.data[0].vendor, ModelVendor::Unknown);
1006    }
1007}