Checkpoint: Get score_snippets to compile

Agus and Finn created

Co-Authored-By: Finn <finn@zed.dev>

Change summary

crates/edit_prediction_context/src/scored_declaration.rs | 190 ++++++----
crates/edit_prediction_context/src/tree_sitter_index.rs  |  90 ++++
2 files changed, 196 insertions(+), 84 deletions(-)

Detailed changes

crates/edit_prediction_context/src/scored_declaration.rs 🔗

@@ -1,14 +1,18 @@
+use collections::HashSet;
+use gpui::{App, Entity};
 use itertools::Itertools as _;
+use language::BufferSnapshot;
+use project::ProjectEntryId;
 use serde::Serialize;
-use std::collections::HashMap;
-use std::path::Path;
-use std::sync::Arc;
+use std::{collections::HashMap, ops::Range};
 use strum::EnumIter;
-use tree_sitter::StreamingIterator;
+use text::{OffsetRangeExt, Point, ToPoint};
 
 use crate::{
-    Declaration, EditPredictionExcerpt, EditPredictionExcerptText, outline::Identifier,
-    reference::Reference, text_similarity::IdentifierOccurrences,
+    Declaration, EditPredictionExcerpt, EditPredictionExcerptText, TreeSitterIndex,
+    outline::Identifier,
+    reference::{Reference, ReferenceRegion},
+    text_similarity::{IdentifierOccurrences, jaccard_similarity, weighted_overlap_coefficient},
 };
 
 #[derive(Clone, Debug)]
