From cc32bfdfdf8bc95259a9df900b575667f84f77eb Mon Sep 17 00:00:00 2001 From: Agus Date: Wed, 17 Sep 2025 11:47:47 -0300 Subject: [PATCH] Checkpoint: Get score_snippets to compile Co-Authored-By: Finn --- .../src/scored_declaration.rs | 190 +++++++++++------- .../src/tree_sitter_index.rs | 90 ++++++++- 2 files changed, 196 insertions(+), 84 deletions(-) diff --git a/crates/edit_prediction_context/src/scored_declaration.rs b/crates/edit_prediction_context/src/scored_declaration.rs index d61b877aae4690f1456c7fe769947ab3786f00f0..56629a0505a6bae1e32aec158ba4a2985ef803d7 100644 --- a/crates/edit_prediction_context/src/scored_declaration.rs +++ b/crates/edit_prediction_context/src/scored_declaration.rs @@ -1,14 +1,18 @@ +use collections::HashSet; +use gpui::{App, Entity}; use itertools::Itertools as _; +use language::BufferSnapshot; +use project::ProjectEntryId; use serde::Serialize; -use std::collections::HashMap; -use std::path::Path; -use std::sync::Arc; +use std::{collections::HashMap, ops::Range}; use strum::EnumIter; -use tree_sitter::StreamingIterator; +use text::{OffsetRangeExt, Point, ToPoint}; use crate::{ - Declaration, EditPredictionExcerpt, EditPredictionExcerptText, outline::Identifier, - reference::Reference, text_similarity::IdentifierOccurrences, + Declaration, EditPredictionExcerpt, EditPredictionExcerptText, TreeSitterIndex, + outline::Identifier, + reference::{Reference, ReferenceRegion}, + text_similarity::{IdentifierOccurrences, jaccard_similarity, weighted_overlap_coefficient}, }; #[derive(Clone, Debug)] @@ -46,23 +50,29 @@ impl ScoredSnippet { } fn scored_snippets( + index: Entity, excerpt: &EditPredictionExcerpt, excerpt_text: &EditPredictionExcerptText, references: Vec, cursor_offset: usize, + current_buffer: &BufferSnapshot, + cx: &App, ) -> Vec { - let excerpt_occurrences = IdentifierOccurrences::within_string(&excerpt_text.body); + let containing_range_identifier_occurrences = + IdentifierOccurrences::within_string(&excerpt_text.body); + let cursor_point = cursor_offset.to_point(¤t_buffer); - /* todo! - if let Some(cursor_within_excerpt) = cursor_offset.checked_sub(excerpt.range.start) { - } else { - }; - let start_point = Point::new(cursor.row.saturating_sub(2), 0); - let end_point = Point::new(cursor.row + 1, 0); + // todo! ask michael why we needed this + // if let Some(cursor_within_excerpt) = cursor_offset.checked_sub(excerpt.range.start) { + // } else { + // }; + let start_point = Point::new(cursor_point.row.saturating_sub(2), 0); + let end_point = Point::new(cursor_point.row + 1, 0); let adjacent_identifier_occurrences = IdentifierOccurrences::within_string( - &source[offset_from_point(source, start_point)..offset_from_point(source, end_point)], + ¤t_buffer + .text_for_range(start_point..end_point) + .collect::(), ); - */ let mut identifier_to_references: HashMap> = HashMap::new(); for reference in references { @@ -75,74 +85,102 @@ fn scored_snippets( identifier_to_references .into_iter() .flat_map(|(identifier, references)| { - let Some(definitions) = index - .identifier_to_definitions - .get(&(identifier.clone(), language.name.clone())) - else { - return Vec::new(); - }; + let definitions = index + .read(cx) + // todo! pick a limit + .declarations_for_identifier::<16>(&identifier, cx); let definition_count = definitions.len(); - let definition_file_count = definitions.keys().len(); + let total_file_count = definitions + .iter() + .filter_map(|definition| definition.project_entry_id(cx)) + .collect::>() + .len(); definitions - .iter_all() - .flat_map(|(definition_file, file_definitions)| { - let same_file_definition_count = file_definitions.len(); - let is_same_file = reference_file == definition_file.as_ref(); - file_definitions - .iter() - .filter(|definition| { - !is_same_file - || !range_intersection(&definition.item_range, &excerpt_range) - .is_some() - }) - .filter_map(|definition| { - let definition_line_distance = if is_same_file { + .iter() + .filter_map(|definition| match definition { + Declaration::Buffer { + declaration, + buffer, + } => { + let is_same_file = buffer + .read_with(cx, |buffer, _| buffer.remote_id()) + .is_ok_and(|buffer_id| buffer_id == current_buffer.remote_id()); + + if is_same_file { + range_intersection( + &declaration.item_range.to_offset(¤t_buffer), + &excerpt.range, + ) + .is_none() + .then(|| { let definition_line = - point_from_offset(source, definition.item_range.start).row; - (cursor.row as i32 - definition_line as i32).abs() as u32 - } else { - 0 - }; - Some((definition_line_distance, definition)) - }) - .sorted_by_key(|&(distance, _)| distance) - .enumerate() - .map( - |( - definition_line_distance_rank, - (definition_line_distance, definition), - )| { - score_snippet( - &identifier, - &references, - definition_file.clone(), - definition.clone(), - is_same_file, - definition_line_distance, - definition_line_distance_rank, - same_file_definition_count, - definition_count, - definition_file_count, - &containing_range_identifier_occurrences, - &adjacent_identifier_occurrences, - cursor, + declaration.item_range.start.to_point(current_buffer).row; + ( + true, + (cursor_point.row as i32 - definition_line as i32).abs() as u32, + definition, ) - }, - ) - .collect::>() + }) + } else { + Some((false, 0, definition)) + } + } + Declaration::File { .. } => { + // We can assume that a file declaration is in a different file, + // because the current onemust be open + Some((false, 0, definition)) + } }) + .sorted_by_key(|&(_, distance, _)| distance) + .enumerate() + .map( + |( + definition_line_distance_rank, + (is_same_file, definition_line_distance, definition), + )| { + let same_file_definition_count = + index.read(cx).file_declaration_count(definition); + + score_snippet( + &identifier, + &references, + definition.clone(), + is_same_file, + definition_line_distance, + definition_line_distance_rank, + same_file_definition_count, + definition_count, + total_file_count, + &containing_range_identifier_occurrences, + &adjacent_identifier_occurrences, + cursor_point, + current_buffer, + cx, + ) + }, + ) .collect::>() }) .flatten() .collect::>() } +// todo! replace with existing util? +fn range_intersection(a: &Range, b: &Range) -> Option> { + let start = a.start.clone().max(b.start.clone()); + let end = a.end.clone().min(b.end.clone()); + if start < end { + Some(Range { start, end }) + } else { + None + } +} + fn score_snippet( identifier: &Identifier, references: &[Reference], - definition_file: Arc, - definition: OutlineItem, + definition: Declaration, is_same_file: bool, definition_line_distance: u32, definition_line_distance_rank: usize, @@ -152,28 +190,28 @@ fn score_snippet( containing_range_identifier_occurrences: &IdentifierOccurrences, adjacent_identifier_occurrences: &IdentifierOccurrences, cursor: Point, + current_buffer: &BufferSnapshot, + cx: &App, ) -> Option { let is_referenced_nearby = references .iter() - .any(|r| r.reference_region == ReferenceRegion::Nearby); + .any(|r| r.region == ReferenceRegion::Nearby); let is_referenced_in_breadcrumb = references .iter() - .any(|r| r.reference_region == ReferenceRegion::Breadcrumb); + .any(|r| r.region == ReferenceRegion::Breadcrumb); let reference_count = references.len(); let reference_line_distance = references .iter() .map(|r| { - let reference_line = point_from_offset(reference_source, r.range.start).row as i32; + let reference_line = r.range.start.to_point(current_buffer).row as i32; (cursor.row as i32 - reference_line).abs() as u32 }) .min() .unwrap(); - let definition_source = index.path_to_source.get(&definition_file).unwrap(); - let item_source_occurrences = - IdentifierOccurrences::within_string(definition.item(&definition_source)); + let item_source_occurrences = IdentifierOccurrences::within_string(&definition.item_text(cx)); let item_signature_occurrences = - IdentifierOccurrences::within_string(definition.signature(&definition_source)); + IdentifierOccurrences::within_string(&definition.signature_text(cx)); let containing_range_vs_item_jaccard = jaccard_similarity( containing_range_identifier_occurrences, &item_source_occurrences, @@ -223,7 +261,6 @@ fn score_snippet( Some(ScoredSnippet { identifier: identifier.clone(), - declaration_file: definition_file, declaration: definition, scores: score_components.score(), score_components, @@ -238,6 +275,7 @@ pub struct ScoreInputs { pub reference_count: usize, pub same_file_definition_count: usize, pub definition_count: usize, + // todo! do we need this? pub definition_file_count: usize, pub reference_line_distance: u32, pub definition_line_distance: u32, diff --git a/crates/edit_prediction_context/src/tree_sitter_index.rs b/crates/edit_prediction_context/src/tree_sitter_index.rs index 4dc00941fe1b8a7a095fffd5605b040001c02eb7..1aa7e72d3eaad66626a9504ec9daf49565c59897 100644 --- a/crates/edit_prediction_context/src/tree_sitter_index.rs +++ b/crates/edit_prediction_context/src/tree_sitter_index.rs @@ -78,6 +78,57 @@ impl Declaration { Declaration::Buffer { declaration, .. } => &declaration.identifier, } } + + pub fn project_entry_id(&self, cx: &App) -> Option { + match self { + Declaration::File { + project_entry_id, .. + } => Some(*project_entry_id), + Declaration::Buffer { buffer, .. } => buffer + .read_with(cx, |buffer, _cx| { + project::File::from_dyn(buffer.file()) + .and_then(|file| file.project_entry_id(cx)) + }) + .ok() + .flatten(), + } + } + + // todo! pick best return type + pub fn item_text(&self, cx: &App) -> Arc { + match self { + Declaration::File { declaration, .. } => declaration.declaration_text.clone(), + Declaration::Buffer { + buffer, + declaration, + } => buffer + .read_with(cx, |buffer, _cx| { + buffer + .text_for_range(declaration.item_range.clone()) + .collect::() + .into() + }) + .unwrap_or_default(), + } + } + + // todo! pick best return type + pub fn signature_text(&self, cx: &App) -> Arc { + match self { + Declaration::File { declaration, .. } => declaration.signature_text.clone(), + Declaration::Buffer { + buffer, + declaration, + } => buffer + .read_with(cx, |buffer, _cx| { + buffer + .text_for_range(declaration.signature_range.clone()) + .collect::() + .into() + }) + .unwrap_or_default(), + } + } } #[derive(Debug, Clone)] @@ -86,7 +137,9 @@ pub struct FileDeclaration { pub identifier: Identifier, pub item_range: Range, pub signature_range: Range, + // todo! should we just store a range with the declaration text? pub signature_text: Arc, + pub declaration_text: Arc, } #[derive(Debug, Clone)] @@ -145,7 +198,7 @@ impl TreeSitterIndex { pub fn declarations_for_identifier( &self, - identifier: Identifier, + identifier: &Identifier, cx: &App, ) -> Vec { // make sure to not have a large stack allocation @@ -206,6 +259,23 @@ impl TreeSitterIndex { result } + pub fn file_declaration_count(&self, declaration: &Declaration) -> usize { + match declaration { + Declaration::File { + project_entry_id, .. + } => self + .files + .get(project_entry_id) + .map(|file_state| file_state.declarations.len()) + .unwrap_or_default(), + Declaration::Buffer { buffer, .. } => self + .buffers + .get(buffer) + .map(|buffer_state| buffer_state.declarations.len()) + .unwrap_or_default(), + } + } + fn handle_worktree_store_event( &mut self, _worktree_store: Entity, @@ -491,12 +561,16 @@ impl FileDeclaration { FileDeclaration { parent: None, identifier: declaration.identifier, - item_range: declaration.item_range, signature_text: snapshot .text_for_range(declaration.signature_range.clone()) .collect::() .into(), signature_range: declaration.signature_range, + declaration_text: snapshot + .text_for_range(declaration.item_range.clone()) + .collect::() + .into(), + item_range: declaration.item_range, } } } @@ -527,7 +601,7 @@ mod tests { }; index.read_with(cx, |index, cx| { - let decls = index.declarations_for_identifier::<8>(main.clone(), cx); + let decls = index.declarations_for_identifier::<8>(&main, cx); assert_eq!(decls.len(), 2); let decl = expect_file_decl("c.rs", &decls[0], &project, cx); @@ -549,7 +623,7 @@ mod tests { }; index.read_with(cx, |index, cx| { - let decls = index.declarations_for_identifier::<8>(test_process_data.clone(), cx); + let decls = index.declarations_for_identifier::<8>(&test_process_data, cx); assert_eq!(decls.len(), 1); let decl = expect_file_decl("c.rs", &decls[0], &project, cx); @@ -588,7 +662,7 @@ mod tests { cx.run_until_parked(); index.read_with(cx, |index, cx| { - let decls = index.declarations_for_identifier::<8>(test_process_data.clone(), cx); + let decls = index.declarations_for_identifier::<8>(&test_process_data, cx); assert_eq!(decls.len(), 1); let decl = expect_buffer_decl("c.rs", &decls[0], cx); @@ -616,7 +690,7 @@ mod tests { index.read_with(cx, |index, cx| { let decls = index.declarations_for_identifier::<1>( - Identifier { + &Identifier { name: "main".into(), language_id: rust_lang_id, }, @@ -646,7 +720,7 @@ mod tests { cx.run_until_parked(); index.read_with(cx, |index, cx| { - let decls = index.declarations_for_identifier::<8>(main.clone(), cx); + let decls = index.declarations_for_identifier::<8>(&main, cx); assert_eq!(decls.len(), 2); let decl = expect_buffer_decl("c.rs", &decls[0], cx); assert_eq!(decl.identifier, main); @@ -669,7 +743,7 @@ mod tests { cx.run_until_parked(); index.read_with(cx, |index, cx| { - let decls = index.declarations_for_identifier::<8>(main, cx); + let decls = index.declarations_for_identifier::<8>(&main, cx); assert_eq!(decls.len(), 2); expect_file_decl("c.rs", &decls[0], &project, cx); expect_file_decl("a.rs", &decls[1], &project, cx);