@@ -1,14 +1,18 @@
+use collections::HashSet;
+use gpui::{App, Entity};
use itertools::Itertools as _;
+use language::BufferSnapshot;
+use project::ProjectEntryId;
use serde::Serialize;
-use std::collections::HashMap;
-use std::path::Path;
-use std::sync::Arc;
+use std::{collections::HashMap, ops::Range};
use strum::EnumIter;
-use tree_sitter::StreamingIterator;
+use text::{OffsetRangeExt, Point, ToPoint};
use crate::{
- Declaration, EditPredictionExcerpt, EditPredictionExcerptText, outline::Identifier,
- reference::Reference, text_similarity::IdentifierOccurrences,
+ Declaration, EditPredictionExcerpt, EditPredictionExcerptText, TreeSitterIndex,
+ outline::Identifier,
+ reference::{Reference, ReferenceRegion},
+ text_similarity::{IdentifierOccurrences, jaccard_similarity, weighted_overlap_coefficient},
};
#[derive(Clone, Debug)]
@@ -46,23 +50,29 @@ impl ScoredSnippet {
}
fn scored_snippets(
+ index: Entity<TreeSitterIndex>,
excerpt: &EditPredictionExcerpt,
excerpt_text: &EditPredictionExcerptText,
references: Vec<Reference>,
cursor_offset: usize,
+ current_buffer: &BufferSnapshot,
+ cx: &App,
) -> Vec<ScoredSnippet> {
- let excerpt_occurrences = IdentifierOccurrences::within_string(&excerpt_text.body);
+ let containing_range_identifier_occurrences =
+ IdentifierOccurrences::within_string(&excerpt_text.body);
+ let cursor_point = cursor_offset.to_point(¤t_buffer);
- /* todo!
- if let Some(cursor_within_excerpt) = cursor_offset.checked_sub(excerpt.range.start) {
- } else {
- };
- let start_point = Point::new(cursor.row.saturating_sub(2), 0);
- let end_point = Point::new(cursor.row + 1, 0);
+ // todo! ask michael why we needed this
+ // if let Some(cursor_within_excerpt) = cursor_offset.checked_sub(excerpt.range.start) {
+ // } else {
+ // };
+ let start_point = Point::new(cursor_point.row.saturating_sub(2), 0);
+ let end_point = Point::new(cursor_point.row + 1, 0);
let adjacent_identifier_occurrences = IdentifierOccurrences::within_string(
- &source[offset_from_point(source, start_point)..offset_from_point(source, end_point)],
+ ¤t_buffer
+ .text_for_range(start_point..end_point)
+ .collect::<String>(),
);
- */
let mut identifier_to_references: HashMap<Identifier, Vec<Reference>> = HashMap::new();
for reference in references {
@@ -75,74 +85,102 @@ fn scored_snippets(
identifier_to_references
.into_iter()
.flat_map(|(identifier, references)| {
- let Some(definitions) = index
- .identifier_to_definitions
- .get(&(identifier.clone(), language.name.clone()))
- else {
- return Vec::new();
- };
+ let definitions = index
+ .read(cx)
+ // todo! pick a limit
+ .declarations_for_identifier::<16>(&identifier, cx);
let definition_count = definitions.len();
- let definition_file_count = definitions.keys().len();
+ let total_file_count = definitions
+ .iter()
+ .filter_map(|definition| definition.project_entry_id(cx))
+ .collect::<HashSet<ProjectEntryId>>()
+ .len();
definitions
- .iter_all()
- .flat_map(|(definition_file, file_definitions)| {
- let same_file_definition_count = file_definitions.len();
- let is_same_file = reference_file == definition_file.as_ref();
- file_definitions
- .iter()
- .filter(|definition| {
- !is_same_file
- || !range_intersection(&definition.item_range, &excerpt_range)
- .is_some()
- })
- .filter_map(|definition| {
- let definition_line_distance = if is_same_file {
+ .iter()
+ .filter_map(|definition| match definition {
+ Declaration::Buffer {
+ declaration,
+ buffer,
+ } => {
+ let is_same_file = buffer
+ .read_with(cx, |buffer, _| buffer.remote_id())
+ .is_ok_and(|buffer_id| buffer_id == current_buffer.remote_id());
+
+ if is_same_file {
+ range_intersection(
+ &declaration.item_range.to_offset(¤t_buffer),
+ &excerpt.range,
+ )
+ .is_none()
+ .then(|| {
let definition_line =
- point_from_offset(source, definition.item_range.start).row;
- (cursor.row as i32 - definition_line as i32).abs() as u32
- } else {
- 0
- };
- Some((definition_line_distance, definition))
- })
- .sorted_by_key(|&(distance, _)| distance)
- .enumerate()
- .map(
- |(
- definition_line_distance_rank,
- (definition_line_distance, definition),
- )| {
- score_snippet(
- &identifier,
- &references,
- definition_file.clone(),
- definition.clone(),
- is_same_file,
- definition_line_distance,
- definition_line_distance_rank,
- same_file_definition_count,
- definition_count,
- definition_file_count,
- &containing_range_identifier_occurrences,
- &adjacent_identifier_occurrences,
- cursor,
+ declaration.item_range.start.to_point(current_buffer).row;
+ (
+ true,
+ (cursor_point.row as i32 - definition_line as i32).abs() as u32,
+ definition,
)
- },
- )
- .collect::<Vec<_>>()
+ })
+ } else {
+ Some((false, 0, definition))
+ }
+ }
+ Declaration::File { .. } => {
+ // We can assume that a file declaration is in a different file,
+ // because the current onemust be open
+ Some((false, 0, definition))
+ }
})
+ .sorted_by_key(|&(_, distance, _)| distance)
+ .enumerate()
+ .map(
+ |(
+ definition_line_distance_rank,
+ (is_same_file, definition_line_distance, definition),
+ )| {
+ let same_file_definition_count =
+ index.read(cx).file_declaration_count(definition);
+
+ score_snippet(
+ &identifier,
+ &references,
+ definition.clone(),
+ is_same_file,
+ definition_line_distance,
+ definition_line_distance_rank,
+ same_file_definition_count,
+ definition_count,
+ total_file_count,
+ &containing_range_identifier_occurrences,
+ &adjacent_identifier_occurrences,
+ cursor_point,
+ current_buffer,
+ cx,
+ )
+ },
+ )
.collect::<Vec<_>>()
})
.flatten()
.collect::<Vec<_>>()
}
+// todo! replace with existing util?
+fn range_intersection<T: Ord + Clone>(a: &Range<T>, b: &Range<T>) -> Option<Range<T>> {
+ let start = a.start.clone().max(b.start.clone());
+ let end = a.end.clone().min(b.end.clone());
+ if start < end {
+ Some(Range { start, end })
+ } else {
+ None
+ }
+}
+
fn score_snippet(
identifier: &Identifier,
references: &[Reference],
- definition_file: Arc<Path>,
- definition: OutlineItem,
+ definition: Declaration,
is_same_file: bool,
definition_line_distance: u32,
definition_line_distance_rank: usize,
@@ -152,28 +190,28 @@ fn score_snippet(
containing_range_identifier_occurrences: &IdentifierOccurrences,
adjacent_identifier_occurrences: &IdentifierOccurrences,
cursor: Point,
+ current_buffer: &BufferSnapshot,
+ cx: &App,
) -> Option<ScoredSnippet> {
let is_referenced_nearby = references
.iter()
- .any(|r| r.reference_region == ReferenceRegion::Nearby);
+ .any(|r| r.region == ReferenceRegion::Nearby);
let is_referenced_in_breadcrumb = references
.iter()
- .any(|r| r.reference_region == ReferenceRegion::Breadcrumb);
+ .any(|r| r.region == ReferenceRegion::Breadcrumb);
let reference_count = references.len();
let reference_line_distance = references
.iter()
.map(|r| {
- let reference_line = point_from_offset(reference_source, r.range.start).row as i32;
+ let reference_line = r.range.start.to_point(current_buffer).row as i32;
(cursor.row as i32 - reference_line).abs() as u32
})
.min()
.unwrap();
- let definition_source = index.path_to_source.get(&definition_file).unwrap();
- let item_source_occurrences =
- IdentifierOccurrences::within_string(definition.item(&definition_source));
+ let item_source_occurrences = IdentifierOccurrences::within_string(&definition.item_text(cx));
let item_signature_occurrences =
- IdentifierOccurrences::within_string(definition.signature(&definition_source));
+ IdentifierOccurrences::within_string(&definition.signature_text(cx));
let containing_range_vs_item_jaccard = jaccard_similarity(
containing_range_identifier_occurrences,
&item_source_occurrences,
@@ -223,7 +261,6 @@ fn score_snippet(
Some(ScoredSnippet {
identifier: identifier.clone(),
- declaration_file: definition_file,
declaration: definition,
scores: score_components.score(),
score_components,
@@ -238,6 +275,7 @@ pub struct ScoreInputs {
pub reference_count: usize,
pub same_file_definition_count: usize,
pub definition_count: usize,
+ // todo! do we need this?
pub definition_file_count: usize,
pub reference_line_distance: u32,
pub definition_line_distance: u32,
@@ -78,6 +78,57 @@ impl Declaration {
Declaration::Buffer { declaration, .. } => &declaration.identifier,
}
}
+
+ pub fn project_entry_id(&self, cx: &App) -> Option<ProjectEntryId> {
+ match self {
+ Declaration::File {
+ project_entry_id, ..
+ } => Some(*project_entry_id),
+ Declaration::Buffer { buffer, .. } => buffer
+ .read_with(cx, |buffer, _cx| {
+ project::File::from_dyn(buffer.file())
+ .and_then(|file| file.project_entry_id(cx))
+ })
+ .ok()
+ .flatten(),
+ }
+ }
+
+ // todo! pick best return type
+ pub fn item_text(&self, cx: &App) -> Arc<str> {
+ match self {
+ Declaration::File { declaration, .. } => declaration.declaration_text.clone(),
+ Declaration::Buffer {
+ buffer,
+ declaration,
+ } => buffer
+ .read_with(cx, |buffer, _cx| {
+ buffer
+ .text_for_range(declaration.item_range.clone())
+ .collect::<String>()
+ .into()
+ })
+ .unwrap_or_default(),
+ }
+ }
+
+ // todo! pick best return type
+ pub fn signature_text(&self, cx: &App) -> Arc<str> {
+ match self {
+ Declaration::File { declaration, .. } => declaration.signature_text.clone(),
+ Declaration::Buffer {
+ buffer,
+ declaration,
+ } => buffer
+ .read_with(cx, |buffer, _cx| {
+ buffer
+ .text_for_range(declaration.signature_range.clone())
+ .collect::<String>()
+ .into()
+ })
+ .unwrap_or_default(),
+ }
+ }
}
#[derive(Debug, Clone)]
@@ -86,7 +137,9 @@ pub struct FileDeclaration {
pub identifier: Identifier,
pub item_range: Range<usize>,
pub signature_range: Range<usize>,
+ // todo! should we just store a range with the declaration text?
pub signature_text: Arc<str>,
+ pub declaration_text: Arc<str>,
}
#[derive(Debug, Clone)]
@@ -145,7 +198,7 @@ impl TreeSitterIndex {
pub fn declarations_for_identifier<const N: usize>(
&self,
- identifier: Identifier,
+ identifier: &Identifier,
cx: &App,
) -> Vec<Declaration> {
// make sure to not have a large stack allocation
@@ -206,6 +259,23 @@ impl TreeSitterIndex {
result
}
+ pub fn file_declaration_count(&self, declaration: &Declaration) -> usize {
+ match declaration {
+ Declaration::File {
+ project_entry_id, ..
+ } => self
+ .files
+ .get(project_entry_id)
+ .map(|file_state| file_state.declarations.len())
+ .unwrap_or_default(),
+ Declaration::Buffer { buffer, .. } => self
+ .buffers
+ .get(buffer)
+ .map(|buffer_state| buffer_state.declarations.len())
+ .unwrap_or_default(),
+ }
+ }
+
fn handle_worktree_store_event(
&mut self,
_worktree_store: Entity<WorktreeStore>,
@@ -491,12 +561,16 @@ impl FileDeclaration {
FileDeclaration {
parent: None,
identifier: declaration.identifier,
- item_range: declaration.item_range,
signature_text: snapshot
.text_for_range(declaration.signature_range.clone())
.collect::<String>()
.into(),
signature_range: declaration.signature_range,
+ declaration_text: snapshot
+ .text_for_range(declaration.item_range.clone())
+ .collect::<String>()
+ .into(),
+ item_range: declaration.item_range,
}
}
}
@@ -527,7 +601,7 @@ mod tests {
};
index.read_with(cx, |index, cx| {
- let decls = index.declarations_for_identifier::<8>(main.clone(), cx);
+ let decls = index.declarations_for_identifier::<8>(&main, cx);
assert_eq!(decls.len(), 2);
let decl = expect_file_decl("c.rs", &decls[0], &project, cx);
@@ -549,7 +623,7 @@ mod tests {
};
index.read_with(cx, |index, cx| {
- let decls = index.declarations_for_identifier::<8>(test_process_data.clone(), cx);
+ let decls = index.declarations_for_identifier::<8>(&test_process_data, cx);
assert_eq!(decls.len(), 1);
let decl = expect_file_decl("c.rs", &decls[0], &project, cx);
@@ -588,7 +662,7 @@ mod tests {
cx.run_until_parked();
index.read_with(cx, |index, cx| {
- let decls = index.declarations_for_identifier::<8>(test_process_data.clone(), cx);
+ let decls = index.declarations_for_identifier::<8>(&test_process_data, cx);
assert_eq!(decls.len(), 1);
let decl = expect_buffer_decl("c.rs", &decls[0], cx);
@@ -616,7 +690,7 @@ mod tests {
index.read_with(cx, |index, cx| {
let decls = index.declarations_for_identifier::<1>(
- Identifier {
+ &Identifier {
name: "main".into(),
language_id: rust_lang_id,
},
@@ -646,7 +720,7 @@ mod tests {
cx.run_until_parked();
index.read_with(cx, |index, cx| {
- let decls = index.declarations_for_identifier::<8>(main.clone(), cx);
+ let decls = index.declarations_for_identifier::<8>(&main, cx);
assert_eq!(decls.len(), 2);
let decl = expect_buffer_decl("c.rs", &decls[0], cx);
assert_eq!(decl.identifier, main);
@@ -669,7 +743,7 @@ mod tests {
cx.run_until_parked();
index.read_with(cx, |index, cx| {
- let decls = index.declarations_for_identifier::<8>(main, cx);
+ let decls = index.declarations_for_identifier::<8>(&main, cx);
assert_eq!(decls.len(), 2);
expect_file_decl("c.rs", &decls[0], &project, cx);
expect_file_decl("a.rs", &decls[1], &project, cx);