From 05545abab6558e4bd13f395ca676bb6ad8684864 Mon Sep 17 00:00:00 2001 From: Michael Sloan Date: Wed, 17 Sep 2025 18:06:41 -0600 Subject: [PATCH] Checkpoint Co-authored-by: Agus --- crates/edit_prediction_context/Cargo.toml | 2 +- .../src/declaration.rs | 55 ++++++++++--------- .../src/declaration_scoring.rs | 37 ++++++++++--- .../src/edit_prediction_context.rs | 6 +- .../src/syntax_index.rs | 11 +++- .../src/text_similarity.rs | 4 ++ 6 files changed, 73 insertions(+), 42 deletions(-) diff --git a/crates/edit_prediction_context/Cargo.toml b/crates/edit_prediction_context/Cargo.toml index 912775d5045a74fd26eaf445e20afef336f40f54..b184c43539c791d13ff2f60aa1441ec3f60d34d2 100644 --- a/crates/edit_prediction_context/Cargo.toml +++ b/crates/edit_prediction_context/Cargo.toml @@ -24,6 +24,7 @@ gpui.workspace = true itertools.workspace = true language.workspace = true log.workspace = true +ordered-float.workspace = true project.workspace = true regex.workspace = true serde.workspace = true @@ -40,7 +41,6 @@ futures.workspace = true gpui = { workspace = true, features = ["test-support"] } indoc.workspace = true language = { workspace = true, features = ["test-support"] } -ordered-float.workspace = true pretty_assertions.workspace = true project = {workspace= true, features = ["test-support"]} serde_json.workspace = true diff --git a/crates/edit_prediction_context/src/declaration.rs b/crates/edit_prediction_context/src/declaration.rs index ad944359d54a2d4bf2ddb02df619e49a2d9fd7a8..fcf54fead80194fe97a2719971f86318a57ad75c 100644 --- a/crates/edit_prediction_context/src/declaration.rs +++ b/crates/edit_prediction_context/src/declaration.rs @@ -60,17 +60,11 @@ impl Declaration { ), Declaration::Buffer { rope, declaration, .. - } => { - let (range, is_truncated) = expand_range_to_line_boundaries_and_truncate( - &declaration.item_range, - ITEM_TEXT_TRUNCATION_LENGTH, - rope, - ); - ( - rope.chunks_in_range(range).collect::>(), - is_truncated, - ) - } + } => ( + rope.chunks_in_range(declaration.item_range.clone()) + .collect::>(), + declaration.item_range_is_truncated, + ), } } @@ -82,17 +76,11 @@ impl Declaration { ), Declaration::Buffer { rope, declaration, .. - } => { - let (range, is_truncated) = expand_range_to_line_boundaries_and_truncate( - &declaration.signature_range, - ITEM_TEXT_TRUNCATION_LENGTH, - rope, - ); - ( - rope.chunks_in_range(range).collect::>(), - is_truncated, - ) - } + } => ( + rope.chunks_in_range(declaration.signature_range.clone()) + .collect::>(), + declaration.signature_range_is_truncated, + ), } } } @@ -175,18 +163,31 @@ pub struct BufferDeclaration { pub parent: Option, pub identifier: Identifier, pub item_range: Range, + pub item_range_is_truncated: bool, pub signature_range: Range, + pub signature_range_is_truncated: bool, } impl BufferDeclaration { - pub fn from_outline(declaration: OutlineDeclaration) -> Self { - // use of anchor_before is a guess that the proper behavior is to expand to include - // insertions immediately before the declaration, but not for insertions immediately after + pub fn from_outline(declaration: OutlineDeclaration, rope: &Rope) -> Self { + let (item_range, item_range_is_truncated) = expand_range_to_line_boundaries_and_truncate( + &declaration.item_range, + ITEM_TEXT_TRUNCATION_LENGTH, + rope, + ); + let (signature_range, signature_range_is_truncated) = + expand_range_to_line_boundaries_and_truncate( + &declaration.signature_range, + ITEM_TEXT_TRUNCATION_LENGTH, + rope, + ); Self { parent: None, identifier: declaration.identifier, - item_range: declaration.item_range, - signature_range: declaration.signature_range, + item_range, + item_range_is_truncated, + signature_range, + signature_range_is_truncated, } } } diff --git a/crates/edit_prediction_context/src/declaration_scoring.rs b/crates/edit_prediction_context/src/declaration_scoring.rs index af26ea4ca46e83d82776e03a117e8f51855304ed..aeff8487dff64923813b57c67f6500061ca47ad2 100644 --- a/crates/edit_prediction_context/src/declaration_scoring.rs +++ b/crates/edit_prediction_context/src/declaration_scoring.rs @@ -1,5 +1,6 @@ use itertools::Itertools as _; use language::BufferSnapshot; +use ordered_float::OrderedFloat; use serde::Serialize; use std::{collections::HashMap, ops::Range}; use strum::EnumIter; @@ -12,13 +13,14 @@ use crate::{ text_similarity::{IdentifierOccurrences, jaccard_similarity, weighted_overlap_coefficient}, }; +const MAX_IDENTIFIER_DECLARATION_COUNT: usize = 16; + // TODO: // -// * Consider adding declaration_file_count (n) +// * Consider adding declaration_file_count #[derive(Clone, Debug)] pub struct ScoredSnippet { - #[allow(dead_code)] pub identifier: Identifier, pub declaration: Declaration, pub score_components: ScoreInputs, @@ -42,7 +44,17 @@ impl ScoredSnippet { } pub fn size(&self, style: SnippetStyle) -> usize { - todo!() + // TODO: how to handle truncation? + match &self.declaration { + Declaration::File { declaration, .. } => match style { + SnippetStyle::Signature => declaration.signature_range_in_text.len(), + SnippetStyle::Declaration => declaration.text.len(), + }, + Declaration::Buffer { declaration, .. } => match style { + SnippetStyle::Signature => declaration.signature_range.len(), + SnippetStyle::Declaration => declaration.item_range.len(), + }, + } } pub fn score_density(&self, style: SnippetStyle) -> f32 { @@ -70,11 +82,11 @@ pub fn scored_snippets( .collect::(), ); - identifier_to_references + let mut snippets = identifier_to_references .into_iter() .flat_map(|(identifier, references)| { - // todo! pick a limit - let declarations = index.declarations_for_identifier::<16>(&identifier); + let declarations = + index.declarations_for_identifier::(&identifier); let declaration_count = declarations.len(); declarations @@ -144,10 +156,19 @@ pub fn scored_snippets( .collect::>() }) .flatten() - .collect::>() + .collect::>(); + + snippets.sort_unstable_by_key(|snippet| { + OrderedFloat( + snippet + .score_density(SnippetStyle::Declaration) + .max(snippet.score_density(SnippetStyle::Signature)), + ) + }); + + snippets } -// todo! replace with existing util? fn range_intersection(a: &Range, b: &Range) -> Option> { let start = a.start.clone().max(b.start.clone()); let end = a.end.clone().min(b.end.clone()); diff --git a/crates/edit_prediction_context/src/edit_prediction_context.rs b/crates/edit_prediction_context/src/edit_prediction_context.rs index ff999964f4a7b4634d80680aaa06e4d09c7948a8..5d73dc7f7dcf2223ae1f23b22c2f104842206e12 100644 --- a/crates/edit_prediction_context/src/edit_prediction_context.rs +++ b/crates/edit_prediction_context/src/edit_prediction_context.rs @@ -17,9 +17,9 @@ use text::{Point, ToOffset as _}; use crate::declaration_scoring::{ScoredSnippet, scored_snippets}; pub struct EditPredictionContext { - excerpt: EditPredictionExcerpt, - excerpt_text: EditPredictionExcerptText, - snippets: Vec, + pub excerpt: EditPredictionExcerpt, + pub excerpt_text: EditPredictionExcerptText, + pub snippets: Vec, } impl EditPredictionContext { diff --git a/crates/edit_prediction_context/src/syntax_index.rs b/crates/edit_prediction_context/src/syntax_index.rs index bef993d0c50dd06602b0343082929cb8677a3bb8..5aba354c1dc0bd61a6dc5bb0b025ad45db6b12b6 100644 --- a/crates/edit_prediction_context/src/syntax_index.rs +++ b/crates/edit_prediction_context/src/syntax_index.rs @@ -240,17 +240,22 @@ impl SyntaxIndex { let rope = snapshot.text.as_rope().clone(); anyhow::Ok(( - rope, declarations_in_buffer(&snapshot) .into_iter() - .map(|item| (item.parent_index, BufferDeclaration::from_outline(item))) + .map(|item| { + ( + item.parent_index, + BufferDeclaration::from_outline(item, &rope), + ) + }) .collect::>(), + rope, )) }); let task = cx.spawn({ async move |this, cx| { - let Ok((rope, declarations)) = parse_task.await else { + let Ok((declarations, rope)) = parse_task.await else { return; }; diff --git a/crates/edit_prediction_context/src/text_similarity.rs b/crates/edit_prediction_context/src/text_similarity.rs index b81ccc1298798365cff11d9cdb6dd6bfc1541e89..6adbb245c9e4411c39bf1fad99da49f8f8bf7ef9 100644 --- a/crates/edit_prediction_context/src/text_similarity.rs +++ b/crates/edit_prediction_context/src/text_similarity.rs @@ -127,6 +127,8 @@ pub fn jaccard_similarity<'a>( intersection as f32 / union as f32 } +// TODO +#[allow(dead_code)] pub fn overlap_coefficient<'a>( mut set_a: &'a IdentifierOccurrences, mut set_b: &'a IdentifierOccurrences, @@ -142,6 +144,8 @@ pub fn overlap_coefficient<'a>( intersection as f32 / set_a.identifier_to_count.len() as f32 } +// TODO +#[allow(dead_code)] pub fn weighted_jaccard_similarity<'a>( mut set_a: &'a IdentifierOccurrences, mut set_b: &'a IdentifierOccurrences,