edit predictions: Initial Tree-sitter context gathering (#38372)

Michael Sloan , Agus , Oleksiy , and Finn created

Release Notes:

- N/A

Co-authored-by: Agus <agus@zed.dev>
Co-authored-by: Oleksiy <oleksiy@zed.dev>
Co-authored-by: Finn <finn@zed.dev>

Change summary

Cargo.lock                                                    |   6 
crates/edit_prediction_context/Cargo.toml                     |   7 
crates/edit_prediction_context/src/declaration.rs             | 193 +
crates/edit_prediction_context/src/declaration_scoring.rs     | 326 ++
crates/edit_prediction_context/src/edit_prediction_context.rs | 216 +
crates/edit_prediction_context/src/excerpt.rs                 |   2 
crates/edit_prediction_context/src/outline.rs                 |  12 
crates/edit_prediction_context/src/reference.rs               |   2 
crates/edit_prediction_context/src/syntax_index.rs            | 638 ++--
crates/edit_prediction_context/src/text_similarity.rs         | 241 +
crates/edit_prediction_context/src/wip_requests.rs            |  35 
11 files changed, 1,361 insertions(+), 317 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -5140,17 +5140,23 @@ version = "0.1.0"
 dependencies = [
  "anyhow",
  "arrayvec",
+ "clap",
  "collections",
  "futures 0.3.31",
  "gpui",
  "indoc",
+ "itertools 0.14.0",
  "language",
  "log",
+ "ordered-float 2.10.1",
  "pretty_assertions",
  "project",
+ "regex",
+ "serde",
  "serde_json",
  "settings",
  "slotmap",
+ "strum 0.27.1",
  "text",
  "tree-sitter",
  "util",

crates/edit_prediction_context/Cargo.toml 🔗

@@ -15,17 +15,24 @@ path = "src/edit_prediction_context.rs"
 anyhow.workspace = true
 arrayvec.workspace = true
 collections.workspace = true
+futures.workspace = true
 gpui.workspace = true
+itertools.workspace = true
 language.workspace = true
 log.workspace = true
+ordered-float.workspace = true
 project.workspace = true
+regex.workspace = true
+serde.workspace = true
 slotmap.workspace = true
+strum.workspace = true
 text.workspace = true
 tree-sitter.workspace = true
 util.workspace = true
 workspace-hack.workspace = true
 
 [dev-dependencies]
+clap.workspace = true
 futures.workspace = true
 gpui = { workspace = true, features = ["test-support"] }
 indoc.workspace = true

crates/edit_prediction_context/src/declaration.rs 🔗

@@ -0,0 +1,193 @@
+use language::LanguageId;
+use project::ProjectEntryId;
+use std::borrow::Cow;
+use std::ops::Range;
+use std::sync::Arc;
+use text::{Bias, BufferId, Rope};
+
+use crate::outline::OutlineDeclaration;
+
+#[derive(Debug, Clone, Eq, PartialEq, Hash)]
+pub struct Identifier {
+    pub name: Arc<str>,
+    pub language_id: LanguageId,
+}
+
+slotmap::new_key_type! {
+    pub struct DeclarationId;
+}
+
+#[derive(Debug, Clone)]
+pub enum Declaration {
+    File {
+        project_entry_id: ProjectEntryId,
+        declaration: FileDeclaration,
+    },
+    Buffer {
+        project_entry_id: ProjectEntryId,
+        buffer_id: BufferId,
+        rope: Rope,
+        declaration: BufferDeclaration,
+    },
+}
+
+const ITEM_TEXT_TRUNCATION_LENGTH: usize = 1024;
+
+impl Declaration {
+    pub fn identifier(&self) -> &Identifier {
+        match self {
+            Declaration::File { declaration, .. } => &declaration.identifier,
+            Declaration::Buffer { declaration, .. } => &declaration.identifier,
+        }
+    }
+
+    pub fn project_entry_id(&self) -> Option<ProjectEntryId> {
+        match self {
+            Declaration::File {
+                project_entry_id, ..
+            } => Some(*project_entry_id),
+            Declaration::Buffer {
+                project_entry_id, ..
+            } => Some(*project_entry_id),
+        }
+    }
+
+    pub fn item_text(&self) -> (Cow<'_, str>, bool) {
+        match self {
+            Declaration::File { declaration, .. } => (
+                declaration.text.as_ref().into(),
+                declaration.text_is_truncated,
+            ),
+            Declaration::Buffer {
+                rope, declaration, ..
+            } => (
+                rope.chunks_in_range(declaration.item_range.clone())
+                    .collect::<Cow<str>>(),
+                declaration.item_range_is_truncated,
+            ),
+        }
+    }
+
+    pub fn signature_text(&self) -> (Cow<'_, str>, bool) {
+        match self {
+            Declaration::File { declaration, .. } => (
+                declaration.text[declaration.signature_range_in_text.clone()].into(),
+                declaration.signature_is_truncated,
+            ),
+            Declaration::Buffer {
+                rope, declaration, ..
+            } => (
+                rope.chunks_in_range(declaration.signature_range.clone())
+                    .collect::<Cow<str>>(),
+                declaration.signature_range_is_truncated,
+            ),
+        }
+    }
+}
+
+fn expand_range_to_line_boundaries_and_truncate(
+    range: &Range<usize>,
+    limit: usize,
+    rope: &Rope,
+) -> (Range<usize>, bool) {
+    let mut point_range = rope.offset_to_point(range.start)..rope.offset_to_point(range.end);
+    point_range.start.column = 0;
+    point_range.end.row += 1;
+    point_range.end.column = 0;
+
+    let mut item_range =
+        rope.point_to_offset(point_range.start)..rope.point_to_offset(point_range.end);
+    let is_truncated = item_range.len() > limit;
+    if is_truncated {
+        item_range.end = item_range.start + limit;
+    }
+    item_range.end = rope.clip_offset(item_range.end, Bias::Left);
+    (item_range, is_truncated)
+}
+
+#[derive(Debug, Clone)]
+pub struct FileDeclaration {
+    pub parent: Option<DeclarationId>,
+    pub identifier: Identifier,
+    /// offset range of the declaration in the file, expanded to line boundaries and truncated
+    pub item_range_in_file: Range<usize>,
+    /// text of `item_range_in_file`
+    pub text: Arc<str>,
+    /// whether `text` was truncated
+    pub text_is_truncated: bool,
+    /// offset range of the signature within `text`
+    pub signature_range_in_text: Range<usize>,
+    /// whether `signature` was truncated
+    pub signature_is_truncated: bool,
+}
+
+impl FileDeclaration {
+    pub fn from_outline(declaration: OutlineDeclaration, rope: &Rope) -> FileDeclaration {
+        let (item_range_in_file, text_is_truncated) = expand_range_to_line_boundaries_and_truncate(
+            &declaration.item_range,
+            ITEM_TEXT_TRUNCATION_LENGTH,
+            rope,
+        );
+
+        // TODO: consider logging if unexpected
+        let signature_start = declaration
+            .signature_range
+            .start
+            .saturating_sub(item_range_in_file.start);
+        let mut signature_end = declaration
+            .signature_range
+            .end
+            .saturating_sub(item_range_in_file.start);
+        let signature_is_truncated = signature_end > item_range_in_file.len();
+        if signature_is_truncated {
+            signature_end = item_range_in_file.len();
+        }
+
+        FileDeclaration {
+            parent: None,
+            identifier: declaration.identifier,
+            signature_range_in_text: signature_start..signature_end,
+            signature_is_truncated,
+            text: rope
+                .chunks_in_range(item_range_in_file.clone())
+                .collect::<String>()
+                .into(),
+            text_is_truncated,
+            item_range_in_file,
+        }
+    }
+}
+
+#[derive(Debug, Clone)]
+pub struct BufferDeclaration {
+    pub parent: Option<DeclarationId>,
+    pub identifier: Identifier,
+    pub item_range: Range<usize>,
+    pub item_range_is_truncated: bool,
+    pub signature_range: Range<usize>,
+    pub signature_range_is_truncated: bool,
+}
+
+impl BufferDeclaration {
+    pub fn from_outline(declaration: OutlineDeclaration, rope: &Rope) -> Self {
+        let (item_range, item_range_is_truncated) = expand_range_to_line_boundaries_and_truncate(
+            &declaration.item_range,
+            ITEM_TEXT_TRUNCATION_LENGTH,
+            rope,
+        );
+        let (signature_range, signature_range_is_truncated) =
+            expand_range_to_line_boundaries_and_truncate(
+                &declaration.signature_range,
+                ITEM_TEXT_TRUNCATION_LENGTH,
+                rope,
+            );
+        Self {
+            parent: None,
+            identifier: declaration.identifier,
+            item_range,
+            item_range_is_truncated,
+            signature_range,
+            signature_range_is_truncated,
+        }
+    }
+}

crates/edit_prediction_context/src/declaration_scoring.rs 🔗

@@ -0,0 +1,326 @@
+use itertools::Itertools as _;
+use language::BufferSnapshot;
+use ordered_float::OrderedFloat;
+use serde::Serialize;
+use std::{collections::HashMap, ops::Range};
+use strum::EnumIter;
+use text::{OffsetRangeExt, Point, ToPoint};
+
+use crate::{
+    Declaration, EditPredictionExcerpt, EditPredictionExcerptText, Identifier,
+    reference::{Reference, ReferenceRegion},
+    syntax_index::SyntaxIndexState,
+    text_similarity::{IdentifierOccurrences, jaccard_similarity, weighted_overlap_coefficient},
+};
+
+const MAX_IDENTIFIER_DECLARATION_COUNT: usize = 16;
+
+// TODO:
+//
+// * Consider adding declaration_file_count
+
+#[derive(Clone, Debug)]
+pub struct ScoredSnippet {
+    pub identifier: Identifier,
+    pub declaration: Declaration,
+    pub score_components: ScoreInputs,
+    pub scores: Scores,
+}
+
+// TODO: Consider having "Concise" style corresponding to `concise_text`
+#[derive(EnumIter, Clone, Copy, PartialEq, Eq, Hash, Debug)]
+pub enum SnippetStyle {
+    Signature,
+    Declaration,
+}
+
+impl ScoredSnippet {
+    /// Returns the score for this snippet with the specified style.
+    pub fn score(&self, style: SnippetStyle) -> f32 {
+        match style {
+            SnippetStyle::Signature => self.scores.signature,
+            SnippetStyle::Declaration => self.scores.declaration,
+        }
+    }
+
+    pub fn size(&self, style: SnippetStyle) -> usize {
+        // TODO: how to handle truncation?
+        match &self.declaration {
+            Declaration::File { declaration, .. } => match style {
+                SnippetStyle::Signature => declaration.signature_range_in_text.len(),
+                SnippetStyle::Declaration => declaration.text.len(),
+            },
+            Declaration::Buffer { declaration, .. } => match style {
+                SnippetStyle::Signature => declaration.signature_range.len(),
+                SnippetStyle::Declaration => declaration.item_range.len(),
+            },
+        }
+    }
+
+    pub fn score_density(&self, style: SnippetStyle) -> f32 {
+        self.score(style) / (self.size(style)) as f32
+    }
+}
+
+pub fn scored_snippets(
+    index: &SyntaxIndexState,
+    excerpt: &EditPredictionExcerpt,
+    excerpt_text: &EditPredictionExcerptText,
+    identifier_to_references: HashMap<Identifier, Vec<Reference>>,
+    cursor_offset: usize,
+    current_buffer: &BufferSnapshot,
+) -> Vec<ScoredSnippet> {
+    let containing_range_identifier_occurrences =
+        IdentifierOccurrences::within_string(&excerpt_text.body);
+    let cursor_point = cursor_offset.to_point(&current_buffer);
+
+    let start_point = Point::new(cursor_point.row.saturating_sub(2), 0);
+    let end_point = Point::new(cursor_point.row + 1, 0);
+    let adjacent_identifier_occurrences = IdentifierOccurrences::within_string(
+        &current_buffer
+            .text_for_range(start_point..end_point)
+            .collect::<String>(),
+    );
+
+    let mut snippets = identifier_to_references
+        .into_iter()
+        .flat_map(|(identifier, references)| {
+            let declarations =
+                index.declarations_for_identifier::<MAX_IDENTIFIER_DECLARATION_COUNT>(&identifier);
+            let declaration_count = declarations.len();
+
+            declarations
+                .iter()
+                .filter_map(|declaration| match declaration {
+                    Declaration::Buffer {
+                        buffer_id,
+                        declaration: buffer_declaration,
+                        ..
+                    } => {
+                        let is_same_file = buffer_id == &current_buffer.remote_id();
+
+                        if is_same_file {
+                            range_intersection(
+                                &buffer_declaration.item_range.to_offset(&current_buffer),
+                                &excerpt.range,
+                            )
+                            .is_none()
+                            .then(|| {
+                                let declaration_line = buffer_declaration
+                                    .item_range
+                                    .start
+                                    .to_point(current_buffer)
+                                    .row;
+                                (
+                                    true,
+                                    (cursor_point.row as i32 - declaration_line as i32)
+                                        .unsigned_abs(),
+                                    declaration,
+                                )
+                            })
+                        } else {
+                            // TODO should we prefer the current file instead?
+                            Some((false, 0, declaration))
+                        }
+                    }
+                    Declaration::File { .. } => {
+                        // TODO should we prefer the current file instead?
+                        // We can assume that a file declaration is in a different file,
+                        // because the current one must be open
+                        Some((false, 0, declaration))
+                    }
+                })
+                .sorted_by_key(|&(_, distance, _)| distance)
+                .enumerate()
+                .map(
+                    |(
+                        declaration_line_distance_rank,
+                        (is_same_file, declaration_line_distance, declaration),
+                    )| {
+                        let same_file_declaration_count = index.file_declaration_count(declaration);
+
+                        score_snippet(
+                            &identifier,
+                            &references,
+                            declaration.clone(),
+                            is_same_file,
+                            declaration_line_distance,
+                            declaration_line_distance_rank,
+                            same_file_declaration_count,
+                            declaration_count,
+                            &containing_range_identifier_occurrences,
+                            &adjacent_identifier_occurrences,
+                            cursor_point,
+                            current_buffer,
+                        )
+                    },
+                )
+                .collect::<Vec<_>>()
+        })
+        .flatten()
+        .collect::<Vec<_>>();
+
+    snippets.sort_unstable_by_key(|snippet| {
+        OrderedFloat(
+            snippet
+                .score_density(SnippetStyle::Declaration)
+                .max(snippet.score_density(SnippetStyle::Signature)),
+        )
+    });
+
+    snippets
+}
+
+fn range_intersection<T: Ord + Clone>(a: &Range<T>, b: &Range<T>) -> Option<Range<T>> {
+    let start = a.start.clone().max(b.start.clone());
+    let end = a.end.clone().min(b.end.clone());
+    if start < end {
+        Some(Range { start, end })
+    } else {
+        None
+    }
+}
+
+fn score_snippet(
+    identifier: &Identifier,
+    references: &[Reference],
+    declaration: Declaration,
+    is_same_file: bool,
+    declaration_line_distance: u32,
+    declaration_line_distance_rank: usize,
+    same_file_declaration_count: usize,
+    declaration_count: usize,
+    containing_range_identifier_occurrences: &IdentifierOccurrences,
+    adjacent_identifier_occurrences: &IdentifierOccurrences,
+    cursor: Point,
+    current_buffer: &BufferSnapshot,
+) -> Option<ScoredSnippet> {
+    let is_referenced_nearby = references
+        .iter()
+        .any(|r| r.region == ReferenceRegion::Nearby);
+    let is_referenced_in_breadcrumb = references
+        .iter()
+        .any(|r| r.region == ReferenceRegion::Breadcrumb);
+    let reference_count = references.len();
+    let reference_line_distance = references
+        .iter()
+        .map(|r| {
+            let reference_line = r.range.start.to_point(current_buffer).row as i32;
+            (cursor.row as i32 - reference_line).unsigned_abs()
+        })
+        .min()
+        .unwrap();
+
+    let item_source_occurrences = IdentifierOccurrences::within_string(&declaration.item_text().0);
+    let item_signature_occurrences =
+        IdentifierOccurrences::within_string(&declaration.signature_text().0);
+    let containing_range_vs_item_jaccard = jaccard_similarity(
+        containing_range_identifier_occurrences,
+        &item_source_occurrences,
+    );
+    let containing_range_vs_signature_jaccard = jaccard_similarity(
+        containing_range_identifier_occurrences,
+        &item_signature_occurrences,
+    );
+    let adjacent_vs_item_jaccard =
+        jaccard_similarity(adjacent_identifier_occurrences, &item_source_occurrences);
+    let adjacent_vs_signature_jaccard =
+        jaccard_similarity(adjacent_identifier_occurrences, &item_signature_occurrences);
+
+    let containing_range_vs_item_weighted_overlap = weighted_overlap_coefficient(
+        containing_range_identifier_occurrences,
+        &item_source_occurrences,
+    );
+    let containing_range_vs_signature_weighted_overlap = weighted_overlap_coefficient(
+        containing_range_identifier_occurrences,
+        &item_signature_occurrences,
+    );
+    let adjacent_vs_item_weighted_overlap =
+        weighted_overlap_coefficient(adjacent_identifier_occurrences, &item_source_occurrences);
+    let adjacent_vs_signature_weighted_overlap =
+        weighted_overlap_coefficient(adjacent_identifier_occurrences, &item_signature_occurrences);
+
+    let score_components = ScoreInputs {
+        is_same_file,
+        is_referenced_nearby,
+        is_referenced_in_breadcrumb,
+        reference_line_distance,
+        declaration_line_distance,
+        declaration_line_distance_rank,
+        reference_count,
+        same_file_declaration_count,
+        declaration_count,
+        containing_range_vs_item_jaccard,
+        containing_range_vs_signature_jaccard,
+        adjacent_vs_item_jaccard,
+        adjacent_vs_signature_jaccard,
+        containing_range_vs_item_weighted_overlap,
+        containing_range_vs_signature_weighted_overlap,
+        adjacent_vs_item_weighted_overlap,
+        adjacent_vs_signature_weighted_overlap,
+    };
+
+    Some(ScoredSnippet {
+        identifier: identifier.clone(),
+        declaration: declaration,
+        scores: score_components.score(),
+        score_components,
+    })
+}
+
+#[derive(Clone, Debug, Serialize)]
+pub struct ScoreInputs {
+    pub is_same_file: bool,
+    pub is_referenced_nearby: bool,
+    pub is_referenced_in_breadcrumb: bool,
+    pub reference_count: usize,
+    pub same_file_declaration_count: usize,
+    pub declaration_count: usize,
+    pub reference_line_distance: u32,
+    pub declaration_line_distance: u32,
+    pub declaration_line_distance_rank: usize,
+    pub containing_range_vs_item_jaccard: f32,
+    pub containing_range_vs_signature_jaccard: f32,
+    pub adjacent_vs_item_jaccard: f32,
+    pub adjacent_vs_signature_jaccard: f32,
+    pub containing_range_vs_item_weighted_overlap: f32,
+    pub containing_range_vs_signature_weighted_overlap: f32,
+    pub adjacent_vs_item_weighted_overlap: f32,
+    pub adjacent_vs_signature_weighted_overlap: f32,
+}
+
+#[derive(Clone, Debug, Serialize)]
+pub struct Scores {
+    pub signature: f32,
+    pub declaration: f32,
+}
+
+impl ScoreInputs {
+    fn score(&self) -> Scores {
+        // Score related to how likely this is the correct declaration, range 0 to 1
+        let accuracy_score = if self.is_same_file {
+            // TODO: use declaration_line_distance_rank
+            1.0 / self.same_file_declaration_count as f32
+        } else {
+            1.0 / self.declaration_count as f32
+        };
+
+        // Score related to the distance between the reference and cursor, range 0 to 1
+        let distance_score = if self.is_referenced_nearby {
+            1.0 / (1.0 + self.reference_line_distance as f32 / 10.0).powf(2.0)
+        } else {
+            // same score as ~14 lines away, rationale is to not overly penalize references from parent signatures
+            0.5
+        };
+
+        // For now instead of linear combination, the scores are just multiplied together.
+        let combined_score = 10.0 * accuracy_score * distance_score;
+
+        Scores {
+            signature: combined_score * self.containing_range_vs_signature_weighted_overlap,
+            // declaration score gets boosted both by being multiplied by 2 and by there being more
+            // weighted overlap.
+            declaration: 2.0 * combined_score * self.containing_range_vs_item_weighted_overlap,
+        }
+    }
+}

crates/edit_prediction_context/src/edit_prediction_context.rs 🔗

@@ -1,8 +1,220 @@
+mod declaration;
+mod declaration_scoring;
 mod excerpt;
 mod outline;
 mod reference;
-mod tree_sitter_index;
+mod syntax_index;
+mod text_similarity;
 
+pub use declaration::{BufferDeclaration, Declaration, FileDeclaration, Identifier};
 pub use excerpt::{EditPredictionExcerpt, EditPredictionExcerptOptions, EditPredictionExcerptText};
+use gpui::{App, AppContext as _, Entity, Task};
+use language::BufferSnapshot;
 pub use reference::references_in_excerpt;
-pub use tree_sitter_index::{BufferDeclaration, Declaration, FileDeclaration, TreeSitterIndex};
+pub use syntax_index::SyntaxIndex;
+use text::{Point, ToOffset as _};
+
+use crate::declaration_scoring::{ScoredSnippet, scored_snippets};
+
+pub struct EditPredictionContext {
+    pub excerpt: EditPredictionExcerpt,
+    pub excerpt_text: EditPredictionExcerptText,
+    pub snippets: Vec<ScoredSnippet>,
+}
+
+impl EditPredictionContext {
+    pub fn gather(
+        cursor_point: Point,
+        buffer: BufferSnapshot,
+        excerpt_options: EditPredictionExcerptOptions,
+        syntax_index: Entity<SyntaxIndex>,
+        cx: &mut App,
+    ) -> Task<Self> {
+        let index_state = syntax_index.read_with(cx, |index, _cx| index.state().clone());
+        cx.background_spawn(async move {
+            let index_state = index_state.lock().await;
+
+            let excerpt =
+                EditPredictionExcerpt::select_from_buffer(cursor_point, &buffer, &excerpt_options)
+                    .unwrap();
+            let excerpt_text = excerpt.text(&buffer);
+            let references = references_in_excerpt(&excerpt, &excerpt_text, &buffer);
+            let cursor_offset = cursor_point.to_offset(&buffer);
+
+            let snippets = scored_snippets(
+                &index_state,
+                &excerpt,
+                &excerpt_text,
+                references,
+                cursor_offset,
+                &buffer,
+            );
+
+            Self {
+                excerpt,
+                excerpt_text,
+                snippets,
+            }
+        })
+    }
+}
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+    use std::sync::Arc;
+
+    use gpui::{Entity, TestAppContext};
+    use indoc::indoc;
+    use language::{Language, LanguageConfig, LanguageId, LanguageMatcher, tree_sitter_rust};
+    use project::{FakeFs, Project};
+    use serde_json::json;
+    use settings::SettingsStore;
+    use util::path;
+
+    use crate::{EditPredictionExcerptOptions, SyntaxIndex};
+
+    #[gpui::test]
+    async fn test_call_site(cx: &mut TestAppContext) {
+        let (project, index, _rust_lang_id) = init_test(cx).await;
+
+        let buffer = project
+            .update(cx, |project, cx| {
+                let project_path = project.find_project_path("c.rs", cx).unwrap();
+                project.open_buffer(project_path, cx)
+            })
+            .await
+            .unwrap();
+
+        cx.run_until_parked();
+
+        // first process_data call site
+        let cursor_point = language::Point::new(8, 21);
+        let buffer_snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot());
+
+        let context = cx
+            .update(|cx| {
+                EditPredictionContext::gather(
+                    cursor_point,
+                    buffer_snapshot,
+                    EditPredictionExcerptOptions {
+                        max_bytes: 40,
+                        min_bytes: 10,
+                        target_before_cursor_over_total_bytes: 0.5,
+                        include_parent_signatures: false,
+                    },
+                    index,
+                    cx,
+                )
+            })
+            .await;
+
+        assert_eq!(context.snippets.len(), 1);
+        assert_eq!(context.snippets[0].identifier.name.as_ref(), "process_data");
+        drop(buffer);
+    }
+
+    async fn init_test(
+        cx: &mut TestAppContext,
+    ) -> (Entity<Project>, Entity<SyntaxIndex>, LanguageId) {
+        cx.update(|cx| {
+            let settings_store = SettingsStore::test(cx);
+            cx.set_global(settings_store);
+            language::init(cx);
+            Project::init_settings(cx);
+        });
+
+        let fs = FakeFs::new(cx.executor());
+        fs.insert_tree(
+            path!("/root"),
+            json!({
+                "a.rs": indoc! {r#"
+                    fn main() {
+                        let x = 1;
+                        let y = 2;
+                        let z = add(x, y);
+                        println!("Result: {}", z);
+                    }
+
+                    fn add(a: i32, b: i32) -> i32 {
+                        a + b
+                    }
+                "#},
+                "b.rs": indoc! {"
+                    pub struct Config {
+                        pub name: String,
+                        pub value: i32,
+                    }
+
+                    impl Config {
+                        pub fn new(name: String, value: i32) -> Self {
+                            Config { name, value }
+                        }
+                    }
+                "},
+                "c.rs": indoc! {r#"
+                    use std::collections::HashMap;
+
+                    fn main() {
+                        let args: Vec<String> = std::env::args().collect();
+                        let data: Vec<i32> = args[1..]
+                            .iter()
+                            .filter_map(|s| s.parse().ok())
+                            .collect();
+                        let result = process_data(data);
+                        println!("{:?}", result);
+                    }
+
+                    fn process_data(data: Vec<i32>) -> HashMap<i32, usize> {
+                        let mut counts = HashMap::new();
+                        for value in data {
+                            *counts.entry(value).or_insert(0) += 1;
+                        }
+                        counts
+                    }
+
+                    #[cfg(test)]
+                    mod tests {
+                        use super::*;
+
+                        #[test]
+                        fn test_process_data() {
+                            let data = vec![1, 2, 2, 3];
+                            let result = process_data(data);
+                            assert_eq!(result.get(&2), Some(&2));
+                        }
+                    }
+                "#}
+            }),
+        )
+        .await;
+        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
+        let language_registry = project.read_with(cx, |project, _| project.languages().clone());
+        let lang = rust_lang();
+        let lang_id = lang.id();
+        language_registry.add(Arc::new(lang));
+
+        let index = cx.new(|cx| SyntaxIndex::new(&project, cx));
+        cx.run_until_parked();
+
+        (project, index, lang_id)
+    }
+
+    fn rust_lang() -> Language {
+        Language::new(
+            LanguageConfig {
+                name: "Rust".into(),
+                matcher: LanguageMatcher {
+                    path_suffixes: vec!["rs".to_string()],
+                    ..Default::default()
+                },
+                ..Default::default()
+            },
+            Some(tree_sitter_rust::LANGUAGE.into()),
+        )
+        .with_highlights_query(include_str!("../../languages/src/rust/highlights.scm"))
+        .unwrap()
+        .with_outline_query(include_str!("../../languages/src/rust/outline.scm"))
+        .unwrap()
+    }
+}

crates/edit_prediction_context/src/excerpt.rs 🔗

@@ -31,7 +31,7 @@ pub struct EditPredictionExcerptOptions {
     pub include_parent_signatures: bool,
 }
 
-#[derive(Clone)]
+#[derive(Debug, Clone)]
 pub struct EditPredictionExcerpt {
     pub range: Range<usize>,
     pub parent_signature_ranges: Vec<Range<usize>>,

crates/edit_prediction_context/src/outline.rs 🔗

@@ -1,5 +1,7 @@
-use language::{BufferSnapshot, LanguageId, SyntaxMapMatches};
-use std::{cmp::Reverse, ops::Range, sync::Arc};
+use language::{BufferSnapshot, SyntaxMapMatches};
+use std::{cmp::Reverse, ops::Range};
+
+use crate::declaration::Identifier;
 
 // TODO:
 //
@@ -18,12 +20,6 @@ pub struct OutlineDeclaration {
     pub signature_range: Range<usize>,
 }
 
-#[derive(Debug, Clone, Eq, PartialEq, Hash)]
-pub struct Identifier {
-    pub name: Arc<str>,
-    pub language_id: LanguageId,
-}
-
 pub fn declarations_in_buffer(buffer: &BufferSnapshot) -> Vec<OutlineDeclaration> {
     declarations_overlapping_range(0..buffer.len(), buffer)
 }

crates/edit_prediction_context/src/reference.rs 🔗

@@ -3,8 +3,8 @@ use std::collections::HashMap;
 use std::ops::Range;
 
 use crate::{
+    declaration::Identifier,
     excerpt::{EditPredictionExcerpt, EditPredictionExcerptText},
-    outline::Identifier,
 };
 
 #[derive(Debug)]

crates/edit_prediction_context/src/tree_sitter_index.rs → crates/edit_prediction_context/src/syntax_index.rs 🔗

@@ -1,20 +1,26 @@
+use std::sync::Arc;
+
 use collections::{HashMap, HashSet};
+use futures::lock::Mutex;
 use gpui::{App, AppContext as _, Context, Entity, Task, WeakEntity};
-use language::{Buffer, BufferEvent, BufferSnapshot};
+use language::{Buffer, BufferEvent};
 use project::buffer_store::{BufferStore, BufferStoreEvent};
 use project::worktree_store::{WorktreeStore, WorktreeStoreEvent};
 use project::{PathChange, Project, ProjectEntryId, ProjectPath};
 use slotmap::SlotMap;
-use std::ops::Range;
-use std::sync::Arc;
-use text::Anchor;
+use text::BufferId;
 use util::{ResultExt as _, debug_panic, some_or_debug_panic};
 
-use crate::outline::{Identifier, OutlineDeclaration, declarations_in_buffer};
+use crate::declaration::{
+    BufferDeclaration, Declaration, DeclarationId, FileDeclaration, Identifier,
+};
+use crate::outline::declarations_in_buffer;
 
 // TODO:
 //
 // * Skip for remote projects
+//
+// * Consider making SyntaxIndex not an Entity.
 
 // Potential future improvements:
 //
@@ -34,17 +40,19 @@ use crate::outline::{Identifier, OutlineDeclaration, declarations_in_buffer};
 // * Concurrent slotmap
 //
 // * Use queue for parsing
+//
 
-slotmap::new_key_type! {
-    pub struct DeclarationId;
+pub struct SyntaxIndex {
+    state: Arc<Mutex<SyntaxIndexState>>,
+    project: WeakEntity<Project>,
 }
 
-pub struct TreeSitterIndex {
+#[derive(Default)]
+pub struct SyntaxIndexState {
     declarations: SlotMap<DeclarationId, Declaration>,
     identifiers: HashMap<Identifier, HashSet<DeclarationId>>,
     files: HashMap<ProjectEntryId, FileState>,
-    buffers: HashMap<WeakEntity<Buffer>, BufferState>,
-    project: WeakEntity<Project>,
+    buffers: HashMap<BufferId, BufferState>,
 }
 
 #[derive(Debug, Default)]
@@ -59,52 +67,11 @@ struct BufferState {
     task: Option<Task<()>>,
 }
 
-#[derive(Debug, Clone)]
-pub enum Declaration {
-    File {
-        project_entry_id: ProjectEntryId,
-        declaration: FileDeclaration,
-    },
-    Buffer {
-        buffer: WeakEntity<Buffer>,
-        declaration: BufferDeclaration,
-    },
-}
-
-impl Declaration {
-    fn identifier(&self) -> &Identifier {
-        match self {
-            Declaration::File { declaration, .. } => &declaration.identifier,
-            Declaration::Buffer { declaration, .. } => &declaration.identifier,
-        }
-    }
-}
-
-#[derive(Debug, Clone)]
-pub struct FileDeclaration {
-    pub parent: Option<DeclarationId>,
-    pub identifier: Identifier,
-    pub item_range: Range<usize>,
-    pub signature_range: Range<usize>,
-    pub signature_text: Arc<str>,
-}
-
-#[derive(Debug, Clone)]
-pub struct BufferDeclaration {
-    pub parent: Option<DeclarationId>,
-    pub identifier: Identifier,
-    pub item_range: Range<Anchor>,
-    pub signature_range: Range<Anchor>,
-}
-
-impl TreeSitterIndex {
+impl SyntaxIndex {
     pub fn new(project: &Entity<Project>, cx: &mut Context<Self>) -> Self {
         let mut this = Self {
-            declarations: SlotMap::with_key(),
-            identifiers: HashMap::default(),
             project: project.downgrade(),
-            files: HashMap::default(),
-            buffers: HashMap::default(),
+            state: Arc::new(Mutex::new(SyntaxIndexState::default())),
         };
 
         let worktree_store = project.read(cx).worktree_store();
@@ -139,73 +106,6 @@ impl TreeSitterIndex {
         this
     }
 
-    pub fn declaration(&self, id: DeclarationId) -> Option<&Declaration> {
-        self.declarations.get(id)
-    }
-
-    pub fn declarations_for_identifier<const N: usize>(
-        &self,
-        identifier: Identifier,
-        cx: &App,
-    ) -> Vec<Declaration> {
-        // make sure to not have a large stack allocation
-        assert!(N < 32);
-
-        let Some(declaration_ids) = self.identifiers.get(&identifier) else {
-            return vec![];
-        };
-
-        let mut result = Vec::with_capacity(N);
-        let mut included_buffer_entry_ids = arrayvec::ArrayVec::<_, N>::new();
-        let mut file_declarations = Vec::new();
-
-        for declaration_id in declaration_ids {
-            let declaration = self.declarations.get(*declaration_id);
-            let Some(declaration) = some_or_debug_panic(declaration) else {
-                continue;
-            };
-            match declaration {
-                Declaration::Buffer { buffer, .. } => {
-                    if let Ok(Some(entry_id)) = buffer.read_with(cx, |buffer, cx| {
-                        project::File::from_dyn(buffer.file()).and_then(|f| f.project_entry_id(cx))
-                    }) {
-                        included_buffer_entry_ids.push(entry_id);
-                        result.push(declaration.clone());
-                        if result.len() == N {
-                            return result;
-                        }
-                    }
-                }
-                Declaration::File {
-                    project_entry_id, ..
-                } => {
-                    if !included_buffer_entry_ids.contains(project_entry_id) {
-                        file_declarations.push(declaration.clone());
-                    }
-                }
-            }
-        }
-
-        for declaration in file_declarations {
-            match declaration {
-                Declaration::File {
-                    project_entry_id, ..
-                } => {
-                    if !included_buffer_entry_ids.contains(&project_entry_id) {
-                        result.push(declaration);
-
-                        if result.len() == N {
-                            return result;
-                        }
-                    }
-                }
-                Declaration::Buffer { .. } => {}
-            }
-        }
-
-        result
-    }
-
     fn handle_worktree_store_event(
         &mut self,
         _worktree_store: Entity<WorktreeStore>,
@@ -215,21 +115,33 @@ impl TreeSitterIndex {
         use WorktreeStoreEvent::*;
         match event {
             WorktreeUpdatedEntries(worktree_id, updated_entries_set) => {
-                for (path, entry_id, path_change) in updated_entries_set.iter() {
-                    if let PathChange::Removed = path_change {
-                        self.files.remove(entry_id);
-                    } else {
-                        let project_path = ProjectPath {
-                            worktree_id: *worktree_id,
-                            path: path.clone(),
-                        };
-                        self.update_file(*entry_id, project_path, cx);
+                let state = Arc::downgrade(&self.state);
+                let worktree_id = *worktree_id;
+                let updated_entries_set = updated_entries_set.clone();
+                cx.spawn(async move |this, cx| {
+                    let Some(state) = state.upgrade() else { return };
+                    for (path, entry_id, path_change) in updated_entries_set.iter() {
+                        if let PathChange::Removed = path_change {
+                            state.lock().await.files.remove(entry_id);
+                        } else {
+                            let project_path = ProjectPath {
+                                worktree_id,
+                                path: path.clone(),
+                            };
+                            this.update(cx, |this, cx| {
+                                this.update_file(*entry_id, project_path, cx);
+                            })
+                            .ok();
+                        }
                     }
-                }
+                })
+                .detach();
             }
             WorktreeDeletedEntry(_worktree_id, project_entry_id) => {
-                // TODO: Is this needed?
-                self.files.remove(project_entry_id);
+                let project_entry_id = *project_entry_id;
+                self.with_state(cx, move |state| {
+                    state.files.remove(&project_entry_id);
+                })
             }
             _ => {}
         }
@@ -251,15 +163,42 @@ impl TreeSitterIndex {
         }
     }
 
+    pub fn state(&self) -> &Arc<Mutex<SyntaxIndexState>> {
+        &self.state
+    }
+
+    fn with_state(&self, cx: &mut App, f: impl FnOnce(&mut SyntaxIndexState) + Send + 'static) {
+        if let Some(mut state) = self.state.try_lock() {
+            f(&mut state);
+            return;
+        }
+        let state = Arc::downgrade(&self.state);
+        cx.background_spawn(async move {
+            let Some(state) = state.upgrade() else {
+                return;
+            };
+            let mut state = state.lock().await;
+            f(&mut state)
+        })
+        .detach();
+    }
+
     fn register_buffer(&mut self, buffer: &Entity<Buffer>, cx: &mut Context<Self>) {
-        self.buffers
-            .insert(buffer.downgrade(), BufferState::default());
-        let weak_buf = buffer.downgrade();
-        cx.observe_release(buffer, move |this, _buffer, _cx| {
-            this.buffers.remove(&weak_buf);
+        let buffer_id = buffer.read(cx).remote_id();
+        cx.observe_release(buffer, move |this, _buffer, cx| {
+            this.with_state(cx, move |state| {
+                if let Some(buffer_state) = state.buffers.remove(&buffer_id) {
+                    SyntaxIndexState::remove_buffer_declarations(
+                        &buffer_state.declarations,
+                        &mut state.declarations,
+                        &mut state.identifiers,
+                    );
+                }
+            })
         })
         .detach();
         cx.subscribe(buffer, Self::handle_buffer_event).detach();
+
         self.update_buffer(buffer.clone(), cx);
     }
 
@@ -275,10 +214,19 @@ impl TreeSitterIndex {
         }
     }
 
-    fn update_buffer(&mut self, buffer: Entity<Buffer>, cx: &Context<Self>) {
-        let mut parse_status = buffer.read(cx).parse_status();
+    fn update_buffer(&mut self, buffer_entity: Entity<Buffer>, cx: &mut Context<Self>) {
+        let buffer = buffer_entity.read(cx);
+
+        let Some(project_entry_id) =
+            project::File::from_dyn(buffer.file()).and_then(|f| f.project_entry_id(cx))
+        else {
+            return;
+        };
+        let buffer_id = buffer.remote_id();
+
+        let mut parse_status = buffer.parse_status();
         let snapshot_task = cx.spawn({
-            let weak_buffer = buffer.downgrade();
+            let weak_buffer = buffer_entity.downgrade();
             async move |_, cx| {
                 while *parse_status.borrow() != language::ParseStatus::Idle {
                     parse_status.changed().await?;
@@ -289,75 +237,77 @@ impl TreeSitterIndex {
 
         let parse_task = cx.background_spawn(async move {
             let snapshot = snapshot_task.await?;
+            let rope = snapshot.text.as_rope().clone();
 
-            anyhow::Ok(
+            anyhow::Ok((
                 declarations_in_buffer(&snapshot)
                     .into_iter()
                     .map(|item| {
                         (
                             item.parent_index,
-                            BufferDeclaration::from_outline(item, &snapshot),
+                            BufferDeclaration::from_outline(item, &rope),
                         )
                     })
                     .collect::<Vec<_>>(),
-            )
+                rope,
+            ))
         });
 
         let task = cx.spawn({
-            let weak_buffer = buffer.downgrade();
             async move |this, cx| {
-                let Ok(declarations) = parse_task.await else {
+                let Ok((declarations, rope)) = parse_task.await else {
                     return;
                 };
 
-                this.update(cx, |this, _cx| {
-                    let buffer_state = this
-                        .buffers
-                        .entry(weak_buffer.clone())
-                        .or_insert_with(Default::default);
-
-                    for old_declaration_id in &buffer_state.declarations {
-                        let Some(declaration) = this.declarations.remove(*old_declaration_id)
-                        else {
-                            debug_panic!("declaration not found");
-                            continue;
-                        };
-                        if let Some(identifier_declarations) =
-                            this.identifiers.get_mut(declaration.identifier())
-                        {
-                            identifier_declarations.remove(old_declaration_id);
+                this.update(cx, move |this, cx| {
+                    this.with_state(cx, move |state| {
+                        let buffer_state = state
+                            .buffers
+                            .entry(buffer_id)
+                            .or_insert_with(Default::default);
+
+                        SyntaxIndexState::remove_buffer_declarations(
+                            &buffer_state.declarations,
+                            &mut state.declarations,
+                            &mut state.identifiers,
+                        );
+
+                        let mut new_ids = Vec::with_capacity(declarations.len());
+                        state.declarations.reserve(declarations.len());
+                        for (parent_index, mut declaration) in declarations {
+                            declaration.parent = parent_index
+                                .and_then(|ix| some_or_debug_panic(new_ids.get(ix).copied()));
+
+                            let identifier = declaration.identifier.clone();
+                            let declaration_id = state.declarations.insert(Declaration::Buffer {
+                                rope: rope.clone(),
+                                buffer_id,
+                                declaration,
+                                project_entry_id,
+                            });
+                            new_ids.push(declaration_id);
+
+                            state
+                                .identifiers
+                                .entry(identifier)
+                                .or_default()
+                                .insert(declaration_id);
                         }
-                    }
-
-                    let mut new_ids = Vec::with_capacity(declarations.len());
-                    this.declarations.reserve(declarations.len());
-                    for (parent_index, mut declaration) in declarations {
-                        declaration.parent = parent_index
-                            .and_then(|ix| some_or_debug_panic(new_ids.get(ix).copied()));
-
-                        let identifier = declaration.identifier.clone();
-                        let declaration_id = this.declarations.insert(Declaration::Buffer {
-                            buffer: weak_buffer.clone(),
-                            declaration,
-                        });
-                        new_ids.push(declaration_id);
-
-                        this.identifiers
-                            .entry(identifier)
-                            .or_default()
-                            .insert(declaration_id);
-                    }
 
-                    buffer_state.declarations = new_ids;
+                        buffer_state.declarations = new_ids;
+                    });
                 })
                 .ok();
             }
         });
 
-        self.buffers
-            .entry(buffer.downgrade())
-            .or_insert_with(Default::default)
-            .task = Some(task);
+        self.with_state(cx, move |state| {
+            state
+                .buffers
+                .entry(buffer_id)
+                .or_insert_with(Default::default)
+                .task = Some(task)
+        });
     }
 
     fn update_file(
@@ -401,14 +351,10 @@ impl TreeSitterIndex {
 
         let parse_task = cx.background_spawn(async move {
             let snapshot = snapshot_task.await?;
+            let rope = snapshot.as_rope();
             let declarations = declarations_in_buffer(&snapshot)
                 .into_iter()
-                .map(|item| {
-                    (
-                        item.parent_index,
-                        FileDeclaration::from_outline(item, &snapshot),
-                    )
-                })
+                .map(|item| (item.parent_index, FileDeclaration::from_outline(item, rope)))
                 .collect::<Vec<_>>();
             anyhow::Ok(declarations)
         });
@@ -419,84 +365,160 @@ impl TreeSitterIndex {
                 let Ok(declarations) = parse_task.await else {
                     return;
                 };
-                this.update(cx, |this, _cx| {
-                    let file_state = this.files.entry(entry_id).or_insert_with(Default::default);
-
-                    for old_declaration_id in &file_state.declarations {
-                        let Some(declaration) = this.declarations.remove(*old_declaration_id)
-                        else {
-                            debug_panic!("declaration not found");
-                            continue;
-                        };
-                        if let Some(identifier_declarations) =
-                            this.identifiers.get_mut(declaration.identifier())
-                        {
-                            identifier_declarations.remove(old_declaration_id);
+                this.update(cx, |this, cx| {
+                    this.with_state(cx, move |state| {
+                        let file_state =
+                            state.files.entry(entry_id).or_insert_with(Default::default);
+
+                        for old_declaration_id in &file_state.declarations {
+                            let Some(declaration) = state.declarations.remove(*old_declaration_id)
+                            else {
+                                debug_panic!("declaration not found");
+                                continue;
+                            };
+                            if let Some(identifier_declarations) =
+                                state.identifiers.get_mut(declaration.identifier())
+                            {
+                                identifier_declarations.remove(old_declaration_id);
+                            }
                         }
-                    }
-
-                    let mut new_ids = Vec::with_capacity(declarations.len());
-                    this.declarations.reserve(declarations.len());
 
-                    for (parent_index, mut declaration) in declarations {
-                        declaration.parent = parent_index
-                            .and_then(|ix| some_or_debug_panic(new_ids.get(ix).copied()));
-
-                        let identifier = declaration.identifier.clone();
-                        let declaration_id = this.declarations.insert(Declaration::File {
-                            project_entry_id: entry_id,
-                            declaration,
-                        });
-                        new_ids.push(declaration_id);
-
-                        this.identifiers
-                            .entry(identifier)
-                            .or_default()
-                            .insert(declaration_id);
-                    }
+                        let mut new_ids = Vec::with_capacity(declarations.len());
+                        state.declarations.reserve(declarations.len());
+
+                        for (parent_index, mut declaration) in declarations {
+                            declaration.parent = parent_index
+                                .and_then(|ix| some_or_debug_panic(new_ids.get(ix).copied()));
+
+                            let identifier = declaration.identifier.clone();
+                            let declaration_id = state.declarations.insert(Declaration::File {
+                                project_entry_id: entry_id,
+                                declaration,
+                            });
+                            new_ids.push(declaration_id);
+
+                            state
+                                .identifiers
+                                .entry(identifier)
+                                .or_default()
+                                .insert(declaration_id);
+                        }
 
-                    file_state.declarations = new_ids;
+                        file_state.declarations = new_ids;
+                    });
                 })
                 .ok();
             }
         });
 
-        self.files
-            .entry(entry_id)
-            .or_insert_with(Default::default)
-            .task = Some(task);
+        self.with_state(cx, move |state| {
+            state
+                .files
+                .entry(entry_id)
+                .or_insert_with(Default::default)
+                .task = Some(task);
+        });
     }
 }
 
-impl BufferDeclaration {
-    pub fn from_outline(declaration: OutlineDeclaration, snapshot: &BufferSnapshot) -> Self {
-        // use of anchor_before is a guess that the proper behavior is to expand to include
-        // insertions immediately before the declaration, but not for insertions immediately after
-        Self {
-            parent: None,
-            identifier: declaration.identifier,
-            item_range: snapshot.anchor_before(declaration.item_range.start)
-                ..snapshot.anchor_before(declaration.item_range.end),
-            signature_range: snapshot.anchor_before(declaration.signature_range.start)
-                ..snapshot.anchor_before(declaration.signature_range.end),
+impl SyntaxIndexState {
+    pub fn declaration(&self, id: DeclarationId) -> Option<&Declaration> {
+        self.declarations.get(id)
+    }
+
+    /// Returns declarations for the identifier. If the limit is exceeded, returns an empty vector.
+    ///
+    /// TODO: Consider doing some pre-ranking and instead truncating when N is exceeded.
+    pub fn declarations_for_identifier<const N: usize>(
+        &self,
+        identifier: &Identifier,
+    ) -> Vec<Declaration> {
+        // make sure to not have a large stack allocation
+        assert!(N < 32);
+
+        let Some(declaration_ids) = self.identifiers.get(&identifier) else {
+            return vec![];
+        };
+
+        let mut result = Vec::with_capacity(N);
+        let mut included_buffer_entry_ids = arrayvec::ArrayVec::<_, N>::new();
+        let mut file_declarations = Vec::new();
+
+        for declaration_id in declaration_ids {
+            let declaration = self.declarations.get(*declaration_id);
+            let Some(declaration) = some_or_debug_panic(declaration) else {
+                continue;
+            };
+            match declaration {
+                Declaration::Buffer {
+                    project_entry_id, ..
+                } => {
+                    included_buffer_entry_ids.push(*project_entry_id);
+                    result.push(declaration.clone());
+                    if result.len() == N {
+                        return Vec::new();
+                    }
+                }
+                Declaration::File {
+                    project_entry_id, ..
+                } => {
+                    if !included_buffer_entry_ids.contains(&project_entry_id) {
+                        file_declarations.push(declaration.clone());
+                    }
+                }
+            }
+        }
+
+        for declaration in file_declarations {
+            match declaration {
+                Declaration::File {
+                    project_entry_id, ..
+                } => {
+                    if !included_buffer_entry_ids.contains(&project_entry_id) {
+                        result.push(declaration);
+
+                        if result.len() == N {
+                            return Vec::new();
+                        }
+                    }
+                }
+                Declaration::Buffer { .. } => {}
+            }
+        }
+
+        result
+    }
+
+    pub fn file_declaration_count(&self, declaration: &Declaration) -> usize {
+        match declaration {
+            Declaration::File {
+                project_entry_id, ..
+            } => self
+                .files
+                .get(project_entry_id)
+                .map(|file_state| file_state.declarations.len())
+                .unwrap_or_default(),
+            Declaration::Buffer { buffer_id, .. } => self
+                .buffers
+                .get(buffer_id)
+                .map(|buffer_state| buffer_state.declarations.len())
+                .unwrap_or_default(),
         }
     }
-}
 
-impl FileDeclaration {
-    pub fn from_outline(
-        declaration: OutlineDeclaration,
-        snapshot: &BufferSnapshot,
-    ) -> FileDeclaration {
-        FileDeclaration {
-            parent: None,
-            identifier: declaration.identifier,
-            item_range: declaration.item_range,
-            signature_text: snapshot
-                .text_for_range(declaration.signature_range.clone())
-                .collect::<String>()
-                .into(),
-            signature_range: declaration.signature_range,
+    fn remove_buffer_declarations(
+        old_declaration_ids: &[DeclarationId],
+        declarations: &mut SlotMap<DeclarationId, Declaration>,
+        identifiers: &mut HashMap<Identifier, HashSet<DeclarationId>>,
+    ) {
+        for old_declaration_id in old_declaration_ids {
+            let Some(declaration) = declarations.remove(*old_declaration_id) else {
+                debug_panic!("declaration not found");
+                continue;
+            };
+            if let Some(identifier_declarations) = identifiers.get_mut(declaration.identifier()) {
+                identifier_declarations.remove(old_declaration_id);
+            }
         }
     }
 }
@@ -509,13 +531,13 @@ mod tests {
     use gpui::TestAppContext;
     use indoc::indoc;
     use language::{Language, LanguageConfig, LanguageId, LanguageMatcher, tree_sitter_rust};
-    use project::{FakeFs, Project, ProjectItem};
+    use project::{FakeFs, Project};
     use serde_json::json;
     use settings::SettingsStore;
     use text::OffsetRangeExt as _;
     use util::path;
 
-    use crate::tree_sitter_index::TreeSitterIndex;
+    use crate::syntax_index::SyntaxIndex;
 
     #[gpui::test]
     async fn test_unopen_indexed_files(cx: &mut TestAppContext) {
@@ -525,17 +547,19 @@ mod tests {
             language_id: rust_lang_id,
         };
 
-        index.read_with(cx, |index, cx| {
-            let decls = index.declarations_for_identifier::<8>(main.clone(), cx);
+        let index_state = index.read_with(cx, |index, _cx| index.state().clone());
+        let index_state = index_state.lock().await;
+        cx.update(|cx| {
+            let decls = index_state.declarations_for_identifier::<8>(&main);
             assert_eq!(decls.len(), 2);
 
             let decl = expect_file_decl("c.rs", &decls[0], &project, cx);
             assert_eq!(decl.identifier, main.clone());
-            assert_eq!(decl.item_range, 32..279);
+            assert_eq!(decl.item_range_in_file, 32..280);
 
             let decl = expect_file_decl("a.rs", &decls[1], &project, cx);
             assert_eq!(decl.identifier, main);
-            assert_eq!(decl.item_range, 0..97);
+            assert_eq!(decl.item_range_in_file, 0..98);
         });
     }
 
@@ -547,15 +571,17 @@ mod tests {
             language_id: rust_lang_id,
         };
 
-        index.read_with(cx, |index, cx| {
-            let decls = index.declarations_for_identifier::<8>(test_process_data.clone(), cx);
+        let index_state = index.read_with(cx, |index, _cx| index.state().clone());
+        let index_state = index_state.lock().await;
+        cx.update(|cx| {
+            let decls = index_state.declarations_for_identifier::<8>(&test_process_data);
             assert_eq!(decls.len(), 1);
 
             let decl = expect_file_decl("c.rs", &decls[0], &project, cx);
             assert_eq!(decl.identifier, test_process_data);
 
             let parent_id = decl.parent.unwrap();
-            let parent = index.declaration(parent_id).unwrap();
+            let parent = index_state.declaration(parent_id).unwrap();
             let parent_decl = expect_file_decl("c.rs", &parent, &project, cx);
             assert_eq!(
                 parent_decl.identifier,
@@ -586,16 +612,18 @@ mod tests {
 
         cx.run_until_parked();
 
-        index.read_with(cx, |index, cx| {
-            let decls = index.declarations_for_identifier::<8>(test_process_data.clone(), cx);
+        let index_state = index.read_with(cx, |index, _cx| index.state().clone());
+        let index_state = index_state.lock().await;
+        cx.update(|cx| {
+            let decls = index_state.declarations_for_identifier::<8>(&test_process_data);
             assert_eq!(decls.len(), 1);
 
-            let decl = expect_buffer_decl("c.rs", &decls[0], cx);
+            let decl = expect_buffer_decl("c.rs", &decls[0], &project, cx);
             assert_eq!(decl.identifier, test_process_data);
 
             let parent_id = decl.parent.unwrap();
-            let parent = index.declaration(parent_id).unwrap();
-            let parent_decl = expect_buffer_decl("c.rs", &parent, cx);
+            let parent = index_state.declaration(parent_id).unwrap();
+            let parent_decl = expect_buffer_decl("c.rs", &parent, &project, cx);
             assert_eq!(
                 parent_decl.identifier,
                 Identifier {
@@ -613,16 +641,13 @@ mod tests {
     async fn test_declarations_limt(cx: &mut TestAppContext) {
         let (_, index, rust_lang_id) = init_test(cx).await;
 
-        index.read_with(cx, |index, cx| {
-            let decls = index.declarations_for_identifier::<1>(
-                Identifier {
-                    name: "main".into(),
-                    language_id: rust_lang_id,
-                },
-                cx,
-            );
-            assert_eq!(decls.len(), 1);
+        let index_state = index.read_with(cx, |index, _cx| index.state().clone());
+        let index_state = index_state.lock().await;
+        let decls = index_state.declarations_for_identifier::<1>(&Identifier {
+            name: "main".into(),
+            language_id: rust_lang_id,
         });
+        assert_eq!(decls.len(), 0);
     }
 
     #[gpui::test]
@@ -644,24 +669,31 @@ mod tests {
 
         cx.run_until_parked();
 
-        index.read_with(cx, |index, cx| {
-            let decls = index.declarations_for_identifier::<8>(main.clone(), cx);
-            assert_eq!(decls.len(), 2);
-            let decl = expect_buffer_decl("c.rs", &decls[0], cx);
-            assert_eq!(decl.identifier, main);
-            assert_eq!(decl.item_range.to_offset(&buffer.read(cx)), 32..279);
+        let index_state_arc = index.read_with(cx, |index, _cx| index.state().clone());
+        {
+            let index_state = index_state_arc.lock().await;
 
-            expect_file_decl("a.rs", &decls[1], &project, cx);
-        });
+            cx.update(|cx| {
+                let decls = index_state.declarations_for_identifier::<8>(&main);
+                assert_eq!(decls.len(), 2);
+                let decl = expect_buffer_decl("c.rs", &decls[0], &project, cx);
+                assert_eq!(decl.identifier, main);
+                assert_eq!(decl.item_range.to_offset(&buffer.read(cx)), 32..280);
+
+                expect_file_decl("a.rs", &decls[1], &project, cx);
+            });
+        }
 
-        // Need to trigger flush_effects so that the observe_release handler will run.
-        cx.update(|_cx| {
+        // Drop the buffer and wait for release
+        cx.update(|_| {
             drop(buffer);
         });
         cx.run_until_parked();
 
-        index.read_with(cx, |index, cx| {
-            let decls = index.declarations_for_identifier::<8>(main, cx);
+        let index_state = index_state_arc.lock().await;
+
+        cx.update(|cx| {
+            let decls = index_state.declarations_for_identifier::<8>(&main);
             assert_eq!(decls.len(), 2);
             expect_file_decl("c.rs", &decls[0], &project, cx);
             expect_file_decl("a.rs", &decls[1], &project, cx);
@@ -671,24 +703,20 @@ mod tests {
     fn expect_buffer_decl<'a>(
         path: &str,
         declaration: &'a Declaration,
+        project: &Entity<Project>,
         cx: &App,
     ) -> &'a BufferDeclaration {
         if let Declaration::Buffer {
             declaration,
-            buffer,
+            project_entry_id,
+            ..
         } = declaration
         {
-            assert_eq!(
-                buffer
-                    .upgrade()
-                    .unwrap()
-                    .read(cx)
-                    .project_path(cx)
-                    .unwrap()
-                    .path
-                    .as_ref(),
-                Path::new(path),
-            );
+            let project_path = project
+                .read(cx)
+                .path_for_entry(*project_entry_id, cx)
+                .unwrap();
+            assert_eq!(project_path.path.as_ref(), Path::new(path),);
             declaration
         } else {
             panic!("Expected a buffer declaration, found {:?}", declaration);
@@ -723,7 +751,7 @@ mod tests {
 
     async fn init_test(
         cx: &mut TestAppContext,
-    ) -> (Entity<Project>, Entity<TreeSitterIndex>, LanguageId) {
+    ) -> (Entity<Project>, Entity<SyntaxIndex>, LanguageId) {
         cx.update(|cx| {
             let settings_store = SettingsStore::test(cx);
             cx.set_global(settings_store);
@@ -801,7 +829,7 @@ mod tests {
         let lang_id = lang.id();
         language_registry.add(Arc::new(lang));
 
-        let index = cx.new(|cx| TreeSitterIndex::new(&project, cx));
+        let index = cx.new(|cx| SyntaxIndex::new(&project, cx));
         cx.run_until_parked();
 
         (project, index, lang_id)

crates/edit_prediction_context/src/text_similarity.rs 🔗

@@ -0,0 +1,241 @@
+use regex::Regex;
+use std::{collections::HashMap, sync::LazyLock};
+
+use crate::reference::Reference;
+
+// TODO: Consider implementing sliding window similarity matching like
+// https://github.com/sourcegraph/cody-public-snapshot/blob/8e20ac6c1460c08b0db581c0204658112a246eda/vscode/src/completions/context/retrievers/jaccard-similarity/bestJaccardMatch.ts
+//
+// That implementation could actually be more efficient - no need to track words in the window that
+// are not in the query.
+
+static IDENTIFIER_REGEX: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"\b\w+\b").unwrap());
+
+#[derive(Debug)]
+pub struct IdentifierOccurrences {
+    identifier_to_count: HashMap<String, usize>,
+    total_count: usize,
+}
+
+impl IdentifierOccurrences {
+    pub fn within_string(code: &str) -> Self {
+        Self::from_iterator(IDENTIFIER_REGEX.find_iter(code).map(|mat| mat.as_str()))
+    }
+
+    #[allow(dead_code)]
+    pub fn within_references(references: &[Reference]) -> Self {
+        Self::from_iterator(
+            references
+                .iter()
+                .map(|reference| reference.identifier.name.as_ref()),
+        )
+    }
+
+    pub fn from_iterator<'a>(identifier_iterator: impl Iterator<Item = &'a str>) -> Self {
+        let mut identifier_to_count = HashMap::new();
+        let mut total_count = 0;
+        for identifier in identifier_iterator {
+            // TODO: Score matches that match case higher?
+            //
+            // TODO: Also include unsplit identifier?
+            for identifier_part in split_identifier(identifier) {
+                identifier_to_count
+                    .entry(identifier_part.to_lowercase())
+                    .and_modify(|count| *count += 1)
+                    .or_insert(1);
+                total_count += 1;
+            }
+        }
+        IdentifierOccurrences {
+            identifier_to_count,
+            total_count,
+        }
+    }
+}
+
+// Splits camelcase / snakecase / kebabcase / pascalcase
+//
+// TODO: Make this more efficient / elegant.
+fn split_identifier(identifier: &str) -> Vec<&str> {
+    let mut parts = Vec::new();
+    let mut start = 0;
+    let chars: Vec<char> = identifier.chars().collect();
+
+    if chars.is_empty() {
+        return parts;
+    }
+
+    let mut i = 0;
+    while i < chars.len() {
+        let ch = chars[i];
+
+        // Handle explicit delimiters (underscore and hyphen)
+        if ch == '_' || ch == '-' {
+            if i > start {
+                parts.push(&identifier[start..i]);
+            }
+            start = i + 1;
+            i += 1;
+            continue;
+        }
+
+        // Handle camelCase and PascalCase transitions
+        if i > 0 && i < chars.len() {
+            let prev_char = chars[i - 1];
+
+            // Transition from lowercase/digit to uppercase
+            if (prev_char.is_lowercase() || prev_char.is_ascii_digit()) && ch.is_uppercase() {
+                parts.push(&identifier[start..i]);
+                start = i;
+            }
+            // Handle sequences like "XMLParser" -> ["XML", "Parser"]
+            else if i + 1 < chars.len()
+                && ch.is_uppercase()
+                && chars[i + 1].is_lowercase()
+                && prev_char.is_uppercase()
+            {
+                parts.push(&identifier[start..i]);
+                start = i;
+            }
+        }
+
+        i += 1;
+    }
+
+    // Add the last part if there's any remaining
+    if start < identifier.len() {
+        parts.push(&identifier[start..]);
+    }
+
+    // Filter out empty strings
+    parts.into_iter().filter(|s| !s.is_empty()).collect()
+}
+
+pub fn jaccard_similarity<'a>(
+    mut set_a: &'a IdentifierOccurrences,
+    mut set_b: &'a IdentifierOccurrences,
+) -> f32 {
+    if set_a.identifier_to_count.len() > set_b.identifier_to_count.len() {
+        std::mem::swap(&mut set_a, &mut set_b);
+    }
+    let intersection = set_a
+        .identifier_to_count
+        .keys()
+        .filter(|key| set_b.identifier_to_count.contains_key(*key))
+        .count();
+    let union = set_a.identifier_to_count.len() + set_b.identifier_to_count.len() - intersection;
+    intersection as f32 / union as f32
+}
+
+// TODO
+#[allow(dead_code)]
+pub fn overlap_coefficient<'a>(
+    mut set_a: &'a IdentifierOccurrences,
+    mut set_b: &'a IdentifierOccurrences,
+) -> f32 {
+    if set_a.identifier_to_count.len() > set_b.identifier_to_count.len() {
+        std::mem::swap(&mut set_a, &mut set_b);
+    }
+    let intersection = set_a
+        .identifier_to_count
+        .keys()
+        .filter(|key| set_b.identifier_to_count.contains_key(*key))
+        .count();
+    intersection as f32 / set_a.identifier_to_count.len() as f32
+}
+
+// TODO
+#[allow(dead_code)]
+pub fn weighted_jaccard_similarity<'a>(
+    mut set_a: &'a IdentifierOccurrences,
+    mut set_b: &'a IdentifierOccurrences,
+) -> f32 {
+    if set_a.identifier_to_count.len() > set_b.identifier_to_count.len() {
+        std::mem::swap(&mut set_a, &mut set_b);
+    }
+
+    let mut numerator = 0;
+    let mut denominator_a = 0;
+    let mut used_count_b = 0;
+    for (symbol, count_a) in set_a.identifier_to_count.iter() {
+        let count_b = set_b.identifier_to_count.get(symbol).unwrap_or(&0);
+        numerator += count_a.min(count_b);
+        denominator_a += count_a.max(count_b);
+        used_count_b += count_b;
+    }
+
+    let denominator = denominator_a + (set_b.total_count - used_count_b);
+    if denominator == 0 {
+        0.0
+    } else {
+        numerator as f32 / denominator as f32
+    }
+}
+
+pub fn weighted_overlap_coefficient<'a>(
+    mut set_a: &'a IdentifierOccurrences,
+    mut set_b: &'a IdentifierOccurrences,
+) -> f32 {
+    if set_a.identifier_to_count.len() > set_b.identifier_to_count.len() {
+        std::mem::swap(&mut set_a, &mut set_b);
+    }
+
+    let mut numerator = 0;
+    for (symbol, count_a) in set_a.identifier_to_count.iter() {
+        let count_b = set_b.identifier_to_count.get(symbol).unwrap_or(&0);
+        numerator += count_a.min(count_b);
+    }
+
+    let denominator = set_a.total_count.min(set_b.total_count);
+    if denominator == 0 {
+        0.0
+    } else {
+        numerator as f32 / denominator as f32
+    }
+}
+
+#[cfg(test)]
+mod test {
+    use super::*;
+
+    #[test]
+    fn test_split_identifier() {
+        assert_eq!(split_identifier("snake_case"), vec!["snake", "case"]);
+        assert_eq!(split_identifier("kebab-case"), vec!["kebab", "case"]);
+        assert_eq!(split_identifier("PascalCase"), vec!["Pascal", "Case"]);
+        assert_eq!(split_identifier("camelCase"), vec!["camel", "Case"]);
+        assert_eq!(split_identifier("XMLParser"), vec!["XML", "Parser"]);
+    }
+
+    #[test]
+    fn test_similarity_functions() {
+        // 10 identifier parts, 8 unique
+        // Repeats: 2 "outline", 2 "items"
+        let set_a = IdentifierOccurrences::within_string(
+            "let mut outline_items = query_outline_items(&language, &tree, &source);",
+        );
+        // 14 identifier parts, 11 unique
+        // Repeats: 2 "outline", 2 "language", 2 "tree"
+        let set_b = IdentifierOccurrences::within_string(
+            "pub fn query_outline_items(language: &Language, tree: &Tree, source: &str) -> Vec<OutlineItem> {",
+        );
+
+        // 6 overlaps: "outline", "items", "query", "language", "tree", "source"
+        // 7 non-overlaps: "let", "mut", "pub", "fn", "vec", "item", "str"
+        assert_eq!(jaccard_similarity(&set_a, &set_b), 6.0 / (6.0 + 7.0));
+
+        // Numerator is one more than before due to both having 2 "outline".
+        // Denominator is the same except for 3 more due to the non-overlapping duplicates
+        assert_eq!(
+            weighted_jaccard_similarity(&set_a, &set_b),
+            7.0 / (7.0 + 7.0 + 3.0)
+        );
+
+        // Numerator is the same as jaccard_similarity. Denominator is the size of the smaller set, 8.
+        assert_eq!(overlap_coefficient(&set_a, &set_b), 6.0 / 8.0);
+
+        // Numerator is the same as weighted_jaccard_similarity. Denominator is the total weight of
+        // the smaller set, 10.
+        assert_eq!(weighted_overlap_coefficient(&set_a, &set_b), 7.0 / 10.0);
+    }
+}

crates/edit_prediction_context/src/wip_requests.rs 🔗

@@ -0,0 +1,35 @@
+// To discuss: What to send to the new endpoint? Thinking it'd make sense to put `prompt.rs` from
+// `zeta_context.rs` in cloud.
+//
+// * Run excerpt selection at several different sizes, send the largest size with offsets within for
+// the smaller sizes.
+//
+// * Longer event history.
+//
+// * Many more snippets than could fit in model context - allows ranking experimentation.
+
+pub struct Zeta2Request {
+    pub event_history: Vec<Event>,
+    pub excerpt: String,
+    pub excerpt_subsets: Vec<Zeta2ExcerptSubset>,
+    /// Within `excerpt`
+    pub cursor_position: usize,
+    pub signatures: Vec<String>,
+    pub retrieved_declarations: Vec<ReferencedDeclaration>,
+}
+
+pub struct Zeta2ExcerptSubset {
+    /// Within `excerpt` text.
+    pub excerpt_range: Range<usize>,
+    /// Within `signatures`.
+    pub parent_signatures: Vec<usize>,
+}
+
+pub struct ReferencedDeclaration {
+    pub text: Arc<str>,
+    /// Range within `text`
+    pub signature_range: Range<usize>,
+    /// Indices within `signatures`.
+    pub parent_signatures: Vec<usize>,
+    // A bunch of score metrics
+}