Detailed changes
@@ -5140,17 +5140,23 @@ version = "0.1.0"
dependencies = [
"anyhow",
"arrayvec",
+ "clap",
"collections",
"futures 0.3.31",
"gpui",
"indoc",
+ "itertools 0.14.0",
"language",
"log",
+ "ordered-float 2.10.1",
"pretty_assertions",
"project",
+ "regex",
+ "serde",
"serde_json",
"settings",
"slotmap",
+ "strum 0.27.1",
"text",
"tree-sitter",
"util",
@@ -15,17 +15,24 @@ path = "src/edit_prediction_context.rs"
anyhow.workspace = true
arrayvec.workspace = true
collections.workspace = true
+futures.workspace = true
gpui.workspace = true
+itertools.workspace = true
language.workspace = true
log.workspace = true
+ordered-float.workspace = true
project.workspace = true
+regex.workspace = true
+serde.workspace = true
slotmap.workspace = true
+strum.workspace = true
text.workspace = true
tree-sitter.workspace = true
util.workspace = true
workspace-hack.workspace = true
[dev-dependencies]
+clap.workspace = true
futures.workspace = true
gpui = { workspace = true, features = ["test-support"] }
indoc.workspace = true
@@ -0,0 +1,193 @@
+use language::LanguageId;
+use project::ProjectEntryId;
+use std::borrow::Cow;
+use std::ops::Range;
+use std::sync::Arc;
+use text::{Bias, BufferId, Rope};
+
+use crate::outline::OutlineDeclaration;
+
+#[derive(Debug, Clone, Eq, PartialEq, Hash)]
+pub struct Identifier {
+ pub name: Arc<str>,
+ pub language_id: LanguageId,
+}
+
+slotmap::new_key_type! {
+ pub struct DeclarationId;
+}
+
+#[derive(Debug, Clone)]
+pub enum Declaration {
+ File {
+ project_entry_id: ProjectEntryId,
+ declaration: FileDeclaration,
+ },
+ Buffer {
+ project_entry_id: ProjectEntryId,
+ buffer_id: BufferId,
+ rope: Rope,
+ declaration: BufferDeclaration,
+ },
+}
+
+const ITEM_TEXT_TRUNCATION_LENGTH: usize = 1024;
+
+impl Declaration {
+ pub fn identifier(&self) -> &Identifier {
+ match self {
+ Declaration::File { declaration, .. } => &declaration.identifier,
+ Declaration::Buffer { declaration, .. } => &declaration.identifier,
+ }
+ }
+
+ pub fn project_entry_id(&self) -> Option<ProjectEntryId> {
+ match self {
+ Declaration::File {
+ project_entry_id, ..
+ } => Some(*project_entry_id),
+ Declaration::Buffer {
+ project_entry_id, ..
+ } => Some(*project_entry_id),
+ }
+ }
+
+ pub fn item_text(&self) -> (Cow<'_, str>, bool) {
+ match self {
+ Declaration::File { declaration, .. } => (
+ declaration.text.as_ref().into(),
+ declaration.text_is_truncated,
+ ),
+ Declaration::Buffer {
+ rope, declaration, ..
+ } => (
+ rope.chunks_in_range(declaration.item_range.clone())
+ .collect::<Cow<str>>(),
+ declaration.item_range_is_truncated,
+ ),
+ }
+ }
+
+ pub fn signature_text(&self) -> (Cow<'_, str>, bool) {
+ match self {
+ Declaration::File { declaration, .. } => (
+ declaration.text[declaration.signature_range_in_text.clone()].into(),
+ declaration.signature_is_truncated,
+ ),
+ Declaration::Buffer {
+ rope, declaration, ..
+ } => (
+ rope.chunks_in_range(declaration.signature_range.clone())
+ .collect::<Cow<str>>(),
+ declaration.signature_range_is_truncated,
+ ),
+ }
+ }
+}
+
+fn expand_range_to_line_boundaries_and_truncate(
+ range: &Range<usize>,
+ limit: usize,
+ rope: &Rope,
+) -> (Range<usize>, bool) {
+ let mut point_range = rope.offset_to_point(range.start)..rope.offset_to_point(range.end);
+ point_range.start.column = 0;
+ point_range.end.row += 1;
+ point_range.end.column = 0;
+
+ let mut item_range =
+ rope.point_to_offset(point_range.start)..rope.point_to_offset(point_range.end);
+ let is_truncated = item_range.len() > limit;
+ if is_truncated {
+ item_range.end = item_range.start + limit;
+ }
+ item_range.end = rope.clip_offset(item_range.end, Bias::Left);
+ (item_range, is_truncated)
+}
+
+#[derive(Debug, Clone)]
+pub struct FileDeclaration {
+ pub parent: Option<DeclarationId>,
+ pub identifier: Identifier,
+ /// offset range of the declaration in the file, expanded to line boundaries and truncated
+ pub item_range_in_file: Range<usize>,
+ /// text of `item_range_in_file`
+ pub text: Arc<str>,
+ /// whether `text` was truncated
+ pub text_is_truncated: bool,
+ /// offset range of the signature within `text`
+ pub signature_range_in_text: Range<usize>,
+ /// whether `signature` was truncated
+ pub signature_is_truncated: bool,
+}
+
+impl FileDeclaration {
+ pub fn from_outline(declaration: OutlineDeclaration, rope: &Rope) -> FileDeclaration {
+ let (item_range_in_file, text_is_truncated) = expand_range_to_line_boundaries_and_truncate(
+ &declaration.item_range,
+ ITEM_TEXT_TRUNCATION_LENGTH,
+ rope,
+ );
+
+ // TODO: consider logging if unexpected
+ let signature_start = declaration
+ .signature_range
+ .start
+ .saturating_sub(item_range_in_file.start);
+ let mut signature_end = declaration
+ .signature_range
+ .end
+ .saturating_sub(item_range_in_file.start);
+ let signature_is_truncated = signature_end > item_range_in_file.len();
+ if signature_is_truncated {
+ signature_end = item_range_in_file.len();
+ }
+
+ FileDeclaration {
+ parent: None,
+ identifier: declaration.identifier,
+ signature_range_in_text: signature_start..signature_end,
+ signature_is_truncated,
+ text: rope
+ .chunks_in_range(item_range_in_file.clone())
+ .collect::<String>()
+ .into(),
+ text_is_truncated,
+ item_range_in_file,
+ }
+ }
+}
+
+#[derive(Debug, Clone)]
+pub struct BufferDeclaration {
+ pub parent: Option<DeclarationId>,
+ pub identifier: Identifier,
+ pub item_range: Range<usize>,
+ pub item_range_is_truncated: bool,
+ pub signature_range: Range<usize>,
+ pub signature_range_is_truncated: bool,
+}
+
+impl BufferDeclaration {
+ pub fn from_outline(declaration: OutlineDeclaration, rope: &Rope) -> Self {
+ let (item_range, item_range_is_truncated) = expand_range_to_line_boundaries_and_truncate(
+ &declaration.item_range,
+ ITEM_TEXT_TRUNCATION_LENGTH,
+ rope,
+ );
+ let (signature_range, signature_range_is_truncated) =
+ expand_range_to_line_boundaries_and_truncate(
+ &declaration.signature_range,
+ ITEM_TEXT_TRUNCATION_LENGTH,
+ rope,
+ );
+ Self {
+ parent: None,
+ identifier: declaration.identifier,
+ item_range,
+ item_range_is_truncated,
+ signature_range,
+ signature_range_is_truncated,
+ }
+ }
+}
@@ -0,0 +1,326 @@
+use itertools::Itertools as _;
+use language::BufferSnapshot;
+use ordered_float::OrderedFloat;
+use serde::Serialize;
+use std::{collections::HashMap, ops::Range};
+use strum::EnumIter;
+use text::{OffsetRangeExt, Point, ToPoint};
+
+use crate::{
+ Declaration, EditPredictionExcerpt, EditPredictionExcerptText, Identifier,
+ reference::{Reference, ReferenceRegion},
+ syntax_index::SyntaxIndexState,
+ text_similarity::{IdentifierOccurrences, jaccard_similarity, weighted_overlap_coefficient},
+};
+
+const MAX_IDENTIFIER_DECLARATION_COUNT: usize = 16;
+
+// TODO:
+//
+// * Consider adding declaration_file_count
+
+#[derive(Clone, Debug)]
+pub struct ScoredSnippet {
+ pub identifier: Identifier,
+ pub declaration: Declaration,
+ pub score_components: ScoreInputs,
+ pub scores: Scores,
+}
+
+// TODO: Consider having "Concise" style corresponding to `concise_text`
+#[derive(EnumIter, Clone, Copy, PartialEq, Eq, Hash, Debug)]
+pub enum SnippetStyle {
+ Signature,
+ Declaration,
+}
+
+impl ScoredSnippet {
+ /// Returns the score for this snippet with the specified style.
+ pub fn score(&self, style: SnippetStyle) -> f32 {
+ match style {
+ SnippetStyle::Signature => self.scores.signature,
+ SnippetStyle::Declaration => self.scores.declaration,
+ }
+ }
+
+ pub fn size(&self, style: SnippetStyle) -> usize {
+ // TODO: how to handle truncation?
+ match &self.declaration {
+ Declaration::File { declaration, .. } => match style {
+ SnippetStyle::Signature => declaration.signature_range_in_text.len(),
+ SnippetStyle::Declaration => declaration.text.len(),
+ },
+ Declaration::Buffer { declaration, .. } => match style {
+ SnippetStyle::Signature => declaration.signature_range.len(),
+ SnippetStyle::Declaration => declaration.item_range.len(),
+ },
+ }
+ }
+
+ pub fn score_density(&self, style: SnippetStyle) -> f32 {
+ self.score(style) / (self.size(style)) as f32
+ }
+}
+
+pub fn scored_snippets(
+ index: &SyntaxIndexState,
+ excerpt: &EditPredictionExcerpt,
+ excerpt_text: &EditPredictionExcerptText,
+ identifier_to_references: HashMap<Identifier, Vec<Reference>>,
+ cursor_offset: usize,
+ current_buffer: &BufferSnapshot,
+) -> Vec<ScoredSnippet> {
+ let containing_range_identifier_occurrences =
+ IdentifierOccurrences::within_string(&excerpt_text.body);
+ let cursor_point = cursor_offset.to_point(¤t_buffer);
+
+ let start_point = Point::new(cursor_point.row.saturating_sub(2), 0);
+ let end_point = Point::new(cursor_point.row + 1, 0);
+ let adjacent_identifier_occurrences = IdentifierOccurrences::within_string(
+ ¤t_buffer
+ .text_for_range(start_point..end_point)
+ .collect::<String>(),
+ );
+
+ let mut snippets = identifier_to_references
+ .into_iter()
+ .flat_map(|(identifier, references)| {
+ let declarations =
+ index.declarations_for_identifier::<MAX_IDENTIFIER_DECLARATION_COUNT>(&identifier);
+ let declaration_count = declarations.len();
+
+ declarations
+ .iter()
+ .filter_map(|declaration| match declaration {
+ Declaration::Buffer {
+ buffer_id,
+ declaration: buffer_declaration,
+ ..
+ } => {
+ let is_same_file = buffer_id == ¤t_buffer.remote_id();
+
+ if is_same_file {
+ range_intersection(
+ &buffer_declaration.item_range.to_offset(¤t_buffer),
+ &excerpt.range,
+ )
+ .is_none()
+ .then(|| {
+ let declaration_line = buffer_declaration
+ .item_range
+ .start
+ .to_point(current_buffer)
+ .row;
+ (
+ true,
+ (cursor_point.row as i32 - declaration_line as i32)
+ .unsigned_abs(),
+ declaration,
+ )
+ })
+ } else {
+ // TODO should we prefer the current file instead?
+ Some((false, 0, declaration))
+ }
+ }
+ Declaration::File { .. } => {
+ // TODO should we prefer the current file instead?
+ // We can assume that a file declaration is in a different file,
+ // because the current one must be open
+ Some((false, 0, declaration))
+ }
+ })
+ .sorted_by_key(|&(_, distance, _)| distance)
+ .enumerate()
+ .map(
+ |(
+ declaration_line_distance_rank,
+ (is_same_file, declaration_line_distance, declaration),
+ )| {
+ let same_file_declaration_count = index.file_declaration_count(declaration);
+
+ score_snippet(
+ &identifier,
+ &references,
+ declaration.clone(),
+ is_same_file,
+ declaration_line_distance,
+ declaration_line_distance_rank,
+ same_file_declaration_count,
+ declaration_count,
+ &containing_range_identifier_occurrences,
+ &adjacent_identifier_occurrences,
+ cursor_point,
+ current_buffer,
+ )
+ },
+ )
+ .collect::<Vec<_>>()
+ })
+ .flatten()
+ .collect::<Vec<_>>();
+
+ snippets.sort_unstable_by_key(|snippet| {
+ OrderedFloat(
+ snippet
+ .score_density(SnippetStyle::Declaration)
+ .max(snippet.score_density(SnippetStyle::Signature)),
+ )
+ });
+
+ snippets
+}
+
+fn range_intersection<T: Ord + Clone>(a: &Range<T>, b: &Range<T>) -> Option<Range<T>> {
+ let start = a.start.clone().max(b.start.clone());
+ let end = a.end.clone().min(b.end.clone());
+ if start < end {
+ Some(Range { start, end })
+ } else {
+ None
+ }
+}
+
+fn score_snippet(
+ identifier: &Identifier,
+ references: &[Reference],
+ declaration: Declaration,
+ is_same_file: bool,
+ declaration_line_distance: u32,
+ declaration_line_distance_rank: usize,
+ same_file_declaration_count: usize,
+ declaration_count: usize,
+ containing_range_identifier_occurrences: &IdentifierOccurrences,
+ adjacent_identifier_occurrences: &IdentifierOccurrences,
+ cursor: Point,
+ current_buffer: &BufferSnapshot,
+) -> Option<ScoredSnippet> {
+ let is_referenced_nearby = references
+ .iter()
+ .any(|r| r.region == ReferenceRegion::Nearby);
+ let is_referenced_in_breadcrumb = references
+ .iter()
+ .any(|r| r.region == ReferenceRegion::Breadcrumb);
+ let reference_count = references.len();
+ let reference_line_distance = references
+ .iter()
+ .map(|r| {
+ let reference_line = r.range.start.to_point(current_buffer).row as i32;
+ (cursor.row as i32 - reference_line).unsigned_abs()
+ })
+ .min()
+ .unwrap();
+
+ let item_source_occurrences = IdentifierOccurrences::within_string(&declaration.item_text().0);
+ let item_signature_occurrences =
+ IdentifierOccurrences::within_string(&declaration.signature_text().0);
+ let containing_range_vs_item_jaccard = jaccard_similarity(
+ containing_range_identifier_occurrences,
+ &item_source_occurrences,
+ );
+ let containing_range_vs_signature_jaccard = jaccard_similarity(
+ containing_range_identifier_occurrences,
+ &item_signature_occurrences,
+ );
+ let adjacent_vs_item_jaccard =
+ jaccard_similarity(adjacent_identifier_occurrences, &item_source_occurrences);
+ let adjacent_vs_signature_jaccard =
+ jaccard_similarity(adjacent_identifier_occurrences, &item_signature_occurrences);
+
+ let containing_range_vs_item_weighted_overlap = weighted_overlap_coefficient(
+ containing_range_identifier_occurrences,
+ &item_source_occurrences,
+ );
+ let containing_range_vs_signature_weighted_overlap = weighted_overlap_coefficient(
+ containing_range_identifier_occurrences,
+ &item_signature_occurrences,
+ );
+ let adjacent_vs_item_weighted_overlap =
+ weighted_overlap_coefficient(adjacent_identifier_occurrences, &item_source_occurrences);
+ let adjacent_vs_signature_weighted_overlap =
+ weighted_overlap_coefficient(adjacent_identifier_occurrences, &item_signature_occurrences);
+
+ let score_components = ScoreInputs {
+ is_same_file,
+ is_referenced_nearby,
+ is_referenced_in_breadcrumb,
+ reference_line_distance,
+ declaration_line_distance,
+ declaration_line_distance_rank,
+ reference_count,
+ same_file_declaration_count,
+ declaration_count,
+ containing_range_vs_item_jaccard,
+ containing_range_vs_signature_jaccard,
+ adjacent_vs_item_jaccard,
+ adjacent_vs_signature_jaccard,
+ containing_range_vs_item_weighted_overlap,
+ containing_range_vs_signature_weighted_overlap,
+ adjacent_vs_item_weighted_overlap,
+ adjacent_vs_signature_weighted_overlap,
+ };
+
+ Some(ScoredSnippet {
+ identifier: identifier.clone(),
+ declaration: declaration,
+ scores: score_components.score(),
+ score_components,
+ })
+}
+
+#[derive(Clone, Debug, Serialize)]
+pub struct ScoreInputs {
+ pub is_same_file: bool,
+ pub is_referenced_nearby: bool,
+ pub is_referenced_in_breadcrumb: bool,
+ pub reference_count: usize,
+ pub same_file_declaration_count: usize,
+ pub declaration_count: usize,
+ pub reference_line_distance: u32,
+ pub declaration_line_distance: u32,
+ pub declaration_line_distance_rank: usize,
+ pub containing_range_vs_item_jaccard: f32,
+ pub containing_range_vs_signature_jaccard: f32,
+ pub adjacent_vs_item_jaccard: f32,
+ pub adjacent_vs_signature_jaccard: f32,
+ pub containing_range_vs_item_weighted_overlap: f32,
+ pub containing_range_vs_signature_weighted_overlap: f32,
+ pub adjacent_vs_item_weighted_overlap: f32,
+ pub adjacent_vs_signature_weighted_overlap: f32,
+}
+
+#[derive(Clone, Debug, Serialize)]
+pub struct Scores {
+ pub signature: f32,
+ pub declaration: f32,
+}
+
+impl ScoreInputs {
+ fn score(&self) -> Scores {
+ // Score related to how likely this is the correct declaration, range 0 to 1
+ let accuracy_score = if self.is_same_file {
+ // TODO: use declaration_line_distance_rank
+ 1.0 / self.same_file_declaration_count as f32
+ } else {
+ 1.0 / self.declaration_count as f32
+ };
+
+ // Score related to the distance between the reference and cursor, range 0 to 1
+ let distance_score = if self.is_referenced_nearby {
+ 1.0 / (1.0 + self.reference_line_distance as f32 / 10.0).powf(2.0)
+ } else {
+ // same score as ~14 lines away, rationale is to not overly penalize references from parent signatures
+ 0.5
+ };
+
+ // For now instead of linear combination, the scores are just multiplied together.
+ let combined_score = 10.0 * accuracy_score * distance_score;
+
+ Scores {
+ signature: combined_score * self.containing_range_vs_signature_weighted_overlap,
+ // declaration score gets boosted both by being multiplied by 2 and by there being more
+ // weighted overlap.
+ declaration: 2.0 * combined_score * self.containing_range_vs_item_weighted_overlap,
+ }
+ }
+}
@@ -1,8 +1,220 @@
+mod declaration;
+mod declaration_scoring;
mod excerpt;
mod outline;
mod reference;
-mod tree_sitter_index;
+mod syntax_index;
+mod text_similarity;
+pub use declaration::{BufferDeclaration, Declaration, FileDeclaration, Identifier};
pub use excerpt::{EditPredictionExcerpt, EditPredictionExcerptOptions, EditPredictionExcerptText};
+use gpui::{App, AppContext as _, Entity, Task};
+use language::BufferSnapshot;
pub use reference::references_in_excerpt;
-pub use tree_sitter_index::{BufferDeclaration, Declaration, FileDeclaration, TreeSitterIndex};
+pub use syntax_index::SyntaxIndex;
+use text::{Point, ToOffset as _};
+
+use crate::declaration_scoring::{ScoredSnippet, scored_snippets};
+
+pub struct EditPredictionContext {
+ pub excerpt: EditPredictionExcerpt,
+ pub excerpt_text: EditPredictionExcerptText,
+ pub snippets: Vec<ScoredSnippet>,
+}
+
+impl EditPredictionContext {
+ pub fn gather(
+ cursor_point: Point,
+ buffer: BufferSnapshot,
+ excerpt_options: EditPredictionExcerptOptions,
+ syntax_index: Entity<SyntaxIndex>,
+ cx: &mut App,
+ ) -> Task<Self> {
+ let index_state = syntax_index.read_with(cx, |index, _cx| index.state().clone());
+ cx.background_spawn(async move {
+ let index_state = index_state.lock().await;
+
+ let excerpt =
+ EditPredictionExcerpt::select_from_buffer(cursor_point, &buffer, &excerpt_options)
+ .unwrap();
+ let excerpt_text = excerpt.text(&buffer);
+ let references = references_in_excerpt(&excerpt, &excerpt_text, &buffer);
+ let cursor_offset = cursor_point.to_offset(&buffer);
+
+ let snippets = scored_snippets(
+ &index_state,
+ &excerpt,
+ &excerpt_text,
+ references,
+ cursor_offset,
+ &buffer,
+ );
+
+ Self {
+ excerpt,
+ excerpt_text,
+ snippets,
+ }
+ })
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use std::sync::Arc;
+
+ use gpui::{Entity, TestAppContext};
+ use indoc::indoc;
+ use language::{Language, LanguageConfig, LanguageId, LanguageMatcher, tree_sitter_rust};
+ use project::{FakeFs, Project};
+ use serde_json::json;
+ use settings::SettingsStore;
+ use util::path;
+
+ use crate::{EditPredictionExcerptOptions, SyntaxIndex};
+
+ #[gpui::test]
+ async fn test_call_site(cx: &mut TestAppContext) {
+ let (project, index, _rust_lang_id) = init_test(cx).await;
+
+ let buffer = project
+ .update(cx, |project, cx| {
+ let project_path = project.find_project_path("c.rs", cx).unwrap();
+ project.open_buffer(project_path, cx)
+ })
+ .await
+ .unwrap();
+
+ cx.run_until_parked();
+
+ // first process_data call site
+ let cursor_point = language::Point::new(8, 21);
+ let buffer_snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot());
+
+ let context = cx
+ .update(|cx| {
+ EditPredictionContext::gather(
+ cursor_point,
+ buffer_snapshot,
+ EditPredictionExcerptOptions {
+ max_bytes: 40,
+ min_bytes: 10,
+ target_before_cursor_over_total_bytes: 0.5,
+ include_parent_signatures: false,
+ },
+ index,
+ cx,
+ )
+ })
+ .await;
+
+ assert_eq!(context.snippets.len(), 1);
+ assert_eq!(context.snippets[0].identifier.name.as_ref(), "process_data");
+ drop(buffer);
+ }
+
+ async fn init_test(
+ cx: &mut TestAppContext,
+ ) -> (Entity<Project>, Entity<SyntaxIndex>, LanguageId) {
+ cx.update(|cx| {
+ let settings_store = SettingsStore::test(cx);
+ cx.set_global(settings_store);
+ language::init(cx);
+ Project::init_settings(cx);
+ });
+
+ let fs = FakeFs::new(cx.executor());
+ fs.insert_tree(
+ path!("/root"),
+ json!({
+ "a.rs": indoc! {r#"
+ fn main() {
+ let x = 1;
+ let y = 2;
+ let z = add(x, y);
+ println!("Result: {}", z);
+ }
+
+ fn add(a: i32, b: i32) -> i32 {
+ a + b
+ }
+ "#},
+ "b.rs": indoc! {"
+ pub struct Config {
+ pub name: String,
+ pub value: i32,
+ }
+
+ impl Config {
+ pub fn new(name: String, value: i32) -> Self {
+ Config { name, value }
+ }
+ }
+ "},
+ "c.rs": indoc! {r#"
+ use std::collections::HashMap;
+
+ fn main() {
+ let args: Vec<String> = std::env::args().collect();
+ let data: Vec<i32> = args[1..]
+ .iter()
+ .filter_map(|s| s.parse().ok())
+ .collect();
+ let result = process_data(data);
+ println!("{:?}", result);
+ }
+
+ fn process_data(data: Vec<i32>) -> HashMap<i32, usize> {
+ let mut counts = HashMap::new();
+ for value in data {
+ *counts.entry(value).or_insert(0) += 1;
+ }
+ counts
+ }
+
+ #[cfg(test)]
+ mod tests {
+ use super::*;
+
+ #[test]
+ fn test_process_data() {
+ let data = vec![1, 2, 2, 3];
+ let result = process_data(data);
+ assert_eq!(result.get(&2), Some(&2));
+ }
+ }
+ "#}
+ }),
+ )
+ .await;
+ let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
+ let language_registry = project.read_with(cx, |project, _| project.languages().clone());
+ let lang = rust_lang();
+ let lang_id = lang.id();
+ language_registry.add(Arc::new(lang));
+
+ let index = cx.new(|cx| SyntaxIndex::new(&project, cx));
+ cx.run_until_parked();
+
+ (project, index, lang_id)
+ }
+
+ fn rust_lang() -> Language {
+ Language::new(
+ LanguageConfig {
+ name: "Rust".into(),
+ matcher: LanguageMatcher {
+ path_suffixes: vec!["rs".to_string()],
+ ..Default::default()
+ },
+ ..Default::default()
+ },
+ Some(tree_sitter_rust::LANGUAGE.into()),
+ )
+ .with_highlights_query(include_str!("../../languages/src/rust/highlights.scm"))
+ .unwrap()
+ .with_outline_query(include_str!("../../languages/src/rust/outline.scm"))
+ .unwrap()
+ }
+}
@@ -31,7 +31,7 @@ pub struct EditPredictionExcerptOptions {
pub include_parent_signatures: bool,
}
-#[derive(Clone)]
+#[derive(Debug, Clone)]
pub struct EditPredictionExcerpt {
pub range: Range<usize>,
pub parent_signature_ranges: Vec<Range<usize>>,
@@ -1,5 +1,7 @@
-use language::{BufferSnapshot, LanguageId, SyntaxMapMatches};
-use std::{cmp::Reverse, ops::Range, sync::Arc};
+use language::{BufferSnapshot, SyntaxMapMatches};
+use std::{cmp::Reverse, ops::Range};
+
+use crate::declaration::Identifier;
// TODO:
//
@@ -18,12 +20,6 @@ pub struct OutlineDeclaration {
pub signature_range: Range<usize>,
}
-#[derive(Debug, Clone, Eq, PartialEq, Hash)]
-pub struct Identifier {
- pub name: Arc<str>,
- pub language_id: LanguageId,
-}
-
pub fn declarations_in_buffer(buffer: &BufferSnapshot) -> Vec<OutlineDeclaration> {
declarations_overlapping_range(0..buffer.len(), buffer)
}
@@ -3,8 +3,8 @@ use std::collections::HashMap;
use std::ops::Range;
use crate::{
+ declaration::Identifier,
excerpt::{EditPredictionExcerpt, EditPredictionExcerptText},
- outline::Identifier,
};
#[derive(Debug)]
@@ -1,20 +1,26 @@
+use std::sync::Arc;
+
use collections::{HashMap, HashSet};
+use futures::lock::Mutex;
use gpui::{App, AppContext as _, Context, Entity, Task, WeakEntity};
-use language::{Buffer, BufferEvent, BufferSnapshot};
+use language::{Buffer, BufferEvent};
use project::buffer_store::{BufferStore, BufferStoreEvent};
use project::worktree_store::{WorktreeStore, WorktreeStoreEvent};
use project::{PathChange, Project, ProjectEntryId, ProjectPath};
use slotmap::SlotMap;
-use std::ops::Range;
-use std::sync::Arc;
-use text::Anchor;
+use text::BufferId;
use util::{ResultExt as _, debug_panic, some_or_debug_panic};
-use crate::outline::{Identifier, OutlineDeclaration, declarations_in_buffer};
+use crate::declaration::{
+ BufferDeclaration, Declaration, DeclarationId, FileDeclaration, Identifier,
+};
+use crate::outline::declarations_in_buffer;
// TODO:
//
// * Skip for remote projects
+//
+// * Consider making SyntaxIndex not an Entity.
// Potential future improvements:
//
@@ -34,17 +40,19 @@ use crate::outline::{Identifier, OutlineDeclaration, declarations_in_buffer};
// * Concurrent slotmap
//
// * Use queue for parsing
+//
-slotmap::new_key_type! {
- pub struct DeclarationId;
+pub struct SyntaxIndex {
+ state: Arc<Mutex<SyntaxIndexState>>,
+ project: WeakEntity<Project>,
}
-pub struct TreeSitterIndex {
+#[derive(Default)]
+pub struct SyntaxIndexState {
declarations: SlotMap<DeclarationId, Declaration>,
identifiers: HashMap<Identifier, HashSet<DeclarationId>>,
files: HashMap<ProjectEntryId, FileState>,
- buffers: HashMap<WeakEntity<Buffer>, BufferState>,
- project: WeakEntity<Project>,
+ buffers: HashMap<BufferId, BufferState>,
}
#[derive(Debug, Default)]
@@ -59,52 +67,11 @@ struct BufferState {
task: Option<Task<()>>,
}
-#[derive(Debug, Clone)]
-pub enum Declaration {
- File {
- project_entry_id: ProjectEntryId,
- declaration: FileDeclaration,
- },
- Buffer {
- buffer: WeakEntity<Buffer>,
- declaration: BufferDeclaration,
- },
-}
-
-impl Declaration {
- fn identifier(&self) -> &Identifier {
- match self {
- Declaration::File { declaration, .. } => &declaration.identifier,
- Declaration::Buffer { declaration, .. } => &declaration.identifier,
- }
- }
-}
-
-#[derive(Debug, Clone)]
-pub struct FileDeclaration {
- pub parent: Option<DeclarationId>,
- pub identifier: Identifier,
- pub item_range: Range<usize>,
- pub signature_range: Range<usize>,
- pub signature_text: Arc<str>,
-}
-
-#[derive(Debug, Clone)]
-pub struct BufferDeclaration {
- pub parent: Option<DeclarationId>,
- pub identifier: Identifier,
- pub item_range: Range<Anchor>,
- pub signature_range: Range<Anchor>,
-}
-
-impl TreeSitterIndex {
+impl SyntaxIndex {
pub fn new(project: &Entity<Project>, cx: &mut Context<Self>) -> Self {
let mut this = Self {
- declarations: SlotMap::with_key(),
- identifiers: HashMap::default(),
project: project.downgrade(),
- files: HashMap::default(),
- buffers: HashMap::default(),
+ state: Arc::new(Mutex::new(SyntaxIndexState::default())),
};
let worktree_store = project.read(cx).worktree_store();
@@ -139,73 +106,6 @@ impl TreeSitterIndex {
this
}
- pub fn declaration(&self, id: DeclarationId) -> Option<&Declaration> {
- self.declarations.get(id)
- }
-
- pub fn declarations_for_identifier<const N: usize>(
- &self,
- identifier: Identifier,
- cx: &App,
- ) -> Vec<Declaration> {
- // make sure to not have a large stack allocation
- assert!(N < 32);
-
- let Some(declaration_ids) = self.identifiers.get(&identifier) else {
- return vec![];
- };
-
- let mut result = Vec::with_capacity(N);
- let mut included_buffer_entry_ids = arrayvec::ArrayVec::<_, N>::new();
- let mut file_declarations = Vec::new();
-
- for declaration_id in declaration_ids {
- let declaration = self.declarations.get(*declaration_id);
- let Some(declaration) = some_or_debug_panic(declaration) else {
- continue;
- };
- match declaration {
- Declaration::Buffer { buffer, .. } => {
- if let Ok(Some(entry_id)) = buffer.read_with(cx, |buffer, cx| {
- project::File::from_dyn(buffer.file()).and_then(|f| f.project_entry_id(cx))
- }) {
- included_buffer_entry_ids.push(entry_id);
- result.push(declaration.clone());
- if result.len() == N {
- return result;
- }
- }
- }
- Declaration::File {
- project_entry_id, ..
- } => {
- if !included_buffer_entry_ids.contains(project_entry_id) {
- file_declarations.push(declaration.clone());
- }
- }
- }
- }
-
- for declaration in file_declarations {
- match declaration {
- Declaration::File {
- project_entry_id, ..
- } => {
- if !included_buffer_entry_ids.contains(&project_entry_id) {
- result.push(declaration);
-
- if result.len() == N {
- return result;
- }
- }
- }
- Declaration::Buffer { .. } => {}
- }
- }
-
- result
- }
-
fn handle_worktree_store_event(
&mut self,
_worktree_store: Entity<WorktreeStore>,
@@ -215,21 +115,33 @@ impl TreeSitterIndex {
use WorktreeStoreEvent::*;
match event {
WorktreeUpdatedEntries(worktree_id, updated_entries_set) => {
- for (path, entry_id, path_change) in updated_entries_set.iter() {
- if let PathChange::Removed = path_change {
- self.files.remove(entry_id);
- } else {
- let project_path = ProjectPath {
- worktree_id: *worktree_id,
- path: path.clone(),
- };
- self.update_file(*entry_id, project_path, cx);
+ let state = Arc::downgrade(&self.state);
+ let worktree_id = *worktree_id;
+ let updated_entries_set = updated_entries_set.clone();
+ cx.spawn(async move |this, cx| {
+ let Some(state) = state.upgrade() else { return };
+ for (path, entry_id, path_change) in updated_entries_set.iter() {
+ if let PathChange::Removed = path_change {
+ state.lock().await.files.remove(entry_id);
+ } else {
+ let project_path = ProjectPath {
+ worktree_id,
+ path: path.clone(),
+ };
+ this.update(cx, |this, cx| {
+ this.update_file(*entry_id, project_path, cx);
+ })
+ .ok();
+ }
}
- }
+ })
+ .detach();
}
WorktreeDeletedEntry(_worktree_id, project_entry_id) => {
- // TODO: Is this needed?
- self.files.remove(project_entry_id);
+ let project_entry_id = *project_entry_id;
+ self.with_state(cx, move |state| {
+ state.files.remove(&project_entry_id);
+ })
}
_ => {}
}
@@ -251,15 +163,42 @@ impl TreeSitterIndex {
}
}
+ pub fn state(&self) -> &Arc<Mutex<SyntaxIndexState>> {
+ &self.state
+ }
+
+ fn with_state(&self, cx: &mut App, f: impl FnOnce(&mut SyntaxIndexState) + Send + 'static) {
+ if let Some(mut state) = self.state.try_lock() {
+ f(&mut state);
+ return;
+ }
+ let state = Arc::downgrade(&self.state);
+ cx.background_spawn(async move {
+ let Some(state) = state.upgrade() else {
+ return;
+ };
+ let mut state = state.lock().await;
+ f(&mut state)
+ })
+ .detach();
+ }
+
fn register_buffer(&mut self, buffer: &Entity<Buffer>, cx: &mut Context<Self>) {
- self.buffers
- .insert(buffer.downgrade(), BufferState::default());
- let weak_buf = buffer.downgrade();
- cx.observe_release(buffer, move |this, _buffer, _cx| {
- this.buffers.remove(&weak_buf);
+ let buffer_id = buffer.read(cx).remote_id();
+ cx.observe_release(buffer, move |this, _buffer, cx| {
+ this.with_state(cx, move |state| {
+ if let Some(buffer_state) = state.buffers.remove(&buffer_id) {
+ SyntaxIndexState::remove_buffer_declarations(
+ &buffer_state.declarations,
+ &mut state.declarations,
+ &mut state.identifiers,
+ );
+ }
+ })
})
.detach();
cx.subscribe(buffer, Self::handle_buffer_event).detach();
+
self.update_buffer(buffer.clone(), cx);
}
@@ -275,10 +214,19 @@ impl TreeSitterIndex {
}
}
- fn update_buffer(&mut self, buffer: Entity<Buffer>, cx: &Context<Self>) {
- let mut parse_status = buffer.read(cx).parse_status();
+ fn update_buffer(&mut self, buffer_entity: Entity<Buffer>, cx: &mut Context<Self>) {
+ let buffer = buffer_entity.read(cx);
+
+ let Some(project_entry_id) =
+ project::File::from_dyn(buffer.file()).and_then(|f| f.project_entry_id(cx))
+ else {
+ return;
+ };
+ let buffer_id = buffer.remote_id();
+
+ let mut parse_status = buffer.parse_status();
let snapshot_task = cx.spawn({
- let weak_buffer = buffer.downgrade();
+ let weak_buffer = buffer_entity.downgrade();
async move |_, cx| {
while *parse_status.borrow() != language::ParseStatus::Idle {
parse_status.changed().await?;
@@ -289,75 +237,77 @@ impl TreeSitterIndex {
let parse_task = cx.background_spawn(async move {
let snapshot = snapshot_task.await?;
+ let rope = snapshot.text.as_rope().clone();
- anyhow::Ok(
+ anyhow::Ok((
declarations_in_buffer(&snapshot)
.into_iter()
.map(|item| {
(
item.parent_index,
- BufferDeclaration::from_outline(item, &snapshot),
+ BufferDeclaration::from_outline(item, &rope),
)
})
.collect::<Vec<_>>(),
- )
+ rope,
+ ))
});
let task = cx.spawn({
- let weak_buffer = buffer.downgrade();
async move |this, cx| {
- let Ok(declarations) = parse_task.await else {
+ let Ok((declarations, rope)) = parse_task.await else {
return;
};
- this.update(cx, |this, _cx| {
- let buffer_state = this
- .buffers
- .entry(weak_buffer.clone())
- .or_insert_with(Default::default);
-
- for old_declaration_id in &buffer_state.declarations {
- let Some(declaration) = this.declarations.remove(*old_declaration_id)
- else {
- debug_panic!("declaration not found");
- continue;
- };
- if let Some(identifier_declarations) =
- this.identifiers.get_mut(declaration.identifier())
- {
- identifier_declarations.remove(old_declaration_id);
+ this.update(cx, move |this, cx| {
+ this.with_state(cx, move |state| {
+ let buffer_state = state
+ .buffers
+ .entry(buffer_id)
+ .or_insert_with(Default::default);
+
+ SyntaxIndexState::remove_buffer_declarations(
+ &buffer_state.declarations,
+ &mut state.declarations,
+ &mut state.identifiers,
+ );
+
+ let mut new_ids = Vec::with_capacity(declarations.len());
+ state.declarations.reserve(declarations.len());
+ for (parent_index, mut declaration) in declarations {
+ declaration.parent = parent_index
+ .and_then(|ix| some_or_debug_panic(new_ids.get(ix).copied()));
+
+ let identifier = declaration.identifier.clone();
+ let declaration_id = state.declarations.insert(Declaration::Buffer {
+ rope: rope.clone(),
+ buffer_id,
+ declaration,
+ project_entry_id,
+ });
+ new_ids.push(declaration_id);
+
+ state
+ .identifiers
+ .entry(identifier)
+ .or_default()
+ .insert(declaration_id);
}
- }
-
- let mut new_ids = Vec::with_capacity(declarations.len());
- this.declarations.reserve(declarations.len());
- for (parent_index, mut declaration) in declarations {
- declaration.parent = parent_index
- .and_then(|ix| some_or_debug_panic(new_ids.get(ix).copied()));
-
- let identifier = declaration.identifier.clone();
- let declaration_id = this.declarations.insert(Declaration::Buffer {
- buffer: weak_buffer.clone(),
- declaration,
- });
- new_ids.push(declaration_id);
-
- this.identifiers
- .entry(identifier)
- .or_default()
- .insert(declaration_id);
- }
- buffer_state.declarations = new_ids;
+ buffer_state.declarations = new_ids;
+ });
})
.ok();
}
});
- self.buffers
- .entry(buffer.downgrade())
- .or_insert_with(Default::default)
- .task = Some(task);
+ self.with_state(cx, move |state| {
+ state
+ .buffers
+ .entry(buffer_id)
+ .or_insert_with(Default::default)
+ .task = Some(task)
+ });
}
fn update_file(
@@ -401,14 +351,10 @@ impl TreeSitterIndex {
let parse_task = cx.background_spawn(async move {
let snapshot = snapshot_task.await?;
+ let rope = snapshot.as_rope();
let declarations = declarations_in_buffer(&snapshot)
.into_iter()
- .map(|item| {
- (
- item.parent_index,
- FileDeclaration::from_outline(item, &snapshot),
- )
- })
+ .map(|item| (item.parent_index, FileDeclaration::from_outline(item, rope)))
.collect::<Vec<_>>();
anyhow::Ok(declarations)
});
@@ -419,84 +365,160 @@ impl TreeSitterIndex {
let Ok(declarations) = parse_task.await else {
return;
};
- this.update(cx, |this, _cx| {
- let file_state = this.files.entry(entry_id).or_insert_with(Default::default);
-
- for old_declaration_id in &file_state.declarations {
- let Some(declaration) = this.declarations.remove(*old_declaration_id)
- else {
- debug_panic!("declaration not found");
- continue;
- };
- if let Some(identifier_declarations) =
- this.identifiers.get_mut(declaration.identifier())
- {
- identifier_declarations.remove(old_declaration_id);
+ this.update(cx, |this, cx| {
+ this.with_state(cx, move |state| {
+ let file_state =
+ state.files.entry(entry_id).or_insert_with(Default::default);
+
+ for old_declaration_id in &file_state.declarations {
+ let Some(declaration) = state.declarations.remove(*old_declaration_id)
+ else {
+ debug_panic!("declaration not found");
+ continue;
+ };
+ if let Some(identifier_declarations) =
+ state.identifiers.get_mut(declaration.identifier())
+ {
+ identifier_declarations.remove(old_declaration_id);
+ }
}
- }
-
- let mut new_ids = Vec::with_capacity(declarations.len());
- this.declarations.reserve(declarations.len());
- for (parent_index, mut declaration) in declarations {
- declaration.parent = parent_index
- .and_then(|ix| some_or_debug_panic(new_ids.get(ix).copied()));
-
- let identifier = declaration.identifier.clone();
- let declaration_id = this.declarations.insert(Declaration::File {
- project_entry_id: entry_id,
- declaration,
- });
- new_ids.push(declaration_id);
-
- this.identifiers
- .entry(identifier)
- .or_default()
- .insert(declaration_id);
- }
+ let mut new_ids = Vec::with_capacity(declarations.len());
+ state.declarations.reserve(declarations.len());
+
+ for (parent_index, mut declaration) in declarations {
+ declaration.parent = parent_index
+ .and_then(|ix| some_or_debug_panic(new_ids.get(ix).copied()));
+
+ let identifier = declaration.identifier.clone();
+ let declaration_id = state.declarations.insert(Declaration::File {
+ project_entry_id: entry_id,
+ declaration,
+ });
+ new_ids.push(declaration_id);
+
+ state
+ .identifiers
+ .entry(identifier)
+ .or_default()
+ .insert(declaration_id);
+ }
- file_state.declarations = new_ids;
+ file_state.declarations = new_ids;
+ });
})
.ok();
}
});
- self.files
- .entry(entry_id)
- .or_insert_with(Default::default)
- .task = Some(task);
+ self.with_state(cx, move |state| {
+ state
+ .files
+ .entry(entry_id)
+ .or_insert_with(Default::default)
+ .task = Some(task);
+ });
}
}
-impl BufferDeclaration {
- pub fn from_outline(declaration: OutlineDeclaration, snapshot: &BufferSnapshot) -> Self {
- // use of anchor_before is a guess that the proper behavior is to expand to include
- // insertions immediately before the declaration, but not for insertions immediately after
- Self {
- parent: None,
- identifier: declaration.identifier,
- item_range: snapshot.anchor_before(declaration.item_range.start)
- ..snapshot.anchor_before(declaration.item_range.end),
- signature_range: snapshot.anchor_before(declaration.signature_range.start)
- ..snapshot.anchor_before(declaration.signature_range.end),
+impl SyntaxIndexState {
+ pub fn declaration(&self, id: DeclarationId) -> Option<&Declaration> {
+ self.declarations.get(id)
+ }
+
+ /// Returns declarations for the identifier. If the limit is exceeded, returns an empty vector.
+ ///
+ /// TODO: Consider doing some pre-ranking and instead truncating when N is exceeded.
+ pub fn declarations_for_identifier<const N: usize>(
+ &self,
+ identifier: &Identifier,
+ ) -> Vec<Declaration> {
+ // make sure to not have a large stack allocation
+ assert!(N < 32);
+
+ let Some(declaration_ids) = self.identifiers.get(&identifier) else {
+ return vec![];
+ };
+
+ let mut result = Vec::with_capacity(N);
+ let mut included_buffer_entry_ids = arrayvec::ArrayVec::<_, N>::new();
+ let mut file_declarations = Vec::new();
+
+ for declaration_id in declaration_ids {
+ let declaration = self.declarations.get(*declaration_id);
+ let Some(declaration) = some_or_debug_panic(declaration) else {
+ continue;
+ };
+ match declaration {
+ Declaration::Buffer {
+ project_entry_id, ..
+ } => {
+ included_buffer_entry_ids.push(*project_entry_id);
+ result.push(declaration.clone());
+ if result.len() == N {
+ return Vec::new();
+ }
+ }
+ Declaration::File {
+ project_entry_id, ..
+ } => {
+ if !included_buffer_entry_ids.contains(&project_entry_id) {
+ file_declarations.push(declaration.clone());
+ }
+ }
+ }
+ }
+
+ for declaration in file_declarations {
+ match declaration {
+ Declaration::File {
+ project_entry_id, ..
+ } => {
+ if !included_buffer_entry_ids.contains(&project_entry_id) {
+ result.push(declaration);
+
+ if result.len() == N {
+ return Vec::new();
+ }
+ }
+ }
+ Declaration::Buffer { .. } => {}
+ }
+ }
+
+ result
+ }
+
+ pub fn file_declaration_count(&self, declaration: &Declaration) -> usize {
+ match declaration {
+ Declaration::File {
+ project_entry_id, ..
+ } => self
+ .files
+ .get(project_entry_id)
+ .map(|file_state| file_state.declarations.len())
+ .unwrap_or_default(),
+ Declaration::Buffer { buffer_id, .. } => self
+ .buffers
+ .get(buffer_id)
+ .map(|buffer_state| buffer_state.declarations.len())
+ .unwrap_or_default(),
}
}
-}
-impl FileDeclaration {
- pub fn from_outline(
- declaration: OutlineDeclaration,
- snapshot: &BufferSnapshot,
- ) -> FileDeclaration {
- FileDeclaration {
- parent: None,
- identifier: declaration.identifier,
- item_range: declaration.item_range,
- signature_text: snapshot
- .text_for_range(declaration.signature_range.clone())
- .collect::<String>()
- .into(),
- signature_range: declaration.signature_range,
+ fn remove_buffer_declarations(
+ old_declaration_ids: &[DeclarationId],
+ declarations: &mut SlotMap<DeclarationId, Declaration>,
+ identifiers: &mut HashMap<Identifier, HashSet<DeclarationId>>,
+ ) {
+ for old_declaration_id in old_declaration_ids {
+ let Some(declaration) = declarations.remove(*old_declaration_id) else {
+ debug_panic!("declaration not found");
+ continue;
+ };
+ if let Some(identifier_declarations) = identifiers.get_mut(declaration.identifier()) {
+ identifier_declarations.remove(old_declaration_id);
+ }
}
}
}
@@ -509,13 +531,13 @@ mod tests {
use gpui::TestAppContext;
use indoc::indoc;
use language::{Language, LanguageConfig, LanguageId, LanguageMatcher, tree_sitter_rust};
- use project::{FakeFs, Project, ProjectItem};
+ use project::{FakeFs, Project};
use serde_json::json;
use settings::SettingsStore;
use text::OffsetRangeExt as _;
use util::path;
- use crate::tree_sitter_index::TreeSitterIndex;
+ use crate::syntax_index::SyntaxIndex;
#[gpui::test]
async fn test_unopen_indexed_files(cx: &mut TestAppContext) {
@@ -525,17 +547,19 @@ mod tests {
language_id: rust_lang_id,
};
- index.read_with(cx, |index, cx| {
- let decls = index.declarations_for_identifier::<8>(main.clone(), cx);
+ let index_state = index.read_with(cx, |index, _cx| index.state().clone());
+ let index_state = index_state.lock().await;
+ cx.update(|cx| {
+ let decls = index_state.declarations_for_identifier::<8>(&main);
assert_eq!(decls.len(), 2);
let decl = expect_file_decl("c.rs", &decls[0], &project, cx);
assert_eq!(decl.identifier, main.clone());
- assert_eq!(decl.item_range, 32..279);
+ assert_eq!(decl.item_range_in_file, 32..280);
let decl = expect_file_decl("a.rs", &decls[1], &project, cx);
assert_eq!(decl.identifier, main);
- assert_eq!(decl.item_range, 0..97);
+ assert_eq!(decl.item_range_in_file, 0..98);
});
}
@@ -547,15 +571,17 @@ mod tests {
language_id: rust_lang_id,
};
- index.read_with(cx, |index, cx| {
- let decls = index.declarations_for_identifier::<8>(test_process_data.clone(), cx);
+ let index_state = index.read_with(cx, |index, _cx| index.state().clone());
+ let index_state = index_state.lock().await;
+ cx.update(|cx| {
+ let decls = index_state.declarations_for_identifier::<8>(&test_process_data);
assert_eq!(decls.len(), 1);
let decl = expect_file_decl("c.rs", &decls[0], &project, cx);
assert_eq!(decl.identifier, test_process_data);
let parent_id = decl.parent.unwrap();
- let parent = index.declaration(parent_id).unwrap();
+ let parent = index_state.declaration(parent_id).unwrap();
let parent_decl = expect_file_decl("c.rs", &parent, &project, cx);
assert_eq!(
parent_decl.identifier,
@@ -586,16 +612,18 @@ mod tests {
cx.run_until_parked();
- index.read_with(cx, |index, cx| {
- let decls = index.declarations_for_identifier::<8>(test_process_data.clone(), cx);
+ let index_state = index.read_with(cx, |index, _cx| index.state().clone());
+ let index_state = index_state.lock().await;
+ cx.update(|cx| {
+ let decls = index_state.declarations_for_identifier::<8>(&test_process_data);
assert_eq!(decls.len(), 1);
- let decl = expect_buffer_decl("c.rs", &decls[0], cx);
+ let decl = expect_buffer_decl("c.rs", &decls[0], &project, cx);
assert_eq!(decl.identifier, test_process_data);
let parent_id = decl.parent.unwrap();
- let parent = index.declaration(parent_id).unwrap();
- let parent_decl = expect_buffer_decl("c.rs", &parent, cx);
+ let parent = index_state.declaration(parent_id).unwrap();
+ let parent_decl = expect_buffer_decl("c.rs", &parent, &project, cx);
assert_eq!(
parent_decl.identifier,
Identifier {
@@ -613,16 +641,13 @@ mod tests {
async fn test_declarations_limt(cx: &mut TestAppContext) {
let (_, index, rust_lang_id) = init_test(cx).await;
- index.read_with(cx, |index, cx| {
- let decls = index.declarations_for_identifier::<1>(
- Identifier {
- name: "main".into(),
- language_id: rust_lang_id,
- },
- cx,
- );
- assert_eq!(decls.len(), 1);
+ let index_state = index.read_with(cx, |index, _cx| index.state().clone());
+ let index_state = index_state.lock().await;
+ let decls = index_state.declarations_for_identifier::<1>(&Identifier {
+ name: "main".into(),
+ language_id: rust_lang_id,
});
+ assert_eq!(decls.len(), 0);
}
#[gpui::test]
@@ -644,24 +669,31 @@ mod tests {
cx.run_until_parked();
- index.read_with(cx, |index, cx| {
- let decls = index.declarations_for_identifier::<8>(main.clone(), cx);
- assert_eq!(decls.len(), 2);
- let decl = expect_buffer_decl("c.rs", &decls[0], cx);
- assert_eq!(decl.identifier, main);
- assert_eq!(decl.item_range.to_offset(&buffer.read(cx)), 32..279);
+ let index_state_arc = index.read_with(cx, |index, _cx| index.state().clone());
+ {
+ let index_state = index_state_arc.lock().await;
- expect_file_decl("a.rs", &decls[1], &project, cx);
- });
+ cx.update(|cx| {
+ let decls = index_state.declarations_for_identifier::<8>(&main);
+ assert_eq!(decls.len(), 2);
+ let decl = expect_buffer_decl("c.rs", &decls[0], &project, cx);
+ assert_eq!(decl.identifier, main);
+ assert_eq!(decl.item_range.to_offset(&buffer.read(cx)), 32..280);
+
+ expect_file_decl("a.rs", &decls[1], &project, cx);
+ });
+ }
- // Need to trigger flush_effects so that the observe_release handler will run.
- cx.update(|_cx| {
+ // Drop the buffer and wait for release
+ cx.update(|_| {
drop(buffer);
});
cx.run_until_parked();
- index.read_with(cx, |index, cx| {
- let decls = index.declarations_for_identifier::<8>(main, cx);
+ let index_state = index_state_arc.lock().await;
+
+ cx.update(|cx| {
+ let decls = index_state.declarations_for_identifier::<8>(&main);
assert_eq!(decls.len(), 2);
expect_file_decl("c.rs", &decls[0], &project, cx);
expect_file_decl("a.rs", &decls[1], &project, cx);
@@ -671,24 +703,20 @@ mod tests {
fn expect_buffer_decl<'a>(
path: &str,
declaration: &'a Declaration,
+ project: &Entity<Project>,
cx: &App,
) -> &'a BufferDeclaration {
if let Declaration::Buffer {
declaration,
- buffer,
+ project_entry_id,
+ ..
} = declaration
{
- assert_eq!(
- buffer
- .upgrade()
- .unwrap()
- .read(cx)
- .project_path(cx)
- .unwrap()
- .path
- .as_ref(),
- Path::new(path),
- );
+ let project_path = project
+ .read(cx)
+ .path_for_entry(*project_entry_id, cx)
+ .unwrap();
+ assert_eq!(project_path.path.as_ref(), Path::new(path),);
declaration
} else {
panic!("Expected a buffer declaration, found {:?}", declaration);
@@ -723,7 +751,7 @@ mod tests {
async fn init_test(
cx: &mut TestAppContext,
- ) -> (Entity<Project>, Entity<TreeSitterIndex>, LanguageId) {
+ ) -> (Entity<Project>, Entity<SyntaxIndex>, LanguageId) {
cx.update(|cx| {
let settings_store = SettingsStore::test(cx);
cx.set_global(settings_store);
@@ -801,7 +829,7 @@ mod tests {
let lang_id = lang.id();
language_registry.add(Arc::new(lang));
- let index = cx.new(|cx| TreeSitterIndex::new(&project, cx));
+ let index = cx.new(|cx| SyntaxIndex::new(&project, cx));
cx.run_until_parked();
(project, index, lang_id)
@@ -0,0 +1,241 @@
+use regex::Regex;
+use std::{collections::HashMap, sync::LazyLock};
+
+use crate::reference::Reference;
+
+// TODO: Consider implementing sliding window similarity matching like
+// https://github.com/sourcegraph/cody-public-snapshot/blob/8e20ac6c1460c08b0db581c0204658112a246eda/vscode/src/completions/context/retrievers/jaccard-similarity/bestJaccardMatch.ts
+//
+// That implementation could actually be more efficient - no need to track words in the window that
+// are not in the query.
+
+static IDENTIFIER_REGEX: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"\b\w+\b").unwrap());
+
+#[derive(Debug)]
+pub struct IdentifierOccurrences {
+ identifier_to_count: HashMap<String, usize>,
+ total_count: usize,
+}
+
+impl IdentifierOccurrences {
+ pub fn within_string(code: &str) -> Self {
+ Self::from_iterator(IDENTIFIER_REGEX.find_iter(code).map(|mat| mat.as_str()))
+ }
+
+ #[allow(dead_code)]
+ pub fn within_references(references: &[Reference]) -> Self {
+ Self::from_iterator(
+ references
+ .iter()
+ .map(|reference| reference.identifier.name.as_ref()),
+ )
+ }
+
+ pub fn from_iterator<'a>(identifier_iterator: impl Iterator<Item = &'a str>) -> Self {
+ let mut identifier_to_count = HashMap::new();
+ let mut total_count = 0;
+ for identifier in identifier_iterator {
+ // TODO: Score matches that match case higher?
+ //
+ // TODO: Also include unsplit identifier?
+ for identifier_part in split_identifier(identifier) {
+ identifier_to_count
+ .entry(identifier_part.to_lowercase())
+ .and_modify(|count| *count += 1)
+ .or_insert(1);
+ total_count += 1;
+ }
+ }
+ IdentifierOccurrences {
+ identifier_to_count,
+ total_count,
+ }
+ }
+}
+
+// Splits camelcase / snakecase / kebabcase / pascalcase
+//
+// TODO: Make this more efficient / elegant.
+fn split_identifier(identifier: &str) -> Vec<&str> {
+ let mut parts = Vec::new();
+ let mut start = 0;
+ let chars: Vec<char> = identifier.chars().collect();
+
+ if chars.is_empty() {
+ return parts;
+ }
+
+ let mut i = 0;
+ while i < chars.len() {
+ let ch = chars[i];
+
+ // Handle explicit delimiters (underscore and hyphen)
+ if ch == '_' || ch == '-' {
+ if i > start {
+ parts.push(&identifier[start..i]);
+ }
+ start = i + 1;
+ i += 1;
+ continue;
+ }
+
+ // Handle camelCase and PascalCase transitions
+ if i > 0 && i < chars.len() {
+ let prev_char = chars[i - 1];
+
+ // Transition from lowercase/digit to uppercase
+ if (prev_char.is_lowercase() || prev_char.is_ascii_digit()) && ch.is_uppercase() {
+ parts.push(&identifier[start..i]);
+ start = i;
+ }
+ // Handle sequences like "XMLParser" -> ["XML", "Parser"]
+ else if i + 1 < chars.len()
+ && ch.is_uppercase()
+ && chars[i + 1].is_lowercase()
+ && prev_char.is_uppercase()
+ {
+ parts.push(&identifier[start..i]);
+ start = i;
+ }
+ }
+
+ i += 1;
+ }
+
+ // Add the last part if there's any remaining
+ if start < identifier.len() {
+ parts.push(&identifier[start..]);
+ }
+
+ // Filter out empty strings
+ parts.into_iter().filter(|s| !s.is_empty()).collect()
+}
+
+pub fn jaccard_similarity<'a>(
+ mut set_a: &'a IdentifierOccurrences,
+ mut set_b: &'a IdentifierOccurrences,
+) -> f32 {
+ if set_a.identifier_to_count.len() > set_b.identifier_to_count.len() {
+ std::mem::swap(&mut set_a, &mut set_b);
+ }
+ let intersection = set_a
+ .identifier_to_count
+ .keys()
+ .filter(|key| set_b.identifier_to_count.contains_key(*key))
+ .count();
+ let union = set_a.identifier_to_count.len() + set_b.identifier_to_count.len() - intersection;
+ intersection as f32 / union as f32
+}
+
+// TODO
+#[allow(dead_code)]
+pub fn overlap_coefficient<'a>(
+ mut set_a: &'a IdentifierOccurrences,
+ mut set_b: &'a IdentifierOccurrences,
+) -> f32 {
+ if set_a.identifier_to_count.len() > set_b.identifier_to_count.len() {
+ std::mem::swap(&mut set_a, &mut set_b);
+ }
+ let intersection = set_a
+ .identifier_to_count
+ .keys()
+ .filter(|key| set_b.identifier_to_count.contains_key(*key))
+ .count();
+ intersection as f32 / set_a.identifier_to_count.len() as f32
+}
+
+// TODO
+#[allow(dead_code)]
+pub fn weighted_jaccard_similarity<'a>(
+ mut set_a: &'a IdentifierOccurrences,
+ mut set_b: &'a IdentifierOccurrences,
+) -> f32 {
+ if set_a.identifier_to_count.len() > set_b.identifier_to_count.len() {
+ std::mem::swap(&mut set_a, &mut set_b);
+ }
+
+ let mut numerator = 0;
+ let mut denominator_a = 0;
+ let mut used_count_b = 0;
+ for (symbol, count_a) in set_a.identifier_to_count.iter() {
+ let count_b = set_b.identifier_to_count.get(symbol).unwrap_or(&0);
+ numerator += count_a.min(count_b);
+ denominator_a += count_a.max(count_b);
+ used_count_b += count_b;
+ }
+
+ let denominator = denominator_a + (set_b.total_count - used_count_b);
+ if denominator == 0 {
+ 0.0
+ } else {
+ numerator as f32 / denominator as f32
+ }
+}
+
+pub fn weighted_overlap_coefficient<'a>(
+ mut set_a: &'a IdentifierOccurrences,
+ mut set_b: &'a IdentifierOccurrences,
+) -> f32 {
+ if set_a.identifier_to_count.len() > set_b.identifier_to_count.len() {
+ std::mem::swap(&mut set_a, &mut set_b);
+ }
+
+ let mut numerator = 0;
+ for (symbol, count_a) in set_a.identifier_to_count.iter() {
+ let count_b = set_b.identifier_to_count.get(symbol).unwrap_or(&0);
+ numerator += count_a.min(count_b);
+ }
+
+ let denominator = set_a.total_count.min(set_b.total_count);
+ if denominator == 0 {
+ 0.0
+ } else {
+ numerator as f32 / denominator as f32
+ }
+}
+
+#[cfg(test)]
+mod test {
+ use super::*;
+
+ #[test]
+ fn test_split_identifier() {
+ assert_eq!(split_identifier("snake_case"), vec!["snake", "case"]);
+ assert_eq!(split_identifier("kebab-case"), vec!["kebab", "case"]);
+ assert_eq!(split_identifier("PascalCase"), vec!["Pascal", "Case"]);
+ assert_eq!(split_identifier("camelCase"), vec!["camel", "Case"]);
+ assert_eq!(split_identifier("XMLParser"), vec!["XML", "Parser"]);
+ }
+
+ #[test]
+ fn test_similarity_functions() {
+ // 10 identifier parts, 8 unique
+ // Repeats: 2 "outline", 2 "items"
+ let set_a = IdentifierOccurrences::within_string(
+ "let mut outline_items = query_outline_items(&language, &tree, &source);",
+ );
+ // 14 identifier parts, 11 unique
+ // Repeats: 2 "outline", 2 "language", 2 "tree"
+ let set_b = IdentifierOccurrences::within_string(
+ "pub fn query_outline_items(language: &Language, tree: &Tree, source: &str) -> Vec<OutlineItem> {",
+ );
+
+ // 6 overlaps: "outline", "items", "query", "language", "tree", "source"
+ // 7 non-overlaps: "let", "mut", "pub", "fn", "vec", "item", "str"
+ assert_eq!(jaccard_similarity(&set_a, &set_b), 6.0 / (6.0 + 7.0));
+
+ // Numerator is one more than before due to both having 2 "outline".
+ // Denominator is the same except for 3 more due to the non-overlapping duplicates
+ assert_eq!(
+ weighted_jaccard_similarity(&set_a, &set_b),
+ 7.0 / (7.0 + 7.0 + 3.0)
+ );
+
+ // Numerator is the same as jaccard_similarity. Denominator is the size of the smaller set, 8.
+ assert_eq!(overlap_coefficient(&set_a, &set_b), 6.0 / 8.0);
+
+ // Numerator is the same as weighted_jaccard_similarity. Denominator is the total weight of
+ // the smaller set, 10.
+ assert_eq!(weighted_overlap_coefficient(&set_a, &set_b), 7.0 / 10.0);
+ }
+}
@@ -0,0 +1,35 @@
+// To discuss: What to send to the new endpoint? Thinking it'd make sense to put `prompt.rs` from
+// `zeta_context.rs` in cloud.
+//
+// * Run excerpt selection at several different sizes, send the largest size with offsets within for
+// the smaller sizes.
+//
+// * Longer event history.
+//
+// * Many more snippets than could fit in model context - allows ranking experimentation.
+
+pub struct Zeta2Request {
+ pub event_history: Vec<Event>,
+ pub excerpt: String,
+ pub excerpt_subsets: Vec<Zeta2ExcerptSubset>,
+ /// Within `excerpt`
+ pub cursor_position: usize,
+ pub signatures: Vec<String>,
+ pub retrieved_declarations: Vec<ReferencedDeclaration>,
+}
+
+pub struct Zeta2ExcerptSubset {
+ /// Within `excerpt` text.
+ pub excerpt_range: Range<usize>,
+ /// Within `signatures`.
+ pub parent_signatures: Vec<usize>,
+}
+
+pub struct ReferencedDeclaration {
+ pub text: Arc<str>,
+ /// Range within `text`
+ pub signature_range: Range<usize>,
+ /// Indices within `signatures`.
+ pub parent_signatures: Vec<usize>,
+ // A bunch of score metrics
+}