copilot_chat.rs

   1pub mod responses;
   2
   3use std::path::PathBuf;
   4use std::sync::Arc;
   5use std::sync::OnceLock;
   6
   7use anyhow::Context as _;
   8use anyhow::{Result, anyhow};
   9use collections::HashSet;
  10use fs::Fs;
  11use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
  12use gpui::WeakEntity;
  13use gpui::{App, AsyncApp, Global, prelude::*};
  14use http_client::HttpRequestExt;
  15use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
  16use paths::home_dir;
  17use serde::{Deserialize, Serialize};
  18
  19use settings::watch_config_dir;
  20
  21pub const COPILOT_OAUTH_ENV_VAR: &str = "GH_COPILOT_TOKEN";
  22const DEFAULT_COPILOT_API_ENDPOINT: &str = "https://api.githubcopilot.com";
  23
  24#[derive(Default, Clone, Debug, PartialEq)]
  25pub struct CopilotChatConfiguration {
  26    pub enterprise_uri: Option<String>,
  27}
  28
  29impl CopilotChatConfiguration {
  30    pub fn oauth_domain(&self) -> String {
  31        if let Some(enterprise_uri) = &self.enterprise_uri {
  32            Self::parse_domain(enterprise_uri)
  33        } else {
  34            "github.com".to_string()
  35        }
  36    }
  37
  38    pub fn graphql_url(&self) -> String {
  39        if let Some(enterprise_uri) = &self.enterprise_uri {
  40            let domain = Self::parse_domain(enterprise_uri);
  41            format!("https://{}/api/graphql", domain)
  42        } else {
  43            "https://api.github.com/graphql".to_string()
  44        }
  45    }
  46
  47    pub fn chat_completions_url(&self, api_endpoint: &str) -> String {
  48        format!("{}/chat/completions", api_endpoint)
  49    }
  50
  51    pub fn responses_url(&self, api_endpoint: &str) -> String {
  52        format!("{}/responses", api_endpoint)
  53    }
  54
  55    pub fn models_url(&self, api_endpoint: &str) -> String {
  56        format!("{}/models", api_endpoint)
  57    }
  58
  59    fn parse_domain(enterprise_uri: &str) -> String {
  60        let uri = enterprise_uri.trim_end_matches('/');
  61
  62        if let Some(domain) = uri.strip_prefix("https://") {
  63            domain.split('/').next().unwrap_or(domain).to_string()
  64        } else if let Some(domain) = uri.strip_prefix("http://") {
  65            domain.split('/').next().unwrap_or(domain).to_string()
  66        } else {
  67            uri.split('/').next().unwrap_or(uri).to_string()
  68        }
  69    }
  70}
  71
  72#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
  73#[serde(rename_all = "lowercase")]
  74pub enum Role {
  75    User,
  76    Assistant,
  77    System,
  78}
  79
  80#[derive(Deserialize, Serialize, Debug, Clone, PartialEq)]
  81pub enum ModelSupportedEndpoint {
  82    #[serde(rename = "/chat/completions")]
  83    ChatCompletions,
  84    #[serde(rename = "/responses")]
  85    Responses,
  86    #[serde(rename = "/v1/messages")]
  87    Messages,
  88    /// Unknown endpoint that we don't explicitly support yet
  89    #[serde(other)]
  90    Unknown,
  91}
  92
  93#[derive(Deserialize)]
  94struct ModelSchema {
  95    #[serde(deserialize_with = "deserialize_models_skip_errors")]
  96    data: Vec<Model>,
  97}
  98
  99fn deserialize_models_skip_errors<'de, D>(deserializer: D) -> Result<Vec<Model>, D::Error>
 100where
 101    D: serde::Deserializer<'de>,
 102{
 103    let raw_values = Vec::<serde_json::Value>::deserialize(deserializer)?;
 104    let models = raw_values
 105        .into_iter()
 106        .filter_map(|value| match serde_json::from_value::<Model>(value) {
 107            Ok(model) => Some(model),
 108            Err(err) => {
 109                log::warn!("GitHub Copilot Chat model failed to deserialize: {:?}", err);
 110                None
 111            }
 112        })
 113        .collect();
 114
 115    Ok(models)
 116}
 117
 118#[derive(Clone, Serialize, Deserialize, Debug, PartialEq)]
 119pub struct Model {
 120    billing: ModelBilling,
 121    capabilities: ModelCapabilities,
 122    id: String,
 123    name: String,
 124    policy: Option<ModelPolicy>,
 125    vendor: ModelVendor,
 126    is_chat_default: bool,
 127    // The model with this value true is selected by VSCode copilot if a premium request limit is
 128    // reached. Zed does not currently implement this behaviour
 129    is_chat_fallback: bool,
 130    model_picker_enabled: bool,
 131    #[serde(default)]
 132    supported_endpoints: Vec<ModelSupportedEndpoint>,
 133}
 134
 135#[derive(Clone, Serialize, Deserialize, Debug, PartialEq)]
 136struct ModelBilling {
 137    is_premium: bool,
 138    multiplier: f64,
 139    // List of plans a model is restricted to
 140    // Field is not present if a model is available for all plans
 141    #[serde(default)]
 142    restricted_to: Option<Vec<String>>,
 143}
 144
 145#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
 146struct ModelCapabilities {
 147    family: String,
 148    #[serde(default)]
 149    limits: ModelLimits,
 150    supports: ModelSupportedFeatures,
 151    #[serde(rename = "type")]
 152    model_type: String,
 153    #[serde(default)]
 154    tokenizer: Option<String>,
 155}
 156
 157#[derive(Default, Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
 158struct ModelLimits {
 159    #[serde(default)]
 160    max_context_window_tokens: usize,
 161    #[serde(default)]
 162    max_output_tokens: usize,
 163    #[serde(default)]
 164    max_prompt_tokens: u64,
 165}
 166
 167#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
 168struct ModelPolicy {
 169    state: String,
 170}
 171
 172#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
 173struct ModelSupportedFeatures {
 174    #[serde(default)]
 175    streaming: bool,
 176    #[serde(default)]
 177    tool_calls: bool,
 178    #[serde(default)]
 179    parallel_tool_calls: bool,
 180    #[serde(default)]
 181    vision: bool,
 182}
 183
 184#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
 185pub enum ModelVendor {
 186    // Azure OpenAI should have no functional difference from OpenAI in Copilot Chat
 187    #[serde(alias = "Azure OpenAI")]
 188    OpenAI,
 189    Google,
 190    Anthropic,
 191    #[serde(rename = "xAI")]
 192    XAI,
 193    /// Unknown vendor that we don't explicitly support yet
 194    #[serde(other)]
 195    Unknown,
 196}
 197
 198#[derive(Serialize, Deserialize, Debug, Eq, PartialEq, Clone)]
 199#[serde(tag = "type")]
 200pub enum ChatMessagePart {
 201    #[serde(rename = "text")]
 202    Text { text: String },
 203    #[serde(rename = "image_url")]
 204    Image { image_url: ImageUrl },
 205}
 206
 207#[derive(Serialize, Deserialize, Debug, Eq, PartialEq, Clone)]
 208pub struct ImageUrl {
 209    pub url: String,
 210}
 211
 212impl Model {
 213    pub fn uses_streaming(&self) -> bool {
 214        self.capabilities.supports.streaming
 215    }
 216
 217    pub fn id(&self) -> &str {
 218        self.id.as_str()
 219    }
 220
 221    pub fn display_name(&self) -> &str {
 222        self.name.as_str()
 223    }
 224
 225    pub fn max_token_count(&self) -> u64 {
 226        self.capabilities.limits.max_prompt_tokens
 227    }
 228
 229    pub fn supports_tools(&self) -> bool {
 230        self.capabilities.supports.tool_calls
 231    }
 232
 233    pub fn vendor(&self) -> ModelVendor {
 234        self.vendor
 235    }
 236
 237    pub fn supports_vision(&self) -> bool {
 238        self.capabilities.supports.vision
 239    }
 240
 241    pub fn supports_parallel_tool_calls(&self) -> bool {
 242        self.capabilities.supports.parallel_tool_calls
 243    }
 244
 245    pub fn tokenizer(&self) -> Option<&str> {
 246        self.capabilities.tokenizer.as_deref()
 247    }
 248
 249    pub fn supports_response(&self) -> bool {
 250        self.supported_endpoints.len() > 0
 251            && !self
 252                .supported_endpoints
 253                .contains(&ModelSupportedEndpoint::ChatCompletions)
 254            && self
 255                .supported_endpoints
 256                .contains(&ModelSupportedEndpoint::Responses)
 257    }
 258}
 259
 260#[derive(Serialize, Deserialize)]
 261pub struct Request {
 262    pub intent: bool,
 263    pub n: usize,
 264    pub stream: bool,
 265    pub temperature: f32,
 266    pub model: String,
 267    pub messages: Vec<ChatMessage>,
 268    #[serde(default, skip_serializing_if = "Vec::is_empty")]
 269    pub tools: Vec<Tool>,
 270    #[serde(default, skip_serializing_if = "Option::is_none")]
 271    pub tool_choice: Option<ToolChoice>,
 272}
 273
 274#[derive(Serialize, Deserialize)]
 275pub struct Function {
 276    pub name: String,
 277    pub description: String,
 278    pub parameters: serde_json::Value,
 279}
 280
 281#[derive(Serialize, Deserialize)]
 282#[serde(tag = "type", rename_all = "snake_case")]
 283pub enum Tool {
 284    Function { function: Function },
 285}
 286
 287#[derive(Serialize, Deserialize, Debug)]
 288#[serde(rename_all = "lowercase")]
 289pub enum ToolChoice {
 290    Auto,
 291    Any,
 292    None,
 293}
 294
 295#[derive(Serialize, Deserialize, Debug)]
 296#[serde(tag = "role", rename_all = "lowercase")]
 297pub enum ChatMessage {
 298    Assistant {
 299        content: ChatMessageContent,
 300        #[serde(default, skip_serializing_if = "Vec::is_empty")]
 301        tool_calls: Vec<ToolCall>,
 302        #[serde(default, skip_serializing_if = "Option::is_none")]
 303        reasoning_opaque: Option<String>,
 304        #[serde(default, skip_serializing_if = "Option::is_none")]
 305        reasoning_text: Option<String>,
 306    },
 307    User {
 308        content: ChatMessageContent,
 309    },
 310    System {
 311        content: String,
 312    },
 313    Tool {
 314        content: ChatMessageContent,
 315        tool_call_id: String,
 316    },
 317}
 318
 319#[derive(Debug, Serialize, Deserialize)]
 320#[serde(untagged)]
 321pub enum ChatMessageContent {
 322    Plain(String),
 323    Multipart(Vec<ChatMessagePart>),
 324}
 325
 326impl ChatMessageContent {
 327    pub fn empty() -> Self {
 328        ChatMessageContent::Multipart(vec![])
 329    }
 330}
 331
 332impl From<Vec<ChatMessagePart>> for ChatMessageContent {
 333    fn from(mut parts: Vec<ChatMessagePart>) -> Self {
 334        if let [ChatMessagePart::Text { text }] = parts.as_mut_slice() {
 335            ChatMessageContent::Plain(std::mem::take(text))
 336        } else {
 337            ChatMessageContent::Multipart(parts)
 338        }
 339    }
 340}
 341
 342impl From<String> for ChatMessageContent {
 343    fn from(text: String) -> Self {
 344        ChatMessageContent::Plain(text)
 345    }
 346}
 347
 348#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
 349pub struct ToolCall {
 350    pub id: String,
 351    #[serde(flatten)]
 352    pub content: ToolCallContent,
 353}
 354
 355#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
 356#[serde(tag = "type", rename_all = "lowercase")]
 357pub enum ToolCallContent {
 358    Function { function: FunctionContent },
 359}
 360
 361#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
 362pub struct FunctionContent {
 363    pub name: String,
 364    pub arguments: String,
 365    #[serde(default, skip_serializing_if = "Option::is_none")]
 366    pub thought_signature: Option<String>,
 367}
 368
 369#[derive(Deserialize, Debug)]
 370#[serde(tag = "type", rename_all = "snake_case")]
 371pub struct ResponseEvent {
 372    pub choices: Vec<ResponseChoice>,
 373    pub id: String,
 374    pub usage: Option<Usage>,
 375}
 376
 377#[derive(Deserialize, Debug)]
 378pub struct Usage {
 379    pub completion_tokens: u64,
 380    pub prompt_tokens: u64,
 381    pub total_tokens: u64,
 382}
 383
 384#[derive(Debug, Deserialize)]
 385pub struct ResponseChoice {
 386    pub index: Option<usize>,
 387    pub finish_reason: Option<String>,
 388    pub delta: Option<ResponseDelta>,
 389    pub message: Option<ResponseDelta>,
 390}
 391
 392#[derive(Debug, Deserialize)]
 393pub struct ResponseDelta {
 394    pub content: Option<String>,
 395    pub role: Option<Role>,
 396    #[serde(default)]
 397    pub tool_calls: Vec<ToolCallChunk>,
 398    pub reasoning_opaque: Option<String>,
 399    pub reasoning_text: Option<String>,
 400}
 401#[derive(Deserialize, Debug, Eq, PartialEq)]
 402pub struct ToolCallChunk {
 403    pub index: Option<usize>,
 404    pub id: Option<String>,
 405    pub function: Option<FunctionChunk>,
 406}
 407
 408#[derive(Deserialize, Debug, Eq, PartialEq)]
 409pub struct FunctionChunk {
 410    pub name: Option<String>,
 411    pub arguments: Option<String>,
 412    pub thought_signature: Option<String>,
 413}
 414
 415struct GlobalCopilotChat(gpui::Entity<CopilotChat>);
 416
 417impl Global for GlobalCopilotChat {}
 418
 419pub struct CopilotChat {
 420    oauth_token: Option<String>,
 421    api_endpoint: Option<String>,
 422    configuration: CopilotChatConfiguration,
 423    models: Option<Vec<Model>>,
 424    client: Arc<dyn HttpClient>,
 425}
 426
 427pub fn init(
 428    fs: Arc<dyn Fs>,
 429    client: Arc<dyn HttpClient>,
 430    configuration: CopilotChatConfiguration,
 431    cx: &mut App,
 432) {
 433    let copilot_chat = cx.new(|cx| CopilotChat::new(fs, client, configuration, cx));
 434    cx.set_global(GlobalCopilotChat(copilot_chat));
 435}
 436
 437pub fn copilot_chat_config_dir() -> &'static PathBuf {
 438    static COPILOT_CHAT_CONFIG_DIR: OnceLock<PathBuf> = OnceLock::new();
 439
 440    COPILOT_CHAT_CONFIG_DIR.get_or_init(|| {
 441        let config_dir = if cfg!(target_os = "windows") {
 442            dirs::data_local_dir().expect("failed to determine LocalAppData directory")
 443        } else {
 444            std::env::var("XDG_CONFIG_HOME")
 445                .map(PathBuf::from)
 446                .unwrap_or_else(|_| home_dir().join(".config"))
 447        };
 448
 449        config_dir.join("github-copilot")
 450    })
 451}
 452
 453fn copilot_chat_config_paths() -> [PathBuf; 2] {
 454    let base_dir = copilot_chat_config_dir();
 455    [base_dir.join("hosts.json"), base_dir.join("apps.json")]
 456}
 457
 458impl CopilotChat {
 459    pub fn global(cx: &App) -> Option<gpui::Entity<Self>> {
 460        cx.try_global::<GlobalCopilotChat>()
 461            .map(|model| model.0.clone())
 462    }
 463
 464    fn new(
 465        fs: Arc<dyn Fs>,
 466        client: Arc<dyn HttpClient>,
 467        configuration: CopilotChatConfiguration,
 468        cx: &mut Context<Self>,
 469    ) -> Self {
 470        let config_paths: HashSet<PathBuf> = copilot_chat_config_paths().into_iter().collect();
 471        let dir_path = copilot_chat_config_dir();
 472
 473        cx.spawn(async move |this, cx| {
 474            let mut parent_watch_rx = watch_config_dir(
 475                cx.background_executor(),
 476                fs.clone(),
 477                dir_path.clone(),
 478                config_paths,
 479            );
 480            while let Some(contents) = parent_watch_rx.next().await {
 481                let oauth_domain =
 482                    this.read_with(cx, |this, _| this.configuration.oauth_domain())?;
 483                let oauth_token = extract_oauth_token(contents, &oauth_domain);
 484
 485                this.update(cx, |this, cx| {
 486                    this.oauth_token = oauth_token.clone();
 487                    cx.notify();
 488                })?;
 489
 490                if oauth_token.is_some() {
 491                    Self::update_models(&this, cx).await?;
 492                }
 493            }
 494            anyhow::Ok(())
 495        })
 496        .detach_and_log_err(cx);
 497
 498        let this = Self {
 499            oauth_token: std::env::var(COPILOT_OAUTH_ENV_VAR).ok(),
 500            api_endpoint: None,
 501            models: None,
 502            configuration,
 503            client,
 504        };
 505
 506        if this.oauth_token.is_some() {
 507            cx.spawn(async move |this, cx| Self::update_models(&this, cx).await)
 508                .detach_and_log_err(cx);
 509        }
 510
 511        this
 512    }
 513
 514    async fn update_models(this: &WeakEntity<Self>, cx: &mut AsyncApp) -> Result<()> {
 515        let (oauth_token, client, configuration) = this.read_with(cx, |this, _| {
 516            (
 517                this.oauth_token.clone(),
 518                this.client.clone(),
 519                this.configuration.clone(),
 520            )
 521        })?;
 522
 523        let oauth_token = oauth_token
 524            .ok_or_else(|| anyhow!("OAuth token is missing while updating Copilot Chat models"))?;
 525
 526        let api_endpoint =
 527            Self::resolve_api_endpoint(&this, &oauth_token, &configuration, &client, cx).await?;
 528
 529        let models_url = configuration.models_url(&api_endpoint);
 530        let models = get_models(models_url.into(), oauth_token, client.clone()).await?;
 531
 532        this.update(cx, |this, cx| {
 533            this.models = Some(models);
 534            cx.notify();
 535        })?;
 536        anyhow::Ok(())
 537    }
 538
 539    pub fn is_authenticated(&self) -> bool {
 540        self.oauth_token.is_some()
 541    }
 542
 543    pub fn models(&self) -> Option<&[Model]> {
 544        self.models.as_deref()
 545    }
 546
 547    pub async fn stream_completion(
 548        request: Request,
 549        is_user_initiated: bool,
 550        mut cx: AsyncApp,
 551    ) -> Result<BoxStream<'static, Result<ResponseEvent>>> {
 552        let (client, oauth_token, api_endpoint, configuration) =
 553            Self::get_auth_details(&mut cx).await?;
 554
 555        let api_url = configuration.chat_completions_url(&api_endpoint);
 556        stream_completion(
 557            client.clone(),
 558            oauth_token,
 559            api_url.into(),
 560            request,
 561            is_user_initiated,
 562        )
 563        .await
 564    }
 565
 566    pub async fn stream_response(
 567        request: responses::Request,
 568        is_user_initiated: bool,
 569        mut cx: AsyncApp,
 570    ) -> Result<BoxStream<'static, Result<responses::StreamEvent>>> {
 571        let (client, oauth_token, api_endpoint, configuration) =
 572            Self::get_auth_details(&mut cx).await?;
 573
 574        let api_url = configuration.responses_url(&api_endpoint);
 575        responses::stream_response(
 576            client.clone(),
 577            oauth_token,
 578            api_url,
 579            request,
 580            is_user_initiated,
 581        )
 582        .await
 583    }
 584
 585    async fn get_auth_details(
 586        cx: &mut AsyncApp,
 587    ) -> Result<(
 588        Arc<dyn HttpClient>,
 589        String,
 590        String,
 591        CopilotChatConfiguration,
 592    )> {
 593        let this = cx
 594            .update(|cx| Self::global(cx))
 595            .context("Copilot chat is not enabled")?;
 596
 597        let (oauth_token, api_endpoint, client, configuration) = this.read_with(cx, |this, _| {
 598            (
 599                this.oauth_token.clone(),
 600                this.api_endpoint.clone(),
 601                this.client.clone(),
 602                this.configuration.clone(),
 603            )
 604        });
 605
 606        let oauth_token = oauth_token.context("No OAuth token available")?;
 607
 608        let api_endpoint = match api_endpoint {
 609            Some(endpoint) => endpoint,
 610            None => {
 611                let weak = this.downgrade();
 612                Self::resolve_api_endpoint(&weak, &oauth_token, &configuration, &client, cx).await?
 613            }
 614        };
 615
 616        Ok((client, oauth_token, api_endpoint, configuration))
 617    }
 618
 619    async fn resolve_api_endpoint(
 620        this: &WeakEntity<Self>,
 621        oauth_token: &str,
 622        configuration: &CopilotChatConfiguration,
 623        client: &Arc<dyn HttpClient>,
 624        cx: &mut AsyncApp,
 625    ) -> Result<String> {
 626        let api_endpoint = match discover_api_endpoint(oauth_token, configuration, client).await {
 627            Ok(endpoint) => endpoint,
 628            Err(error) => {
 629                log::warn!(
 630                    "Failed to discover Copilot API endpoint via GraphQL, \
 631                         falling back to {DEFAULT_COPILOT_API_ENDPOINT}: {error:#}"
 632                );
 633                DEFAULT_COPILOT_API_ENDPOINT.to_string()
 634            }
 635        };
 636
 637        this.update(cx, |this, cx| {
 638            this.api_endpoint = Some(api_endpoint.clone());
 639            cx.notify();
 640        })?;
 641
 642        Ok(api_endpoint)
 643    }
 644
 645    pub fn set_configuration(
 646        &mut self,
 647        configuration: CopilotChatConfiguration,
 648        cx: &mut Context<Self>,
 649    ) {
 650        let same_configuration = self.configuration == configuration;
 651        self.configuration = configuration;
 652        if !same_configuration {
 653            self.api_endpoint = None;
 654            cx.spawn(async move |this, cx| {
 655                Self::update_models(&this, cx).await?;
 656                Ok::<_, anyhow::Error>(())
 657            })
 658            .detach();
 659        }
 660    }
 661}
 662
 663async fn get_models(
 664    models_url: Arc<str>,
 665    oauth_token: String,
 666    client: Arc<dyn HttpClient>,
 667) -> Result<Vec<Model>> {
 668    let all_models = request_models(models_url, oauth_token, client).await?;
 669
 670    let mut models: Vec<Model> = all_models
 671        .into_iter()
 672        .filter(|model| {
 673            model.model_picker_enabled
 674                && model.capabilities.model_type.as_str() == "chat"
 675                && model
 676                    .policy
 677                    .as_ref()
 678                    .is_none_or(|policy| policy.state == "enabled")
 679        })
 680        .collect();
 681
 682    if let Some(default_model_position) = models.iter().position(|model| model.is_chat_default) {
 683        let default_model = models.remove(default_model_position);
 684        models.insert(0, default_model);
 685    }
 686
 687    Ok(models)
 688}
 689
 690#[derive(Deserialize)]
 691struct GraphQLResponse {
 692    data: Option<GraphQLData>,
 693}
 694
 695#[derive(Deserialize)]
 696struct GraphQLData {
 697    viewer: GraphQLViewer,
 698}
 699
 700#[derive(Deserialize)]
 701struct GraphQLViewer {
 702    #[serde(rename = "copilotEndpoints")]
 703    copilot_endpoints: GraphQLCopilotEndpoints,
 704}
 705
 706#[derive(Deserialize)]
 707struct GraphQLCopilotEndpoints {
 708    api: String,
 709}
 710
 711pub(crate) async fn discover_api_endpoint(
 712    oauth_token: &str,
 713    configuration: &CopilotChatConfiguration,
 714    client: &Arc<dyn HttpClient>,
 715) -> Result<String> {
 716    let graphql_url = configuration.graphql_url();
 717    let query = serde_json::json!({
 718        "query": "query { viewer { copilotEndpoints { api } } }"
 719    });
 720
 721    let request = HttpRequest::builder()
 722        .method(Method::POST)
 723        .uri(graphql_url.as_str())
 724        .header("Authorization", format!("Bearer {}", oauth_token))
 725        .header("Content-Type", "application/json")
 726        .body(AsyncBody::from(serde_json::to_string(&query)?))?;
 727
 728    let mut response = client.send(request).await?;
 729
 730    anyhow::ensure!(
 731        response.status().is_success(),
 732        "GraphQL endpoint discovery failed: {}",
 733        response.status()
 734    );
 735
 736    let mut body = Vec::new();
 737    response.body_mut().read_to_end(&mut body).await?;
 738    let body_str = std::str::from_utf8(&body)?;
 739
 740    let parsed: GraphQLResponse = serde_json::from_str(body_str)
 741        .context("Failed to parse GraphQL response for Copilot endpoint discovery")?;
 742
 743    let data = parsed
 744        .data
 745        .context("GraphQL response contained no data field")?;
 746
 747    Ok(data.viewer.copilot_endpoints.api)
 748}
 749
 750pub(crate) fn copilot_request_headers(
 751    builder: http_client::Builder,
 752    oauth_token: &str,
 753    is_user_initiated: Option<bool>,
 754) -> http_client::Builder {
 755    builder
 756        .header("Authorization", format!("Bearer {}", oauth_token))
 757        .header("Content-Type", "application/json")
 758        .header(
 759            "Editor-Version",
 760            format!(
 761                "Zed/{}",
 762                option_env!("CARGO_PKG_VERSION").unwrap_or("unknown")
 763            ),
 764        )
 765        .when_some(is_user_initiated, |builder, is_user_initiated| {
 766            builder.header(
 767                "X-Initiator",
 768                if is_user_initiated { "user" } else { "agent" },
 769            )
 770        })
 771}
 772
 773async fn request_models(
 774    models_url: Arc<str>,
 775    oauth_token: String,
 776    client: Arc<dyn HttpClient>,
 777) -> Result<Vec<Model>> {
 778    let request_builder = copilot_request_headers(
 779        HttpRequest::builder()
 780            .method(Method::GET)
 781            .uri(models_url.as_ref()),
 782        &oauth_token,
 783        None,
 784    )
 785    .header("x-github-api-version", "2025-05-01");
 786
 787    let request = request_builder.body(AsyncBody::empty())?;
 788
 789    let mut response = client.send(request).await?;
 790
 791    anyhow::ensure!(
 792        response.status().is_success(),
 793        "Failed to request models: {}",
 794        response.status()
 795    );
 796    let mut body = Vec::new();
 797    response.body_mut().read_to_end(&mut body).await?;
 798
 799    let body_str = std::str::from_utf8(&body)?;
 800
 801    let models = serde_json::from_str::<ModelSchema>(body_str)?.data;
 802
 803    Ok(models)
 804}
 805
 806fn extract_oauth_token(contents: String, domain: &str) -> Option<String> {
 807    serde_json::from_str::<serde_json::Value>(&contents)
 808        .map(|v| {
 809            v.as_object().and_then(|obj| {
 810                obj.iter().find_map(|(key, value)| {
 811                    if key.starts_with(domain) {
 812                        value["oauth_token"].as_str().map(|v| v.to_string())
 813                    } else {
 814                        None
 815                    }
 816                })
 817            })
 818        })
 819        .ok()
 820        .flatten()
 821}
 822
 823async fn stream_completion(
 824    client: Arc<dyn HttpClient>,
 825    oauth_token: String,
 826    completion_url: Arc<str>,
 827    request: Request,
 828    is_user_initiated: bool,
 829) -> Result<BoxStream<'static, Result<ResponseEvent>>> {
 830    let is_vision_request = request.messages.iter().any(|message| match message {
 831        ChatMessage::User { content }
 832        | ChatMessage::Assistant { content, .. }
 833        | ChatMessage::Tool { content, .. } => {
 834            matches!(content, ChatMessageContent::Multipart(parts) if parts.iter().any(|part| matches!(part, ChatMessagePart::Image { .. })))
 835        }
 836        _ => false,
 837    });
 838
 839    let request_builder = copilot_request_headers(
 840        HttpRequest::builder()
 841            .method(Method::POST)
 842            .uri(completion_url.as_ref()),
 843        &oauth_token,
 844        Some(is_user_initiated),
 845    )
 846    .when(is_vision_request, |builder| {
 847        builder.header("Copilot-Vision-Request", is_vision_request.to_string())
 848    });
 849
 850    let is_streaming = request.stream;
 851
 852    let json = serde_json::to_string(&request)?;
 853    let request = request_builder.body(AsyncBody::from(json))?;
 854    let mut response = client.send(request).await?;
 855
 856    if !response.status().is_success() {
 857        let mut body = Vec::new();
 858        response.body_mut().read_to_end(&mut body).await?;
 859        let body_str = std::str::from_utf8(&body)?;
 860        anyhow::bail!(
 861            "Failed to connect to API: {} {}",
 862            response.status(),
 863            body_str
 864        );
 865    }
 866
 867    if is_streaming {
 868        let reader = BufReader::new(response.into_body());
 869        Ok(reader
 870            .lines()
 871            .filter_map(|line| async move {
 872                match line {
 873                    Ok(line) => {
 874                        let line = line.strip_prefix("data: ")?;
 875                        if line.starts_with("[DONE]") {
 876                            return None;
 877                        }
 878
 879                        match serde_json::from_str::<ResponseEvent>(line) {
 880                            Ok(response) => {
 881                                if response.choices.is_empty() {
 882                                    None
 883                                } else {
 884                                    Some(Ok(response))
 885                                }
 886                            }
 887                            Err(error) => Some(Err(anyhow!(error))),
 888                        }
 889                    }
 890                    Err(error) => Some(Err(anyhow!(error))),
 891                }
 892            })
 893            .boxed())
 894    } else {
 895        let mut body = Vec::new();
 896        response.body_mut().read_to_end(&mut body).await?;
 897        let body_str = std::str::from_utf8(&body)?;
 898        let response: ResponseEvent = serde_json::from_str(body_str)?;
 899
 900        Ok(futures::stream::once(async move { Ok(response) }).boxed())
 901    }
 902}
 903
 904#[cfg(test)]
 905mod tests {
 906    use super::*;
 907
 908    #[test]
 909    fn test_resilient_model_schema_deserialize() {
 910        let json = r#"{
 911              "data": [
 912                {
 913                  "billing": {
 914                    "is_premium": false,
 915                    "multiplier": 0
 916                  },
 917                  "capabilities": {
 918                    "family": "gpt-4",
 919                    "limits": {
 920                      "max_context_window_tokens": 32768,
 921                      "max_output_tokens": 4096,
 922                      "max_prompt_tokens": 32768
 923                    },
 924                    "object": "model_capabilities",
 925                    "supports": { "streaming": true, "tool_calls": true },
 926                    "tokenizer": "cl100k_base",
 927                    "type": "chat"
 928                  },
 929                  "id": "gpt-4",
 930                  "is_chat_default": false,
 931                  "is_chat_fallback": false,
 932                  "model_picker_enabled": false,
 933                  "name": "GPT 4",
 934                  "object": "model",
 935                  "preview": false,
 936                  "vendor": "Azure OpenAI",
 937                  "version": "gpt-4-0613"
 938                },
 939                {
 940                    "some-unknown-field": 123
 941                },
 942                {
 943                  "billing": {
 944                    "is_premium": true,
 945                    "multiplier": 1,
 946                    "restricted_to": [
 947                      "pro",
 948                      "pro_plus",
 949                      "business",
 950                      "enterprise"
 951                    ]
 952                  },
 953                  "capabilities": {
 954                    "family": "claude-3.7-sonnet",
 955                    "limits": {
 956                      "max_context_window_tokens": 200000,
 957                      "max_output_tokens": 16384,
 958                      "max_prompt_tokens": 90000,
 959                      "vision": {
 960                        "max_prompt_image_size": 3145728,
 961                        "max_prompt_images": 1,
 962                        "supported_media_types": ["image/jpeg", "image/png", "image/webp"]
 963                      }
 964                    },
 965                    "object": "model_capabilities",
 966                    "supports": {
 967                      "parallel_tool_calls": true,
 968                      "streaming": true,
 969                      "tool_calls": true,
 970                      "vision": true
 971                    },
 972                    "tokenizer": "o200k_base",
 973                    "type": "chat"
 974                  },
 975                  "id": "claude-3.7-sonnet",
 976                  "is_chat_default": false,
 977                  "is_chat_fallback": false,
 978                  "model_picker_enabled": true,
 979                  "name": "Claude 3.7 Sonnet",
 980                  "object": "model",
 981                  "policy": {
 982                    "state": "enabled",
 983                    "terms": "Enable access to the latest Claude 3.7 Sonnet model from Anthropic. [Learn more about how GitHub Copilot serves Claude 3.7 Sonnet](https://docs.github.com/copilot/using-github-copilot/using-claude-sonnet-in-github-copilot)."
 984                  },
 985                  "preview": false,
 986                  "vendor": "Anthropic",
 987                  "version": "claude-3.7-sonnet"
 988                }
 989              ],
 990              "object": "list"
 991            }"#;
 992
 993        let schema: ModelSchema = serde_json::from_str(json).unwrap();
 994
 995        assert_eq!(schema.data.len(), 2);
 996        assert_eq!(schema.data[0].id, "gpt-4");
 997        assert_eq!(schema.data[1].id, "claude-3.7-sonnet");
 998    }
 999
