Checkpoint

Michael Sloan and Agus created

Co-authored-by: Agus <agus@zed.dev>

Change summary

crates/edit_prediction_context/Cargo.toml                     |  2 
crates/edit_prediction_context/src/declaration.rs             | 55 ++--
crates/edit_prediction_context/src/declaration_scoring.rs     | 37 ++
crates/edit_prediction_context/src/edit_prediction_context.rs |  6 
crates/edit_prediction_context/src/syntax_index.rs            | 11 
crates/edit_prediction_context/src/text_similarity.rs         |  4 
6 files changed, 73 insertions(+), 42 deletions(-)

Detailed changes

crates/edit_prediction_context/Cargo.toml 🔗

@@ -24,6 +24,7 @@ gpui.workspace = true
 itertools.workspace = true
 language.workspace = true
 log.workspace = true
+ordered-float.workspace = true
 project.workspace = true
 regex.workspace = true
 serde.workspace = true
@@ -40,7 +41,6 @@ futures.workspace = true
 gpui = { workspace = true, features = ["test-support"] }
 indoc.workspace = true
 language = { workspace = true, features = ["test-support"] }
-ordered-float.workspace = true
 pretty_assertions.workspace = true
 project = {workspace= true, features = ["test-support"]}
 serde_json.workspace = true

crates/edit_prediction_context/src/declaration.rs 🔗

@@ -60,17 +60,11 @@ impl Declaration {
             ),
             Declaration::Buffer {
                 rope, declaration, ..
-            } => {
-                let (range, is_truncated) = expand_range_to_line_boundaries_and_truncate(
-                    &declaration.item_range,
-                    ITEM_TEXT_TRUNCATION_LENGTH,
-                    rope,
-                );
-                (
-                    rope.chunks_in_range(range).collect::<Cow<str>>(),
-                    is_truncated,
-                )
-            }
+            } => (
+                rope.chunks_in_range(declaration.item_range.clone())
+                    .collect::<Cow<str>>(),
+                declaration.item_range_is_truncated,
+            ),
         }
     }
 
@@ -82,17 +76,11 @@ impl Declaration {
             ),
             Declaration::Buffer {
                 rope, declaration, ..
-            } => {
-                let (range, is_truncated) = expand_range_to_line_boundaries_and_truncate(
-                    &declaration.signature_range,
-                    ITEM_TEXT_TRUNCATION_LENGTH,
-                    rope,
-                );
-                (
-                    rope.chunks_in_range(range).collect::<Cow<str>>(),
-                    is_truncated,
-                )
-            }
+            } => (
+                rope.chunks_in_range(declaration.signature_range.clone())
+                    .collect::<Cow<str>>(),
+                declaration.signature_range_is_truncated,
+            ),
         }
     }
 }
