diff --git a/Cargo.lock b/Cargo.lock index 3855baec953955dd114f6d8910bf35d560ed3b7f..13d951f37867763e9a3f624475bc337c3daa04b6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -21639,11 +21639,13 @@ name = "zeta2" version = "0.1.0" dependencies = [ "client", + "cloud_llm_client", "edit_prediction", + "edit_prediction_context", "gpui", "language", + "log", "project", - "util", "workspace-hack", ] diff --git a/crates/cloud_llm_client/src/predict_edits_v3.rs b/crates/cloud_llm_client/src/predict_edits_v3.rs index 076dc6c5cb2f5c50d460a9c8e01172461ca9b123..cdcb02287017f10211a1d5056d9e56d95459a4ab 100644 --- a/crates/cloud_llm_client/src/predict_edits_v3.rs +++ b/crates/cloud_llm_client/src/predict_edits_v3.rs @@ -6,7 +6,7 @@ use crate::PredictEditsGitInfo; // TODO: snippet ordering within file / relative to excerpt #[derive(Debug, Clone, Serialize, Deserialize)] -pub struct Body { +pub struct PredictEditsRequest { pub excerpt: String, /// Within `signatures` pub excerpt_parent: Option, @@ -15,8 +15,8 @@ pub struct Body { pub events: Vec, #[serde(default)] pub can_collect_data: bool, - #[serde(skip_serializing_if = "Option::is_none", default)] - pub diagnostic_groups: Option>, + #[serde(skip_serializing_if = "Vec::is_empty", default)] + pub diagnostic_groups: Vec, /// Info about the git repository state, only present when can_collect_data is true. #[serde(skip_serializing_if = "Option::is_none", default)] pub git_info: Option, @@ -68,6 +68,12 @@ pub struct ScoreComponents { pub adjacent_vs_signature_weighted_overlap: f32, } +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct DiagnosticGroup { + pub language_server: String, + pub diagnostic_group: serde_json::Value, +} + /* #[derive(Debug, Clone)] pub struct SerializedJson { diff --git a/crates/edit_prediction_context/src/declaration_scoring.rs b/crates/edit_prediction_context/src/declaration_scoring.rs index accb92901c173aecf9ad76116f4fc43e9c32c981..1638723857d225d05efce512bb9025fa89fb38f3 100644 --- a/crates/edit_prediction_context/src/declaration_scoring.rs +++ b/crates/edit_prediction_context/src/declaration_scoring.rs @@ -16,10 +16,6 @@ use crate::{ const MAX_IDENTIFIER_DECLARATION_COUNT: usize = 16; -// TODO: -// -// * Consider adding declaration_file_count - #[derive(Clone, Debug)] pub struct ScoredSnippet { pub identifier: Identifier, @@ -28,7 +24,6 @@ pub struct ScoredSnippet { pub scores: Scores, } -// TODO: Consider having "Concise" style corresponding to `concise_text` #[derive(EnumIter, Clone, Copy, PartialEq, Eq, Hash, Debug)] pub enum SnippetStyle { Signature, @@ -244,6 +239,7 @@ fn score_snippet( let adjacent_vs_signature_weighted_overlap = weighted_overlap_coefficient(adjacent_identifier_occurrences, &item_signature_occurrences); + // TODO: Consider adding declaration_file_count let score_components = ScoreComponents { is_same_file, is_referenced_nearby, diff --git a/crates/edit_prediction_context/src/edit_prediction_context.rs b/crates/edit_prediction_context/src/edit_prediction_context.rs index 41f4b02aa616587e5338ce76b6bc783898fb88bf..0176d3ceca61e64d932f882004289015e778e1c5 100644 --- a/crates/edit_prediction_context/src/edit_prediction_context.rs +++ b/crates/edit_prediction_context/src/edit_prediction_context.rs @@ -6,23 +6,15 @@ mod reference; mod syntax_index; mod text_similarity; -use cloud_llm_client::predict_edits_v3::{self, Signature}; -use collections::HashMap; -pub use declaration::{BufferDeclaration, Declaration, FileDeclaration, Identifier}; -pub use declaration_scoring::SnippetStyle; -pub use excerpt::{EditPredictionExcerpt, EditPredictionExcerptOptions, EditPredictionExcerptText}; - use gpui::{App, AppContext as _, Entity, Task}; use language::BufferSnapshot; -pub use reference::references_in_excerpt; -pub use syntax_index::SyntaxIndex; use text::{Point, ToOffset as _}; -use crate::{ - declaration::DeclarationId, - declaration_scoring::{ScoredSnippet, scored_snippets}, - syntax_index::SyntaxIndexState, -}; +pub use declaration::*; +pub use declaration_scoring::*; +pub use excerpt::*; +pub use reference::*; +pub use syntax_index::*; #[derive(Debug)] pub struct EditPredictionContext { @@ -32,7 +24,7 @@ pub struct EditPredictionContext { } impl EditPredictionContext { - pub fn gather( + pub fn gather_context_in_background( cursor_point: Point, buffer: BufferSnapshot, excerpt_options: EditPredictionExcerptOptions, @@ -42,25 +34,25 @@ impl EditPredictionContext { let index_state = syntax_index.read_with(cx, |index, _cx| index.state().clone()); cx.background_spawn(async move { let index_state = index_state.lock().await; - Self::gather_context(cursor_point, buffer, excerpt_options, &index_state) + Self::gather_context(cursor_point, &buffer, &excerpt_options, &index_state) }) } - fn gather_context( + pub fn gather_context( cursor_point: Point, - buffer: BufferSnapshot, - excerpt_options: EditPredictionExcerptOptions, + buffer: &BufferSnapshot, + excerpt_options: &EditPredictionExcerptOptions, index_state: &SyntaxIndexState, ) -> Option { let excerpt = EditPredictionExcerpt::select_from_buffer( cursor_point, - &buffer, - &excerpt_options, + buffer, + excerpt_options, Some(index_state), )?; - let excerpt_text = excerpt.text(&buffer); - let references = references_in_excerpt(&excerpt, &excerpt_text, &buffer); - let cursor_offset = cursor_point.to_offset(&buffer); + let excerpt_text = excerpt.text(buffer); + let references = references_in_excerpt(&excerpt, &excerpt_text, buffer); + let cursor_offset = cursor_point.to_offset(buffer); let snippets = scored_snippets( &index_state, @@ -68,7 +60,7 @@ impl EditPredictionContext { &excerpt_text, references, cursor_offset, - &buffer, + buffer, ); Some(Self { @@ -77,97 +69,6 @@ impl EditPredictionContext { snippets, }) } - - pub fn cloud_request( - cursor_point: Point, - buffer: BufferSnapshot, - excerpt_options: EditPredictionExcerptOptions, - syntax_index: Entity, - cx: &mut App, - ) -> Task> { - let index_state = syntax_index.read_with(cx, |index, _cx| index.state().clone()); - cx.background_spawn(async move { - let index_state = index_state.lock().await; - Self::gather_context(cursor_point, buffer, excerpt_options, &index_state) - .map(|context| context.into_cloud_request(&index_state)) - }) - } - - pub fn into_cloud_request(self, index: &SyntaxIndexState) -> predict_edits_v3::Body { - let mut signatures = Vec::new(); - let mut declaration_to_signature_index = HashMap::default(); - let mut referenced_declarations = Vec::new(); - let excerpt_parent = self - .excerpt - .parent_declarations - .last() - .and_then(|(parent, _)| { - add_signature( - *parent, - &mut declaration_to_signature_index, - &mut signatures, - index, - ) - }); - for snippet in self.snippets { - let parent_index = snippet.declaration.parent().and_then(|parent| { - add_signature( - parent, - &mut declaration_to_signature_index, - &mut signatures, - index, - ) - }); - let (text, text_is_truncated) = snippet.declaration.item_text(); - referenced_declarations.push(predict_edits_v3::ReferencedDeclaration { - text: text.into(), - text_is_truncated, - signature_range: snippet.declaration.signature_range_in_item_text(), - parent_index, - score_components: snippet.score_components, - signature_score: snippet.scores.signature, - declaration_score: snippet.scores.declaration, - }); - } - predict_edits_v3::Body { - excerpt: self.excerpt_text.body, - referenced_declarations, - signatures, - excerpt_parent, - // todo! - events: vec![], - can_collect_data: false, - diagnostic_groups: None, - git_info: None, - } - } -} - -fn add_signature( - declaration_id: DeclarationId, - declaration_to_signature_index: &mut HashMap, - signatures: &mut Vec, - index: &SyntaxIndexState, -) -> Option { - if let Some(signature_index) = declaration_to_signature_index.get(&declaration_id) { - return Some(*signature_index); - } - let Some(parent_declaration) = index.declaration(declaration_id) else { - log::error!("bug: missing parent declaration"); - return None; - }; - let parent_index = parent_declaration.parent().and_then(|parent| { - add_signature(parent, declaration_to_signature_index, signatures, index) - }); - let (text, text_is_truncated) = parent_declaration.signature_text(); - let signature_index = signatures.len(); - signatures.push(Signature { - text: text.into(), - text_is_truncated, - parent_index, - }); - declaration_to_signature_index.insert(declaration_id, signature_index); - Some(signature_index) } #[cfg(test)] @@ -205,7 +106,7 @@ mod tests { let context = cx .update(|cx| { - EditPredictionContext::gather( + EditPredictionContext::gather_context_in_background( cursor_point, buffer_snapshot, EditPredictionExcerptOptions { diff --git a/crates/edit_prediction_context/src/syntax_index.rs b/crates/edit_prediction_context/src/syntax_index.rs index bd80245c6b7ae51a634e9718f6516bac275abfef..d234e975d504c145d7bc2fc0680569c388ba0d1c 100644 --- a/crates/edit_prediction_context/src/syntax_index.rs +++ b/crates/edit_prediction_context/src/syntax_index.rs @@ -17,12 +17,6 @@ use crate::declaration::{ }; use crate::outline::declarations_in_buffer; -// TODO: -// -// * Skip for remote projects -// -// * Consider making SyntaxIndex not an Entity. - // Potential future improvements: // // * Send multiple selected excerpt ranges. Challenge is that excerpt ranges influence which @@ -41,7 +35,6 @@ use crate::outline::declarations_in_buffer; // * Concurrent slotmap // // * Use queue for parsing -// pub struct SyntaxIndex { state: Arc>, diff --git a/crates/edit_prediction_context/src/text_similarity.rs b/crates/edit_prediction_context/src/text_similarity.rs index f7a9822ecca01f1b6ff1dc04bdc12fbcddc5159b..2ace7bf10cc6fd13b8a5636212211a3274d3c259 100644 --- a/crates/edit_prediction_context/src/text_similarity.rs +++ b/crates/edit_prediction_context/src/text_similarity.rs @@ -9,8 +9,12 @@ use crate::reference::Reference; // That implementation could actually be more efficient - no need to track words in the window that // are not in the query. +// TODO: Consider a flat sorted Vec<(String, usize)> representation. Intersection can just walk the +// two in parallel. + static IDENTIFIER_REGEX: LazyLock = LazyLock::new(|| Regex::new(r"\b\w+\b").unwrap()); +// TODO: use &str or Cow keys? #[derive(Debug)] pub struct IdentifierOccurrences { identifier_to_count: HashMap, diff --git a/crates/edit_prediction_tools/src/edit_prediction_tools.rs b/crates/edit_prediction_tools/src/edit_prediction_tools.rs index ac84689baa84fa337a75c29f58d8d42f9060fe6f..7dcdf8724dc4c0c7b56d21627f7a87b30d83bd79 100644 --- a/crates/edit_prediction_tools/src/edit_prediction_tools.rs +++ b/crates/edit_prediction_tools/src/edit_prediction_tools.rs @@ -222,7 +222,7 @@ impl EditPredictionTools { start_time = Some(Instant::now()); - EditPredictionContext::gather( + EditPredictionContext::gather_context_in_background( cursor_position, current_buffer_snapshot, options, diff --git a/crates/zeta2/Cargo.toml b/crates/zeta2/Cargo.toml index 5c975bd84b4a8b9fc660ea39c7cf4887cbb0b46e..1f250c451a0482ff754623829b528f4c6602731c 100644 --- a/crates/zeta2/Cargo.toml +++ b/crates/zeta2/Cargo.toml @@ -13,9 +13,11 @@ path = "src/zeta2.rs" [dependencies] client.workspace = true +cloud_llm_client.workspace = true edit_prediction.workspace = true +edit_prediction_context.workspace = true gpui.workspace = true language.workspace = true +log.workspace = true project.workspace = true workspace-hack.workspace = true -util.workspace = true diff --git a/crates/zeta2/src/zeta2.rs b/crates/zeta2/src/zeta2.rs index 46de8a1a8ec36d47bb6bcc3cd68ec3492fae98a1..a15e39d799b2f42aa49a47c1d78f9df5bb12cd91 100644 --- a/crates/zeta2/src/zeta2.rs +++ b/crates/zeta2/src/zeta2.rs @@ -1,9 +1,14 @@ -use std::{ops::Range, sync::Arc}; - -use gpui::{App, Entity, EntityId, Task, prelude::*}; - +use cloud_llm_client::predict_edits_v3::{self, Signature}; use edit_prediction::{DataCollectionState, Direction, EditPrediction, EditPredictionProvider}; +use edit_prediction_context::{ + DeclarationId, EditPredictionContext, EditPredictionExcerptOptions, SyntaxIndex, + SyntaxIndexState, +}; +use gpui::{App, Entity, EntityId, Task, prelude::*}; use language::{Anchor, ToPoint}; +use language::{BufferSnapshot, Point}; +use std::collections::HashMap; +use std::{ops::Range, sync::Arc}; pub struct Zeta2EditPredictionProvider { current: Option, @@ -152,3 +157,116 @@ impl EditPredictionProvider for Zeta2EditPredictionProvider { Some(current_prediction.prediction) } } + +pub fn make_cloud_request_in_background( + cursor_point: Point, + buffer: BufferSnapshot, + events: Vec, + can_collect_data: bool, + diagnostic_groups: Vec, + git_info: Option, + excerpt_options: EditPredictionExcerptOptions, + syntax_index: Entity, + cx: &mut App, +) -> Task> { + let index_state = syntax_index.read_with(cx, |index, _cx| index.state().clone()); + cx.background_spawn(async move { + let index_state = index_state.lock().await; + EditPredictionContext::gather_context(cursor_point, &buffer, &excerpt_options, &index_state) + .map(|context| { + make_cloud_request( + context, + events, + can_collect_data, + diagnostic_groups, + git_info, + &index_state, + ) + }) + }) +} + +pub fn make_cloud_request( + context: EditPredictionContext, + events: Vec, + can_collect_data: bool, + diagnostic_groups: Vec, + git_info: Option, + index_state: &SyntaxIndexState, +) -> predict_edits_v3::PredictEditsRequest { + let mut signatures = Vec::new(); + let mut declaration_to_signature_index = HashMap::default(); + let mut referenced_declarations = Vec::new(); + for snippet in context.snippets { + let parent_index = snippet.declaration.parent().and_then(|parent| { + add_signature( + parent, + &mut declaration_to_signature_index, + &mut signatures, + index_state, + ) + }); + let (text, text_is_truncated) = snippet.declaration.item_text(); + referenced_declarations.push(predict_edits_v3::ReferencedDeclaration { + text: text.into(), + text_is_truncated, + signature_range: snippet.declaration.signature_range_in_item_text(), + parent_index, + score_components: snippet.score_components, + signature_score: snippet.scores.signature, + declaration_score: snippet.scores.declaration, + }); + } + + let excerpt_parent = context + .excerpt + .parent_declarations + .last() + .and_then(|(parent, _)| { + add_signature( + *parent, + &mut declaration_to_signature_index, + &mut signatures, + index_state, + ) + }); + + predict_edits_v3::PredictEditsRequest { + excerpt: context.excerpt_text.body, + referenced_declarations, + signatures, + excerpt_parent, + // todo! + events, + can_collect_data, + diagnostic_groups, + git_info, + } +} + +fn add_signature( + declaration_id: DeclarationId, + declaration_to_signature_index: &mut HashMap, + signatures: &mut Vec, + index: &SyntaxIndexState, +) -> Option { + if let Some(signature_index) = declaration_to_signature_index.get(&declaration_id) { + return Some(*signature_index); + } + let Some(parent_declaration) = index.declaration(declaration_id) else { + log::error!("bug: missing parent declaration"); + return None; + }; + let parent_index = parent_declaration.parent().and_then(|parent| { + add_signature(parent, declaration_to_signature_index, signatures, index) + }); + let (text, text_is_truncated) = parent_declaration.signature_text(); + let signature_index = signatures.len(); + signatures.push(Signature { + text: text.into(), + text_is_truncated, + parent_index, + }); + declaration_to_signature_index.insert(declaration_id, signature_index); + Some(signature_index) +}