@@ -46,23 +50,29 @@ impl ScoredSnippet {
 }
 
 fn scored_snippets(
+    index: Entity<TreeSitterIndex>,
     excerpt: &EditPredictionExcerpt,
     excerpt_text: &EditPredictionExcerptText,
     references: Vec<Reference>,
     cursor_offset: usize,
+    current_buffer: &BufferSnapshot,
+    cx: &App,
 ) -> Vec<ScoredSnippet> {
-    let excerpt_occurrences = IdentifierOccurrences::within_string(&excerpt_text.body);
+    let containing_range_identifier_occurrences =
+        IdentifierOccurrences::within_string(&excerpt_text.body);
+    let cursor_point = cursor_offset.to_point(&current_buffer);
 
-    /* todo!
-    if let Some(cursor_within_excerpt) = cursor_offset.checked_sub(excerpt.range.start) {
-    } else {
-    };
-    let start_point = Point::new(cursor.row.saturating_sub(2), 0);
-    let end_point = Point::new(cursor.row + 1, 0);
+    // todo! ask michael why we needed this
+    // if let Some(cursor_within_excerpt) = cursor_offset.checked_sub(excerpt.range.start) {
+    // } else {
+    // };
+    let start_point = Point::new(cursor_point.row.saturating_sub(2), 0);
+    let end_point = Point::new(cursor_point.row + 1, 0);
     let adjacent_identifier_occurrences = IdentifierOccurrences::within_string(
-        &source[offset_from_point(source, start_point)..offset_from_point(source, end_point)],
+        &current_buffer
+            .text_for_range(start_point..end_point)
+            .collect::<String>(),
     );
-    */
 
     let mut identifier_to_references: HashMap<Identifier, Vec<Reference>> = HashMap::new();
     for reference in references {
@@ -75,74 +85,102 @@ fn scored_snippets(
     identifier_to_references
         .into_iter()
         .flat_map(|(identifier, references)| {
-            let Some(definitions) = index
-                .identifier_to_definitions
-                .get(&(identifier.clone(), language.name.clone()))
-            else {
-                return Vec::new();
-            };
+            let definitions = index
+                .read(cx)
+                // todo! pick a limit
+                .declarations_for_identifier::<16>(&identifier, cx);
             let definition_count = definitions.len();
-            let definition_file_count = definitions.keys().len();
+            let total_file_count = definitions
+                .iter()
+                .filter_map(|definition| definition.project_entry_id(cx))
+                .collect::<HashSet<ProjectEntryId>>()
+                .len();
 
             definitions
-                .iter_all()
-                .flat_map(|(definition_file, file_definitions)| {
-                    let same_file_definition_count = file_definitions.len();
-                    let is_same_file = reference_file == definition_file.as_ref();
-                    file_definitions
-                        .iter()
-                        .filter(|definition| {
-                            !is_same_file
-                                || !range_intersection(&definition.item_range, &excerpt_range)
-                                    .is_some()
-                        })
-                        .filter_map(|definition| {
-                            let definition_line_distance = if is_same_file {
+                .iter()
+                .filter_map(|definition| match definition {
+                    Declaration::Buffer {
+                        declaration,
+                        buffer,
+                    } => {
+                        let is_same_file = buffer
+                            .read_with(cx, |buffer, _| buffer.remote_id())
+                            .is_ok_and(|buffer_id| buffer_id == current_buffer.remote_id());
+
+                        if is_same_file {
+                            range_intersection(
+                                &declaration.item_range.to_offset(&current_buffer),
+                                &excerpt.range,
+                            )
+                            .is_none()
+                            .then(|| {
                                 let definition_line =
-                                    point_from_offset(source, definition.item_range.start).row;
-                                (cursor.row as i32 - definition_line as i32).abs() as u32
-                            } else {
-                                0
-                            };
-                            Some((definition_line_distance, definition))
-                        })
-                        .sorted_by_key(|&(distance, _)| distance)
-                        .enumerate()
-                        .map(
-                            |(
-                                definition_line_distance_rank,
-                                (definition_line_distance, definition),
-                            )| {
-                                score_snippet(
-                                    &identifier,
-                                    &references,
-                                    definition_file.clone(),
-                                    definition.clone(),
-                                    is_same_file,
-                                    definition_line_distance,
-                                    definition_line_distance_rank,
-                                    same_file_definition_count,
-                                    definition_count,
-                                    definition_file_count,
-                                    &containing_range_identifier_occurrences,
-                                    &adjacent_identifier_occurrences,
-                                    cursor,
+                                    declaration.item_range.start.to_point(current_buffer).row;
+                                (
+                                    true,
+                                    (cursor_point.row as i32 - definition_line as i32).abs() as u32,
+                                    definition,
                                 )
-                            },
-                        )
-                        .collect::<Vec<_>>()
+                            })
+                        } else {
+                            Some((false, 0, definition))
+                        }
+                    }
+                    Declaration::File { .. } => {
+                        // We can assume that a file declaration is in a different file,
+                        // because the current onemust be open
+                        Some((false, 0, definition))
+                    }
                 })
+                .sorted_by_key(|&(_, distance, _)| distance)
+                .enumerate()
+                .map(
+                    |(
+                        definition_line_distance_rank,
+                        (is_same_file, definition_line_distance, definition),
+                    )| {
+                        let same_file_definition_count =
+                            index.read(cx).file_declaration_count(definition);
+
+                        score_snippet(
+                            &identifier,
+                            &references,
+                            definition.clone(),
+                            is_same_file,
+                            definition_line_distance,
+                            definition_line_distance_rank,
+                            same_file_definition_count,
+                            definition_count,
+                            total_file_count,
+                            &containing_range_identifier_occurrences,
+                            &adjacent_identifier_occurrences,
+                            cursor_point,
+                            current_buffer,
+                            cx,
+                        )
+                    },
+                )
                 .collect::<Vec<_>>()
         })
         .flatten()
         .collect::<Vec<_>>()
 }
 
+// todo! replace with existing util?
+fn range_intersection<T: Ord + Clone>(a: &Range<T>, b: &Range<T>) -> Option<Range<T>> {
+    let start = a.start.clone().max(b.start.clone());
+    let end = a.end.clone().min(b.end.clone());
+    if start < end {
+        Some(Range { start, end })
+    } else {
+        None
+    }
+}
+
 fn score_snippet(
     identifier: &Identifier,
     references: &[Reference],
-    definition_file: Arc<Path>,
-    definition: OutlineItem,
+    definition: Declaration,
     is_same_file: bool,
     definition_line_distance: u32,
     definition_line_distance_rank: usize,
@@ -152,28 +190,28 @@ fn score_snippet(
     containing_range_identifier_occurrences: &IdentifierOccurrences,
     adjacent_identifier_occurrences: &IdentifierOccurrences,
     cursor: Point,
+    current_buffer: &BufferSnapshot,
+    cx: &App,
 ) -> Option<ScoredSnippet> {
     let is_referenced_nearby = references
         .iter()
-        .any(|r| r.reference_region == ReferenceRegion::Nearby);
+        .any(|r| r.region == ReferenceRegion::Nearby);
     let is_referenced_in_breadcrumb = references
         .iter()
-        .any(|r| r.reference_region == ReferenceRegion::Breadcrumb);
+        .any(|r| r.region == ReferenceRegion::Breadcrumb);
     let reference_count = references.len();
     let reference_line_distance = references
         .iter()
         .map(|r| {
-            let reference_line = point_from_offset(reference_source, r.range.start).row as i32;
+            let reference_line = r.range.start.to_point(current_buffer).row as i32;
             (cursor.row as i32 - reference_line).abs() as u32
         })
         .min()
         .unwrap();
 
-    let definition_source = index.path_to_source.get(&definition_file).unwrap();
-    let item_source_occurrences =
-        IdentifierOccurrences::within_string(definition.item(&definition_source));
+    let item_source_occurrences = IdentifierOccurrences::within_string(&definition.item_text(cx));
     let item_signature_occurrences =
-        IdentifierOccurrences::within_string(definition.signature(&definition_source));
+        IdentifierOccurrences::within_string(&definition.signature_text(cx));
     let containing_range_vs_item_jaccard = jaccard_similarity(
         containing_range_identifier_occurrences,
         &item_source_occurrences,
@@ -223,7 +261,6 @@ fn score_snippet(
 
     Some(ScoredSnippet {
         identifier: identifier.clone(),
-        declaration_file: definition_file,
         declaration: definition,
         scores: score_components.score(),
         score_components,
@@ -238,6 +275,7 @@ pub struct ScoreInputs {
     pub reference_count: usize,
     pub same_file_definition_count: usize,
     pub definition_count: usize,
+    // todo! do we need this?
     pub definition_file_count: usize,
     pub reference_line_distance: u32,
     pub definition_line_distance: u32,

crates/edit_prediction_context/src/tree_sitter_index.rs 🔗

@@ -78,6 +78,57 @@ impl Declaration {
             Declaration::Buffer { declaration, .. } => &declaration.identifier,
         }
     }
+
+    pub fn project_entry_id(&self, cx: &App) -> Option<ProjectEntryId> {
+        match self {
+            Declaration::File {
+                project_entry_id, ..
+            } => Some(*project_entry_id),
+            Declaration::Buffer { buffer, .. } => buffer
+                .read_with(cx, |buffer, _cx| {
+                    project::File::from_dyn(buffer.file())
+                        .and_then(|file| file.project_entry_id(cx))
+                })
+                .ok()
+                .flatten(),
+        }
+    }
+
+    // todo! pick best return type
+    pub fn item_text(&self, cx: &App) -> Arc<str> {
+        match self {
+            Declaration::File { declaration, .. } => declaration.declaration_text.clone(),
+            Declaration::Buffer {
+                buffer,
+                declaration,
+            } => buffer
+                .read_with(cx, |buffer, _cx| {
+                    buffer
+                        .text_for_range(declaration.item_range.clone())
+                        .collect::<String>()
+                        .into()
+                })
+                .unwrap_or_default(),
+        }
+    }
+
+    // todo! pick best return type
+    pub fn signature_text(&self, cx: &App) -> Arc<str> {
+        match self {
+            Declaration::File { declaration, .. } => declaration.signature_text.clone(),
+            Declaration::Buffer {
+                buffer,
+                declaration,
+            } => buffer
+                .read_with(cx, |buffer, _cx| {
+                    buffer
+                        .text_for_range(declaration.signature_range.clone())
+                        .collect::<String>()
+                        .into()
+                })
+                .unwrap_or_default(),
+        }
+    }
 }
 
 #[derive(Debug, Clone)]
@@ -86,7 +137,9 @@ pub struct FileDeclaration {
     pub identifier: Identifier,
     pub item_range: Range<usize>,
     pub signature_range: Range<usize>,
+    // todo! should we just store a range with the declaration text?
     pub signature_text: Arc<str>,
+    pub declaration_text: Arc<str>,
 }
 
 #[derive(Debug, Clone)]
@@ -145,7 +198,7 @@ impl TreeSitterIndex {
 
     pub fn declarations_for_identifier<const N: usize>(
         &self,
-        identifier: Identifier,
+        identifier: &Identifier,
         cx: &App,
     ) -> Vec<Declaration> {
         // make sure to not have a large stack allocation
@@ -206,6 +259,23 @@ impl TreeSitterIndex {
         result
     }
 
+    pub fn file_declaration_count(&self, declaration: &Declaration) -> usize {
+        match declaration {
+            Declaration::File {
+                project_entry_id, ..
+            } => self
+                .files
+                .get(project_entry_id)
+                .map(|file_state| file_state.declarations.len())
+                .unwrap_or_default(),
+            Declaration::Buffer { buffer, .. } => self
+                .buffers
+                .get(buffer)
+                .map(|buffer_state| buffer_state.declarations.len())
+                .unwrap_or_default(),
+        }
+    }
+
     fn handle_worktree_store_event(
         &mut self,
         _worktree_store: Entity<WorktreeStore>,
@@ -491,12 +561,16 @@ impl FileDeclaration {
         FileDeclaration {
             parent: None,
             identifier: declaration.identifier,
-            item_range: declaration.item_range,
             signature_text: snapshot
                 .text_for_range(declaration.signature_range.clone())
                 .collect::<String>()
                 .into(),
             signature_range: declaration.signature_range,
+            declaration_text: snapshot
+                .text_for_range(declaration.item_range.clone())
+                .collect::<String>()
+                .into(),
+            item_range: declaration.item_range,
         }
     }
 }
@@ -527,7 +601,7 @@ mod tests {
         };
 
         index.read_with(cx, |index, cx| {
-            let decls = index.declarations_for_identifier::<8>(main.clone(), cx);
+            let decls = index.declarations_for_identifier::<8>(&main, cx);
             assert_eq!(decls.len(), 2);
 
             let decl = expect_file_decl("c.rs", &decls[0], &project, cx);
@@ -549,7 +623,7 @@ mod tests {
         };
 
         index.read_with(cx, |index, cx| {
-            let decls = index.declarations_for_identifier::<8>(test_process_data.clone(), cx);
+            let decls = index.declarations_for_identifier::<8>(&test_process_data, cx);
             assert_eq!(decls.len(), 1);
 
             let decl = expect_file_decl("c.rs", &decls[0], &project, cx);
@@ -588,7 +662,7 @@ mod tests {
         cx.run_until_parked();
 
         index.read_with(cx, |index, cx| {
-            let decls = index.declarations_for_identifier::<8>(test_process_data.clone(), cx);
+            let decls = index.declarations_for_identifier::<8>(&test_process_data, cx);
             assert_eq!(decls.len(), 1);
 
             let decl = expect_buffer_decl("c.rs", &decls[0], cx);
@@ -616,7 +690,7 @@ mod tests {
 
         index.read_with(cx, |index, cx| {
             let decls = index.declarations_for_identifier::<1>(
-                Identifier {
+                &Identifier {
                     name: "main".into(),
                     language_id: rust_lang_id,
                 },
@@ -646,7 +720,7 @@ mod tests {
         cx.run_until_parked();
 
         index.read_with(cx, |index, cx| {
-            let decls = index.declarations_for_identifier::<8>(main.clone(), cx);
+            let decls = index.declarations_for_identifier::<8>(&main, cx);
             assert_eq!(decls.len(), 2);
             let decl = expect_buffer_decl("c.rs", &decls[0], cx);
             assert_eq!(decl.identifier, main);
@@ -669,7 +743,7 @@ mod tests {
         cx.run_until_parked();
 
         index.read_with(cx, |index, cx| {
-            let decls = index.declarations_for_identifier::<8>(main, cx);
+            let decls = index.declarations_for_identifier::<8>(&main, cx);
             assert_eq!(decls.len(), 2);
             expect_file_decl("c.rs", &decls[0], &project, cx);
             expect_file_decl("a.rs", &decls[1], &project, cx);