@@ -175,18 +163,31 @@ pub struct BufferDeclaration {
     pub parent: Option<DeclarationId>,
     pub identifier: Identifier,
     pub item_range: Range<usize>,
+    pub item_range_is_truncated: bool,
     pub signature_range: Range<usize>,
+    pub signature_range_is_truncated: bool,
 }
 
 impl BufferDeclaration {
-    pub fn from_outline(declaration: OutlineDeclaration) -> Self {
-        // use of anchor_before is a guess that the proper behavior is to expand to include
-        // insertions immediately before the declaration, but not for insertions immediately after
+    pub fn from_outline(declaration: OutlineDeclaration, rope: &Rope) -> Self {
+        let (item_range, item_range_is_truncated) = expand_range_to_line_boundaries_and_truncate(
+            &declaration.item_range,
+            ITEM_TEXT_TRUNCATION_LENGTH,
+            rope,
+        );
+        let (signature_range, signature_range_is_truncated) =
+            expand_range_to_line_boundaries_and_truncate(
+                &declaration.signature_range,
+                ITEM_TEXT_TRUNCATION_LENGTH,
+                rope,
+            );
         Self {
             parent: None,
             identifier: declaration.identifier,
-            item_range: declaration.item_range,
-            signature_range: declaration.signature_range,
+            item_range,
+            item_range_is_truncated,
+            signature_range,
+            signature_range_is_truncated,
         }
     }
 }

crates/edit_prediction_context/src/declaration_scoring.rs 🔗

@@ -1,5 +1,6 @@
 use itertools::Itertools as _;
 use language::BufferSnapshot;
+use ordered_float::OrderedFloat;
 use serde::Serialize;
 use std::{collections::HashMap, ops::Range};
 use strum::EnumIter;
@@ -12,13 +13,14 @@ use crate::{
     text_similarity::{IdentifierOccurrences, jaccard_similarity, weighted_overlap_coefficient},
 };
 
+const MAX_IDENTIFIER_DECLARATION_COUNT: usize = 16;
+
 // TODO:
 //
-// * Consider adding declaration_file_count (n)
+// * Consider adding declaration_file_count
 
 #[derive(Clone, Debug)]
 pub struct ScoredSnippet {
-    #[allow(dead_code)]
     pub identifier: Identifier,
     pub declaration: Declaration,
     pub score_components: ScoreInputs,
@@ -42,7 +44,17 @@ impl ScoredSnippet {
     }
 
     pub fn size(&self, style: SnippetStyle) -> usize {
-        todo!()
+        // TODO: how to handle truncation?
+        match &self.declaration {
+            Declaration::File { declaration, .. } => match style {
+                SnippetStyle::Signature => declaration.signature_range_in_text.len(),
+                SnippetStyle::Declaration => declaration.text.len(),
+            },
+            Declaration::Buffer { declaration, .. } => match style {
+                SnippetStyle::Signature => declaration.signature_range.len(),
+                SnippetStyle::Declaration => declaration.item_range.len(),
+            },
+        }
     }
 
     pub fn score_density(&self, style: SnippetStyle) -> f32 {
@@ -70,11 +82,11 @@ pub fn scored_snippets(
             .collect::<String>(),
     );
 
-    identifier_to_references
+    let mut snippets = identifier_to_references
         .into_iter()
         .flat_map(|(identifier, references)| {
-            // todo! pick a limit
-            let declarations = index.declarations_for_identifier::<16>(&identifier);
+            let declarations =
+                index.declarations_for_identifier::<MAX_IDENTIFIER_DECLARATION_COUNT>(&identifier);
             let declaration_count = declarations.len();
 
             declarations
@@ -144,10 +156,19 @@ pub fn scored_snippets(
                 .collect::<Vec<_>>()
         })
         .flatten()
-        .collect::<Vec<_>>()
+        .collect::<Vec<_>>();
+
+    snippets.sort_unstable_by_key(|snippet| {
+        OrderedFloat(
+            snippet
+                .score_density(SnippetStyle::Declaration)
+                .max(snippet.score_density(SnippetStyle::Signature)),
+        )
+    });
+
+    snippets
 }
 
-// todo! replace with existing util?
 fn range_intersection<T: Ord + Clone>(a: &Range<T>, b: &Range<T>) -> Option<Range<T>> {
     let start = a.start.clone().max(b.start.clone());
     let end = a.end.clone().min(b.end.clone());

crates/edit_prediction_context/src/edit_prediction_context.rs 🔗

@@ -17,9 +17,9 @@ use text::{Point, ToOffset as _};
 use crate::declaration_scoring::{ScoredSnippet, scored_snippets};
 
 pub struct EditPredictionContext {
-    excerpt: EditPredictionExcerpt,
-    excerpt_text: EditPredictionExcerptText,
-    snippets: Vec<ScoredSnippet>,
+    pub excerpt: EditPredictionExcerpt,
+    pub excerpt_text: EditPredictionExcerptText,
+    pub snippets: Vec<ScoredSnippet>,
 }
 
 impl EditPredictionContext {

crates/edit_prediction_context/src/syntax_index.rs 🔗

@@ -240,17 +240,22 @@ impl SyntaxIndex {
             let rope = snapshot.text.as_rope().clone();
 
             anyhow::Ok((
-                rope,
                 declarations_in_buffer(&snapshot)
                     .into_iter()
-                    .map(|item| (item.parent_index, BufferDeclaration::from_outline(item)))
+                    .map(|item| {
+                        (
+                            item.parent_index,
+                            BufferDeclaration::from_outline(item, &rope),
+                        )
+                    })
                     .collect::<Vec<_>>(),
+                rope,
             ))
         });
 
         let task = cx.spawn({
             async move |this, cx| {
-                let Ok((rope, declarations)) = parse_task.await else {
+                let Ok((declarations, rope)) = parse_task.await else {
                     return;
                 };
 

crates/edit_prediction_context/src/text_similarity.rs 🔗

@@ -127,6 +127,8 @@ pub fn jaccard_similarity<'a>(
     intersection as f32 / union as f32
 }
 
+// TODO
+#[allow(dead_code)]
 pub fn overlap_coefficient<'a>(
     mut set_a: &'a IdentifierOccurrences,
     mut set_b: &'a IdentifierOccurrences,
@@ -142,6 +144,8 @@ pub fn overlap_coefficient<'a>(
     intersection as f32 / set_a.identifier_to_count.len() as f32
 }
 
+// TODO
+#[allow(dead_code)]
 pub fn weighted_jaccard_similarity<'a>(
     mut set_a: &'a IdentifierOccurrences,
     mut set_b: &'a IdentifierOccurrences,