From 09967ac3d0a50a8afc821ceb8756b18b960ebaa9 Mon Sep 17 00:00:00 2001 From: Marshall Bowers Date: Thu, 6 Feb 2025 13:07:26 -0500 Subject: [PATCH] zeta: Send up diagnostics with prediction requests (#24384) This PR makes it so we send up the diagnostic groups as additional data with the edit prediction request. We're not yet making use of them, but we are recording them so we can use them later (e.g., to train the model). Release Notes: - N/A --------- Co-authored-by: Nathan --- Cargo.lock | 13 ++- Cargo.toml | 2 +- .../src/copilot_completion_provider.rs | 2 + crates/editor/src/editor.rs | 8 +- crates/editor/src/inline_completion_tests.rs | 2 + crates/inline_completion/Cargo.toml | 1 + .../src/inline_completion.rs | 6 +- crates/language/src/buffer.rs | 2 +- crates/language/src/diagnostic_set.rs | 19 +++- crates/language/src/language.rs | 9 +- crates/languages/src/lib.rs | 1 - crates/project/src/lsp_store.rs | 13 +++ crates/supermaven/Cargo.toml | 1 + .../src/supermaven_completion_provider.rs | 2 + crates/zeta/Cargo.toml | 1 + crates/zeta/src/zeta.rs | 94 ++++++++++++++++--- 16 files changed, 145 insertions(+), 31 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 3065f2f68d9296c6aa587807f12467b29763ed09..5ad5d048bc88a23e547ff8eb62f38951d135f3fb 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6428,6 +6428,7 @@ version = "0.1.0" dependencies = [ "gpui", "language", + "project", ] [[package]] @@ -7160,7 +7161,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fc2f4eb4bc735547cfed7c0a4922cbd04a4655978c09b54f1f7b228750664c34" dependencies = [ "cfg-if", - "windows-targets 0.48.5", + "windows-targets 0.52.6", ] [[package]] @@ -10215,7 +10216,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "22505a5c94da8e3b7c2996394d1c933236c4d743e81a410bcca4e6989fc066a4" dependencies = [ "bytes 1.10.0", - "heck 0.4.1", + "heck 0.5.0", "itertools 0.12.1", "log", "multimap 0.10.0", @@ -15484,7 +15485,7 @@ version = "0.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cf221c93e13a30d793f7645a0e7762c55d169dbb0a49671918a2319d289b10bb" dependencies = [ - "windows-sys 0.48.0", + "windows-sys 0.59.0", ] [[package]] @@ -16729,11 +16730,12 @@ dependencies = [ [[package]] name = "zed_llm_client" -version = "0.1.2" +version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4ab9496dc5c80b2c5fb9654a76d7208d31b53130fb282085fcdde07653831843" +checksum = "9ea4d8ead1e1158e5ebdd6735df25973781da70de5c8008e3a13595865ca4f31" dependencies = [ "serde", + "serde_json", ] [[package]] @@ -16956,6 +16958,7 @@ dependencies = [ "log", "menu", "postage", + "project", "regex", "reqwest_client", "rpc", diff --git a/Cargo.toml b/Cargo.toml index 73160b0cd17179b57ee839a78b27833ce992be19..ff50372ed342d2c4d6a60b6e8c1df06ffc89f5c9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -557,7 +557,7 @@ wasmtime = { version = "24", default-features = false, features = [ wasmtime-wasi = "24" which = "6.0.0" wit-component = "0.201" -zed_llm_client = "0.1.1" +zed_llm_client = "0.2" zstd = "0.11" metal = "0.31" diff --git a/crates/copilot/src/copilot_completion_provider.rs b/crates/copilot/src/copilot_completion_provider.rs index f953e5a1100371c6990e71e1208bb6e33b15d8bd..93ffeaf2e2d92164a4fd40062ba69aa2802d0b00 100644 --- a/crates/copilot/src/copilot_completion_provider.rs +++ b/crates/copilot/src/copilot_completion_provider.rs @@ -3,6 +3,7 @@ use anyhow::Result; use gpui::{App, Context, Entity, EntityId, Task}; use inline_completion::{Direction, InlineCompletion, InlineCompletionProvider}; use language::{language_settings::AllLanguageSettings, Buffer, OffsetRangeExt, ToOffset}; +use project::Project; use settings::Settings; use std::{path::Path, time::Duration}; @@ -79,6 +80,7 @@ impl InlineCompletionProvider for CopilotCompletionProvider { fn refresh( &mut self, + _project: Option>, buffer: Entity, cursor_position: language::Anchor, debounce: bool, diff --git a/crates/editor/src/editor.rs b/crates/editor/src/editor.rs index 95373616f00eddee8a80ce57af96bbc921de6b19..b533345e66aa9abfaa66a93c8795e8f2582ad958 100644 --- a/crates/editor/src/editor.rs +++ b/crates/editor/src/editor.rs @@ -4648,7 +4648,13 @@ impl Editor { } self.update_visible_inline_completion(window, cx); - provider.refresh(buffer, cursor_buffer_position, debounce, cx); + provider.refresh( + self.project.clone(), + buffer, + cursor_buffer_position, + debounce, + cx, + ); Some(()) } diff --git a/crates/editor/src/inline_completion_tests.rs b/crates/editor/src/inline_completion_tests.rs index 40f77bd35b97f1b049e684282132d69762b76dbc..c0ad941b7a67e06ac7697ac94d91d71306a17035 100644 --- a/crates/editor/src/inline_completion_tests.rs +++ b/crates/editor/src/inline_completion_tests.rs @@ -3,6 +3,7 @@ use indoc::indoc; use inline_completion::InlineCompletionProvider; use language::{Language, LanguageConfig}; use multi_buffer::{Anchor, MultiBufferSnapshot, ToPoint}; +use project::Project; use std::{num::NonZeroU32, ops::Range, sync::Arc}; use text::{Point, ToOffset}; @@ -394,6 +395,7 @@ impl InlineCompletionProvider for FakeInlineCompletionProvider { fn refresh( &mut self, + _project: Option>, _buffer: gpui::Entity, _cursor_position: language::Anchor, _debounce: bool, diff --git a/crates/inline_completion/Cargo.toml b/crates/inline_completion/Cargo.toml index b6b5e2a92ec84d08b333ccb177458787b4a77d95..b478db6f948ad139e127fc0e9ebf7f332c0d8547 100644 --- a/crates/inline_completion/Cargo.toml +++ b/crates/inline_completion/Cargo.toml @@ -14,3 +14,4 @@ path = "src/inline_completion.rs" [dependencies] gpui.workspace = true language.workspace = true +project.workspace = true diff --git a/crates/inline_completion/src/inline_completion.rs b/crates/inline_completion/src/inline_completion.rs index 7c1d89f097e9250a54f2ce5f26306f21ce3644b8..d262112e0380da4b9d352e9e54a9d95da8b31160 100644 --- a/crates/inline_completion/src/inline_completion.rs +++ b/crates/inline_completion/src/inline_completion.rs @@ -1,5 +1,6 @@ use gpui::{App, Context, Entity}; use language::Buffer; +use project::Project; use std::ops::Range; // TODO: Find a better home for `Direction`. @@ -58,6 +59,7 @@ pub trait InlineCompletionProvider: 'static + Sized { fn is_refreshing(&self) -> bool; fn refresh( &mut self, + project: Option>, buffer: Entity, cursor_position: language::Anchor, debounce: bool, @@ -101,6 +103,7 @@ pub trait InlineCompletionProviderHandle { fn is_refreshing(&self, cx: &App) -> bool; fn refresh( &self, + project: Option>, buffer: Entity, cursor_position: language::Anchor, debounce: bool, @@ -174,13 +177,14 @@ where fn refresh( &self, + project: Option>, buffer: Entity, cursor_position: language::Anchor, debounce: bool, cx: &mut App, ) { self.update(cx, |this, cx| { - this.refresh(buffer, cursor_position, debounce, cx) + this.refresh(project, buffer, cursor_position, debounce, cx) }) } diff --git a/crates/language/src/buffer.rs b/crates/language/src/buffer.rs index 5c9798c9b199b02a9fe409a32dc20553c318a7f1..defa935c2827152437b00d29e3bccbd35621403e 100644 --- a/crates/language/src/buffer.rs +++ b/crates/language/src/buffer.rs @@ -197,7 +197,7 @@ struct SelectionSet { } /// A diagnostic associated with a certain range of a buffer. -#[derive(Clone, Debug, PartialEq, Eq)] +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] pub struct Diagnostic { /// The name of the service that produced this diagnostic. pub source: Option, diff --git a/crates/language/src/diagnostic_set.rs b/crates/language/src/diagnostic_set.rs index 0f2e39275cd6f2d6ddd31d7fc5bc9b3836c3166b..cff59c8004ce47f28c99ed542dbfb146cbb1041a 100644 --- a/crates/language/src/diagnostic_set.rs +++ b/crates/language/src/diagnostic_set.rs @@ -2,6 +2,7 @@ use crate::{range_to_lsp, Diagnostic}; use anyhow::Result; use collections::HashMap; use lsp::LanguageServerId; +use serde::Serialize; use std::{ cmp::{Ordering, Reverse}, iter, @@ -25,7 +26,7 @@ pub struct DiagnosticSet { /// the diagnostics are stored internally as [`Anchor`]s, but can be /// resolved to different coordinates types like [`usize`] byte offsets or /// [`Point`](gpui::Point)s. -#[derive(Clone, Debug, PartialEq, Eq)] +#[derive(Clone, Debug, PartialEq, Eq, Serialize)] pub struct DiagnosticEntry { /// The range of the buffer where the diagnostic applies. pub range: Range, @@ -35,7 +36,7 @@ pub struct DiagnosticEntry { /// A group of related diagnostics, ordered by their start position /// in the buffer. -#[derive(Debug)] +#[derive(Debug, Serialize)] pub struct DiagnosticGroup { /// The diagnostics. pub entries: Vec>, @@ -43,6 +44,20 @@ pub struct DiagnosticGroup { pub primary_ix: usize, } +impl DiagnosticGroup { + /// Converts the entries in this [`DiagnosticGroup`] to a different buffer coordinate type. + pub fn resolve(&self, buffer: &text::BufferSnapshot) -> DiagnosticGroup { + DiagnosticGroup { + entries: self + .entries + .iter() + .map(|entry| entry.resolve(buffer)) + .collect(), + primary_ix: self.primary_ix, + } + } +} + #[derive(Clone, Debug)] pub struct Summary { start: Anchor, diff --git a/crates/language/src/language.rs b/crates/language/src/language.rs index f7aa838c7f096bb3ecf30a70e411ff53c32a5533..48438757fbf6cc143465db63190c0e76d37a8ba2 100644 --- a/crates/language/src/language.rs +++ b/crates/language/src/language.rs @@ -32,10 +32,7 @@ use gpui::{App, AsyncApp, Entity, SharedString, Task}; pub use highlight_map::HighlightMap; use http_client::HttpClient; pub use language_registry::{LanguageName, LoadedLanguage}; -use lsp::{ - CodeActionKind, InitializeParams, LanguageServerBinary, LanguageServerBinaryOptions, - LanguageServerName, -}; +use lsp::{CodeActionKind, InitializeParams, LanguageServerBinary, LanguageServerBinaryOptions}; use parking_lot::Mutex; use regex::Regex; use schemars::{ @@ -73,12 +70,12 @@ use util::serde::default_true; pub use buffer::Operation; pub use buffer::*; -pub use diagnostic_set::DiagnosticEntry; +pub use diagnostic_set::{DiagnosticEntry, DiagnosticGroup}; pub use language_registry::{ AvailableLanguage, LanguageNotFound, LanguageQueries, LanguageRegistry, LanguageServerBinaryStatus, QUERY_FILENAME_PREFIXES, }; -pub use lsp::LanguageServerId; +pub use lsp::{LanguageServerId, LanguageServerName}; pub use outline::*; pub use syntax_map::{OwnedSyntaxLayer, SyntaxLayer, ToTreeSitterPoint, TreeSitterOptions}; pub use text::{AnchorRangeExt, LineEnding}; diff --git a/crates/languages/src/lib.rs b/crates/languages/src/lib.rs index fbfe7b371ce1fc3b26a41b464064a342d8d9f34b..fc14962720853a30baad8c76f3e1a65cab8f6a12 100644 --- a/crates/languages/src/lib.rs +++ b/crates/languages/src/lib.rs @@ -2,7 +2,6 @@ use anyhow::Context as _; use gpui::{App, UpdateGlobal}; use json::json_task_context; pub use language::*; -use lsp::LanguageServerName; use node_runtime::NodeRuntime; use python::{PythonContextProvider, PythonToolchainProvider}; use rust_embed::RustEmbed; diff --git a/crates/project/src/lsp_store.rs b/crates/project/src/lsp_store.rs index e73bef795b59a0d4e6739ef1006c246a2d84cdd6..f85ba369f7590b3b1af20ee21ca85086dc76046f 100644 --- a/crates/project/src/lsp_store.rs +++ b/crates/project/src/lsp_store.rs @@ -166,6 +166,19 @@ pub struct LocalLspStore { } impl LocalLspStore { + /// Returns the running language server for the given ID. Note if the language server is starting, it will not be returned. + pub fn running_language_server_for_id( + &self, + id: LanguageServerId, + ) -> Option<&Arc> { + let language_server_state = self.language_servers.get(&id)?; + + match language_server_state { + LanguageServerState::Running { server, .. } => Some(server), + LanguageServerState::Starting(_) => None, + } + } + fn start_language_server( &mut self, worktree_handle: &Entity, diff --git a/crates/supermaven/Cargo.toml b/crates/supermaven/Cargo.toml index a4748754bcaa673a447f460d27884ad16b5250bf..aa173266fe5367e7d3fd7e86f1ac888bd601971d 100644 --- a/crates/supermaven/Cargo.toml +++ b/crates/supermaven/Cargo.toml @@ -22,6 +22,7 @@ inline_completion.workspace = true language.workspace = true log.workspace = true postage.workspace = true +project.workspace = true serde.workspace = true serde_json.workspace = true settings.workspace = true diff --git a/crates/supermaven/src/supermaven_completion_provider.rs b/crates/supermaven/src/supermaven_completion_provider.rs index f80551a3f39d3f3a417bded1f4affa1bce46253b..c17053ca5514bf10600b88098c1c802a13edb879 100644 --- a/crates/supermaven/src/supermaven_completion_provider.rs +++ b/crates/supermaven/src/supermaven_completion_provider.rs @@ -4,6 +4,7 @@ use futures::StreamExt as _; use gpui::{App, Context, Entity, EntityId, Task}; use inline_completion::{Direction, InlineCompletion, InlineCompletionProvider}; use language::{Anchor, Buffer, BufferSnapshot}; +use project::Project; use std::{ ops::{AddAssign, Range}, path::Path, @@ -123,6 +124,7 @@ impl InlineCompletionProvider for SupermavenCompletionProvider { fn refresh( &mut self, + _project: Option>, buffer_handle: Entity, cursor_position: Anchor, debounce: bool, diff --git a/crates/zeta/Cargo.toml b/crates/zeta/Cargo.toml index e981460256eb977e9c166b1a89a89cd4fde9def7..1904a4d2bac484394e07ce3f708358c78a79e81d 100644 --- a/crates/zeta/Cargo.toml +++ b/crates/zeta/Cargo.toml @@ -37,6 +37,7 @@ language_models.workspace = true log.workspace = true menu.workspace = true postage.workspace = true +project.workspace = true regex.workspace = true serde.workspace = true serde_json.workspace = true diff --git a/crates/zeta/src/zeta.rs b/crates/zeta/src/zeta.rs index c20522b00b8610dc3b6dc4c20fc653712a9c0f5b..a2d660134294e33c86ae028a40da6cb9c3c0e973 100644 --- a/crates/zeta/src/zeta.rs +++ b/crates/zeta/src/zeta.rs @@ -30,6 +30,7 @@ use language::{ }; use language_models::LlmApiToken; use postage::watch; +use project::Project; use settings::WorktreeId; use std::{ borrow::Cow, @@ -363,6 +364,7 @@ impl Zeta { pub fn request_completion_impl( &mut self, + project: Option<&Entity>, buffer: &Entity, cursor: language::Anchor, can_collect_data: bool, @@ -374,6 +376,7 @@ impl Zeta { R: Future> + Send + 'static, { let snapshot = self.report_changes_for_buffer(&buffer, cx); + let diagnostic_groups = snapshot.diagnostic_groups(None); let cursor_point = cursor.to_point(&snapshot); let cursor_offset = cursor_point.to_offset(&snapshot); let events = self.events.clone(); @@ -387,10 +390,39 @@ impl Zeta { let is_staff = cx.is_staff(); let buffer = buffer.clone(); + + let local_lsp_store = + project.and_then(|project| project.read(cx).lsp_store().read(cx).as_local()); + let diagnostic_groups = if let Some(local_lsp_store) = local_lsp_store { + Some( + diagnostic_groups + .into_iter() + .filter_map(|(language_server_id, diagnostic_group)| { + let language_server = + local_lsp_store.running_language_server_for_id(language_server_id)?; + + Some(( + language_server.name(), + diagnostic_group.resolve::(&snapshot), + )) + }) + .collect::>(), + ) + } else { + None + }; + cx.spawn(|_, cx| async move { let request_sent_at = Instant::now(); - let (input_events, input_excerpt, excerpt_range, input_outline) = cx + struct BackgroundValues { + input_events: String, + input_excerpt: String, + excerpt_range: Range, + input_outline: String, + } + + let values = cx .background_executor() .spawn({ let snapshot = snapshot.clone(); @@ -419,18 +451,36 @@ impl Zeta { // is not counted towards TOTAL_BYTE_LIMIT. let input_outline = prompt_for_outline(&snapshot); - anyhow::Ok((input_events, input_excerpt, excerpt_range, input_outline)) + anyhow::Ok(BackgroundValues { + input_events, + input_excerpt, + excerpt_range, + input_outline, + }) } }) .await?; - log::debug!("Events:\n{}\nExcerpt:\n{}", input_events, input_excerpt); + log::debug!( + "Events:\n{}\nExcerpt:\n{}", + values.input_events, + values.input_excerpt + ); let body = PredictEditsBody { - input_events: input_events.clone(), - input_excerpt: input_excerpt.clone(), - outline: Some(input_outline.clone()), + input_events: values.input_events.clone(), + input_excerpt: values.input_excerpt.clone(), + outline: Some(values.input_outline.clone()), can_collect_data, + diagnostic_groups: diagnostic_groups.and_then(|diagnostic_groups| { + diagnostic_groups + .into_iter() + .map(|(name, diagnostic_group)| { + Ok((name.to_string(), serde_json::to_value(diagnostic_group)?)) + }) + .collect::>>() + .log_err() + }), }; let response = perform_predict_edits(client, llm_token, is_staff, body).await?; @@ -442,12 +492,12 @@ impl Zeta { output_excerpt, buffer, &snapshot, - excerpt_range, + values.excerpt_range, cursor_offset, path, - input_outline, - input_events, - input_excerpt, + values.input_outline, + values.input_events, + values.input_excerpt, request_sent_at, &cx, ) @@ -466,11 +516,13 @@ impl Zeta { and then another "#}; + let project = None; let buffer = cx.new(|cx| Buffer::local(test_buffer_text, cx)); let position = buffer.read(cx).anchor_before(Point::new(1, 0)); let completion_tasks = vec![ self.fake_completion( + project, &buffer, position, PredictEditsResponse { @@ -486,6 +538,7 @@ and then another cx, ), self.fake_completion( + project, &buffer, position, PredictEditsResponse { @@ -501,6 +554,7 @@ and then another cx, ), self.fake_completion( + project, &buffer, position, PredictEditsResponse { @@ -517,6 +571,7 @@ and then another cx, ), self.fake_completion( + project, &buffer, position, PredictEditsResponse { @@ -533,6 +588,7 @@ and then another cx, ), self.fake_completion( + project, &buffer, position, PredictEditsResponse { @@ -548,6 +604,7 @@ and then another cx, ), self.fake_completion( + project, &buffer, position, PredictEditsResponse { @@ -562,6 +619,7 @@ and then another cx, ), self.fake_completion( + project, &buffer, position, PredictEditsResponse { @@ -594,6 +652,7 @@ and then another #[cfg(any(test, feature = "test-support"))] pub fn fake_completion( &mut self, + project: Option<&Entity>, buffer: &Entity, position: language::Anchor, response: PredictEditsResponse, @@ -601,19 +660,21 @@ and then another ) -> Task>> { use std::future::ready; - self.request_completion_impl(buffer, position, false, cx, |_, _, _, _| { + self.request_completion_impl(project, buffer, position, false, cx, |_, _, _, _| { ready(Ok(response)) }) } pub fn request_completion( &mut self, + project: Option<&Entity>, buffer: &Entity, position: language::Anchor, can_collect_data: bool, cx: &mut Context, ) -> Task>> { self.request_completion_impl( + project, buffer, position, can_collect_data, @@ -1494,6 +1555,7 @@ impl inline_completion::InlineCompletionProvider for ZetaInlineCompletionProvide fn refresh( &mut self, + project: Option>, buffer: Entity, position: language::Anchor, _debounce: bool, @@ -1529,7 +1591,13 @@ impl inline_completion::InlineCompletionProvider for ZetaInlineCompletionProvide let completion_request = this.update(&mut cx, |this, cx| { this.last_request_timestamp = Instant::now(); this.zeta.update(cx, |zeta, cx| { - zeta.request_completion(&buffer, position, can_collect_data, cx) + zeta.request_completion( + project.as_ref(), + &buffer, + position, + can_collect_data, + cx, + ) }) }); @@ -1858,7 +1926,7 @@ mod tests { let buffer = cx.new(|cx| Buffer::local(buffer_content, cx)); let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 0))); let completion_task = zeta.update(cx, |zeta, cx| { - zeta.request_completion(&buffer, cursor, false, cx) + zeta.request_completion(None, &buffer, cursor, false, cx) }); let token_request = server.receive::().await.unwrap();