From 6563330239a142b7e0c2cb02abf7aa5da1571c57 Mon Sep 17 00:00:00 2001 From: Kyle Kelley Date: Fri, 3 May 2024 12:50:42 -0700 Subject: [PATCH] Supermaven (#10788) Adds a supermaven provider for completions. There are various other refactors amidst this branch, primarily to make copilot no longer a dependency of project as well as show LSP Logs for global LSPs like copilot properly. This feature is not enabled by default. We're going to seek to refine it in the coming weeks. Release Notes: - N/A --------- Co-authored-by: Antonio Scandurra Co-authored-by: Nathan Sobo Co-authored-by: Max Co-authored-by: Max Brunsfeld --- Cargo.lock | 97 +++- Cargo.toml | 8 +- assets/icons/supermaven.svg | 8 + assets/icons/supermaven_disabled.svg | 15 + assets/icons/supermaven_error.svg | 11 + assets/icons/supermaven_init.svg | 11 + assets/settings/default.json | 4 +- crates/anthropic/src/anthropic.rs | 4 +- crates/collab/Cargo.toml | 1 + crates/collab/k8s/collab.template.yml | 5 + crates/collab/src/completion.rs | 2 + crates/collab/src/lib.rs | 1 + crates/collab/src/rpc.rs | 74 ++- crates/collab/src/tests/test_server.rs | 1 + crates/copilot/Cargo.toml | 10 + crates/copilot/src/copilot.rs | 23 +- .../src/copilot_completion_provider.rs | 77 +-- crates/{copilot_ui => copilot}/src/sign_in.rs | 4 +- crates/copilot_ui/src/copilot_button.rs | 403 -------------- crates/copilot_ui/src/copilot_ui.rs | 7 - crates/editor/src/editor.rs | 44 +- .../editor/src/inline_completion_provider.rs | 10 +- crates/google_ai/src/google_ai.rs | 6 +- .../Cargo.toml | 7 +- .../LICENSE-GPL | 0 .../src/inline_completion_button.rs | 510 ++++++++++++++++++ crates/language/src/language_settings.rs | 116 ++-- crates/language_tools/Cargo.toml | 2 +- crates/language_tools/src/lsp_log.rs | 323 ++++++----- crates/project/Cargo.toml | 1 - crates/project/src/project.rs | 79 +-- crates/rpc/proto/zed.proto | 13 +- crates/rpc/src/proto.rs | 3 + crates/supermaven/Cargo.toml | 41 ++ crates/supermaven/src/messages.rs | 152 ++++++ crates/supermaven/src/supermaven.rs | 345 ++++++++++++ .../src/supermaven_completion_provider.rs | 131 +++++ crates/supermaven_api/Cargo.toml | 21 + crates/supermaven_api/src/supermaven_api.rs | 291 ++++++++++ crates/ui/src/components/icon.rs | 8 + crates/util/src/paths.rs | 1 + crates/welcome/Cargo.toml | 2 +- crates/welcome/src/welcome.rs | 3 +- crates/zed/Cargo.toml | 3 +- crates/zed/src/main.rs | 57 +- crates/zed/src/zed.rs | 8 +- .../zed/src/zed/inline_completion_registry.rs | 126 +++++ 47 files changed, 2242 insertions(+), 827 deletions(-) create mode 100644 assets/icons/supermaven.svg create mode 100644 assets/icons/supermaven_disabled.svg create mode 100644 assets/icons/supermaven_error.svg create mode 100644 assets/icons/supermaven_init.svg create mode 100644 crates/collab/src/completion.rs rename crates/{copilot_ui => copilot}/src/copilot_completion_provider.rs (94%) rename crates/{copilot_ui => copilot}/src/sign_in.rs (98%) delete mode 100644 crates/copilot_ui/src/copilot_button.rs delete mode 100644 crates/copilot_ui/src/copilot_ui.rs rename crates/{copilot_ui => inline_completion_button}/Cargo.toml (88%) rename crates/{copilot_ui => inline_completion_button}/LICENSE-GPL (100%) create mode 100644 crates/inline_completion_button/src/inline_completion_button.rs create mode 100644 crates/supermaven/Cargo.toml create mode 100644 crates/supermaven/src/messages.rs create mode 100644 crates/supermaven/src/supermaven.rs create mode 100644 crates/supermaven/src/supermaven_completion_provider.rs create mode 100644 crates/supermaven_api/Cargo.toml create mode 100644 crates/supermaven_api/src/supermaven_api.rs create mode 100644 crates/zed/src/zed/inline_completion_registry.rs diff --git a/Cargo.lock b/Cargo.lock index 3417a62d8241979b0024ea8fbc0a0df3d9a8dd13..894b827c4f1055ecea80bf63cdd07167fb0b1127 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2316,6 +2316,7 @@ dependencies = [ "sha2 0.10.7", "sqlx", "subtle", + "supermaven_api", "telemetry_events", "text", "theme", @@ -2512,30 +2513,10 @@ dependencies = [ "async-compression", "async-std", "async-tar", + "client", "clock", "collections", "command_palette_hooks", - "fs", - "futures 0.3.28", - "gpui", - "language", - "lsp", - "node_runtime", - "parking_lot", - "rpc", - "serde", - "settings", - "smol", - "util", -] - -[[package]] -name = "copilot_ui" -version = "0.1.0" -dependencies = [ - "anyhow", - "client", - "copilot", "editor", "fs", "futures 0.3.28", @@ -2544,14 +2525,18 @@ dependencies = [ "language", "lsp", "menu", + "node_runtime", + "parking_lot", "project", + "rpc", + "serde", "serde_json", "settings", + "smol", "theme", "ui", "util", "workspace", - "zed_actions", ] [[package]] @@ -5143,6 +5128,30 @@ dependencies = [ "syn 2.0.59", ] +[[package]] +name = "inline_completion_button" +version = "0.1.0" +dependencies = [ + "anyhow", + "copilot", + "editor", + "fs", + "futures 0.3.28", + "gpui", + "indoc", + "language", + "lsp", + "project", + "serde_json", + "settings", + "supermaven", + "theme", + "ui", + "util", + "workspace", + "zed_actions", +] + [[package]] name = "inotify" version = "0.9.6" @@ -5548,6 +5557,7 @@ dependencies = [ "anyhow", "client", "collections", + "copilot", "editor", "env_logger", "futures 0.3.28", @@ -7422,7 +7432,6 @@ dependencies = [ "client", "clock", "collections", - "copilot", "env_logger", "fs", "futures 0.3.28", @@ -9594,6 +9603,43 @@ dependencies = [ "rayon", ] +[[package]] +name = "supermaven" +version = "0.1.0" +dependencies = [ + "anyhow", + "client", + "collections", + "editor", + "env_logger", + "futures 0.3.28", + "gpui", + "language", + "log", + "postage", + "project", + "serde", + "serde_json", + "settings", + "smol", + "supermaven_api", + "theme", + "ui", + "util", +] + +[[package]] +name = "supermaven_api" +version = "0.1.0" +dependencies = [ + "anyhow", + "futures 0.3.28", + "serde", + "serde_json", + "smol", + "util", +] + [[package]] name = "sval" version = "2.8.0" @@ -11798,12 +11844,12 @@ version = "0.1.0" dependencies = [ "anyhow", "client", - "copilot_ui", "db", "editor", "extensions_ui", "fuzzy", "gpui", + "inline_completion_button", "install_cli", "picker", "project", @@ -12683,7 +12729,6 @@ dependencies = [ "collections", "command_palette", "copilot", - "copilot_ui", "db", "dev_server_projects", "diagnostics", @@ -12700,6 +12745,7 @@ dependencies = [ "gpui", "headless", "image_viewer", + "inline_completion_button", "install_cli", "isahc", "journal", @@ -12730,6 +12776,7 @@ dependencies = [ "settings", "simplelog", "smol", + "supermaven", "tab_switcher", "task", "tasks_ui", diff --git a/Cargo.toml b/Cargo.toml index ca0e5f35bd372584d3ba2093a1196f3102239cdc..67ce732b6103ef7723d3fdecf387dbd56f7c1744 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,7 +20,6 @@ members = [ "crates/command_palette", "crates/command_palette_hooks", "crates/copilot", - "crates/copilot_ui", "crates/db", "crates/diagnostics", "crates/editor", @@ -42,6 +41,7 @@ members = [ "crates/gpui_macros", "crates/headless", "crates/image_viewer", + "crates/inline_completion_button", "crates/install_cli", "crates/journal", "crates/language", @@ -86,6 +86,8 @@ members = [ "crates/storybook", "crates/sum_tree", "crates/tab_switcher", + "crates/supermaven", + "crates/supermaven_api", "crates/terminal", "crates/terminal_view", "crates/text", @@ -159,7 +161,6 @@ color = { path = "crates/color" } command_palette = { path = "crates/command_palette" } command_palette_hooks = { path = "crates/command_palette_hooks" } copilot = { path = "crates/copilot" } -copilot_ui = { path = "crates/copilot_ui" } db = { path = "crates/db" } diagnostics = { path = "crates/diagnostics" } editor = { path = "crates/editor" } @@ -180,6 +181,7 @@ gpui_macros = { path = "crates/gpui_macros" } headless = { path = "crates/headless" } install_cli = { path = "crates/install_cli" } image_viewer = { path = "crates/image_viewer" } +inline_completion_button = { path = "crates/inline_completion_button" } journal = { path = "crates/journal" } language = { path = "crates/language" } language_selector = { path = "crates/language_selector" } @@ -220,6 +222,8 @@ settings = { path = "crates/settings" } snippet = { path = "crates/snippet" } sqlez = { path = "crates/sqlez" } sqlez_macros = { path = "crates/sqlez_macros" } +supermaven = { path = "crates/supermaven" } +supermaven_api = { path = "crates/supermaven_api"} story = { path = "crates/story" } storybook = { path = "crates/storybook" } sum_tree = { path = "crates/sum_tree" } diff --git a/assets/icons/supermaven.svg b/assets/icons/supermaven.svg new file mode 100644 index 0000000000000000000000000000000000000000..19837fbf56ceb653e4b3e5b478747bdaf1654b32 --- /dev/null +++ b/assets/icons/supermaven.svg @@ -0,0 +1,8 @@ + + + + + + + + diff --git a/assets/icons/supermaven_disabled.svg b/assets/icons/supermaven_disabled.svg new file mode 100644 index 0000000000000000000000000000000000000000..39ff8a6122d9ce978998795181b640d1ef4b2eed --- /dev/null +++ b/assets/icons/supermaven_disabled.svg @@ -0,0 +1,15 @@ + + + + + + + + + + + + + + + diff --git a/assets/icons/supermaven_error.svg b/assets/icons/supermaven_error.svg new file mode 100644 index 0000000000000000000000000000000000000000..669322b97d9b5e88a83d032caf03dad48742bebb --- /dev/null +++ b/assets/icons/supermaven_error.svg @@ -0,0 +1,11 @@ + + + + + + + + + + + diff --git a/assets/icons/supermaven_init.svg b/assets/icons/supermaven_init.svg new file mode 100644 index 0000000000000000000000000000000000000000..b919d5559bfd66ca1b85fa3c2ead5bc6d54f3db4 --- /dev/null +++ b/assets/icons/supermaven_init.svg @@ -0,0 +1,11 @@ + + + + + + + + + + + diff --git a/assets/settings/default.json b/assets/settings/default.json index c8560d7f159766fd7d65e2357de41437b9054f41..de6da01f87ef0e0ce59214bece7ce5e1a9155dd3 100644 --- a/assets/settings/default.json +++ b/assets/settings/default.json @@ -12,8 +12,8 @@ "base_keymap": "VSCode", // Features that can be globally enabled or disabled "features": { - // Show Copilot icon in status bar - "copilot": true + // Which inline completion provider to use. + "inline_completion_provider": "copilot" }, // The name of a font to use for rendering text in the editor "buffer_font_family": "Zed Mono", diff --git a/crates/anthropic/src/anthropic.rs b/crates/anthropic/src/anthropic.rs index a96a23b166cd246e48e025fdb8836b6c2e57a6e4..aeaae1f34d732bd7bb62169569ed46d50efea014 100644 --- a/crates/anthropic/src/anthropic.rs +++ b/crates/anthropic/src/anthropic.rs @@ -1,7 +1,7 @@ use anyhow::{anyhow, Result}; use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, StreamExt}; use serde::{Deserialize, Serialize}; -use std::convert::TryFrom; +use std::{convert::TryFrom, sync::Arc}; use util::http::{AsyncBody, HttpClient, Method, Request as HttpRequest}; #[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)] @@ -141,7 +141,7 @@ pub enum TextDelta { } pub async fn stream_completion( - client: &dyn HttpClient, + client: Arc, api_url: &str, api_key: &str, request: Request, diff --git a/crates/collab/Cargo.toml b/crates/collab/Cargo.toml index e8dbcf851a53f769528b0779555ebc1577ac0758..5e719739aea480529c1af4e64cfc263650560b5c 100644 --- a/crates/collab/Cargo.toml +++ b/crates/collab/Cargo.toml @@ -39,6 +39,7 @@ live_kit_server.workspace = true log.workspace = true nanoid.workspace = true open_ai.workspace = true +supermaven_api.workspace = true parking_lot.workspace = true prometheus = "0.13" prost.workspace = true diff --git a/crates/collab/k8s/collab.template.yml b/crates/collab/k8s/collab.template.yml index 8bd6a71514ee7cac73e20188db8194fcd475daea..271b146b0b6bba88ae4187e66dabe7c56716442a 100644 --- a/crates/collab/k8s/collab.template.yml +++ b/crates/collab/k8s/collab.template.yml @@ -172,6 +172,11 @@ spec: secretKeyRef: name: slack key: panics_webhook + - name: SUPERMAVEN_ADMIN_API_KEY + valueFrom: + secretKeyRef: + name: supermaven + key: api_key - name: INVITE_LINK_PREFIX value: ${INVITE_LINK_PREFIX} - name: RUST_BACKTRACE diff --git a/crates/collab/src/completion.rs b/crates/collab/src/completion.rs new file mode 100644 index 0000000000000000000000000000000000000000..dd1f4b3be6479b764a714573773f385c4d8a2604 --- /dev/null +++ b/crates/collab/src/completion.rs @@ -0,0 +1,2 @@ +use anyhow::{anyhow, Result}; +use rpc::proto; diff --git a/crates/collab/src/lib.rs b/crates/collab/src/lib.rs index 925d192fc0e439e0a91c6f079fc25cc7422c642e..ae83fccb98cf622600e6b22a83c7f76e2abd95a0 100644 --- a/crates/collab/src/lib.rs +++ b/crates/collab/src/lib.rs @@ -138,6 +138,7 @@ pub struct Config { pub zed_client_checksum_seed: Option, pub slack_panics_webhook: Option, pub auto_join_channel_id: Option, + pub supermaven_admin_api_key: Option>, } impl Config { diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index e4a83a43382bb8180be8b98f6e74ebcf501eee98..59f811f0b561f28584a367e32d8a5ead9040dcd1 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -34,6 +34,7 @@ pub use connection_pool::{ConnectionPool, ZedVersion}; use core::fmt::{self, Debug, Formatter}; use open_ai::{OpenAiEmbeddingModel, OPEN_AI_API_URL}; use sha2::Digest; +use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi}; use futures::{ channel::oneshot, @@ -148,7 +149,8 @@ struct Session { peer: Arc, connection_pool: Arc>, live_kit_client: Option>, - http_client: IsahcHttpClient, + supermaven_client: Option>, + http_client: Arc, rate_limiter: Arc, _executor: Executor, } @@ -189,6 +191,14 @@ impl Session { } } + fn is_staff(&self) -> bool { + match &self.principal { + Principal::User(user) => user.admin, + Principal::Impersonated { .. } => true, + Principal::DevServer(_) => false, + } + } + fn dev_server_id(&self) -> Option { match &self.principal { Principal::User(_) | Principal::Impersonated { .. } => None, @@ -233,6 +243,14 @@ impl UserSession { pub fn user_id(&self) -> UserId { self.0.user_id().unwrap() } + + pub fn email(&self) -> Option { + match &self.0.principal { + Principal::User(user) => user.email_address.clone(), + Principal::Impersonated { user, .. } => user.email_address.clone(), + Principal::DevServer(..) => None, + } + } } impl Deref for UserSession { @@ -561,6 +579,7 @@ impl Server { .add_request_handler(user_handler(get_private_user_info)) .add_message_handler(user_message_handler(acknowledge_channel_message)) .add_message_handler(user_message_handler(acknowledge_buffer_version)) + .add_request_handler(user_handler(get_supermaven_api_key)) .add_streaming_request_handler({ let app_state = app_state.clone(); move |request, response, session| { @@ -938,13 +957,22 @@ impl Server { tracing::info!("connection opened"); let http_client = match IsahcHttpClient::new() { - Ok(http_client) => http_client, + Ok(http_client) => Arc::new(http_client), Err(error) => { tracing::error!(?error, "failed to create HTTP client"); return; } }; + let supermaven_client = if let Some(supermaven_admin_api_key) = this.app_state.config.supermaven_admin_api_key.clone() { + Some(Arc::new(SupermavenAdminApi::new( + supermaven_admin_api_key.to_string(), + http_client.clone(), + ))) + } else { + None + }; + let session = Session { principal: principal.clone(), connection_id, @@ -955,6 +983,7 @@ impl Server { http_client, rate_limiter: this.app_state.rate_limiter.clone(), _executor: executor.clone(), + supermaven_client, }; if let Err(error) = this.send_initial_client_update(connection_id, &principal, zed_version, send_connection_id, &session).await { @@ -4210,7 +4239,7 @@ async fn complete_with_open_ai( api_key: Arc, ) -> Result<()> { let mut completion_stream = open_ai::stream_completion( - &session.http_client, + session.http_client.as_ref(), OPEN_AI_API_URL, &api_key, crate::ai::language_model_request_to_open_ai(request)?, @@ -4274,7 +4303,7 @@ async fn complete_with_google_ai( api_key: Arc, ) -> Result<()> { let mut stream = google_ai::stream_generate_content( - &session.http_client, + session.http_client.clone(), google_ai::API_URL, api_key.as_ref(), crate::ai::language_model_request_to_google_ai(request)?, @@ -4358,7 +4387,7 @@ async fn complete_with_anthropic( .collect(); let mut stream = anthropic::stream_completion( - &session.http_client, + session.http_client.clone(), "https://api.anthropic.com", &api_key, anthropic::Request { @@ -4482,7 +4511,7 @@ async fn count_tokens_with_language_model( let api_key = google_ai_api_key .ok_or_else(|| anyhow!("no Google AI API key configured on the server"))?; let tokens_response = google_ai::count_tokens( - &session.http_client, + session.http_client.as_ref(), google_ai::API_URL, &api_key, crate::ai::count_tokens_request_to_google_ai(request)?, @@ -4530,7 +4559,7 @@ async fn compute_embeddings( let embeddings = match request.model.as_str() { "openai/text-embedding-3-small" => { open_ai::embed( - &session.http_client, + session.http_client.as_ref(), OPEN_AI_API_URL, &api_key, OpenAiEmbeddingModel::TextEmbedding3Small, @@ -4602,6 +4631,37 @@ async fn authorize_access_to_language_models(session: &UserSession) -> Result<() } } +/// Get a Supermaven API key for the user +async fn get_supermaven_api_key( + _request: proto::GetSupermavenApiKey, + response: Response, + session: UserSession, +) -> Result<()> { + let user_id: String = session.user_id().to_string(); + if !session.is_staff() { + return Err(anyhow!("supermaven not enabled for this account"))?; + } + + let email = session + .email() + .ok_or_else(|| anyhow!("user must have an email"))?; + + let supermaven_admin_api = session + .supermaven_client + .as_ref() + .ok_or_else(|| anyhow!("supermaven not configured"))?; + + let result = supermaven_admin_api + .try_get_or_create_user(CreateExternalUserRequest { id: user_id, email }) + .await?; + + response.send(proto::GetSupermavenApiKeyResponse { + api_key: result.api_key, + })?; + + Ok(()) +} + /// Start receiving chat updates for a channel async fn join_channel_chat( request: proto::JoinChannelChat, diff --git a/crates/collab/src/tests/test_server.rs b/crates/collab/src/tests/test_server.rs index 2fec21f76e6f206715699f91c391c6e04ab31052..3a456a328e5a05153b9d014f6c12dd2d123feced 100644 --- a/crates/collab/src/tests/test_server.rs +++ b/crates/collab/src/tests/test_server.rs @@ -655,6 +655,7 @@ impl TestServer { auto_join_channel_id: None, migrations_path: None, seed_path: None, + supermaven_admin_api_key: None, }, }) } diff --git a/crates/copilot/Cargo.toml b/crates/copilot/Cargo.toml index 609bd0e3a808ac909b5cc7e9b927cdecd1ca6b02..3f38a81f5b4d503c6596a562c42e41a2320a8a8f 100644 --- a/crates/copilot/Cargo.toml +++ b/crates/copilot/Cargo.toml @@ -27,28 +27,38 @@ anyhow.workspace = true async-compression.workspace = true async-tar.workspace = true collections.workspace = true +client.workspace = true command_palette_hooks.workspace = true +editor.workspace = true futures.workspace = true gpui.workspace = true language.workspace = true lsp.workspace = true +menu.workspace = true node_runtime.workspace = true parking_lot.workspace = true +project.workspace = true serde.workspace = true settings.workspace = true smol.workspace = true +ui.workspace = true util.workspace = true +workspace.workspace = true [target.'cfg(windows)'.dependencies] async-std = { version = "1.12.0", features = ["unstable"] } [dev-dependencies] clock.workspace = true +indoc.workspace = true +serde_json.workspace = true collections = { workspace = true, features = ["test-support"] } fs = { workspace = true, features = ["test-support"] } gpui = { workspace = true, features = ["test-support"] } language = { workspace = true, features = ["test-support"] } lsp = { workspace = true, features = ["test-support"] } +project = { workspace = true, features = ["test-support"] } rpc = { workspace = true, features = ["test-support"] } settings = { workspace = true, features = ["test-support"] } +theme = { workspace = true, features = ["test-support"] } util = { workspace = true, features = ["test-support"] } diff --git a/crates/copilot/src/copilot.rs b/crates/copilot/src/copilot.rs index 99f94b55118a0e8d2e31c0cdb7eb1552beccb675..577f335d2aed861476deca2ee7cea5ca1fa4d0a3 100644 --- a/crates/copilot/src/copilot.rs +++ b/crates/copilot/src/copilot.rs @@ -1,4 +1,7 @@ +mod copilot_completion_provider; pub mod request; +mod sign_in; + use anyhow::{anyhow, Context as _, Result}; use async_compression::futures::bufread::GzipDecoder; use async_tar::Archive; @@ -10,9 +13,9 @@ use gpui::{ ModelContext, Task, WeakModel, }; use language::{ - language_settings::{all_language_settings, language_settings}, - point_from_lsp, point_to_lsp, Anchor, Bias, Buffer, BufferSnapshot, Language, - LanguageServerName, PointUtf16, ToPointUtf16, + language_settings::{all_language_settings, language_settings, InlineCompletionProvider}, + point_from_lsp, point_to_lsp, Anchor, Bias, Buffer, BufferSnapshot, Language, PointUtf16, + ToPointUtf16, }; use lsp::{LanguageServer, LanguageServerBinary, LanguageServerId}; use node_runtime::NodeRuntime; @@ -32,6 +35,9 @@ use util::{ fs::remove_matching, github::latest_github_release, http::HttpClient, maybe, paths, ResultExt, }; +pub use copilot_completion_provider::CopilotCompletionProvider; +pub use sign_in::CopilotCodeVerification; + actions!( copilot, [ @@ -144,7 +150,6 @@ impl CopilotServer { } struct RunningCopilotServer { - name: LanguageServerName, lsp: Arc, sign_in_status: SignInStatus, registered_buffers: HashMap, @@ -354,7 +359,9 @@ impl Copilot { let server_id = self.server_id; let http = self.http.clone(); let node_runtime = self.node_runtime.clone(); - if all_language_settings(None, cx).copilot_enabled(None, None) { + if all_language_settings(None, cx).inline_completions.provider + == InlineCompletionProvider::Copilot + { if matches!(self.server, CopilotServer::Disabled) { let start_task = cx .spawn(move |this, cx| { @@ -393,7 +400,6 @@ impl Copilot { http: http.clone(), node_runtime, server: CopilotServer::Running(RunningCopilotServer { - name: LanguageServerName(Arc::from("copilot")), lsp: Arc::new(server), sign_in_status: SignInStatus::Authorized, registered_buffers: Default::default(), @@ -467,7 +473,6 @@ impl Copilot { match server { Ok((server, status)) => { this.server = CopilotServer::Running(RunningCopilotServer { - name: LanguageServerName(Arc::from("copilot")), lsp: server, sign_in_status: SignInStatus::SignedOut, registered_buffers: Default::default(), @@ -607,9 +612,9 @@ impl Copilot { cx.background_executor().spawn(start_task) } - pub fn language_server(&self) -> Option<(&LanguageServerName, &Arc)> { + pub fn language_server(&self) -> Option<&Arc> { if let CopilotServer::Running(server) = &self.server { - Some((&server.name, &server.lsp)) + Some(&server.lsp) } else { None } diff --git a/crates/copilot_ui/src/copilot_completion_provider.rs b/crates/copilot/src/copilot_completion_provider.rs similarity index 94% rename from crates/copilot_ui/src/copilot_completion_provider.rs rename to crates/copilot/src/copilot_completion_provider.rs index c6226c7bb1630ab03ee48ee6a4dbb6746e02d636..970145a10f7b5d5a60f0a2cb20786e6807d1a7ad 100644 --- a/crates/copilot_ui/src/copilot_completion_provider.rs +++ b/crates/copilot/src/copilot_completion_provider.rs @@ -1,10 +1,12 @@ +use crate::{Completion, Copilot}; use anyhow::Result; use client::telemetry::Telemetry; -use copilot::Copilot; use editor::{Direction, InlineCompletionProvider}; use gpui::{AppContext, EntityId, Model, ModelContext, Task}; -use language::language_settings::AllLanguageSettings; -use language::{language_settings::all_language_settings, Buffer, OffsetRangeExt, ToOffset}; +use language::{ + language_settings::{all_language_settings, AllLanguageSettings}, + Buffer, OffsetRangeExt, ToOffset, +}; use settings::Settings; use std::{path::Path, sync::Arc, time::Duration}; @@ -13,7 +15,7 @@ pub const COPILOT_DEBOUNCE_TIMEOUT: Duration = Duration::from_millis(75); pub struct CopilotCompletionProvider { cycled: bool, buffer_id: Option, - completions: Vec, + completions: Vec, active_completion_index: usize, file_extension: Option, pending_refresh: Task>, @@ -42,11 +44,11 @@ impl CopilotCompletionProvider { self } - fn active_completion(&self) -> Option<&copilot::Completion> { + fn active_completion(&self) -> Option<&Completion> { self.completions.get(self.active_completion_index) } - fn push_completion(&mut self, new_completion: copilot::Completion) { + fn push_completion(&mut self, new_completion: Completion) { for completion in &self.completions { if completion.text == new_completion.text && completion.range == new_completion.range { return; @@ -71,7 +73,7 @@ impl InlineCompletionProvider for CopilotCompletionProvider { let file = buffer.file(); let language = buffer.language_at(cursor_position); let settings = all_language_settings(file, cx); - settings.copilot_enabled(language.as_ref(), file.map(|f| f.path().as_ref())) + settings.inline_completions_enabled(language.as_ref(), file.map(|f| f.path().as_ref())) } fn refresh( @@ -196,7 +198,10 @@ impl InlineCompletionProvider for CopilotCompletionProvider { fn discard(&mut self, cx: &mut ModelContext) { let settings = AllLanguageSettings::get_global(cx); - if !settings.copilot.feature_enabled { + + let copilot_enabled = settings.inline_completions_enabled(None, None); + + if !copilot_enabled { return; } @@ -298,7 +303,9 @@ mod tests { ) .await; let copilot_provider = cx.new_model(|_| CopilotCompletionProvider::new(copilot)); - cx.update_editor(|editor, cx| editor.set_inline_completion_provider(copilot_provider, cx)); + cx.update_editor(|editor, cx| { + editor.set_inline_completion_provider(Some(copilot_provider), cx) + }); // When inserting, ensure autocompletion is favored over Copilot suggestions. cx.set_state(indoc! {" @@ -318,7 +325,7 @@ mod tests { ); handle_copilot_completion_request( &copilot_lsp, - vec![copilot::request::Completion { + vec![crate::request::Completion { text: "one.copilot1".into(), range: lsp::Range::new(lsp::Position::new(0, 0), lsp::Position::new(0, 4)), ..Default::default() @@ -360,7 +367,7 @@ mod tests { ); handle_copilot_completion_request( &copilot_lsp, - vec![copilot::request::Completion { + vec![crate::request::Completion { text: "one.copilot1".into(), range: lsp::Range::new(lsp::Position::new(0, 0), lsp::Position::new(0, 4)), ..Default::default() @@ -393,7 +400,7 @@ mod tests { ); handle_copilot_completion_request( &copilot_lsp, - vec![copilot::request::Completion { + vec![crate::request::Completion { text: "one.copilot1".into(), range: lsp::Range::new(lsp::Position::new(0, 0), lsp::Position::new(0, 4)), ..Default::default() @@ -426,7 +433,7 @@ mod tests { // After debouncing, new Copilot completions should be requested. handle_copilot_completion_request( &copilot_lsp, - vec![copilot::request::Completion { + vec![crate::request::Completion { text: "one.copilot2".into(), range: lsp::Range::new(lsp::Position::new(0, 0), lsp::Position::new(0, 5)), ..Default::default() @@ -503,7 +510,7 @@ mod tests { }); handle_copilot_completion_request( &copilot_lsp, - vec![copilot::request::Completion { + vec![crate::request::Completion { text: " let x = 4;".into(), range: lsp::Range::new(lsp::Position::new(1, 0), lsp::Position::new(1, 2)), ..Default::default() @@ -553,7 +560,9 @@ mod tests { ) .await; let copilot_provider = cx.new_model(|_| CopilotCompletionProvider::new(copilot)); - cx.update_editor(|editor, cx| editor.set_inline_completion_provider(copilot_provider, cx)); + cx.update_editor(|editor, cx| { + editor.set_inline_completion_provider(Some(copilot_provider), cx) + }); // Setup the editor with a completion request. cx.set_state(indoc! {" @@ -573,7 +582,7 @@ mod tests { ); handle_copilot_completion_request( &copilot_lsp, - vec![copilot::request::Completion { + vec![crate::request::Completion { text: "one.copilot1".into(), range: lsp::Range::new(lsp::Position::new(0, 0), lsp::Position::new(0, 4)), ..Default::default() @@ -615,7 +624,7 @@ mod tests { ); handle_copilot_completion_request( &copilot_lsp, - vec![copilot::request::Completion { + vec![crate::request::Completion { text: "one.123. copilot\n 456".into(), range: lsp::Range::new(lsp::Position::new(0, 0), lsp::Position::new(0, 4)), ..Default::default() @@ -675,7 +684,9 @@ mod tests { ) .await; let copilot_provider = cx.new_model(|_| CopilotCompletionProvider::new(copilot)); - cx.update_editor(|editor, cx| editor.set_inline_completion_provider(copilot_provider, cx)); + cx.update_editor(|editor, cx| { + editor.set_inline_completion_provider(Some(copilot_provider), cx) + }); cx.set_state(indoc! {" one @@ -685,7 +696,7 @@ mod tests { handle_copilot_completion_request( &copilot_lsp, - vec![copilot::request::Completion { + vec![crate::request::Completion { text: "two.foo()".into(), range: lsp::Range::new(lsp::Position::new(1, 0), lsp::Position::new(1, 2)), ..Default::default() @@ -756,13 +767,13 @@ mod tests { let copilot_provider = cx.new_model(|_| CopilotCompletionProvider::new(copilot)); editor .update(cx, |editor, cx| { - editor.set_inline_completion_provider(copilot_provider, cx) + editor.set_inline_completion_provider(Some(copilot_provider), cx) }) .unwrap(); handle_copilot_completion_request( &copilot_lsp, - vec![copilot::request::Completion { + vec![crate::request::Completion { text: "b = 2 + a".into(), range: lsp::Range::new(lsp::Position::new(1, 0), lsp::Position::new(1, 5)), ..Default::default() @@ -788,7 +799,7 @@ mod tests { handle_copilot_completion_request( &copilot_lsp, - vec![copilot::request::Completion { + vec![crate::request::Completion { text: "d = 4 + c".into(), range: lsp::Range::new(lsp::Position::new(1, 0), lsp::Position::new(1, 6)), ..Default::default() @@ -833,7 +844,7 @@ mod tests { async fn test_copilot_disabled_globs(executor: BackgroundExecutor, cx: &mut TestAppContext) { init_test(cx, |settings| { settings - .copilot + .inline_completions .get_or_insert(Default::default()) .disabled_globs = Some(vec![".env*".to_string()]); }); @@ -888,15 +899,15 @@ mod tests { let copilot_provider = cx.new_model(|_| CopilotCompletionProvider::new(copilot)); editor .update(cx, |editor, cx| { - editor.set_inline_completion_provider(copilot_provider, cx) + editor.set_inline_completion_provider(Some(copilot_provider), cx) }) .unwrap(); let mut copilot_requests = copilot_lsp - .handle_request::( + .handle_request::( move |_params, _cx| async move { - Ok(copilot::request::GetCompletionsResult { - completions: vec![copilot::request::Completion { + Ok(crate::request::GetCompletionsResult { + completions: vec![crate::request::Completion { text: "next line".into(), range: lsp::Range::new( lsp::Position::new(1, 0), @@ -931,21 +942,21 @@ mod tests { fn handle_copilot_completion_request( lsp: &lsp::FakeLanguageServer, - completions: Vec, - completions_cycling: Vec, + completions: Vec, + completions_cycling: Vec, ) { - lsp.handle_request::(move |_params, _cx| { + lsp.handle_request::(move |_params, _cx| { let completions = completions.clone(); async move { - Ok(copilot::request::GetCompletionsResult { + Ok(crate::request::GetCompletionsResult { completions: completions.clone(), }) } }); - lsp.handle_request::(move |_params, _cx| { + lsp.handle_request::(move |_params, _cx| { let completions_cycling = completions_cycling.clone(); async move { - Ok(copilot::request::GetCompletionsResult { + Ok(crate::request::GetCompletionsResult { completions: completions_cycling.clone(), }) } diff --git a/crates/copilot_ui/src/sign_in.rs b/crates/copilot/src/sign_in.rs similarity index 98% rename from crates/copilot_ui/src/sign_in.rs rename to crates/copilot/src/sign_in.rs index 396b2367f90bca0142290af024053931d8dd7018..abf7252fef10f656bd1d9893640a18cd32a9c56d 100644 --- a/crates/copilot_ui/src/sign_in.rs +++ b/crates/copilot/src/sign_in.rs @@ -1,4 +1,4 @@ -use copilot::{request::PromptUserDeviceFlow, Copilot, Status}; +use crate::{request::PromptUserDeviceFlow, Copilot, Status}; use gpui::{ div, svg, AppContext, ClipboardItem, DismissEvent, Element, EventEmitter, FocusHandle, FocusableView, InteractiveElement, IntoElement, Model, MouseDownEvent, ParentElement, Render, @@ -26,7 +26,7 @@ impl EventEmitter for CopilotCodeVerification {} impl ModalView for CopilotCodeVerification {} impl CopilotCodeVerification { - pub(crate) fn new(copilot: &Model, cx: &mut ViewContext) -> Self { + pub fn new(copilot: &Model, cx: &mut ViewContext) -> Self { let status = copilot.read(cx).status(); Self { status, diff --git a/crates/copilot_ui/src/copilot_button.rs b/crates/copilot_ui/src/copilot_button.rs deleted file mode 100644 index b228a10839fb9c8d730d779088eb06726699f4be..0000000000000000000000000000000000000000 --- a/crates/copilot_ui/src/copilot_button.rs +++ /dev/null @@ -1,403 +0,0 @@ -use crate::sign_in::CopilotCodeVerification; -use anyhow::Result; -use copilot::{Copilot, SignOut, Status}; -use editor::{scroll::Autoscroll, Editor}; -use fs::Fs; -use gpui::{ - div, Action, AnchorCorner, AppContext, AsyncWindowContext, Entity, IntoElement, ParentElement, - Render, Subscription, View, ViewContext, WeakView, WindowContext, -}; -use language::{ - language_settings::{self, all_language_settings, AllLanguageSettings}, - File, Language, -}; -use settings::{update_settings_file, Settings, SettingsStore}; -use std::{path::Path, sync::Arc}; -use util::{paths, ResultExt}; -use workspace::notifications::NotificationId; -use workspace::{ - create_and_open_local_file, - item::ItemHandle, - ui::{ - popover_menu, ButtonCommon, Clickable, ContextMenu, IconButton, IconName, IconSize, Tooltip, - }, - StatusItemView, Toast, Workspace, -}; -use zed_actions::OpenBrowser; - -const COPILOT_SETTINGS_URL: &str = "https://github.com/settings/copilot"; - -struct CopilotStartingToast; - -struct CopilotErrorToast; - -pub struct CopilotButton { - editor_subscription: Option<(Subscription, usize)>, - editor_enabled: Option, - language: Option>, - file: Option>, - fs: Arc, -} - -impl Render for CopilotButton { - fn render(&mut self, cx: &mut ViewContext) -> impl IntoElement { - let all_language_settings = all_language_settings(None, cx); - if !all_language_settings.copilot.feature_enabled { - return div(); - } - - let Some(copilot) = Copilot::global(cx) else { - return div(); - }; - let status = copilot.read(cx).status(); - - let enabled = self - .editor_enabled - .unwrap_or_else(|| all_language_settings.copilot_enabled(None, None)); - - let icon = match status { - Status::Error(_) => IconName::CopilotError, - Status::Authorized => { - if enabled { - IconName::Copilot - } else { - IconName::CopilotDisabled - } - } - _ => IconName::CopilotInit, - }; - - if let Status::Error(e) = status { - return div().child( - IconButton::new("copilot-error", icon) - .icon_size(IconSize::Small) - .on_click(cx.listener(move |_, _, cx| { - if let Some(workspace) = cx.window_handle().downcast::() { - workspace - .update(cx, |workspace, cx| { - workspace.show_toast( - Toast::new( - NotificationId::unique::(), - format!("Copilot can't be started: {}", e), - ) - .on_click( - "Reinstall Copilot", - |cx| { - if let Some(copilot) = Copilot::global(cx) { - copilot - .update(cx, |copilot, cx| { - copilot.reinstall(cx) - }) - .detach(); - } - }, - ), - cx, - ); - }) - .ok(); - } - })) - .tooltip(|cx| Tooltip::text("GitHub Copilot", cx)), - ); - } - let this = cx.view().clone(); - - div().child( - popover_menu("copilot") - .menu(move |cx| match status { - Status::Authorized => { - Some(this.update(cx, |this, cx| this.build_copilot_menu(cx))) - } - _ => Some(this.update(cx, |this, cx| this.build_copilot_start_menu(cx))), - }) - .anchor(AnchorCorner::BottomRight) - .trigger( - IconButton::new("copilot-icon", icon) - .tooltip(|cx| Tooltip::text("GitHub Copilot", cx)), - ), - ) - } -} - -impl CopilotButton { - pub fn new(fs: Arc, cx: &mut ViewContext) -> Self { - if let Some(copilot) = Copilot::global(cx) { - cx.observe(&copilot, |_, _, cx| cx.notify()).detach() - } - - cx.observe_global::(move |_, cx| cx.notify()) - .detach(); - - Self { - editor_subscription: None, - editor_enabled: None, - language: None, - file: None, - fs, - } - } - - pub fn build_copilot_start_menu(&mut self, cx: &mut ViewContext) -> View { - let fs = self.fs.clone(); - ContextMenu::build(cx, |menu, _| { - menu.entry("Sign In", None, initiate_sign_in).entry( - "Disable Copilot", - None, - move |cx| hide_copilot(fs.clone(), cx), - ) - }) - } - - pub fn build_copilot_menu(&mut self, cx: &mut ViewContext) -> View { - let fs = self.fs.clone(); - - ContextMenu::build(cx, move |mut menu, cx| { - if let Some(language) = self.language.clone() { - let fs = fs.clone(); - let language_enabled = - language_settings::language_settings(Some(&language), None, cx) - .show_copilot_suggestions; - - menu = menu.entry( - format!( - "{} Suggestions for {}", - if language_enabled { "Hide" } else { "Show" }, - language.name() - ), - None, - move |cx| toggle_copilot_for_language(language.clone(), fs.clone(), cx), - ); - } - - let settings = AllLanguageSettings::get_global(cx); - - if let Some(file) = &self.file { - let path = file.path().clone(); - let path_enabled = settings.copilot_enabled_for_path(&path); - - menu = menu.entry( - format!( - "{} Suggestions for This Path", - if path_enabled { "Hide" } else { "Show" } - ), - None, - move |cx| { - if let Some(workspace) = cx.window_handle().downcast::() { - if let Ok(workspace) = workspace.root_view(cx) { - let workspace = workspace.downgrade(); - cx.spawn(|cx| { - configure_disabled_globs( - workspace, - path_enabled.then_some(path.clone()), - cx, - ) - }) - .detach_and_log_err(cx); - } - } - }, - ); - } - - let globally_enabled = settings.copilot_enabled(None, None); - menu.entry( - if globally_enabled { - "Hide Suggestions for All Files" - } else { - "Show Suggestions for All Files" - }, - None, - move |cx| toggle_copilot_globally(fs.clone(), cx), - ) - .separator() - .link( - "Copilot Settings", - OpenBrowser { - url: COPILOT_SETTINGS_URL.to_string(), - } - .boxed_clone(), - ) - .action("Sign Out", SignOut.boxed_clone()) - }) - } - - pub fn update_enabled(&mut self, editor: View, cx: &mut ViewContext) { - let editor = editor.read(cx); - let snapshot = editor.buffer().read(cx).snapshot(cx); - let suggestion_anchor = editor.selections.newest_anchor().start; - let language = snapshot.language_at(suggestion_anchor); - let file = snapshot.file_at(suggestion_anchor).cloned(); - self.editor_enabled = { - let file = file.as_ref(); - Some( - file.map(|file| !file.is_private()).unwrap_or(true) - && all_language_settings(file, cx) - .copilot_enabled(language, file.map(|file| file.path().as_ref())), - ) - }; - self.language = language.cloned(); - self.file = file; - - cx.notify() - } -} - -impl StatusItemView for CopilotButton { - fn set_active_pane_item(&mut self, item: Option<&dyn ItemHandle>, cx: &mut ViewContext) { - if let Some(editor) = item.and_then(|item| item.act_as::(cx)) { - self.editor_subscription = Some(( - cx.observe(&editor, Self::update_enabled), - editor.entity_id().as_u64() as usize, - )); - self.update_enabled(editor, cx); - } else { - self.language = None; - self.editor_subscription = None; - self.editor_enabled = None; - } - cx.notify(); - } -} - -async fn configure_disabled_globs( - workspace: WeakView, - path_to_disable: Option>, - mut cx: AsyncWindowContext, -) -> Result<()> { - let settings_editor = workspace - .update(&mut cx, |_, cx| { - create_and_open_local_file(&paths::SETTINGS, cx, || { - settings::initial_user_settings_content().as_ref().into() - }) - })? - .await? - .downcast::() - .unwrap(); - - settings_editor.downgrade().update(&mut cx, |item, cx| { - let text = item.buffer().read(cx).snapshot(cx).text(); - - let settings = cx.global::(); - let edits = settings.edits_for_update::(&text, |file| { - let copilot = file.copilot.get_or_insert_with(Default::default); - let globs = copilot.disabled_globs.get_or_insert_with(|| { - settings - .get::(None) - .copilot - .disabled_globs - .iter() - .map(|glob| glob.glob().to_string()) - .collect() - }); - - if let Some(path_to_disable) = &path_to_disable { - globs.push(path_to_disable.to_string_lossy().into_owned()); - } else { - globs.clear(); - } - }); - - if !edits.is_empty() { - item.change_selections(Some(Autoscroll::newest()), cx, |selections| { - selections.select_ranges(edits.iter().map(|e| e.0.clone())); - }); - - // When *enabling* a path, don't actually perform an edit, just select the range. - if path_to_disable.is_some() { - item.edit(edits.iter().cloned(), cx); - } - } - })?; - - anyhow::Ok(()) -} - -fn toggle_copilot_globally(fs: Arc, cx: &mut AppContext) { - let show_copilot_suggestions = all_language_settings(None, cx).copilot_enabled(None, None); - update_settings_file::(fs, cx, move |file| { - file.defaults.show_copilot_suggestions = Some(!show_copilot_suggestions) - }); -} - -fn toggle_copilot_for_language(language: Arc, fs: Arc, cx: &mut AppContext) { - let show_copilot_suggestions = - all_language_settings(None, cx).copilot_enabled(Some(&language), None); - update_settings_file::(fs, cx, move |file| { - file.languages - .entry(language.name()) - .or_default() - .show_copilot_suggestions = Some(!show_copilot_suggestions); - }); -} - -fn hide_copilot(fs: Arc, cx: &mut AppContext) { - update_settings_file::(fs, cx, move |file| { - file.features.get_or_insert(Default::default()).copilot = Some(false); - }); -} - -pub fn initiate_sign_in(cx: &mut WindowContext) { - let Some(copilot) = Copilot::global(cx) else { - return; - }; - let status = copilot.read(cx).status(); - let Some(workspace) = cx.window_handle().downcast::() else { - return; - }; - match status { - Status::Starting { task } => { - let Some(workspace) = cx.window_handle().downcast::() else { - return; - }; - - let Ok(workspace) = workspace.update(cx, |workspace, cx| { - workspace.show_toast( - Toast::new( - NotificationId::unique::(), - "Copilot is starting...", - ), - cx, - ); - workspace.weak_handle() - }) else { - return; - }; - - cx.spawn(|mut cx| async move { - task.await; - if let Some(copilot) = cx.update(|cx| Copilot::global(cx)).ok().flatten() { - workspace - .update(&mut cx, |workspace, cx| match copilot.read(cx).status() { - Status::Authorized => workspace.show_toast( - Toast::new( - NotificationId::unique::(), - "Copilot has started!", - ), - cx, - ), - _ => { - workspace.dismiss_toast( - &NotificationId::unique::(), - cx, - ); - copilot - .update(cx, |copilot, cx| copilot.sign_in(cx)) - .detach_and_log_err(cx); - } - }) - .log_err(); - } - }) - .detach(); - } - _ => { - copilot.update(cx, |this, cx| this.sign_in(cx)).detach(); - workspace - .update(cx, |this, cx| { - this.toggle_modal(cx, |cx| CopilotCodeVerification::new(&copilot, cx)); - }) - .ok(); - } - } -} diff --git a/crates/copilot_ui/src/copilot_ui.rs b/crates/copilot_ui/src/copilot_ui.rs deleted file mode 100644 index 63bd03102fd3bb6e26ee8d1d44008335c31372e0..0000000000000000000000000000000000000000 --- a/crates/copilot_ui/src/copilot_ui.rs +++ /dev/null @@ -1,7 +0,0 @@ -pub mod copilot_button; -mod copilot_completion_provider; -mod sign_in; - -pub use copilot_button::*; -pub use copilot_completion_provider::*; -pub use sign_in::*; diff --git a/crates/editor/src/editor.rs b/crates/editor/src/editor.rs index ef5c97a5924803c581975ca7717ba33bfdeb3397..bbd215b23bdd7a677d06e2bf6b99801526cfff9c 100644 --- a/crates/editor/src/editor.rs +++ b/crates/editor/src/editor.rs @@ -1757,19 +1757,22 @@ impl Editor { self.completion_provider = Some(hub); } - pub fn set_inline_completion_provider( + pub fn set_inline_completion_provider( &mut self, - provider: Model, + provider: Option>, cx: &mut ViewContext, - ) { - self.inline_completion_provider = Some(RegisteredInlineCompletionProvider { - _subscription: cx.observe(&provider, |this, _, cx| { - if this.focus_handle.is_focused(cx) { - this.update_visible_inline_completion(cx); - } - }), - provider: Arc::new(provider), - }); + ) where + T: InlineCompletionProvider, + { + self.inline_completion_provider = + provider.map(|provider| RegisteredInlineCompletionProvider { + _subscription: cx.observe(&provider, |this, _, cx| { + if this.focus_handle.is_focused(cx) { + this.update_visible_inline_completion(cx); + } + }), + provider: Arc::new(provider), + }); self.refresh_inline_completion(false, cx); } @@ -2676,7 +2679,7 @@ impl Editor { } drop(snapshot); - let had_active_copilot_completion = this.has_active_inline_completion(cx); + let had_active_inline_completion = this.has_active_inline_completion(cx); this.change_selections(Some(Autoscroll::fit()), cx, |s| s.select(new_selections)); if brace_inserted { @@ -2692,7 +2695,7 @@ impl Editor { } } - if had_active_copilot_completion { + if had_active_inline_completion { this.refresh_inline_completion(true, cx); if !this.has_active_inline_completion(cx) { this.trigger_completion_on_input(&text, cx); @@ -4005,7 +4008,7 @@ impl Editor { if !self.show_inline_completions || !provider.is_enabled(&buffer, cursor_buffer_position, cx) { - self.clear_inline_completion(cx); + self.discard_inline_completion(cx); return None; } @@ -4207,13 +4210,6 @@ impl Editor { self.discard_inline_completion(cx); } - fn clear_inline_completion(&mut self, cx: &mut ViewContext) { - if let Some(old_completion) = self.active_inline_completion.take() { - self.splice_inlays(vec![old_completion.id], Vec::new(), cx); - } - self.discard_inline_completion(cx); - } - fn inline_completion_provider(&self) -> Option> { Some(self.inline_completion_provider.as_ref()?.provider.clone()) } @@ -9947,12 +9943,14 @@ impl Editor { .raw_user_settings() .get("vim_mode") == Some(&serde_json::Value::Bool(true)); - let copilot_enabled = all_language_settings(file, cx).copilot_enabled(None, None); + + let copilot_enabled = all_language_settings(file, cx).inline_completions.provider + == language::language_settings::InlineCompletionProvider::Copilot; let copilot_enabled_for_language = self .buffer .read(cx) .settings_at(0, cx) - .show_copilot_suggestions; + .show_inline_completions; let telemetry = project.read(cx).client().telemetry().clone(); telemetry.report_editor_event( diff --git a/crates/editor/src/inline_completion_provider.rs b/crates/editor/src/inline_completion_provider.rs index 31edf806239bf3a22361be9373a42ed5698aa610..2fb2cb608f20fc5517185ed22ac8591f3ec1a7d4 100644 --- a/crates/editor/src/inline_completion_provider.rs +++ b/crates/editor/src/inline_completion_provider.rs @@ -25,11 +25,11 @@ pub trait InlineCompletionProvider: 'static + Sized { ); fn accept(&mut self, cx: &mut ModelContext); fn discard(&mut self, cx: &mut ModelContext); - fn active_completion_text( - &self, + fn active_completion_text<'a>( + &'a self, buffer: &Model, cursor_position: language::Anchor, - cx: &AppContext, + cx: &'a AppContext, ) -> Option<&str>; } @@ -57,7 +57,7 @@ pub trait InlineCompletionProviderHandle { fn accept(&self, cx: &mut AppContext); fn discard(&self, cx: &mut AppContext); fn active_completion_text<'a>( - &self, + &'a self, buffer: &Model, cursor_position: language::Anchor, cx: &'a AppContext, @@ -110,7 +110,7 @@ where } fn active_completion_text<'a>( - &self, + &'a self, buffer: &Model, cursor_position: language::Anchor, cx: &'a AppContext, diff --git a/crates/google_ai/src/google_ai.rs b/crates/google_ai/src/google_ai.rs index 4fe461981fe97f313498b6b01e67baf498ce9740..53ed5894dec4666e45343409f2b3888c6fe4d48c 100644 --- a/crates/google_ai/src/google_ai.rs +++ b/crates/google_ai/src/google_ai.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use anyhow::{anyhow, Result}; use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, StreamExt}; use serde::{Deserialize, Serialize}; @@ -5,8 +7,8 @@ use util::http::HttpClient; pub const API_URL: &str = "https://generativelanguage.googleapis.com"; -pub async fn stream_generate_content( - client: &T, +pub async fn stream_generate_content( + client: Arc, api_url: &str, api_key: &str, request: GenerateContentRequest, diff --git a/crates/copilot_ui/Cargo.toml b/crates/inline_completion_button/Cargo.toml similarity index 88% rename from crates/copilot_ui/Cargo.toml rename to crates/inline_completion_button/Cargo.toml index 4bf3240aabd95a7cd671a828978a8f0542d272e6..48acdb3ae1d89ae47314c66ba6a3719293ab0e85 100644 --- a/crates/copilot_ui/Cargo.toml +++ b/crates/inline_completion_button/Cargo.toml @@ -1,5 +1,5 @@ [package] -name = "copilot_ui" +name = "inline_completion_button" version = "0.1.0" edition = "2021" publish = false @@ -9,19 +9,18 @@ license = "GPL-3.0-or-later" workspace = true [lib] -path = "src/copilot_ui.rs" +path = "src/inline_completion_button.rs" doctest = false [dependencies] anyhow.workspace = true -client.workspace = true copilot.workspace = true editor.workspace = true fs.workspace = true gpui.workspace = true language.workspace = true -menu.workspace = true settings.workspace = true +supermaven.workspace = true ui.workspace = true util.workspace = true workspace.workspace = true diff --git a/crates/copilot_ui/LICENSE-GPL b/crates/inline_completion_button/LICENSE-GPL similarity index 100% rename from crates/copilot_ui/LICENSE-GPL rename to crates/inline_completion_button/LICENSE-GPL diff --git a/crates/inline_completion_button/src/inline_completion_button.rs b/crates/inline_completion_button/src/inline_completion_button.rs new file mode 100644 index 0000000000000000000000000000000000000000..86f6945ac1c4f0a450246fa05ffa139de3cdefcc --- /dev/null +++ b/crates/inline_completion_button/src/inline_completion_button.rs @@ -0,0 +1,510 @@ +use anyhow::Result; +use copilot::{Copilot, CopilotCodeVerification, Status}; +use editor::{scroll::Autoscroll, Editor}; +use fs::Fs; +use gpui::{ + div, Action, AnchorCorner, AppContext, AsyncWindowContext, Entity, IntoElement, ParentElement, + Render, Subscription, View, ViewContext, WeakView, WindowContext, +}; +use language::{ + language_settings::{ + self, all_language_settings, AllLanguageSettings, InlineCompletionProvider, + }, + File, Language, +}; +use settings::{update_settings_file, Settings, SettingsStore}; +use std::{path::Path, sync::Arc}; +use supermaven::{AccountStatus, Supermaven}; +use util::{paths, ResultExt}; +use workspace::{ + create_and_open_local_file, + item::ItemHandle, + notifications::NotificationId, + ui::{ + popover_menu, ButtonCommon, Clickable, ContextMenu, IconButton, IconName, IconSize, Tooltip, + }, + StatusItemView, Toast, Workspace, +}; +use zed_actions::OpenBrowser; + +const COPILOT_SETTINGS_URL: &str = "https://github.com/settings/copilot"; + +struct CopilotStartingToast; + +struct CopilotErrorToast; + +pub struct InlineCompletionButton { + editor_subscription: Option<(Subscription, usize)>, + editor_enabled: Option, + language: Option>, + file: Option>, + fs: Arc, +} + +enum SupermavenButtonStatus { + Ready, + Errored(String), + NeedsActivation(String), + Initializing, +} + +impl Render for InlineCompletionButton { + fn render(&mut self, cx: &mut ViewContext) -> impl IntoElement { + let all_language_settings = all_language_settings(None, cx); + + match all_language_settings.inline_completions.provider { + InlineCompletionProvider::None => return div(), + + InlineCompletionProvider::Copilot => { + let Some(copilot) = Copilot::global(cx) else { + return div(); + }; + let status = copilot.read(cx).status(); + + let enabled = self.editor_enabled.unwrap_or_else(|| { + all_language_settings.inline_completions_enabled(None, None) + }); + + let icon = match status { + Status::Error(_) => IconName::CopilotError, + Status::Authorized => { + if enabled { + IconName::Copilot + } else { + IconName::CopilotDisabled + } + } + _ => IconName::CopilotInit, + }; + + if let Status::Error(e) = status { + return div().child( + IconButton::new("copilot-error", icon) + .icon_size(IconSize::Small) + .on_click(cx.listener(move |_, _, cx| { + if let Some(workspace) = cx.window_handle().downcast::() + { + workspace + .update(cx, |workspace, cx| { + workspace.show_toast( + Toast::new( + NotificationId::unique::(), + format!("Copilot can't be started: {}", e), + ) + .on_click("Reinstall Copilot", |cx| { + if let Some(copilot) = Copilot::global(cx) { + copilot + .update(cx, |copilot, cx| { + copilot.reinstall(cx) + }) + .detach(); + } + }), + cx, + ); + }) + .ok(); + } + })) + .tooltip(|cx| Tooltip::text("GitHub Copilot", cx)), + ); + } + let this = cx.view().clone(); + + div().child( + popover_menu("copilot") + .menu(move |cx| { + Some(match status { + Status::Authorized => { + this.update(cx, |this, cx| this.build_copilot_context_menu(cx)) + } + _ => this.update(cx, |this, cx| this.build_copilot_start_menu(cx)), + }) + }) + .anchor(AnchorCorner::BottomRight) + .trigger( + IconButton::new("copilot-icon", icon) + .tooltip(|cx| Tooltip::text("GitHub Copilot", cx)), + ), + ) + } + + InlineCompletionProvider::Supermaven => { + let Some(supermaven) = Supermaven::global(cx) else { + return div(); + }; + + let supermaven = supermaven.read(cx); + + let status = match supermaven { + Supermaven::Starting => SupermavenButtonStatus::Initializing, + Supermaven::FailedDownload { error } => { + SupermavenButtonStatus::Errored(error.to_string()) + } + Supermaven::Spawned(agent) => { + let account_status = agent.account_status.clone(); + match account_status { + AccountStatus::NeedsActivation { activate_url } => { + SupermavenButtonStatus::NeedsActivation(activate_url.clone()) + } + AccountStatus::Unknown => SupermavenButtonStatus::Initializing, + AccountStatus::Ready => SupermavenButtonStatus::Ready, + } + } + Supermaven::Error { error } => { + SupermavenButtonStatus::Errored(error.to_string()) + } + }; + + let icon = status.to_icon(); + let tooltip_text = status.to_tooltip(); + let this = cx.view().clone(); + + return div().child( + popover_menu("supermaven") + .menu(move |cx| match &status { + SupermavenButtonStatus::NeedsActivation(activate_url) => { + Some(ContextMenu::build(cx, |menu, _| { + let activate_url = activate_url.clone(); + menu.entry("Sign In", None, move |cx| { + cx.open_url(activate_url.as_str()) + }) + })) + } + SupermavenButtonStatus::Ready => Some( + this.update(cx, |this, cx| this.build_supermaven_context_menu(cx)), + ), + _ => None, + }) + .anchor(AnchorCorner::BottomRight) + .trigger( + IconButton::new("supermaven-icon", icon) + .tooltip(move |cx| Tooltip::text(tooltip_text.clone(), cx)), + ), + ); + } + } + } +} + +impl InlineCompletionButton { + pub fn new(fs: Arc, cx: &mut ViewContext) -> Self { + if let Some(copilot) = Copilot::global(cx) { + cx.observe(&copilot, |_, _, cx| cx.notify()).detach() + } + + cx.observe_global::(move |_, cx| cx.notify()) + .detach(); + + Self { + editor_subscription: None, + editor_enabled: None, + language: None, + file: None, + fs, + } + } + + pub fn build_copilot_start_menu(&mut self, cx: &mut ViewContext) -> View { + let fs = self.fs.clone(); + ContextMenu::build(cx, |menu, _| { + menu.entry("Sign In", None, initiate_sign_in).entry( + "Disable Copilot", + None, + move |cx| hide_copilot(fs.clone(), cx), + ) + }) + } + + pub fn build_language_settings_menu( + &self, + mut menu: ContextMenu, + cx: &mut WindowContext, + ) -> ContextMenu { + let fs = self.fs.clone(); + + if let Some(language) = self.language.clone() { + let fs = fs.clone(); + let language_enabled = language_settings::language_settings(Some(&language), None, cx) + .show_inline_completions; + + menu = menu.entry( + format!( + "{} Inline Completions for {}", + if language_enabled { "Hide" } else { "Show" }, + language.name() + ), + None, + move |cx| toggle_inline_completions_for_language(language.clone(), fs.clone(), cx), + ); + } + + let settings = AllLanguageSettings::get_global(cx); + + if let Some(file) = &self.file { + let path = file.path().clone(); + let path_enabled = settings.inline_completions_enabled_for_path(&path); + + menu = menu.entry( + format!( + "{} Inline Completions for This Path", + if path_enabled { "Hide" } else { "Show" } + ), + None, + move |cx| { + if let Some(workspace) = cx.window_handle().downcast::() { + if let Ok(workspace) = workspace.root_view(cx) { + let workspace = workspace.downgrade(); + cx.spawn(|cx| { + configure_disabled_globs( + workspace, + path_enabled.then_some(path.clone()), + cx, + ) + }) + .detach_and_log_err(cx); + } + } + }, + ); + } + + let globally_enabled = settings.inline_completions_enabled(None, None); + menu.entry( + if globally_enabled { + "Hide Inline Completions for All Files" + } else { + "Show Inline Completions for All Files" + }, + None, + move |cx| toggle_inline_completions_globally(fs.clone(), cx), + ) + } + + fn build_copilot_context_menu(&self, cx: &mut ViewContext) -> View { + ContextMenu::build(cx, |menu, cx| { + self.build_language_settings_menu(menu, cx) + .separator() + .link( + "Copilot Settings", + OpenBrowser { + url: COPILOT_SETTINGS_URL.to_string(), + } + .boxed_clone(), + ) + .action("Sign Out", copilot::SignOut.boxed_clone()) + }) + } + + fn build_supermaven_context_menu(&self, cx: &mut ViewContext) -> View { + ContextMenu::build(cx, |menu, cx| { + self.build_language_settings_menu(menu, cx).separator() + }) + } + + pub fn update_enabled(&mut self, editor: View, cx: &mut ViewContext) { + let editor = editor.read(cx); + let snapshot = editor.buffer().read(cx).snapshot(cx); + let suggestion_anchor = editor.selections.newest_anchor().start; + let language = snapshot.language_at(suggestion_anchor); + let file = snapshot.file_at(suggestion_anchor).cloned(); + self.editor_enabled = { + let file = file.as_ref(); + Some( + file.map(|file| !file.is_private()).unwrap_or(true) + && all_language_settings(file, cx).inline_completions_enabled( + language, + file.map(|file| file.path().as_ref()), + ), + ) + }; + self.language = language.cloned(); + self.file = file; + + cx.notify() + } +} + +impl StatusItemView for InlineCompletionButton { + fn set_active_pane_item(&mut self, item: Option<&dyn ItemHandle>, cx: &mut ViewContext) { + if let Some(editor) = item.and_then(|item| item.act_as::(cx)) { + self.editor_subscription = Some(( + cx.observe(&editor, Self::update_enabled), + editor.entity_id().as_u64() as usize, + )); + self.update_enabled(editor, cx); + } else { + self.language = None; + self.editor_subscription = None; + self.editor_enabled = None; + } + cx.notify(); + } +} + +impl SupermavenButtonStatus { + fn to_icon(&self) -> IconName { + match self { + SupermavenButtonStatus::Ready => IconName::Supermaven, + SupermavenButtonStatus::Errored(_) => IconName::SupermavenError, + SupermavenButtonStatus::NeedsActivation(_) => IconName::SupermavenInit, + SupermavenButtonStatus::Initializing => IconName::SupermavenInit, + } + } + + fn to_tooltip(&self) -> String { + match self { + SupermavenButtonStatus::Ready => "Supermaven is ready".to_string(), + SupermavenButtonStatus::Errored(error) => format!("Supermaven error: {}", error), + SupermavenButtonStatus::NeedsActivation(_) => "Supermaven needs activation".to_string(), + SupermavenButtonStatus::Initializing => "Supermaven initializing".to_string(), + } + } +} + +async fn configure_disabled_globs( + workspace: WeakView, + path_to_disable: Option>, + mut cx: AsyncWindowContext, +) -> Result<()> { + let settings_editor = workspace + .update(&mut cx, |_, cx| { + create_and_open_local_file(&paths::SETTINGS, cx, || { + settings::initial_user_settings_content().as_ref().into() + }) + })? + .await? + .downcast::() + .unwrap(); + + settings_editor.downgrade().update(&mut cx, |item, cx| { + let text = item.buffer().read(cx).snapshot(cx).text(); + + let settings = cx.global::(); + let edits = settings.edits_for_update::(&text, |file| { + let copilot = file.inline_completions.get_or_insert_with(Default::default); + let globs = copilot.disabled_globs.get_or_insert_with(|| { + settings + .get::(None) + .inline_completions + .disabled_globs + .iter() + .map(|glob| glob.glob().to_string()) + .collect() + }); + + if let Some(path_to_disable) = &path_to_disable { + globs.push(path_to_disable.to_string_lossy().into_owned()); + } else { + globs.clear(); + } + }); + + if !edits.is_empty() { + item.change_selections(Some(Autoscroll::newest()), cx, |selections| { + selections.select_ranges(edits.iter().map(|e| e.0.clone())); + }); + + // When *enabling* a path, don't actually perform an edit, just select the range. + if path_to_disable.is_some() { + item.edit(edits.iter().cloned(), cx); + } + } + })?; + + anyhow::Ok(()) +} + +fn toggle_inline_completions_globally(fs: Arc, cx: &mut AppContext) { + let show_inline_completions = + all_language_settings(None, cx).inline_completions_enabled(None, None); + update_settings_file::(fs, cx, move |file| { + file.defaults.show_inline_completions = Some(!show_inline_completions) + }); +} + +fn toggle_inline_completions_for_language( + language: Arc, + fs: Arc, + cx: &mut AppContext, +) { + let show_inline_completions = + all_language_settings(None, cx).inline_completions_enabled(Some(&language), None); + update_settings_file::(fs, cx, move |file| { + file.languages + .entry(language.name()) + .or_default() + .show_inline_completions = Some(!show_inline_completions); + }); +} + +fn hide_copilot(fs: Arc, cx: &mut AppContext) { + update_settings_file::(fs, cx, move |file| { + file.features.get_or_insert(Default::default()).copilot = Some(false); + }); +} + +pub fn initiate_sign_in(cx: &mut WindowContext) { + let Some(copilot) = Copilot::global(cx) else { + return; + }; + let status = copilot.read(cx).status(); + let Some(workspace) = cx.window_handle().downcast::() else { + return; + }; + match status { + Status::Starting { task } => { + let Some(workspace) = cx.window_handle().downcast::() else { + return; + }; + + let Ok(workspace) = workspace.update(cx, |workspace, cx| { + workspace.show_toast( + Toast::new( + NotificationId::unique::(), + "Copilot is starting...", + ), + cx, + ); + workspace.weak_handle() + }) else { + return; + }; + + cx.spawn(|mut cx| async move { + task.await; + if let Some(copilot) = cx.update(|cx| Copilot::global(cx)).ok().flatten() { + workspace + .update(&mut cx, |workspace, cx| match copilot.read(cx).status() { + Status::Authorized => workspace.show_toast( + Toast::new( + NotificationId::unique::(), + "Copilot has started!", + ), + cx, + ), + _ => { + workspace.dismiss_toast( + &NotificationId::unique::(), + cx, + ); + copilot + .update(cx, |copilot, cx| copilot.sign_in(cx)) + .detach_and_log_err(cx); + } + }) + .log_err(); + } + }) + .detach(); + } + _ => { + copilot.update(cx, |this, cx| this.sign_in(cx)).detach(); + workspace + .update(cx, |this, cx| { + this.toggle_modal(cx, |cx| CopilotCodeVerification::new(&copilot, cx)); + }) + .ok(); + } + } +} diff --git a/crates/language/src/language_settings.rs b/crates/language/src/language_settings.rs index bea5344be217b21ed556a8d433da92cdde49aad8..537816b983cb3bb5b5e3e3b494c0f2cfe67d77c4 100644 --- a/crates/language/src/language_settings.rs +++ b/crates/language/src/language_settings.rs @@ -51,8 +51,8 @@ pub fn all_language_settings<'a>( /// The settings for all languages. #[derive(Debug, Clone)] pub struct AllLanguageSettings { - /// The settings for GitHub Copilot. - pub copilot: CopilotSettings, + /// The inline completion settings. + pub inline_completions: InlineCompletionSettings, defaults: LanguageSettings, languages: HashMap, LanguageSettings>, pub(crate) file_types: HashMap, Vec>, @@ -101,9 +101,9 @@ pub struct LanguageSettings { /// - `"!"` - A language server ID prefixed with a `!` will be disabled. /// - `"..."` - A placeholder to refer to the **rest** of the registered language servers for this language. pub language_servers: Vec>, - /// Controls whether Copilot provides suggestion immediately (true) - /// or waits for a `copilot::Toggle` (false). - pub show_copilot_suggestions: bool, + /// Controls whether inline completions are shown immediately (true) + /// or manually by triggering `editor::ShowInlineCompletion` (false). + pub show_inline_completions: bool, /// Whether to show tabs and spaces in the editor. pub show_whitespaces: ShowWhitespaceSetting, /// Whether to start a new line with a comment when a previous line is a comment as well. @@ -165,12 +165,23 @@ impl LanguageSettings { } } -/// The settings for [GitHub Copilot](https://github.com/features/copilot). +/// The provider that supplies inline completions. +#[derive(Copy, Clone, Debug, Default, Eq, PartialEq, Serialize, Deserialize, JsonSchema)] +#[serde(rename_all = "snake_case")] +pub enum InlineCompletionProvider { + None, + #[default] + Copilot, + Supermaven, +} + +/// The settings for inline completions, such as [GitHub Copilot](https://github.com/features/copilot) +/// or [Supermaven](https://supermaven.com). #[derive(Clone, Debug, Default)] -pub struct CopilotSettings { - /// Whether Copilot is enabled. - pub feature_enabled: bool, - /// A list of globs representing files that Copilot should be disabled for. +pub struct InlineCompletionSettings { + /// The provider that supplies inline completions. + pub provider: InlineCompletionProvider, + /// A list of globs representing files that inline completions should be disabled for. pub disabled_globs: Vec, } @@ -180,9 +191,9 @@ pub struct AllLanguageSettingsContent { /// The settings for enabling/disabling features. #[serde(default)] pub features: Option, - /// The settings for GitHub Copilot. - #[serde(default)] - pub copilot: Option, + /// The inline completion settings. + #[serde(default, alias = "copilot")] + pub inline_completions: Option, /// The default language settings. #[serde(flatten)] pub defaults: LanguageSettingsContent, @@ -277,12 +288,12 @@ pub struct LanguageSettingsContent { /// Default: ["..."] #[serde(default)] pub language_servers: Option>>, - /// Controls whether Copilot provides suggestion immediately (true) - /// or waits for a `copilot::Toggle` (false). + /// Controls whether inline completions are shown immediately (true) + /// or manually by triggering `editor::ShowInlineCompletion` (false). /// /// Default: true - #[serde(default)] - pub show_copilot_suggestions: Option, + #[serde(default, alias = "show_copilot_suggestions")] + pub show_inline_completions: Option, /// Whether to show tabs and spaces in the editor. #[serde(default)] pub show_whitespaces: Option, @@ -314,10 +325,10 @@ pub struct LanguageSettingsContent { pub code_actions_on_format: Option>, } -/// The contents of the GitHub Copilot settings. -#[derive(Clone, Debug, PartialEq, Default, Serialize, Deserialize, JsonSchema)] -pub struct CopilotSettingsContent { - /// A list of globs representing files that Copilot should be disabled for. +/// The contents of the inline completion settings. +#[derive(Clone, Debug, Default, Serialize, Deserialize, JsonSchema, PartialEq)] +pub struct InlineCompletionSettingsContent { + /// A list of globs representing files that inline completions should be disabled for. #[serde(default)] pub disabled_globs: Option>, } @@ -328,6 +339,8 @@ pub struct CopilotSettingsContent { pub struct FeaturesContent { /// Whether the GitHub Copilot feature is enabled. pub copilot: Option, + /// Determines which inline completion provider to use. + pub inline_completion_provider: Option, } /// Controls the soft-wrapping behavior in the editor. @@ -475,29 +488,29 @@ impl AllLanguageSettings { &self.defaults } - /// Returns whether GitHub Copilot is enabled for the given path. - pub fn copilot_enabled_for_path(&self, path: &Path) -> bool { + /// Returns whether inline completions are enabled for the given path. + pub fn inline_completions_enabled_for_path(&self, path: &Path) -> bool { !self - .copilot + .inline_completions .disabled_globs .iter() .any(|glob| glob.is_match(path)) } - /// Returns whether GitHub Copilot is enabled for the given language and path. - pub fn copilot_enabled(&self, language: Option<&Arc>, path: Option<&Path>) -> bool { - if !self.copilot.feature_enabled { - return false; - } - + /// Returns whether inline completions are enabled for the given language and path. + pub fn inline_completions_enabled( + &self, + language: Option<&Arc>, + path: Option<&Path>, + ) -> bool { if let Some(path) = path { - if !self.copilot_enabled_for_path(path) { + if !self.inline_completions_enabled_for_path(path) { return false; } } self.language(language.map(|l| l.name()).as_deref()) - .show_copilot_suggestions + .show_inline_completions } } @@ -551,13 +564,13 @@ impl settings::Settings for AllLanguageSettings { languages.insert(language_name.clone(), language_settings); } - let mut copilot_enabled = default_value + let mut copilot_enabled = default_value.features.as_ref().and_then(|f| f.copilot); + let mut inline_completion_provider = default_value .features .as_ref() - .and_then(|f| f.copilot) - .ok_or_else(Self::missing_default)?; - let mut copilot_globs = default_value - .copilot + .and_then(|f| f.inline_completion_provider); + let mut completion_globs = default_value + .inline_completions .as_ref() .and_then(|c| c.disabled_globs.as_ref()) .ok_or_else(Self::missing_default)?; @@ -565,14 +578,21 @@ impl settings::Settings for AllLanguageSettings { let mut file_types: HashMap, Vec> = HashMap::default(); for user_settings in sources.customizations() { if let Some(copilot) = user_settings.features.as_ref().and_then(|f| f.copilot) { - copilot_enabled = copilot; + copilot_enabled = Some(copilot); + } + if let Some(provider) = user_settings + .features + .as_ref() + .and_then(|f| f.inline_completion_provider) + { + inline_completion_provider = Some(provider); } if let Some(globs) = user_settings - .copilot + .inline_completions .as_ref() .and_then(|f| f.disabled_globs.as_ref()) { - copilot_globs = globs; + completion_globs = globs; } // A user's global settings override the default global settings and @@ -601,9 +621,15 @@ impl settings::Settings for AllLanguageSettings { } Ok(Self { - copilot: CopilotSettings { - feature_enabled: copilot_enabled, - disabled_globs: copilot_globs + inline_completions: InlineCompletionSettings { + provider: if let Some(provider) = inline_completion_provider { + provider + } else if copilot_enabled.unwrap_or(true) { + InlineCompletionProvider::Copilot + } else { + InlineCompletionProvider::None + }, + disabled_globs: completion_globs .iter() .filter_map(|g| Some(globset::Glob::new(g).ok()?.compile_matcher())) .collect(), @@ -714,8 +740,8 @@ fn merge_settings(settings: &mut LanguageSettings, src: &LanguageSettingsContent ); merge(&mut settings.language_servers, src.language_servers.clone()); merge( - &mut settings.show_copilot_suggestions, - src.show_copilot_suggestions, + &mut settings.show_inline_completions, + src.show_inline_completions, ); merge(&mut settings.show_whitespaces, src.show_whitespaces); merge( diff --git a/crates/language_tools/Cargo.toml b/crates/language_tools/Cargo.toml index 6d0a1199b3d6a2eef885fb0c7f1c388f1e7cc030..d85f5a6e52b873e5225b3500d524ef3d72fc66b2 100644 --- a/crates/language_tools/Cargo.toml +++ b/crates/language_tools/Cargo.toml @@ -15,6 +15,7 @@ doctest = false [dependencies] anyhow.workspace = true collections.workspace = true +copilot.workspace = true editor.workspace = true futures.workspace = true gpui.workspace = true @@ -26,7 +27,6 @@ settings.workspace = true theme.workspace = true tree-sitter.workspace = true ui.workspace = true -util.workspace = true workspace.workspace = true [dev-dependencies] diff --git a/crates/language_tools/src/lsp_log.rs b/crates/language_tools/src/lsp_log.rs index a35d8b33e59a1a8de212be0faf936a35aab2c400..28a27aac6047fbf07b03c40ead5ff806ceb0a77d 100644 --- a/crates/language_tools/src/lsp_log.rs +++ b/crates/language_tools/src/lsp_log.rs @@ -1,4 +1,5 @@ use collections::{HashMap, VecDeque}; +use copilot::Copilot; use editor::{actions::MoveToEnd, Editor, EditorEvent}; use futures::{channel::mpsc, StreamExt}; use gpui::{ @@ -7,11 +8,10 @@ use gpui::{ View, ViewContext, VisualContext, WeakModel, WindowContext, }; use language::{LanguageServerId, LanguageServerName}; -use lsp::IoKind; +use lsp::{IoKind, LanguageServer}; use project::{search::SearchQuery, Project}; use std::{borrow::Cow, sync::Arc}; use ui::{popover_menu, prelude::*, Button, Checkbox, ContextMenu, Label, Selection}; -use util::maybe; use workspace::{ item::{Item, ItemHandle, TabContentParams}, searchable::{SearchEvent, SearchableItem, SearchableItemHandle}, @@ -24,17 +24,21 @@ const MAX_STORED_LOG_ENTRIES: usize = 2000; pub struct LogStore { projects: HashMap, ProjectState>, - io_tx: mpsc::UnboundedSender<(WeakModel, LanguageServerId, IoKind, String)>, + language_servers: HashMap, + copilot_log_subscription: Option, + _copilot_subscription: Option, + io_tx: mpsc::UnboundedSender<(LanguageServerId, IoKind, String)>, } struct ProjectState { - servers: HashMap, _subscriptions: [gpui::Subscription; 2], } struct LanguageServerState { + name: LanguageServerName, log_messages: VecDeque, rpc_state: Option, + project: Option>, _io_logs_subscription: Option, _lsp_logs_subscription: Option, } @@ -109,15 +113,55 @@ pub fn init(cx: &mut AppContext) { impl LogStore { pub fn new(cx: &mut ModelContext) -> Self { let (io_tx, mut io_rx) = mpsc::unbounded(); + + let copilot_subscription = Copilot::global(cx).map(|copilot| { + let copilot = &copilot; + cx.subscribe( + copilot, + |this, copilot, copilot_event, cx| match copilot_event { + copilot::Event::CopilotLanguageServerStarted => { + if let Some(server) = copilot.read(cx).language_server() { + let server_id = server.server_id(); + let weak_this = cx.weak_model(); + this.copilot_log_subscription = + Some(server.on_notification::( + move |params, mut cx| { + weak_this + .update(&mut cx, |this, cx| { + this.add_language_server_log( + server_id, + ¶ms.message, + cx, + ); + }) + .ok(); + }, + )); + this.add_language_server( + None, + LanguageServerName(Arc::from("copilot")), + server.clone(), + cx, + ); + } + } + }, + ) + }); + let this = Self { + copilot_log_subscription: None, + _copilot_subscription: copilot_subscription, projects: HashMap::default(), + language_servers: HashMap::default(), io_tx, }; + cx.spawn(|this, mut cx| async move { - while let Some((project, server_id, io_kind, message)) = io_rx.next().await { + while let Some((server_id, io_kind, message)) = io_rx.next().await { if let Some(this) = this.upgrade() { this.update(&mut cx, |this, cx| { - this.on_io(project, server_id, io_kind, &message, cx); + this.on_io(server_id, io_kind, &message, cx); })?; } } @@ -132,20 +176,32 @@ impl LogStore { self.projects.insert( project.downgrade(), ProjectState { - servers: HashMap::default(), _subscriptions: [ cx.observe_release(project, move |this, _, _| { this.projects.remove(&weak_project); + this.language_servers + .retain(|_, state| state.project.as_ref() != Some(&weak_project)); }), cx.subscribe(project, |this, project, event, cx| match event { project::Event::LanguageServerAdded(id) => { - this.add_language_server(&project, *id, cx); + let read_project = project.read(cx); + if let Some((server, adapter)) = read_project + .language_server_for_id(*id) + .zip(read_project.language_server_adapter_for_id(*id)) + { + this.add_language_server( + Some(&project.downgrade()), + adapter.name.clone(), + server, + cx, + ); + } } project::Event::LanguageServerRemoved(id) => { - this.remove_language_server(&project, *id, cx); + this.remove_language_server(*id, cx); } project::Event::LanguageServerLog(id, message) => { - this.add_language_server_log(&project, *id, message, cx); + this.add_language_server_log(*id, message, cx); } _ => {} }), @@ -154,74 +210,69 @@ impl LogStore { ); } - fn add_language_server( + fn get_language_server_state( &mut self, - project: &Model, id: LanguageServerId, + ) -> Option<&mut LanguageServerState> { + self.language_servers.get_mut(&id) + } + + fn add_language_server( + &mut self, + project: Option<&WeakModel>, + name: LanguageServerName, + server: Arc, cx: &mut ModelContext, ) -> Option<&mut LanguageServerState> { - let project_state = self.projects.get_mut(&project.downgrade())?; - let server_state = project_state.servers.entry(id).or_insert_with(|| { - cx.notify(); - LanguageServerState { - rpc_state: None, - log_messages: VecDeque::with_capacity(MAX_STORED_LOG_ENTRIES), - _io_logs_subscription: None, - _lsp_logs_subscription: None, - } - }); + let server_state = self + .language_servers + .entry(server.server_id()) + .or_insert_with(|| { + cx.notify(); + LanguageServerState { + name, + rpc_state: None, + project: project.cloned(), + log_messages: VecDeque::with_capacity(MAX_STORED_LOG_ENTRIES), + _io_logs_subscription: None, + _lsp_logs_subscription: None, + } + }); - let server = project.read(cx).language_server_for_id(id); - if let Some(server) = server.as_deref() { - if server.has_notification_handler::() { - // Another event wants to re-add the server that was already added and subscribed to, avoid doing it again. - return Some(server_state); - } + if server.has_notification_handler::() { + // Another event wants to re-add the server that was already added and subscribed to, avoid doing it again. + return Some(server_state); } - let weak_project = project.downgrade(); let io_tx = self.io_tx.clone(); - server_state._io_logs_subscription = server.as_ref().map(|server| { - server.on_io(move |io_kind, message| { - io_tx - .unbounded_send((weak_project.clone(), id, io_kind, message.to_string())) - .ok(); - }) - }); + let server_id = server.server_id(); + server_state._io_logs_subscription = Some(server.on_io(move |io_kind, message| { + io_tx + .unbounded_send((server_id, io_kind, message.to_string())) + .ok(); + })); let this = cx.handle().downgrade(); - let weak_project = project.downgrade(); - server_state._lsp_logs_subscription = server.map(|server| { - let server_id = server.server_id(); - server.on_notification::({ + server_state._lsp_logs_subscription = + Some(server.on_notification::({ move |params, mut cx| { - if let Some((project, this)) = weak_project.upgrade().zip(this.upgrade()) { + if let Some(this) = this.upgrade() { this.update(&mut cx, |this, cx| { - this.add_language_server_log(&project, server_id, ¶ms.message, cx); + this.add_language_server_log(server_id, ¶ms.message, cx); }) .ok(); } } - }) - }); + })); Some(server_state) } fn add_language_server_log( &mut self, - project: &Model, id: LanguageServerId, message: &str, cx: &mut ModelContext, ) -> Option<()> { - let language_server_state = match self - .projects - .get_mut(&project.downgrade())? - .servers - .get_mut(&id) - { - Some(existing_state) => existing_state, - None => self.add_language_server(&project, id, cx)?, - }; + let language_server_state = self.get_language_server_state(id)?; let log_lines = &mut language_server_state.log_messages; while log_lines.len() >= MAX_STORED_LOG_ENTRIES { @@ -238,38 +289,43 @@ impl LogStore { Some(()) } - fn remove_language_server( - &mut self, - project: &Model, - id: LanguageServerId, - cx: &mut ModelContext, - ) -> Option<()> { - let project_state = self.projects.get_mut(&project.downgrade())?; - project_state.servers.remove(&id); + fn remove_language_server(&mut self, id: LanguageServerId, cx: &mut ModelContext) { + self.language_servers.remove(&id); cx.notify(); - Some(()) } - fn server_logs( - &self, - project: &Model, - server_id: LanguageServerId, - ) -> Option<&VecDeque> { - let weak_project = project.downgrade(); - let project_state = self.projects.get(&weak_project)?; - let server_state = project_state.servers.get(&server_id)?; - Some(&server_state.log_messages) + fn server_logs(&self, server_id: LanguageServerId) -> Option<&VecDeque> { + Some(&self.language_servers.get(&server_id)?.log_messages) + } + + fn server_ids_for_project<'a>( + &'a self, + project: &'a WeakModel, + ) -> impl Iterator + 'a { + [].into_iter() + .chain(self.language_servers.iter().filter_map(|(id, state)| { + if state.project.as_ref() == Some(project) { + return Some(*id); + } else { + None + } + })) + .chain(self.language_servers.iter().filter_map(|(id, state)| { + if state.project.is_none() { + return Some(*id); + } else { + None + } + })) } fn enable_rpc_trace_for_language_server( &mut self, - project: &Model, server_id: LanguageServerId, ) -> Option<&mut LanguageServerRpcState> { - let weak_project = project.downgrade(); - let project_state = self.projects.get_mut(&weak_project)?; - let server_state = project_state.servers.get_mut(&server_id)?; - let rpc_state = server_state + let rpc_state = self + .language_servers + .get_mut(&server_id)? .rpc_state .get_or_insert_with(|| LanguageServerRpcState { rpc_messages: VecDeque::with_capacity(MAX_STORED_LOG_ENTRIES), @@ -280,20 +336,14 @@ impl LogStore { pub fn disable_rpc_trace_for_language_server( &mut self, - project: &Model, server_id: LanguageServerId, - _: &mut ModelContext, ) -> Option<()> { - let project = project.downgrade(); - let project_state = self.projects.get_mut(&project)?; - let server_state = project_state.servers.get_mut(&server_id)?; - server_state.rpc_state.take(); + self.language_servers.get_mut(&server_id)?.rpc_state.take(); Some(()) } fn on_io( &mut self, - project: WeakModel, language_server_id: LanguageServerId, io_kind: IoKind, message: &str, @@ -303,18 +353,14 @@ impl LogStore { IoKind::StdOut => true, IoKind::StdIn => false, IoKind::StdErr => { - let project = project.upgrade()?; let message = format!("stderr: {}", message.trim()); - self.add_language_server_log(&project, language_server_id, &message, cx); + self.add_language_server_log(language_server_id, &message, cx); return Some(()); } }; let state = self - .projects - .get_mut(&project)? - .servers - .get_mut(&language_server_id)? + .get_language_server_state(language_server_id)? .rpc_state .as_mut()?; let kind = if is_received { @@ -360,42 +406,40 @@ impl LspLogView { ) -> Self { let server_id = log_store .read(cx) - .projects - .get(&project.downgrade()) - .and_then(|project| project.servers.keys().copied().next()); - let model_changes_subscription = cx.observe(&log_store, |this, store, cx| { - maybe!({ - let project_state = store.read(cx).projects.get(&this.project.downgrade())?; - if let Some(current_lsp) = this.current_server_id { - if !project_state.servers.contains_key(¤t_lsp) { - if let Some(server) = project_state.servers.iter().next() { - if this.is_showing_rpc_trace { - this.show_rpc_trace_for_server(*server.0, cx) - } else { - this.show_logs_for_server(*server.0, cx) - } - } else { - this.current_server_id = None; - this.editor.update(cx, |editor, cx| { - editor.set_read_only(false); - editor.clear(cx); - editor.set_read_only(true); - }); - cx.notify(); - } - } - } else { - if let Some(server) = project_state.servers.iter().next() { + .language_servers + .iter() + .find(|(_, server)| server.project == Some(project.downgrade())) + .map(|(id, _)| *id); + + let weak_project = project.downgrade(); + let model_changes_subscription = cx.observe(&log_store, move |this, store, cx| { + let first_server_id_for_project = + store.read(cx).server_ids_for_project(&weak_project).next(); + if let Some(current_lsp) = this.current_server_id { + if !store.read(cx).language_servers.contains_key(¤t_lsp) { + if let Some(server_id) = first_server_id_for_project { if this.is_showing_rpc_trace { - this.show_rpc_trace_for_server(*server.0, cx) + this.show_rpc_trace_for_server(server_id, cx) } else { - this.show_logs_for_server(*server.0, cx) + this.show_logs_for_server(server_id, cx) } + } else { + this.current_server_id = None; + this.editor.update(cx, |editor, cx| { + editor.set_read_only(false); + editor.clear(cx); + editor.set_read_only(true); + }); + cx.notify(); } } - - Some(()) - }); + } else if let Some(server_id) = first_server_id_for_project { + if this.is_showing_rpc_trace { + this.show_rpc_trace_for_server(server_id, cx) + } else { + this.show_logs_for_server(server_id, cx) + } + } cx.notify(); }); @@ -477,14 +521,14 @@ impl LspLogView { pub(crate) fn menu_items<'a>(&'a self, cx: &'a AppContext) -> Option> { let log_store = self.log_store.read(cx); - let state = log_store.projects.get(&self.project.downgrade())?; + let mut rows = self .project .read(cx) .language_servers() .filter_map(|(server_id, language_server_name, worktree_id)| { let worktree = self.project.read(cx).worktree_for_id(worktree_id, cx)?; - let state = state.servers.get(&server_id)?; + let state = log_store.language_servers.get(&server_id)?; Some(LogMenuItem { server_id, server_name: language_server_name, @@ -501,7 +545,7 @@ impl LspLogView { .read(cx) .supplementary_language_servers() .filter_map(|(&server_id, (name, _))| { - let state = state.servers.get(&server_id)?; + let state = log_store.language_servers.get(&server_id)?; Some(LogMenuItem { server_id, server_name: name.clone(), @@ -514,6 +558,27 @@ impl LspLogView { }) }), ) + .chain( + log_store + .language_servers + .iter() + .filter_map(|(server_id, state)| { + if state.project.is_none() { + Some(LogMenuItem { + server_id: *server_id, + server_name: state.name.clone(), + worktree_root_name: "supplementary".to_string(), + rpc_trace_enabled: state.rpc_state.is_some(), + rpc_trace_selected: self.is_showing_rpc_trace + && self.current_server_id == Some(*server_id), + logs_selected: !self.is_showing_rpc_trace + && self.current_server_id == Some(*server_id), + }) + } else { + None + } + }), + ) .collect::>(); rows.sort_by_key(|row| row.server_id); rows.dedup_by_key(|row| row.server_id); @@ -524,7 +589,7 @@ impl LspLogView { let log_contents = self .log_store .read(cx) - .server_logs(&self.project, server_id) + .server_logs(server_id) .map(log_contents); if let Some(log_contents) = log_contents { self.current_server_id = Some(server_id); @@ -544,7 +609,7 @@ impl LspLogView { ) { let rpc_log = self.log_store.update(cx, |log_store, _| { log_store - .enable_rpc_trace_for_language_server(&self.project, server_id) + .enable_rpc_trace_for_language_server(server_id) .map(|state| log_contents(&state.rpc_messages)) }); if let Some(rpc_log) = rpc_log { @@ -585,11 +650,11 @@ impl LspLogView { enabled: bool, cx: &mut ViewContext, ) { - self.log_store.update(cx, |log_store, cx| { + self.log_store.update(cx, |log_store, _| { if enabled { - log_store.enable_rpc_trace_for_language_server(&self.project, server_id); + log_store.enable_rpc_trace_for_language_server(server_id); } else { - log_store.disable_rpc_trace_for_language_server(&self.project, server_id, cx); + log_store.disable_rpc_trace_for_language_server(server_id); } }); if !enabled && Some(server_id) == self.current_server_id { diff --git a/crates/project/Cargo.toml b/crates/project/Cargo.toml index 30766a7b6feb954d6364257af62b40cad874675b..1d943bc080dfd05d1e4216620d8e88912dcee646 100644 --- a/crates/project/Cargo.toml +++ b/crates/project/Cargo.toml @@ -30,7 +30,6 @@ async-trait.workspace = true client.workspace = true clock.workspace = true collections.workspace = true -copilot.workspace = true fs.workspace = true futures.workspace = true fuzzy.workspace = true diff --git a/crates/project/src/project.rs b/crates/project/src/project.rs index 733a06172b28fa24ba2ebf5a840060f0d7cb9a3e..28c61820164410a728db209f96b1d71a1d2c4397 100644 --- a/crates/project/src/project.rs +++ b/crates/project/src/project.rs @@ -20,7 +20,6 @@ use client::{ }; use clock::ReplicaId; use collections::{hash_map, BTreeMap, HashMap, HashSet, VecDeque}; -use copilot::Copilot; use debounced_delay::DebouncedDelay; use futures::{ channel::{ @@ -200,8 +199,6 @@ pub struct Project { _maintain_buffer_languages: Task<()>, _maintain_workspace_config: Task>, terminals: Terminals, - copilot_lsp_subscription: Option, - copilot_log_subscription: Option, current_lsp_settings: HashMap, LspSettings>, node: Option>, default_prettier: DefaultPrettier, @@ -685,8 +682,6 @@ impl Project { let (tx, rx) = mpsc::unbounded(); cx.spawn(move |this, cx| Self::send_buffer_ordered_messages(this, rx, cx)) .detach(); - let copilot_lsp_subscription = - Copilot::global(cx).map(|copilot| subscribe_for_copilot_events(&copilot, cx)); let tasks = Inventory::new(cx); Self { @@ -735,8 +730,6 @@ impl Project { terminals: Terminals { local_handles: Vec::new(), }, - copilot_lsp_subscription, - copilot_log_subscription: None, current_lsp_settings: ProjectSettings::get_global(cx).lsp.clone(), node: Some(node), default_prettier: DefaultPrettier::default(), @@ -823,8 +816,6 @@ impl Project { let (tx, rx) = mpsc::unbounded(); cx.spawn(move |this, cx| Self::send_buffer_ordered_messages(this, rx, cx)) .detach(); - let copilot_lsp_subscription = - Copilot::global(cx).map(|copilot| subscribe_for_copilot_events(&copilot, cx)); let mut this = Self { worktrees: Vec::new(), buffer_ordered_messages_tx: tx, @@ -891,8 +882,6 @@ impl Project { terminals: Terminals { local_handles: Vec::new(), }, - copilot_lsp_subscription, - copilot_log_subscription: None, current_lsp_settings: ProjectSettings::get_global(cx).lsp.clone(), node: None, default_prettier: DefaultPrettier::default(), @@ -1184,17 +1173,6 @@ impl Project { self.restart_language_servers(worktree, language, cx); } - if self.copilot_lsp_subscription.is_none() { - if let Some(copilot) = Copilot::global(cx) { - for buffer in self.opened_buffers.values() { - if let Some(buffer) = buffer.upgrade() { - self.register_buffer_with_copilot(&buffer, cx); - } - } - self.copilot_lsp_subscription = Some(subscribe_for_copilot_events(&copilot, cx)); - } - } - cx.notify(); } @@ -2351,7 +2329,7 @@ impl Project { self.detect_language_for_buffer(buffer, cx); self.register_buffer_with_language_servers(buffer, cx); - self.register_buffer_with_copilot(buffer, cx); + // self.register_buffer_with_copilot(buffer, cx); cx.observe_release(buffer, |this, buffer, cx| { if let Some(file) = File::from_dyn(buffer.file()) { if file.is_local() { @@ -2500,15 +2478,15 @@ impl Project { }); } - fn register_buffer_with_copilot( - &self, - buffer_handle: &Model, - cx: &mut ModelContext, - ) { - if let Some(copilot) = Copilot::global(cx) { - copilot.update(cx, |copilot, cx| copilot.register_buffer(buffer_handle, cx)); - } - } + // fn register_buffer_with_copilot( + // &self, + // buffer_handle: &Model, + // cx: &mut ModelContext, + // ) { + // if let Some(copilot) = Copilot::global(cx) { + // copilot.update(cx, |copilot, cx| copilot.register_buffer(buffer_handle, cx)); + // } + // } async fn send_buffer_ordered_messages( this: WeakModel, @@ -10475,43 +10453,6 @@ async fn search_ignored_entry( } } -fn subscribe_for_copilot_events( - copilot: &Model, - cx: &mut ModelContext<'_, Project>, -) -> gpui::Subscription { - cx.subscribe( - copilot, - |project, copilot, copilot_event, cx| match copilot_event { - copilot::Event::CopilotLanguageServerStarted => { - match copilot.read(cx).language_server() { - Some((name, copilot_server)) => { - // Another event wants to re-add the server that was already added and subscribed to, avoid doing it again. - if !copilot_server.has_notification_handler::() { - let new_server_id = copilot_server.server_id(); - let weak_project = cx.weak_model(); - let copilot_log_subscription = copilot_server - .on_notification::( - move |params, mut cx| { - weak_project.update(&mut cx, |_, cx| { - cx.emit(Event::LanguageServerLog( - new_server_id, - params.message, - )); - }).ok(); - }, - ); - project.supplementary_language_servers.insert(new_server_id, (name.clone(), Arc::clone(copilot_server))); - project.copilot_log_subscription = Some(copilot_log_subscription); - cx.emit(Event::LanguageServerAdded(new_server_id)); - } - } - None => debug_panic!("Received Copilot language server started event, but no language server is running"), - } - } - }, - ) -} - fn glob_literal_prefix(glob: &str) -> &str { let mut literal_end = 0; for (i, part) in glob.split(path::MAIN_SEPARATOR).enumerate() { diff --git a/crates/rpc/proto/zed.proto b/crates/rpc/proto/zed.proto index 3dfa9508dc31b4678c27403086208c0e3d8d4f1d..5f8af8e1f07f58f86c6183f83148b0a771a8e548 100644 --- a/crates/rpc/proto/zed.proto +++ b/crates/rpc/proto/zed.proto @@ -207,7 +207,7 @@ message Envelope { GetCachedEmbeddings get_cached_embeddings = 189; GetCachedEmbeddingsResponse get_cached_embeddings_response = 190; ComputeEmbeddings compute_embeddings = 191; - ComputeEmbeddingsResponse compute_embeddings_response = 192; // current max + ComputeEmbeddingsResponse compute_embeddings_response = 192; UpdateChannelMessage update_channel_message = 170; ChannelMessageUpdate channel_message_update = 171; @@ -238,7 +238,10 @@ message Envelope { ValidateDevServerProjectRequest validate_dev_server_project_request = 194; DeleteDevServer delete_dev_server = 195; OpenNewBuffer open_new_buffer = 196; - DeleteDevServerProject delete_dev_server_project = 197; // Current max + DeleteDevServerProject delete_dev_server_project = 197; + + GetSupermavenApiKey get_supermaven_api_key = 198; + GetSupermavenApiKeyResponse get_supermaven_api_key_response = 199; // current max } reserved 158 to 161; @@ -2084,3 +2087,9 @@ message LspResponse { GetCodeActionsResponse get_code_actions_response = 2; } } + +message GetSupermavenApiKey {} + +message GetSupermavenApiKeyResponse { + string api_key = 1; +} diff --git a/crates/rpc/src/proto.rs b/crates/rpc/src/proto.rs index 966a24ead9913ef70869db299cad2758a9ca89d4..d011f1d1d2dc0c3e749765367715d746722ec3af 100644 --- a/crates/rpc/src/proto.rs +++ b/crates/rpc/src/proto.rs @@ -201,6 +201,8 @@ messages!( (GetProjectSymbolsResponse, Background), (GetReferences, Background), (GetReferencesResponse, Background), + (GetSupermavenApiKey, Background), + (GetSupermavenApiKeyResponse, Background), (GetTypeDefinition, Background), (GetTypeDefinitionResponse, Background), (GetImplementation, Background), @@ -360,6 +362,7 @@ request_messages!( (GetPrivateUserInfo, GetPrivateUserInfoResponse), (GetProjectSymbols, GetProjectSymbolsResponse), (GetReferences, GetReferencesResponse), + (GetSupermavenApiKey, GetSupermavenApiKeyResponse), (GetTypeDefinition, GetTypeDefinitionResponse), (GetUsers, UsersResponse), (IncomingCall, Ack), diff --git a/crates/supermaven/Cargo.toml b/crates/supermaven/Cargo.toml new file mode 100644 index 0000000000000000000000000000000000000000..4abbcd4a436c299a9cf7cea8f2bf60800c4c823d --- /dev/null +++ b/crates/supermaven/Cargo.toml @@ -0,0 +1,41 @@ +[package] +name = "supermaven" +version = "0.1.0" +edition = "2021" +publish = false +license = "GPL-3.0-or-later" + +[lints] +workspace = true + +[lib] +path = "src/supermaven.rs" +doctest = false + +[dependencies] +anyhow.workspace = true +client.workspace = true +collections.workspace = true +editor.workspace = true +gpui.workspace = true +futures.workspace = true +language.workspace = true +log.workspace = true +postage.workspace = true +serde.workspace = true +serde_json.workspace = true +settings.workspace = true +supermaven_api.workspace = true +smol.workspace = true +ui.workspace = true +util.workspace = true + +[dev-dependencies] +editor = { workspace = true, features = ["test-support"] } +env_logger.workspace = true +gpui = { workspace = true, features = ["test-support"] } +language = { workspace = true, features = ["test-support"] } +project = { workspace = true, features = ["test-support"] } +settings = { workspace = true, features = ["test-support"] } +theme = { workspace = true, features = ["test-support"] } +util = { workspace = true, features = ["test-support"] } diff --git a/crates/supermaven/src/messages.rs b/crates/supermaven/src/messages.rs new file mode 100644 index 0000000000000000000000000000000000000000..9082e00d60f0e44ba65dd3089b2d70d8edb938dd --- /dev/null +++ b/crates/supermaven/src/messages.rs @@ -0,0 +1,152 @@ +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct SetApiKey { + pub api_key: String, +} + +// Outbound messages +#[derive(Debug, Serialize)] +#[serde(tag = "kind", rename_all = "snake_case")] +pub enum OutboundMessage { + SetApiKey(SetApiKey), + StateUpdate(StateUpdateMessage), + #[allow(dead_code)] + UseFreeVersion, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct StateUpdateMessage { + pub new_id: String, + pub updates: Vec, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(tag = "kind", rename_all = "snake_case")] +pub enum StateUpdate { + FileUpdate(FileUpdateMessage), + CursorUpdate(CursorPositionUpdateMessage), +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub struct FileUpdateMessage { + pub path: String, + pub content: String, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub struct CursorPositionUpdateMessage { + pub path: String, + pub offset: usize, +} + +// Inbound messages coming in on stdout + +#[derive(Debug, Serialize, Deserialize)] +#[serde(tag = "kind", rename_all = "snake_case")] +pub enum ResponseItem { + // A completion + Text { text: String }, + // Vestigial message type from old versions -- safe to ignore + Del { text: String }, + // Be able to delete whitespace prior to the cursor, likely for the rest of the completion + Dedent { text: String }, + // When the completion is over + End, + // Got the closing parentheses and shouldn't show any more after + Barrier, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SupermavenResponse { + pub state_id: String, + pub items: Vec, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct SupermavenMetadataMessage { + pub dust_strings: Option>, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct SupermavenTaskUpdateMessage { + pub task: String, + pub status: TaskStatus, + pub percent_complete: Option, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum TaskStatus { + InProgress, + Complete, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct SupermavenActiveRepoMessage { + pub repo_simple_name: Option, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(tag = "kind", rename_all = "snake_case")] +pub enum SupermavenPopupAction { + OpenUrl { label: String, url: String }, + NoOp { label: String }, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub struct SupermavenPopupMessage { + pub message: String, + pub actions: Vec, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(tag = "kind", rename_all = "camelCase")] +pub struct ActivationRequest { + pub activate_url: Option, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SupermavenSetMessage { + pub key: String, + pub value: serde_json::Value, +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub enum ServiceTier { + FreeNoLicense, + #[serde(other)] + Unknown, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(tag = "kind", rename_all = "snake_case")] +pub enum SupermavenMessage { + Response(SupermavenResponse), + Metadata(SupermavenMetadataMessage), + Apology { + message: Option, + }, + ActivationRequest(ActivationRequest), + ActivationSuccess, + Passthrough { + passthrough: Box, + }, + Popup(SupermavenPopupMessage), + TaskStatus(SupermavenTaskUpdateMessage), + ActiveRepo(SupermavenActiveRepoMessage), + ServiceTier { + service_tier: ServiceTier, + }, + + Set(SupermavenSetMessage), + #[serde(other)] + Unknown, +} diff --git a/crates/supermaven/src/supermaven.rs b/crates/supermaven/src/supermaven.rs new file mode 100644 index 0000000000000000000000000000000000000000..c4321163579bb00f4f6b3b7565249fcbcd3e3871 --- /dev/null +++ b/crates/supermaven/src/supermaven.rs @@ -0,0 +1,345 @@ +mod messages; +mod supermaven_completion_provider; + +pub use supermaven_completion_provider::*; + +use anyhow::{Context as _, Result}; +#[allow(unused_imports)] +use client::{proto, Client}; +use collections::BTreeMap; + +use futures::{channel::mpsc, io::BufReader, AsyncBufReadExt, StreamExt}; +use gpui::{AppContext, AsyncAppContext, EntityId, Global, Model, ModelContext, Task, WeakModel}; +use language::{language_settings::all_language_settings, Anchor, Buffer, ToOffset}; +use messages::*; +use postage::watch; +use serde::{Deserialize, Serialize}; +use settings::SettingsStore; +use smol::{ + io::AsyncWriteExt, + process::{Child, ChildStdin, ChildStdout, Command}, +}; +use std::{ops::Range, path::PathBuf, process::Stdio, sync::Arc}; +use ui::prelude::*; +use util::ResultExt; + +pub fn init(client: Arc, cx: &mut AppContext) { + let supermaven = cx.new_model(|_| Supermaven::Starting); + Supermaven::set_global(supermaven.clone(), cx); + + let mut provider = all_language_settings(None, cx).inline_completions.provider; + if provider == language::language_settings::InlineCompletionProvider::Supermaven { + supermaven.update(cx, |supermaven, cx| supermaven.start(client.clone(), cx)); + } + + cx.observe_global::(move |cx| { + let new_provider = all_language_settings(None, cx).inline_completions.provider; + if new_provider != provider { + provider = new_provider; + if provider == language::language_settings::InlineCompletionProvider::Supermaven { + supermaven.update(cx, |supermaven, cx| supermaven.start(client.clone(), cx)); + } else { + supermaven.update(cx, |supermaven, _cx| supermaven.stop()); + } + } + }) + .detach(); +} + +pub enum Supermaven { + Starting, + FailedDownload { error: anyhow::Error }, + Spawned(SupermavenAgent), + Error { error: anyhow::Error }, +} + +#[derive(Clone)] +pub enum AccountStatus { + Unknown, + NeedsActivation { activate_url: String }, + Ready, +} + +#[derive(Clone)] +struct SupermavenGlobal(Model); + +impl Global for SupermavenGlobal {} + +impl Supermaven { + pub fn global(cx: &AppContext) -> Option> { + cx.try_global::() + .map(|model| model.0.clone()) + } + + pub fn set_global(supermaven: Model, cx: &mut AppContext) { + cx.set_global(SupermavenGlobal(supermaven)); + } + + pub fn start(&mut self, client: Arc, cx: &mut ModelContext) { + if let Self::Starting = self { + cx.spawn(|this, mut cx| async move { + let binary_path = + supermaven_api::get_supermaven_agent_path(client.http_client()).await?; + + this.update(&mut cx, |this, cx| { + if let Self::Starting = this { + *this = + Self::Spawned(SupermavenAgent::new(binary_path, client.clone(), cx)?); + } + anyhow::Ok(()) + }) + }) + .detach_and_log_err(cx) + } + } + + pub fn stop(&mut self) { + *self = Self::Starting; + } + + pub fn is_enabled(&self) -> bool { + matches!(self, Self::Spawned { .. }) + } + + pub fn complete( + &mut self, + buffer: &Model, + cursor_position: Anchor, + cx: &AppContext, + ) -> Option { + if let Self::Spawned(agent) = self { + let buffer_id = buffer.entity_id(); + let buffer = buffer.read(cx); + let path = buffer + .file() + .and_then(|file| Some(file.as_local()?.abs_path(cx))) + .unwrap_or_else(|| PathBuf::from("untitled")) + .to_string_lossy() + .to_string(); + let content = buffer.text(); + let offset = cursor_position.to_offset(buffer); + let state_id = agent.next_state_id; + agent.next_state_id.0 += 1; + + let (updates_tx, mut updates_rx) = watch::channel(); + postage::stream::Stream::try_recv(&mut updates_rx).unwrap(); + + agent.states.insert( + state_id, + SupermavenCompletionState { + buffer_id, + range: cursor_position.bias_left(buffer)..cursor_position.bias_right(buffer), + completion: Vec::new(), + text: String::new(), + updates_tx, + }, + ); + let _ = agent + .outgoing_tx + .unbounded_send(OutboundMessage::StateUpdate(StateUpdateMessage { + new_id: state_id.0.to_string(), + updates: vec![ + StateUpdate::FileUpdate(FileUpdateMessage { + path: path.clone(), + content, + }), + StateUpdate::CursorUpdate(CursorPositionUpdateMessage { path, offset }), + ], + })); + + Some(SupermavenCompletion { + id: state_id, + updates: updates_rx, + }) + } else { + None + } + } + + pub fn completion( + &self, + id: SupermavenCompletionStateId, + ) -> Option<&SupermavenCompletionState> { + if let Self::Spawned(agent) = self { + agent.states.get(&id) + } else { + None + } + } +} + +pub struct SupermavenAgent { + _process: Child, + next_state_id: SupermavenCompletionStateId, + states: BTreeMap, + outgoing_tx: mpsc::UnboundedSender, + _handle_outgoing_messages: Task>, + _handle_incoming_messages: Task>, + pub account_status: AccountStatus, + service_tier: Option, + #[allow(dead_code)] + client: Arc, +} + +impl SupermavenAgent { + fn new( + binary_path: PathBuf, + client: Arc, + cx: &mut ModelContext, + ) -> Result { + let mut process = Command::new(&binary_path) + .arg("stdio") + .stdin(Stdio::piped()) + .stdout(Stdio::piped()) + .stderr(Stdio::piped()) + .kill_on_drop(true) + .spawn() + .context("failed to start the binary")?; + + let stdin = process + .stdin + .take() + .context("failed to get stdin for process")?; + let stdout = process + .stdout + .take() + .context("failed to get stdout for process")?; + + let (outgoing_tx, outgoing_rx) = mpsc::unbounded(); + + cx.spawn({ + let client = client.clone(); + let outgoing_tx = outgoing_tx.clone(); + move |this, mut cx| async move { + let mut status = client.status(); + while let Some(status) = status.next().await { + if status.is_connected() { + let api_key = client.request(proto::GetSupermavenApiKey {}).await?.api_key; + outgoing_tx + .unbounded_send(OutboundMessage::SetApiKey(SetApiKey { api_key })) + .ok(); + this.update(&mut cx, |this, cx| { + if let Supermaven::Spawned(this) = this { + this.account_status = AccountStatus::Ready; + cx.notify(); + } + })?; + break; + } + } + return anyhow::Ok(()); + } + }) + .detach(); + + Ok(Self { + _process: process, + next_state_id: SupermavenCompletionStateId::default(), + states: BTreeMap::default(), + outgoing_tx, + _handle_outgoing_messages: cx + .spawn(|_, _cx| Self::handle_outgoing_messages(outgoing_rx, stdin)), + _handle_incoming_messages: cx + .spawn(|this, cx| Self::handle_incoming_messages(this, stdout, cx)), + account_status: AccountStatus::Unknown, + service_tier: None, + client, + }) + } + + async fn handle_outgoing_messages( + mut outgoing: mpsc::UnboundedReceiver, + mut stdin: ChildStdin, + ) -> Result<()> { + while let Some(message) = outgoing.next().await { + let bytes = serde_json::to_vec(&message)?; + stdin.write_all(&bytes).await?; + stdin.write_all(&[b'\n']).await?; + } + Ok(()) + } + + async fn handle_incoming_messages( + this: WeakModel, + stdout: ChildStdout, + mut cx: AsyncAppContext, + ) -> Result<()> { + const MESSAGE_PREFIX: &str = "SM-MESSAGE "; + + let stdout = BufReader::new(stdout); + let mut lines = stdout.lines(); + while let Some(line) = lines.next().await { + let Some(line) = line.context("failed to read line from stdout").log_err() else { + continue; + }; + let Some(line) = line.strip_prefix(MESSAGE_PREFIX) else { + continue; + }; + let Some(message) = serde_json::from_str::(&line) + .with_context(|| format!("failed to deserialize line from stdout: {:?}", line)) + .log_err() + else { + continue; + }; + + this.update(&mut cx, |this, _cx| { + if let Supermaven::Spawned(this) = this { + this.handle_message(message); + } + Task::ready(anyhow::Ok(())) + })? + .await?; + } + + Ok(()) + } + + fn handle_message(&mut self, message: SupermavenMessage) { + match message { + SupermavenMessage::ActivationRequest(request) => { + self.account_status = match request.activate_url { + Some(activate_url) => AccountStatus::NeedsActivation { + activate_url: activate_url.clone(), + }, + None => AccountStatus::Ready, + }; + } + SupermavenMessage::ServiceTier { service_tier } => { + self.service_tier = Some(service_tier); + } + SupermavenMessage::Response(response) => { + let state_id = SupermavenCompletionStateId(response.state_id.parse().unwrap()); + if let Some(state) = self.states.get_mut(&state_id) { + for item in &response.items { + if let ResponseItem::Text { text } = item { + state.text.push_str(text); + } + } + state.completion.extend(response.items); + *state.updates_tx.borrow_mut() = (); + } + } + SupermavenMessage::Passthrough { passthrough } => self.handle_message(*passthrough), + _ => { + log::warn!("unhandled message: {:?}", message); + } + } + } +} + +#[derive(Copy, Clone, Debug, Default, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize)] +pub struct SupermavenCompletionStateId(usize); + +#[allow(dead_code)] +pub struct SupermavenCompletionState { + buffer_id: EntityId, + range: Range, + completion: Vec, + text: String, + updates_tx: watch::Sender<()>, +} + +pub struct SupermavenCompletion { + pub id: SupermavenCompletionStateId, + pub updates: watch::Receiver<()>, +} diff --git a/crates/supermaven/src/supermaven_completion_provider.rs b/crates/supermaven/src/supermaven_completion_provider.rs new file mode 100644 index 0000000000000000000000000000000000000000..8dc06bfac011a4ccdc48fce7773f609f0d2a5e9c --- /dev/null +++ b/crates/supermaven/src/supermaven_completion_provider.rs @@ -0,0 +1,131 @@ +use crate::{Supermaven, SupermavenCompletionStateId}; +use anyhow::Result; +use editor::{Direction, InlineCompletionProvider}; +use futures::StreamExt as _; +use gpui::{AppContext, Model, ModelContext, Task}; +use language::{ + language_settings::all_language_settings, Anchor, Buffer, OffsetRangeExt as _, ToOffset, +}; +use std::time::Duration; + +pub const DEBOUNCE_TIMEOUT: Duration = Duration::from_millis(75); + +pub struct SupermavenCompletionProvider { + supermaven: Model, + completion_id: Option, + pending_refresh: Task>, +} + +impl SupermavenCompletionProvider { + pub fn new(supermaven: Model) -> Self { + Self { + supermaven, + completion_id: None, + pending_refresh: Task::ready(Ok(())), + } + } +} + +impl InlineCompletionProvider for SupermavenCompletionProvider { + fn is_enabled(&self, buffer: &Model, cursor_position: Anchor, cx: &AppContext) -> bool { + if !self.supermaven.read(cx).is_enabled() { + return false; + } + + let buffer = buffer.read(cx); + let file = buffer.file(); + let language = buffer.language_at(cursor_position); + let settings = all_language_settings(file, cx); + settings.inline_completions_enabled(language.as_ref(), file.map(|f| f.path().as_ref())) + } + + fn refresh( + &mut self, + buffer_handle: Model, + cursor_position: Anchor, + debounce: bool, + cx: &mut ModelContext, + ) { + let Some(mut completion) = self.supermaven.update(cx, |supermaven, cx| { + supermaven.complete(&buffer_handle, cursor_position, cx) + }) else { + return; + }; + + self.pending_refresh = cx.spawn(|this, mut cx| async move { + if debounce { + cx.background_executor().timer(DEBOUNCE_TIMEOUT).await; + } + + while let Some(()) = completion.updates.next().await { + this.update(&mut cx, |this, cx| { + this.completion_id = Some(completion.id); + cx.notify(); + })?; + } + Ok(()) + }); + } + + fn cycle( + &mut self, + _buffer: Model, + _cursor_position: Anchor, + _direction: Direction, + _cx: &mut ModelContext, + ) { + // todo!("cycling") + } + + fn accept(&mut self, _cx: &mut ModelContext) { + self.pending_refresh = Task::ready(Ok(())); + self.completion_id = None; + } + + fn discard(&mut self, _cx: &mut ModelContext) { + self.pending_refresh = Task::ready(Ok(())); + self.completion_id = None; + } + + fn active_completion_text<'a>( + &'a self, + buffer: &Model, + cursor_position: Anchor, + cx: &'a AppContext, + ) -> Option<&'a str> { + let completion_id = self.completion_id?; + let buffer = buffer.read(cx); + let cursor_offset = cursor_position.to_offset(buffer); + let completion = self.supermaven.read(cx).completion(completion_id)?; + + let mut completion_range = completion.range.to_offset(buffer); + + let prefix_len = common_prefix( + buffer.chars_for_range(completion_range.clone()), + completion.text.chars(), + ); + completion_range.start += prefix_len; + let suffix_len = common_prefix( + buffer.reversed_chars_for_range(completion_range.clone()), + completion.text[prefix_len..].chars().rev(), + ); + completion_range.end = completion_range.end.saturating_sub(suffix_len); + + let completion_text = &completion.text[prefix_len..completion.text.len() - suffix_len]; + if completion_range.is_empty() + && completion_range.start == cursor_offset + && !completion_text.trim().is_empty() + { + Some(completion_text) + } else { + None + } + } +} + +fn common_prefix, T2: Iterator>(a: T1, b: T2) -> usize { + a.zip(b) + .take_while(|(a, b)| a == b) + .map(|(a, _)| a.len_utf8()) + .sum() +} diff --git a/crates/supermaven_api/Cargo.toml b/crates/supermaven_api/Cargo.toml new file mode 100644 index 0000000000000000000000000000000000000000..69b69652832da39ba98bedc83915554f30dd64b2 --- /dev/null +++ b/crates/supermaven_api/Cargo.toml @@ -0,0 +1,21 @@ +[package] +name = "supermaven_api" +version = "0.1.0" +edition = "2021" +publish = false +license = "GPL-3.0-or-later" + +[lints] +workspace = true + +[lib] +path = "src/supermaven_api.rs" +doctest = false + +[dependencies] +anyhow.workspace = true +futures.workspace = true +serde.workspace = true +serde_json.workspace = true +smol.workspace = true +util.workspace = true diff --git a/crates/supermaven_api/src/supermaven_api.rs b/crates/supermaven_api/src/supermaven_api.rs new file mode 100644 index 0000000000000000000000000000000000000000..9d55bc541319b68a0bac822fd82cc646db9986f6 --- /dev/null +++ b/crates/supermaven_api/src/supermaven_api.rs @@ -0,0 +1,291 @@ +use anyhow::{anyhow, Context, Result}; +use futures::io::BufReader; +use futures::{AsyncReadExt, Future}; +use serde::{Deserialize, Serialize}; +use smol::fs::{self, File}; +use smol::stream::StreamExt; +use std::path::{Path, PathBuf}; +use std::sync::Arc; +use util::http::{AsyncBody, HttpClient, Request as HttpRequest}; +use util::paths::SUPERMAVEN_DIR; + +#[derive(Serialize)] +pub struct GetExternalUserRequest { + pub id: String, +} + +#[derive(Serialize)] +pub struct CreateExternalUserRequest { + pub id: String, + pub email: String, +} + +#[derive(Serialize)] +pub struct DeleteExternalUserRequest { + pub id: String, +} + +#[derive(Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct CreateExternalUserResponse { + pub api_key: String, +} + +#[derive(Deserialize)] +pub struct SupermavenApiError { + pub message: String, +} + +pub struct SupermavenBinary {} + +pub struct SupermavenAdminApi { + admin_api_key: String, + api_url: String, + http_client: Arc, +} + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SupermavenDownloadResponse { + pub download_url: String, + pub version: u64, + pub sha256_hash: String, +} + +#[derive(Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SupermavenUser { + id: String, + email: String, + api_key: String, +} + +impl SupermavenAdminApi { + pub fn new(admin_api_key: String, http_client: Arc) -> Self { + Self { + admin_api_key, + api_url: "https://supermaven.com/api/".to_string(), + http_client, + } + } + + pub async fn try_get_user( + &self, + request: GetExternalUserRequest, + ) -> Result> { + let uri = format!("{}external-user/{}", &self.api_url, &request.id); + + let request = HttpRequest::get(&uri).header("Authorization", self.admin_api_key.clone()); + + let mut response = self + .http_client + .send(request.body(AsyncBody::default())?) + .await + .with_context(|| "Unable to get Supermaven API Key".to_string())?; + + let mut body = Vec::new(); + response.body_mut().read_to_end(&mut body).await?; + + if response.status().is_client_error() { + let error: SupermavenApiError = serde_json::from_slice(&body)?; + if error.message == "User not found" { + return Ok(None); + } else { + return Err(anyhow!("Supermaven API error: {}", error.message)); + } + } else if response.status().is_server_error() { + let error: SupermavenApiError = serde_json::from_slice(&body)?; + return Err(anyhow!("Supermaven API server error").context(error.message)); + } + + let body_str = std::str::from_utf8(&body)?; + + Ok(Some( + serde_json::from_str::(body_str) + .with_context(|| "Unable to parse Supermaven user response".to_string())?, + )) + } + + pub async fn try_create_user( + &self, + request: CreateExternalUserRequest, + ) -> Result { + let uri = format!("{}external-user", &self.api_url); + + let request = HttpRequest::post(&uri) + .header("Authorization", self.admin_api_key.clone()) + .body(AsyncBody::from(serde_json::to_vec(&request)?))?; + + let mut response = self + .http_client + .send(request) + .await + .with_context(|| "Unable to create Supermaven API Key".to_string())?; + + let mut body = Vec::new(); + response.body_mut().read_to_end(&mut body).await?; + + let body_str = std::str::from_utf8(&body)?; + + if !response.status().is_success() { + let error: SupermavenApiError = serde_json::from_slice(&body)?; + return Err(anyhow!("Supermaven API server error").context(error.message)); + } + + serde_json::from_str::(body_str) + .with_context(|| "Unable to parse Supermaven API Key response".to_string()) + } + + pub async fn try_delete_user(&self, request: DeleteExternalUserRequest) -> Result<()> { + let uri = format!("{}external-user/{}", &self.api_url, &request.id); + + let request = HttpRequest::delete(&uri).header("Authorization", self.admin_api_key.clone()); + + let mut response = self + .http_client + .send(request.body(AsyncBody::default())?) + .await + .with_context(|| "Unable to delete Supermaven User".to_string())?; + + let mut body = Vec::new(); + response.body_mut().read_to_end(&mut body).await?; + + if response.status().is_client_error() { + let error: SupermavenApiError = serde_json::from_slice(&body)?; + if error.message == "User not found" { + return Ok(()); + } else { + return Err(anyhow!("Supermaven API error: {}", error.message)); + } + } else if response.status().is_server_error() { + let error: SupermavenApiError = serde_json::from_slice(&body)?; + return Err(anyhow!("Supermaven API server error").context(error.message)); + } + + Ok(()) + } + + pub async fn try_get_or_create_user( + &self, + request: CreateExternalUserRequest, + ) -> Result { + let get_user_request = GetExternalUserRequest { + id: request.id.clone(), + }; + + match self.try_get_user(get_user_request).await? { + None => self.try_create_user(request).await, + Some(SupermavenUser { api_key, .. }) => Ok(CreateExternalUserResponse { api_key }), + } + } +} + +pub async fn latest_release( + client: Arc, + platform: &str, + arch: &str, +) -> Result { + let uri = format!( + "https://supermaven.com/api/download-path?platform={}&arch={}", + platform, arch + ); + + // Download is not authenticated + let request = HttpRequest::get(&uri); + + let mut response = client + .send(request.body(AsyncBody::default())?) + .await + .with_context(|| "Unable to acquire Supermaven Agent".to_string())?; + + let mut body = Vec::new(); + response.body_mut().read_to_end(&mut body).await?; + + if response.status().is_client_error() || response.status().is_server_error() { + let body_str = std::str::from_utf8(&body)?; + let error: SupermavenApiError = serde_json::from_str(body_str)?; + return Err(anyhow!("Supermaven API error: {}", error.message)); + } + + serde_json::from_slice::(&body) + .with_context(|| "Unable to parse Supermaven Agent response".to_string()) +} + +pub fn version_path(version: u64) -> PathBuf { + SUPERMAVEN_DIR.join(format!("sm-agent-{}", version)) +} + +pub async fn has_version(version_path: &Path) -> bool { + fs::metadata(version_path) + .await + .map_or(false, |m| m.is_file()) +} + +pub fn get_supermaven_agent_path( + client: Arc, +) -> impl Future> { + async move { + fs::create_dir_all(&*SUPERMAVEN_DIR) + .await + .with_context(|| { + format!( + "Could not create Supermaven Agent Directory at {:?}", + &*SUPERMAVEN_DIR + ) + })?; + + let platform = match std::env::consts::OS { + "macos" => "darwin", + "windows" => "windows", + "linux" => "linux", + _ => return Err(anyhow!("unsupported platform")), + }; + + let arch = match std::env::consts::ARCH { + "x86_64" => "amd64", + "aarch64" => "arm64", + _ => return Err(anyhow!("unsupported architecture")), + }; + + let download_info = latest_release(client.clone(), platform, arch).await?; + + let binary_path = version_path(download_info.version); + + if has_version(&binary_path).await { + return Ok(binary_path); + } + + let request = HttpRequest::get(&download_info.download_url); + + let mut response = client + .send(request.body(AsyncBody::default())?) + .await + .with_context(|| "Unable to download Supermaven Agent".to_string())?; + + let mut file = File::create(&binary_path) + .await + .with_context(|| format!("Unable to create file at {:?}", binary_path))?; + + futures::io::copy(BufReader::new(response.body_mut()), &mut file) + .await + .with_context(|| format!("Unable to write binary to file at {:?}", binary_path))?; + + #[cfg(not(windows))] + { + file.set_permissions(::from_mode( + 0o755, + )) + .await?; + } + + let mut old_binary_paths = fs::read_dir(&*SUPERMAVEN_DIR).await?; + while let Some(old_binary_path) = old_binary_paths.next().await { + let old_binary_path = old_binary_path?; + if old_binary_path.path() != binary_path { + fs::remove_file(old_binary_path.path()).await?; + } + } + + Ok(binary_path) + } +} diff --git a/crates/ui/src/components/icon.rs b/crates/ui/src/components/icon.rs index bc05a8f3d37ff4fcf918837fb39b22103f82331d..9c9e05d6b683497347f73c09a3df0121b06c47a9 100644 --- a/crates/ui/src/components/icon.rs +++ b/crates/ui/src/components/icon.rs @@ -155,6 +155,10 @@ pub enum IconName { Space, Split, Spinner, + Supermaven, + SupermavenDisabled, + SupermavenError, + SupermavenInit, Tab, Terminal, Trash, @@ -261,6 +265,10 @@ impl IconName { IconName::Space => "icons/space.svg", IconName::Split => "icons/split.svg", IconName::Spinner => "icons/spinner.svg", + IconName::Supermaven => "icons/supermaven.svg", + IconName::SupermavenDisabled => "icons/supermaven_disabled.svg", + IconName::SupermavenError => "icons/supermaven_error.svg", + IconName::SupermavenInit => "icons/supermaven_init.svg", IconName::Tab => "icons/tab.svg", IconName::Terminal => "icons/terminal.svg", IconName::Trash => "icons/trash.svg", diff --git a/crates/util/src/paths.rs b/crates/util/src/paths.rs index 205ea72f0af71f19e0cd3f4b7dee6e0a02fd3d5b..feb7c195352b8446e546d8efac1953f224f719ee 100644 --- a/crates/util/src/paths.rs +++ b/crates/util/src/paths.rs @@ -52,6 +52,7 @@ lazy_static::lazy_static! { pub static ref EXTENSIONS_DIR: PathBuf = SUPPORT_DIR.join("extensions"); pub static ref LANGUAGES_DIR: PathBuf = SUPPORT_DIR.join("languages"); pub static ref COPILOT_DIR: PathBuf = SUPPORT_DIR.join("copilot"); + pub static ref SUPERMAVEN_DIR: PathBuf = SUPPORT_DIR.join("supermaven"); pub static ref DEFAULT_PRETTIER_DIR: PathBuf = SUPPORT_DIR.join("prettier"); pub static ref DB_DIR: PathBuf = SUPPORT_DIR.join("db"); pub static ref CRASHES_DIR: Option = cfg!(target_os = "macos") diff --git a/crates/welcome/Cargo.toml b/crates/welcome/Cargo.toml index c18a09673ff8c97e8a8b1c557536b2dc3e972d2e..e747072cdeba704686300da0a6cc1bcbfb5acb87 100644 --- a/crates/welcome/Cargo.toml +++ b/crates/welcome/Cargo.toml @@ -17,7 +17,7 @@ test-support = [] [dependencies] anyhow.workspace = true client.workspace = true -copilot_ui.workspace = true +inline_completion_button.workspace = true db.workspace = true extensions_ui.workspace = true fuzzy.workspace = true diff --git a/crates/welcome/src/welcome.rs b/crates/welcome/src/welcome.rs index e6a2a53f2e6cb5483063710e9f65b4dd59250ea7..3ae07cda6801ea6df627c060895d2733c9c0f924 100644 --- a/crates/welcome/src/welcome.rs +++ b/crates/welcome/src/welcome.rs @@ -2,7 +2,6 @@ mod base_keymap_picker; mod base_keymap_setting; use client::{telemetry::Telemetry, TelemetrySettings}; -use copilot_ui; use db::kvp::KEY_VALUE_STORE; use gpui::{ svg, AnyElement, AppContext, EventEmitter, FocusHandle, FocusableView, InteractiveElement, @@ -143,7 +142,7 @@ impl Render for WelcomePage { this.telemetry.report_app_event( "welcome page: sign in to copilot".to_string(), ); - copilot_ui::initiate_sign_in(cx); + inline_completion_button::initiate_sign_in(cx); })), ) .child( diff --git a/crates/zed/Cargo.toml b/crates/zed/Cargo.toml index 9a9f40020a77c4d302dfe51b0a13346cf509a1cb..a8130fe5df90d0e6521943699944633ccb25fd19 100644 --- a/crates/zed/Cargo.toml +++ b/crates/zed/Cargo.toml @@ -35,7 +35,6 @@ collab_ui.workspace = true collections.workspace = true command_palette.workspace = true copilot.workspace = true -copilot_ui.workspace = true db.workspace = true diagnostics.workspace = true editor.workspace = true @@ -51,6 +50,7 @@ go_to_line.workspace = true gpui.workspace = true headless.workspace = true image_viewer.workspace = true +inline_completion_button.workspace = true install_cli.workspace = true isahc.workspace = true journal.workspace = true @@ -83,6 +83,7 @@ settings.workspace = true simplelog = "0.9" smol.workspace = true tab_switcher.workspace = true +supermaven.workspace = true task.workspace = true tasks_ui.workspace = true telemetry_events.workspace = true diff --git a/crates/zed/src/main.rs b/crates/zed/src/main.rs index 9850a2f603411c06363620b51d64e97e1bb5e3b3..3b2e96965ec8042bdaf8ff11b70196fb9b115e88 100644 --- a/crates/zed/src/main.rs +++ b/crates/zed/src/main.rs @@ -9,16 +9,14 @@ mod zed; use anyhow::{anyhow, Context as _, Result}; use clap::{command, Parser}; use cli::FORCE_CLI_MODE_ENV_VAR_NAME; -use client::{parse_zed_link, telemetry::Telemetry, Client, DevServerToken, UserStore}; +use client::{parse_zed_link, Client, DevServerToken, UserStore}; use collab_ui::channel_view::ChannelView; -use copilot::Copilot; -use copilot_ui::CopilotCompletionProvider; use db::kvp::KEY_VALUE_STORE; -use editor::{Editor, EditorMode}; +use editor::Editor; use env_logger::Builder; use fs::RealFs; use futures::{future, StreamExt}; -use gpui::{App, AppContext, AsyncAppContext, Context, Task, ViewContext, VisualContext}; +use gpui::{App, AppContext, AsyncAppContext, Context, Task, VisualContext}; use image_viewer; use language::LanguageRegistry; use log::LevelFilter; @@ -55,6 +53,8 @@ use zed::{ OpenListener, OpenRequest, }; +use crate::zed::inline_completion_registry; + #[cfg(feature = "mimalloc")] #[global_allocator] static GLOBAL: mimalloc::MiMalloc = mimalloc::MiMalloc; @@ -270,17 +270,20 @@ fn init_ui(args: Args) { editor::init(cx); image_viewer::init(cx); diagnostics::init(cx); + + // Initialize each completion provider. Settings are used for toggling between them. copilot::init( copilot_language_server_id, client.http_client(), node_runtime.clone(), cx, ); + supermaven::init(client.clone(), cx); assistant::init(client.clone(), cx); assistant2::init(client.clone(), cx); - init_inline_completion_provider(client.telemetry().clone(), cx); + inline_completion_registry::init(client.telemetry().clone(), cx); extension::init( fs.clone(), @@ -888,45 +891,3 @@ fn watch_file_types(fs: Arc, cx: &mut AppContext) { #[cfg(not(debug_assertions))] fn watch_file_types(_fs: Arc, _cx: &mut AppContext) {} - -fn init_inline_completion_provider(telemetry: Arc, cx: &mut AppContext) { - if let Some(copilot) = Copilot::global(cx) { - cx.observe_new_views(move |editor: &mut Editor, cx: &mut ViewContext| { - if editor.mode() == EditorMode::Full { - // We renamed some of these actions to not be copilot-specific, but that - // would have not been backwards-compatible. So here we are re-registering - // the actions with the old names to not break people's keymaps. - editor - .register_action(cx.listener( - |editor, _: &copilot::Suggest, cx: &mut ViewContext| { - editor.show_inline_completion(&Default::default(), cx); - }, - )) - .register_action(cx.listener( - |editor, _: &copilot::NextSuggestion, cx: &mut ViewContext| { - editor.next_inline_completion(&Default::default(), cx); - }, - )) - .register_action(cx.listener( - |editor, _: &copilot::PreviousSuggestion, cx: &mut ViewContext| { - editor.previous_inline_completion(&Default::default(), cx); - }, - )) - .register_action(cx.listener( - |editor, - _: &editor::actions::AcceptPartialCopilotSuggestion, - cx: &mut ViewContext| { - editor.accept_partial_inline_completion(&Default::default(), cx); - }, - )); - - let provider = cx.new_model(|_| { - CopilotCompletionProvider::new(copilot.clone()) - .with_telemetry(telemetry.clone()) - }); - editor.set_inline_completion_provider(provider, cx) - } - }) - .detach(); - } -} diff --git a/crates/zed/src/zed.rs b/crates/zed/src/zed.rs index 6c0f155ce2a54c5088e9c271577c39e6ba30a300..14cc9febd26b511c16867de8fc3e656124235b01 100644 --- a/crates/zed/src/zed.rs +++ b/crates/zed/src/zed.rs @@ -1,4 +1,5 @@ mod app_menus; +pub mod inline_completion_registry; mod only_instance; mod open_listener; @@ -127,7 +128,10 @@ pub fn initialize_workspace(app_state: Arc, cx: &mut AppContext) { }) .detach(); - let copilot = cx.new_view(|cx| copilot_ui::CopilotButton::new(app_state.fs.clone(), cx)); + let inline_completion_button = cx.new_view(|cx| { + inline_completion_button::InlineCompletionButton::new(app_state.fs.clone(), cx) + }); + let diagnostic_summary = cx.new_view(|cx| diagnostics::items::DiagnosticIndicator::new(workspace, cx)); let activity_indicator = @@ -140,7 +144,7 @@ pub fn initialize_workspace(app_state: Arc, cx: &mut AppContext) { workspace.status_bar().update(cx, |status_bar, cx| { status_bar.add_left_item(diagnostic_summary, cx); status_bar.add_left_item(activity_indicator, cx); - status_bar.add_right_item(copilot, cx); + status_bar.add_right_item(inline_completion_button, cx); status_bar.add_right_item(active_buffer_language, cx); status_bar.add_right_item(vim_mode_indicator, cx); status_bar.add_right_item(cursor_position, cx); diff --git a/crates/zed/src/zed/inline_completion_registry.rs b/crates/zed/src/zed/inline_completion_registry.rs new file mode 100644 index 0000000000000000000000000000000000000000..7ea50322a38cfc9f0806913682a53556fe6e6582 --- /dev/null +++ b/crates/zed/src/zed/inline_completion_registry.rs @@ -0,0 +1,126 @@ +use std::{cell::RefCell, rc::Rc, sync::Arc}; + +use client::telemetry::Telemetry; +use collections::HashMap; +use copilot::{Copilot, CopilotCompletionProvider}; +use editor::{Editor, EditorMode}; +use gpui::{AnyWindowHandle, AppContext, Context, ViewContext, WeakView}; +use language::language_settings::all_language_settings; +use settings::SettingsStore; +use supermaven::{Supermaven, SupermavenCompletionProvider}; + +pub fn init(telemetry: Arc, cx: &mut AppContext) { + let editors: Rc, AnyWindowHandle>>> = Rc::default(); + cx.observe_new_views({ + let editors = editors.clone(); + let telemetry = telemetry.clone(); + move |editor: &mut Editor, cx: &mut ViewContext| { + if editor.mode() != EditorMode::Full { + return; + } + + register_backward_compatible_actions(editor, cx); + + let editor_handle = cx.view().downgrade(); + cx.on_release({ + let editor_handle = editor_handle.clone(); + let editors = editors.clone(); + move |_, _, _| { + editors.borrow_mut().remove(&editor_handle); + } + }) + .detach(); + editors + .borrow_mut() + .insert(editor_handle, cx.window_handle()); + let provider = all_language_settings(None, cx).inline_completions.provider; + assign_inline_completion_provider(editor, provider, &telemetry, cx); + } + }) + .detach(); + + let mut provider = all_language_settings(None, cx).inline_completions.provider; + for (editor, window) in editors.borrow().iter() { + _ = window.update(cx, |_window, cx| { + _ = editor.update(cx, |editor, cx| { + assign_inline_completion_provider(editor, provider, &telemetry, cx); + }) + }); + } + + cx.observe_global::(move |cx| { + let new_provider = all_language_settings(None, cx).inline_completions.provider; + if new_provider != provider { + provider = new_provider; + for (editor, window) in editors.borrow().iter() { + _ = window.update(cx, |_window, cx| { + _ = editor.update(cx, |editor, cx| { + assign_inline_completion_provider(editor, provider, &telemetry, cx); + }) + }); + } + } + }) + .detach(); +} + +fn register_backward_compatible_actions(editor: &mut Editor, cx: &mut ViewContext) { + // We renamed some of these actions to not be copilot-specific, but that + // would have not been backwards-compatible. So here we are re-registering + // the actions with the old names to not break people's keymaps. + editor + .register_action(cx.listener( + |editor, _: &copilot::Suggest, cx: &mut ViewContext| { + editor.show_inline_completion(&Default::default(), cx); + }, + )) + .register_action(cx.listener( + |editor, _: &copilot::NextSuggestion, cx: &mut ViewContext| { + editor.next_inline_completion(&Default::default(), cx); + }, + )) + .register_action(cx.listener( + |editor, _: &copilot::PreviousSuggestion, cx: &mut ViewContext| { + editor.previous_inline_completion(&Default::default(), cx); + }, + )) + .register_action(cx.listener( + |editor, + _: &editor::actions::AcceptPartialCopilotSuggestion, + cx: &mut ViewContext| { + editor.accept_partial_inline_completion(&Default::default(), cx); + }, + )); +} + +fn assign_inline_completion_provider( + editor: &mut Editor, + provider: language::language_settings::InlineCompletionProvider, + telemetry: &Arc, + cx: &mut ViewContext, +) { + match provider { + language::language_settings::InlineCompletionProvider::None => {} + language::language_settings::InlineCompletionProvider::Copilot => { + if let Some(copilot) = Copilot::global(cx) { + if let Some(buffer) = editor.buffer().read(cx).as_singleton() { + if buffer.read(cx).file().is_some() { + copilot.update(cx, |copilot, cx| { + copilot.register_buffer(&buffer, cx); + }); + } + } + let provider = cx.new_model(|_| { + CopilotCompletionProvider::new(copilot).with_telemetry(telemetry.clone()) + }); + editor.set_inline_completion_provider(Some(provider), cx); + } + } + language::language_settings::InlineCompletionProvider::Supermaven => { + if let Some(supermaven) = Supermaven::global(cx) { + let provider = cx.new_model(|_| SupermavenCompletionProvider::new(supermaven)); + editor.set_inline_completion_provider(Some(provider), cx); + } + } + } +}