From f562e7e157b7f53b1dcd64d06778a68bb9abddaa Mon Sep 17 00:00:00 2001 From: Michael Sloan Date: Thu, 18 Sep 2025 06:44:40 -0600 Subject: [PATCH] edit predictions: Initial Tree-sitter context gathering (#38372) Release Notes: - N/A Co-authored-by: Agus Co-authored-by: Oleksiy Co-authored-by: Finn --- Cargo.lock | 6 + crates/edit_prediction_context/Cargo.toml | 7 + .../src/declaration.rs | 193 ++++++ .../src/declaration_scoring.rs | 326 +++++++++ .../src/edit_prediction_context.rs | 216 +++++- crates/edit_prediction_context/src/excerpt.rs | 2 +- crates/edit_prediction_context/src/outline.rs | 12 +- .../edit_prediction_context/src/reference.rs | 2 +- .../{tree_sitter_index.rs => syntax_index.rs} | 638 +++++++++--------- .../src/text_similarity.rs | 241 +++++++ .../src/wip_requests.rs | 35 + 11 files changed, 1361 insertions(+), 317 deletions(-) create mode 100644 crates/edit_prediction_context/src/declaration.rs create mode 100644 crates/edit_prediction_context/src/declaration_scoring.rs rename crates/edit_prediction_context/src/{tree_sitter_index.rs => syntax_index.rs} (58%) create mode 100644 crates/edit_prediction_context/src/text_similarity.rs create mode 100644 crates/edit_prediction_context/src/wip_requests.rs diff --git a/Cargo.lock b/Cargo.lock index 294127b598926ddae6924f0af1dd273416cf1e1c..4966106b79599db9ece65fd9d9e6d794196a46c5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5140,17 +5140,23 @@ version = "0.1.0" dependencies = [ "anyhow", "arrayvec", + "clap", "collections", "futures 0.3.31", "gpui", "indoc", + "itertools 0.14.0", "language", "log", + "ordered-float 2.10.1", "pretty_assertions", "project", + "regex", + "serde", "serde_json", "settings", "slotmap", + "strum 0.27.1", "text", "tree-sitter", "util", diff --git a/crates/edit_prediction_context/Cargo.toml b/crates/edit_prediction_context/Cargo.toml index ad455b0a4ecb1746debafd23f0503b4365f9a0cf..48f51da3912ea5bca589e7b559d5b665b9b762d6 100644 --- a/crates/edit_prediction_context/Cargo.toml +++ b/crates/edit_prediction_context/Cargo.toml @@ -15,17 +15,24 @@ path = "src/edit_prediction_context.rs" anyhow.workspace = true arrayvec.workspace = true collections.workspace = true +futures.workspace = true gpui.workspace = true +itertools.workspace = true language.workspace = true log.workspace = true +ordered-float.workspace = true project.workspace = true +regex.workspace = true +serde.workspace = true slotmap.workspace = true +strum.workspace = true text.workspace = true tree-sitter.workspace = true util.workspace = true workspace-hack.workspace = true [dev-dependencies] +clap.workspace = true futures.workspace = true gpui = { workspace = true, features = ["test-support"] } indoc.workspace = true diff --git a/crates/edit_prediction_context/src/declaration.rs b/crates/edit_prediction_context/src/declaration.rs new file mode 100644 index 0000000000000000000000000000000000000000..fcf54fead80194fe97a2719971f86318a57ad75c --- /dev/null +++ b/crates/edit_prediction_context/src/declaration.rs @@ -0,0 +1,193 @@ +use language::LanguageId; +use project::ProjectEntryId; +use std::borrow::Cow; +use std::ops::Range; +use std::sync::Arc; +use text::{Bias, BufferId, Rope}; + +use crate::outline::OutlineDeclaration; + +#[derive(Debug, Clone, Eq, PartialEq, Hash)] +pub struct Identifier { + pub name: Arc, + pub language_id: LanguageId, +} + +slotmap::new_key_type! { + pub struct DeclarationId; +} + +#[derive(Debug, Clone)] +pub enum Declaration { + File { + project_entry_id: ProjectEntryId, + declaration: FileDeclaration, + }, + Buffer { + project_entry_id: ProjectEntryId, + buffer_id: BufferId, + rope: Rope, + declaration: BufferDeclaration, + }, +} + +const ITEM_TEXT_TRUNCATION_LENGTH: usize = 1024; + +impl Declaration { + pub fn identifier(&self) -> &Identifier { + match self { + Declaration::File { declaration, .. } => &declaration.identifier, + Declaration::Buffer { declaration, .. } => &declaration.identifier, + } + } + + pub fn project_entry_id(&self) -> Option { + match self { + Declaration::File { + project_entry_id, .. + } => Some(*project_entry_id), + Declaration::Buffer { + project_entry_id, .. + } => Some(*project_entry_id), + } + } + + pub fn item_text(&self) -> (Cow<'_, str>, bool) { + match self { + Declaration::File { declaration, .. } => ( + declaration.text.as_ref().into(), + declaration.text_is_truncated, + ), + Declaration::Buffer { + rope, declaration, .. + } => ( + rope.chunks_in_range(declaration.item_range.clone()) + .collect::>(), + declaration.item_range_is_truncated, + ), + } + } + + pub fn signature_text(&self) -> (Cow<'_, str>, bool) { + match self { + Declaration::File { declaration, .. } => ( + declaration.text[declaration.signature_range_in_text.clone()].into(), + declaration.signature_is_truncated, + ), + Declaration::Buffer { + rope, declaration, .. + } => ( + rope.chunks_in_range(declaration.signature_range.clone()) + .collect::>(), + declaration.signature_range_is_truncated, + ), + } + } +} + +fn expand_range_to_line_boundaries_and_truncate( + range: &Range, + limit: usize, + rope: &Rope, +) -> (Range, bool) { + let mut point_range = rope.offset_to_point(range.start)..rope.offset_to_point(range.end); + point_range.start.column = 0; + point_range.end.row += 1; + point_range.end.column = 0; + + let mut item_range = + rope.point_to_offset(point_range.start)..rope.point_to_offset(point_range.end); + let is_truncated = item_range.len() > limit; + if is_truncated { + item_range.end = item_range.start + limit; + } + item_range.end = rope.clip_offset(item_range.end, Bias::Left); + (item_range, is_truncated) +} + +#[derive(Debug, Clone)] +pub struct FileDeclaration { + pub parent: Option, + pub identifier: Identifier, + /// offset range of the declaration in the file, expanded to line boundaries and truncated + pub item_range_in_file: Range, + /// text of `item_range_in_file` + pub text: Arc, + /// whether `text` was truncated + pub text_is_truncated: bool, + /// offset range of the signature within `text` + pub signature_range_in_text: Range, + /// whether `signature` was truncated + pub signature_is_truncated: bool, +} + +impl FileDeclaration { + pub fn from_outline(declaration: OutlineDeclaration, rope: &Rope) -> FileDeclaration { + let (item_range_in_file, text_is_truncated) = expand_range_to_line_boundaries_and_truncate( + &declaration.item_range, + ITEM_TEXT_TRUNCATION_LENGTH, + rope, + ); + + // TODO: consider logging if unexpected + let signature_start = declaration + .signature_range + .start + .saturating_sub(item_range_in_file.start); + let mut signature_end = declaration + .signature_range + .end + .saturating_sub(item_range_in_file.start); + let signature_is_truncated = signature_end > item_range_in_file.len(); + if signature_is_truncated { + signature_end = item_range_in_file.len(); + } + + FileDeclaration { + parent: None, + identifier: declaration.identifier, + signature_range_in_text: signature_start..signature_end, + signature_is_truncated, + text: rope + .chunks_in_range(item_range_in_file.clone()) + .collect::() + .into(), + text_is_truncated, + item_range_in_file, + } + } +} + +#[derive(Debug, Clone)] +pub struct BufferDeclaration { + pub parent: Option, + pub identifier: Identifier, + pub item_range: Range, + pub item_range_is_truncated: bool, + pub signature_range: Range, + pub signature_range_is_truncated: bool, +} + +impl BufferDeclaration { + pub fn from_outline(declaration: OutlineDeclaration, rope: &Rope) -> Self { + let (item_range, item_range_is_truncated) = expand_range_to_line_boundaries_and_truncate( + &declaration.item_range, + ITEM_TEXT_TRUNCATION_LENGTH, + rope, + ); + let (signature_range, signature_range_is_truncated) = + expand_range_to_line_boundaries_and_truncate( + &declaration.signature_range, + ITEM_TEXT_TRUNCATION_LENGTH, + rope, + ); + Self { + parent: None, + identifier: declaration.identifier, + item_range, + item_range_is_truncated, + signature_range, + signature_range_is_truncated, + } + } +} diff --git a/crates/edit_prediction_context/src/declaration_scoring.rs b/crates/edit_prediction_context/src/declaration_scoring.rs new file mode 100644 index 0000000000000000000000000000000000000000..dc442710516a935a65e755393fdfc15026ff1f0e --- /dev/null +++ b/crates/edit_prediction_context/src/declaration_scoring.rs @@ -0,0 +1,326 @@ +use itertools::Itertools as _; +use language::BufferSnapshot; +use ordered_float::OrderedFloat; +use serde::Serialize; +use std::{collections::HashMap, ops::Range}; +use strum::EnumIter; +use text::{OffsetRangeExt, Point, ToPoint}; + +use crate::{ + Declaration, EditPredictionExcerpt, EditPredictionExcerptText, Identifier, + reference::{Reference, ReferenceRegion}, + syntax_index::SyntaxIndexState, + text_similarity::{IdentifierOccurrences, jaccard_similarity, weighted_overlap_coefficient}, +}; + +const MAX_IDENTIFIER_DECLARATION_COUNT: usize = 16; + +// TODO: +// +// * Consider adding declaration_file_count + +#[derive(Clone, Debug)] +pub struct ScoredSnippet { + pub identifier: Identifier, + pub declaration: Declaration, + pub score_components: ScoreInputs, + pub scores: Scores, +} + +// TODO: Consider having "Concise" style corresponding to `concise_text` +#[derive(EnumIter, Clone, Copy, PartialEq, Eq, Hash, Debug)] +pub enum SnippetStyle { + Signature, + Declaration, +} + +impl ScoredSnippet { + /// Returns the score for this snippet with the specified style. + pub fn score(&self, style: SnippetStyle) -> f32 { + match style { + SnippetStyle::Signature => self.scores.signature, + SnippetStyle::Declaration => self.scores.declaration, + } + } + + pub fn size(&self, style: SnippetStyle) -> usize { + // TODO: how to handle truncation? + match &self.declaration { + Declaration::File { declaration, .. } => match style { + SnippetStyle::Signature => declaration.signature_range_in_text.len(), + SnippetStyle::Declaration => declaration.text.len(), + }, + Declaration::Buffer { declaration, .. } => match style { + SnippetStyle::Signature => declaration.signature_range.len(), + SnippetStyle::Declaration => declaration.item_range.len(), + }, + } + } + + pub fn score_density(&self, style: SnippetStyle) -> f32 { + self.score(style) / (self.size(style)) as f32 + } +} + +pub fn scored_snippets( + index: &SyntaxIndexState, + excerpt: &EditPredictionExcerpt, + excerpt_text: &EditPredictionExcerptText, + identifier_to_references: HashMap>, + cursor_offset: usize, + current_buffer: &BufferSnapshot, +) -> Vec { + let containing_range_identifier_occurrences = + IdentifierOccurrences::within_string(&excerpt_text.body); + let cursor_point = cursor_offset.to_point(¤t_buffer); + + let start_point = Point::new(cursor_point.row.saturating_sub(2), 0); + let end_point = Point::new(cursor_point.row + 1, 0); + let adjacent_identifier_occurrences = IdentifierOccurrences::within_string( + ¤t_buffer + .text_for_range(start_point..end_point) + .collect::(), + ); + + let mut snippets = identifier_to_references + .into_iter() + .flat_map(|(identifier, references)| { + let declarations = + index.declarations_for_identifier::(&identifier); + let declaration_count = declarations.len(); + + declarations + .iter() + .filter_map(|declaration| match declaration { + Declaration::Buffer { + buffer_id, + declaration: buffer_declaration, + .. + } => { + let is_same_file = buffer_id == ¤t_buffer.remote_id(); + + if is_same_file { + range_intersection( + &buffer_declaration.item_range.to_offset(¤t_buffer), + &excerpt.range, + ) + .is_none() + .then(|| { + let declaration_line = buffer_declaration + .item_range + .start + .to_point(current_buffer) + .row; + ( + true, + (cursor_point.row as i32 - declaration_line as i32) + .unsigned_abs(), + declaration, + ) + }) + } else { + // TODO should we prefer the current file instead? + Some((false, 0, declaration)) + } + } + Declaration::File { .. } => { + // TODO should we prefer the current file instead? + // We can assume that a file declaration is in a different file, + // because the current one must be open + Some((false, 0, declaration)) + } + }) + .sorted_by_key(|&(_, distance, _)| distance) + .enumerate() + .map( + |( + declaration_line_distance_rank, + (is_same_file, declaration_line_distance, declaration), + )| { + let same_file_declaration_count = index.file_declaration_count(declaration); + + score_snippet( + &identifier, + &references, + declaration.clone(), + is_same_file, + declaration_line_distance, + declaration_line_distance_rank, + same_file_declaration_count, + declaration_count, + &containing_range_identifier_occurrences, + &adjacent_identifier_occurrences, + cursor_point, + current_buffer, + ) + }, + ) + .collect::>() + }) + .flatten() + .collect::>(); + + snippets.sort_unstable_by_key(|snippet| { + OrderedFloat( + snippet + .score_density(SnippetStyle::Declaration) + .max(snippet.score_density(SnippetStyle::Signature)), + ) + }); + + snippets +} + +fn range_intersection(a: &Range, b: &Range) -> Option> { + let start = a.start.clone().max(b.start.clone()); + let end = a.end.clone().min(b.end.clone()); + if start < end { + Some(Range { start, end }) + } else { + None + } +} + +fn score_snippet( + identifier: &Identifier, + references: &[Reference], + declaration: Declaration, + is_same_file: bool, + declaration_line_distance: u32, + declaration_line_distance_rank: usize, + same_file_declaration_count: usize, + declaration_count: usize, + containing_range_identifier_occurrences: &IdentifierOccurrences, + adjacent_identifier_occurrences: &IdentifierOccurrences, + cursor: Point, + current_buffer: &BufferSnapshot, +) -> Option { + let is_referenced_nearby = references + .iter() + .any(|r| r.region == ReferenceRegion::Nearby); + let is_referenced_in_breadcrumb = references + .iter() + .any(|r| r.region == ReferenceRegion::Breadcrumb); + let reference_count = references.len(); + let reference_line_distance = references + .iter() + .map(|r| { + let reference_line = r.range.start.to_point(current_buffer).row as i32; + (cursor.row as i32 - reference_line).unsigned_abs() + }) + .min() + .unwrap(); + + let item_source_occurrences = IdentifierOccurrences::within_string(&declaration.item_text().0); + let item_signature_occurrences = + IdentifierOccurrences::within_string(&declaration.signature_text().0); + let containing_range_vs_item_jaccard = jaccard_similarity( + containing_range_identifier_occurrences, + &item_source_occurrences, + ); + let containing_range_vs_signature_jaccard = jaccard_similarity( + containing_range_identifier_occurrences, + &item_signature_occurrences, + ); + let adjacent_vs_item_jaccard = + jaccard_similarity(adjacent_identifier_occurrences, &item_source_occurrences); + let adjacent_vs_signature_jaccard = + jaccard_similarity(adjacent_identifier_occurrences, &item_signature_occurrences); + + let containing_range_vs_item_weighted_overlap = weighted_overlap_coefficient( + containing_range_identifier_occurrences, + &item_source_occurrences, + ); + let containing_range_vs_signature_weighted_overlap = weighted_overlap_coefficient( + containing_range_identifier_occurrences, + &item_signature_occurrences, + ); + let adjacent_vs_item_weighted_overlap = + weighted_overlap_coefficient(adjacent_identifier_occurrences, &item_source_occurrences); + let adjacent_vs_signature_weighted_overlap = + weighted_overlap_coefficient(adjacent_identifier_occurrences, &item_signature_occurrences); + + let score_components = ScoreInputs { + is_same_file, + is_referenced_nearby, + is_referenced_in_breadcrumb, + reference_line_distance, + declaration_line_distance, + declaration_line_distance_rank, + reference_count, + same_file_declaration_count, + declaration_count, + containing_range_vs_item_jaccard, + containing_range_vs_signature_jaccard, + adjacent_vs_item_jaccard, + adjacent_vs_signature_jaccard, + containing_range_vs_item_weighted_overlap, + containing_range_vs_signature_weighted_overlap, + adjacent_vs_item_weighted_overlap, + adjacent_vs_signature_weighted_overlap, + }; + + Some(ScoredSnippet { + identifier: identifier.clone(), + declaration: declaration, + scores: score_components.score(), + score_components, + }) +} + +#[derive(Clone, Debug, Serialize)] +pub struct ScoreInputs { + pub is_same_file: bool, + pub is_referenced_nearby: bool, + pub is_referenced_in_breadcrumb: bool, + pub reference_count: usize, + pub same_file_declaration_count: usize, + pub declaration_count: usize, + pub reference_line_distance: u32, + pub declaration_line_distance: u32, + pub declaration_line_distance_rank: usize, + pub containing_range_vs_item_jaccard: f32, + pub containing_range_vs_signature_jaccard: f32, + pub adjacent_vs_item_jaccard: f32, + pub adjacent_vs_signature_jaccard: f32, + pub containing_range_vs_item_weighted_overlap: f32, + pub containing_range_vs_signature_weighted_overlap: f32, + pub adjacent_vs_item_weighted_overlap: f32, + pub adjacent_vs_signature_weighted_overlap: f32, +} + +#[derive(Clone, Debug, Serialize)] +pub struct Scores { + pub signature: f32, + pub declaration: f32, +} + +impl ScoreInputs { + fn score(&self) -> Scores { + // Score related to how likely this is the correct declaration, range 0 to 1 + let accuracy_score = if self.is_same_file { + // TODO: use declaration_line_distance_rank + 1.0 / self.same_file_declaration_count as f32 + } else { + 1.0 / self.declaration_count as f32 + }; + + // Score related to the distance between the reference and cursor, range 0 to 1 + let distance_score = if self.is_referenced_nearby { + 1.0 / (1.0 + self.reference_line_distance as f32 / 10.0).powf(2.0) + } else { + // same score as ~14 lines away, rationale is to not overly penalize references from parent signatures + 0.5 + }; + + // For now instead of linear combination, the scores are just multiplied together. + let combined_score = 10.0 * accuracy_score * distance_score; + + Scores { + signature: combined_score * self.containing_range_vs_signature_weighted_overlap, + // declaration score gets boosted both by being multiplied by 2 and by there being more + // weighted overlap. + declaration: 2.0 * combined_score * self.containing_range_vs_item_weighted_overlap, + } + } +} diff --git a/crates/edit_prediction_context/src/edit_prediction_context.rs b/crates/edit_prediction_context/src/edit_prediction_context.rs index acfb89880c3ed9e7b1ebcacd4b5fa313830165ba..5d73dc7f7dcf2223ae1f23b22c2f104842206e12 100644 --- a/crates/edit_prediction_context/src/edit_prediction_context.rs +++ b/crates/edit_prediction_context/src/edit_prediction_context.rs @@ -1,8 +1,220 @@ +mod declaration; +mod declaration_scoring; mod excerpt; mod outline; mod reference; -mod tree_sitter_index; +mod syntax_index; +mod text_similarity; +pub use declaration::{BufferDeclaration, Declaration, FileDeclaration, Identifier}; pub use excerpt::{EditPredictionExcerpt, EditPredictionExcerptOptions, EditPredictionExcerptText}; +use gpui::{App, AppContext as _, Entity, Task}; +use language::BufferSnapshot; pub use reference::references_in_excerpt; -pub use tree_sitter_index::{BufferDeclaration, Declaration, FileDeclaration, TreeSitterIndex}; +pub use syntax_index::SyntaxIndex; +use text::{Point, ToOffset as _}; + +use crate::declaration_scoring::{ScoredSnippet, scored_snippets}; + +pub struct EditPredictionContext { + pub excerpt: EditPredictionExcerpt, + pub excerpt_text: EditPredictionExcerptText, + pub snippets: Vec, +} + +impl EditPredictionContext { + pub fn gather( + cursor_point: Point, + buffer: BufferSnapshot, + excerpt_options: EditPredictionExcerptOptions, + syntax_index: Entity, + cx: &mut App, + ) -> Task { + let index_state = syntax_index.read_with(cx, |index, _cx| index.state().clone()); + cx.background_spawn(async move { + let index_state = index_state.lock().await; + + let excerpt = + EditPredictionExcerpt::select_from_buffer(cursor_point, &buffer, &excerpt_options) + .unwrap(); + let excerpt_text = excerpt.text(&buffer); + let references = references_in_excerpt(&excerpt, &excerpt_text, &buffer); + let cursor_offset = cursor_point.to_offset(&buffer); + + let snippets = scored_snippets( + &index_state, + &excerpt, + &excerpt_text, + references, + cursor_offset, + &buffer, + ); + + Self { + excerpt, + excerpt_text, + snippets, + } + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::sync::Arc; + + use gpui::{Entity, TestAppContext}; + use indoc::indoc; + use language::{Language, LanguageConfig, LanguageId, LanguageMatcher, tree_sitter_rust}; + use project::{FakeFs, Project}; + use serde_json::json; + use settings::SettingsStore; + use util::path; + + use crate::{EditPredictionExcerptOptions, SyntaxIndex}; + + #[gpui::test] + async fn test_call_site(cx: &mut TestAppContext) { + let (project, index, _rust_lang_id) = init_test(cx).await; + + let buffer = project + .update(cx, |project, cx| { + let project_path = project.find_project_path("c.rs", cx).unwrap(); + project.open_buffer(project_path, cx) + }) + .await + .unwrap(); + + cx.run_until_parked(); + + // first process_data call site + let cursor_point = language::Point::new(8, 21); + let buffer_snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot()); + + let context = cx + .update(|cx| { + EditPredictionContext::gather( + cursor_point, + buffer_snapshot, + EditPredictionExcerptOptions { + max_bytes: 40, + min_bytes: 10, + target_before_cursor_over_total_bytes: 0.5, + include_parent_signatures: false, + }, + index, + cx, + ) + }) + .await; + + assert_eq!(context.snippets.len(), 1); + assert_eq!(context.snippets[0].identifier.name.as_ref(), "process_data"); + drop(buffer); + } + + async fn init_test( + cx: &mut TestAppContext, + ) -> (Entity, Entity, LanguageId) { + cx.update(|cx| { + let settings_store = SettingsStore::test(cx); + cx.set_global(settings_store); + language::init(cx); + Project::init_settings(cx); + }); + + let fs = FakeFs::new(cx.executor()); + fs.insert_tree( + path!("/root"), + json!({ + "a.rs": indoc! {r#" + fn main() { + let x = 1; + let y = 2; + let z = add(x, y); + println!("Result: {}", z); + } + + fn add(a: i32, b: i32) -> i32 { + a + b + } + "#}, + "b.rs": indoc! {" + pub struct Config { + pub name: String, + pub value: i32, + } + + impl Config { + pub fn new(name: String, value: i32) -> Self { + Config { name, value } + } + } + "}, + "c.rs": indoc! {r#" + use std::collections::HashMap; + + fn main() { + let args: Vec = std::env::args().collect(); + let data: Vec = args[1..] + .iter() + .filter_map(|s| s.parse().ok()) + .collect(); + let result = process_data(data); + println!("{:?}", result); + } + + fn process_data(data: Vec) -> HashMap { + let mut counts = HashMap::new(); + for value in data { + *counts.entry(value).or_insert(0) += 1; + } + counts + } + + #[cfg(test)] + mod tests { + use super::*; + + #[test] + fn test_process_data() { + let data = vec![1, 2, 2, 3]; + let result = process_data(data); + assert_eq!(result.get(&2), Some(&2)); + } + } + "#} + }), + ) + .await; + let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; + let language_registry = project.read_with(cx, |project, _| project.languages().clone()); + let lang = rust_lang(); + let lang_id = lang.id(); + language_registry.add(Arc::new(lang)); + + let index = cx.new(|cx| SyntaxIndex::new(&project, cx)); + cx.run_until_parked(); + + (project, index, lang_id) + } + + fn rust_lang() -> Language { + Language::new( + LanguageConfig { + name: "Rust".into(), + matcher: LanguageMatcher { + path_suffixes: vec!["rs".to_string()], + ..Default::default() + }, + ..Default::default() + }, + Some(tree_sitter_rust::LANGUAGE.into()), + ) + .with_highlights_query(include_str!("../../languages/src/rust/highlights.scm")) + .unwrap() + .with_outline_query(include_str!("../../languages/src/rust/outline.scm")) + .unwrap() + } +} diff --git a/crates/edit_prediction_context/src/excerpt.rs b/crates/edit_prediction_context/src/excerpt.rs index c6caa6a1b7b4076cf739c1ac198656b9fba431a6..da1de042623167d17f078c1e85b461fb0ecc8c24 100644 --- a/crates/edit_prediction_context/src/excerpt.rs +++ b/crates/edit_prediction_context/src/excerpt.rs @@ -31,7 +31,7 @@ pub struct EditPredictionExcerptOptions { pub include_parent_signatures: bool, } -#[derive(Clone)] +#[derive(Debug, Clone)] pub struct EditPredictionExcerpt { pub range: Range, pub parent_signature_ranges: Vec>, diff --git a/crates/edit_prediction_context/src/outline.rs b/crates/edit_prediction_context/src/outline.rs index 492352add1fd4c666eab3b12989f9b801d03570f..ec02c869dfae4cb861206cb801c285462e734f36 100644 --- a/crates/edit_prediction_context/src/outline.rs +++ b/crates/edit_prediction_context/src/outline.rs @@ -1,5 +1,7 @@ -use language::{BufferSnapshot, LanguageId, SyntaxMapMatches}; -use std::{cmp::Reverse, ops::Range, sync::Arc}; +use language::{BufferSnapshot, SyntaxMapMatches}; +use std::{cmp::Reverse, ops::Range}; + +use crate::declaration::Identifier; // TODO: // @@ -18,12 +20,6 @@ pub struct OutlineDeclaration { pub signature_range: Range, } -#[derive(Debug, Clone, Eq, PartialEq, Hash)] -pub struct Identifier { - pub name: Arc, - pub language_id: LanguageId, -} - pub fn declarations_in_buffer(buffer: &BufferSnapshot) -> Vec { declarations_overlapping_range(0..buffer.len(), buffer) } diff --git a/crates/edit_prediction_context/src/reference.rs b/crates/edit_prediction_context/src/reference.rs index 65d34e73bf20f62b24ac2a654af43fc3b83041a9..ee2fc7ba573c3909b5a650e3ca0ff20155272b9f 100644 --- a/crates/edit_prediction_context/src/reference.rs +++ b/crates/edit_prediction_context/src/reference.rs @@ -3,8 +3,8 @@ use std::collections::HashMap; use std::ops::Range; use crate::{ + declaration::Identifier, excerpt::{EditPredictionExcerpt, EditPredictionExcerptText}, - outline::Identifier, }; #[derive(Debug)] diff --git a/crates/edit_prediction_context/src/tree_sitter_index.rs b/crates/edit_prediction_context/src/syntax_index.rs similarity index 58% rename from crates/edit_prediction_context/src/tree_sitter_index.rs rename to crates/edit_prediction_context/src/syntax_index.rs index f905aa7a01f29d26083d219bc8d2bd600847036a..852973dd7296647b0f868c3f9242ed59b81b6743 100644 --- a/crates/edit_prediction_context/src/tree_sitter_index.rs +++ b/crates/edit_prediction_context/src/syntax_index.rs @@ -1,20 +1,26 @@ +use std::sync::Arc; + use collections::{HashMap, HashSet}; +use futures::lock::Mutex; use gpui::{App, AppContext as _, Context, Entity, Task, WeakEntity}; -use language::{Buffer, BufferEvent, BufferSnapshot}; +use language::{Buffer, BufferEvent}; use project::buffer_store::{BufferStore, BufferStoreEvent}; use project::worktree_store::{WorktreeStore, WorktreeStoreEvent}; use project::{PathChange, Project, ProjectEntryId, ProjectPath}; use slotmap::SlotMap; -use std::ops::Range; -use std::sync::Arc; -use text::Anchor; +use text::BufferId; use util::{ResultExt as _, debug_panic, some_or_debug_panic}; -use crate::outline::{Identifier, OutlineDeclaration, declarations_in_buffer}; +use crate::declaration::{ + BufferDeclaration, Declaration, DeclarationId, FileDeclaration, Identifier, +}; +use crate::outline::declarations_in_buffer; // TODO: // // * Skip for remote projects +// +// * Consider making SyntaxIndex not an Entity. // Potential future improvements: // @@ -34,17 +40,19 @@ use crate::outline::{Identifier, OutlineDeclaration, declarations_in_buffer}; // * Concurrent slotmap // // * Use queue for parsing +// -slotmap::new_key_type! { - pub struct DeclarationId; +pub struct SyntaxIndex { + state: Arc>, + project: WeakEntity, } -pub struct TreeSitterIndex { +#[derive(Default)] +pub struct SyntaxIndexState { declarations: SlotMap, identifiers: HashMap>, files: HashMap, - buffers: HashMap, BufferState>, - project: WeakEntity, + buffers: HashMap, } #[derive(Debug, Default)] @@ -59,52 +67,11 @@ struct BufferState { task: Option>, } -#[derive(Debug, Clone)] -pub enum Declaration { - File { - project_entry_id: ProjectEntryId, - declaration: FileDeclaration, - }, - Buffer { - buffer: WeakEntity, - declaration: BufferDeclaration, - }, -} - -impl Declaration { - fn identifier(&self) -> &Identifier { - match self { - Declaration::File { declaration, .. } => &declaration.identifier, - Declaration::Buffer { declaration, .. } => &declaration.identifier, - } - } -} - -#[derive(Debug, Clone)] -pub struct FileDeclaration { - pub parent: Option, - pub identifier: Identifier, - pub item_range: Range, - pub signature_range: Range, - pub signature_text: Arc, -} - -#[derive(Debug, Clone)] -pub struct BufferDeclaration { - pub parent: Option, - pub identifier: Identifier, - pub item_range: Range, - pub signature_range: Range, -} - -impl TreeSitterIndex { +impl SyntaxIndex { pub fn new(project: &Entity, cx: &mut Context) -> Self { let mut this = Self { - declarations: SlotMap::with_key(), - identifiers: HashMap::default(), project: project.downgrade(), - files: HashMap::default(), - buffers: HashMap::default(), + state: Arc::new(Mutex::new(SyntaxIndexState::default())), }; let worktree_store = project.read(cx).worktree_store(); @@ -139,73 +106,6 @@ impl TreeSitterIndex { this } - pub fn declaration(&self, id: DeclarationId) -> Option<&Declaration> { - self.declarations.get(id) - } - - pub fn declarations_for_identifier( - &self, - identifier: Identifier, - cx: &App, - ) -> Vec { - // make sure to not have a large stack allocation - assert!(N < 32); - - let Some(declaration_ids) = self.identifiers.get(&identifier) else { - return vec![]; - }; - - let mut result = Vec::with_capacity(N); - let mut included_buffer_entry_ids = arrayvec::ArrayVec::<_, N>::new(); - let mut file_declarations = Vec::new(); - - for declaration_id in declaration_ids { - let declaration = self.declarations.get(*declaration_id); - let Some(declaration) = some_or_debug_panic(declaration) else { - continue; - }; - match declaration { - Declaration::Buffer { buffer, .. } => { - if let Ok(Some(entry_id)) = buffer.read_with(cx, |buffer, cx| { - project::File::from_dyn(buffer.file()).and_then(|f| f.project_entry_id(cx)) - }) { - included_buffer_entry_ids.push(entry_id); - result.push(declaration.clone()); - if result.len() == N { - return result; - } - } - } - Declaration::File { - project_entry_id, .. - } => { - if !included_buffer_entry_ids.contains(project_entry_id) { - file_declarations.push(declaration.clone()); - } - } - } - } - - for declaration in file_declarations { - match declaration { - Declaration::File { - project_entry_id, .. - } => { - if !included_buffer_entry_ids.contains(&project_entry_id) { - result.push(declaration); - - if result.len() == N { - return result; - } - } - } - Declaration::Buffer { .. } => {} - } - } - - result - } - fn handle_worktree_store_event( &mut self, _worktree_store: Entity, @@ -215,21 +115,33 @@ impl TreeSitterIndex { use WorktreeStoreEvent::*; match event { WorktreeUpdatedEntries(worktree_id, updated_entries_set) => { - for (path, entry_id, path_change) in updated_entries_set.iter() { - if let PathChange::Removed = path_change { - self.files.remove(entry_id); - } else { - let project_path = ProjectPath { - worktree_id: *worktree_id, - path: path.clone(), - }; - self.update_file(*entry_id, project_path, cx); + let state = Arc::downgrade(&self.state); + let worktree_id = *worktree_id; + let updated_entries_set = updated_entries_set.clone(); + cx.spawn(async move |this, cx| { + let Some(state) = state.upgrade() else { return }; + for (path, entry_id, path_change) in updated_entries_set.iter() { + if let PathChange::Removed = path_change { + state.lock().await.files.remove(entry_id); + } else { + let project_path = ProjectPath { + worktree_id, + path: path.clone(), + }; + this.update(cx, |this, cx| { + this.update_file(*entry_id, project_path, cx); + }) + .ok(); + } } - } + }) + .detach(); } WorktreeDeletedEntry(_worktree_id, project_entry_id) => { - // TODO: Is this needed? - self.files.remove(project_entry_id); + let project_entry_id = *project_entry_id; + self.with_state(cx, move |state| { + state.files.remove(&project_entry_id); + }) } _ => {} } @@ -251,15 +163,42 @@ impl TreeSitterIndex { } } + pub fn state(&self) -> &Arc> { + &self.state + } + + fn with_state(&self, cx: &mut App, f: impl FnOnce(&mut SyntaxIndexState) + Send + 'static) { + if let Some(mut state) = self.state.try_lock() { + f(&mut state); + return; + } + let state = Arc::downgrade(&self.state); + cx.background_spawn(async move { + let Some(state) = state.upgrade() else { + return; + }; + let mut state = state.lock().await; + f(&mut state) + }) + .detach(); + } + fn register_buffer(&mut self, buffer: &Entity, cx: &mut Context) { - self.buffers - .insert(buffer.downgrade(), BufferState::default()); - let weak_buf = buffer.downgrade(); - cx.observe_release(buffer, move |this, _buffer, _cx| { - this.buffers.remove(&weak_buf); + let buffer_id = buffer.read(cx).remote_id(); + cx.observe_release(buffer, move |this, _buffer, cx| { + this.with_state(cx, move |state| { + if let Some(buffer_state) = state.buffers.remove(&buffer_id) { + SyntaxIndexState::remove_buffer_declarations( + &buffer_state.declarations, + &mut state.declarations, + &mut state.identifiers, + ); + } + }) }) .detach(); cx.subscribe(buffer, Self::handle_buffer_event).detach(); + self.update_buffer(buffer.clone(), cx); } @@ -275,10 +214,19 @@ impl TreeSitterIndex { } } - fn update_buffer(&mut self, buffer: Entity, cx: &Context) { - let mut parse_status = buffer.read(cx).parse_status(); + fn update_buffer(&mut self, buffer_entity: Entity, cx: &mut Context) { + let buffer = buffer_entity.read(cx); + + let Some(project_entry_id) = + project::File::from_dyn(buffer.file()).and_then(|f| f.project_entry_id(cx)) + else { + return; + }; + let buffer_id = buffer.remote_id(); + + let mut parse_status = buffer.parse_status(); let snapshot_task = cx.spawn({ - let weak_buffer = buffer.downgrade(); + let weak_buffer = buffer_entity.downgrade(); async move |_, cx| { while *parse_status.borrow() != language::ParseStatus::Idle { parse_status.changed().await?; @@ -289,75 +237,77 @@ impl TreeSitterIndex { let parse_task = cx.background_spawn(async move { let snapshot = snapshot_task.await?; + let rope = snapshot.text.as_rope().clone(); - anyhow::Ok( + anyhow::Ok(( declarations_in_buffer(&snapshot) .into_iter() .map(|item| { ( item.parent_index, - BufferDeclaration::from_outline(item, &snapshot), + BufferDeclaration::from_outline(item, &rope), ) }) .collect::>(), - ) + rope, + )) }); let task = cx.spawn({ - let weak_buffer = buffer.downgrade(); async move |this, cx| { - let Ok(declarations) = parse_task.await else { + let Ok((declarations, rope)) = parse_task.await else { return; }; - this.update(cx, |this, _cx| { - let buffer_state = this - .buffers - .entry(weak_buffer.clone()) - .or_insert_with(Default::default); - - for old_declaration_id in &buffer_state.declarations { - let Some(declaration) = this.declarations.remove(*old_declaration_id) - else { - debug_panic!("declaration not found"); - continue; - }; - if let Some(identifier_declarations) = - this.identifiers.get_mut(declaration.identifier()) - { - identifier_declarations.remove(old_declaration_id); + this.update(cx, move |this, cx| { + this.with_state(cx, move |state| { + let buffer_state = state + .buffers + .entry(buffer_id) + .or_insert_with(Default::default); + + SyntaxIndexState::remove_buffer_declarations( + &buffer_state.declarations, + &mut state.declarations, + &mut state.identifiers, + ); + + let mut new_ids = Vec::with_capacity(declarations.len()); + state.declarations.reserve(declarations.len()); + for (parent_index, mut declaration) in declarations { + declaration.parent = parent_index + .and_then(|ix| some_or_debug_panic(new_ids.get(ix).copied())); + + let identifier = declaration.identifier.clone(); + let declaration_id = state.declarations.insert(Declaration::Buffer { + rope: rope.clone(), + buffer_id, + declaration, + project_entry_id, + }); + new_ids.push(declaration_id); + + state + .identifiers + .entry(identifier) + .or_default() + .insert(declaration_id); } - } - - let mut new_ids = Vec::with_capacity(declarations.len()); - this.declarations.reserve(declarations.len()); - for (parent_index, mut declaration) in declarations { - declaration.parent = parent_index - .and_then(|ix| some_or_debug_panic(new_ids.get(ix).copied())); - - let identifier = declaration.identifier.clone(); - let declaration_id = this.declarations.insert(Declaration::Buffer { - buffer: weak_buffer.clone(), - declaration, - }); - new_ids.push(declaration_id); - - this.identifiers - .entry(identifier) - .or_default() - .insert(declaration_id); - } - buffer_state.declarations = new_ids; + buffer_state.declarations = new_ids; + }); }) .ok(); } }); - self.buffers - .entry(buffer.downgrade()) - .or_insert_with(Default::default) - .task = Some(task); + self.with_state(cx, move |state| { + state + .buffers + .entry(buffer_id) + .or_insert_with(Default::default) + .task = Some(task) + }); } fn update_file( @@ -401,14 +351,10 @@ impl TreeSitterIndex { let parse_task = cx.background_spawn(async move { let snapshot = snapshot_task.await?; + let rope = snapshot.as_rope(); let declarations = declarations_in_buffer(&snapshot) .into_iter() - .map(|item| { - ( - item.parent_index, - FileDeclaration::from_outline(item, &snapshot), - ) - }) + .map(|item| (item.parent_index, FileDeclaration::from_outline(item, rope))) .collect::>(); anyhow::Ok(declarations) }); @@ -419,84 +365,160 @@ impl TreeSitterIndex { let Ok(declarations) = parse_task.await else { return; }; - this.update(cx, |this, _cx| { - let file_state = this.files.entry(entry_id).or_insert_with(Default::default); - - for old_declaration_id in &file_state.declarations { - let Some(declaration) = this.declarations.remove(*old_declaration_id) - else { - debug_panic!("declaration not found"); - continue; - }; - if let Some(identifier_declarations) = - this.identifiers.get_mut(declaration.identifier()) - { - identifier_declarations.remove(old_declaration_id); + this.update(cx, |this, cx| { + this.with_state(cx, move |state| { + let file_state = + state.files.entry(entry_id).or_insert_with(Default::default); + + for old_declaration_id in &file_state.declarations { + let Some(declaration) = state.declarations.remove(*old_declaration_id) + else { + debug_panic!("declaration not found"); + continue; + }; + if let Some(identifier_declarations) = + state.identifiers.get_mut(declaration.identifier()) + { + identifier_declarations.remove(old_declaration_id); + } } - } - - let mut new_ids = Vec::with_capacity(declarations.len()); - this.declarations.reserve(declarations.len()); - for (parent_index, mut declaration) in declarations { - declaration.parent = parent_index - .and_then(|ix| some_or_debug_panic(new_ids.get(ix).copied())); - - let identifier = declaration.identifier.clone(); - let declaration_id = this.declarations.insert(Declaration::File { - project_entry_id: entry_id, - declaration, - }); - new_ids.push(declaration_id); - - this.identifiers - .entry(identifier) - .or_default() - .insert(declaration_id); - } + let mut new_ids = Vec::with_capacity(declarations.len()); + state.declarations.reserve(declarations.len()); + + for (parent_index, mut declaration) in declarations { + declaration.parent = parent_index + .and_then(|ix| some_or_debug_panic(new_ids.get(ix).copied())); + + let identifier = declaration.identifier.clone(); + let declaration_id = state.declarations.insert(Declaration::File { + project_entry_id: entry_id, + declaration, + }); + new_ids.push(declaration_id); + + state + .identifiers + .entry(identifier) + .or_default() + .insert(declaration_id); + } - file_state.declarations = new_ids; + file_state.declarations = new_ids; + }); }) .ok(); } }); - self.files - .entry(entry_id) - .or_insert_with(Default::default) - .task = Some(task); + self.with_state(cx, move |state| { + state + .files + .entry(entry_id) + .or_insert_with(Default::default) + .task = Some(task); + }); } } -impl BufferDeclaration { - pub fn from_outline(declaration: OutlineDeclaration, snapshot: &BufferSnapshot) -> Self { - // use of anchor_before is a guess that the proper behavior is to expand to include - // insertions immediately before the declaration, but not for insertions immediately after - Self { - parent: None, - identifier: declaration.identifier, - item_range: snapshot.anchor_before(declaration.item_range.start) - ..snapshot.anchor_before(declaration.item_range.end), - signature_range: snapshot.anchor_before(declaration.signature_range.start) - ..snapshot.anchor_before(declaration.signature_range.end), +impl SyntaxIndexState { + pub fn declaration(&self, id: DeclarationId) -> Option<&Declaration> { + self.declarations.get(id) + } + + /// Returns declarations for the identifier. If the limit is exceeded, returns an empty vector. + /// + /// TODO: Consider doing some pre-ranking and instead truncating when N is exceeded. + pub fn declarations_for_identifier( + &self, + identifier: &Identifier, + ) -> Vec { + // make sure to not have a large stack allocation + assert!(N < 32); + + let Some(declaration_ids) = self.identifiers.get(&identifier) else { + return vec![]; + }; + + let mut result = Vec::with_capacity(N); + let mut included_buffer_entry_ids = arrayvec::ArrayVec::<_, N>::new(); + let mut file_declarations = Vec::new(); + + for declaration_id in declaration_ids { + let declaration = self.declarations.get(*declaration_id); + let Some(declaration) = some_or_debug_panic(declaration) else { + continue; + }; + match declaration { + Declaration::Buffer { + project_entry_id, .. + } => { + included_buffer_entry_ids.push(*project_entry_id); + result.push(declaration.clone()); + if result.len() == N { + return Vec::new(); + } + } + Declaration::File { + project_entry_id, .. + } => { + if !included_buffer_entry_ids.contains(&project_entry_id) { + file_declarations.push(declaration.clone()); + } + } + } + } + + for declaration in file_declarations { + match declaration { + Declaration::File { + project_entry_id, .. + } => { + if !included_buffer_entry_ids.contains(&project_entry_id) { + result.push(declaration); + + if result.len() == N { + return Vec::new(); + } + } + } + Declaration::Buffer { .. } => {} + } + } + + result + } + + pub fn file_declaration_count(&self, declaration: &Declaration) -> usize { + match declaration { + Declaration::File { + project_entry_id, .. + } => self + .files + .get(project_entry_id) + .map(|file_state| file_state.declarations.len()) + .unwrap_or_default(), + Declaration::Buffer { buffer_id, .. } => self + .buffers + .get(buffer_id) + .map(|buffer_state| buffer_state.declarations.len()) + .unwrap_or_default(), } } -} -impl FileDeclaration { - pub fn from_outline( - declaration: OutlineDeclaration, - snapshot: &BufferSnapshot, - ) -> FileDeclaration { - FileDeclaration { - parent: None, - identifier: declaration.identifier, - item_range: declaration.item_range, - signature_text: snapshot - .text_for_range(declaration.signature_range.clone()) - .collect::() - .into(), - signature_range: declaration.signature_range, + fn remove_buffer_declarations( + old_declaration_ids: &[DeclarationId], + declarations: &mut SlotMap, + identifiers: &mut HashMap>, + ) { + for old_declaration_id in old_declaration_ids { + let Some(declaration) = declarations.remove(*old_declaration_id) else { + debug_panic!("declaration not found"); + continue; + }; + if let Some(identifier_declarations) = identifiers.get_mut(declaration.identifier()) { + identifier_declarations.remove(old_declaration_id); + } } } } @@ -509,13 +531,13 @@ mod tests { use gpui::TestAppContext; use indoc::indoc; use language::{Language, LanguageConfig, LanguageId, LanguageMatcher, tree_sitter_rust}; - use project::{FakeFs, Project, ProjectItem}; + use project::{FakeFs, Project}; use serde_json::json; use settings::SettingsStore; use text::OffsetRangeExt as _; use util::path; - use crate::tree_sitter_index::TreeSitterIndex; + use crate::syntax_index::SyntaxIndex; #[gpui::test] async fn test_unopen_indexed_files(cx: &mut TestAppContext) { @@ -525,17 +547,19 @@ mod tests { language_id: rust_lang_id, }; - index.read_with(cx, |index, cx| { - let decls = index.declarations_for_identifier::<8>(main.clone(), cx); + let index_state = index.read_with(cx, |index, _cx| index.state().clone()); + let index_state = index_state.lock().await; + cx.update(|cx| { + let decls = index_state.declarations_for_identifier::<8>(&main); assert_eq!(decls.len(), 2); let decl = expect_file_decl("c.rs", &decls[0], &project, cx); assert_eq!(decl.identifier, main.clone()); - assert_eq!(decl.item_range, 32..279); + assert_eq!(decl.item_range_in_file, 32..280); let decl = expect_file_decl("a.rs", &decls[1], &project, cx); assert_eq!(decl.identifier, main); - assert_eq!(decl.item_range, 0..97); + assert_eq!(decl.item_range_in_file, 0..98); }); } @@ -547,15 +571,17 @@ mod tests { language_id: rust_lang_id, }; - index.read_with(cx, |index, cx| { - let decls = index.declarations_for_identifier::<8>(test_process_data.clone(), cx); + let index_state = index.read_with(cx, |index, _cx| index.state().clone()); + let index_state = index_state.lock().await; + cx.update(|cx| { + let decls = index_state.declarations_for_identifier::<8>(&test_process_data); assert_eq!(decls.len(), 1); let decl = expect_file_decl("c.rs", &decls[0], &project, cx); assert_eq!(decl.identifier, test_process_data); let parent_id = decl.parent.unwrap(); - let parent = index.declaration(parent_id).unwrap(); + let parent = index_state.declaration(parent_id).unwrap(); let parent_decl = expect_file_decl("c.rs", &parent, &project, cx); assert_eq!( parent_decl.identifier, @@ -586,16 +612,18 @@ mod tests { cx.run_until_parked(); - index.read_with(cx, |index, cx| { - let decls = index.declarations_for_identifier::<8>(test_process_data.clone(), cx); + let index_state = index.read_with(cx, |index, _cx| index.state().clone()); + let index_state = index_state.lock().await; + cx.update(|cx| { + let decls = index_state.declarations_for_identifier::<8>(&test_process_data); assert_eq!(decls.len(), 1); - let decl = expect_buffer_decl("c.rs", &decls[0], cx); + let decl = expect_buffer_decl("c.rs", &decls[0], &project, cx); assert_eq!(decl.identifier, test_process_data); let parent_id = decl.parent.unwrap(); - let parent = index.declaration(parent_id).unwrap(); - let parent_decl = expect_buffer_decl("c.rs", &parent, cx); + let parent = index_state.declaration(parent_id).unwrap(); + let parent_decl = expect_buffer_decl("c.rs", &parent, &project, cx); assert_eq!( parent_decl.identifier, Identifier { @@ -613,16 +641,13 @@ mod tests { async fn test_declarations_limt(cx: &mut TestAppContext) { let (_, index, rust_lang_id) = init_test(cx).await; - index.read_with(cx, |index, cx| { - let decls = index.declarations_for_identifier::<1>( - Identifier { - name: "main".into(), - language_id: rust_lang_id, - }, - cx, - ); - assert_eq!(decls.len(), 1); + let index_state = index.read_with(cx, |index, _cx| index.state().clone()); + let index_state = index_state.lock().await; + let decls = index_state.declarations_for_identifier::<1>(&Identifier { + name: "main".into(), + language_id: rust_lang_id, }); + assert_eq!(decls.len(), 0); } #[gpui::test] @@ -644,24 +669,31 @@ mod tests { cx.run_until_parked(); - index.read_with(cx, |index, cx| { - let decls = index.declarations_for_identifier::<8>(main.clone(), cx); - assert_eq!(decls.len(), 2); - let decl = expect_buffer_decl("c.rs", &decls[0], cx); - assert_eq!(decl.identifier, main); - assert_eq!(decl.item_range.to_offset(&buffer.read(cx)), 32..279); + let index_state_arc = index.read_with(cx, |index, _cx| index.state().clone()); + { + let index_state = index_state_arc.lock().await; - expect_file_decl("a.rs", &decls[1], &project, cx); - }); + cx.update(|cx| { + let decls = index_state.declarations_for_identifier::<8>(&main); + assert_eq!(decls.len(), 2); + let decl = expect_buffer_decl("c.rs", &decls[0], &project, cx); + assert_eq!(decl.identifier, main); + assert_eq!(decl.item_range.to_offset(&buffer.read(cx)), 32..280); + + expect_file_decl("a.rs", &decls[1], &project, cx); + }); + } - // Need to trigger flush_effects so that the observe_release handler will run. - cx.update(|_cx| { + // Drop the buffer and wait for release + cx.update(|_| { drop(buffer); }); cx.run_until_parked(); - index.read_with(cx, |index, cx| { - let decls = index.declarations_for_identifier::<8>(main, cx); + let index_state = index_state_arc.lock().await; + + cx.update(|cx| { + let decls = index_state.declarations_for_identifier::<8>(&main); assert_eq!(decls.len(), 2); expect_file_decl("c.rs", &decls[0], &project, cx); expect_file_decl("a.rs", &decls[1], &project, cx); @@ -671,24 +703,20 @@ mod tests { fn expect_buffer_decl<'a>( path: &str, declaration: &'a Declaration, + project: &Entity, cx: &App, ) -> &'a BufferDeclaration { if let Declaration::Buffer { declaration, - buffer, + project_entry_id, + .. } = declaration { - assert_eq!( - buffer - .upgrade() - .unwrap() - .read(cx) - .project_path(cx) - .unwrap() - .path - .as_ref(), - Path::new(path), - ); + let project_path = project + .read(cx) + .path_for_entry(*project_entry_id, cx) + .unwrap(); + assert_eq!(project_path.path.as_ref(), Path::new(path),); declaration } else { panic!("Expected a buffer declaration, found {:?}", declaration); @@ -723,7 +751,7 @@ mod tests { async fn init_test( cx: &mut TestAppContext, - ) -> (Entity, Entity, LanguageId) { + ) -> (Entity, Entity, LanguageId) { cx.update(|cx| { let settings_store = SettingsStore::test(cx); cx.set_global(settings_store); @@ -801,7 +829,7 @@ mod tests { let lang_id = lang.id(); language_registry.add(Arc::new(lang)); - let index = cx.new(|cx| TreeSitterIndex::new(&project, cx)); + let index = cx.new(|cx| SyntaxIndex::new(&project, cx)); cx.run_until_parked(); (project, index, lang_id) diff --git a/crates/edit_prediction_context/src/text_similarity.rs b/crates/edit_prediction_context/src/text_similarity.rs new file mode 100644 index 0000000000000000000000000000000000000000..f7a9822ecca01f1b6ff1dc04bdc12fbcddc5159b --- /dev/null +++ b/crates/edit_prediction_context/src/text_similarity.rs @@ -0,0 +1,241 @@ +use regex::Regex; +use std::{collections::HashMap, sync::LazyLock}; + +use crate::reference::Reference; + +// TODO: Consider implementing sliding window similarity matching like +// https://github.com/sourcegraph/cody-public-snapshot/blob/8e20ac6c1460c08b0db581c0204658112a246eda/vscode/src/completions/context/retrievers/jaccard-similarity/bestJaccardMatch.ts +// +// That implementation could actually be more efficient - no need to track words in the window that +// are not in the query. + +static IDENTIFIER_REGEX: LazyLock = LazyLock::new(|| Regex::new(r"\b\w+\b").unwrap()); + +#[derive(Debug)] +pub struct IdentifierOccurrences { + identifier_to_count: HashMap, + total_count: usize, +} + +impl IdentifierOccurrences { + pub fn within_string(code: &str) -> Self { + Self::from_iterator(IDENTIFIER_REGEX.find_iter(code).map(|mat| mat.as_str())) + } + + #[allow(dead_code)] + pub fn within_references(references: &[Reference]) -> Self { + Self::from_iterator( + references + .iter() + .map(|reference| reference.identifier.name.as_ref()), + ) + } + + pub fn from_iterator<'a>(identifier_iterator: impl Iterator) -> Self { + let mut identifier_to_count = HashMap::new(); + let mut total_count = 0; + for identifier in identifier_iterator { + // TODO: Score matches that match case higher? + // + // TODO: Also include unsplit identifier? + for identifier_part in split_identifier(identifier) { + identifier_to_count + .entry(identifier_part.to_lowercase()) + .and_modify(|count| *count += 1) + .or_insert(1); + total_count += 1; + } + } + IdentifierOccurrences { + identifier_to_count, + total_count, + } + } +} + +// Splits camelcase / snakecase / kebabcase / pascalcase +// +// TODO: Make this more efficient / elegant. +fn split_identifier(identifier: &str) -> Vec<&str> { + let mut parts = Vec::new(); + let mut start = 0; + let chars: Vec = identifier.chars().collect(); + + if chars.is_empty() { + return parts; + } + + let mut i = 0; + while i < chars.len() { + let ch = chars[i]; + + // Handle explicit delimiters (underscore and hyphen) + if ch == '_' || ch == '-' { + if i > start { + parts.push(&identifier[start..i]); + } + start = i + 1; + i += 1; + continue; + } + + // Handle camelCase and PascalCase transitions + if i > 0 && i < chars.len() { + let prev_char = chars[i - 1]; + + // Transition from lowercase/digit to uppercase + if (prev_char.is_lowercase() || prev_char.is_ascii_digit()) && ch.is_uppercase() { + parts.push(&identifier[start..i]); + start = i; + } + // Handle sequences like "XMLParser" -> ["XML", "Parser"] + else if i + 1 < chars.len() + && ch.is_uppercase() + && chars[i + 1].is_lowercase() + && prev_char.is_uppercase() + { + parts.push(&identifier[start..i]); + start = i; + } + } + + i += 1; + } + + // Add the last part if there's any remaining + if start < identifier.len() { + parts.push(&identifier[start..]); + } + + // Filter out empty strings + parts.into_iter().filter(|s| !s.is_empty()).collect() +} + +pub fn jaccard_similarity<'a>( + mut set_a: &'a IdentifierOccurrences, + mut set_b: &'a IdentifierOccurrences, +) -> f32 { + if set_a.identifier_to_count.len() > set_b.identifier_to_count.len() { + std::mem::swap(&mut set_a, &mut set_b); + } + let intersection = set_a + .identifier_to_count + .keys() + .filter(|key| set_b.identifier_to_count.contains_key(*key)) + .count(); + let union = set_a.identifier_to_count.len() + set_b.identifier_to_count.len() - intersection; + intersection as f32 / union as f32 +} + +// TODO +#[allow(dead_code)] +pub fn overlap_coefficient<'a>( + mut set_a: &'a IdentifierOccurrences, + mut set_b: &'a IdentifierOccurrences, +) -> f32 { + if set_a.identifier_to_count.len() > set_b.identifier_to_count.len() { + std::mem::swap(&mut set_a, &mut set_b); + } + let intersection = set_a + .identifier_to_count + .keys() + .filter(|key| set_b.identifier_to_count.contains_key(*key)) + .count(); + intersection as f32 / set_a.identifier_to_count.len() as f32 +} + +// TODO +#[allow(dead_code)] +pub fn weighted_jaccard_similarity<'a>( + mut set_a: &'a IdentifierOccurrences, + mut set_b: &'a IdentifierOccurrences, +) -> f32 { + if set_a.identifier_to_count.len() > set_b.identifier_to_count.len() { + std::mem::swap(&mut set_a, &mut set_b); + } + + let mut numerator = 0; + let mut denominator_a = 0; + let mut used_count_b = 0; + for (symbol, count_a) in set_a.identifier_to_count.iter() { + let count_b = set_b.identifier_to_count.get(symbol).unwrap_or(&0); + numerator += count_a.min(count_b); + denominator_a += count_a.max(count_b); + used_count_b += count_b; + } + + let denominator = denominator_a + (set_b.total_count - used_count_b); + if denominator == 0 { + 0.0 + } else { + numerator as f32 / denominator as f32 + } +} + +pub fn weighted_overlap_coefficient<'a>( + mut set_a: &'a IdentifierOccurrences, + mut set_b: &'a IdentifierOccurrences, +) -> f32 { + if set_a.identifier_to_count.len() > set_b.identifier_to_count.len() { + std::mem::swap(&mut set_a, &mut set_b); + } + + let mut numerator = 0; + for (symbol, count_a) in set_a.identifier_to_count.iter() { + let count_b = set_b.identifier_to_count.get(symbol).unwrap_or(&0); + numerator += count_a.min(count_b); + } + + let denominator = set_a.total_count.min(set_b.total_count); + if denominator == 0 { + 0.0 + } else { + numerator as f32 / denominator as f32 + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_split_identifier() { + assert_eq!(split_identifier("snake_case"), vec!["snake", "case"]); + assert_eq!(split_identifier("kebab-case"), vec!["kebab", "case"]); + assert_eq!(split_identifier("PascalCase"), vec!["Pascal", "Case"]); + assert_eq!(split_identifier("camelCase"), vec!["camel", "Case"]); + assert_eq!(split_identifier("XMLParser"), vec!["XML", "Parser"]); + } + + #[test] + fn test_similarity_functions() { + // 10 identifier parts, 8 unique + // Repeats: 2 "outline", 2 "items" + let set_a = IdentifierOccurrences::within_string( + "let mut outline_items = query_outline_items(&language, &tree, &source);", + ); + // 14 identifier parts, 11 unique + // Repeats: 2 "outline", 2 "language", 2 "tree" + let set_b = IdentifierOccurrences::within_string( + "pub fn query_outline_items(language: &Language, tree: &Tree, source: &str) -> Vec {", + ); + + // 6 overlaps: "outline", "items", "query", "language", "tree", "source" + // 7 non-overlaps: "let", "mut", "pub", "fn", "vec", "item", "str" + assert_eq!(jaccard_similarity(&set_a, &set_b), 6.0 / (6.0 + 7.0)); + + // Numerator is one more than before due to both having 2 "outline". + // Denominator is the same except for 3 more due to the non-overlapping duplicates + assert_eq!( + weighted_jaccard_similarity(&set_a, &set_b), + 7.0 / (7.0 + 7.0 + 3.0) + ); + + // Numerator is the same as jaccard_similarity. Denominator is the size of the smaller set, 8. + assert_eq!(overlap_coefficient(&set_a, &set_b), 6.0 / 8.0); + + // Numerator is the same as weighted_jaccard_similarity. Denominator is the total weight of + // the smaller set, 10. + assert_eq!(weighted_overlap_coefficient(&set_a, &set_b), 7.0 / 10.0); + } +} diff --git a/crates/edit_prediction_context/src/wip_requests.rs b/crates/edit_prediction_context/src/wip_requests.rs new file mode 100644 index 0000000000000000000000000000000000000000..9189587929725c8e1e4369fe5bd24cc641d6afab --- /dev/null +++ b/crates/edit_prediction_context/src/wip_requests.rs @@ -0,0 +1,35 @@ +// To discuss: What to send to the new endpoint? Thinking it'd make sense to put `prompt.rs` from +// `zeta_context.rs` in cloud. +// +// * Run excerpt selection at several different sizes, send the largest size with offsets within for +// the smaller sizes. +// +// * Longer event history. +// +// * Many more snippets than could fit in model context - allows ranking experimentation. + +pub struct Zeta2Request { + pub event_history: Vec, + pub excerpt: String, + pub excerpt_subsets: Vec, + /// Within `excerpt` + pub cursor_position: usize, + pub signatures: Vec, + pub retrieved_declarations: Vec, +} + +pub struct Zeta2ExcerptSubset { + /// Within `excerpt` text. + pub excerpt_range: Range, + /// Within `signatures`. + pub parent_signatures: Vec, +} + +pub struct ReferencedDeclaration { + pub text: Arc, + /// Range within `text` + pub signature_range: Range, + /// Indices within `signatures`. + pub parent_signatures: Vec, + // A bunch of score metrics +}