diff --git a/Cargo.lock b/Cargo.lock index ff21d537bdd92b8395682ad75b4de2be4fe0808b..ce1dfe0a4e5dfcc1009c31d72588bb3aa7c3dc4c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3760,7 +3760,6 @@ name = "copilot_chat" version = "0.1.0" dependencies = [ "anyhow", - "chrono", "collections", "dirs 4.0.0", "fs", diff --git a/crates/copilot_chat/Cargo.toml b/crates/copilot_chat/Cargo.toml index 593d7869264b5f653a7a211d63d948d74557b0c2..991a58ac85227ebc84fad5a6d631fe17811fabd4 100644 --- a/crates/copilot_chat/Cargo.toml +++ b/crates/copilot_chat/Cargo.toml @@ -22,7 +22,6 @@ test-support = [ [dependencies] anyhow.workspace = true -chrono.workspace = true collections.workspace = true dirs.workspace = true fs.workspace = true diff --git a/crates/copilot_chat/src/copilot_chat.rs b/crates/copilot_chat/src/copilot_chat.rs index 56b844ea6bc704a2fb5065222f4c356ecbeded5a..513b813517cc7f929f922842611f78fb617ff396 100644 --- a/crates/copilot_chat/src/copilot_chat.rs +++ b/crates/copilot_chat/src/copilot_chat.rs @@ -6,7 +6,6 @@ use std::sync::OnceLock; use anyhow::Context as _; use anyhow::{Result, anyhow}; -use chrono::DateTime; use collections::HashSet; use fs::Fs; use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream}; @@ -20,6 +19,7 @@ use serde::{Deserialize, Serialize}; use settings::watch_config_dir; pub const COPILOT_OAUTH_ENV_VAR: &str = "GH_COPILOT_TOKEN"; +const DEFAULT_COPILOT_API_ENDPOINT: &str = "https://api.githubcopilot.com"; #[derive(Default, Clone, Debug, PartialEq)] pub struct CopilotChatConfiguration { @@ -27,33 +27,33 @@ pub struct CopilotChatConfiguration { } impl CopilotChatConfiguration { - pub fn token_url(&self) -> String { + pub fn oauth_domain(&self) -> String { if let Some(enterprise_uri) = &self.enterprise_uri { - let domain = Self::parse_domain(enterprise_uri); - format!("https://api.{}/copilot_internal/v2/token", domain) + Self::parse_domain(enterprise_uri) } else { - "https://api.github.com/copilot_internal/v2/token".to_string() + "github.com".to_string() } } - pub fn oauth_domain(&self) -> String { + pub fn graphql_url(&self) -> String { if let Some(enterprise_uri) = &self.enterprise_uri { - Self::parse_domain(enterprise_uri) + let domain = Self::parse_domain(enterprise_uri); + format!("https://{}/api/graphql", domain) } else { - "github.com".to_string() + "https://api.github.com/graphql".to_string() } } - pub fn chat_completions_url_from_endpoint(&self, endpoint: &str) -> String { - format!("{}/chat/completions", endpoint) + pub fn chat_completions_url(&self, api_endpoint: &str) -> String { + format!("{}/chat/completions", api_endpoint) } - pub fn responses_url_from_endpoint(&self, endpoint: &str) -> String { - format!("{}/responses", endpoint) + pub fn responses_url(&self, api_endpoint: &str) -> String { + format!("{}/responses", api_endpoint) } - pub fn models_url_from_endpoint(&self, endpoint: &str) -> String { - format!("{}/models", endpoint) + pub fn models_url(&self, api_endpoint: &str) -> String { + format!("{}/models", api_endpoint) } fn parse_domain(enterprise_uri: &str) -> String { @@ -412,55 +412,13 @@ pub struct FunctionChunk { pub thought_signature: Option, } -#[derive(Deserialize)] -struct ApiTokenResponse { - token: String, - expires_at: i64, - endpoints: ApiTokenResponseEndpoints, -} - -#[derive(Deserialize)] -struct ApiTokenResponseEndpoints { - api: String, -} - -#[derive(Clone)] -struct ApiToken { - api_key: String, - expires_at: DateTime, - api_endpoint: String, -} - -impl ApiToken { - pub fn remaining_seconds(&self) -> i64 { - self.expires_at - .timestamp() - .saturating_sub(chrono::Utc::now().timestamp()) - } -} - -impl TryFrom for ApiToken { - type Error = anyhow::Error; - - fn try_from(response: ApiTokenResponse) -> Result { - let expires_at = - DateTime::from_timestamp(response.expires_at, 0).context("invalid expires_at")?; - - Ok(Self { - api_key: response.token, - expires_at, - api_endpoint: response.endpoints.api, - }) - } -} - struct GlobalCopilotChat(gpui::Entity); impl Global for GlobalCopilotChat {} pub struct CopilotChat { oauth_token: Option, - api_token: Option, + api_endpoint: Option, configuration: CopilotChatConfiguration, models: Option>, client: Arc, @@ -539,7 +497,7 @@ impl CopilotChat { let this = Self { oauth_token: std::env::var(COPILOT_OAUTH_ENV_VAR).ok(), - api_token: None, + api_endpoint: None, models: None, configuration, client, @@ -565,15 +523,13 @@ impl CopilotChat { let oauth_token = oauth_token .ok_or_else(|| anyhow!("OAuth token is missing while updating Copilot Chat models"))?; - let token_url = configuration.token_url(); - let api_token = request_api_token(&oauth_token, token_url.into(), client.clone()).await?; + let api_endpoint = + Self::resolve_api_endpoint(&this, &oauth_token, &configuration, &client, cx).await?; - let models_url = configuration.models_url_from_endpoint(&api_token.api_endpoint); - let models = - get_models(models_url.into(), api_token.api_key.clone(), client.clone()).await?; + let models_url = configuration.models_url(&api_endpoint); + let models = get_models(models_url.into(), oauth_token, client.clone()).await?; this.update(cx, |this, cx| { - this.api_token = Some(api_token); this.models = Some(models); cx.notify(); })?; @@ -593,12 +549,13 @@ impl CopilotChat { is_user_initiated: bool, mut cx: AsyncApp, ) -> Result>> { - let (client, token, configuration) = Self::get_auth_details(&mut cx).await?; + let (client, oauth_token, api_endpoint, configuration) = + Self::get_auth_details(&mut cx).await?; - let api_url = configuration.chat_completions_url_from_endpoint(&token.api_endpoint); + let api_url = configuration.chat_completions_url(&api_endpoint); stream_completion( client.clone(), - token.api_key, + oauth_token, api_url.into(), request, is_user_initiated, @@ -611,12 +568,13 @@ impl CopilotChat { is_user_initiated: bool, mut cx: AsyncApp, ) -> Result>> { - let (client, token, configuration) = Self::get_auth_details(&mut cx).await?; + let (client, oauth_token, api_endpoint, configuration) = + Self::get_auth_details(&mut cx).await?; - let api_url = configuration.responses_url_from_endpoint(&token.api_endpoint); + let api_url = configuration.responses_url(&api_endpoint); responses::stream_response( client.clone(), - token.api_key, + oauth_token, api_url, request, is_user_initiated, @@ -626,15 +584,20 @@ impl CopilotChat { async fn get_auth_details( cx: &mut AsyncApp, - ) -> Result<(Arc, ApiToken, CopilotChatConfiguration)> { + ) -> Result<( + Arc, + String, + String, + CopilotChatConfiguration, + )> { let this = cx .update(|cx| Self::global(cx)) .context("Copilot chat is not enabled")?; - let (oauth_token, api_token, client, configuration) = this.read_with(cx, |this, _| { + let (oauth_token, api_endpoint, client, configuration) = this.read_with(cx, |this, _| { ( this.oauth_token.clone(), - this.api_token.clone(), + this.api_endpoint.clone(), this.client.clone(), this.configuration.clone(), ) @@ -642,21 +605,41 @@ impl CopilotChat { let oauth_token = oauth_token.context("No OAuth token available")?; - let token = match api_token { - Some(api_token) if api_token.remaining_seconds() > 5 * 60 => api_token, - _ => { - let token_url = configuration.token_url(); - let token = - request_api_token(&oauth_token, token_url.into(), client.clone()).await?; - this.update(cx, |this, cx| { - this.api_token = Some(token.clone()); - cx.notify(); - }); - token + let api_endpoint = match api_endpoint { + Some(endpoint) => endpoint, + None => { + let weak = this.downgrade(); + Self::resolve_api_endpoint(&weak, &oauth_token, &configuration, &client, cx).await? } }; - Ok((client, token, configuration)) + Ok((client, oauth_token, api_endpoint, configuration)) + } + + async fn resolve_api_endpoint( + this: &WeakEntity, + oauth_token: &str, + configuration: &CopilotChatConfiguration, + client: &Arc, + cx: &mut AsyncApp, + ) -> Result { + let api_endpoint = match discover_api_endpoint(oauth_token, configuration, client).await { + Ok(endpoint) => endpoint, + Err(error) => { + log::warn!( + "Failed to discover Copilot API endpoint via GraphQL, \ + falling back to {DEFAULT_COPILOT_API_ENDPOINT}: {error:#}" + ); + DEFAULT_COPILOT_API_ENDPOINT.to_string() + } + }; + + this.update(cx, |this, cx| { + this.api_endpoint = Some(api_endpoint.clone()); + cx.notify(); + })?; + + Ok(api_endpoint) } pub fn set_configuration( @@ -667,7 +650,7 @@ impl CopilotChat { let same_configuration = self.configuration == configuration; self.configuration = configuration; if !same_configuration { - self.api_token = None; + self.api_endpoint = None; cx.spawn(async move |this, cx| { Self::update_models(&this, cx).await?; Ok::<_, anyhow::Error>(()) @@ -679,10 +662,10 @@ impl CopilotChat { async fn get_models( models_url: Arc, - api_token: String, + oauth_token: String, client: Arc, ) -> Result> { - let all_models = request_models(models_url, api_token, client).await?; + let all_models = request_models(models_url, oauth_token, client).await?; let mut models: Vec = all_models .into_iter() @@ -704,69 +687,120 @@ async fn get_models( Ok(models) } -async fn request_models( - models_url: Arc, - api_token: String, - client: Arc, -) -> Result> { - let request_builder = HttpRequest::builder() - .method(Method::GET) - .uri(models_url.as_ref()) - .header("Authorization", format!("Bearer {}", api_token)) - .header("Content-Type", "application/json") - .header("Copilot-Integration-Id", "vscode-chat") - .header("Editor-Version", "vscode/1.103.2") - .header("x-github-api-version", "2025-05-01"); +#[derive(Deserialize)] +struct GraphQLResponse { + data: Option, +} - let request = request_builder.body(AsyncBody::empty())?; +#[derive(Deserialize)] +struct GraphQLData { + viewer: GraphQLViewer, +} + +#[derive(Deserialize)] +struct GraphQLViewer { + #[serde(rename = "copilotEndpoints")] + copilot_endpoints: GraphQLCopilotEndpoints, +} + +#[derive(Deserialize)] +struct GraphQLCopilotEndpoints { + api: String, +} + +pub(crate) async fn discover_api_endpoint( + oauth_token: &str, + configuration: &CopilotChatConfiguration, + client: &Arc, +) -> Result { + let graphql_url = configuration.graphql_url(); + let query = serde_json::json!({ + "query": "query { viewer { copilotEndpoints { api } } }" + }); + + let request = HttpRequest::builder() + .method(Method::POST) + .uri(graphql_url.as_str()) + .header("Authorization", format!("Bearer {}", oauth_token)) + .header("Content-Type", "application/json") + .body(AsyncBody::from(serde_json::to_string(&query)?))?; let mut response = client.send(request).await?; anyhow::ensure!( response.status().is_success(), - "Failed to request models: {}", + "GraphQL endpoint discovery failed: {}", response.status() ); + let mut body = Vec::new(); response.body_mut().read_to_end(&mut body).await?; - let body_str = std::str::from_utf8(&body)?; - let models = serde_json::from_str::(body_str)?.data; + let parsed: GraphQLResponse = serde_json::from_str(body_str) + .context("Failed to parse GraphQL response for Copilot endpoint discovery")?; - Ok(models) + let data = parsed + .data + .context("GraphQL response contained no data field")?; + + Ok(data.viewer.copilot_endpoints.api) } -async fn request_api_token( +pub(crate) fn copilot_request_headers( + builder: http_client::Builder, oauth_token: &str, - auth_url: Arc, + is_user_initiated: Option, +) -> http_client::Builder { + builder + .header("Authorization", format!("Bearer {}", oauth_token)) + .header("Content-Type", "application/json") + .header( + "Editor-Version", + format!( + "Zed/{}", + option_env!("CARGO_PKG_VERSION").unwrap_or("unknown") + ), + ) + .when_some(is_user_initiated, |builder, is_user_initiated| { + builder.header( + "X-Initiator", + if is_user_initiated { "user" } else { "agent" }, + ) + }) +} + +async fn request_models( + models_url: Arc, + oauth_token: String, client: Arc, -) -> Result { - let request_builder = HttpRequest::builder() - .method(Method::GET) - .uri(auth_url.as_ref()) - .header("Authorization", format!("token {}", oauth_token)) - .header("Accept", "application/json"); +) -> Result> { + let request_builder = copilot_request_headers( + HttpRequest::builder() + .method(Method::GET) + .uri(models_url.as_ref()), + &oauth_token, + None, + ) + .header("x-github-api-version", "2025-05-01"); let request = request_builder.body(AsyncBody::empty())?; let mut response = client.send(request).await?; - if response.status().is_success() { - let mut body = Vec::new(); - response.body_mut().read_to_end(&mut body).await?; + anyhow::ensure!( + response.status().is_success(), + "Failed to request models: {}", + response.status() + ); + let mut body = Vec::new(); + response.body_mut().read_to_end(&mut body).await?; - let body_str = std::str::from_utf8(&body)?; + let body_str = std::str::from_utf8(&body)?; - let parsed: ApiTokenResponse = serde_json::from_str(body_str)?; - ApiToken::try_from(parsed) - } else { - let mut body = Vec::new(); - response.body_mut().read_to_end(&mut body).await?; + let models = serde_json::from_str::(body_str)?.data; - let body_str = std::str::from_utf8(&body)?; - anyhow::bail!("Failed to request API token: {body_str}"); - } + Ok(models) } fn extract_oauth_token(contents: String, domain: &str) -> Option { @@ -788,7 +822,7 @@ fn extract_oauth_token(contents: String, domain: &str) -> Option { async fn stream_completion( client: Arc, - api_key: String, + oauth_token: String, completion_url: Arc, request: Request, is_user_initiated: bool, @@ -802,25 +836,16 @@ async fn stream_completion( _ => false, }); - let request_initiator = if is_user_initiated { "user" } else { "agent" }; - - let request_builder = HttpRequest::builder() - .method(Method::POST) - .uri(completion_url.as_ref()) - .header( - "Editor-Version", - format!( - "Zed/{}", - option_env!("CARGO_PKG_VERSION").unwrap_or("unknown") - ), - ) - .header("Authorization", format!("Bearer {}", api_key)) - .header("Content-Type", "application/json") - .header("Copilot-Integration-Id", "vscode-chat") - .header("X-Initiator", request_initiator) - .when(is_vision_request, |builder| { - builder.header("Copilot-Vision-Request", is_vision_request.to_string()) - }); + let request_builder = copilot_request_headers( + HttpRequest::builder() + .method(Method::POST) + .uri(completion_url.as_ref()), + &oauth_token, + Some(is_user_initiated), + ) + .when(is_vision_request, |builder| { + builder.header("Copilot-Vision-Request", is_vision_request.to_string()) + }); let is_streaming = request.stream; diff --git a/crates/copilot_chat/src/responses.rs b/crates/copilot_chat/src/responses.rs index 8262d8e4c370a66a44fc65a2b4de05da23dc5f18..473e583027bf77f3f7dc43d7914f6d2afff743a0 100644 --- a/crates/copilot_chat/src/responses.rs +++ b/crates/copilot_chat/src/responses.rs @@ -1,5 +1,6 @@ use std::sync::Arc; +use super::copilot_request_headers; use anyhow::{Result, anyhow}; use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream}; use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest}; @@ -275,7 +276,7 @@ pub enum ResponseOutputContent { pub async fn stream_response( client: Arc, - api_key: String, + oauth_token: String, api_url: String, request: Request, is_user_initiated: bool, @@ -290,22 +291,11 @@ pub async fn stream_response( _ => false, }); - let request_initiator = if is_user_initiated { "user" } else { "agent" }; - - let request_builder = HttpRequest::builder() - .method(Method::POST) - .uri(&api_url) - .header( - "Editor-Version", - format!( - "Zed/{}", - option_env!("CARGO_PKG_VERSION").unwrap_or("unknown") - ), - ) - .header("Authorization", format!("Bearer {}", api_key)) - .header("Content-Type", "application/json") - .header("Copilot-Integration-Id", "vscode-chat") - .header("X-Initiator", request_initiator); + let request_builder = copilot_request_headers( + HttpRequest::builder().method(Method::POST).uri(&api_url), + &oauth_token, + Some(is_user_initiated), + ); let request_builder = if is_vision_request { request_builder.header("Copilot-Vision-Request", "true")