From cb0c4bec24a82ade5046e18752243d814e6a016d Mon Sep 17 00:00:00 2001 From: Michael Sloan Date: Thu, 18 Sep 2025 17:38:57 -0600 Subject: [PATCH] Progress preparing new cloud request + using index in excerpt selection Co-authored-by: Agus --- Cargo.lock | 1 + .../cloud_llm_client/src/cloud_llm_client.rs | 2 + .../cloud_llm_client/src/predict_edits_v3.rs | 123 +++++++++++++ crates/edit_prediction_context/Cargo.toml | 1 + .../src/declaration.rs | 24 +++ .../src/declaration_scoring.rs | 77 ++++---- .../src/edit_prediction_context.rs | 168 ++++++++++++++---- crates/edit_prediction_context/src/excerpt.rs | 85 ++++----- .../edit_prediction_context/src/reference.rs | 4 +- .../src/syntax_index.rs | 62 +++++-- .../src/edit_prediction_tools.rs | 11 +- 11 files changed, 408 insertions(+), 150 deletions(-) create mode 100644 crates/cloud_llm_client/src/predict_edits_v3.rs diff --git a/Cargo.lock b/Cargo.lock index 28247c7f4c6bcab64a8acce951e072e80abf751b..3abb36adc1d6768dbe81d573ea6a27704ae4fe59 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5174,6 +5174,7 @@ dependencies = [ "anyhow", "arrayvec", "clap", + "cloud_llm_client", "collections", "futures 0.3.31", "gpui", diff --git a/crates/cloud_llm_client/src/cloud_llm_client.rs b/crates/cloud_llm_client/src/cloud_llm_client.rs index 24923a318441afeaa2521064b4f433ab9ee1e55f..e0cc42af76156466c31ead17d6421f3634d3ad7c 100644 --- a/crates/cloud_llm_client/src/cloud_llm_client.rs +++ b/crates/cloud_llm_client/src/cloud_llm_client.rs @@ -1,3 +1,5 @@ +pub mod predict_edits_v3; + use std::str::FromStr; use std::sync::Arc; diff --git a/crates/cloud_llm_client/src/predict_edits_v3.rs b/crates/cloud_llm_client/src/predict_edits_v3.rs new file mode 100644 index 0000000000000000000000000000000000000000..076dc6c5cb2f5c50d460a9c8e01172461ca9b123 --- /dev/null +++ b/crates/cloud_llm_client/src/predict_edits_v3.rs @@ -0,0 +1,123 @@ +use serde::{Deserialize, Serialize}; +use std::ops::Range; + +use crate::PredictEditsGitInfo; + +// TODO: snippet ordering within file / relative to excerpt + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Body { + pub excerpt: String, + /// Within `signatures` + pub excerpt_parent: Option, + pub signatures: Vec, + pub referenced_declarations: Vec, + pub events: Vec, + #[serde(default)] + pub can_collect_data: bool, + #[serde(skip_serializing_if = "Option::is_none", default)] + pub diagnostic_groups: Option>, + /// Info about the git repository state, only present when can_collect_data is true. + #[serde(skip_serializing_if = "Option::is_none", default)] + pub git_info: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum Event {} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Signature { + pub text: String, + pub text_is_truncated: bool, + #[serde(skip_serializing_if = "Option::is_none", default)] + pub parent_index: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ReferencedDeclaration { + pub text: String, + pub text_is_truncated: bool, + /// Range within `text` + pub signature_range: Range, + /// Index within `signatures`. + #[serde(skip_serializing_if = "Option::is_none", default)] + pub parent_index: Option, + pub score_components: ScoreComponents, + pub signature_score: f32, + pub declaration_score: f32, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ScoreComponents { + pub is_same_file: bool, + pub is_referenced_nearby: bool, + pub is_referenced_in_breadcrumb: bool, + pub reference_count: usize, + pub same_file_declaration_count: usize, + pub declaration_count: usize, + pub reference_line_distance: u32, + pub declaration_line_distance: u32, + pub declaration_line_distance_rank: usize, + pub containing_range_vs_item_jaccard: f32, + pub containing_range_vs_signature_jaccard: f32, + pub adjacent_vs_item_jaccard: f32, + pub adjacent_vs_signature_jaccard: f32, + pub containing_range_vs_item_weighted_overlap: f32, + pub containing_range_vs_signature_weighted_overlap: f32, + pub adjacent_vs_item_weighted_overlap: f32, + pub adjacent_vs_signature_weighted_overlap: f32, +} + +/* +#[derive(Debug, Clone)] +pub struct SerializedJson { + raw: Box, + _phantom: PhantomData, +} + +impl SerializedJson +where + T: Serialize + for<'de> Deserialize<'de>, +{ + pub fn new(value: &T) -> Result { + Ok(SerializedJson { + raw: serde_json::value::to_raw_value(value)?, + _phantom: PhantomData, + }) + } + + pub fn deserialize(&self) -> Result { + serde_json::from_str(self.raw.get()) + } + + pub fn as_raw(&self) -> &RawValue { + &self.raw + } + + pub fn into_raw(self) -> Box { + self.raw + } +} + +impl Serialize for SerializedJson { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + self.raw.serialize(serializer) + } +} + +impl<'de, T> Deserialize<'de> for SerializedJson { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + let raw = Box::::deserialize(deserializer)?; + Ok(SerializedJson { + raw, + _phantom: PhantomData, + }) + } +} +*/ diff --git a/crates/edit_prediction_context/Cargo.toml b/crates/edit_prediction_context/Cargo.toml index 48f51da3912ea5bca589e7b559d5b665b9b762d6..75880cad5f3e2807e525908656931853efa19a92 100644 --- a/crates/edit_prediction_context/Cargo.toml +++ b/crates/edit_prediction_context/Cargo.toml @@ -14,6 +14,7 @@ path = "src/edit_prediction_context.rs" [dependencies] anyhow.workspace = true arrayvec.workspace = true +cloud_llm_client.workspace = true collections.workspace = true futures.workspace = true gpui.workspace = true diff --git a/crates/edit_prediction_context/src/declaration.rs b/crates/edit_prediction_context/src/declaration.rs index 8fba85367c70e2cb1211e343bba7a6675a5a5360..2f3b6f437146ff3d91d45aff44b666dde07580b3 100644 --- a/crates/edit_prediction_context/src/declaration.rs +++ b/crates/edit_prediction_context/src/declaration.rs @@ -41,6 +41,20 @@ impl Declaration { } } + pub fn parent(&self) -> Option { + match self { + Declaration::File { declaration, .. } => declaration.parent, + Declaration::Buffer { declaration, .. } => declaration.parent, + } + } + + pub fn as_buffer(&self) -> Option<&BufferDeclaration> { + match self { + Declaration::File { .. } => None, + Declaration::Buffer { declaration, .. } => Some(declaration), + } + } + pub fn project_entry_id(&self) -> ProjectEntryId { match self { Declaration::File { @@ -83,6 +97,16 @@ impl Declaration { ), } } + + pub fn signature_range_in_item_text(&self) -> Range { + match self { + Declaration::File { declaration, .. } => declaration.signature_range_in_text.clone(), + Declaration::Buffer { declaration, .. } => { + declaration.signature_range.start - declaration.item_range.start + ..declaration.signature_range.end - declaration.item_range.start + } + } + } } fn expand_range_to_line_boundaries_and_truncate( diff --git a/crates/edit_prediction_context/src/declaration_scoring.rs b/crates/edit_prediction_context/src/declaration_scoring.rs index 4cbc4e83c02e0cae912813a261d99f8fe8c41b55..accb92901c173aecf9ad76116f4fc43e9c32c981 100644 --- a/crates/edit_prediction_context/src/declaration_scoring.rs +++ b/crates/edit_prediction_context/src/declaration_scoring.rs @@ -1,10 +1,11 @@ +use cloud_llm_client::predict_edits_v3::ScoreComponents; use itertools::Itertools as _; use language::BufferSnapshot; use ordered_float::OrderedFloat; use serde::Serialize; use std::{collections::HashMap, ops::Range}; use strum::EnumIter; -use text::{OffsetRangeExt, Point, ToPoint}; +use text::{Point, ToPoint}; use crate::{ Declaration, EditPredictionExcerpt, EditPredictionExcerptText, Identifier, @@ -23,7 +24,7 @@ const MAX_IDENTIFIER_DECLARATION_COUNT: usize = 16; pub struct ScoredSnippet { pub identifier: Identifier, pub declaration: Declaration, - pub score_components: ScoreInputs, + pub score_components: ScoreComponents, pub scores: Scores, } @@ -90,8 +91,8 @@ pub fn scored_snippets( let declaration_count = declarations.len(); declarations - .iter() - .filter_map(|declaration| match declaration { + .into_iter() + .filter_map(|(declaration_id, declaration)| match declaration { Declaration::Buffer { buffer_id, declaration: buffer_declaration, @@ -100,24 +101,29 @@ pub fn scored_snippets( let is_same_file = buffer_id == ¤t_buffer.remote_id(); if is_same_file { - range_intersection( - &buffer_declaration.item_range.to_offset(¤t_buffer), - &excerpt.range, - ) - .is_none() - .then(|| { + let overlaps_excerpt = + range_intersection(&buffer_declaration.item_range, &excerpt.range) + .is_some(); + if overlaps_excerpt + || excerpt + .parent_declarations + .iter() + .any(|(excerpt_parent, _)| excerpt_parent == &declaration_id) + { + None + } else { let declaration_line = buffer_declaration .item_range .start .to_point(current_buffer) .row; - ( + Some(( true, (cursor_point.row as i32 - declaration_line as i32) .unsigned_abs(), declaration, - ) - }) + )) + } } else { Some((false, u32::MAX, declaration)) } @@ -238,7 +244,7 @@ fn score_snippet( let adjacent_vs_signature_weighted_overlap = weighted_overlap_coefficient(adjacent_identifier_occurrences, &item_signature_occurrences); - let score_components = ScoreInputs { + let score_components = ScoreComponents { is_same_file, is_referenced_nearby, is_referenced_in_breadcrumb, @@ -261,51 +267,30 @@ fn score_snippet( Some(ScoredSnippet { identifier: identifier.clone(), declaration: declaration, - scores: score_components.score(), + scores: Scores::score(&score_components), score_components, }) } -#[derive(Clone, Debug, Serialize)] -pub struct ScoreInputs { - pub is_same_file: bool, - pub is_referenced_nearby: bool, - pub is_referenced_in_breadcrumb: bool, - pub reference_count: usize, - pub same_file_declaration_count: usize, - pub declaration_count: usize, - pub reference_line_distance: u32, - pub declaration_line_distance: u32, - pub declaration_line_distance_rank: usize, - pub containing_range_vs_item_jaccard: f32, - pub containing_range_vs_signature_jaccard: f32, - pub adjacent_vs_item_jaccard: f32, - pub adjacent_vs_signature_jaccard: f32, - pub containing_range_vs_item_weighted_overlap: f32, - pub containing_range_vs_signature_weighted_overlap: f32, - pub adjacent_vs_item_weighted_overlap: f32, - pub adjacent_vs_signature_weighted_overlap: f32, -} - #[derive(Clone, Debug, Serialize)] pub struct Scores { pub signature: f32, pub declaration: f32, } -impl ScoreInputs { - fn score(&self) -> Scores { +impl Scores { + fn score(components: &ScoreComponents) -> Scores { // Score related to how likely this is the correct declaration, range 0 to 1 - let accuracy_score = if self.is_same_file { + let accuracy_score = if components.is_same_file { // TODO: use declaration_line_distance_rank - 1.0 / self.same_file_declaration_count as f32 + 1.0 / components.same_file_declaration_count as f32 } else { - 1.0 / self.declaration_count as f32 + 1.0 / components.declaration_count as f32 }; // Score related to the distance between the reference and cursor, range 0 to 1 - let distance_score = if self.is_referenced_nearby { - 1.0 / (1.0 + self.reference_line_distance as f32 / 10.0).powf(2.0) + let distance_score = if components.is_referenced_nearby { + 1.0 / (1.0 + components.reference_line_distance as f32 / 10.0).powf(2.0) } else { // same score as ~14 lines away, rationale is to not overly penalize references from parent signatures 0.5 @@ -315,10 +300,12 @@ impl ScoreInputs { let combined_score = 10.0 * accuracy_score * distance_score; Scores { - signature: combined_score * self.containing_range_vs_signature_weighted_overlap, + signature: combined_score * components.containing_range_vs_signature_weighted_overlap, // declaration score gets boosted both by being multiplied by 2 and by there being more // weighted overlap. - declaration: 2.0 * combined_score * self.containing_range_vs_item_weighted_overlap, + declaration: 2.0 + * combined_score + * components.containing_range_vs_item_weighted_overlap, } } } diff --git a/crates/edit_prediction_context/src/edit_prediction_context.rs b/crates/edit_prediction_context/src/edit_prediction_context.rs index aed2953777d82d65b7e9cb42229d78634d5e4a3d..41f4b02aa616587e5338ce76b6bc783898fb88bf 100644 --- a/crates/edit_prediction_context/src/edit_prediction_context.rs +++ b/crates/edit_prediction_context/src/edit_prediction_context.rs @@ -6,8 +6,8 @@ mod reference; mod syntax_index; mod text_similarity; -use std::time::Instant; - +use cloud_llm_client::predict_edits_v3::{self, Signature}; +use collections::HashMap; pub use declaration::{BufferDeclaration, Declaration, FileDeclaration, Identifier}; pub use declaration_scoring::SnippetStyle; pub use excerpt::{EditPredictionExcerpt, EditPredictionExcerptOptions, EditPredictionExcerptText}; @@ -18,14 +18,17 @@ pub use reference::references_in_excerpt; pub use syntax_index::SyntaxIndex; use text::{Point, ToOffset as _}; -use crate::declaration_scoring::{ScoredSnippet, scored_snippets}; +use crate::{ + declaration::DeclarationId, + declaration_scoring::{ScoredSnippet, scored_snippets}, + syntax_index::SyntaxIndexState, +}; #[derive(Debug)] pub struct EditPredictionContext { pub excerpt: EditPredictionExcerpt, pub excerpt_text: EditPredictionExcerptText, pub snippets: Vec, - pub retrieval_duration: std::time::Duration, } impl EditPredictionContext { @@ -36,34 +39,135 @@ impl EditPredictionContext { syntax_index: Entity, cx: &mut App, ) -> Task> { - let start = Instant::now(); let index_state = syntax_index.read_with(cx, |index, _cx| index.state().clone()); cx.background_spawn(async move { let index_state = index_state.lock().await; + Self::gather_context(cursor_point, buffer, excerpt_options, &index_state) + }) + } - let excerpt = - EditPredictionExcerpt::select_from_buffer(cursor_point, &buffer, &excerpt_options)?; - let excerpt_text = excerpt.text(&buffer); - let references = references_in_excerpt(&excerpt, &excerpt_text, &buffer); - let cursor_offset = cursor_point.to_offset(&buffer); - - let snippets = scored_snippets( - &index_state, - &excerpt, - &excerpt_text, - references, - cursor_offset, - &buffer, - ); - - Some(Self { - excerpt, - excerpt_text, - snippets, - retrieval_duration: start.elapsed(), - }) + fn gather_context( + cursor_point: Point, + buffer: BufferSnapshot, + excerpt_options: EditPredictionExcerptOptions, + index_state: &SyntaxIndexState, + ) -> Option { + let excerpt = EditPredictionExcerpt::select_from_buffer( + cursor_point, + &buffer, + &excerpt_options, + Some(index_state), + )?; + let excerpt_text = excerpt.text(&buffer); + let references = references_in_excerpt(&excerpt, &excerpt_text, &buffer); + let cursor_offset = cursor_point.to_offset(&buffer); + + let snippets = scored_snippets( + &index_state, + &excerpt, + &excerpt_text, + references, + cursor_offset, + &buffer, + ); + + Some(Self { + excerpt, + excerpt_text, + snippets, }) } + + pub fn cloud_request( + cursor_point: Point, + buffer: BufferSnapshot, + excerpt_options: EditPredictionExcerptOptions, + syntax_index: Entity, + cx: &mut App, + ) -> Task> { + let index_state = syntax_index.read_with(cx, |index, _cx| index.state().clone()); + cx.background_spawn(async move { + let index_state = index_state.lock().await; + Self::gather_context(cursor_point, buffer, excerpt_options, &index_state) + .map(|context| context.into_cloud_request(&index_state)) + }) + } + + pub fn into_cloud_request(self, index: &SyntaxIndexState) -> predict_edits_v3::Body { + let mut signatures = Vec::new(); + let mut declaration_to_signature_index = HashMap::default(); + let mut referenced_declarations = Vec::new(); + let excerpt_parent = self + .excerpt + .parent_declarations + .last() + .and_then(|(parent, _)| { + add_signature( + *parent, + &mut declaration_to_signature_index, + &mut signatures, + index, + ) + }); + for snippet in self.snippets { + let parent_index = snippet.declaration.parent().and_then(|parent| { + add_signature( + parent, + &mut declaration_to_signature_index, + &mut signatures, + index, + ) + }); + let (text, text_is_truncated) = snippet.declaration.item_text(); + referenced_declarations.push(predict_edits_v3::ReferencedDeclaration { + text: text.into(), + text_is_truncated, + signature_range: snippet.declaration.signature_range_in_item_text(), + parent_index, + score_components: snippet.score_components, + signature_score: snippet.scores.signature, + declaration_score: snippet.scores.declaration, + }); + } + predict_edits_v3::Body { + excerpt: self.excerpt_text.body, + referenced_declarations, + signatures, + excerpt_parent, + // todo! + events: vec![], + can_collect_data: false, + diagnostic_groups: None, + git_info: None, + } + } +} + +fn add_signature( + declaration_id: DeclarationId, + declaration_to_signature_index: &mut HashMap, + signatures: &mut Vec, + index: &SyntaxIndexState, +) -> Option { + if let Some(signature_index) = declaration_to_signature_index.get(&declaration_id) { + return Some(*signature_index); + } + let Some(parent_declaration) = index.declaration(declaration_id) else { + log::error!("bug: missing parent declaration"); + return None; + }; + let parent_index = parent_declaration.parent().and_then(|parent| { + add_signature(parent, declaration_to_signature_index, signatures, index) + }); + let (text, text_is_truncated) = parent_declaration.signature_text(); + let signature_index = signatures.len(); + signatures.push(Signature { + text: text.into(), + text_is_truncated, + parent_index, + }); + declaration_to_signature_index.insert(declaration_id, signature_index); + Some(signature_index) } #[cfg(test)] @@ -105,10 +209,9 @@ mod tests { cursor_point, buffer_snapshot, EditPredictionExcerptOptions { - max_bytes: 40, + max_bytes: 60, min_bytes: 10, target_before_cursor_over_total_bytes: 0.5, - include_parent_signatures: false, }, index, cx, @@ -117,8 +220,13 @@ mod tests { .await .unwrap(); - assert_eq!(context.snippets.len(), 1); - assert_eq!(context.snippets[0].identifier.name.as_ref(), "process_data"); + let mut snippet_identifiers = context + .snippets + .iter() + .map(|snippet| snippet.identifier.name.as_ref()) + .collect::>(); + snippet_identifiers.sort(); + assert_eq!(snippet_identifiers, vec!["main", "process_data"]); drop(buffer); } diff --git a/crates/edit_prediction_context/src/excerpt.rs b/crates/edit_prediction_context/src/excerpt.rs index d27e75bd129147986d2359585d5e048704d8828f..764c9040f247561ba6058dfd2954f42085297a4c 100644 --- a/crates/edit_prediction_context/src/excerpt.rs +++ b/crates/edit_prediction_context/src/excerpt.rs @@ -1,9 +1,11 @@ use language::BufferSnapshot; use std::ops::Range; -use text::{OffsetRangeExt as _, Point, ToOffset as _, ToPoint as _}; +use text::{Point, ToOffset as _, ToPoint as _}; use tree_sitter::{Node, TreeCursor}; use util::RangeExt; +use crate::{BufferDeclaration, declaration::DeclarationId, syntax_index::SyntaxIndexState}; + // TODO: // // - Test parent signatures @@ -27,14 +29,12 @@ pub struct EditPredictionExcerptOptions { pub min_bytes: usize, /// Target ratio of bytes before the cursor divided by total bytes in the window. pub target_before_cursor_over_total_bytes: f32, - /// Whether to include parent signatures - pub include_parent_signatures: bool, } #[derive(Debug, Clone)] pub struct EditPredictionExcerpt { pub range: Range, - pub parent_signature_ranges: Vec>, + pub parent_declarations: Vec<(DeclarationId, Range)>, pub size: usize, } @@ -50,9 +50,9 @@ impl EditPredictionExcerpt { .text_for_range(self.range.clone()) .collect::(); let parent_signatures = self - .parent_signature_ranges + .parent_declarations .iter() - .map(|range| buffer.text_for_range(range.clone()).collect::()) + .map(|(_, range)| buffer.text_for_range(range.clone()).collect::()) .collect(); EditPredictionExcerptText { body, @@ -62,8 +62,9 @@ impl EditPredictionExcerpt { /// Selects an excerpt around a buffer position, attempting to choose logical boundaries based /// on TreeSitter structure and approximately targeting a goal ratio of bytesbefore vs after the - /// cursor. When `include_parent_signatures` is true, the excerpt also includes the signatures - /// of parent outline items. + /// cursor. + /// + /// When `index` is provided, the excerpt will include the signatures of parent outline items. /// /// First tries to use AST node boundaries to select the excerpt, and falls back on line-based /// expansion. @@ -73,6 +74,7 @@ impl EditPredictionExcerpt { query_point: Point, buffer: &BufferSnapshot, options: &EditPredictionExcerptOptions, + syntax_index: Option<&SyntaxIndexState>, ) -> Option { if buffer.len() <= options.max_bytes { log::debug!( @@ -90,17 +92,9 @@ impl EditPredictionExcerpt { return None; } - // TODO: Don't compute text / annotation_range / skip converting to and from anchors. - let outline_items = if options.include_parent_signatures { - buffer - .outline_items_containing(query_range.clone(), false, None) - .into_iter() - .flat_map(|item| { - Some(ExcerptOutlineItem { - item_range: item.range.to_offset(&buffer), - signature_range: item.signature_range?.to_offset(&buffer), - }) - }) + let parent_declarations = if let Some(syntax_index) = syntax_index { + syntax_index + .buffer_declarations_containing_range(buffer.remote_id(), query_range.clone()) .collect() } else { Vec::new() @@ -109,7 +103,7 @@ impl EditPredictionExcerpt { let excerpt_selector = ExcerptSelector { query_offset, query_range, - outline_items: &outline_items, + parent_declarations: &parent_declarations, buffer, options, }; @@ -132,15 +126,15 @@ impl EditPredictionExcerpt { excerpt_selector.select_lines() } - fn new(range: Range, parent_signature_ranges: Vec>) -> Self { + fn new(range: Range, parent_declarations: Vec<(DeclarationId, Range)>) -> Self { let size = range.len() - + parent_signature_ranges + + parent_declarations .iter() - .map(|r| r.len()) + .map(|(_, range)| range.len()) .sum::(); Self { range, - parent_signature_ranges, + parent_declarations, size, } } @@ -150,20 +144,14 @@ impl EditPredictionExcerpt { // this is an issue because parent_signature_ranges may be incorrect log::error!("bug: with_expanded_range called with disjoint range"); } - let mut parent_signature_ranges = Vec::with_capacity(self.parent_signature_ranges.len()); - let mut size = new_range.len(); - for range in &self.parent_signature_ranges { - if range.contains_inclusive(&new_range) { + let mut parent_declarations = Vec::with_capacity(self.parent_declarations.len()); + for (declaration_id, range) in &self.parent_declarations { + if !range.contains_inclusive(&new_range) { break; } - parent_signature_ranges.push(range.clone()); - size += range.len(); - } - Self { - range: new_range, - parent_signature_ranges, - size, + parent_declarations.push((*declaration_id, range.clone())); } + Self::new(new_range, parent_declarations) } fn parent_signatures_size(&self) -> usize { @@ -174,16 +162,11 @@ impl EditPredictionExcerpt { struct ExcerptSelector<'a> { query_offset: usize, query_range: Range, - outline_items: &'a [ExcerptOutlineItem], + parent_declarations: &'a [(DeclarationId, &'a BufferDeclaration)], buffer: &'a BufferSnapshot, options: &'a EditPredictionExcerptOptions, } -struct ExcerptOutlineItem { - item_range: Range, - signature_range: Range, -} - impl<'a> ExcerptSelector<'a> { /// Finds the largest node that is smaller than the window size and contains `query_range`. fn select_tree_sitter_nodes(&self) -> Option { @@ -396,13 +379,13 @@ impl<'a> ExcerptSelector<'a> { } fn make_excerpt(&self, range: Range) -> EditPredictionExcerpt { - let parent_signature_ranges = self - .outline_items + let parent_declarations = self + .parent_declarations .iter() - .filter(|item| item.item_range.contains_inclusive(&range)) - .map(|item| item.signature_range.clone()) + .filter(|(_, declaration)| declaration.item_range.contains_inclusive(&range)) + .map(|(id, declaration)| (*id, declaration.signature_range.clone())) .collect(); - EditPredictionExcerpt::new(range, parent_signature_ranges) + EditPredictionExcerpt::new(range, parent_declarations) } /// Returns `true` if the `forward` excerpt is a better choice than the `backward` excerpt. @@ -493,8 +476,9 @@ mod tests { let buffer = create_buffer(&text, cx); let cursor_point = cursor.to_point(&buffer); - let excerpt = EditPredictionExcerpt::select_from_buffer(cursor_point, &buffer, &options) - .expect("Should select an excerpt"); + let excerpt = + EditPredictionExcerpt::select_from_buffer(cursor_point, &buffer, &options, None) + .expect("Should select an excerpt"); pretty_assertions::assert_eq!( generate_marked_text(&text, std::slice::from_ref(&excerpt.range), false), generate_marked_text(&text, &[expected_excerpt], false) @@ -517,7 +501,6 @@ fn main() { max_bytes: 20, min_bytes: 10, target_before_cursor_over_total_bytes: 0.5, - include_parent_signatures: false, }; check_example(options, text, cx); @@ -541,7 +524,6 @@ fn bar() {}"#; max_bytes: 65, min_bytes: 10, target_before_cursor_over_total_bytes: 0.5, - include_parent_signatures: false, }; check_example(options, text, cx); @@ -561,7 +543,6 @@ fn main() { max_bytes: 50, min_bytes: 10, target_before_cursor_over_total_bytes: 0.5, - include_parent_signatures: false, }; check_example(options, text, cx); @@ -583,7 +564,6 @@ fn main() { max_bytes: 60, min_bytes: 45, target_before_cursor_over_total_bytes: 0.5, - include_parent_signatures: false, }; check_example(options, text, cx); @@ -608,7 +588,6 @@ fn main() { max_bytes: 120, min_bytes: 10, target_before_cursor_over_total_bytes: 0.6, - include_parent_signatures: false, }; check_example(options, text, cx); diff --git a/crates/edit_prediction_context/src/reference.rs b/crates/edit_prediction_context/src/reference.rs index abb7dd75dd569f7a14bea807bdc64441d0e64871..268f8c39ef84ba29593f502aff7e818e931cc873 100644 --- a/crates/edit_prediction_context/src/reference.rs +++ b/crates/edit_prediction_context/src/reference.rs @@ -33,8 +33,8 @@ pub fn references_in_excerpt( snapshot, ); - for (range, text) in excerpt - .parent_signature_ranges + for ((_, range), text) in excerpt + .parent_declarations .iter() .zip(excerpt_text.parent_signatures.iter()) { diff --git a/crates/edit_prediction_context/src/syntax_index.rs b/crates/edit_prediction_context/src/syntax_index.rs index 64982f5805f08a3ba791578e28778f0c8399fde8..bd80245c6b7ae51a634e9718f6516bac275abfef 100644 --- a/crates/edit_prediction_context/src/syntax_index.rs +++ b/crates/edit_prediction_context/src/syntax_index.rs @@ -1,5 +1,3 @@ -use std::sync::Arc; - use collections::{HashMap, HashSet}; use futures::lock::Mutex; use gpui::{App, AppContext as _, Context, Entity, Task, WeakEntity}; @@ -8,8 +6,11 @@ use project::buffer_store::{BufferStore, BufferStoreEvent}; use project::worktree_store::{WorktreeStore, WorktreeStoreEvent}; use project::{PathChange, Project, ProjectEntryId, ProjectPath}; use slotmap::SlotMap; +use std::iter; +use std::ops::Range; +use std::sync::Arc; use text::BufferId; -use util::{debug_panic, some_or_debug_panic}; +use util::{RangeExt as _, debug_panic, some_or_debug_panic}; use crate::declaration::{ BufferDeclaration, Declaration, DeclarationId, FileDeclaration, Identifier, @@ -432,7 +433,7 @@ impl SyntaxIndexState { pub fn declarations_for_identifier( &self, identifier: &Identifier, - ) -> Vec { + ) -> Vec<(DeclarationId, &Declaration)> { // make sure to not have a large stack allocation assert!(N < 32); @@ -454,7 +455,7 @@ impl SyntaxIndexState { project_entry_id, .. } => { included_buffer_entry_ids.push(*project_entry_id); - result.push(declaration.clone()); + result.push((*declaration_id, declaration)); if result.len() == N { return Vec::new(); } @@ -463,19 +464,19 @@ impl SyntaxIndexState { project_entry_id, .. } => { if !included_buffer_entry_ids.contains(&project_entry_id) { - file_declarations.push(declaration.clone()); + file_declarations.push((*declaration_id, declaration)); } } } } - for declaration in file_declarations { + for (declaration_id, declaration) in file_declarations { match declaration { Declaration::File { project_entry_id, .. } => { if !included_buffer_entry_ids.contains(&project_entry_id) { - result.push(declaration); + result.push((declaration_id, declaration)); if result.len() == N { return Vec::new(); @@ -489,6 +490,35 @@ impl SyntaxIndexState { result } + pub fn buffer_declarations_containing_range( + &self, + buffer_id: BufferId, + range: Range, + ) -> impl Iterator { + let Some(buffer_state) = self.buffers.get(&buffer_id) else { + return itertools::Either::Left(iter::empty()); + }; + + let iter = buffer_state + .declarations + .iter() + .filter_map(move |declaration_id| { + let Some(declaration) = self + .declarations + .get(*declaration_id) + .and_then(|d| d.as_buffer()) + else { + log::error!("bug: missing buffer outline declaration"); + return None; + }; + if declaration.item_range.contains_inclusive(&range) { + return Some((*declaration_id, declaration)); + } + return None; + }); + itertools::Either::Right(iter) + } + pub fn file_declaration_count(&self, declaration: &Declaration) -> usize { match declaration { Declaration::File { @@ -553,11 +583,11 @@ mod tests { let decls = index_state.declarations_for_identifier::<8>(&main); assert_eq!(decls.len(), 2); - let decl = expect_file_decl("c.rs", &decls[0], &project, cx); + let decl = expect_file_decl("c.rs", &decls[0].1, &project, cx); assert_eq!(decl.identifier, main.clone()); assert_eq!(decl.item_range_in_file, 32..280); - let decl = expect_file_decl("a.rs", &decls[1], &project, cx); + let decl = expect_file_decl("a.rs", &decls[1].1, &project, cx); assert_eq!(decl.identifier, main); assert_eq!(decl.item_range_in_file, 0..98); }); @@ -577,7 +607,7 @@ mod tests { let decls = index_state.declarations_for_identifier::<8>(&test_process_data); assert_eq!(decls.len(), 1); - let decl = expect_file_decl("c.rs", &decls[0], &project, cx); + let decl = expect_file_decl("c.rs", &decls[0].1, &project, cx); assert_eq!(decl.identifier, test_process_data); let parent_id = decl.parent.unwrap(); @@ -618,7 +648,7 @@ mod tests { let decls = index_state.declarations_for_identifier::<8>(&test_process_data); assert_eq!(decls.len(), 1); - let decl = expect_buffer_decl("c.rs", &decls[0], &project, cx); + let decl = expect_buffer_decl("c.rs", &decls[0].1, &project, cx); assert_eq!(decl.identifier, test_process_data); let parent_id = decl.parent.unwrap(); @@ -676,11 +706,11 @@ mod tests { cx.update(|cx| { let decls = index_state.declarations_for_identifier::<8>(&main); assert_eq!(decls.len(), 2); - let decl = expect_buffer_decl("c.rs", &decls[0], &project, cx); + let decl = expect_buffer_decl("c.rs", &decls[0].1, &project, cx); assert_eq!(decl.identifier, main); assert_eq!(decl.item_range.to_offset(&buffer.read(cx)), 32..280); - expect_file_decl("a.rs", &decls[1], &project, cx); + expect_file_decl("a.rs", &decls[1].1, &project, cx); }); } @@ -695,8 +725,8 @@ mod tests { cx.update(|cx| { let decls = index_state.declarations_for_identifier::<8>(&main); assert_eq!(decls.len(), 2); - expect_file_decl("c.rs", &decls[0], &project, cx); - expect_file_decl("a.rs", &decls[1], &project, cx); + expect_file_decl("c.rs", &decls[0].1, &project, cx); + expect_file_decl("a.rs", &decls[1].1, &project, cx); }); } diff --git a/crates/edit_prediction_tools/src/edit_prediction_tools.rs b/crates/edit_prediction_tools/src/edit_prediction_tools.rs index f00a16e026704f1d1da318956f41128a9783a54c..ac84689baa84fa337a75c29f58d8d42f9060fe6f 100644 --- a/crates/edit_prediction_tools/src/edit_prediction_tools.rs +++ b/crates/edit_prediction_tools/src/edit_prediction_tools.rs @@ -4,7 +4,7 @@ use std::{ path::{Path, PathBuf}, str::FromStr, sync::Arc, - time::Duration, + time::{Duration, Instant}, }; use collections::HashMap; @@ -195,6 +195,8 @@ impl EditPredictionTools { .timer(Duration::from_millis(50)) .await; + let mut start_time = None; + let Ok(task) = this.update(cx, |this, cx| { fn number_input_value( input: &Entity, @@ -216,10 +218,10 @@ impl EditPredictionTools { &this.cursor_context_ratio_input, cx, ), - // TODO Display and add to options - include_parent_signatures: false, }; + start_time = Some(Instant::now()); + EditPredictionContext::gather( cursor_position, current_buffer_snapshot, @@ -243,6 +245,7 @@ impl EditPredictionTools { .ok(); return; }; + let retrieval_duration = start_time.unwrap().elapsed(); let mut languages = HashMap::default(); for snippet in context.snippets.iter() { @@ -320,7 +323,7 @@ impl EditPredictionTools { this.last_context = Some(ContextState { context_editor, - retrieval_duration: context.retrieval_duration, + retrieval_duration, }); cx.notify(); })