From da71465437128e73a5ab401516e46e8ca1c8fa05 Mon Sep 17 00:00:00 2001 From: Michael Sloan Date: Fri, 26 Sep 2025 01:57:28 -0600 Subject: [PATCH] edit_prediction_context: Minor optimization of text similarity + some renames (#38941) Release Notes: - N/A --- Cargo.lock | 1 + Cargo.toml | 1 + .../cloud_llm_client/src/predict_edits_v3.rs | 12 +- .../src/cloud_zeta2_prompt.rs | 46 +++--- crates/edit_prediction_context/Cargo.toml | 1 + .../src/declaration_scoring.rs | 143 ++++++++---------- .../src/edit_prediction_context.rs | 23 ++- .../src/text_similarity.rs | 137 ++++++++++------- crates/zeta2/src/zeta2.rs | 2 +- crates/zeta2_tools/src/zeta2_tools.rs | 8 +- 10 files changed, 198 insertions(+), 176 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index d33f1d02903b717b479d8dfd745ff75d4d49846a..5bbe2f880ba8fb631ccbe382aa0a029f05a78ce2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5171,6 +5171,7 @@ dependencies = [ "collections", "futures 0.3.31", "gpui", + "hashbrown 0.15.3", "indoc", "itertools 0.14.0", "language", diff --git a/Cargo.toml b/Cargo.toml index 3ad8bd2348858bb051f853703de9a7666fcf26a0..be6699edb6d19efac2cb981cbd75efd64580629e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -511,6 +511,7 @@ futures-lite = "1.13" git2 = { version = "0.20.1", default-features = false } globset = "0.4" handlebars = "4.3" +hashbrown = "0.15.3" heck = "0.5" heed = { version = "0.21.0", features = ["read-txn-no-tls"] } hex = "0.4.3" diff --git a/crates/cloud_llm_client/src/predict_edits_v3.rs b/crates/cloud_llm_client/src/predict_edits_v3.rs index 03bd5359cd01048c2edb5c3b8743916ddc3b4f2d..ec475598245b111e2647c63c3edcddd0d15ee5b8 100644 --- a/crates/cloud_llm_client/src/predict_edits_v3.rs +++ b/crates/cloud_llm_client/src/predict_edits_v3.rs @@ -103,13 +103,13 @@ pub struct ReferencedDeclaration { /// Index within `signatures`. #[serde(skip_serializing_if = "Option::is_none", default)] pub parent_index: Option, - pub score_components: ScoreComponents, + pub score_components: DeclarationScoreComponents, pub signature_score: f32, pub declaration_score: f32, } #[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ScoreComponents { +pub struct DeclarationScoreComponents { pub is_same_file: bool, pub is_referenced_nearby: bool, pub is_referenced_in_breadcrumb: bool, @@ -119,12 +119,12 @@ pub struct ScoreComponents { pub reference_line_distance: u32, pub declaration_line_distance: u32, pub declaration_line_distance_rank: usize, - pub containing_range_vs_item_jaccard: f32, - pub containing_range_vs_signature_jaccard: f32, + pub excerpt_vs_item_jaccard: f32, + pub excerpt_vs_signature_jaccard: f32, pub adjacent_vs_item_jaccard: f32, pub adjacent_vs_signature_jaccard: f32, - pub containing_range_vs_item_weighted_overlap: f32, - pub containing_range_vs_signature_weighted_overlap: f32, + pub excerpt_vs_item_weighted_overlap: f32, + pub excerpt_vs_signature_weighted_overlap: f32, pub adjacent_vs_item_weighted_overlap: f32, pub adjacent_vs_signature_weighted_overlap: f32, } diff --git a/crates/cloud_zeta2_prompt/src/cloud_zeta2_prompt.rs b/crates/cloud_zeta2_prompt/src/cloud_zeta2_prompt.rs index 9c1b64013abd8ade6a951838ded00c36aafba347..42477d9d06cd945d1018cccdfe1ffd6cbf1000cd 100644 --- a/crates/cloud_zeta2_prompt/src/cloud_zeta2_prompt.rs +++ b/crates/cloud_zeta2_prompt/src/cloud_zeta2_prompt.rs @@ -70,7 +70,7 @@ pub struct PlannedSnippet<'a> { } #[derive(EnumIter, Clone, Copy, PartialEq, Eq, Hash, Debug, PartialOrd, Ord)] -pub enum SnippetStyle { +pub enum DeclarationStyle { Signature, Declaration, } @@ -84,10 +84,10 @@ pub struct SectionLabels { impl<'a> PlannedPrompt<'a> { /// Greedy one-pass knapsack algorithm to populate the prompt plan. Does the following: /// - /// Initializes a priority queue by populating it with each snippet, finding the SnippetStyle - /// that minimizes `score_density = score / snippet.range(style).len()`. When a "signature" - /// snippet is popped, insert an entry for the "declaration" variant that reflects the cost of - /// upgrade. + /// Initializes a priority queue by populating it with each snippet, finding the + /// DeclarationStyle that minimizes `score_density = score / snippet.range(style).len()`. When a + /// "signature" snippet is popped, insert an entry for the "declaration" variant that reflects + /// the cost of upgrade. /// /// TODO: Implement an early halting condition. One option might be to have another priority /// queue where the score is the size, and update it accordingly. Another option might be to @@ -131,13 +131,13 @@ impl<'a> PlannedPrompt<'a> { struct QueueEntry { score_density: OrderedFloat, declaration_index: usize, - style: SnippetStyle, + style: DeclarationStyle, } // Initialize priority queue with the best score for each snippet. let mut queue: BinaryHeap = BinaryHeap::new(); for (declaration_index, declaration) in request.referenced_declarations.iter().enumerate() { - let (style, score_density) = SnippetStyle::iter() + let (style, score_density) = DeclarationStyle::iter() .map(|style| { ( style, @@ -186,7 +186,7 @@ impl<'a> PlannedPrompt<'a> { this.budget_used += additional_bytes; this.add_parents(&mut included_parents, additional_parents); let planned_snippet = match queue_entry.style { - SnippetStyle::Signature => { + DeclarationStyle::Signature => { let Some(text) = declaration.text.get(declaration.signature_range.clone()) else { return Err(anyhow!( @@ -203,7 +203,7 @@ impl<'a> PlannedPrompt<'a> { text_is_truncated: declaration.text_is_truncated, } } - SnippetStyle::Declaration => PlannedSnippet { + DeclarationStyle::Declaration => PlannedSnippet { path: declaration.path.clone(), range: declaration.range.clone(), text: &declaration.text, @@ -213,11 +213,13 @@ impl<'a> PlannedPrompt<'a> { this.snippets.push(planned_snippet); // When a Signature is consumed, insert an entry for Definition style. - if queue_entry.style == SnippetStyle::Signature { - let signature_size = declaration_size(&declaration, SnippetStyle::Signature); - let declaration_size = declaration_size(&declaration, SnippetStyle::Declaration); - let signature_score = declaration_score(&declaration, SnippetStyle::Signature); - let declaration_score = declaration_score(&declaration, SnippetStyle::Declaration); + if queue_entry.style == DeclarationStyle::Signature { + let signature_size = declaration_size(&declaration, DeclarationStyle::Signature); + let declaration_size = + declaration_size(&declaration, DeclarationStyle::Declaration); + let signature_score = declaration_score(&declaration, DeclarationStyle::Signature); + let declaration_score = + declaration_score(&declaration, DeclarationStyle::Declaration); let score_diff = declaration_score - signature_score; let size_diff = declaration_size.saturating_sub(signature_size); @@ -225,7 +227,7 @@ impl<'a> PlannedPrompt<'a> { queue.push(QueueEntry { declaration_index: queue_entry.declaration_index, score_density: OrderedFloat(score_diff / (size_diff as f32)), - style: SnippetStyle::Declaration, + style: DeclarationStyle::Declaration, }); } } @@ -510,20 +512,20 @@ impl<'a> PlannedPrompt<'a> { } } -fn declaration_score_density(declaration: &ReferencedDeclaration, style: SnippetStyle) -> f32 { +fn declaration_score_density(declaration: &ReferencedDeclaration, style: DeclarationStyle) -> f32 { declaration_score(declaration, style) / declaration_size(declaration, style) as f32 } -fn declaration_score(declaration: &ReferencedDeclaration, style: SnippetStyle) -> f32 { +fn declaration_score(declaration: &ReferencedDeclaration, style: DeclarationStyle) -> f32 { match style { - SnippetStyle::Signature => declaration.signature_score, - SnippetStyle::Declaration => declaration.declaration_score, + DeclarationStyle::Signature => declaration.signature_score, + DeclarationStyle::Declaration => declaration.declaration_score, } } -fn declaration_size(declaration: &ReferencedDeclaration, style: SnippetStyle) -> usize { +fn declaration_size(declaration: &ReferencedDeclaration, style: DeclarationStyle) -> usize { match style { - SnippetStyle::Signature => declaration.signature_range.len(), - SnippetStyle::Declaration => declaration.text.len(), + DeclarationStyle::Signature => declaration.signature_range.len(), + DeclarationStyle::Declaration => declaration.text.len(), } } diff --git a/crates/edit_prediction_context/Cargo.toml b/crates/edit_prediction_context/Cargo.toml index 75880cad5f3e2807e525908656931853efa19a92..d4321b10fa6ab338c4642ccc6678094f7bb1a385 100644 --- a/crates/edit_prediction_context/Cargo.toml +++ b/crates/edit_prediction_context/Cargo.toml @@ -18,6 +18,7 @@ cloud_llm_client.workspace = true collections.workspace = true futures.workspace = true gpui.workspace = true +hashbrown.workspace = true itertools.workspace = true language.workspace = true log.workspace = true diff --git a/crates/edit_prediction_context/src/declaration_scoring.rs b/crates/edit_prediction_context/src/declaration_scoring.rs index f655387f680a53413161383f0678b21456c271f6..363e61cd21e6cf0432d23a0a50619cf420777fd9 100644 --- a/crates/edit_prediction_context/src/declaration_scoring.rs +++ b/crates/edit_prediction_context/src/declaration_scoring.rs @@ -1,4 +1,4 @@ -use cloud_llm_client::predict_edits_v3::ScoreComponents; +use cloud_llm_client::predict_edits_v3::DeclarationScoreComponents; use itertools::Itertools as _; use language::BufferSnapshot; use ordered_float::OrderedFloat; @@ -8,76 +8,67 @@ use strum::EnumIter; use text::{Point, ToPoint}; use crate::{ - Declaration, EditPredictionExcerpt, EditPredictionExcerptText, Identifier, + Declaration, EditPredictionExcerpt, Identifier, reference::{Reference, ReferenceRegion}, syntax_index::SyntaxIndexState, - text_similarity::{IdentifierOccurrences, jaccard_similarity, weighted_overlap_coefficient}, + text_similarity::{Occurrences, jaccard_similarity, weighted_overlap_coefficient}, }; const MAX_IDENTIFIER_DECLARATION_COUNT: usize = 16; #[derive(Clone, Debug)] -pub struct ScoredSnippet { +pub struct ScoredDeclaration { pub identifier: Identifier, pub declaration: Declaration, - pub score_components: ScoreComponents, - pub scores: Scores, + pub score_components: DeclarationScoreComponents, + pub scores: DeclarationScores, } #[derive(EnumIter, Clone, Copy, PartialEq, Eq, Hash, Debug)] -pub enum SnippetStyle { +pub enum DeclarationStyle { Signature, Declaration, } -impl ScoredSnippet { - /// Returns the score for this snippet with the specified style. - pub fn score(&self, style: SnippetStyle) -> f32 { +impl ScoredDeclaration { + /// Returns the score for this declaration with the specified style. + pub fn score(&self, style: DeclarationStyle) -> f32 { match style { - SnippetStyle::Signature => self.scores.signature, - SnippetStyle::Declaration => self.scores.declaration, + DeclarationStyle::Signature => self.scores.signature, + DeclarationStyle::Declaration => self.scores.declaration, } } - pub fn size(&self, style: SnippetStyle) -> usize { + pub fn size(&self, style: DeclarationStyle) -> usize { match &self.declaration { Declaration::File { declaration, .. } => match style { - SnippetStyle::Signature => declaration.signature_range.len(), - SnippetStyle::Declaration => declaration.text.len(), + DeclarationStyle::Signature => declaration.signature_range.len(), + DeclarationStyle::Declaration => declaration.text.len(), }, Declaration::Buffer { declaration, .. } => match style { - SnippetStyle::Signature => declaration.signature_range.len(), - SnippetStyle::Declaration => declaration.item_range.len(), + DeclarationStyle::Signature => declaration.signature_range.len(), + DeclarationStyle::Declaration => declaration.item_range.len(), }, } } - pub fn score_density(&self, style: SnippetStyle) -> f32 { + pub fn score_density(&self, style: DeclarationStyle) -> f32 { self.score(style) / (self.size(style)) as f32 } } -pub fn scored_snippets( +pub fn scored_declarations( index: &SyntaxIndexState, excerpt: &EditPredictionExcerpt, - excerpt_text: &EditPredictionExcerptText, + excerpt_occurrences: &Occurrences, + adjacent_occurrences: &Occurrences, identifier_to_references: HashMap>, cursor_offset: usize, current_buffer: &BufferSnapshot, -) -> Vec { - let containing_range_identifier_occurrences = - IdentifierOccurrences::within_string(&excerpt_text.body); +) -> Vec { let cursor_point = cursor_offset.to_point(¤t_buffer); - let start_point = Point::new(cursor_point.row.saturating_sub(2), 0); - let end_point = Point::new(cursor_point.row + 1, 0); - let adjacent_identifier_occurrences = IdentifierOccurrences::within_string( - ¤t_buffer - .text_for_range(start_point..end_point) - .collect::(), - ); - - let mut snippets = identifier_to_references + let mut declarations = identifier_to_references .into_iter() .flat_map(|(identifier, references)| { let declarations = @@ -137,7 +128,7 @@ pub fn scored_snippets( )| { let same_file_declaration_count = index.file_declaration_count(declaration); - score_snippet( + score_declaration( &identifier, &references, declaration.clone(), @@ -146,8 +137,8 @@ pub fn scored_snippets( declaration_line_distance_rank, same_file_declaration_count, declaration_count, - &containing_range_identifier_occurrences, - &adjacent_identifier_occurrences, + &excerpt_occurrences, + &adjacent_occurrences, cursor_point, current_buffer, ) @@ -158,14 +149,14 @@ pub fn scored_snippets( .flatten() .collect::>(); - snippets.sort_unstable_by_key(|snippet| { - let score_density = snippet - .score_density(SnippetStyle::Declaration) - .max(snippet.score_density(SnippetStyle::Signature)); + declarations.sort_unstable_by_key(|declaration| { + let score_density = declaration + .score_density(DeclarationStyle::Declaration) + .max(declaration.score_density(DeclarationStyle::Signature)); Reverse(OrderedFloat(score_density)) }); - snippets + declarations } fn range_intersection(a: &Range, b: &Range) -> Option> { @@ -178,7 +169,7 @@ fn range_intersection(a: &Range, b: &Range) -> Option Option { +) -> Option { let is_referenced_nearby = references .iter() .any(|r| r.region == ReferenceRegion::Nearby); @@ -208,37 +199,27 @@ fn score_snippet( .min() .unwrap(); - let item_source_occurrences = IdentifierOccurrences::within_string(&declaration.item_text().0); - let item_signature_occurrences = - IdentifierOccurrences::within_string(&declaration.signature_text().0); - let containing_range_vs_item_jaccard = jaccard_similarity( - containing_range_identifier_occurrences, - &item_source_occurrences, - ); - let containing_range_vs_signature_jaccard = jaccard_similarity( - containing_range_identifier_occurrences, - &item_signature_occurrences, - ); + let item_source_occurrences = Occurrences::within_string(&declaration.item_text().0); + let item_signature_occurrences = Occurrences::within_string(&declaration.signature_text().0); + let excerpt_vs_item_jaccard = jaccard_similarity(excerpt_occurrences, &item_source_occurrences); + let excerpt_vs_signature_jaccard = + jaccard_similarity(excerpt_occurrences, &item_signature_occurrences); let adjacent_vs_item_jaccard = - jaccard_similarity(adjacent_identifier_occurrences, &item_source_occurrences); + jaccard_similarity(adjacent_occurrences, &item_source_occurrences); let adjacent_vs_signature_jaccard = - jaccard_similarity(adjacent_identifier_occurrences, &item_signature_occurrences); + jaccard_similarity(adjacent_occurrences, &item_signature_occurrences); - let containing_range_vs_item_weighted_overlap = weighted_overlap_coefficient( - containing_range_identifier_occurrences, - &item_source_occurrences, - ); - let containing_range_vs_signature_weighted_overlap = weighted_overlap_coefficient( - containing_range_identifier_occurrences, - &item_signature_occurrences, - ); + let excerpt_vs_item_weighted_overlap = + weighted_overlap_coefficient(excerpt_occurrences, &item_source_occurrences); + let excerpt_vs_signature_weighted_overlap = + weighted_overlap_coefficient(excerpt_occurrences, &item_signature_occurrences); let adjacent_vs_item_weighted_overlap = - weighted_overlap_coefficient(adjacent_identifier_occurrences, &item_source_occurrences); + weighted_overlap_coefficient(adjacent_occurrences, &item_source_occurrences); let adjacent_vs_signature_weighted_overlap = - weighted_overlap_coefficient(adjacent_identifier_occurrences, &item_signature_occurrences); + weighted_overlap_coefficient(adjacent_occurrences, &item_signature_occurrences); // TODO: Consider adding declaration_file_count - let score_components = ScoreComponents { + let score_components = DeclarationScoreComponents { is_same_file, is_referenced_nearby, is_referenced_in_breadcrumb, @@ -248,32 +229,32 @@ fn score_snippet( reference_count, same_file_declaration_count, declaration_count, - containing_range_vs_item_jaccard, - containing_range_vs_signature_jaccard, + excerpt_vs_item_jaccard, + excerpt_vs_signature_jaccard, adjacent_vs_item_jaccard, adjacent_vs_signature_jaccard, - containing_range_vs_item_weighted_overlap, - containing_range_vs_signature_weighted_overlap, + excerpt_vs_item_weighted_overlap, + excerpt_vs_signature_weighted_overlap, adjacent_vs_item_weighted_overlap, adjacent_vs_signature_weighted_overlap, }; - Some(ScoredSnippet { + Some(ScoredDeclaration { identifier: identifier.clone(), declaration: declaration, - scores: Scores::score(&score_components), + scores: DeclarationScores::score(&score_components), score_components, }) } #[derive(Clone, Debug, Serialize)] -pub struct Scores { +pub struct DeclarationScores { pub signature: f32, pub declaration: f32, } -impl Scores { - fn score(components: &ScoreComponents) -> Scores { +impl DeclarationScores { + fn score(components: &DeclarationScoreComponents) -> DeclarationScores { // TODO: handle truncation // Score related to how likely this is the correct declaration, range 0 to 1 @@ -295,13 +276,11 @@ impl Scores { // For now instead of linear combination, the scores are just multiplied together. let combined_score = 10.0 * accuracy_score * distance_score; - Scores { - signature: combined_score * components.containing_range_vs_signature_weighted_overlap, + DeclarationScores { + signature: combined_score * components.excerpt_vs_signature_weighted_overlap, // declaration score gets boosted both by being multiplied by 2 and by there being more // weighted overlap. - declaration: 2.0 - * combined_score - * components.containing_range_vs_item_weighted_overlap, + declaration: 2.0 * combined_score * components.excerpt_vs_item_weighted_overlap, } } } diff --git a/crates/edit_prediction_context/src/edit_prediction_context.rs b/crates/edit_prediction_context/src/edit_prediction_context.rs index f3752abab991493660d197fe871838a71f6c8ad1..1118e64eddbbb0bf5bd3a6fb95f41fa962b499bb 100644 --- a/crates/edit_prediction_context/src/edit_prediction_context.rs +++ b/crates/edit_prediction_context/src/edit_prediction_context.rs @@ -21,7 +21,7 @@ pub struct EditPredictionContext { pub excerpt: EditPredictionExcerpt, pub excerpt_text: EditPredictionExcerptText, pub cursor_offset_in_excerpt: usize, - pub snippets: Vec, + pub declarations: Vec, } impl EditPredictionContext { @@ -58,17 +58,28 @@ impl EditPredictionContext { index_state, )?; let excerpt_text = excerpt.text(buffer); + let excerpt_occurrences = text_similarity::Occurrences::within_string(&excerpt_text.body); + + let adjacent_start = Point::new(cursor_point.row.saturating_sub(2), 0); + let adjacent_end = Point::new(cursor_point.row + 1, 0); + let adjacent_occurrences = text_similarity::Occurrences::within_string( + &buffer + .text_for_range(adjacent_start..adjacent_end) + .collect::(), + ); + let cursor_offset_in_file = cursor_point.to_offset(buffer); // TODO fix this to not need saturating_sub let cursor_offset_in_excerpt = cursor_offset_in_file.saturating_sub(excerpt.range.start); - let snippets = if let Some(index_state) = index_state { + let declarations = if let Some(index_state) = index_state { let references = references_in_excerpt(&excerpt, &excerpt_text, buffer); - scored_snippets( + scored_declarations( &index_state, &excerpt, - &excerpt_text, + &excerpt_occurrences, + &adjacent_occurrences, references, cursor_offset_in_file, buffer, @@ -81,7 +92,7 @@ impl EditPredictionContext { excerpt, excerpt_text, cursor_offset_in_excerpt, - snippets, + declarations, }) } } @@ -137,7 +148,7 @@ mod tests { .unwrap(); let mut snippet_identifiers = context - .snippets + .declarations .iter() .map(|snippet| snippet.identifier.name.as_ref()) .collect::>(); diff --git a/crates/edit_prediction_context/src/text_similarity.rs b/crates/edit_prediction_context/src/text_similarity.rs index 2ace7bf10cc6fd13b8a5636212211a3274d3c259..99d8fb4dd191bbec1b8c695f274a0024c6cb32ae 100644 --- a/crates/edit_prediction_context/src/text_similarity.rs +++ b/crates/edit_prediction_context/src/text_similarity.rs @@ -1,5 +1,9 @@ +use hashbrown::HashTable; use regex::Regex; -use std::{collections::HashMap, sync::LazyLock}; +use std::{ + hash::{Hash, Hasher as _}, + sync::LazyLock, +}; use crate::reference::Reference; @@ -14,49 +18,76 @@ use crate::reference::Reference; static IDENTIFIER_REGEX: LazyLock = LazyLock::new(|| Regex::new(r"\b\w+\b").unwrap()); -// TODO: use &str or Cow keys? -#[derive(Debug)] -pub struct IdentifierOccurrences { - identifier_to_count: HashMap, +/// Multiset of text occurrences for text similarity that only stores hashes and counts. +#[derive(Debug, Default)] +pub struct Occurrences { + table: HashTable, total_count: usize, } -impl IdentifierOccurrences { - pub fn within_string(code: &str) -> Self { - Self::from_iterator(IDENTIFIER_REGEX.find_iter(code).map(|mat| mat.as_str())) +#[derive(Debug)] +struct OccurrenceEntry { + hash: u64, + count: usize, +} + +impl Occurrences { + pub fn within_string(text: &str) -> Self { + Self::from_identifiers(IDENTIFIER_REGEX.find_iter(text).map(|mat| mat.as_str())) } #[allow(dead_code)] pub fn within_references(references: &[Reference]) -> Self { - Self::from_iterator( + Self::from_identifiers( references .iter() .map(|reference| reference.identifier.name.as_ref()), ) } - pub fn from_iterator<'a>(identifier_iterator: impl Iterator) -> Self { - let mut identifier_to_count = HashMap::new(); - let mut total_count = 0; - for identifier in identifier_iterator { - // TODO: Score matches that match case higher? - // - // TODO: Also include unsplit identifier? + pub fn from_identifiers<'a>(identifiers: impl IntoIterator) -> Self { + let mut this = Self::default(); + // TODO: Score matches that match case higher? + // + // TODO: Also include unsplit identifier? + for identifier in identifiers { for identifier_part in split_identifier(identifier) { - identifier_to_count - .entry(identifier_part.to_lowercase()) - .and_modify(|count| *count += 1) - .or_insert(1); - total_count += 1; + this.add_hash(fx_hash(&identifier_part.to_lowercase())); } } - IdentifierOccurrences { - identifier_to_count, - total_count, - } + this + } + + fn add_hash(&mut self, hash: u64) { + self.table + .entry( + hash, + |entry: &OccurrenceEntry| entry.hash == hash, + |entry| entry.hash, + ) + .and_modify(|entry| entry.count += 1) + .or_insert(OccurrenceEntry { hash, count: 1 }); + self.total_count += 1; + } + + fn contains_hash(&self, hash: u64) -> bool { + self.get_count(hash) != 0 + } + + fn get_count(&self, hash: u64) -> usize { + self.table + .find(hash, |entry| entry.hash == hash) + .map(|entry| entry.count) + .unwrap_or(0) } } +pub fn fx_hash(data: &T) -> u64 { + let mut hasher = collections::FxHasher::default(); + data.hash(&mut hasher); + hasher.finish() +} + // Splits camelcase / snakecase / kebabcase / pascalcase // // TODO: Make this more efficient / elegant. @@ -115,54 +146,49 @@ fn split_identifier(identifier: &str) -> Vec<&str> { parts.into_iter().filter(|s| !s.is_empty()).collect() } -pub fn jaccard_similarity<'a>( - mut set_a: &'a IdentifierOccurrences, - mut set_b: &'a IdentifierOccurrences, -) -> f32 { - if set_a.identifier_to_count.len() > set_b.identifier_to_count.len() { +pub fn jaccard_similarity<'a>(mut set_a: &'a Occurrences, mut set_b: &'a Occurrences) -> f32 { + if set_a.table.len() > set_b.table.len() { std::mem::swap(&mut set_a, &mut set_b); } let intersection = set_a - .identifier_to_count - .keys() - .filter(|key| set_b.identifier_to_count.contains_key(*key)) + .table + .iter() + .filter(|entry| set_b.contains_hash(entry.hash)) .count(); - let union = set_a.identifier_to_count.len() + set_b.identifier_to_count.len() - intersection; + let union = set_a.table.len() + set_b.table.len() - intersection; intersection as f32 / union as f32 } // TODO #[allow(dead_code)] -pub fn overlap_coefficient<'a>( - mut set_a: &'a IdentifierOccurrences, - mut set_b: &'a IdentifierOccurrences, -) -> f32 { - if set_a.identifier_to_count.len() > set_b.identifier_to_count.len() { +pub fn overlap_coefficient<'a>(mut set_a: &'a Occurrences, mut set_b: &'a Occurrences) -> f32 { + if set_a.table.len() > set_b.table.len() { std::mem::swap(&mut set_a, &mut set_b); } let intersection = set_a - .identifier_to_count - .keys() - .filter(|key| set_b.identifier_to_count.contains_key(*key)) + .table + .iter() + .filter(|entry| set_b.contains_hash(entry.hash)) .count(); - intersection as f32 / set_a.identifier_to_count.len() as f32 + intersection as f32 / set_a.table.len() as f32 } // TODO #[allow(dead_code)] pub fn weighted_jaccard_similarity<'a>( - mut set_a: &'a IdentifierOccurrences, - mut set_b: &'a IdentifierOccurrences, + mut set_a: &'a Occurrences, + mut set_b: &'a Occurrences, ) -> f32 { - if set_a.identifier_to_count.len() > set_b.identifier_to_count.len() { + if set_a.table.len() > set_b.table.len() { std::mem::swap(&mut set_a, &mut set_b); } let mut numerator = 0; let mut denominator_a = 0; let mut used_count_b = 0; - for (symbol, count_a) in set_a.identifier_to_count.iter() { - let count_b = set_b.identifier_to_count.get(symbol).unwrap_or(&0); + for entry_a in set_a.table.iter() { + let count_a = entry_a.count; + let count_b = set_b.get_count(entry_a.hash); numerator += count_a.min(count_b); denominator_a += count_a.max(count_b); used_count_b += count_b; @@ -177,16 +203,17 @@ pub fn weighted_jaccard_similarity<'a>( } pub fn weighted_overlap_coefficient<'a>( - mut set_a: &'a IdentifierOccurrences, - mut set_b: &'a IdentifierOccurrences, + mut set_a: &'a Occurrences, + mut set_b: &'a Occurrences, ) -> f32 { - if set_a.identifier_to_count.len() > set_b.identifier_to_count.len() { + if set_a.table.len() > set_b.table.len() { std::mem::swap(&mut set_a, &mut set_b); } let mut numerator = 0; - for (symbol, count_a) in set_a.identifier_to_count.iter() { - let count_b = set_b.identifier_to_count.get(symbol).unwrap_or(&0); + for entry_a in set_a.table.iter() { + let count_a = entry_a.count; + let count_b = set_b.get_count(entry_a.hash); numerator += count_a.min(count_b); } @@ -215,12 +242,12 @@ mod test { fn test_similarity_functions() { // 10 identifier parts, 8 unique // Repeats: 2 "outline", 2 "items" - let set_a = IdentifierOccurrences::within_string( + let set_a = Occurrences::within_string( "let mut outline_items = query_outline_items(&language, &tree, &source);", ); // 14 identifier parts, 11 unique // Repeats: 2 "outline", 2 "language", 2 "tree" - let set_b = IdentifierOccurrences::within_string( + let set_b = Occurrences::within_string( "pub fn query_outline_items(language: &Language, tree: &Tree, source: &str) -> Vec {", ); diff --git a/crates/zeta2/src/zeta2.rs b/crates/zeta2/src/zeta2.rs index b15496a7558ce21de775cc7666382098431d2c21..f58bb963dd4cb4d852bf24f2ee7cbb02abc6efa9 100644 --- a/crates/zeta2/src/zeta2.rs +++ b/crates/zeta2/src/zeta2.rs @@ -733,7 +733,7 @@ fn make_cloud_request( let mut declaration_to_signature_index = HashMap::default(); let mut referenced_declarations = Vec::new(); - for snippet in context.snippets { + for snippet in context.declarations { let project_entry_id = snippet.declaration.project_entry_id(); let Some(path) = worktrees.iter().find_map(|worktree| { worktree.entry_for_id(project_entry_id).map(|entry| { diff --git a/crates/zeta2_tools/src/zeta2_tools.rs b/crates/zeta2_tools/src/zeta2_tools.rs index 741c3e4f69df2024637c81ab31df3d6c0e8e9c65..ac4f27be81243c257efa8a5cc498aa95ce6979d7 100644 --- a/crates/zeta2_tools/src/zeta2_tools.rs +++ b/crates/zeta2_tools/src/zeta2_tools.rs @@ -18,7 +18,7 @@ use util::{ResultExt, paths::PathStyle, rel_path::RelPath}; use workspace::{Item, SplitDirection, Workspace}; use zeta2::{Zeta, ZetaOptions}; -use edit_prediction_context::{EditPredictionExcerptOptions, SnippetStyle}; +use edit_prediction_context::{DeclarationStyle, EditPredictionExcerptOptions}; actions!( dev, @@ -285,7 +285,7 @@ impl Zeta2Inspector { let mut languages = HashMap::default(); for lang_id in prediction .context - .snippets + .declarations .iter() .map(|snippet| snippet.declaration.identifier().language_id) .chain(prediction.context.excerpt_text.language_id) @@ -334,7 +334,7 @@ impl Zeta2Inspector { cx, ); - for snippet in &prediction.context.snippets { + for snippet in &prediction.context.declarations { let path = this .project .read(cx) @@ -345,7 +345,7 @@ impl Zeta2Inspector { "{} (Score density: {})", path.map(|p| p.path.display(path_style).to_string()) .unwrap_or_else(|| "".to_string()), - snippet.score_density(SnippetStyle::Declaration) + snippet.score_density(DeclarationStyle::Declaration) )) .unwrap() .into(),