Switch to a Zed user-agent header for Copilot traffic (#48591)

Richard Feldman and Mikayla Maki created

Follow-up to #48528

Release Notes:

- N/A

---------

Co-authored-by: Mikayla Maki <mikayla.c.maki@gmail.com>

Change summary

Cargo.lock                              |   1 
crates/copilot_chat/Cargo.toml          |   1 
crates/copilot_chat/src/copilot_chat.rs | 317 ++++++++++++++------------
crates/copilot_chat/src/responses.rs    |  24 -
4 files changed, 178 insertions(+), 165 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -3760,7 +3760,6 @@ name = "copilot_chat"
 version = "0.1.0"
 dependencies = [
  "anyhow",
- "chrono",
  "collections",
  "dirs 4.0.0",
  "fs",

crates/copilot_chat/Cargo.toml 🔗

@@ -22,7 +22,6 @@ test-support = [
 
 [dependencies]
 anyhow.workspace = true
-chrono.workspace = true
 collections.workspace = true
 dirs.workspace = true
 fs.workspace = true

crates/copilot_chat/src/copilot_chat.rs 🔗

@@ -6,7 +6,6 @@ use std::sync::OnceLock;
 
 use anyhow::Context as _;
 use anyhow::{Result, anyhow};
-use chrono::DateTime;
 use collections::HashSet;
 use fs::Fs;
 use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
@@ -20,6 +19,7 @@ use serde::{Deserialize, Serialize};
 use settings::watch_config_dir;
 
 pub const COPILOT_OAUTH_ENV_VAR: &str = "GH_COPILOT_TOKEN";
+const DEFAULT_COPILOT_API_ENDPOINT: &str = "https://api.githubcopilot.com";
 
 #[derive(Default, Clone, Debug, PartialEq)]
 pub struct CopilotChatConfiguration {
@@ -27,33 +27,33 @@ pub struct CopilotChatConfiguration {
 }
 
 impl CopilotChatConfiguration {
-    pub fn token_url(&self) -> String {
+    pub fn oauth_domain(&self) -> String {
         if let Some(enterprise_uri) = &self.enterprise_uri {
-            let domain = Self::parse_domain(enterprise_uri);
-            format!("https://api.{}/copilot_internal/v2/token", domain)
+            Self::parse_domain(enterprise_uri)
         } else {
-            "https://api.github.com/copilot_internal/v2/token".to_string()
+            "github.com".to_string()
         }
     }
 
-    pub fn oauth_domain(&self) -> String {
+    pub fn graphql_url(&self) -> String {
         if let Some(enterprise_uri) = &self.enterprise_uri {
-            Self::parse_domain(enterprise_uri)
+            let domain = Self::parse_domain(enterprise_uri);
+            format!("https://{}/api/graphql", domain)
         } else {
-            "github.com".to_string()
+            "https://api.github.com/graphql".to_string()
         }
     }
 
-    pub fn chat_completions_url_from_endpoint(&self, endpoint: &str) -> String {
-        format!("{}/chat/completions", endpoint)
+    pub fn chat_completions_url(&self, api_endpoint: &str) -> String {
+        format!("{}/chat/completions", api_endpoint)
     }
 
-    pub fn responses_url_from_endpoint(&self, endpoint: &str) -> String {
-        format!("{}/responses", endpoint)
+    pub fn responses_url(&self, api_endpoint: &str) -> String {
+        format!("{}/responses", api_endpoint)
     }
 
-    pub fn models_url_from_endpoint(&self, endpoint: &str) -> String {
-        format!("{}/models", endpoint)
+    pub fn models_url(&self, api_endpoint: &str) -> String {
+        format!("{}/models", api_endpoint)
     }
 
     fn parse_domain(enterprise_uri: &str) -> String {
@@ -412,55 +412,13 @@ pub struct FunctionChunk {
     pub thought_signature: Option<String>,
 }
 
-#[derive(Deserialize)]
-struct ApiTokenResponse {
-    token: String,
-    expires_at: i64,
-    endpoints: ApiTokenResponseEndpoints,
-}
-
-#[derive(Deserialize)]
-struct ApiTokenResponseEndpoints {
-    api: String,
-}
-
-#[derive(Clone)]
-struct ApiToken {
-    api_key: String,
-    expires_at: DateTime<chrono::Utc>,
-    api_endpoint: String,
-}
-
-impl ApiToken {
-    pub fn remaining_seconds(&self) -> i64 {
-        self.expires_at
-            .timestamp()
-            .saturating_sub(chrono::Utc::now().timestamp())
-    }
-}
-
-impl TryFrom<ApiTokenResponse> for ApiToken {
-    type Error = anyhow::Error;
-
-    fn try_from(response: ApiTokenResponse) -> Result<Self, Self::Error> {
-        let expires_at =
-            DateTime::from_timestamp(response.expires_at, 0).context("invalid expires_at")?;
-
-        Ok(Self {
-            api_key: response.token,
-            expires_at,
-            api_endpoint: response.endpoints.api,
-        })
-    }
-}
-
 struct GlobalCopilotChat(gpui::Entity<CopilotChat>);
 
 impl Global for GlobalCopilotChat {}
 
 pub struct CopilotChat {
     oauth_token: Option<String>,
-    api_token: Option<ApiToken>,
+    api_endpoint: Option<String>,
     configuration: CopilotChatConfiguration,
     models: Option<Vec<Model>>,
     client: Arc<dyn HttpClient>,
@@ -539,7 +497,7 @@ impl CopilotChat {
 
         let this = Self {
             oauth_token: std::env::var(COPILOT_OAUTH_ENV_VAR).ok(),
-            api_token: None,
+            api_endpoint: None,
             models: None,
             configuration,
             client,
@@ -565,15 +523,13 @@ impl CopilotChat {
         let oauth_token = oauth_token
             .ok_or_else(|| anyhow!("OAuth token is missing while updating Copilot Chat models"))?;
 
-        let token_url = configuration.token_url();
-        let api_token = request_api_token(&oauth_token, token_url.into(), client.clone()).await?;
+        let api_endpoint =
+            Self::resolve_api_endpoint(&this, &oauth_token, &configuration, &client, cx).await?;
 
-        let models_url = configuration.models_url_from_endpoint(&api_token.api_endpoint);
-        let models =
-            get_models(models_url.into(), api_token.api_key.clone(), client.clone()).await?;
+        let models_url = configuration.models_url(&api_endpoint);
+        let models = get_models(models_url.into(), oauth_token, client.clone()).await?;
 
         this.update(cx, |this, cx| {
-            this.api_token = Some(api_token);
             this.models = Some(models);
             cx.notify();
         })?;
@@ -593,12 +549,13 @@ impl CopilotChat {
         is_user_initiated: bool,
         mut cx: AsyncApp,
     ) -> Result<BoxStream<'static, Result<ResponseEvent>>> {
-        let (client, token, configuration) = Self::get_auth_details(&mut cx).await?;
+        let (client, oauth_token, api_endpoint, configuration) =
+            Self::get_auth_details(&mut cx).await?;
 
-        let api_url = configuration.chat_completions_url_from_endpoint(&token.api_endpoint);
+        let api_url = configuration.chat_completions_url(&api_endpoint);
         stream_completion(
             client.clone(),
-            token.api_key,
+            oauth_token,
             api_url.into(),
             request,
             is_user_initiated,
@@ -611,12 +568,13 @@ impl CopilotChat {
         is_user_initiated: bool,
         mut cx: AsyncApp,
     ) -> Result<BoxStream<'static, Result<responses::StreamEvent>>> {
-        let (client, token, configuration) = Self::get_auth_details(&mut cx).await?;
+        let (client, oauth_token, api_endpoint, configuration) =
+            Self::get_auth_details(&mut cx).await?;
 
-        let api_url = configuration.responses_url_from_endpoint(&token.api_endpoint);
+        let api_url = configuration.responses_url(&api_endpoint);
         responses::stream_response(
             client.clone(),
-            token.api_key,
+            oauth_token,
             api_url,
             request,
             is_user_initiated,
@@ -626,15 +584,20 @@ impl CopilotChat {
 
     async fn get_auth_details(
         cx: &mut AsyncApp,
-    ) -> Result<(Arc<dyn HttpClient>, ApiToken, CopilotChatConfiguration)> {
+    ) -> Result<(
+        Arc<dyn HttpClient>,
+        String,
+        String,
+        CopilotChatConfiguration,
+    )> {
         let this = cx
             .update(|cx| Self::global(cx))
             .context("Copilot chat is not enabled")?;
 
-        let (oauth_token, api_token, client, configuration) = this.read_with(cx, |this, _| {
+        let (oauth_token, api_endpoint, client, configuration) = this.read_with(cx, |this, _| {
             (
                 this.oauth_token.clone(),
-                this.api_token.clone(),
+                this.api_endpoint.clone(),
                 this.client.clone(),
                 this.configuration.clone(),
             )
@@ -642,21 +605,41 @@ impl CopilotChat {
 
         let oauth_token = oauth_token.context("No OAuth token available")?;
 
-        let token = match api_token {
-            Some(api_token) if api_token.remaining_seconds() > 5 * 60 => api_token,
-            _ => {
-                let token_url = configuration.token_url();
-                let token =
-                    request_api_token(&oauth_token, token_url.into(), client.clone()).await?;
-                this.update(cx, |this, cx| {
-                    this.api_token = Some(token.clone());
-                    cx.notify();
-                });
-                token
+        let api_endpoint = match api_endpoint {
+            Some(endpoint) => endpoint,
+            None => {
+                let weak = this.downgrade();
+                Self::resolve_api_endpoint(&weak, &oauth_token, &configuration, &client, cx).await?
             }
         };
 
-        Ok((client, token, configuration))
+        Ok((client, oauth_token, api_endpoint, configuration))
+    }
+
+    async fn resolve_api_endpoint(
+        this: &WeakEntity<Self>,
+        oauth_token: &str,
+        configuration: &CopilotChatConfiguration,
+        client: &Arc<dyn HttpClient>,
+        cx: &mut AsyncApp,
+    ) -> Result<String> {
+        let api_endpoint = match discover_api_endpoint(oauth_token, configuration, client).await {
+            Ok(endpoint) => endpoint,
+            Err(error) => {
+                log::warn!(
+                    "Failed to discover Copilot API endpoint via GraphQL, \
+                         falling back to {DEFAULT_COPILOT_API_ENDPOINT}: {error:#}"
+                );
+                DEFAULT_COPILOT_API_ENDPOINT.to_string()
+            }
+        };
+
+        this.update(cx, |this, cx| {
+            this.api_endpoint = Some(api_endpoint.clone());
+            cx.notify();
+        })?;
+
+        Ok(api_endpoint)
     }
 
     pub fn set_configuration(
@@ -667,7 +650,7 @@ impl CopilotChat {
         let same_configuration = self.configuration == configuration;
         self.configuration = configuration;
         if !same_configuration {
-            self.api_token = None;
+            self.api_endpoint = None;
             cx.spawn(async move |this, cx| {
                 Self::update_models(&this, cx).await?;
                 Ok::<_, anyhow::Error>(())
@@ -679,10 +662,10 @@ impl CopilotChat {
 
 async fn get_models(
     models_url: Arc<str>,
-    api_token: String,
+    oauth_token: String,
     client: Arc<dyn HttpClient>,
 ) -> Result<Vec<Model>> {
-    let all_models = request_models(models_url, api_token, client).await?;
+    let all_models = request_models(models_url, oauth_token, client).await?;
 
     let mut models: Vec<Model> = all_models
         .into_iter()
@@ -704,69 +687,120 @@ async fn get_models(
     Ok(models)
 }
 
-async fn request_models(
-    models_url: Arc<str>,
-    api_token: String,
-    client: Arc<dyn HttpClient>,
-) -> Result<Vec<Model>> {
-    let request_builder = HttpRequest::builder()
-        .method(Method::GET)
-        .uri(models_url.as_ref())
-        .header("Authorization", format!("Bearer {}", api_token))
-        .header("Content-Type", "application/json")
-        .header("Copilot-Integration-Id", "vscode-chat")
-        .header("Editor-Version", "vscode/1.103.2")
-        .header("x-github-api-version", "2025-05-01");
+#[derive(Deserialize)]
+struct GraphQLResponse {
+    data: Option<GraphQLData>,
+}
 
-    let request = request_builder.body(AsyncBody::empty())?;
+#[derive(Deserialize)]
+struct GraphQLData {
+    viewer: GraphQLViewer,
+}
+
+#[derive(Deserialize)]
+struct GraphQLViewer {
+    #[serde(rename = "copilotEndpoints")]
+    copilot_endpoints: GraphQLCopilotEndpoints,
+}
+
+#[derive(Deserialize)]
+struct GraphQLCopilotEndpoints {
+    api: String,
+}
+
+pub(crate) async fn discover_api_endpoint(
+    oauth_token: &str,
+    configuration: &CopilotChatConfiguration,
+    client: &Arc<dyn HttpClient>,
+) -> Result<String> {
+    let graphql_url = configuration.graphql_url();
+    let query = serde_json::json!({
+        "query": "query { viewer { copilotEndpoints { api } } }"
+    });
+
+    let request = HttpRequest::builder()
+        .method(Method::POST)
+        .uri(graphql_url.as_str())
+        .header("Authorization", format!("Bearer {}", oauth_token))
+        .header("Content-Type", "application/json")
+        .body(AsyncBody::from(serde_json::to_string(&query)?))?;
 
     let mut response = client.send(request).await?;
 
     anyhow::ensure!(
         response.status().is_success(),
-        "Failed to request models: {}",
+        "GraphQL endpoint discovery failed: {}",
         response.status()
     );
+
     let mut body = Vec::new();
     response.body_mut().read_to_end(&mut body).await?;
-
     let body_str = std::str::from_utf8(&body)?;
 
-    let models = serde_json::from_str::<ModelSchema>(body_str)?.data;
+    let parsed: GraphQLResponse = serde_json::from_str(body_str)
+        .context("Failed to parse GraphQL response for Copilot endpoint discovery")?;
 
-    Ok(models)
+    let data = parsed
+        .data
+        .context("GraphQL response contained no data field")?;
+
+    Ok(data.viewer.copilot_endpoints.api)
 }
 
-async fn request_api_token(
+pub(crate) fn copilot_request_headers(
+    builder: http_client::Builder,
     oauth_token: &str,
-    auth_url: Arc<str>,
+    is_user_initiated: Option<bool>,
+) -> http_client::Builder {
+    builder
+        .header("Authorization", format!("Bearer {}", oauth_token))
+        .header("Content-Type", "application/json")
+        .header(
+            "Editor-Version",
+            format!(
+                "Zed/{}",
+                option_env!("CARGO_PKG_VERSION").unwrap_or("unknown")
+            ),
+        )
+        .when_some(is_user_initiated, |builder, is_user_initiated| {
+            builder.header(
+                "X-Initiator",
+                if is_user_initiated { "user" } else { "agent" },
+            )
+        })
+}
+
+async fn request_models(
+    models_url: Arc<str>,
+    oauth_token: String,
     client: Arc<dyn HttpClient>,
-) -> Result<ApiToken> {
-    let request_builder = HttpRequest::builder()
-        .method(Method::GET)
-        .uri(auth_url.as_ref())
-        .header("Authorization", format!("token {}", oauth_token))
-        .header("Accept", "application/json");
+) -> Result<Vec<Model>> {
+    let request_builder = copilot_request_headers(
+        HttpRequest::builder()
+            .method(Method::GET)
+            .uri(models_url.as_ref()),
+        &oauth_token,
+        None,
+    )
+    .header("x-github-api-version", "2025-05-01");
 
     let request = request_builder.body(AsyncBody::empty())?;
 
     let mut response = client.send(request).await?;
 
-    if response.status().is_success() {
-        let mut body = Vec::new();
-        response.body_mut().read_to_end(&mut body).await?;
+    anyhow::ensure!(
+        response.status().is_success(),
+        "Failed to request models: {}",
+        response.status()
+    );
+    let mut body = Vec::new();
+    response.body_mut().read_to_end(&mut body).await?;
 
-        let body_str = std::str::from_utf8(&body)?;
+    let body_str = std::str::from_utf8(&body)?;
 
-        let parsed: ApiTokenResponse = serde_json::from_str(body_str)?;
-        ApiToken::try_from(parsed)
-    } else {
-        let mut body = Vec::new();
-        response.body_mut().read_to_end(&mut body).await?;
+    let models = serde_json::from_str::<ModelSchema>(body_str)?.data;
 
-        let body_str = std::str::from_utf8(&body)?;
-        anyhow::bail!("Failed to request API token: {body_str}");
-    }
+    Ok(models)
 }
 
 fn extract_oauth_token(contents: String, domain: &str) -> Option<String> {
@@ -788,7 +822,7 @@ fn extract_oauth_token(contents: String, domain: &str) -> Option<String> {
 
 async fn stream_completion(
     client: Arc<dyn HttpClient>,
-    api_key: String,
+    oauth_token: String,
     completion_url: Arc<str>,
     request: Request,
     is_user_initiated: bool,
@@ -802,25 +836,16 @@ async fn stream_completion(
         _ => false,
     });
 
-    let request_initiator = if is_user_initiated { "user" } else { "agent" };
-
-    let request_builder = HttpRequest::builder()
-        .method(Method::POST)
-        .uri(completion_url.as_ref())
-        .header(
-            "Editor-Version",
-            format!(
-                "Zed/{}",
-                option_env!("CARGO_PKG_VERSION").unwrap_or("unknown")
-            ),
-        )
-        .header("Authorization", format!("Bearer {}", api_key))
-        .header("Content-Type", "application/json")
-        .header("Copilot-Integration-Id", "vscode-chat")
-        .header("X-Initiator", request_initiator)
-        .when(is_vision_request, |builder| {
-            builder.header("Copilot-Vision-Request", is_vision_request.to_string())
-        });
+    let request_builder = copilot_request_headers(
+        HttpRequest::builder()
+            .method(Method::POST)
+            .uri(completion_url.as_ref()),
+        &oauth_token,
+        Some(is_user_initiated),
+    )
+    .when(is_vision_request, |builder| {
+        builder.header("Copilot-Vision-Request", is_vision_request.to_string())
+    });
 
     let is_streaming = request.stream;
 

crates/copilot_chat/src/responses.rs 🔗

@@ -1,5 +1,6 @@
 use std::sync::Arc;
 
+use super::copilot_request_headers;
 use anyhow::{Result, anyhow};
 use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
 use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
@@ -275,7 +276,7 @@ pub enum ResponseOutputContent {
 
 pub async fn stream_response(
     client: Arc<dyn HttpClient>,
-    api_key: String,
+    oauth_token: String,
     api_url: String,
     request: Request,
     is_user_initiated: bool,
@@ -290,22 +291,11 @@ pub async fn stream_response(
         _ => false,
     });
 
-    let request_initiator = if is_user_initiated { "user" } else { "agent" };
-
-    let request_builder = HttpRequest::builder()
-        .method(Method::POST)
-        .uri(&api_url)
-        .header(
-            "Editor-Version",
-            format!(
-                "Zed/{}",
-                option_env!("CARGO_PKG_VERSION").unwrap_or("unknown")
-            ),
-        )
-        .header("Authorization", format!("Bearer {}", api_key))
-        .header("Content-Type", "application/json")
-        .header("Copilot-Integration-Id", "vscode-chat")
-        .header("X-Initiator", request_initiator);
+    let request_builder = copilot_request_headers(
+        HttpRequest::builder().method(Method::POST).uri(&api_url),
+        &oauth_token,
+        Some(is_user_initiated),
+    );
 
     let request_builder = if is_vision_request {
         request_builder.header("Copilot-Vision-Request", "true")