copilot_chat.rs

   1use std::path::PathBuf;
   2use std::sync::Arc;
   3use std::sync::OnceLock;
   4
   5use anyhow::Context as _;
   6use anyhow::{Result, anyhow};
   7use chrono::DateTime;
   8use collections::HashSet;
   9use fs::Fs;
  10use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
  11use gpui::WeakEntity;
  12use gpui::{App, AsyncApp, Global, prelude::*};
  13use http_client::HttpRequestExt;
  14use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
  15use itertools::Itertools;
  16use paths::home_dir;
  17use serde::{Deserialize, Serialize};
  18
  19use crate::copilot_responses as responses;
  20use settings::watch_config_dir;
  21
  22pub const COPILOT_OAUTH_ENV_VAR: &str = "GH_COPILOT_TOKEN";
  23
  24#[derive(Default, Clone, Debug, PartialEq)]
  25pub struct CopilotChatConfiguration {
  26    pub enterprise_uri: Option<String>,
  27}
  28
  29impl CopilotChatConfiguration {
  30    pub fn token_url(&self) -> String {
  31        if let Some(enterprise_uri) = &self.enterprise_uri {
  32            let domain = Self::parse_domain(enterprise_uri);
  33            format!("https://api.{}/copilot_internal/v2/token", domain)
  34        } else {
  35            "https://api.github.com/copilot_internal/v2/token".to_string()
  36        }
  37    }
  38
  39    pub fn oauth_domain(&self) -> String {
  40        if let Some(enterprise_uri) = &self.enterprise_uri {
  41            Self::parse_domain(enterprise_uri)
  42        } else {
  43            "github.com".to_string()
  44        }
  45    }
  46
  47    pub fn chat_completions_url_from_endpoint(&self, endpoint: &str) -> String {
  48        format!("{}/chat/completions", endpoint)
  49    }
  50
  51    pub fn responses_url_from_endpoint(&self, endpoint: &str) -> String {
  52        format!("{}/responses", endpoint)
  53    }
  54
  55    pub fn models_url_from_endpoint(&self, endpoint: &str) -> String {
  56        format!("{}/models", endpoint)
  57    }
  58
  59    fn parse_domain(enterprise_uri: &str) -> String {
  60        let uri = enterprise_uri.trim_end_matches('/');
  61
  62        if let Some(domain) = uri.strip_prefix("https://") {
  63            domain.split('/').next().unwrap_or(domain).to_string()
  64        } else if let Some(domain) = uri.strip_prefix("http://") {
  65            domain.split('/').next().unwrap_or(domain).to_string()
  66        } else {
  67            uri.split('/').next().unwrap_or(uri).to_string()
  68        }
  69    }
  70}
  71
  72#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
  73#[serde(rename_all = "lowercase")]
  74pub enum Role {
  75    User,
  76    Assistant,
  77    System,
  78}
  79
  80#[derive(Deserialize, Serialize, Debug, Clone, PartialEq)]
  81pub enum ModelSupportedEndpoint {
  82    #[serde(rename = "/chat/completions")]
  83    ChatCompletions,
  84    #[serde(rename = "/responses")]
  85    Responses,
  86}
  87
  88#[derive(Deserialize)]
  89struct ModelSchema {
  90    #[serde(deserialize_with = "deserialize_models_skip_errors")]
  91    data: Vec<Model>,
  92}
  93
  94fn deserialize_models_skip_errors<'de, D>(deserializer: D) -> Result<Vec<Model>, D::Error>
  95where
  96    D: serde::Deserializer<'de>,
  97{
  98    let raw_values = Vec::<serde_json::Value>::deserialize(deserializer)?;
  99    let models = raw_values
 100        .into_iter()
 101        .filter_map(|value| match serde_json::from_value::<Model>(value) {
 102            Ok(model) => Some(model),
 103            Err(err) => {
 104                log::warn!("GitHub Copilot Chat model failed to deserialize: {:?}", err);
 105                None
 106            }
 107        })
 108        .collect();
 109
 110    Ok(models)
 111}
 112
 113#[derive(Clone, Serialize, Deserialize, Debug, PartialEq)]
 114pub struct Model {
 115    billing: ModelBilling,
 116    capabilities: ModelCapabilities,
 117    id: String,
 118    name: String,
 119    policy: Option<ModelPolicy>,
 120    vendor: ModelVendor,
 121    is_chat_default: bool,
 122    // The model with this value true is selected by VSCode copilot if a premium request limit is
 123    // reached. Zed does not currently implement this behaviour
 124    is_chat_fallback: bool,
 125    model_picker_enabled: bool,
 126    #[serde(default)]
 127    supported_endpoints: Vec<ModelSupportedEndpoint>,
 128}
 129
 130#[derive(Clone, Serialize, Deserialize, Debug, PartialEq)]
 131struct ModelBilling {
 132    is_premium: bool,
 133    multiplier: f64,
 134    // List of plans a model is restricted to
 135    // Field is not present if a model is available for all plans
 136    #[serde(default)]
 137    restricted_to: Option<Vec<String>>,
 138}
 139
 140#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
 141struct ModelCapabilities {
 142    family: String,
 143    #[serde(default)]
 144    limits: ModelLimits,
 145    supports: ModelSupportedFeatures,
 146    #[serde(rename = "type")]
 147    model_type: String,
 148    #[serde(default)]
 149    tokenizer: Option<String>,
 150}
 151
 152#[derive(Default, Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
 153struct ModelLimits {
 154    #[serde(default)]
 155    max_context_window_tokens: usize,
 156    #[serde(default)]
 157    max_output_tokens: usize,
 158    #[serde(default)]
 159    max_prompt_tokens: u64,
 160}
 161
 162#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
 163struct ModelPolicy {
 164    state: String,
 165}
 166
 167#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
 168struct ModelSupportedFeatures {
 169    #[serde(default)]
 170    streaming: bool,
 171    #[serde(default)]
 172    tool_calls: bool,
 173    #[serde(default)]
 174    parallel_tool_calls: bool,
 175    #[serde(default)]
 176    vision: bool,
 177}
 178
 179#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
 180pub enum ModelVendor {
 181    // Azure OpenAI should have no functional difference from OpenAI in Copilot Chat
 182    #[serde(alias = "Azure OpenAI")]
 183    OpenAI,
 184    Google,
 185    Anthropic,
 186    #[serde(rename = "xAI")]
 187    XAI,
 188    /// Unknown vendor that we don't explicitly support yet
 189    #[serde(other)]
 190    Unknown,
 191}
 192
 193#[derive(Serialize, Deserialize, Debug, Eq, PartialEq, Clone)]
 194#[serde(tag = "type")]
 195pub enum ChatMessagePart {
 196    #[serde(rename = "text")]
 197    Text { text: String },
 198    #[serde(rename = "image_url")]
 199    Image { image_url: ImageUrl },
 200}
 201
 202#[derive(Serialize, Deserialize, Debug, Eq, PartialEq, Clone)]
 203pub struct ImageUrl {
 204    pub url: String,
 205}
 206
 207impl Model {
 208    pub fn uses_streaming(&self) -> bool {
 209        self.capabilities.supports.streaming
 210    }
 211
 212    pub fn id(&self) -> &str {
 213        self.id.as_str()
 214    }
 215
 216    pub fn display_name(&self) -> &str {
 217        self.name.as_str()
 218    }
 219
 220    pub fn max_token_count(&self) -> u64 {
 221        self.capabilities.limits.max_prompt_tokens
 222    }
 223
 224    pub fn supports_tools(&self) -> bool {
 225        self.capabilities.supports.tool_calls
 226    }
 227
 228    pub fn vendor(&self) -> ModelVendor {
 229        self.vendor
 230    }
 231
 232    pub fn supports_vision(&self) -> bool {
 233        self.capabilities.supports.vision
 234    }
 235
 236    pub fn supports_parallel_tool_calls(&self) -> bool {
 237        self.capabilities.supports.parallel_tool_calls
 238    }
 239
 240    pub fn tokenizer(&self) -> Option<&str> {
 241        self.capabilities.tokenizer.as_deref()
 242    }
 243
 244    pub fn supports_response(&self) -> bool {
 245        self.supported_endpoints.len() > 0
 246            && !self
 247                .supported_endpoints
 248                .contains(&ModelSupportedEndpoint::ChatCompletions)
 249            && self
 250                .supported_endpoints
 251                .contains(&ModelSupportedEndpoint::Responses)
 252    }
 253}
 254
 255#[derive(Serialize, Deserialize)]
 256pub struct Request {
 257    pub intent: bool,
 258    pub n: usize,
 259    pub stream: bool,
 260    pub temperature: f32,
 261    pub model: String,
 262    pub messages: Vec<ChatMessage>,
 263    #[serde(default, skip_serializing_if = "Vec::is_empty")]
 264    pub tools: Vec<Tool>,
 265    #[serde(default, skip_serializing_if = "Option::is_none")]
 266    pub tool_choice: Option<ToolChoice>,
 267}
 268
 269#[derive(Serialize, Deserialize)]
 270pub struct Function {
 271    pub name: String,
 272    pub description: String,
 273    pub parameters: serde_json::Value,
 274}
 275
 276#[derive(Serialize, Deserialize)]
 277#[serde(tag = "type", rename_all = "snake_case")]
 278pub enum Tool {
 279    Function { function: Function },
 280}
 281
 282#[derive(Serialize, Deserialize, Debug)]
 283#[serde(rename_all = "lowercase")]
 284pub enum ToolChoice {
 285    Auto,
 286    Any,
 287    None,
 288}
 289
 290#[derive(Serialize, Deserialize, Debug)]
 291#[serde(tag = "role", rename_all = "lowercase")]
 292pub enum ChatMessage {
 293    Assistant {
 294        content: ChatMessageContent,
 295        #[serde(default, skip_serializing_if = "Vec::is_empty")]
 296        tool_calls: Vec<ToolCall>,
 297    },
 298    User {
 299        content: ChatMessageContent,
 300    },
 301    System {
 302        content: String,
 303    },
 304    Tool {
 305        content: ChatMessageContent,
 306        tool_call_id: String,
 307    },
 308}
 309
 310#[derive(Debug, Serialize, Deserialize)]
 311#[serde(untagged)]
 312pub enum ChatMessageContent {
 313    Plain(String),
 314    Multipart(Vec<ChatMessagePart>),
 315}
 316
 317impl ChatMessageContent {
 318    pub fn empty() -> Self {
 319        ChatMessageContent::Multipart(vec![])
 320    }
 321}
 322
 323impl From<Vec<ChatMessagePart>> for ChatMessageContent {
 324    fn from(mut parts: Vec<ChatMessagePart>) -> Self {
 325        if let [ChatMessagePart::Text { text }] = parts.as_mut_slice() {
 326            ChatMessageContent::Plain(std::mem::take(text))
 327        } else {
 328            ChatMessageContent::Multipart(parts)
 329        }
 330    }
 331}
 332
 333impl From<String> for ChatMessageContent {
 334    fn from(text: String) -> Self {
 335        ChatMessageContent::Plain(text)
 336    }
 337}
 338
 339#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
 340pub struct ToolCall {
 341    pub id: String,
 342    #[serde(flatten)]
 343    pub content: ToolCallContent,
 344}
 345
 346#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
 347#[serde(tag = "type", rename_all = "lowercase")]
 348pub enum ToolCallContent {
 349    Function { function: FunctionContent },
 350}
 351
 352#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
 353pub struct FunctionContent {
 354    pub name: String,
 355    pub arguments: String,
 356}
 357
 358#[derive(Deserialize, Debug)]
 359#[serde(tag = "type", rename_all = "snake_case")]
 360pub struct ResponseEvent {
 361    pub choices: Vec<ResponseChoice>,
 362    pub id: String,
 363    pub usage: Option<Usage>,
 364}
 365
 366#[derive(Deserialize, Debug)]
 367pub struct Usage {
 368    pub completion_tokens: u64,
 369    pub prompt_tokens: u64,
 370    pub total_tokens: u64,
 371}
 372
 373#[derive(Debug, Deserialize)]
 374pub struct ResponseChoice {
 375    pub index: Option<usize>,
 376    pub finish_reason: Option<String>,
 377    pub delta: Option<ResponseDelta>,
 378    pub message: Option<ResponseDelta>,
 379}
 380
 381#[derive(Debug, Deserialize)]
 382pub struct ResponseDelta {
 383    pub content: Option<String>,
 384    pub role: Option<Role>,
 385    #[serde(default)]
 386    pub tool_calls: Vec<ToolCallChunk>,
 387}
 388#[derive(Deserialize, Debug, Eq, PartialEq)]
 389pub struct ToolCallChunk {
 390    pub index: Option<usize>,
 391    pub id: Option<String>,
 392    pub function: Option<FunctionChunk>,
 393}
 394
 395#[derive(Deserialize, Debug, Eq, PartialEq)]
 396pub struct FunctionChunk {
 397    pub name: Option<String>,
 398    pub arguments: Option<String>,
 399}
 400
 401#[derive(Deserialize)]
 402struct ApiTokenResponse {
 403    token: String,
 404    expires_at: i64,
 405    endpoints: ApiTokenResponseEndpoints,
 406}
 407
 408#[derive(Deserialize)]
 409struct ApiTokenResponseEndpoints {
 410    api: String,
 411}
 412
 413#[derive(Clone)]
 414struct ApiToken {
 415    api_key: String,
 416    expires_at: DateTime<chrono::Utc>,
 417    api_endpoint: String,
 418}
 419
 420impl ApiToken {
 421    pub fn remaining_seconds(&self) -> i64 {
 422        self.expires_at
 423            .timestamp()
 424            .saturating_sub(chrono::Utc::now().timestamp())
 425    }
 426}
 427
 428impl TryFrom<ApiTokenResponse> for ApiToken {
 429    type Error = anyhow::Error;
 430
 431    fn try_from(response: ApiTokenResponse) -> Result<Self, Self::Error> {
 432        let expires_at =
 433            DateTime::from_timestamp(response.expires_at, 0).context("invalid expires_at")?;
 434
 435        Ok(Self {
 436            api_key: response.token,
 437            expires_at,
 438            api_endpoint: response.endpoints.api,
 439        })
 440    }
 441}
 442
 443struct GlobalCopilotChat(gpui::Entity<CopilotChat>);
 444
 445impl Global for GlobalCopilotChat {}
 446
 447pub struct CopilotChat {
 448    oauth_token: Option<String>,
 449    api_token: Option<ApiToken>,
 450    configuration: CopilotChatConfiguration,
 451    models: Option<Vec<Model>>,
 452    client: Arc<dyn HttpClient>,
 453}
 454
 455pub fn init(
 456    fs: Arc<dyn Fs>,
 457    client: Arc<dyn HttpClient>,
 458    configuration: CopilotChatConfiguration,
 459    cx: &mut App,
 460) {
 461    let copilot_chat = cx.new(|cx| CopilotChat::new(fs, client, configuration, cx));
 462    cx.set_global(GlobalCopilotChat(copilot_chat));
 463}
 464
 465pub fn copilot_chat_config_dir() -> &'static PathBuf {
 466    static COPILOT_CHAT_CONFIG_DIR: OnceLock<PathBuf> = OnceLock::new();
 467
 468    COPILOT_CHAT_CONFIG_DIR.get_or_init(|| {
 469        let config_dir = if cfg!(target_os = "windows") {
 470            dirs::data_local_dir().expect("failed to determine LocalAppData directory")
 471        } else {
 472            std::env::var("XDG_CONFIG_HOME")
 473                .map(PathBuf::from)
 474                .unwrap_or_else(|_| home_dir().join(".config"))
 475        };
 476
 477        config_dir.join("github-copilot")
 478    })
 479}
 480
 481fn copilot_chat_config_paths() -> [PathBuf; 2] {
 482    let base_dir = copilot_chat_config_dir();
 483    [base_dir.join("hosts.json"), base_dir.join("apps.json")]
 484}
 485
 486impl CopilotChat {
 487    pub fn global(cx: &App) -> Option<gpui::Entity<Self>> {
 488        cx.try_global::<GlobalCopilotChat>()
 489            .map(|model| model.0.clone())
 490    }
 491
 492    fn new(
 493        fs: Arc<dyn Fs>,
 494        client: Arc<dyn HttpClient>,
 495        configuration: CopilotChatConfiguration,
 496        cx: &mut Context<Self>,
 497    ) -> Self {
 498        let config_paths: HashSet<PathBuf> = copilot_chat_config_paths().into_iter().collect();
 499        let dir_path = copilot_chat_config_dir();
 500
 501        cx.spawn(async move |this, cx| {
 502            let mut parent_watch_rx = watch_config_dir(
 503                cx.background_executor(),
 504                fs.clone(),
 505                dir_path.clone(),
 506                config_paths,
 507            );
 508            while let Some(contents) = parent_watch_rx.next().await {
 509                let oauth_domain =
 510                    this.read_with(cx, |this, _| this.configuration.oauth_domain())?;
 511                let oauth_token = extract_oauth_token(contents, &oauth_domain);
 512
 513                this.update(cx, |this, cx| {
 514                    this.oauth_token = oauth_token.clone();
 515                    cx.notify();
 516                })?;
 517
 518                if oauth_token.is_some() {
 519                    Self::update_models(&this, cx).await?;
 520                }
 521            }
 522            anyhow::Ok(())
 523        })
 524        .detach_and_log_err(cx);
 525
 526        let this = Self {
 527            oauth_token: std::env::var(COPILOT_OAUTH_ENV_VAR).ok(),
 528            api_token: None,
 529            models: None,
 530            configuration,
 531            client,
 532        };
 533
 534        if this.oauth_token.is_some() {
 535            cx.spawn(async move |this, cx| Self::update_models(&this, cx).await)
 536                .detach_and_log_err(cx);
 537        }
 538
 539        this
 540    }
 541
 542    async fn update_models(this: &WeakEntity<Self>, cx: &mut AsyncApp) -> Result<()> {
 543        let (oauth_token, client, configuration) = this.read_with(cx, |this, _| {
 544            (
 545                this.oauth_token.clone(),
 546                this.client.clone(),
 547                this.configuration.clone(),
 548            )
 549        })?;
 550
 551        let oauth_token = oauth_token
 552            .ok_or_else(|| anyhow!("OAuth token is missing while updating Copilot Chat models"))?;
 553
 554        let token_url = configuration.token_url();
 555        let api_token = request_api_token(&oauth_token, token_url.into(), client.clone()).await?;
 556
 557        let models_url = configuration.models_url_from_endpoint(&api_token.api_endpoint);
 558        let models =
 559            get_models(models_url.into(), api_token.api_key.clone(), client.clone()).await?;
 560
 561        this.update(cx, |this, cx| {
 562            this.api_token = Some(api_token);
 563            this.models = Some(models);
 564            cx.notify();
 565        })?;
 566        anyhow::Ok(())
 567    }
 568
 569    pub fn is_authenticated(&self) -> bool {
 570        self.oauth_token.is_some()
 571    }
 572
 573    pub fn models(&self) -> Option<&[Model]> {
 574        self.models.as_deref()
 575    }
 576
 577    pub async fn stream_completion(
 578        request: Request,
 579        is_user_initiated: bool,
 580        mut cx: AsyncApp,
 581    ) -> Result<BoxStream<'static, Result<ResponseEvent>>> {
 582        let (client, token, configuration) = Self::get_auth_details(&mut cx).await?;
 583
 584        let api_url = configuration.chat_completions_url_from_endpoint(&token.api_endpoint);
 585        stream_completion(
 586            client.clone(),
 587            token.api_key,
 588            api_url.into(),
 589            request,
 590            is_user_initiated,
 591        )
 592        .await
 593    }
 594
 595    pub async fn stream_response(
 596        request: responses::Request,
 597        is_user_initiated: bool,
 598        mut cx: AsyncApp,
 599    ) -> Result<BoxStream<'static, Result<responses::StreamEvent>>> {
 600        let (client, token, configuration) = Self::get_auth_details(&mut cx).await?;
 601
 602        let api_url = configuration.responses_url_from_endpoint(&token.api_endpoint);
 603        responses::stream_response(
 604            client.clone(),
 605            token.api_key,
 606            api_url,
 607            request,
 608            is_user_initiated,
 609        )
 610        .await
 611    }
 612
 613    async fn get_auth_details(
 614        cx: &mut AsyncApp,
 615    ) -> Result<(Arc<dyn HttpClient>, ApiToken, CopilotChatConfiguration)> {
 616        let this = cx
 617            .update(|cx| Self::global(cx))
 618            .ok()
 619            .flatten()
 620            .context("Copilot chat is not enabled")?;
 621
 622        let (oauth_token, api_token, client, configuration) = this.read_with(cx, |this, _| {
 623            (
 624                this.oauth_token.clone(),
 625                this.api_token.clone(),
 626                this.client.clone(),
 627                this.configuration.clone(),
 628            )
 629        })?;
 630
 631        let oauth_token = oauth_token.context("No OAuth token available")?;
 632
 633        let token = match api_token {
 634            Some(api_token) if api_token.remaining_seconds() > 5 * 60 => api_token,
 635            _ => {
 636                let token_url = configuration.token_url();
 637                let token =
 638                    request_api_token(&oauth_token, token_url.into(), client.clone()).await?;
 639                this.update(cx, |this, cx| {
 640                    this.api_token = Some(token.clone());
 641                    cx.notify();
 642                })?;
 643                token
 644            }
 645        };
 646
 647        Ok((client, token, configuration))
 648    }
 649
 650    pub fn set_configuration(
 651        &mut self,
 652        configuration: CopilotChatConfiguration,
 653        cx: &mut Context<Self>,
 654    ) {
 655        let same_configuration = self.configuration == configuration;
 656        self.configuration = configuration;
 657        if !same_configuration {
 658            self.api_token = None;
 659            cx.spawn(async move |this, cx| {
 660                Self::update_models(&this, cx).await?;
 661                Ok::<_, anyhow::Error>(())
 662            })
 663            .detach();
 664        }
 665    }
 666}
 667
 668async fn get_models(
 669    models_url: Arc<str>,
 670    api_token: String,
 671    client: Arc<dyn HttpClient>,
 672) -> Result<Vec<Model>> {
 673    let all_models = request_models(models_url, api_token, client).await?;
 674
 675    let mut models: Vec<Model> = all_models
 676        .into_iter()
 677        .filter(|model| {
 678            model.model_picker_enabled
 679                && model.capabilities.model_type.as_str() == "chat"
 680                && model
 681                    .policy
 682                    .as_ref()
 683                    .is_none_or(|policy| policy.state == "enabled")
 684        })
 685        .dedup_by(|a, b| a.capabilities.family == b.capabilities.family)
 686        .collect();
 687
 688    if let Some(default_model_position) = models.iter().position(|model| model.is_chat_default) {
 689        let default_model = models.remove(default_model_position);
 690        models.insert(0, default_model);
 691    }
 692
 693    Ok(models)
 694}
 695
 696async fn request_models(
 697    models_url: Arc<str>,
 698    api_token: String,
 699    client: Arc<dyn HttpClient>,
 700) -> Result<Vec<Model>> {
 701    let request_builder = HttpRequest::builder()
 702        .method(Method::GET)
 703        .uri(models_url.as_ref())
 704        .header("Authorization", format!("Bearer {}", api_token))
 705        .header("Content-Type", "application/json")
 706        .header("Copilot-Integration-Id", "vscode-chat")
 707        .header("Editor-Version", "vscode/1.103.2")
 708        .header("x-github-api-version", "2025-05-01");
 709
 710    let request = request_builder.body(AsyncBody::empty())?;
 711
 712    let mut response = client.send(request).await?;
 713
 714    anyhow::ensure!(
 715        response.status().is_success(),
 716        "Failed to request models: {}",
 717        response.status()
 718    );
 719    let mut body = Vec::new();
 720    response.body_mut().read_to_end(&mut body).await?;
 721
 722    let body_str = std::str::from_utf8(&body)?;
 723
 724    let models = serde_json::from_str::<ModelSchema>(body_str)?.data;
 725
 726    Ok(models)
 727}
 728
 729async fn request_api_token(
 730    oauth_token: &str,
 731    auth_url: Arc<str>,
 732    client: Arc<dyn HttpClient>,
 733) -> Result<ApiToken> {
 734    let request_builder = HttpRequest::builder()
 735        .method(Method::GET)
 736        .uri(auth_url.as_ref())
 737        .header("Authorization", format!("token {}", oauth_token))
 738        .header("Accept", "application/json");
 739
 740    let request = request_builder.body(AsyncBody::empty())?;
 741
 742    let mut response = client.send(request).await?;
 743
 744    if response.status().is_success() {
 745        let mut body = Vec::new();
 746        response.body_mut().read_to_end(&mut body).await?;
 747
 748        let body_str = std::str::from_utf8(&body)?;
 749
 750        let parsed: ApiTokenResponse = serde_json::from_str(body_str)?;
 751        ApiToken::try_from(parsed)
 752    } else {
 753        let mut body = Vec::new();
 754        response.body_mut().read_to_end(&mut body).await?;
 755
 756        let body_str = std::str::from_utf8(&body)?;
 757        anyhow::bail!("Failed to request API token: {body_str}");
 758    }
 759}
 760
 761fn extract_oauth_token(contents: String, domain: &str) -> Option<String> {
 762    serde_json::from_str::<serde_json::Value>(&contents)
 763        .map(|v| {
 764            v.as_object().and_then(|obj| {
 765                obj.iter().find_map(|(key, value)| {
 766                    if key.starts_with(domain) {
 767                        value["oauth_token"].as_str().map(|v| v.to_string())
 768                    } else {
 769                        None
 770                    }
 771                })
 772            })
 773        })
 774        .ok()
 775        .flatten()
 776}
 777
 778async fn stream_completion(
 779    client: Arc<dyn HttpClient>,
 780    api_key: String,
 781    completion_url: Arc<str>,
 782    request: Request,
 783    is_user_initiated: bool,
 784) -> Result<BoxStream<'static, Result<ResponseEvent>>> {
 785    let is_vision_request = request.messages.iter().any(|message| match message {
 786      ChatMessage::User { content }
 787      | ChatMessage::Assistant { content, .. }
 788      | ChatMessage::Tool { content, .. } => {
 789          matches!(content, ChatMessageContent::Multipart(parts) if parts.iter().any(|part| matches!(part, ChatMessagePart::Image { .. })))
 790      }
 791      _ => false,
 792  });
 793
 794    let request_initiator = if is_user_initiated { "user" } else { "agent" };
 795
 796    let request_builder = HttpRequest::builder()
 797        .method(Method::POST)
 798        .uri(completion_url.as_ref())
 799        .header(
 800            "Editor-Version",
 801            format!(
 802                "Zed/{}",
 803                option_env!("CARGO_PKG_VERSION").unwrap_or("unknown")
 804            ),
 805        )
 806        .header("Authorization", format!("Bearer {}", api_key))
 807        .header("Content-Type", "application/json")
 808        .header("Copilot-Integration-Id", "vscode-chat")
 809        .header("X-Initiator", request_initiator)
 810        .when(is_vision_request, |builder| {
 811            builder.header("Copilot-Vision-Request", is_vision_request.to_string())
 812        });
 813
 814    let is_streaming = request.stream;
 815
 816    let json = serde_json::to_string(&request)?;
 817    let request = request_builder.body(AsyncBody::from(json))?;
 818    let mut response = client.send(request).await?;
 819
 820    if !response.status().is_success() {
 821        let mut body = Vec::new();
 822        response.body_mut().read_to_end(&mut body).await?;
 823        let body_str = std::str::from_utf8(&body)?;
 824        anyhow::bail!(
 825            "Failed to connect to API: {} {}",
 826            response.status(),
 827            body_str
 828        );
 829    }
 830
 831    if is_streaming {
 832        let reader = BufReader::new(response.into_body());
 833        Ok(reader
 834            .lines()
 835            .filter_map(|line| async move {
 836                match line {
 837                    Ok(line) => {
 838                        let line = line.strip_prefix("data: ")?;
 839                        if line.starts_with("[DONE]") {
 840                            return None;
 841                        }
 842
 843                        match serde_json::from_str::<ResponseEvent>(line) {
 844                            Ok(response) => {
 845                                if response.choices.is_empty() {
 846                                    None
 847                                } else {
 848                                    Some(Ok(response))
 849                                }
 850                            }
 851                            Err(error) => Some(Err(anyhow!(error))),
 852                        }
 853                    }
 854                    Err(error) => Some(Err(anyhow!(error))),
 855                }
 856            })
 857            .boxed())
 858    } else {
 859        let mut body = Vec::new();
 860        response.body_mut().read_to_end(&mut body).await?;
 861        let body_str = std::str::from_utf8(&body)?;
 862        let response: ResponseEvent = serde_json::from_str(body_str)?;
 863
 864        Ok(futures::stream::once(async move { Ok(response) }).boxed())
 865    }
 866}
 867
 868#[cfg(test)]
 869mod tests {
 870    use super::*;
 871
 872    #[test]
 873    fn test_resilient_model_schema_deserialize() {
 874        let json = r#"{
 875              "data": [
 876                {
 877                  "billing": {
 878                    "is_premium": false,
 879                    "multiplier": 0
 880                  },
 881                  "capabilities": {
 882                    "family": "gpt-4",
 883                    "limits": {
 884                      "max_context_window_tokens": 32768,
 885                      "max_output_tokens": 4096,
 886                      "max_prompt_tokens": 32768
 887                    },
 888                    "object": "model_capabilities",
 889                    "supports": { "streaming": true, "tool_calls": true },
 890                    "tokenizer": "cl100k_base",
 891                    "type": "chat"
 892                  },
 893                  "id": "gpt-4",
 894                  "is_chat_default": false,
 895                  "is_chat_fallback": false,
 896                  "model_picker_enabled": false,
 897                  "name": "GPT 4",
 898                  "object": "model",
 899                  "preview": false,
 900                  "vendor": "Azure OpenAI",
 901                  "version": "gpt-4-0613"
 902                },
 903                {
 904                    "some-unknown-field": 123
 905                },
 906                {
 907                  "billing": {
 908                    "is_premium": true,
 909                    "multiplier": 1,
 910                    "restricted_to": [
 911                      "pro",
 912                      "pro_plus",
 913                      "business",
 914                      "enterprise"
 915                    ]
 916                  },
 917                  "capabilities": {
 918                    "family": "claude-3.7-sonnet",
 919                    "limits": {
 920                      "max_context_window_tokens": 200000,
 921                      "max_output_tokens": 16384,
 922                      "max_prompt_tokens": 90000,
 923                      "vision": {
 924                        "max_prompt_image_size": 3145728,
 925                        "max_prompt_images": 1,
 926                        "supported_media_types": ["image/jpeg", "image/png", "image/webp"]
 927                      }
 928                    },
 929                    "object": "model_capabilities",
 930                    "supports": {
 931                      "parallel_tool_calls": true,
 932                      "streaming": true,
 933                      "tool_calls": true,
 934                      "vision": true
 935                    },
 936                    "tokenizer": "o200k_base",
 937                    "type": "chat"
 938                  },
 939                  "id": "claude-3.7-sonnet",
 940                  "is_chat_default": false,
 941                  "is_chat_fallback": false,
 942                  "model_picker_enabled": true,
 943                  "name": "Claude 3.7 Sonnet",
 944                  "object": "model",
 945                  "policy": {
 946                    "state": "enabled",
 947                    "terms": "Enable access to the latest Claude 3.7 Sonnet model from Anthropic. [Learn more about how GitHub Copilot serves Claude 3.7 Sonnet](https://docs.github.com/copilot/using-github-copilot/using-claude-sonnet-in-github-copilot)."
 948                  },
 949                  "preview": false,
 950                  "vendor": "Anthropic",
 951                  "version": "claude-3.7-sonnet"
 952                }
 953              ],
 954              "object": "list"
 955            }"#;
 956
 957        let schema: ModelSchema = serde_json::from_str(json).unwrap();
 958
 959        assert_eq!(schema.data.len(), 2);
 960        assert_eq!(schema.data[0].id, "gpt-4");
 961        assert_eq!(schema.data[1].id, "claude-3.7-sonnet");
 962    }
 963
 964    #[test]
 965    fn test_unknown_vendor_resilience() {
 966        let json = r#"{
 967              "data": [
 968                {
 969                  "billing": {
 970                    "is_premium": false,
 971                    "multiplier": 1
 972                  },
 973                  "capabilities": {
 974                    "family": "future-model",
 975                    "limits": {
 976                      "max_context_window_tokens": 128000,
 977                      "max_output_tokens": 8192,
 978                      "max_prompt_tokens": 120000
 979                    },
 980                    "object": "model_capabilities",
 981                    "supports": { "streaming": true, "tool_calls": true },
 982                    "type": "chat"
 983                  },
 984                  "id": "future-model-v1",
 985                  "is_chat_default": false,
 986                  "is_chat_fallback": false,
 987                  "model_picker_enabled": true,
 988                  "name": "Future Model v1",
 989                  "object": "model",
 990                  "preview": false,
 991                  "vendor": "SomeNewVendor",
 992                  "version": "v1.0"
 993                }
 994              ],
 995              "object": "list"
 996            }"#;
 997
 998        let schema: ModelSchema = serde_json::from_str(json).unwrap();
 999
1000        assert_eq!(schema.data.len(), 1);
1001        assert_eq!(schema.data[0].id, "future-model-v1");
1002        assert_eq!(schema.data[0].vendor, ModelVendor::Unknown);
1003    }
1004}