diff --git a/Cargo.lock b/Cargo.lock index 2a40cfe6927868e4806ae226fafa0b83f37627d4..ecf93e2c0536bee05e23226c2dd21f28eebf5c26 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2749,6 +2749,7 @@ dependencies = [ "async-compression", "async-std", "async-tar", + "chrono", "client", "clock", "collections", @@ -2759,6 +2760,7 @@ dependencies = [ "gpui", "http_client", "indoc", + "isahc", "language", "lsp", "menu", @@ -2767,10 +2769,13 @@ dependencies = [ "paths", "project", "rpc", + "schemars", "serde", "serde_json", "settings", "smol", + "strum", + "task", "theme", "ui", "util", @@ -6020,6 +6025,7 @@ dependencies = [ "anyhow", "client", "collections", + "copilot", "ctor", "editor", "env_logger", @@ -6028,6 +6034,7 @@ dependencies = [ "google_ai", "gpui", "http_client", + "inline_completion_button", "language", "log", "menu", diff --git a/crates/assistant/src/assistant_settings.rs b/crates/assistant/src/assistant_settings.rs index 0d4dbd68240353961d84a65e3845e8801ecb0f24..3242096d814de79ea8abb8fb0086388403acc45d 100644 --- a/crates/assistant/src/assistant_settings.rs +++ b/crates/assistant/src/assistant_settings.rs @@ -380,6 +380,7 @@ fn providers_schema(_: &mut schemars::gen::SchemaGenerator) -> schemars::schema: "ollama".into(), "openai".into(), "zed.dev".into(), + "copilot_chat".into(), ]), ..Default::default() } @@ -419,7 +420,7 @@ pub struct AssistantSettingsContentV1 { default_height: Option, /// The provider of the assistant service. /// - /// This can either be the internal `zed.dev` service or an external `openai` service, + /// This can be "openai", "anthropic", "ollama", "zed.dev" /// each with their respective default models and configurations. provider: Option, } diff --git a/crates/copilot/Cargo.toml b/crates/copilot/Cargo.toml index 0fc3067eab139901c334db4383d0ff5470f4e395..54abbaa112060b48e51a94967893575ab18660ef 100644 --- a/crates/copilot/Cargo.toml +++ b/crates/copilot/Cargo.toml @@ -13,6 +13,8 @@ path = "src/copilot.rs" doctest = false [features] +default = [] +schemars = ["dep:schemars"] test-support = [ "collections/test-support", "gpui/test-support", @@ -26,13 +28,16 @@ test-support = [ anyhow.workspace = true async-compression.workspace = true async-tar.workspace = true +chrono.workspace = true collections.workspace = true client.workspace = true command_palette_hooks.workspace = true editor.workspace = true +fs.workspace = true futures.workspace = true gpui.workspace = true http_client.workspace = true +isahc.workspace = true language.workspace = true lsp.workspace = true menu.workspace = true @@ -41,8 +46,12 @@ parking_lot.workspace = true paths.workspace = true project.workspace = true serde.workspace = true +serde_json.workspace = true +schemars = { workspace = true, optional = true } +strum.workspace = true settings.workspace = true smol.workspace = true +task.workspace = true ui.workspace = true util.workspace = true workspace.workspace = true diff --git a/crates/copilot/src/copilot.rs b/crates/copilot/src/copilot.rs index c1d482455ad8bb35d9a758d287fee2db573dd711..9357f735f858c8d474e7c97a6f30c54243c37745 100644 --- a/crates/copilot/src/copilot.rs +++ b/crates/copilot/src/copilot.rs @@ -1,7 +1,9 @@ +pub mod copilot_chat; mod copilot_completion_provider; pub mod request; mod sign_in; +use ::fs::Fs; use anyhow::{anyhow, Context as _, Result}; use async_compression::futures::bufread::GzipDecoder; use async_tar::Archive; @@ -27,6 +29,7 @@ use settings::SettingsStore; use smol::{fs, io::BufReader, stream::StreamExt}; use std::{ any::TypeId, + env, ffi::OsString, mem, ops::Range, @@ -52,10 +55,13 @@ actions!( pub fn init( new_server_id: LanguageServerId, + fs: Arc, http: Arc, node_runtime: Arc, cx: &mut AppContext, ) { + copilot_chat::init(fs, http.clone(), cx); + let copilot = cx.new_model({ let node_runtime = node_runtime.clone(); move |cx| Copilot::start(new_server_id, http, node_runtime, cx) @@ -185,6 +191,10 @@ impl Status { pub fn is_authorized(&self) -> bool { matches!(self, Status::Authorized) } + + pub fn is_disabled(&self) -> bool { + matches!(self, Status::Disabled) + } } struct RegisteredBuffer { @@ -301,6 +311,8 @@ pub struct Copilot { pub enum Event { CopilotLanguageServerStarted, + CopilotAuthSignedIn, + CopilotAuthSignedOut, } impl EventEmitter for Copilot {} @@ -581,7 +593,7 @@ impl Copilot { } } - fn sign_out(&mut self, cx: &mut ModelContext) -> Task> { + pub fn sign_out(&mut self, cx: &mut ModelContext) -> Task> { self.update_sign_in_status(request::SignInStatus::NotSignedIn, cx); if let CopilotServer::Running(RunningCopilotServer { lsp: server, .. }) = &self.server { let server = server.clone(); @@ -928,6 +940,7 @@ impl Copilot { | request::SignInStatus::MaybeOk { .. } | request::SignInStatus::AlreadySignedIn { .. } => { server.sign_in_status = SignInStatus::Authorized; + cx.emit(Event::CopilotAuthSignedIn); for buffer in self.buffers.iter().cloned().collect::>() { if let Some(buffer) = buffer.upgrade() { self.register_buffer(&buffer, cx); @@ -942,6 +955,7 @@ impl Copilot { } request::SignInStatus::Ok { user: None } | request::SignInStatus::NotSignedIn => { server.sign_in_status = SignInStatus::SignedOut; + cx.emit(Event::CopilotAuthSignedOut); for buffer in self.buffers.iter().cloned().collect::>() { self.unregister_buffer(&buffer); } diff --git a/crates/copilot/src/copilot_chat.rs b/crates/copilot/src/copilot_chat.rs new file mode 100644 index 0000000000000000000000000000000000000000..6d3a2ee7dc1188491331c85350497b5d329f2e17 --- /dev/null +++ b/crates/copilot/src/copilot_chat.rs @@ -0,0 +1,364 @@ +use std::{sync::Arc, time::Duration}; + +use anyhow::{anyhow, Result}; +use chrono::DateTime; +use fs::Fs; +use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, StreamExt}; +use gpui::{AppContext, AsyncAppContext, Global}; +use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest}; +use isahc::config::Configurable; +use serde::{Deserialize, Serialize}; +use settings::watch_config_file; +use strum::EnumIter; +use ui::Context; + +pub const COPILOT_CHAT_COMPLETION_URL: &'static str = + "https://api.githubcopilot.com/chat/completions"; +pub const COPILOT_CHAT_AUTH_URL: &'static str = "https://api.github.com/copilot_internal/v2/token"; + +#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)] +#[serde(rename_all = "lowercase")] +pub enum Role { + User, + Assistant, + System, +} + +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)] +pub enum Model { + #[default] + #[serde(alias = "gpt-4", rename = "gpt-4")] + Gpt4, + #[serde(alias = "gpt-3.5-turbo", rename = "gpt-3.5-turbo")] + Gpt3_5Turbo, +} + +impl Model { + pub fn from_id(id: &str) -> Result { + match id { + "gpt-4" => Ok(Self::Gpt4), + "gpt-3.5-turbo" => Ok(Self::Gpt3_5Turbo), + _ => Err(anyhow!("Invalid model id: {}", id)), + } + } + + pub fn id(&self) -> &'static str { + match self { + Self::Gpt3_5Turbo => "gpt-3.5-turbo", + Self::Gpt4 => "gpt-4", + } + } + + pub fn display_name(&self) -> &'static str { + match self { + Self::Gpt3_5Turbo => "GPT-3.5", + Self::Gpt4 => "GPT-4", + } + } + + pub fn max_token_count(&self) -> usize { + match self { + Self::Gpt4 => 8192, + Self::Gpt3_5Turbo => 16385, + } + } +} + +#[derive(Serialize, Deserialize)] +pub struct Request { + pub intent: bool, + pub n: usize, + pub stream: bool, + pub temperature: f32, + pub model: Model, + pub messages: Vec, +} + +impl Request { + pub fn new(model: Model, messages: Vec) -> Self { + Self { + intent: true, + n: 1, + stream: true, + temperature: 0.1, + model, + messages, + } + } +} + +#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] +pub struct ChatMessage { + pub role: Role, + pub content: String, +} + +#[derive(Deserialize, Debug)] +#[serde(tag = "type", rename_all = "snake_case")] +pub struct ResponseEvent { + pub choices: Vec, + pub created: u64, + pub id: String, +} + +#[derive(Debug, Deserialize)] +pub struct ResponseChoice { + pub index: usize, + pub finish_reason: Option, + pub delta: ResponseDelta, +} + +#[derive(Debug, Deserialize)] +pub struct ResponseDelta { + pub content: Option, + pub role: Option, +} + +#[derive(Deserialize)] +struct ApiTokenResponse { + token: String, + expires_at: i64, +} + +#[derive(Clone)] +struct ApiToken { + api_key: String, + expires_at: DateTime, +} + +impl ApiToken { + pub fn remaining_seconds(&self) -> i64 { + self.expires_at + .timestamp() + .saturating_sub(chrono::Utc::now().timestamp()) + } +} + +impl TryFrom for ApiToken { + type Error = anyhow::Error; + + fn try_from(response: ApiTokenResponse) -> Result { + let expires_at = DateTime::from_timestamp(response.expires_at, 0) + .ok_or_else(|| anyhow!("invalid expires_at"))?; + + Ok(Self { + api_key: response.token, + expires_at, + }) + } +} + +struct GlobalCopilotChat(gpui::Model); + +impl Global for GlobalCopilotChat {} + +pub struct CopilotChat { + oauth_token: Option, + api_token: Option, + client: Arc, +} + +pub fn init(fs: Arc, client: Arc, cx: &mut AppContext) { + let copilot_chat = cx.new_model(|cx| CopilotChat::new(fs, client, cx)); + cx.set_global(GlobalCopilotChat(copilot_chat)); +} + +impl CopilotChat { + pub fn global(cx: &AppContext) -> Option> { + cx.try_global::() + .map(|model| model.0.clone()) + } + + pub fn new(fs: Arc, client: Arc, cx: &AppContext) -> Self { + let mut config_file_rx = watch_config_file( + cx.background_executor(), + fs, + paths::copilot_chat_config_path().clone(), + ); + + cx.spawn(|cx| async move { + while let Some(contents) = config_file_rx.next().await { + let oauth_token = extract_oauth_token(contents); + + cx.update(|cx| { + if let Some(this) = Self::global(cx).as_ref() { + this.update(cx, |this, cx| { + this.oauth_token = oauth_token; + cx.notify(); + }); + } + })?; + } + anyhow::Ok(()) + }) + .detach_and_log_err(cx); + + Self { + oauth_token: None, + api_token: None, + client, + } + } + + pub fn is_authenticated(&self) -> bool { + self.oauth_token.is_some() + } + + pub async fn stream_completion( + request: Request, + low_speed_timeout: Option, + cx: &mut AsyncAppContext, + ) -> Result>> { + let Some(this) = cx.update(|cx| Self::global(cx)).ok().flatten() else { + return Err(anyhow!("Copilot chat is not enabled")); + }; + + let (oauth_token, api_token, client) = this.read_with(cx, |this, _| { + ( + this.oauth_token.clone(), + this.api_token.clone(), + this.client.clone(), + ) + })?; + + let oauth_token = oauth_token.ok_or_else(|| anyhow!("No OAuth token available"))?; + + let token = match api_token { + Some(api_token) if api_token.remaining_seconds() > 5 * 60 => api_token.clone(), + _ => { + let token = + request_api_token(&oauth_token, client.clone(), low_speed_timeout).await?; + this.update(cx, |this, cx| { + this.api_token = Some(token.clone()); + cx.notify(); + })?; + token + } + }; + + stream_completion(client.clone(), token.api_key, request, low_speed_timeout).await + } +} + +async fn request_api_token( + oauth_token: &str, + client: Arc, + low_speed_timeout: Option, +) -> Result { + let mut request_builder = HttpRequest::builder() + .method(Method::GET) + .uri(COPILOT_CHAT_AUTH_URL) + .header("Authorization", format!("token {}", oauth_token)) + .header("Accept", "application/json"); + + if let Some(low_speed_timeout) = low_speed_timeout { + request_builder = request_builder.low_speed_timeout(100, low_speed_timeout); + } + + let request = request_builder.body(AsyncBody::empty())?; + + let mut response = client.send(request).await?; + + if response.status().is_success() { + let mut body = Vec::new(); + response.body_mut().read_to_end(&mut body).await?; + + let body_str = std::str::from_utf8(&body)?; + + let parsed: ApiTokenResponse = serde_json::from_str(body_str)?; + ApiToken::try_from(parsed) + } else { + let mut body = Vec::new(); + response.body_mut().read_to_end(&mut body).await?; + + let body_str = std::str::from_utf8(&body)?; + + Err(anyhow!("Failed to request API token: {}", body_str)) + } +} + +fn extract_oauth_token(contents: String) -> Option { + serde_json::from_str::(&contents) + .map(|v| { + v["github.com"]["oauth_token"] + .as_str() + .map(|v| v.to_string()) + }) + .ok() + .flatten() +} + +async fn stream_completion( + client: Arc, + api_key: String, + request: Request, + low_speed_timeout: Option, +) -> Result>> { + let mut request_builder = HttpRequest::builder() + .method(Method::POST) + .uri(COPILOT_CHAT_COMPLETION_URL) + .header( + "Editor-Version", + format!( + "Zed/{}", + option_env!("CARGO_PKG_VERSION").unwrap_or("unknown") + ), + ) + .header("Authorization", format!("Bearer {}", api_key)) + .header("Content-Type", "application/json") + .header("Copilot-Integration-Id", "vscode-chat"); + + if let Some(low_speed_timeout) = low_speed_timeout { + request_builder = request_builder.low_speed_timeout(100, low_speed_timeout); + } + let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?; + let mut response = client.send(request).await?; + if response.status().is_success() { + let reader = BufReader::new(response.into_body()); + Ok(reader + .lines() + .filter_map(|line| async move { + match line { + Ok(line) => { + let line = line.strip_prefix("data: ")?; + if line.starts_with("[DONE]") { + return None; + } + + match serde_json::from_str::(line) { + Ok(response) => { + if response.choices.first().is_none() + || response.choices.first().unwrap().finish_reason.is_some() + { + None + } else { + Some(Ok(response)) + } + } + Err(error) => Some(Err(anyhow!(error))), + } + } + Err(error) => Some(Err(anyhow!(error))), + } + }) + .boxed()) + } else { + let mut body = Vec::new(); + response.body_mut().read_to_end(&mut body).await?; + + let body_str = std::str::from_utf8(&body)?; + + match serde_json::from_str::(body_str) { + Ok(_) => Err(anyhow!( + "Unexpected success response while expecting an error: {}", + body_str, + )), + Err(_) => Err(anyhow!( + "Failed to connect to API: {} {}", + response.status(), + body_str, + )), + } + } +} diff --git a/crates/language_model/Cargo.toml b/crates/language_model/Cargo.toml index de3ba8ef650c7cf17484a499ef6ed0468726e444..9a5c60a0d8e8a949d92bf4a8a3a3e325e26c2297 100644 --- a/crates/language_model/Cargo.toml +++ b/crates/language_model/Cargo.toml @@ -25,12 +25,14 @@ anthropic = { workspace = true, features = ["schemars"] } anyhow.workspace = true client.workspace = true collections.workspace = true +copilot = { workspace = true, features = ["schemars"] } editor.workspace = true feature_flags.workspace = true futures.workspace = true google_ai = { workspace = true, features = ["schemars"] } gpui.workspace = true http_client.workspace = true +inline_completion_button.workspace = true menu.workspace = true ollama = { workspace = true, features = ["schemars"] } open_ai = { workspace = true, features = ["schemars"] } diff --git a/crates/language_model/src/provider.rs b/crates/language_model/src/provider.rs index 6fe0bfd7a1e71c8c63221600564d5afde0dd9871..d2d162b75e05652437df7b24038167f35f68de06 100644 --- a/crates/language_model/src/provider.rs +++ b/crates/language_model/src/provider.rs @@ -1,5 +1,6 @@ pub mod anthropic; pub mod cloud; +pub mod copilot_chat; #[cfg(any(test, feature = "test-support"))] pub mod fake; pub mod google; diff --git a/crates/language_model/src/provider/copilot_chat.rs b/crates/language_model/src/provider/copilot_chat.rs new file mode 100644 index 0000000000000000000000000000000000000000..bd3bee2c81f859e258259946eff7d6e5c3baa0f2 --- /dev/null +++ b/crates/language_model/src/provider/copilot_chat.rs @@ -0,0 +1,359 @@ +use std::sync::Arc; + +use anyhow::Result; +use copilot::copilot_chat::{ + ChatMessage, CopilotChat, Model as CopilotChatModel, Request as CopilotChatRequest, + Role as CopilotChatRole, +}; +use copilot::{Copilot, Status}; +use futures::future::BoxFuture; +use futures::stream::BoxStream; +use futures::{FutureExt, StreamExt}; +use gpui::{ + percentage, svg, Animation, AnimationExt, AnyView, AppContext, AsyncAppContext, Model, + ModelContext, Render, Subscription, Task, Transformation, +}; +use settings::{Settings, SettingsStore}; +use std::time::Duration; +use strum::IntoEnumIterator; +use ui::{ + div, v_flex, Button, ButtonCommon, Clickable, Color, Context, FixedWidth, IconName, + IconPosition, IconSize, IntoElement, Label, LabelCommon, ParentElement, Styled, ViewContext, + VisualContext, WindowContext, +}; + +use crate::settings::AllLanguageModelSettings; +use crate::LanguageModelProviderState; +use crate::{ + LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider, + LanguageModelProviderId, LanguageModelProviderName, LanguageModelRequest, Role, +}; + +use super::open_ai::count_open_ai_tokens; + +const PROVIDER_ID: &str = "copilot_chat"; +const PROVIDER_NAME: &str = "GitHub Copilot Chat"; + +#[derive(Default, Clone, Debug, PartialEq)] +pub struct CopilotChatSettings { + pub low_speed_timeout: Option, +} + +pub struct CopilotChatLanguageModelProvider { + state: Model, +} + +pub struct State { + _copilot_chat_subscription: Option, + _settings_subscription: Subscription, +} + +impl CopilotChatLanguageModelProvider { + pub fn new(cx: &mut AppContext) -> Self { + let state = cx.new_model(|cx| { + let _copilot_chat_subscription = CopilotChat::global(cx) + .map(|copilot_chat| cx.observe(&copilot_chat, |_, _, cx| cx.notify())); + State { + _copilot_chat_subscription, + _settings_subscription: cx.observe_global::(|_, cx| { + cx.notify(); + }), + } + }); + + Self { state } + } +} + +impl LanguageModelProviderState for CopilotChatLanguageModelProvider { + fn subscribe(&self, cx: &mut ModelContext) -> Option { + Some(cx.observe(&self.state, |_, _, cx| { + cx.notify(); + })) + } +} + +impl LanguageModelProvider for CopilotChatLanguageModelProvider { + fn id(&self) -> LanguageModelProviderId { + LanguageModelProviderId(PROVIDER_ID.into()) + } + + fn name(&self) -> LanguageModelProviderName { + LanguageModelProviderName(PROVIDER_NAME.into()) + } + + fn provided_models(&self, _cx: &AppContext) -> Vec> { + CopilotChatModel::iter() + .map(|model| Arc::new(CopilotChatLanguageModel { model }) as Arc) + .collect() + } + + fn is_authenticated(&self, cx: &AppContext) -> bool { + CopilotChat::global(cx) + .map(|m| m.read(cx).is_authenticated()) + .unwrap_or(false) + } + + fn authenticate(&self, cx: &AppContext) -> Task> { + let result = if self.is_authenticated(cx) { + Ok(()) + } else if let Some(copilot) = Copilot::global(cx) { + let error_msg = match copilot.read(cx).status() { + Status::Disabled => anyhow::anyhow!("Copilot must be enabled for Copilot Chat to work. Please enable Copilot and try again."), + Status::Error(e) => anyhow::anyhow!(format!("Received the following error while signing into Copilot: {e}")), + Status::Starting { task: _ } => anyhow::anyhow!("Copilot is still starting, please wait for Copilot to start then try again"), + Status::Unauthorized => anyhow::anyhow!("Unable to authorize with Copilot. Please make sure that you have an active Copilot and Copilot Chat subscription."), + Status::Authorized => return Task::ready(Ok(())), + Status::SignedOut => anyhow::anyhow!("You have signed out of Copilot. Please sign in to Copilot and try again."), + Status::SigningIn { prompt: _ } => anyhow::anyhow!("Still signing into Copilot..."), + }; + Err(error_msg) + } else { + Err(anyhow::anyhow!( + "Copilot must be enabled for Copilot Chat to work. Please enable Copilot and try again." + )) + }; + Task::ready(result) + } + + fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView { + cx.new_view(|cx| AuthenticationPrompt::new(cx)).into() + } + + fn reset_credentials(&self, cx: &AppContext) -> Task> { + let Some(copilot) = Copilot::global(cx) else { + return Task::ready(Err(anyhow::anyhow!( + "Copilot is not available. Please ensure Copilot is enabled and running and try again." + ))); + }; + + let state = self.state.clone(); + + cx.spawn(|mut cx| async move { + cx.update_model(&copilot, |model, cx| model.sign_out(cx))? + .await?; + + cx.update_model(&state, |_, cx| { + cx.notify(); + })?; + + Ok(()) + }) + } +} + +pub struct CopilotChatLanguageModel { + model: CopilotChatModel, +} + +impl LanguageModel for CopilotChatLanguageModel { + fn id(&self) -> LanguageModelId { + LanguageModelId::from(self.model.id().to_string()) + } + + fn name(&self) -> LanguageModelName { + LanguageModelName::from(self.model.display_name().to_string()) + } + + fn provider_id(&self) -> LanguageModelProviderId { + LanguageModelProviderId(PROVIDER_ID.into()) + } + + fn provider_name(&self) -> LanguageModelProviderName { + LanguageModelProviderName(PROVIDER_NAME.into()) + } + + fn telemetry_id(&self) -> String { + format!("copilot_chat/{}", self.model.id()) + } + + fn max_token_count(&self) -> usize { + self.model.max_token_count() + } + + fn count_tokens( + &self, + request: LanguageModelRequest, + cx: &AppContext, + ) -> BoxFuture<'static, Result> { + let model = match self.model { + CopilotChatModel::Gpt4 => open_ai::Model::Four, + CopilotChatModel::Gpt3_5Turbo => open_ai::Model::ThreePointFiveTurbo, + }; + + count_open_ai_tokens(request, model, cx) + } + + fn stream_completion( + &self, + request: LanguageModelRequest, + cx: &AsyncAppContext, + ) -> BoxFuture<'static, Result>>> { + if let Some(message) = request.messages.last() { + if message.content.trim().is_empty() { + const EMPTY_PROMPT_MSG: &str = + "Empty prompts aren't allowed. Please provide a non-empty prompt."; + return futures::future::ready(Err(anyhow::anyhow!(EMPTY_PROMPT_MSG))).boxed(); + } + + // Copilot Chat has a restriction that the final message must be from the user. + // While their API does return an error message for this, we can catch it earlier + // and provide a more helpful error message. + if !matches!(message.role, Role::User) { + const USER_ROLE_MSG: &str = "The final message must be from the user. To provide a system prompt, you must provide the system prompt followed by a user prompt."; + return futures::future::ready(Err(anyhow::anyhow!(USER_ROLE_MSG))).boxed(); + } + } + + let request = self.to_copilot_chat_request(request); + let Ok(low_speed_timeout) = cx.update(|cx| { + AllLanguageModelSettings::get_global(cx) + .copilot_chat + .low_speed_timeout + }) else { + return futures::future::ready(Err(anyhow::anyhow!("App state dropped"))).boxed(); + }; + + cx.spawn(|mut cx| async move { + let response = CopilotChat::stream_completion(request, low_speed_timeout, &mut cx).await?; + let stream = response + .filter_map(|response| async move { + match response { + Ok(result) => { + let choice = result.choices.first(); + match choice { + Some(choice) => Some(Ok(choice.delta.content.clone().unwrap_or_default())), + None => Some(Err(anyhow::anyhow!( + "The Copilot Chat API returned a response with no choices, but hadn't finished the message yet. Please try again." + ))), + } + } + Err(err) => Some(Err(err)), + } + }) + .boxed(); + Ok(stream) + }) + .boxed() + } +} + +impl CopilotChatLanguageModel { + pub fn to_copilot_chat_request(&self, request: LanguageModelRequest) -> CopilotChatRequest { + CopilotChatRequest::new( + self.model.clone(), + request + .messages + .into_iter() + .map(|msg| ChatMessage { + role: match msg.role { + Role::User => CopilotChatRole::User, + Role::Assistant => CopilotChatRole::Assistant, + Role::System => CopilotChatRole::System, + }, + content: msg.content, + }) + .collect(), + ) + } +} + +struct AuthenticationPrompt { + copilot_status: Option, + _subscription: Option, +} + +impl AuthenticationPrompt { + pub fn new(cx: &mut ViewContext) -> Self { + let copilot = Copilot::global(cx); + + Self { + copilot_status: copilot.as_ref().map(|copilot| copilot.read(cx).status()), + _subscription: copilot.as_ref().map(|copilot| { + cx.observe(copilot, |this, model, cx| { + this.copilot_status = Some(model.read(cx).status()); + cx.notify(); + }) + }), + } + } +} + +impl Render for AuthenticationPrompt { + fn render(&mut self, cx: &mut ViewContext) -> impl IntoElement { + let loading_icon = svg() + .size_8() + .path(IconName::ArrowCircle.path()) + .text_color(cx.text_style().color) + .with_animation( + "icon_circle_arrow", + Animation::new(Duration::from_secs(2)).repeat(), + |svg, delta| svg.with_transformation(Transformation::rotate(percentage(delta))), + ); + + const ERROR_LABEL: &str = "Copilot Chat requires the Copilot plugin to be available and running. Please ensure Copilot is running and try again, or use a different Assistant provider."; + match &self.copilot_status { + Some(status) => match status { + Status::Disabled => { + return v_flex().gap_6().p_4().child(Label::new(ERROR_LABEL)); + } + Status::Starting { task: _ } => { + const LABEL: &str = "Starting Copilot..."; + return v_flex() + .gap_6() + .p_4() + .justify_center() + .items_center() + .child(Label::new(LABEL)) + .child(loading_icon); + } + Status::SigningIn { prompt: _ } => { + const LABEL: &str = "Signing in to Copilot..."; + return v_flex() + .gap_6() + .p_4() + .justify_center() + .items_center() + .child(Label::new(LABEL)) + .child(loading_icon); + } + Status::Error(_) => { + const LABEL: &str = "Copilot had issues starting. Please try restarting it. If the issue persists, try reinstalling Copilot."; + return v_flex() + .gap_6() + .p_4() + .child(Label::new(LABEL)) + .child(svg().size_8().path(IconName::CopilotError.path())); + } + _ => { + const LABEL: &str = + "To use the assistant panel or inline assistant, you must login to GitHub Copilot. Your GitHub account must have an active Copilot Chat subscription."; + v_flex().gap_6().p_4().child(Label::new(LABEL)).child( + v_flex() + .gap_2() + .child( + Button::new("sign_in", "Sign In") + .icon_color(Color::Muted) + .icon(IconName::Github) + .icon_position(IconPosition::Start) + .icon_size(IconSize::Medium) + .style(ui::ButtonStyle::Filled) + .full_width() + .on_click(|_, cx| { + inline_completion_button::initiate_sign_in(cx) + }), + ) + .child( + div().flex().w_full().items_center().child( + Label::new("Sign in to start using Github Copilot Chat.") + .color(Color::Muted) + .size(ui::LabelSize::Small), + ), + ), + ) + } + }, + None => v_flex().gap_6().p_4().child(Label::new(ERROR_LABEL)), + } + } +} diff --git a/crates/language_model/src/registry.rs b/crates/language_model/src/registry.rs index 05dcbced5ddedd5df3a13ebed4aa1ad0bcce8886..d90163671e2be01c97ab5627d6aa13f128856ae9 100644 --- a/crates/language_model/src/registry.rs +++ b/crates/language_model/src/registry.rs @@ -1,8 +1,8 @@ use crate::{ provider::{ anthropic::AnthropicLanguageModelProvider, cloud::CloudLanguageModelProvider, - google::GoogleLanguageModelProvider, ollama::OllamaLanguageModelProvider, - open_ai::OpenAiLanguageModelProvider, + copilot_chat::CopilotChatLanguageModelProvider, google::GoogleLanguageModelProvider, + ollama::OllamaLanguageModelProvider, open_ai::OpenAiLanguageModelProvider, }, LanguageModel, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderState, }; @@ -44,6 +44,7 @@ fn register_language_model_providers( GoogleLanguageModelProvider::new(client.http_client(), cx), cx, ); + registry.register_provider(CopilotChatLanguageModelProvider::new(cx), cx); cx.observe_flag::(move |enabled, cx| { let client = client.clone(); diff --git a/crates/language_model/src/settings.rs b/crates/language_model/src/settings.rs index 85ae91649a58cb66bda36172d02f8b4f063e3b07..58e38e49719d3b94b602eec6dcf5658f9df11010 100644 --- a/crates/language_model/src/settings.rs +++ b/crates/language_model/src/settings.rs @@ -9,6 +9,7 @@ use settings::{Settings, SettingsSources}; use crate::provider::{ anthropic::AnthropicSettings, cloud::{self, ZedDotDevSettings}, + copilot_chat::CopilotChatSettings, google::GoogleSettings, ollama::OllamaSettings, open_ai::OpenAiSettings, @@ -26,6 +27,7 @@ pub struct AllLanguageModelSettings { pub openai: OpenAiSettings, pub zed_dot_dev: ZedDotDevSettings, pub google: GoogleSettings, + pub copilot_chat: CopilotChatSettings, } #[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)] @@ -36,6 +38,7 @@ pub struct AllLanguageModelSettingsContent { #[serde(rename = "zed.dev")] pub zed_dot_dev: Option, pub google: Option, + pub copilot_chat: Option, } #[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)] @@ -70,6 +73,11 @@ pub struct ZedDotDevSettingsContent { available_models: Option>, } +#[derive(Default, Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)] +pub struct CopilotChatSettingsContent { + low_speed_timeout_in_seconds: Option, +} + impl settings::Settings for AllLanguageModelSettings { const KEY: Option<&'static str> = Some("language_models"); @@ -165,6 +173,15 @@ impl settings::Settings for AllLanguageModelSettings { .as_ref() .and_then(|s| s.available_models.clone()), ); + + if let Some(low_speed_timeout) = value + .copilot_chat + .as_ref() + .and_then(|s| s.low_speed_timeout_in_seconds) + { + settings.copilot_chat.low_speed_timeout = + Some(Duration::from_secs(low_speed_timeout)); + } } Ok(settings) diff --git a/crates/language_tools/src/lsp_log.rs b/crates/language_tools/src/lsp_log.rs index e2fabdc7b54b412bd7f9024e0c5a3a34c14dfcac..445619dd0c3a4651e7fbd09b68815d4739c65bd7 100644 --- a/crates/language_tools/src/lsp_log.rs +++ b/crates/language_tools/src/lsp_log.rs @@ -180,6 +180,7 @@ impl LogStore { ); } } + _ => {} } }) }); diff --git a/crates/paths/src/paths.rs b/crates/paths/src/paths.rs index 768be162d27359197074eb2b28e85510c724b03d..a28429f6a7508a00cdee451041698dff5101576b 100644 --- a/crates/paths/src/paths.rs +++ b/crates/paths/src/paths.rs @@ -212,6 +212,20 @@ pub fn copilot_dir() -> &'static PathBuf { COPILOT_DIR.get_or_init(|| support_dir().join("copilot")) } +pub fn copilot_chat_config_path() -> &'static PathBuf { + static COPILOT_CHAT_CONFIG_DIR: OnceLock = OnceLock::new(); + + COPILOT_CHAT_CONFIG_DIR.get_or_init(|| { + if cfg!(target_os = "windows") { + home_dir().join("AppData") + } else { + home_dir().join(".config") + } + .join("github-copilot") + .join("hosts.json") + }) +} + /// Returns the path to the Supermaven directory. pub fn supermaven_dir() -> &'static PathBuf { static SUPERMAVEN_DIR: OnceLock = OnceLock::new(); diff --git a/crates/zed/src/main.rs b/crates/zed/src/main.rs index 48497087d77218ee35ea0faa4ff24ae01f444a45..d0bda24385d7dfb948000a27e5ca32a969ff4dd4 100644 --- a/crates/zed/src/main.rs +++ b/crates/zed/src/main.rs @@ -165,9 +165,17 @@ fn init_common(app_state: Arc, cx: &mut AppContext) { SystemAppearance::init(cx); theme::init(theme::LoadThemes::All(Box::new(Assets)), cx); command_palette::init(cx); + let copilot_language_server_id = app_state.languages.next_language_server_id(); + copilot::init( + copilot_language_server_id, + app_state.fs.clone(), + app_state.client.http_client(), + app_state.node_runtime.clone(), + cx, + ); + supermaven::init(app_state.client.clone(), cx); language_model::init(app_state.client.clone(), cx); snippet_provider::init(cx); - supermaven::init(app_state.client.clone(), cx); inline_completion_registry::init(app_state.client.telemetry().clone(), cx); assistant::init(app_state.fs.clone(), app_state.client.clone(), cx); repl::init( @@ -239,15 +247,6 @@ fn init_ui(app_state: Arc, cx: &mut AppContext) -> Result<()> { settings_ui::init(cx); extensions_ui::init(cx); - // Initialize each completion provider. Settings are used for toggling between them. - let copilot_language_server_id = app_state.languages.next_language_server_id(); - copilot::init( - copilot_language_server_id, - app_state.client.http_client(), - app_state.node_runtime.clone(), - cx, - ); - cx.observe_global::({ let languages = app_state.languages.clone(); let http = app_state.client.http_client(); diff --git a/crates/zed/src/zed.rs b/crates/zed/src/zed.rs index fd9f32d6b6ec6c1ad740b76fd695487565773d8b..46cd10467f994934d009d3cfd84d0179173fd334 100644 --- a/crates/zed/src/zed.rs +++ b/crates/zed/src/zed.rs @@ -3456,6 +3456,11 @@ mod tests { project_panel::init((), cx); outline_panel::init((), cx); terminal_view::init(cx); + copilot::copilot_chat::init( + app_state.fs.clone(), + app_state.client.http_client().clone(), + cx, + ); language_model::init(app_state.client.clone(), cx); assistant::init(app_state.fs.clone(), app_state.client.clone(), cx); repl::init(