1000    #[test]
1001    fn test_unknown_vendor_resilience() {
1002        let json = r#"{
1003              "data": [
1004                {
1005                  "billing": {
1006                    "is_premium": false,
1007                    "multiplier": 1
1008                  },
1009                  "capabilities": {
1010                    "family": "future-model",
1011                    "limits": {
1012                      "max_context_window_tokens": 128000,
1013                      "max_output_tokens": 8192,
1014                      "max_prompt_tokens": 120000
1015                    },
1016                    "object": "model_capabilities",
1017                    "supports": { "streaming": true, "tool_calls": true },
1018                    "type": "chat"
1019                  },
1020                  "id": "future-model-v1",
1021                  "is_chat_default": false,
1022                  "is_chat_fallback": false,
1023                  "model_picker_enabled": true,
1024                  "name": "Future Model v1",
1025                  "object": "model",
1026                  "preview": false,
1027                  "vendor": "SomeNewVendor",
1028                  "version": "v1.0"
1029                }
1030              ],
1031              "object": "list"
1032            }"#;
1033
1034        let schema: ModelSchema = serde_json::from_str(json).unwrap();
1035
1036        assert_eq!(schema.data.len(), 1);
1037        assert_eq!(schema.data[0].id, "future-model-v1");
1038        assert_eq!(schema.data[0].vendor, ModelVendor::Unknown);
1039    }
1040
1041    #[test]
1042    fn test_models_with_pending_policy_deserialize() {
1043        // This test verifies that models with policy states other than "enabled"
1044        // (such as "pending" or "requires_consent") are properly deserialized.
1045        // Note: These models will be filtered out by get_models() and won't appear
1046        // in the model picker until the user enables them on GitHub.
1047        let json = r#"{
1048              "data": [
1049                {
1050                  "billing": { "is_premium": true, "multiplier": 1 },
1051                  "capabilities": {
1052                    "family": "claude-sonnet-4",
1053                    "limits": { "max_context_window_tokens": 200000, "max_output_tokens": 16384, "max_prompt_tokens": 90000 },
1054                    "object": "model_capabilities",
1055                    "supports": { "streaming": true, "tool_calls": true },
1056                    "type": "chat"
1057                  },
1058                  "id": "claude-sonnet-4",
1059                  "is_chat_default": false,
1060                  "is_chat_fallback": false,
1061                  "model_picker_enabled": true,
1062                  "name": "Claude Sonnet 4",
1063                  "object": "model",
1064                  "policy": {
1065                    "state": "pending",
1066                    "terms": "Enable access to Claude models from Anthropic."
1067                  },
1068                  "preview": false,
1069                  "vendor": "Anthropic",
1070                  "version": "claude-sonnet-4"
1071                },
1072                {
1073                  "billing": { "is_premium": true, "multiplier": 1 },
1074                  "capabilities": {
1075                    "family": "claude-opus-4",
1076                    "limits": { "max_context_window_tokens": 200000, "max_output_tokens": 16384, "max_prompt_tokens": 90000 },
1077                    "object": "model_capabilities",
1078                    "supports": { "streaming": true, "tool_calls": true },
1079                    "type": "chat"
1080                  },
1081                  "id": "claude-opus-4",
1082                  "is_chat_default": false,
1083                  "is_chat_fallback": false,
1084                  "model_picker_enabled": true,
1085                  "name": "Claude Opus 4",
1086                  "object": "model",
1087                  "policy": {
1088                    "state": "requires_consent",
1089                    "terms": "Enable access to Claude models from Anthropic."
1090                  },
1091                  "preview": false,
1092                  "vendor": "Anthropic",
1093                  "version": "claude-opus-4"
1094                }
1095              ],
1096              "object": "list"
1097            }"#;
1098
1099        let schema: ModelSchema = serde_json::from_str(json).unwrap();
1100
1101        // Both models should deserialize successfully (filtering happens in get_models)
1102        assert_eq!(schema.data.len(), 2);
1103        assert_eq!(schema.data[0].id, "claude-sonnet-4");
1104        assert_eq!(schema.data[1].id, "claude-opus-4");
1105    }
1106
1107    #[test]
1108    fn test_multiple_anthropic_models_preserved() {
1109        // This test verifies that multiple Claude models from Anthropic
1110        // are all preserved and not incorrectly deduplicated.
1111        // This was the root cause of issue #47540.
1112        let json = r#"{
1113              "data": [
1114                {
1115                  "billing": { "is_premium": true, "multiplier": 1 },
1116                  "capabilities": {
1117                    "family": "claude-sonnet-4",
1118                    "limits": { "max_context_window_tokens": 200000, "max_output_tokens": 16384, "max_prompt_tokens": 90000 },
1119                    "object": "model_capabilities",
1120                    "supports": { "streaming": true, "tool_calls": true },
1121                    "type": "chat"
1122                  },
1123                  "id": "claude-sonnet-4",
1124                  "is_chat_default": false,
1125                  "is_chat_fallback": false,
1126                  "model_picker_enabled": true,
1127                  "name": "Claude Sonnet 4",
1128                  "object": "model",
1129                  "preview": false,
1130                  "vendor": "Anthropic",
1131                  "version": "claude-sonnet-4"
1132                },
1133                {
1134                  "billing": { "is_premium": true, "multiplier": 1 },
1135                  "capabilities": {
1136                    "family": "claude-opus-4",
1137                    "limits": { "max_context_window_tokens": 200000, "max_output_tokens": 16384, "max_prompt_tokens": 90000 },
1138                    "object": "model_capabilities",
1139                    "supports": { "streaming": true, "tool_calls": true },
1140                    "type": "chat"
1141                  },
1142                  "id": "claude-opus-4",
1143                  "is_chat_default": false,
1144                  "is_chat_fallback": false,
1145                  "model_picker_enabled": true,
1146                  "name": "Claude Opus 4",
1147                  "object": "model",
1148                  "preview": false,
1149                  "vendor": "Anthropic",
1150                  "version": "claude-opus-4"
1151                },
1152                {
1153                  "billing": { "is_premium": true, "multiplier": 1 },
1154                  "capabilities": {
1155                    "family": "claude-sonnet-4.5",
1156                    "limits": { "max_context_window_tokens": 200000, "max_output_tokens": 16384, "max_prompt_tokens": 90000 },
1157                    "object": "model_capabilities",
1158                    "supports": { "streaming": true, "tool_calls": true },
1159                    "type": "chat"
1160                  },
1161                  "id": "claude-sonnet-4.5",
1162                  "is_chat_default": false,
1163                  "is_chat_fallback": false,
1164                  "model_picker_enabled": true,
1165                  "name": "Claude Sonnet 4.5",
1166                  "object": "model",
1167                  "preview": false,
1168                  "vendor": "Anthropic",
1169                  "version": "claude-sonnet-4.5"
1170                }
1171              ],
1172              "object": "list"
1173            }"#;
1174
1175        let schema: ModelSchema = serde_json::from_str(json).unwrap();
1176
1177        // All three Anthropic models should be preserved
1178        assert_eq!(schema.data.len(), 3);
1179        assert_eq!(schema.data[0].id, "claude-sonnet-4");
1180        assert_eq!(schema.data[1].id, "claude-opus-4");
1181        assert_eq!(schema.data[2].id, "claude-sonnet-4.5");
1182    }
1183
1184    #[test]
1185    fn test_models_with_same_family_both_preserved() {
1186        // Test that models sharing the same family (e.g., thinking variants)
1187        // are both preserved in the model list.
1188        let json = r#"{
1189              "data": [
1190                {
1191                  "billing": { "is_premium": true, "multiplier": 1 },
1192                  "capabilities": {
1193                    "family": "claude-sonnet-4",
1194                    "limits": { "max_context_window_tokens": 200000, "max_output_tokens": 16384, "max_prompt_tokens": 90000 },
1195                    "object": "model_capabilities",
1196                    "supports": { "streaming": true, "tool_calls": true },
1197                    "type": "chat"
1198                  },
1199                  "id": "claude-sonnet-4",
1200                  "is_chat_default": false,
1201                  "is_chat_fallback": false,
1202                  "model_picker_enabled": true,
1203                  "name": "Claude Sonnet 4",
1204                  "object": "model",
1205                  "preview": false,
1206                  "vendor": "Anthropic",
1207                  "version": "claude-sonnet-4"
1208                },
1209                {
1210                  "billing": { "is_premium": true, "multiplier": 1 },
1211                  "capabilities": {
1212                    "family": "claude-sonnet-4",
1213                    "limits": { "max_context_window_tokens": 200000, "max_output_tokens": 16384, "max_prompt_tokens": 90000 },
1214                    "object": "model_capabilities",
1215                    "supports": { "streaming": true, "tool_calls": true },
1216                    "type": "chat"
1217                  },
1218                  "id": "claude-sonnet-4-thinking",
1219                  "is_chat_default": false,
1220                  "is_chat_fallback": false,
1221                  "model_picker_enabled": true,
1222                  "name": "Claude Sonnet 4 (Thinking)",
1223                  "object": "model",
1224                  "preview": false,
1225                  "vendor": "Anthropic",
1226                  "version": "claude-sonnet-4-thinking"
1227                }
1228              ],
1229              "object": "list"
1230            }"#;
1231
1232        let schema: ModelSchema = serde_json::from_str(json).unwrap();
1233
1234        // Both models should be preserved even though they share the same family
1235        assert_eq!(schema.data.len(), 2);
1236        assert_eq!(schema.data[0].id, "claude-sonnet-4");
1237        assert_eq!(schema.data[1].id, "claude-sonnet-4-thinking");
1238    }
1239
1240    #[test]
1241    fn test_mixed_vendor_models_all_preserved() {
1242        // Test that models from different vendors are all preserved.
1243        let json = r#"{
1244              "data": [
1245                {
1246                  "billing": { "is_premium": false, "multiplier": 1 },
1247                  "capabilities": {
1248                    "family": "gpt-4o",
1249                    "limits": { "max_context_window_tokens": 128000, "max_output_tokens": 16384, "max_prompt_tokens": 110000 },
1250                    "object": "model_capabilities",
1251                    "supports": { "streaming": true, "tool_calls": true },
1252                    "type": "chat"
1253                  },
1254                  "id": "gpt-4o",
1255                  "is_chat_default": true,
1256                  "is_chat_fallback": false,
1257                  "model_picker_enabled": true,
1258                  "name": "GPT-4o",
1259                  "object": "model",
1260                  "preview": false,
1261                  "vendor": "Azure OpenAI",
1262                  "version": "gpt-4o"
1263                },
1264                {
1265                  "billing": { "is_premium": true, "multiplier": 1 },
1266                  "capabilities": {
1267                    "family": "claude-sonnet-4",
1268                    "limits": { "max_context_window_tokens": 200000, "max_output_tokens": 16384, "max_prompt_tokens": 90000 },
1269                    "object": "model_capabilities",
1270                    "supports": { "streaming": true, "tool_calls": true },
1271                    "type": "chat"
1272                  },
1273                  "id": "claude-sonnet-4",
1274                  "is_chat_default": false,
1275                  "is_chat_fallback": false,
1276                  "model_picker_enabled": true,
1277                  "name": "Claude Sonnet 4",
1278                  "object": "model",
1279                  "preview": false,
1280                  "vendor": "Anthropic",
1281                  "version": "claude-sonnet-4"
1282                },
1283                {
1284                  "billing": { "is_premium": true, "multiplier": 1 },
1285                  "capabilities": {
1286                    "family": "gemini-2.0-flash",
1287                    "limits": { "max_context_window_tokens": 1000000, "max_output_tokens": 8192, "max_prompt_tokens": 900000 },
1288                    "object": "model_capabilities",
1289                    "supports": { "streaming": true, "tool_calls": true },
1290                    "type": "chat"
1291                  },
1292                  "id": "gemini-2.0-flash",
1293                  "is_chat_default": false,
1294                  "is_chat_fallback": false,
1295                  "model_picker_enabled": true,
1296                  "name": "Gemini 2.0 Flash",
1297                  "object": "model",
1298                  "preview": false,
1299                  "vendor": "Google",
1300                  "version": "gemini-2.0-flash"
1301                }
1302              ],
1303              "object": "list"
1304            }"#;
1305
1306        let schema: ModelSchema = serde_json::from_str(json).unwrap();
1307
1308        // All three models from different vendors should be preserved
1309        assert_eq!(schema.data.len(), 3);
1310        assert_eq!(schema.data[0].id, "gpt-4o");
1311        assert_eq!(schema.data[1].id, "claude-sonnet-4");
1312        assert_eq!(schema.data[2].id, "gemini-2.0-flash");
1313    }
1314
1315    #[test]
1316    fn test_model_with_messages_endpoint_deserializes() {
1317        // Anthropic Claude models use /v1/messages endpoint.
1318        // This test verifies such models deserialize correctly (issue #47540 root cause).
1319        let json = r#"{
1320              "data": [
1321                {
1322                  "billing": { "is_premium": true, "multiplier": 1 },
1323                  "capabilities": {
1324                    "family": "claude-sonnet-4",
1325                    "limits": { "max_context_window_tokens": 200000, "max_output_tokens": 16384, "max_prompt_tokens": 90000 },
1326                    "object": "model_capabilities",
1327                    "supports": { "streaming": true, "tool_calls": true },
1328                    "type": "chat"
1329                  },
1330                  "id": "claude-sonnet-4",
1331                  "is_chat_default": false,
1332                  "is_chat_fallback": false,
1333                  "model_picker_enabled": true,
1334                  "name": "Claude Sonnet 4",
1335                  "object": "model",
1336                  "preview": false,
1337                  "vendor": "Anthropic",
1338                  "version": "claude-sonnet-4",
1339                  "supported_endpoints": ["/v1/messages"]
1340                }
1341              ],
1342              "object": "list"
1343            }"#;
1344
1345        let schema: ModelSchema = serde_json::from_str(json).unwrap();
1346
1347        assert_eq!(schema.data.len(), 1);
1348        assert_eq!(schema.data[0].id, "claude-sonnet-4");
1349        assert_eq!(
1350            schema.data[0].supported_endpoints,
1351            vec![ModelSupportedEndpoint::Messages]
1352        );
1353    }
1354
1355    #[test]
1356    fn test_model_with_unknown_endpoint_deserializes() {
1357        // Future-proofing: unknown endpoints should deserialize to Unknown variant
1358        // instead of causing the entire model to fail deserialization.
1359        let json = r#"{
1360              "data": [
1361                {
1362                  "billing": { "is_premium": false, "multiplier": 1 },
1363                  "capabilities": {
1364                    "family": "future-model",
1365                    "limits": { "max_context_window_tokens": 128000, "max_output_tokens": 8192, "max_prompt_tokens": 120000 },
1366                    "object": "model_capabilities",
1367                    "supports": { "streaming": true, "tool_calls": true },
1368                    "type": "chat"
1369                  },
1370                  "id": "future-model-v2",
1371                  "is_chat_default": false,
1372                  "is_chat_fallback": false,
1373                  "model_picker_enabled": true,
1374                  "name": "Future Model v2",
1375                  "object": "model",
1376                  "preview": false,
1377                  "vendor": "OpenAI",
1378                  "version": "v2.0",
1379                  "supported_endpoints": ["/v2/completions", "/chat/completions"]
1380                }
1381              ],
1382              "object": "list"
1383            }"#;
1384
1385        let schema: ModelSchema = serde_json::from_str(json).unwrap();
1386
1387        assert_eq!(schema.data.len(), 1);
1388        assert_eq!(schema.data[0].id, "future-model-v2");
1389        assert_eq!(
1390            schema.data[0].supported_endpoints,
1391            vec![
1392                ModelSupportedEndpoint::Unknown,
1393                ModelSupportedEndpoint::ChatCompletions
1394            ]
1395        );
1396    }
1397
1398    #[test]
1399    fn test_model_with_multiple_endpoints() {
1400        // Test model with multiple supported endpoints (common for newer models).
1401        let json = r#"{
1402              "data": [
1403                {
1404                  "billing": { "is_premium": true, "multiplier": 1 },
1405                  "capabilities": {
1406                    "family": "gpt-4o",
1407                    "limits": { "max_context_window_tokens": 128000, "max_output_tokens": 16384, "max_prompt_tokens": 110000 },
1408                    "object": "model_capabilities",
1409                    "supports": { "streaming": true, "tool_calls": true },
1410                    "type": "chat"
1411                  },
1412                  "id": "gpt-4o",
1413                  "is_chat_default": true,
1414                  "is_chat_fallback": false,
1415                  "model_picker_enabled": true,
1416                  "name": "GPT-4o",
1417                  "object": "model",
1418                  "preview": false,
1419                  "vendor": "OpenAI",
1420                  "version": "gpt-4o",
1421                  "supported_endpoints": ["/chat/completions", "/responses"]
1422                }
1423              ],
1424              "object": "list"
1425            }"#;
1426
1427        let schema: ModelSchema = serde_json::from_str(json).unwrap();
1428
1429        assert_eq!(schema.data.len(), 1);
1430        assert_eq!(schema.data[0].id, "gpt-4o");
1431        assert_eq!(
1432            schema.data[0].supported_endpoints,
1433            vec![
1434                ModelSupportedEndpoint::ChatCompletions,
1435                ModelSupportedEndpoint::Responses
1436            ]
1437        );
1438    }
1439
1440    #[test]
1441    fn test_supports_response_method() {
1442        // Test the supports_response() method which determines endpoint routing.
1443        let model_with_responses_only = Model {
1444            billing: ModelBilling {
1445                is_premium: false,
1446                multiplier: 1.0,
1447                restricted_to: None,
1448            },
1449            capabilities: ModelCapabilities {
1450                family: "test".to_string(),
1451                limits: ModelLimits::default(),
1452                supports: ModelSupportedFeatures {
1453                    streaming: true,
1454                    tool_calls: true,
1455                    parallel_tool_calls: false,
1456                    vision: false,
1457                },
1458                model_type: "chat".to_string(),
1459                tokenizer: None,
1460            },
1461            id: "test-model".to_string(),
1462            name: "Test Model".to_string(),
1463            policy: None,
1464            vendor: ModelVendor::OpenAI,
1465            is_chat_default: false,
1466            is_chat_fallback: false,
1467            model_picker_enabled: true,
1468            supported_endpoints: vec![ModelSupportedEndpoint::Responses],
1469        };
1470
1471        let model_with_chat_completions = Model {
1472            supported_endpoints: vec![ModelSupportedEndpoint::ChatCompletions],
1473            ..model_with_responses_only.clone()
1474        };
1475
1476        let model_with_both = Model {
1477            supported_endpoints: vec![
1478                ModelSupportedEndpoint::ChatCompletions,
1479                ModelSupportedEndpoint::Responses,
1480            ],
1481            ..model_with_responses_only.clone()
1482        };
1483
1484        let model_with_messages = Model {
1485            supported_endpoints: vec![ModelSupportedEndpoint::Messages],
1486            ..model_with_responses_only.clone()
1487        };
1488
1489        // Only /responses endpoint -> supports_response = true
1490        assert!(model_with_responses_only.supports_response());
1491
1492        // Only /chat/completions endpoint -> supports_response = false
1493        assert!(!model_with_chat_completions.supports_response());
1494
1495        // Both endpoints (has /chat/completions) -> supports_response = false
1496        assert!(!model_with_both.supports_response());
1497
1498        // Only /v1/messages endpoint -> supports_response = false (doesn't have /responses)
1499        assert!(!model_with_messages.supports_response());
1500    }
1501}