copilot_chat.rs

   1use std::path::PathBuf;
   2use std::sync::Arc;
   3use std::sync::OnceLock;
   4
   5use anyhow::Context as _;
   6use anyhow::{Result, anyhow};
   7use chrono::DateTime;
   8use collections::HashSet;
   9use fs::Fs;
  10use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
  11use gpui::WeakEntity;
  12use gpui::{App, AsyncApp, Global, prelude::*};
  13use http_client::HttpRequestExt;
  14use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
  15use itertools::Itertools;
  16use paths::home_dir;
  17use serde::{Deserialize, Serialize};
  18
  19use crate::copilot_responses as responses;
  20use settings::watch_config_dir;
  21
  22pub const COPILOT_OAUTH_ENV_VAR: &str = "GH_COPILOT_TOKEN";
  23
  24#[derive(Default, Clone, Debug, PartialEq)]
  25pub struct CopilotChatConfiguration {
  26    pub enterprise_uri: Option<String>,
  27}
  28
  29impl CopilotChatConfiguration {
  30    pub fn token_url(&self) -> String {
  31        if let Some(enterprise_uri) = &self.enterprise_uri {
  32            let domain = Self::parse_domain(enterprise_uri);
  33            format!("https://api.{}/copilot_internal/v2/token", domain)
  34        } else {
  35            "https://api.github.com/copilot_internal/v2/token".to_string()
  36        }
  37    }
  38
  39    pub fn oauth_domain(&self) -> String {
  40        if let Some(enterprise_uri) = &self.enterprise_uri {
  41            Self::parse_domain(enterprise_uri)
  42        } else {
  43            "github.com".to_string()
  44        }
  45    }
  46
  47    pub fn chat_completions_url_from_endpoint(&self, endpoint: &str) -> String {
  48        format!("{}/chat/completions", endpoint)
  49    }
  50
  51    pub fn responses_url_from_endpoint(&self, endpoint: &str) -> String {
  52        format!("{}/responses", endpoint)
  53    }
  54
  55    pub fn models_url_from_endpoint(&self, endpoint: &str) -> String {
  56        format!("{}/models", endpoint)
  57    }
  58
  59    fn parse_domain(enterprise_uri: &str) -> String {
  60        let uri = enterprise_uri.trim_end_matches('/');
  61
  62        if let Some(domain) = uri.strip_prefix("https://") {
  63            domain.split('/').next().unwrap_or(domain).to_string()
  64        } else if let Some(domain) = uri.strip_prefix("http://") {
  65            domain.split('/').next().unwrap_or(domain).to_string()
  66        } else {
  67            uri.split('/').next().unwrap_or(uri).to_string()
  68        }
  69    }
  70}
  71
  72#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
  73#[serde(rename_all = "lowercase")]
  74pub enum Role {
  75    User,
  76    Assistant,
  77    System,
  78}
  79
  80#[derive(Deserialize, Serialize, Debug, Clone, PartialEq)]
  81pub enum ModelSupportedEndpoint {
  82    #[serde(rename = "/chat/completions")]
  83    ChatCompletions,
  84    #[serde(rename = "/responses")]
  85    Responses,
  86}
  87
  88#[derive(Deserialize)]
  89struct ModelSchema {
  90    #[serde(deserialize_with = "deserialize_models_skip_errors")]
  91    data: Vec<Model>,
  92}
  93
  94fn deserialize_models_skip_errors<'de, D>(deserializer: D) -> Result<Vec<Model>, D::Error>
  95where
  96    D: serde::Deserializer<'de>,
  97{
  98    let raw_values = Vec::<serde_json::Value>::deserialize(deserializer)?;
  99    let models = raw_values
 100        .into_iter()
 101        .filter_map(|value| match serde_json::from_value::<Model>(value) {
 102            Ok(model) => Some(model),
 103            Err(err) => {
 104                log::warn!("GitHub Copilot Chat model failed to deserialize: {:?}", err);
 105                None
 106            }
 107        })
 108        .collect();
 109
 110    Ok(models)
 111}
 112
 113#[derive(Clone, Serialize, Deserialize, Debug, PartialEq)]
 114pub struct Model {
 115    billing: ModelBilling,
 116    capabilities: ModelCapabilities,
 117    id: String,
 118    name: String,
 119    policy: Option<ModelPolicy>,
 120    vendor: ModelVendor,
 121    is_chat_default: bool,
 122    // The model with this value true is selected by VSCode copilot if a premium request limit is
 123    // reached. Zed does not currently implement this behaviour
 124    is_chat_fallback: bool,
 125    model_picker_enabled: bool,
 126    #[serde(default)]
 127    supported_endpoints: Vec<ModelSupportedEndpoint>,
 128}
 129
 130#[derive(Clone, Serialize, Deserialize, Debug, PartialEq)]
 131struct ModelBilling {
 132    is_premium: bool,
 133    multiplier: f64,
 134    // List of plans a model is restricted to
 135    // Field is not present if a model is available for all plans
 136    #[serde(default)]
 137    restricted_to: Option<Vec<String>>,
 138}
 139
 140#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
 141struct ModelCapabilities {
 142    family: String,
 143    #[serde(default)]
 144    limits: ModelLimits,
 145    supports: ModelSupportedFeatures,
 146    #[serde(rename = "type")]
 147    model_type: String,
 148    #[serde(default)]
 149    tokenizer: Option<String>,
 150}
 151
 152#[derive(Default, Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
 153struct ModelLimits {
 154    #[serde(default)]
 155    max_context_window_tokens: usize,
 156    #[serde(default)]
 157    max_output_tokens: usize,
 158    #[serde(default)]
 159    max_prompt_tokens: u64,
 160}
 161
 162#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
 163struct ModelPolicy {
 164    state: String,
 165}
 166
 167#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
 168struct ModelSupportedFeatures {
 169    #[serde(default)]
 170    streaming: bool,
 171    #[serde(default)]
 172    tool_calls: bool,
 173    #[serde(default)]
 174    parallel_tool_calls: bool,
 175    #[serde(default)]
 176    vision: bool,
 177}
 178
 179#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
 180pub enum ModelVendor {
 181    // Azure OpenAI should have no functional difference from OpenAI in Copilot Chat
 182    #[serde(alias = "Azure OpenAI")]
 183    OpenAI,
 184    Google,
 185    Anthropic,
 186    #[serde(rename = "xAI")]
 187    XAI,
 188    /// Unknown vendor that we don't explicitly support yet
 189    #[serde(other)]
 190    Unknown,
 191}
 192
 193#[derive(Serialize, Deserialize, Debug, Eq, PartialEq, Clone)]
 194#[serde(tag = "type")]
 195pub enum ChatMessagePart {
 196    #[serde(rename = "text")]
 197    Text { text: String },
 198    #[serde(rename = "image_url")]
 199    Image { image_url: ImageUrl },
 200}
 201
 202#[derive(Serialize, Deserialize, Debug, Eq, PartialEq, Clone)]
 203pub struct ImageUrl {
 204    pub url: String,
 205}
 206
 207impl Model {
 208    pub fn uses_streaming(&self) -> bool {
 209        self.capabilities.supports.streaming
 210    }
 211
 212    pub fn id(&self) -> &str {
 213        self.id.as_str()
 214    }
 215
 216    pub fn display_name(&self) -> &str {
 217        self.name.as_str()
 218    }
 219
 220    pub fn max_token_count(&self) -> u64 {
 221        self.capabilities.limits.max_prompt_tokens
 222    }
 223
 224    pub fn supports_tools(&self) -> bool {
 225        self.capabilities.supports.tool_calls
 226    }
 227
 228    pub fn vendor(&self) -> ModelVendor {
 229        self.vendor
 230    }
 231
 232    pub fn supports_vision(&self) -> bool {
 233        self.capabilities.supports.vision
 234    }
 235
 236    pub fn supports_parallel_tool_calls(&self) -> bool {
 237        self.capabilities.supports.parallel_tool_calls
 238    }
 239
 240    pub fn tokenizer(&self) -> Option<&str> {
 241        self.capabilities.tokenizer.as_deref()
 242    }
 243
 244    pub fn supports_response(&self) -> bool {
 245        self.supported_endpoints.len() > 0
 246            && !self
 247                .supported_endpoints
 248                .contains(&ModelSupportedEndpoint::ChatCompletions)
 249            && self
 250                .supported_endpoints
 251                .contains(&ModelSupportedEndpoint::Responses)
 252    }
 253}
 254
 255#[derive(Serialize, Deserialize)]
 256pub struct Request {
 257    pub intent: bool,
 258    pub n: usize,
 259    pub stream: bool,
 260    pub temperature: f32,
 261    pub model: String,
 262    pub messages: Vec<ChatMessage>,
 263    #[serde(default, skip_serializing_if = "Vec::is_empty")]
 264    pub tools: Vec<Tool>,
 265    #[serde(default, skip_serializing_if = "Option::is_none")]
 266    pub tool_choice: Option<ToolChoice>,
 267}
 268
 269#[derive(Serialize, Deserialize)]
 270pub struct Function {
 271    pub name: String,
 272    pub description: String,
 273    pub parameters: serde_json::Value,
 274}
 275
 276#[derive(Serialize, Deserialize)]
 277#[serde(tag = "type", rename_all = "snake_case")]
 278pub enum Tool {
 279    Function { function: Function },
 280}
 281
 282#[derive(Serialize, Deserialize, Debug)]
 283#[serde(rename_all = "lowercase")]
 284pub enum ToolChoice {
 285    Auto,
 286    Any,
 287    None,
 288}
 289
 290#[derive(Serialize, Deserialize, Debug)]
 291#[serde(tag = "role", rename_all = "lowercase")]
 292pub enum ChatMessage {
 293    Assistant {
 294        content: ChatMessageContent,
 295        #[serde(default, skip_serializing_if = "Vec::is_empty")]
 296        tool_calls: Vec<ToolCall>,
 297        #[serde(default, skip_serializing_if = "Option::is_none")]
 298        reasoning_opaque: Option<String>,
 299        #[serde(default, skip_serializing_if = "Option::is_none")]
 300        reasoning_text: Option<String>,
 301    },
 302    User {
 303        content: ChatMessageContent,
 304    },
 305    System {
 306        content: String,
 307    },
 308    Tool {
 309        content: ChatMessageContent,
 310        tool_call_id: String,
 311    },
 312}
 313
 314#[derive(Debug, Serialize, Deserialize)]
 315#[serde(untagged)]
 316pub enum ChatMessageContent {
 317    Plain(String),
 318    Multipart(Vec<ChatMessagePart>),
 319}
 320
 321impl ChatMessageContent {
 322    pub fn empty() -> Self {
 323        ChatMessageContent::Multipart(vec![])
 324    }
 325}
 326
 327impl From<Vec<ChatMessagePart>> for ChatMessageContent {
 328    fn from(mut parts: Vec<ChatMessagePart>) -> Self {
 329        if let [ChatMessagePart::Text { text }] = parts.as_mut_slice() {
 330            ChatMessageContent::Plain(std::mem::take(text))
 331        } else {
 332            ChatMessageContent::Multipart(parts)
 333        }
 334    }
 335}
 336
 337impl From<String> for ChatMessageContent {
 338    fn from(text: String) -> Self {
 339        ChatMessageContent::Plain(text)
 340    }
 341}
 342
 343#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
 344pub struct ToolCall {
 345    pub id: String,
 346    #[serde(flatten)]
 347    pub content: ToolCallContent,
 348}
 349
 350#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
 351#[serde(tag = "type", rename_all = "lowercase")]
 352pub enum ToolCallContent {
 353    Function { function: FunctionContent },
 354}
 355
 356#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
 357pub struct FunctionContent {
 358    pub name: String,
 359    pub arguments: String,
 360    #[serde(default, skip_serializing_if = "Option::is_none")]
 361    pub thought_signature: Option<String>,
 362}
 363
 364#[derive(Deserialize, Debug)]
 365#[serde(tag = "type", rename_all = "snake_case")]
 366pub struct ResponseEvent {
 367    pub choices: Vec<ResponseChoice>,
 368    pub id: String,
 369    pub usage: Option<Usage>,
 370}
 371
 372#[derive(Deserialize, Debug)]
 373pub struct Usage {
 374    pub completion_tokens: u64,
 375    pub prompt_tokens: u64,
 376    pub total_tokens: u64,
 377}
 378
 379#[derive(Debug, Deserialize)]
 380pub struct ResponseChoice {
 381    pub index: Option<usize>,
 382    pub finish_reason: Option<String>,
 383    pub delta: Option<ResponseDelta>,
 384    pub message: Option<ResponseDelta>,
 385}
 386
 387#[derive(Debug, Deserialize)]
 388pub struct ResponseDelta {
 389    pub content: Option<String>,
 390    pub role: Option<Role>,
 391    #[serde(default)]
 392    pub tool_calls: Vec<ToolCallChunk>,
 393    pub reasoning_opaque: Option<String>,
 394    pub reasoning_text: Option<String>,
 395}
 396#[derive(Deserialize, Debug, Eq, PartialEq)]
 397pub struct ToolCallChunk {
 398    pub index: Option<usize>,
 399    pub id: Option<String>,
 400    pub function: Option<FunctionChunk>,
 401}
 402
 403#[derive(Deserialize, Debug, Eq, PartialEq)]
 404pub struct FunctionChunk {
 405    pub name: Option<String>,
 406    pub arguments: Option<String>,
 407    pub thought_signature: Option<String>,
 408}
 409
 410#[derive(Deserialize)]
 411struct ApiTokenResponse {
 412    token: String,
 413    expires_at: i64,
 414    endpoints: ApiTokenResponseEndpoints,
 415}
 416
 417#[derive(Deserialize)]
 418struct ApiTokenResponseEndpoints {
 419    api: String,
 420}
 421
 422#[derive(Clone)]
 423struct ApiToken {
 424    api_key: String,
 425    expires_at: DateTime<chrono::Utc>,
 426    api_endpoint: String,
 427}
 428
 429impl ApiToken {
 430    pub fn remaining_seconds(&self) -> i64 {
 431        self.expires_at
 432            .timestamp()
 433            .saturating_sub(chrono::Utc::now().timestamp())
 434    }
 435}
 436
 437impl TryFrom<ApiTokenResponse> for ApiToken {
 438    type Error = anyhow::Error;
 439
 440    fn try_from(response: ApiTokenResponse) -> Result<Self, Self::Error> {
 441        let expires_at =
 442            DateTime::from_timestamp(response.expires_at, 0).context("invalid expires_at")?;
 443
 444        Ok(Self {
 445            api_key: response.token,
 446            expires_at,
 447            api_endpoint: response.endpoints.api,
 448        })
 449    }
 450}
 451
 452struct GlobalCopilotChat(gpui::Entity<CopilotChat>);
 453
 454impl Global for GlobalCopilotChat {}
 455
 456pub struct CopilotChat {
 457    oauth_token: Option<String>,
 458    api_token: Option<ApiToken>,
 459    configuration: CopilotChatConfiguration,
 460    models: Option<Vec<Model>>,
 461    client: Arc<dyn HttpClient>,
 462}
 463
 464pub fn init(
 465    fs: Arc<dyn Fs>,
 466    client: Arc<dyn HttpClient>,
 467    configuration: CopilotChatConfiguration,
 468    cx: &mut App,
 469) {
 470    let copilot_chat = cx.new(|cx| CopilotChat::new(fs, client, configuration, cx));
 471    cx.set_global(GlobalCopilotChat(copilot_chat));
 472}
 473
 474pub fn copilot_chat_config_dir() -> &'static PathBuf {
 475    static COPILOT_CHAT_CONFIG_DIR: OnceLock<PathBuf> = OnceLock::new();
 476
 477    COPILOT_CHAT_CONFIG_DIR.get_or_init(|| {
 478        let config_dir = if cfg!(target_os = "windows") {
 479            dirs::data_local_dir().expect("failed to determine LocalAppData directory")
 480        } else {
 481            std::env::var("XDG_CONFIG_HOME")
 482                .map(PathBuf::from)
 483                .unwrap_or_else(|_| home_dir().join(".config"))
 484        };
 485
 486        config_dir.join("github-copilot")
 487    })
 488}
 489
 490fn copilot_chat_config_paths() -> [PathBuf; 2] {
 491    let base_dir = copilot_chat_config_dir();
 492    [base_dir.join("hosts.json"), base_dir.join("apps.json")]
 493}
 494
 495impl CopilotChat {
 496    pub fn global(cx: &App) -> Option<gpui::Entity<Self>> {
 497        cx.try_global::<GlobalCopilotChat>()
 498            .map(|model| model.0.clone())
 499    }
 500
 501    fn new(
 502        fs: Arc<dyn Fs>,
 503        client: Arc<dyn HttpClient>,
 504        configuration: CopilotChatConfiguration,
 505        cx: &mut Context<Self>,
 506    ) -> Self {
 507        let config_paths: HashSet<PathBuf> = copilot_chat_config_paths().into_iter().collect();
 508        let dir_path = copilot_chat_config_dir();
 509
 510        cx.spawn(async move |this, cx| {
 511            let mut parent_watch_rx = watch_config_dir(
 512                cx.background_executor(),
 513                fs.clone(),
 514                dir_path.clone(),
 515                config_paths,
 516            );
 517            while let Some(contents) = parent_watch_rx.next().await {
 518                let oauth_domain =
 519                    this.read_with(cx, |this, _| this.configuration.oauth_domain())?;
 520                let oauth_token = extract_oauth_token(contents, &oauth_domain);
 521
 522                this.update(cx, |this, cx| {
 523                    this.oauth_token = oauth_token.clone();
 524                    cx.notify();
 525                })?;
 526
 527                if oauth_token.is_some() {
 528                    Self::update_models(&this, cx).await?;
 529                }
 530            }
 531            anyhow::Ok(())
 532        })
 533        .detach_and_log_err(cx);
 534
 535        let this = Self {
 536            oauth_token: std::env::var(COPILOT_OAUTH_ENV_VAR).ok(),
 537            api_token: None,
 538            models: None,
 539            configuration,
 540            client,
 541        };
 542
 543        if this.oauth_token.is_some() {
 544            cx.spawn(async move |this, cx| Self::update_models(&this, cx).await)
 545                .detach_and_log_err(cx);
 546        }
 547
 548        this
 549    }
 550
 551    async fn update_models(this: &WeakEntity<Self>, cx: &mut AsyncApp) -> Result<()> {
 552        let (oauth_token, client, configuration) = this.read_with(cx, |this, _| {
 553            (
 554                this.oauth_token.clone(),
 555                this.client.clone(),
 556                this.configuration.clone(),
 557            )
 558        })?;
 559
 560        let oauth_token = oauth_token
 561            .ok_or_else(|| anyhow!("OAuth token is missing while updating Copilot Chat models"))?;
 562
 563        let token_url = configuration.token_url();
 564        let api_token = request_api_token(&oauth_token, token_url.into(), client.clone()).await?;
 565
 566        let models_url = configuration.models_url_from_endpoint(&api_token.api_endpoint);
 567        let models =
 568            get_models(models_url.into(), api_token.api_key.clone(), client.clone()).await?;
 569
 570        this.update(cx, |this, cx| {
 571            this.api_token = Some(api_token);
 572            this.models = Some(models);
 573            cx.notify();
 574        })?;
 575        anyhow::Ok(())
 576    }
 577
 578    pub fn is_authenticated(&self) -> bool {
 579        self.oauth_token.is_some()
 580    }
 581
 582    pub fn models(&self) -> Option<&[Model]> {
 583        self.models.as_deref()
 584    }
 585
 586    pub async fn stream_completion(
 587        request: Request,
 588        is_user_initiated: bool,
 589        mut cx: AsyncApp,
 590    ) -> Result<BoxStream<'static, Result<ResponseEvent>>> {
 591        let (client, token, configuration) = Self::get_auth_details(&mut cx).await?;
 592
 593        let api_url = configuration.chat_completions_url_from_endpoint(&token.api_endpoint);
 594        stream_completion(
 595            client.clone(),
 596            token.api_key,
 597            api_url.into(),
 598            request,
 599            is_user_initiated,
 600        )
 601        .await
 602    }
 603
 604    pub async fn stream_response(
 605        request: responses::Request,
 606        is_user_initiated: bool,
 607        mut cx: AsyncApp,
 608    ) -> Result<BoxStream<'static, Result<responses::StreamEvent>>> {
 609        let (client, token, configuration) = Self::get_auth_details(&mut cx).await?;
 610
 611        let api_url = configuration.responses_url_from_endpoint(&token.api_endpoint);
 612        responses::stream_response(
 613            client.clone(),
 614            token.api_key,
 615            api_url,
 616            request,
 617            is_user_initiated,
 618        )
 619        .await
 620    }
 621
 622    async fn get_auth_details(
 623        cx: &mut AsyncApp,
 624    ) -> Result<(Arc<dyn HttpClient>, ApiToken, CopilotChatConfiguration)> {
 625        let this = cx
 626            .update(|cx| Self::global(cx))
 627            .context("Copilot chat is not enabled")?;
 628
 629        let (oauth_token, api_token, client, configuration) = this.read_with(cx, |this, _| {
 630            (
 631                this.oauth_token.clone(),
 632                this.api_token.clone(),
 633                this.client.clone(),
 634                this.configuration.clone(),
 635            )
 636        });
 637
 638        let oauth_token = oauth_token.context("No OAuth token available")?;
 639
 640        let token = match api_token {
 641            Some(api_token) if api_token.remaining_seconds() > 5 * 60 => api_token,
 642            _ => {
 643                let token_url = configuration.token_url();
 644                let token =
 645                    request_api_token(&oauth_token, token_url.into(), client.clone()).await?;
 646                this.update(cx, |this, cx| {
 647                    this.api_token = Some(token.clone());
 648                    cx.notify();
 649                });
 650                token
 651            }
 652        };
 653
 654        Ok((client, token, configuration))
 655    }
 656
 657    pub fn set_configuration(
 658        &mut self,
 659        configuration: CopilotChatConfiguration,
 660        cx: &mut Context<Self>,
 661    ) {
 662        let same_configuration = self.configuration == configuration;
 663        self.configuration = configuration;
 664        if !same_configuration {
 665            self.api_token = None;
 666            cx.spawn(async move |this, cx| {
 667                Self::update_models(&this, cx).await?;
 668                Ok::<_, anyhow::Error>(())
 669            })
 670            .detach();
 671        }
 672    }
 673}
 674
 675async fn get_models(
 676    models_url: Arc<str>,
 677    api_token: String,
 678    client: Arc<dyn HttpClient>,
 679) -> Result<Vec<Model>> {
 680    let all_models = request_models(models_url, api_token, client).await?;
 681
 682    let mut models: Vec<Model> = all_models
 683        .into_iter()
 684        .filter(|model| {
 685            model.model_picker_enabled
 686                && model.capabilities.model_type.as_str() == "chat"
 687                && model
 688                    .policy
 689                    .as_ref()
 690                    .is_none_or(|policy| policy.state == "enabled")
 691        })
 692        .dedup_by(|a, b| a.capabilities.family == b.capabilities.family)
 693        .collect();
 694
 695    if let Some(default_model_position) = models.iter().position(|model| model.is_chat_default) {
 696        let default_model = models.remove(default_model_position);
 697        models.insert(0, default_model);
 698    }
 699
 700    Ok(models)
 701}
 702
 703async fn request_models(
 704    models_url: Arc<str>,
 705    api_token: String,
 706    client: Arc<dyn HttpClient>,
 707) -> Result<Vec<Model>> {
 708    let request_builder = HttpRequest::builder()
 709        .method(Method::GET)
 710        .uri(models_url.as_ref())
 711        .header("Authorization", format!("Bearer {}", api_token))
 712        .header("Content-Type", "application/json")
 713        .header("Copilot-Integration-Id", "vscode-chat")
 714        .header("Editor-Version", "vscode/1.103.2")
 715        .header("x-github-api-version", "2025-05-01");
 716
 717    let request = request_builder.body(AsyncBody::empty())?;
 718
 719    let mut response = client.send(request).await?;
 720
 721    anyhow::ensure!(
 722        response.status().is_success(),
 723        "Failed to request models: {}",
 724        response.status()
 725    );
 726    let mut body = Vec::new();
 727    response.body_mut().read_to_end(&mut body).await?;
 728
 729    let body_str = std::str::from_utf8(&body)?;
 730
 731    let models = serde_json::from_str::<ModelSchema>(body_str)?.data;
 732
 733    Ok(models)
 734}
 735
 736async fn request_api_token(
 737    oauth_token: &str,
 738    auth_url: Arc<str>,
 739    client: Arc<dyn HttpClient>,
 740) -> Result<ApiToken> {
 741    let request_builder = HttpRequest::builder()
 742        .method(Method::GET)
 743        .uri(auth_url.as_ref())
 744        .header("Authorization", format!("token {}", oauth_token))
 745        .header("Accept", "application/json");
 746
 747    let request = request_builder.body(AsyncBody::empty())?;
 748
 749    let mut response = client.send(request).await?;
 750
 751    if response.status().is_success() {
 752        let mut body = Vec::new();
 753        response.body_mut().read_to_end(&mut body).await?;
 754
 755        let body_str = std::str::from_utf8(&body)?;
 756
 757        let parsed: ApiTokenResponse = serde_json::from_str(body_str)?;
 758        ApiToken::try_from(parsed)
 759    } else {
 760        let mut body = Vec::new();
 761        response.body_mut().read_to_end(&mut body).await?;
 762
 763        let body_str = std::str::from_utf8(&body)?;
 764        anyhow::bail!("Failed to request API token: {body_str}");
 765    }
 766}
 767
 768fn extract_oauth_token(contents: String, domain: &str) -> Option<String> {
 769    serde_json::from_str::<serde_json::Value>(&contents)
 770        .map(|v| {
 771            v.as_object().and_then(|obj| {
 772                obj.iter().find_map(|(key, value)| {
 773                    if key.starts_with(domain) {
 774                        value["oauth_token"].as_str().map(|v| v.to_string())
 775                    } else {
 776                        None
 777                    }
 778                })
 779            })
 780        })
 781        .ok()
 782        .flatten()
 783}
 784
 785async fn stream_completion(
 786    client: Arc<dyn HttpClient>,
 787    api_key: String,
 788    completion_url: Arc<str>,
 789    request: Request,
 790    is_user_initiated: bool,
 791) -> Result<BoxStream<'static, Result<ResponseEvent>>> {
 792    let is_vision_request = request.messages.iter().any(|message| match message {
 793        ChatMessage::User { content }
 794        | ChatMessage::Assistant { content, .. }
 795        | ChatMessage::Tool { content, .. } => {
 796            matches!(content, ChatMessageContent::Multipart(parts) if parts.iter().any(|part| matches!(part, ChatMessagePart::Image { .. })))
 797        }
 798        _ => false,
 799    });
 800
 801    let request_initiator = if is_user_initiated { "user" } else { "agent" };
 802
 803    let request_builder = HttpRequest::builder()
 804        .method(Method::POST)
 805        .uri(completion_url.as_ref())
 806        .header(
 807            "Editor-Version",
 808            format!(
 809                "Zed/{}",
 810                option_env!("CARGO_PKG_VERSION").unwrap_or("unknown")
 811            ),
 812        )
 813        .header("Authorization", format!("Bearer {}", api_key))
 814        .header("Content-Type", "application/json")
 815        .header("Copilot-Integration-Id", "vscode-chat")
 816        .header("X-Initiator", request_initiator)
 817        .when(is_vision_request, |builder| {
 818            builder.header("Copilot-Vision-Request", is_vision_request.to_string())
 819        });
 820
 821    let is_streaming = request.stream;
 822
 823    let json = serde_json::to_string(&request)?;
 824    let request = request_builder.body(AsyncBody::from(json))?;
 825    let mut response = client.send(request).await?;
 826
 827    if !response.status().is_success() {
 828        let mut body = Vec::new();
 829        response.body_mut().read_to_end(&mut body).await?;
 830        let body_str = std::str::from_utf8(&body)?;
 831        anyhow::bail!(
 832            "Failed to connect to API: {} {}",
 833            response.status(),
 834            body_str
 835        );
 836    }
 837
 838    if is_streaming {
 839        let reader = BufReader::new(response.into_body());
 840        Ok(reader
 841            .lines()
 842            .filter_map(|line| async move {
 843                match line {
 844                    Ok(line) => {
 845                        let line = line.strip_prefix("data: ")?;
 846                        if line.starts_with("[DONE]") {
 847                            return None;
 848                        }
 849
 850                        match serde_json::from_str::<ResponseEvent>(line) {
 851                            Ok(response) => {
 852                                if response.choices.is_empty() {
 853                                    None
 854                                } else {
 855                                    Some(Ok(response))
 856                                }
 857                            }
 858                            Err(error) => Some(Err(anyhow!(error))),
 859                        }
 860                    }
 861                    Err(error) => Some(Err(anyhow!(error))),
 862                }
 863            })
 864            .boxed())
 865    } else {
 866        let mut body = Vec::new();
 867        response.body_mut().read_to_end(&mut body).await?;
 868        let body_str = std::str::from_utf8(&body)?;
 869        let response: ResponseEvent = serde_json::from_str(body_str)?;
 870
 871        Ok(futures::stream::once(async move { Ok(response) }).boxed())
 872    }
 873}
 874
 875#[cfg(test)]
 876mod tests {
 877    use super::*;
 878
 879    #[test]
 880    fn test_resilient_model_schema_deserialize() {
 881        let json = r#"{
 882              "data": [
 883                {
 884                  "billing": {
 885                    "is_premium": false,
 886                    "multiplier": 0
 887                  },
 888                  "capabilities": {
 889                    "family": "gpt-4",
 890                    "limits": {
 891                      "max_context_window_tokens": 32768,
 892                      "max_output_tokens": 4096,
 893                      "max_prompt_tokens": 32768
 894                    },
 895                    "object": "model_capabilities",
 896                    "supports": { "streaming": true, "tool_calls": true },
 897                    "tokenizer": "cl100k_base",
 898                    "type": "chat"
 899                  },
 900                  "id": "gpt-4",
 901                  "is_chat_default": false,
 902                  "is_chat_fallback": false,
 903                  "model_picker_enabled": false,
 904                  "name": "GPT 4",
 905                  "object": "model",
 906                  "preview": false,
 907                  "vendor": "Azure OpenAI",
 908                  "version": "gpt-4-0613"
 909                },
 910                {
 911                    "some-unknown-field": 123
 912                },
 913                {
 914                  "billing": {
 915                    "is_premium": true,
 916                    "multiplier": 1,
 917                    "restricted_to": [
 918                      "pro",
 919                      "pro_plus",
 920                      "business",
 921                      "enterprise"
 922                    ]
 923                  },
 924                  "capabilities": {
 925                    "family": "claude-3.7-sonnet",
 926                    "limits": {
 927                      "max_context_window_tokens": 200000,
 928                      "max_output_tokens": 16384,
 929                      "max_prompt_tokens": 90000,
 930                      "vision": {
 931                        "max_prompt_image_size": 3145728,
 932                        "max_prompt_images": 1,
 933                        "supported_media_types": ["image/jpeg", "image/png", "image/webp"]
 934                      }
 935                    },
 936                    "object": "model_capabilities",
 937                    "supports": {
 938                      "parallel_tool_calls": true,
 939                      "streaming": true,
 940                      "tool_calls": true,
 941                      "vision": true
 942                    },
 943                    "tokenizer": "o200k_base",
 944                    "type": "chat"
 945                  },
 946                  "id": "claude-3.7-sonnet",
 947                  "is_chat_default": false,
 948                  "is_chat_fallback": false,
 949                  "model_picker_enabled": true,
 950                  "name": "Claude 3.7 Sonnet",
 951                  "object": "model",
 952                  "policy": {
 953                    "state": "enabled",
 954                    "terms": "Enable access to the latest Claude 3.7 Sonnet model from Anthropic. [Learn more about how GitHub Copilot serves Claude 3.7 Sonnet](https://docs.github.com/copilot/using-github-copilot/using-claude-sonnet-in-github-copilot)."
 955                  },
 956                  "preview": false,
 957                  "vendor": "Anthropic",
 958                  "version": "claude-3.7-sonnet"
 959                }
 960              ],
 961              "object": "list"
 962            }"#;
 963
 964        let schema: ModelSchema = serde_json::from_str(json).unwrap();
 965
 966        assert_eq!(schema.data.len(), 2);
 967        assert_eq!(schema.data[0].id, "gpt-4");
 968        assert_eq!(schema.data[1].id, "claude-3.7-sonnet");
 969    }
 970
 971    #[test]
 972    fn test_unknown_vendor_resilience() {
 973        let json = r#"{
 974              "data": [
 975                {
 976                  "billing": {
 977                    "is_premium": false,
 978                    "multiplier": 1
 979                  },
 980                  "capabilities": {
 981                    "family": "future-model",
 982                    "limits": {
 983                      "max_context_window_tokens": 128000,
 984                      "max_output_tokens": 8192,
 985                      "max_prompt_tokens": 120000
 986                    },
 987                    "object": "model_capabilities",
 988                    "supports": { "streaming": true, "tool_calls": true },
 989                    "type": "chat"
 990                  },
 991                  "id": "future-model-v1",
 992                  "is_chat_default": false,
 993                  "is_chat_fallback": false,
 994                  "model_picker_enabled": true,
 995                  "name": "Future Model v1",
 996                  "object": "model",
 997                  "preview": false,
 998                  "vendor": "SomeNewVendor",
 999                  "version": "v1.0"
1000                }
1001              ],
1002              "object": "list"
1003            }"#;
1004
1005        let schema: ModelSchema = serde_json::from_str(json).unwrap();
1006
1007        assert_eq!(schema.data.len(), 1);
1008        assert_eq!(schema.data[0].id, "future-model-v1");
1009        assert_eq!(schema.data[0].vendor, ModelVendor::Unknown);
1010    }
1011}