copilot_chat.rs

   1use std::path::PathBuf;
   2use std::sync::Arc;
   3use std::sync::OnceLock;
   4
   5use anyhow::Context as _;
   6use anyhow::{Result, anyhow};
   7use chrono::DateTime;
   8use collections::HashSet;
   9use fs::Fs;
  10use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
  11use gpui::WeakEntity;
  12use gpui::{App, AsyncApp, Global, prelude::*};
  13use http_client::HttpRequestExt;
  14use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
  15use itertools::Itertools;
  16use paths::home_dir;
  17use serde::{Deserialize, Serialize};
  18
  19use crate::copilot_responses as responses;
  20use settings::watch_config_dir;
  21
  22pub const COPILOT_OAUTH_ENV_VAR: &str = "GH_COPILOT_TOKEN";
  23
  24#[derive(Default, Clone, Debug, PartialEq)]
  25pub struct CopilotChatConfiguration {
  26    pub enterprise_uri: Option<String>,
  27}
  28
  29impl CopilotChatConfiguration {
  30    pub fn token_url(&self) -> String {
  31        if let Some(enterprise_uri) = &self.enterprise_uri {
  32            let domain = Self::parse_domain(enterprise_uri);
  33            format!("https://api.{}/copilot_internal/v2/token", domain)
  34        } else {
  35            "https://api.github.com/copilot_internal/v2/token".to_string()
  36        }
  37    }
  38
  39    pub fn oauth_domain(&self) -> String {
  40        if let Some(enterprise_uri) = &self.enterprise_uri {
  41            Self::parse_domain(enterprise_uri)
  42        } else {
  43            "github.com".to_string()
  44        }
  45    }
  46
  47    pub fn chat_completions_url_from_endpoint(&self, endpoint: &str) -> String {
  48        format!("{}/chat/completions", endpoint)
  49    }
  50
  51    pub fn responses_url_from_endpoint(&self, endpoint: &str) -> String {
  52        format!("{}/responses", endpoint)
  53    }
  54
  55    pub fn models_url_from_endpoint(&self, endpoint: &str) -> String {
  56        format!("{}/models", endpoint)
  57    }
  58
  59    fn parse_domain(enterprise_uri: &str) -> String {
  60        let uri = enterprise_uri.trim_end_matches('/');
  61
  62        if let Some(domain) = uri.strip_prefix("https://") {
  63            domain.split('/').next().unwrap_or(domain).to_string()
  64        } else if let Some(domain) = uri.strip_prefix("http://") {
  65            domain.split('/').next().unwrap_or(domain).to_string()
  66        } else {
  67            uri.split('/').next().unwrap_or(uri).to_string()
  68        }
  69    }
  70}
  71
  72#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
  73#[serde(rename_all = "lowercase")]
  74pub enum Role {
  75    User,
  76    Assistant,
  77    System,
  78}
  79
  80#[derive(Deserialize, Serialize, Debug, Clone, PartialEq)]
  81pub enum ModelSupportedEndpoint {
  82    #[serde(rename = "/chat/completions")]
  83    ChatCompletions,
  84    #[serde(rename = "/responses")]
  85    Responses,
  86}
  87
  88#[derive(Deserialize)]
  89struct ModelSchema {
  90    #[serde(deserialize_with = "deserialize_models_skip_errors")]
  91    data: Vec<Model>,
  92}
  93
  94fn deserialize_models_skip_errors<'de, D>(deserializer: D) -> Result<Vec<Model>, D::Error>
  95where
  96    D: serde::Deserializer<'de>,
  97{
  98    let raw_values = Vec::<serde_json::Value>::deserialize(deserializer)?;
  99    let models = raw_values
 100        .into_iter()
 101        .filter_map(|value| match serde_json::from_value::<Model>(value) {
 102            Ok(model) => Some(model),
 103            Err(err) => {
 104                log::warn!("GitHub Copilot Chat model failed to deserialize: {:?}", err);
 105                None
 106            }
 107        })
 108        .collect();
 109
 110    Ok(models)
 111}
 112
 113#[derive(Clone, Serialize, Deserialize, Debug, PartialEq)]
 114pub struct Model {
 115    billing: ModelBilling,
 116    capabilities: ModelCapabilities,
 117    id: String,
 118    name: String,
 119    policy: Option<ModelPolicy>,
 120    vendor: ModelVendor,
 121    is_chat_default: bool,
 122    // The model with this value true is selected by VSCode copilot if a premium request limit is
 123    // reached. Zed does not currently implement this behaviour
 124    is_chat_fallback: bool,
 125    model_picker_enabled: bool,
 126    #[serde(default)]
 127    supported_endpoints: Vec<ModelSupportedEndpoint>,
 128}
 129
 130#[derive(Clone, Serialize, Deserialize, Debug, PartialEq)]
 131struct ModelBilling {
 132    is_premium: bool,
 133    multiplier: f64,
 134    // List of plans a model is restricted to
 135    // Field is not present if a model is available for all plans
 136    #[serde(default)]
 137    restricted_to: Option<Vec<String>>,
 138}
 139
 140#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
 141struct ModelCapabilities {
 142    family: String,
 143    #[serde(default)]
 144    limits: ModelLimits,
 145    supports: ModelSupportedFeatures,
 146    #[serde(rename = "type")]
 147    model_type: String,
 148    #[serde(default)]
 149    tokenizer: Option<String>,
 150}
 151
 152#[derive(Default, Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
 153struct ModelLimits {
 154    #[serde(default)]
 155    max_context_window_tokens: usize,
 156    #[serde(default)]
 157    max_output_tokens: usize,
 158    #[serde(default)]
 159    max_prompt_tokens: u64,
 160}
 161
 162#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
 163struct ModelPolicy {
 164    state: String,
 165}
 166
 167#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
 168struct ModelSupportedFeatures {
 169    #[serde(default)]
 170    streaming: bool,
 171    #[serde(default)]
 172    tool_calls: bool,
 173    #[serde(default)]
 174    parallel_tool_calls: bool,
 175    #[serde(default)]
 176    vision: bool,
 177}
 178
 179#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
 180pub enum ModelVendor {
 181    // Azure OpenAI should have no functional difference from OpenAI in Copilot Chat
 182    #[serde(alias = "Azure OpenAI")]
 183    OpenAI,
 184    Google,
 185    Anthropic,
 186    #[serde(rename = "xAI")]
 187    XAI,
 188    /// Unknown vendor that we don't explicitly support yet
 189    #[serde(other)]
 190    Unknown,
 191}
 192
 193#[derive(Serialize, Deserialize, Debug, Eq, PartialEq, Clone)]
 194#[serde(tag = "type")]
 195pub enum ChatMessagePart {
 196    #[serde(rename = "text")]
 197    Text { text: String },
 198    #[serde(rename = "image_url")]
 199    Image { image_url: ImageUrl },
 200}
 201
 202#[derive(Serialize, Deserialize, Debug, Eq, PartialEq, Clone)]
 203pub struct ImageUrl {
 204    pub url: String,
 205}
 206
 207impl Model {
 208    pub fn uses_streaming(&self) -> bool {
 209        self.capabilities.supports.streaming
 210    }
 211
 212    pub fn id(&self) -> &str {
 213        self.id.as_str()
 214    }
 215
 216    pub fn display_name(&self) -> &str {
 217        self.name.as_str()
 218    }
 219
 220    pub fn max_token_count(&self) -> u64 {
 221        self.capabilities.limits.max_prompt_tokens
 222    }
 223
 224    pub fn supports_tools(&self) -> bool {
 225        self.capabilities.supports.tool_calls
 226    }
 227
 228    pub fn vendor(&self) -> ModelVendor {
 229        self.vendor
 230    }
 231
 232    pub fn supports_vision(&self) -> bool {
 233        self.capabilities.supports.vision
 234    }
 235
 236    pub fn supports_parallel_tool_calls(&self) -> bool {
 237        self.capabilities.supports.parallel_tool_calls
 238    }
 239
 240    pub fn tokenizer(&self) -> Option<&str> {
 241        self.capabilities.tokenizer.as_deref()
 242    }
 243
 244    pub fn supports_response(&self) -> bool {
 245        self.supported_endpoints.len() > 0
 246            && !self
 247                .supported_endpoints
 248                .contains(&ModelSupportedEndpoint::ChatCompletions)
 249            && self
 250                .supported_endpoints
 251                .contains(&ModelSupportedEndpoint::Responses)
 252    }
 253}
 254
 255#[derive(Serialize, Deserialize)]
 256pub struct Request {
 257    pub intent: bool,
 258    pub n: usize,
 259    pub stream: bool,
 260    pub temperature: f32,
 261    pub model: String,
 262    pub messages: Vec<ChatMessage>,
 263    #[serde(default, skip_serializing_if = "Vec::is_empty")]
 264    pub tools: Vec<Tool>,
 265    #[serde(default, skip_serializing_if = "Option::is_none")]
 266    pub tool_choice: Option<ToolChoice>,
 267}
 268
 269#[derive(Serialize, Deserialize)]
 270pub struct Function {
 271    pub name: String,
 272    pub description: String,
 273    pub parameters: serde_json::Value,
 274}
 275
 276#[derive(Serialize, Deserialize)]
 277#[serde(tag = "type", rename_all = "snake_case")]
 278pub enum Tool {
 279    Function { function: Function },
 280}
 281
 282#[derive(Serialize, Deserialize, Debug)]
 283#[serde(rename_all = "lowercase")]
 284pub enum ToolChoice {
 285    Auto,
 286    Any,
 287    None,
 288}
 289
 290#[derive(Serialize, Deserialize, Debug)]
 291#[serde(tag = "role", rename_all = "lowercase")]
 292pub enum ChatMessage {
 293    Assistant {
 294        content: ChatMessageContent,
 295        #[serde(default, skip_serializing_if = "Vec::is_empty")]
 296        tool_calls: Vec<ToolCall>,
 297        #[serde(default, skip_serializing_if = "Option::is_none")]
 298        reasoning_opaque: Option<String>,
 299        #[serde(default, skip_serializing_if = "Option::is_none")]
 300        reasoning_text: Option<String>,
 301    },
 302    User {
 303        content: ChatMessageContent,
 304    },
 305    System {
 306        content: String,
 307    },
 308    Tool {
 309        content: ChatMessageContent,
 310        tool_call_id: String,
 311    },
 312}
 313
 314#[derive(Debug, Serialize, Deserialize)]
 315#[serde(untagged)]
 316pub enum ChatMessageContent {
 317    Plain(String),
 318    Multipart(Vec<ChatMessagePart>),
 319}
 320
 321impl ChatMessageContent {
 322    pub fn empty() -> Self {
 323        ChatMessageContent::Multipart(vec![])
 324    }
 325}
 326
 327impl From<Vec<ChatMessagePart>> for ChatMessageContent {
 328    fn from(mut parts: Vec<ChatMessagePart>) -> Self {
 329        if let [ChatMessagePart::Text { text }] = parts.as_mut_slice() {
 330            ChatMessageContent::Plain(std::mem::take(text))
 331        } else {
 332            ChatMessageContent::Multipart(parts)
 333        }
 334    }
 335}
 336
 337impl From<String> for ChatMessageContent {
 338    fn from(text: String) -> Self {
 339        ChatMessageContent::Plain(text)
 340    }
 341}
 342
 343#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
 344pub struct ToolCall {
 345    pub id: String,
 346    #[serde(flatten)]
 347    pub content: ToolCallContent,
 348}
 349
 350#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
 351#[serde(tag = "type", rename_all = "lowercase")]
 352pub enum ToolCallContent {
 353    Function { function: FunctionContent },
 354}
 355
 356#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
 357pub struct FunctionContent {
 358    pub name: String,
 359    pub arguments: String,
 360    #[serde(default, skip_serializing_if = "Option::is_none")]
 361    pub thought_signature: Option<String>,
 362}
 363
 364#[derive(Deserialize, Debug)]
 365#[serde(tag = "type", rename_all = "snake_case")]
 366pub struct ResponseEvent {
 367    pub choices: Vec<ResponseChoice>,
 368    pub id: String,
 369    pub usage: Option<Usage>,
 370}
 371
 372#[derive(Deserialize, Debug)]
 373pub struct Usage {
 374    pub completion_tokens: u64,
 375    pub prompt_tokens: u64,
 376    pub total_tokens: u64,
 377}
 378
 379#[derive(Debug, Deserialize)]
 380pub struct ResponseChoice {
 381    pub index: Option<usize>,
 382    pub finish_reason: Option<String>,
 383    pub delta: Option<ResponseDelta>,
 384    pub message: Option<ResponseDelta>,
 385}
 386
 387#[derive(Debug, Deserialize)]
 388pub struct ResponseDelta {
 389    pub content: Option<String>,
 390    pub role: Option<Role>,
 391    #[serde(default)]
 392    pub tool_calls: Vec<ToolCallChunk>,
 393    pub reasoning_opaque: Option<String>,
 394    pub reasoning_text: Option<String>,
 395}
 396#[derive(Deserialize, Debug, Eq, PartialEq)]
 397pub struct ToolCallChunk {
 398    pub index: Option<usize>,
 399    pub id: Option<String>,
 400    pub function: Option<FunctionChunk>,
 401}
 402
 403#[derive(Deserialize, Debug, Eq, PartialEq)]
 404pub struct FunctionChunk {
 405    pub name: Option<String>,
 406    pub arguments: Option<String>,
 407    pub thought_signature: Option<String>,
 408}
 409
 410#[derive(Deserialize)]
 411struct ApiTokenResponse {
 412    token: String,
 413    expires_at: i64,
 414    endpoints: ApiTokenResponseEndpoints,
 415}
 416
 417#[derive(Deserialize)]
 418struct ApiTokenResponseEndpoints {
 419    api: String,
 420}
 421
 422#[derive(Clone)]
 423struct ApiToken {
 424    api_key: String,
 425    expires_at: DateTime<chrono::Utc>,
 426    api_endpoint: String,
 427}
 428
 429impl ApiToken {
 430    pub fn remaining_seconds(&self) -> i64 {
 431        self.expires_at
 432            .timestamp()
 433            .saturating_sub(chrono::Utc::now().timestamp())
 434    }
 435}
 436
 437impl TryFrom<ApiTokenResponse> for ApiToken {
 438    type Error = anyhow::Error;
 439
 440    fn try_from(response: ApiTokenResponse) -> Result<Self, Self::Error> {
 441        let expires_at =
 442            DateTime::from_timestamp(response.expires_at, 0).context("invalid expires_at")?;
 443
 444        Ok(Self {
 445            api_key: response.token,
 446            expires_at,
 447            api_endpoint: response.endpoints.api,
 448        })
 449    }
 450}
 451
 452struct GlobalCopilotChat(gpui::Entity<CopilotChat>);
 453
 454impl Global for GlobalCopilotChat {}
 455
 456pub struct CopilotChat {
 457    oauth_token: Option<String>,
 458    api_token: Option<ApiToken>,
 459    configuration: CopilotChatConfiguration,
 460    models: Option<Vec<Model>>,
 461    client: Arc<dyn HttpClient>,
 462}
 463
 464pub fn init(
 465    fs: Arc<dyn Fs>,
 466    client: Arc<dyn HttpClient>,
 467    configuration: CopilotChatConfiguration,
 468    cx: &mut App,
 469) {
 470    let copilot_chat = cx.new(|cx| CopilotChat::new(fs, client, configuration, cx));
 471    cx.set_global(GlobalCopilotChat(copilot_chat));
 472}
 473
 474pub fn copilot_chat_config_dir() -> &'static PathBuf {
 475    static COPILOT_CHAT_CONFIG_DIR: OnceLock<PathBuf> = OnceLock::new();
 476
 477    COPILOT_CHAT_CONFIG_DIR.get_or_init(|| {
 478        let config_dir = if cfg!(target_os = "windows") {
 479            dirs::data_local_dir().expect("failed to determine LocalAppData directory")
 480        } else {
 481            std::env::var("XDG_CONFIG_HOME")
 482                .map(PathBuf::from)
 483                .unwrap_or_else(|_| home_dir().join(".config"))
 484        };
 485
 486        config_dir.join("github-copilot")
 487    })
 488}
 489
 490fn copilot_chat_config_paths() -> [PathBuf; 2] {
 491    let base_dir = copilot_chat_config_dir();
 492    [base_dir.join("hosts.json"), base_dir.join("apps.json")]
 493}
 494
 495impl CopilotChat {
 496    pub fn global(cx: &App) -> Option<gpui::Entity<Self>> {
 497        cx.try_global::<GlobalCopilotChat>()
 498            .map(|model| model.0.clone())
 499    }
 500
 501    fn new(
 502        fs: Arc<dyn Fs>,
 503        client: Arc<dyn HttpClient>,
 504        configuration: CopilotChatConfiguration,
 505        cx: &mut Context<Self>,
 506    ) -> Self {
 507        let config_paths: HashSet<PathBuf> = copilot_chat_config_paths().into_iter().collect();
 508        let dir_path = copilot_chat_config_dir();
 509
 510        cx.spawn(async move |this, cx| {
 511            let mut parent_watch_rx = watch_config_dir(
 512                cx.background_executor(),
 513                fs.clone(),
 514                dir_path.clone(),
 515                config_paths,
 516            );
 517            while let Some(contents) = parent_watch_rx.next().await {
 518                let oauth_domain =
 519                    this.read_with(cx, |this, _| this.configuration.oauth_domain())?;
 520                let oauth_token = extract_oauth_token(contents, &oauth_domain);
 521
 522                this.update(cx, |this, cx| {
 523                    this.oauth_token = oauth_token.clone();
 524                    cx.notify();
 525                })?;
 526
 527                if oauth_token.is_some() {
 528                    Self::update_models(&this, cx).await?;
 529                }
 530            }
 531            anyhow::Ok(())
 532        })
 533        .detach_and_log_err(cx);
 534
 535        let this = Self {
 536            oauth_token: std::env::var(COPILOT_OAUTH_ENV_VAR).ok(),
 537            api_token: None,
 538            models: None,
 539            configuration,
 540            client,
 541        };
 542
 543        if this.oauth_token.is_some() {
 544            cx.spawn(async move |this, cx| Self::update_models(&this, cx).await)
 545                .detach_and_log_err(cx);
 546        }
 547
 548        this
 549    }
 550
 551    async fn update_models(this: &WeakEntity<Self>, cx: &mut AsyncApp) -> Result<()> {
 552        let (oauth_token, client, configuration) = this.read_with(cx, |this, _| {
 553            (
 554                this.oauth_token.clone(),
 555                this.client.clone(),
 556                this.configuration.clone(),
 557            )
 558        })?;
 559
 560        let oauth_token = oauth_token
 561            .ok_or_else(|| anyhow!("OAuth token is missing while updating Copilot Chat models"))?;
 562
 563        let token_url = configuration.token_url();
 564        let api_token = request_api_token(&oauth_token, token_url.into(), client.clone()).await?;
 565
 566        let models_url = configuration.models_url_from_endpoint(&api_token.api_endpoint);
 567        let models =
 568            get_models(models_url.into(), api_token.api_key.clone(), client.clone()).await?;
 569
 570        this.update(cx, |this, cx| {
 571            this.api_token = Some(api_token);
 572            this.models = Some(models);
 573            cx.notify();
 574        })?;
 575        anyhow::Ok(())
 576    }
 577
 578    pub fn is_authenticated(&self) -> bool {
 579        self.oauth_token.is_some()
 580    }
 581
 582    pub fn models(&self) -> Option<&[Model]> {
 583        self.models.as_deref()
 584    }
 585
 586    pub async fn stream_completion(
 587        request: Request,
 588        is_user_initiated: bool,
 589        mut cx: AsyncApp,
 590    ) -> Result<BoxStream<'static, Result<ResponseEvent>>> {
 591        let (client, token, configuration) = Self::get_auth_details(&mut cx).await?;
 592
 593        let api_url = configuration.chat_completions_url_from_endpoint(&token.api_endpoint);
 594        stream_completion(
 595            client.clone(),
 596            token.api_key,
 597            api_url.into(),
 598            request,
 599            is_user_initiated,
 600        )
 601        .await
 602    }
 603
 604    pub async fn stream_response(
 605        request: responses::Request,
 606        is_user_initiated: bool,
 607        mut cx: AsyncApp,
 608    ) -> Result<BoxStream<'static, Result<responses::StreamEvent>>> {
 609        let (client, token, configuration) = Self::get_auth_details(&mut cx).await?;
 610
 611        let api_url = configuration.responses_url_from_endpoint(&token.api_endpoint);
 612        responses::stream_response(
 613            client.clone(),
 614            token.api_key,
 615            api_url,
 616            request,
 617            is_user_initiated,
 618        )
 619        .await
 620    }
 621
 622    async fn get_auth_details(
 623        cx: &mut AsyncApp,
 624    ) -> Result<(Arc<dyn HttpClient>, ApiToken, CopilotChatConfiguration)> {
 625        let this = cx
 626            .update(|cx| Self::global(cx))
 627            .ok()
 628            .flatten()
 629            .context("Copilot chat is not enabled")?;
 630
 631        let (oauth_token, api_token, client, configuration) = this.read_with(cx, |this, _| {
 632            (
 633                this.oauth_token.clone(),
 634                this.api_token.clone(),
 635                this.client.clone(),
 636                this.configuration.clone(),
 637            )
 638        })?;
 639
 640        let oauth_token = oauth_token.context("No OAuth token available")?;
 641
 642        let token = match api_token {
 643            Some(api_token) if api_token.remaining_seconds() > 5 * 60 => api_token,
 644            _ => {
 645                let token_url = configuration.token_url();
 646                let token =
 647                    request_api_token(&oauth_token, token_url.into(), client.clone()).await?;
 648                this.update(cx, |this, cx| {
 649                    this.api_token = Some(token.clone());
 650                    cx.notify();
 651                })?;
 652                token
 653            }
 654        };
 655
 656        Ok((client, token, configuration))
 657    }
 658
 659    pub fn set_configuration(
 660        &mut self,
 661        configuration: CopilotChatConfiguration,
 662        cx: &mut Context<Self>,
 663    ) {
 664        let same_configuration = self.configuration == configuration;
 665        self.configuration = configuration;
 666        if !same_configuration {
 667            self.api_token = None;
 668            cx.spawn(async move |this, cx| {
 669                Self::update_models(&this, cx).await?;
 670                Ok::<_, anyhow::Error>(())
 671            })
 672            .detach();
 673        }
 674    }
 675}
 676
 677async fn get_models(
 678    models_url: Arc<str>,
 679    api_token: String,
 680    client: Arc<dyn HttpClient>,
 681) -> Result<Vec<Model>> {
 682    let all_models = request_models(models_url, api_token, client).await?;
 683
 684    let mut models: Vec<Model> = all_models
 685        .into_iter()
 686        .filter(|model| {
 687            model.model_picker_enabled
 688                && model.capabilities.model_type.as_str() == "chat"
 689                && model
 690                    .policy
 691                    .as_ref()
 692                    .is_none_or(|policy| policy.state == "enabled")
 693        })
 694        .dedup_by(|a, b| a.capabilities.family == b.capabilities.family)
 695        .collect();
 696
 697    if let Some(default_model_position) = models.iter().position(|model| model.is_chat_default) {
 698        let default_model = models.remove(default_model_position);
 699        models.insert(0, default_model);
 700    }
 701
 702    Ok(models)
 703}
 704
 705async fn request_models(
 706    models_url: Arc<str>,
 707    api_token: String,
 708    client: Arc<dyn HttpClient>,
 709) -> Result<Vec<Model>> {
 710    let request_builder = HttpRequest::builder()
 711        .method(Method::GET)
 712        .uri(models_url.as_ref())
 713        .header("Authorization", format!("Bearer {}", api_token))
 714        .header("Content-Type", "application/json")
 715        .header("Copilot-Integration-Id", "vscode-chat")
 716        .header("Editor-Version", "vscode/1.103.2")
 717        .header("x-github-api-version", "2025-05-01");
 718
 719    let request = request_builder.body(AsyncBody::empty())?;
 720
 721    let mut response = client.send(request).await?;
 722
 723    anyhow::ensure!(
 724        response.status().is_success(),
 725        "Failed to request models: {}",
 726        response.status()
 727    );
 728    let mut body = Vec::new();
 729    response.body_mut().read_to_end(&mut body).await?;
 730
 731    let body_str = std::str::from_utf8(&body)?;
 732
 733    let models = serde_json::from_str::<ModelSchema>(body_str)?.data;
 734
 735    Ok(models)
 736}
 737
 738async fn request_api_token(
 739    oauth_token: &str,
 740    auth_url: Arc<str>,
 741    client: Arc<dyn HttpClient>,
 742) -> Result<ApiToken> {
 743    let request_builder = HttpRequest::builder()
 744        .method(Method::GET)
 745        .uri(auth_url.as_ref())
 746        .header("Authorization", format!("token {}", oauth_token))
 747        .header("Accept", "application/json");
 748
 749    let request = request_builder.body(AsyncBody::empty())?;
 750
 751    let mut response = client.send(request).await?;
 752
 753    if response.status().is_success() {
 754        let mut body = Vec::new();
 755        response.body_mut().read_to_end(&mut body).await?;
 756
 757        let body_str = std::str::from_utf8(&body)?;
 758
 759        let parsed: ApiTokenResponse = serde_json::from_str(body_str)?;
 760        ApiToken::try_from(parsed)
 761    } else {
 762        let mut body = Vec::new();
 763        response.body_mut().read_to_end(&mut body).await?;
 764
 765        let body_str = std::str::from_utf8(&body)?;
 766        anyhow::bail!("Failed to request API token: {body_str}");
 767    }
 768}
 769
 770fn extract_oauth_token(contents: String, domain: &str) -> Option<String> {
 771    serde_json::from_str::<serde_json::Value>(&contents)
 772        .map(|v| {
 773            v.as_object().and_then(|obj| {
 774                obj.iter().find_map(|(key, value)| {
 775                    if key.starts_with(domain) {
 776                        value["oauth_token"].as_str().map(|v| v.to_string())
 777                    } else {
 778                        None
 779                    }
 780                })
 781            })
 782        })
 783        .ok()
 784        .flatten()
 785}
 786
 787async fn stream_completion(
 788    client: Arc<dyn HttpClient>,
 789    api_key: String,
 790    completion_url: Arc<str>,
 791    request: Request,
 792    is_user_initiated: bool,
 793) -> Result<BoxStream<'static, Result<ResponseEvent>>> {
 794    let is_vision_request = request.messages.iter().any(|message| match message {
 795        ChatMessage::User { content }
 796        | ChatMessage::Assistant { content, .. }
 797        | ChatMessage::Tool { content, .. } => {
 798            matches!(content, ChatMessageContent::Multipart(parts) if parts.iter().any(|part| matches!(part, ChatMessagePart::Image { .. })))
 799        }
 800        _ => false,
 801    });
 802
 803    let request_initiator = if is_user_initiated { "user" } else { "agent" };
 804
 805    let request_builder = HttpRequest::builder()
 806        .method(Method::POST)
 807        .uri(completion_url.as_ref())
 808        .header(
 809            "Editor-Version",
 810            format!(
 811                "Zed/{}",
 812                option_env!("CARGO_PKG_VERSION").unwrap_or("unknown")
 813            ),
 814        )
 815        .header("Authorization", format!("Bearer {}", api_key))
 816        .header("Content-Type", "application/json")
 817        .header("Copilot-Integration-Id", "vscode-chat")
 818        .header("X-Initiator", request_initiator)
 819        .when(is_vision_request, |builder| {
 820            builder.header("Copilot-Vision-Request", is_vision_request.to_string())
 821        });
 822
 823    let is_streaming = request.stream;
 824
 825    let json = serde_json::to_string(&request)?;
 826    let request = request_builder.body(AsyncBody::from(json))?;
 827    let mut response = client.send(request).await?;
 828
 829    if !response.status().is_success() {
 830        let mut body = Vec::new();
 831        response.body_mut().read_to_end(&mut body).await?;
 832        let body_str = std::str::from_utf8(&body)?;
 833        anyhow::bail!(
 834            "Failed to connect to API: {} {}",
 835            response.status(),
 836            body_str
 837        );
 838    }
 839
 840    if is_streaming {
 841        let reader = BufReader::new(response.into_body());
 842        Ok(reader
 843            .lines()
 844            .filter_map(|line| async move {
 845                match line {
 846                    Ok(line) => {
 847                        let line = line.strip_prefix("data: ")?;
 848                        if line.starts_with("[DONE]") {
 849                            return None;
 850                        }
 851
 852                        match serde_json::from_str::<ResponseEvent>(line) {
 853                            Ok(response) => {
 854                                if response.choices.is_empty() {
 855                                    None
 856                                } else {
 857                                    Some(Ok(response))
 858                                }
 859                            }
 860                            Err(error) => Some(Err(anyhow!(error))),
 861                        }
 862                    }
 863                    Err(error) => Some(Err(anyhow!(error))),
 864                }
 865            })
 866            .boxed())
 867    } else {
 868        let mut body = Vec::new();
 869        response.body_mut().read_to_end(&mut body).await?;
 870        let body_str = std::str::from_utf8(&body)?;
 871        let response: ResponseEvent = serde_json::from_str(body_str)?;
 872
 873        Ok(futures::stream::once(async move { Ok(response) }).boxed())
 874    }
 875}
 876
 877#[cfg(test)]
 878mod tests {
 879    use super::*;
 880
 881    #[test]
 882    fn test_resilient_model_schema_deserialize() {
 883        let json = r#"{
 884              "data": [
 885                {
 886                  "billing": {
 887                    "is_premium": false,
 888                    "multiplier": 0
 889                  },
 890                  "capabilities": {
 891                    "family": "gpt-4",
 892                    "limits": {
 893                      "max_context_window_tokens": 32768,
 894                      "max_output_tokens": 4096,
 895                      "max_prompt_tokens": 32768
 896                    },
 897                    "object": "model_capabilities",
 898                    "supports": { "streaming": true, "tool_calls": true },
 899                    "tokenizer": "cl100k_base",
 900                    "type": "chat"
 901                  },
 902                  "id": "gpt-4",
 903                  "is_chat_default": false,
 904                  "is_chat_fallback": false,
 905                  "model_picker_enabled": false,
 906                  "name": "GPT 4",
 907                  "object": "model",
 908                  "preview": false,
 909                  "vendor": "Azure OpenAI",
 910                  "version": "gpt-4-0613"
 911                },
 912                {
 913                    "some-unknown-field": 123
 914                },
 915                {
 916                  "billing": {
 917                    "is_premium": true,
 918                    "multiplier": 1,
 919                    "restricted_to": [
 920                      "pro",
 921                      "pro_plus",
 922                      "business",
 923                      "enterprise"
 924                    ]
 925                  },
 926                  "capabilities": {
 927                    "family": "claude-3.7-sonnet",
 928                    "limits": {
 929                      "max_context_window_tokens": 200000,
 930                      "max_output_tokens": 16384,
 931                      "max_prompt_tokens": 90000,
 932                      "vision": {
 933                        "max_prompt_image_size": 3145728,
 934                        "max_prompt_images": 1,
 935                        "supported_media_types": ["image/jpeg", "image/png", "image/webp"]
 936                      }
 937                    },
 938                    "object": "model_capabilities",
 939                    "supports": {
 940                      "parallel_tool_calls": true,
 941                      "streaming": true,
 942                      "tool_calls": true,
 943                      "vision": true
 944                    },
 945                    "tokenizer": "o200k_base",
 946                    "type": "chat"
 947                  },
 948                  "id": "claude-3.7-sonnet",
 949                  "is_chat_default": false,
 950                  "is_chat_fallback": false,
 951                  "model_picker_enabled": true,
 952                  "name": "Claude 3.7 Sonnet",
 953                  "object": "model",
 954                  "policy": {
 955                    "state": "enabled",
 956                    "terms": "Enable access to the latest Claude 3.7 Sonnet model from Anthropic. [Learn more about how GitHub Copilot serves Claude 3.7 Sonnet](https://docs.github.com/copilot/using-github-copilot/using-claude-sonnet-in-github-copilot)."
 957                  },
 958                  "preview": false,
 959                  "vendor": "Anthropic",
 960                  "version": "claude-3.7-sonnet"
 961                }
 962              ],
 963              "object": "list"
 964            }"#;
 965
 966        let schema: ModelSchema = serde_json::from_str(json).unwrap();
 967
 968        assert_eq!(schema.data.len(), 2);
 969        assert_eq!(schema.data[0].id, "gpt-4");
 970        assert_eq!(schema.data[1].id, "claude-3.7-sonnet");
 971    }
 972
 973    #[test]
 974    fn test_unknown_vendor_resilience() {
 975        let json = r#"{
 976              "data": [
 977                {
 978                  "billing": {
 979                    "is_premium": false,
 980                    "multiplier": 1
 981                  },
 982                  "capabilities": {
 983                    "family": "future-model",
 984                    "limits": {
 985                      "max_context_window_tokens": 128000,
 986                      "max_output_tokens": 8192,
 987                      "max_prompt_tokens": 120000
 988                    },
 989                    "object": "model_capabilities",
 990                    "supports": { "streaming": true, "tool_calls": true },
 991                    "type": "chat"
 992                  },
 993                  "id": "future-model-v1",
 994                  "is_chat_default": false,
 995                  "is_chat_fallback": false,
 996                  "model_picker_enabled": true,
 997                  "name": "Future Model v1",
 998                  "object": "model",
 999                  "preview": false,
1000                  "vendor": "SomeNewVendor",
1001                  "version": "v1.0"
1002                }
1003              ],
1004              "object": "list"
1005            }"#;
1006
1007        let schema: ModelSchema = serde_json::from_str(json).unwrap();
1008
1009        assert_eq!(schema.data.len(), 1);
1010        assert_eq!(schema.data[0].id, "future-model-v1");
1011        assert_eq!(schema.data[0].vendor, ModelVendor::Unknown);
1012    }
1013}