edit_prediction_context: Minor optimization of text similarity + some renames (#38941)

Michael Sloan created

Release Notes:

- N/A

Change summary

Cargo.lock                                                    |   1 
Cargo.toml                                                    |   1 
crates/cloud_llm_client/src/predict_edits_v3.rs               |  12 
crates/cloud_zeta2_prompt/src/cloud_zeta2_prompt.rs           |  46 
crates/edit_prediction_context/Cargo.toml                     |   1 
crates/edit_prediction_context/src/declaration_scoring.rs     | 143 ++--
crates/edit_prediction_context/src/edit_prediction_context.rs |  23 
crates/edit_prediction_context/src/text_similarity.rs         | 137 ++-
crates/zeta2/src/zeta2.rs                                     |   2 
crates/zeta2_tools/src/zeta2_tools.rs                         |   8 
10 files changed, 198 insertions(+), 176 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -5171,6 +5171,7 @@ dependencies = [
  "collections",
  "futures 0.3.31",
  "gpui",
+ "hashbrown 0.15.3",
  "indoc",
  "itertools 0.14.0",
  "language",

Cargo.toml 🔗

@@ -511,6 +511,7 @@ futures-lite = "1.13"
 git2 = { version = "0.20.1", default-features = false }
 globset = "0.4"
 handlebars = "4.3"
+hashbrown = "0.15.3"
 heck = "0.5"
 heed = { version = "0.21.0", features = ["read-txn-no-tls"] }
 hex = "0.4.3"

crates/cloud_llm_client/src/predict_edits_v3.rs 🔗

@@ -103,13 +103,13 @@ pub struct ReferencedDeclaration {
     /// Index within `signatures`.
     #[serde(skip_serializing_if = "Option::is_none", default)]
     pub parent_index: Option<usize>,
-    pub score_components: ScoreComponents,
+    pub score_components: DeclarationScoreComponents,
     pub signature_score: f32,
     pub declaration_score: f32,
 }
 
 #[derive(Debug, Clone, Serialize, Deserialize)]
-pub struct ScoreComponents {
+pub struct DeclarationScoreComponents {
     pub is_same_file: bool,
     pub is_referenced_nearby: bool,
     pub is_referenced_in_breadcrumb: bool,
@@ -119,12 +119,12 @@ pub struct ScoreComponents {
     pub reference_line_distance: u32,
     pub declaration_line_distance: u32,
     pub declaration_line_distance_rank: usize,
-    pub containing_range_vs_item_jaccard: f32,
-    pub containing_range_vs_signature_jaccard: f32,
+    pub excerpt_vs_item_jaccard: f32,
+    pub excerpt_vs_signature_jaccard: f32,
     pub adjacent_vs_item_jaccard: f32,
     pub adjacent_vs_signature_jaccard: f32,
-    pub containing_range_vs_item_weighted_overlap: f32,
-    pub containing_range_vs_signature_weighted_overlap: f32,
+    pub excerpt_vs_item_weighted_overlap: f32,
+    pub excerpt_vs_signature_weighted_overlap: f32,
     pub adjacent_vs_item_weighted_overlap: f32,
     pub adjacent_vs_signature_weighted_overlap: f32,
 }

crates/cloud_zeta2_prompt/src/cloud_zeta2_prompt.rs 🔗

@@ -70,7 +70,7 @@ pub struct PlannedSnippet<'a> {
 }
 
 #[derive(EnumIter, Clone, Copy, PartialEq, Eq, Hash, Debug, PartialOrd, Ord)]
-pub enum SnippetStyle {
+pub enum DeclarationStyle {
     Signature,
     Declaration,
 }
@@ -84,10 +84,10 @@ pub struct SectionLabels {
 impl<'a> PlannedPrompt<'a> {
     /// Greedy one-pass knapsack algorithm to populate the prompt plan. Does the following:
     ///
-    /// Initializes a priority queue by populating it with each snippet, finding the SnippetStyle
-    /// that minimizes `score_density = score / snippet.range(style).len()`. When a "signature"
-    /// snippet is popped, insert an entry for the "declaration" variant that reflects the cost of
-    /// upgrade.
+    /// Initializes a priority queue by populating it with each snippet, finding the
+    /// DeclarationStyle that minimizes `score_density = score / snippet.range(style).len()`. When a
+    /// "signature" snippet is popped, insert an entry for the "declaration" variant that reflects
+    /// the cost of upgrade.
     ///
     /// TODO: Implement an early halting condition. One option might be to have another priority
     /// queue where the score is the size, and update it accordingly. Another option might be to
@@ -131,13 +131,13 @@ impl<'a> PlannedPrompt<'a> {
         struct QueueEntry {
             score_density: OrderedFloat<f32>,
             declaration_index: usize,
-            style: SnippetStyle,
+            style: DeclarationStyle,
         }
 
         // Initialize priority queue with the best score for each snippet.
         let mut queue: BinaryHeap<QueueEntry> = BinaryHeap::new();
         for (declaration_index, declaration) in request.referenced_declarations.iter().enumerate() {
-            let (style, score_density) = SnippetStyle::iter()
+            let (style, score_density) = DeclarationStyle::iter()
                 .map(|style| {
                     (
                         style,
@@ -186,7 +186,7 @@ impl<'a> PlannedPrompt<'a> {
             this.budget_used += additional_bytes;
             this.add_parents(&mut included_parents, additional_parents);
             let planned_snippet = match queue_entry.style {
-                SnippetStyle::Signature => {
+                DeclarationStyle::Signature => {
                     let Some(text) = declaration.text.get(declaration.signature_range.clone())
                     else {
                         return Err(anyhow!(
@@ -203,7 +203,7 @@ impl<'a> PlannedPrompt<'a> {
                         text_is_truncated: declaration.text_is_truncated,
                     }
                 }
-                SnippetStyle::Declaration => PlannedSnippet {
+                DeclarationStyle::Declaration => PlannedSnippet {
                     path: declaration.path.clone(),
                     range: declaration.range.clone(),
                     text: &declaration.text,
@@ -213,11 +213,13 @@ impl<'a> PlannedPrompt<'a> {
             this.snippets.push(planned_snippet);
 
             // When a Signature is consumed, insert an entry for Definition style.
-            if queue_entry.style == SnippetStyle::Signature {
-                let signature_size = declaration_size(&declaration, SnippetStyle::Signature);
-                let declaration_size = declaration_size(&declaration, SnippetStyle::Declaration);
-                let signature_score = declaration_score(&declaration, SnippetStyle::Signature);
-                let declaration_score = declaration_score(&declaration, SnippetStyle::Declaration);
+            if queue_entry.style == DeclarationStyle::Signature {
+                let signature_size = declaration_size(&declaration, DeclarationStyle::Signature);
+                let declaration_size =
+                    declaration_size(&declaration, DeclarationStyle::Declaration);
+                let signature_score = declaration_score(&declaration, DeclarationStyle::Signature);
+                let declaration_score =
+                    declaration_score(&declaration, DeclarationStyle::Declaration);
 
                 let score_diff = declaration_score - signature_score;
                 let size_diff = declaration_size.saturating_sub(signature_size);
@@ -225,7 +227,7 @@ impl<'a> PlannedPrompt<'a> {
                     queue.push(QueueEntry {
                         declaration_index: queue_entry.declaration_index,
                         score_density: OrderedFloat(score_diff / (size_diff as f32)),
-                        style: SnippetStyle::Declaration,
+                        style: DeclarationStyle::Declaration,
                     });
                 }
             }
@@ -510,20 +512,20 @@ impl<'a> PlannedPrompt<'a> {
     }
 }
 
-fn declaration_score_density(declaration: &ReferencedDeclaration, style: SnippetStyle) -> f32 {
+fn declaration_score_density(declaration: &ReferencedDeclaration, style: DeclarationStyle) -> f32 {
     declaration_score(declaration, style) / declaration_size(declaration, style) as f32
 }
 
-fn declaration_score(declaration: &ReferencedDeclaration, style: SnippetStyle) -> f32 {
+fn declaration_score(declaration: &ReferencedDeclaration, style: DeclarationStyle) -> f32 {
     match style {
-        SnippetStyle::Signature => declaration.signature_score,
-        SnippetStyle::Declaration => declaration.declaration_score,
+        DeclarationStyle::Signature => declaration.signature_score,
+        DeclarationStyle::Declaration => declaration.declaration_score,
     }
 }
 
-fn declaration_size(declaration: &ReferencedDeclaration, style: SnippetStyle) -> usize {
+fn declaration_size(declaration: &ReferencedDeclaration, style: DeclarationStyle) -> usize {
     match style {
-        SnippetStyle::Signature => declaration.signature_range.len(),
-        SnippetStyle::Declaration => declaration.text.len(),
+        DeclarationStyle::Signature => declaration.signature_range.len(),
+        DeclarationStyle::Declaration => declaration.text.len(),
     }
 }

crates/edit_prediction_context/Cargo.toml 🔗

@@ -18,6 +18,7 @@ cloud_llm_client.workspace = true
 collections.workspace = true
 futures.workspace = true
 gpui.workspace = true
+hashbrown.workspace = true
 itertools.workspace = true
 language.workspace = true
 log.workspace = true

crates/edit_prediction_context/src/declaration_scoring.rs 🔗

@@ -1,4 +1,4 @@
-use cloud_llm_client::predict_edits_v3::ScoreComponents;
+use cloud_llm_client::predict_edits_v3::DeclarationScoreComponents;
 use itertools::Itertools as _;
 use language::BufferSnapshot;
 use ordered_float::OrderedFloat;
@@ -8,76 +8,67 @@ use strum::EnumIter;
 use text::{Point, ToPoint};
 
 use crate::{
-    Declaration, EditPredictionExcerpt, EditPredictionExcerptText, Identifier,
+    Declaration, EditPredictionExcerpt, Identifier,
     reference::{Reference, ReferenceRegion},
     syntax_index::SyntaxIndexState,
-    text_similarity::{IdentifierOccurrences, jaccard_similarity, weighted_overlap_coefficient},
+    text_similarity::{Occurrences, jaccard_similarity, weighted_overlap_coefficient},
 };
 
 const MAX_IDENTIFIER_DECLARATION_COUNT: usize = 16;
 
 #[derive(Clone, Debug)]
-pub struct ScoredSnippet {
+pub struct ScoredDeclaration {
     pub identifier: Identifier,
     pub declaration: Declaration,
-    pub score_components: ScoreComponents,
-    pub scores: Scores,
+    pub score_components: DeclarationScoreComponents,
+    pub scores: DeclarationScores,
 }
 
 #[derive(EnumIter, Clone, Copy, PartialEq, Eq, Hash, Debug)]
-pub enum SnippetStyle {
+pub enum DeclarationStyle {
     Signature,
     Declaration,
 }
 
-impl ScoredSnippet {
-    /// Returns the score for this snippet with the specified style.
-    pub fn score(&self, style: SnippetStyle) -> f32 {
+impl ScoredDeclaration {
+    /// Returns the score for this declaration with the specified style.
+    pub fn score(&self, style: DeclarationStyle) -> f32 {
         match style {
-            SnippetStyle::Signature => self.scores.signature,
-            SnippetStyle::Declaration => self.scores.declaration,
+            DeclarationStyle::Signature => self.scores.signature,
+            DeclarationStyle::Declaration => self.scores.declaration,
         }
     }
 
-    pub fn size(&self, style: SnippetStyle) -> usize {
+    pub fn size(&self, style: DeclarationStyle) -> usize {
         match &self.declaration {
             Declaration::File { declaration, .. } => match style {
-                SnippetStyle::Signature => declaration.signature_range.len(),
-                SnippetStyle::Declaration => declaration.text.len(),
+                DeclarationStyle::Signature => declaration.signature_range.len(),
+                DeclarationStyle::Declaration => declaration.text.len(),
             },
             Declaration::Buffer { declaration, .. } => match style {
-                SnippetStyle::Signature => declaration.signature_range.len(),
-                SnippetStyle::Declaration => declaration.item_range.len(),
+                DeclarationStyle::Signature => declaration.signature_range.len(),
+                DeclarationStyle::Declaration => declaration.item_range.len(),
             },
         }
     }
 
-    pub fn score_density(&self, style: SnippetStyle) -> f32 {
+    pub fn score_density(&self, style: DeclarationStyle) -> f32 {
         self.score(style) / (self.size(style)) as f32
     }
 }
 
-pub fn scored_snippets(
+pub fn scored_declarations(
     index: &SyntaxIndexState,
     excerpt: &EditPredictionExcerpt,
-    excerpt_text: &EditPredictionExcerptText,
+    excerpt_occurrences: &Occurrences,
+    adjacent_occurrences: &Occurrences,
     identifier_to_references: HashMap<Identifier, Vec<Reference>>,
     cursor_offset: usize,
     current_buffer: &BufferSnapshot,
-) -> Vec<ScoredSnippet> {
-    let containing_range_identifier_occurrences =
-        IdentifierOccurrences::within_string(&excerpt_text.body);
+) -> Vec<ScoredDeclaration> {
     let cursor_point = cursor_offset.to_point(&current_buffer);
 
-    let start_point = Point::new(cursor_point.row.saturating_sub(2), 0);
-    let end_point = Point::new(cursor_point.row + 1, 0);
-    let adjacent_identifier_occurrences = IdentifierOccurrences::within_string(
-        &current_buffer
-            .text_for_range(start_point..end_point)
-            .collect::<String>(),
-    );
-
-    let mut snippets = identifier_to_references
+    let mut declarations = identifier_to_references
         .into_iter()
         .flat_map(|(identifier, references)| {
             let declarations =
@@ -137,7 +128,7 @@ pub fn scored_snippets(
                     )| {
                         let same_file_declaration_count = index.file_declaration_count(declaration);
 
-                        score_snippet(
+                        score_declaration(
                             &identifier,
                             &references,
                             declaration.clone(),
@@ -146,8 +137,8 @@ pub fn scored_snippets(
                             declaration_line_distance_rank,
                             same_file_declaration_count,
                             declaration_count,
-                            &containing_range_identifier_occurrences,
-                            &adjacent_identifier_occurrences,
+                            &excerpt_occurrences,
+                            &adjacent_occurrences,
                             cursor_point,
                             current_buffer,
                         )
@@ -158,14 +149,14 @@ pub fn scored_snippets(
         .flatten()
         .collect::<Vec<_>>();
 
-    snippets.sort_unstable_by_key(|snippet| {
-        let score_density = snippet
-            .score_density(SnippetStyle::Declaration)
-            .max(snippet.score_density(SnippetStyle::Signature));
+    declarations.sort_unstable_by_key(|declaration| {
+        let score_density = declaration
+            .score_density(DeclarationStyle::Declaration)
+            .max(declaration.score_density(DeclarationStyle::Signature));
         Reverse(OrderedFloat(score_density))
     });
 
-    snippets
+    declarations
 }
 
 fn range_intersection<T: Ord + Clone>(a: &Range<T>, b: &Range<T>) -> Option<Range<T>> {
@@ -178,7 +169,7 @@ fn range_intersection<T: Ord + Clone>(a: &Range<T>, b: &Range<T>) -> Option<Rang
     }
 }
 
-fn score_snippet(
+fn score_declaration(
     identifier: &Identifier,
     references: &[Reference],
     declaration: Declaration,
@@ -187,11 +178,11 @@ fn score_snippet(
     declaration_line_distance_rank: usize,
     same_file_declaration_count: usize,
     declaration_count: usize,
-    containing_range_identifier_occurrences: &IdentifierOccurrences,
-    adjacent_identifier_occurrences: &IdentifierOccurrences,
+    excerpt_occurrences: &Occurrences,
+    adjacent_occurrences: &Occurrences,
     cursor: Point,
     current_buffer: &BufferSnapshot,
-) -> Option<ScoredSnippet> {
+) -> Option<ScoredDeclaration> {
     let is_referenced_nearby = references
         .iter()
         .any(|r| r.region == ReferenceRegion::Nearby);
@@ -208,37 +199,27 @@ fn score_snippet(
         .min()
         .unwrap();
 
-    let item_source_occurrences = IdentifierOccurrences::within_string(&declaration.item_text().0);
-    let item_signature_occurrences =
-        IdentifierOccurrences::within_string(&declaration.signature_text().0);
-    let containing_range_vs_item_jaccard = jaccard_similarity(
-        containing_range_identifier_occurrences,
-        &item_source_occurrences,
-    );
-    let containing_range_vs_signature_jaccard = jaccard_similarity(
-        containing_range_identifier_occurrences,
-        &item_signature_occurrences,
-    );
+    let item_source_occurrences = Occurrences::within_string(&declaration.item_text().0);
+    let item_signature_occurrences = Occurrences::within_string(&declaration.signature_text().0);
+    let excerpt_vs_item_jaccard = jaccard_similarity(excerpt_occurrences, &item_source_occurrences);
+    let excerpt_vs_signature_jaccard =
+        jaccard_similarity(excerpt_occurrences, &item_signature_occurrences);
     let adjacent_vs_item_jaccard =
-        jaccard_similarity(adjacent_identifier_occurrences, &item_source_occurrences);
+        jaccard_similarity(adjacent_occurrences, &item_source_occurrences);
     let adjacent_vs_signature_jaccard =
-        jaccard_similarity(adjacent_identifier_occurrences, &item_signature_occurrences);
+        jaccard_similarity(adjacent_occurrences, &item_signature_occurrences);
 
-    let containing_range_vs_item_weighted_overlap = weighted_overlap_coefficient(
-        containing_range_identifier_occurrences,
-        &item_source_occurrences,
-    );
-    let containing_range_vs_signature_weighted_overlap = weighted_overlap_coefficient(
-        containing_range_identifier_occurrences,
-        &item_signature_occurrences,
-    );
+    let excerpt_vs_item_weighted_overlap =
+        weighted_overlap_coefficient(excerpt_occurrences, &item_source_occurrences);
+    let excerpt_vs_signature_weighted_overlap =
+        weighted_overlap_coefficient(excerpt_occurrences, &item_signature_occurrences);
     let adjacent_vs_item_weighted_overlap =
-        weighted_overlap_coefficient(adjacent_identifier_occurrences, &item_source_occurrences);
+        weighted_overlap_coefficient(adjacent_occurrences, &item_source_occurrences);
     let adjacent_vs_signature_weighted_overlap =
-        weighted_overlap_coefficient(adjacent_identifier_occurrences, &item_signature_occurrences);
+        weighted_overlap_coefficient(adjacent_occurrences, &item_signature_occurrences);
 
     // TODO: Consider adding declaration_file_count
-    let score_components = ScoreComponents {
+    let score_components = DeclarationScoreComponents {
         is_same_file,
         is_referenced_nearby,
         is_referenced_in_breadcrumb,
@@ -248,32 +229,32 @@ fn score_snippet(
         reference_count,
         same_file_declaration_count,
         declaration_count,
-        containing_range_vs_item_jaccard,
-        containing_range_vs_signature_jaccard,
+        excerpt_vs_item_jaccard,
+        excerpt_vs_signature_jaccard,
         adjacent_vs_item_jaccard,
         adjacent_vs_signature_jaccard,
-        containing_range_vs_item_weighted_overlap,
-        containing_range_vs_signature_weighted_overlap,
+        excerpt_vs_item_weighted_overlap,
+        excerpt_vs_signature_weighted_overlap,
         adjacent_vs_item_weighted_overlap,
         adjacent_vs_signature_weighted_overlap,
     };
 
-    Some(ScoredSnippet {
+    Some(ScoredDeclaration {
         identifier: identifier.clone(),
         declaration: declaration,
-        scores: Scores::score(&score_components),
+        scores: DeclarationScores::score(&score_components),
         score_components,
     })
 }
 
 #[derive(Clone, Debug, Serialize)]
-pub struct Scores {
+pub struct DeclarationScores {
     pub signature: f32,
     pub declaration: f32,
 }
 
-impl Scores {
-    fn score(components: &ScoreComponents) -> Scores {
+impl DeclarationScores {
+    fn score(components: &DeclarationScoreComponents) -> DeclarationScores {
         // TODO: handle truncation
 
         // Score related to how likely this is the correct declaration, range 0 to 1
@@ -295,13 +276,11 @@ impl Scores {
         // For now instead of linear combination, the scores are just multiplied together.
         let combined_score = 10.0 * accuracy_score * distance_score;
 
-        Scores {
-            signature: combined_score * components.containing_range_vs_signature_weighted_overlap,
+        DeclarationScores {
+            signature: combined_score * components.excerpt_vs_signature_weighted_overlap,
             // declaration score gets boosted both by being multiplied by 2 and by there being more
             // weighted overlap.
-            declaration: 2.0
-                * combined_score
-                * components.containing_range_vs_item_weighted_overlap,
+            declaration: 2.0 * combined_score * components.excerpt_vs_item_weighted_overlap,
         }
     }
 }

crates/edit_prediction_context/src/edit_prediction_context.rs 🔗

@@ -21,7 +21,7 @@ pub struct EditPredictionContext {
     pub excerpt: EditPredictionExcerpt,
     pub excerpt_text: EditPredictionExcerptText,
     pub cursor_offset_in_excerpt: usize,
-    pub snippets: Vec<ScoredSnippet>,
+    pub declarations: Vec<ScoredDeclaration>,
 }
 
 impl EditPredictionContext {
@@ -58,17 +58,28 @@ impl EditPredictionContext {
             index_state,
         )?;
         let excerpt_text = excerpt.text(buffer);
+        let excerpt_occurrences = text_similarity::Occurrences::within_string(&excerpt_text.body);
+
+        let adjacent_start = Point::new(cursor_point.row.saturating_sub(2), 0);
+        let adjacent_end = Point::new(cursor_point.row + 1, 0);
+        let adjacent_occurrences = text_similarity::Occurrences::within_string(
+            &buffer
+                .text_for_range(adjacent_start..adjacent_end)
+                .collect::<String>(),
+        );
+
         let cursor_offset_in_file = cursor_point.to_offset(buffer);
         // TODO fix this to not need saturating_sub
         let cursor_offset_in_excerpt = cursor_offset_in_file.saturating_sub(excerpt.range.start);
 
-        let snippets = if let Some(index_state) = index_state {
+        let declarations = if let Some(index_state) = index_state {
             let references = references_in_excerpt(&excerpt, &excerpt_text, buffer);
 
-            scored_snippets(
+            scored_declarations(
                 &index_state,
                 &excerpt,
-                &excerpt_text,
+                &excerpt_occurrences,
+                &adjacent_occurrences,
                 references,
                 cursor_offset_in_file,
                 buffer,
@@ -81,7 +92,7 @@ impl EditPredictionContext {
             excerpt,
             excerpt_text,
             cursor_offset_in_excerpt,
-            snippets,
+            declarations,
         })
     }
 }
@@ -137,7 +148,7 @@ mod tests {
             .unwrap();
 
         let mut snippet_identifiers = context
-            .snippets
+            .declarations
             .iter()
             .map(|snippet| snippet.identifier.name.as_ref())
             .collect::<Vec<_>>();

crates/edit_prediction_context/src/text_similarity.rs 🔗

@@ -1,5 +1,9 @@
+use hashbrown::HashTable;
 use regex::Regex;
-use std::{collections::HashMap, sync::LazyLock};
+use std::{
+    hash::{Hash, Hasher as _},
+    sync::LazyLock,
+};
 
 use crate::reference::Reference;
 
@@ -14,49 +18,76 @@ use crate::reference::Reference;
 
 static IDENTIFIER_REGEX: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"\b\w+\b").unwrap());
 
-// TODO: use &str or Cow<str> keys?
-#[derive(Debug)]
-pub struct IdentifierOccurrences {
-    identifier_to_count: HashMap<String, usize>,
+/// Multiset of text occurrences for text similarity that only stores hashes and counts.
+#[derive(Debug, Default)]
+pub struct Occurrences {
+    table: HashTable<OccurrenceEntry>,
     total_count: usize,
 }
 
-impl IdentifierOccurrences {
-    pub fn within_string(code: &str) -> Self {
-        Self::from_iterator(IDENTIFIER_REGEX.find_iter(code).map(|mat| mat.as_str()))
+#[derive(Debug)]
+struct OccurrenceEntry {
+    hash: u64,
+    count: usize,
+}
+
+impl Occurrences {
+    pub fn within_string(text: &str) -> Self {
+        Self::from_identifiers(IDENTIFIER_REGEX.find_iter(text).map(|mat| mat.as_str()))
     }
 
     #[allow(dead_code)]
     pub fn within_references(references: &[Reference]) -> Self {
-        Self::from_iterator(
+        Self::from_identifiers(
             references
                 .iter()
                 .map(|reference| reference.identifier.name.as_ref()),
         )
     }
 
-    pub fn from_iterator<'a>(identifier_iterator: impl Iterator<Item = &'a str>) -> Self {
-        let mut identifier_to_count = HashMap::new();
-        let mut total_count = 0;
-        for identifier in identifier_iterator {
-            // TODO: Score matches that match case higher?
-            //
-            // TODO: Also include unsplit identifier?
+    pub fn from_identifiers<'a>(identifiers: impl IntoIterator<Item = &'a str>) -> Self {
+        let mut this = Self::default();
+        // TODO: Score matches that match case higher?
+        //
+        // TODO: Also include unsplit identifier?
+        for identifier in identifiers {
             for identifier_part in split_identifier(identifier) {
-                identifier_to_count
-                    .entry(identifier_part.to_lowercase())
-                    .and_modify(|count| *count += 1)
-                    .or_insert(1);
-                total_count += 1;
+                this.add_hash(fx_hash(&identifier_part.to_lowercase()));
             }
         }
-        IdentifierOccurrences {
-            identifier_to_count,
-            total_count,
-        }
+        this
+    }
+
+    fn add_hash(&mut self, hash: u64) {
+        self.table
+            .entry(
+                hash,
+                |entry: &OccurrenceEntry| entry.hash == hash,
+                |entry| entry.hash,
+            )
+            .and_modify(|entry| entry.count += 1)
+            .or_insert(OccurrenceEntry { hash, count: 1 });
+        self.total_count += 1;
+    }
+
+    fn contains_hash(&self, hash: u64) -> bool {
+        self.get_count(hash) != 0
+    }
+
+    fn get_count(&self, hash: u64) -> usize {
+        self.table
+            .find(hash, |entry| entry.hash == hash)
+            .map(|entry| entry.count)
+            .unwrap_or(0)
     }
 }
 
+pub fn fx_hash<T: Hash + ?Sized>(data: &T) -> u64 {
+    let mut hasher = collections::FxHasher::default();
+    data.hash(&mut hasher);
+    hasher.finish()
+}
+
 // Splits camelcase / snakecase / kebabcase / pascalcase
 //
 // TODO: Make this more efficient / elegant.
@@ -115,54 +146,49 @@ fn split_identifier(identifier: &str) -> Vec<&str> {
     parts.into_iter().filter(|s| !s.is_empty()).collect()
 }
 
-pub fn jaccard_similarity<'a>(
-    mut set_a: &'a IdentifierOccurrences,
-    mut set_b: &'a IdentifierOccurrences,
-) -> f32 {
-    if set_a.identifier_to_count.len() > set_b.identifier_to_count.len() {
+pub fn jaccard_similarity<'a>(mut set_a: &'a Occurrences, mut set_b: &'a Occurrences) -> f32 {
+    if set_a.table.len() > set_b.table.len() {
         std::mem::swap(&mut set_a, &mut set_b);
     }
     let intersection = set_a
-        .identifier_to_count
-        .keys()
-        .filter(|key| set_b.identifier_to_count.contains_key(*key))
+        .table
+        .iter()
+        .filter(|entry| set_b.contains_hash(entry.hash))
         .count();
-    let union = set_a.identifier_to_count.len() + set_b.identifier_to_count.len() - intersection;
+    let union = set_a.table.len() + set_b.table.len() - intersection;
     intersection as f32 / union as f32
 }
 
 // TODO
 #[allow(dead_code)]
-pub fn overlap_coefficient<'a>(
-    mut set_a: &'a IdentifierOccurrences,
-    mut set_b: &'a IdentifierOccurrences,
-) -> f32 {
-    if set_a.identifier_to_count.len() > set_b.identifier_to_count.len() {
+pub fn overlap_coefficient<'a>(mut set_a: &'a Occurrences, mut set_b: &'a Occurrences) -> f32 {
+    if set_a.table.len() > set_b.table.len() {
         std::mem::swap(&mut set_a, &mut set_b);
     }
     let intersection = set_a
-        .identifier_to_count
-        .keys()
-        .filter(|key| set_b.identifier_to_count.contains_key(*key))
+        .table
+        .iter()
+        .filter(|entry| set_b.contains_hash(entry.hash))
         .count();
-    intersection as f32 / set_a.identifier_to_count.len() as f32
+    intersection as f32 / set_a.table.len() as f32
 }
 
 // TODO
 #[allow(dead_code)]
 pub fn weighted_jaccard_similarity<'a>(
-    mut set_a: &'a IdentifierOccurrences,
-    mut set_b: &'a IdentifierOccurrences,
+    mut set_a: &'a Occurrences,
+    mut set_b: &'a Occurrences,
 ) -> f32 {
-    if set_a.identifier_to_count.len() > set_b.identifier_to_count.len() {
+    if set_a.table.len() > set_b.table.len() {
         std::mem::swap(&mut set_a, &mut set_b);
     }
 
     let mut numerator = 0;
     let mut denominator_a = 0;
     let mut used_count_b = 0;
-    for (symbol, count_a) in set_a.identifier_to_count.iter() {
-        let count_b = set_b.identifier_to_count.get(symbol).unwrap_or(&0);
+    for entry_a in set_a.table.iter() {
+        let count_a = entry_a.count;
+        let count_b = set_b.get_count(entry_a.hash);
         numerator += count_a.min(count_b);
         denominator_a += count_a.max(count_b);
         used_count_b += count_b;
@@ -177,16 +203,17 @@ pub fn weighted_jaccard_similarity<'a>(
 }
 
 pub fn weighted_overlap_coefficient<'a>(
-    mut set_a: &'a IdentifierOccurrences,
-    mut set_b: &'a IdentifierOccurrences,
+    mut set_a: &'a Occurrences,
+    mut set_b: &'a Occurrences,
 ) -> f32 {
-    if set_a.identifier_to_count.len() > set_b.identifier_to_count.len() {
+    if set_a.table.len() > set_b.table.len() {
         std::mem::swap(&mut set_a, &mut set_b);
     }
 
     let mut numerator = 0;
-    for (symbol, count_a) in set_a.identifier_to_count.iter() {
-        let count_b = set_b.identifier_to_count.get(symbol).unwrap_or(&0);
+    for entry_a in set_a.table.iter() {
+        let count_a = entry_a.count;
+        let count_b = set_b.get_count(entry_a.hash);
         numerator += count_a.min(count_b);
     }
 
@@ -215,12 +242,12 @@ mod test {
     fn test_similarity_functions() {
         // 10 identifier parts, 8 unique
         // Repeats: 2 "outline", 2 "items"
-        let set_a = IdentifierOccurrences::within_string(
+        let set_a = Occurrences::within_string(
             "let mut outline_items = query_outline_items(&language, &tree, &source);",
         );
         // 14 identifier parts, 11 unique
         // Repeats: 2 "outline", 2 "language", 2 "tree"
-        let set_b = IdentifierOccurrences::within_string(
+        let set_b = Occurrences::within_string(
             "pub fn query_outline_items(language: &Language, tree: &Tree, source: &str) -> Vec<OutlineItem> {",
         );
 

crates/zeta2/src/zeta2.rs 🔗

@@ -733,7 +733,7 @@ fn make_cloud_request(
     let mut declaration_to_signature_index = HashMap::default();
     let mut referenced_declarations = Vec::new();
 
-    for snippet in context.snippets {
+    for snippet in context.declarations {
         let project_entry_id = snippet.declaration.project_entry_id();
         let Some(path) = worktrees.iter().find_map(|worktree| {
             worktree.entry_for_id(project_entry_id).map(|entry| {

crates/zeta2_tools/src/zeta2_tools.rs 🔗

@@ -18,7 +18,7 @@ use util::{ResultExt, paths::PathStyle, rel_path::RelPath};
 use workspace::{Item, SplitDirection, Workspace};
 use zeta2::{Zeta, ZetaOptions};
 
-use edit_prediction_context::{EditPredictionExcerptOptions, SnippetStyle};
+use edit_prediction_context::{DeclarationStyle, EditPredictionExcerptOptions};
 
 actions!(
     dev,
@@ -285,7 +285,7 @@ impl Zeta2Inspector {
                 let mut languages = HashMap::default();
                 for lang_id in prediction
                     .context
-                    .snippets
+                    .declarations
                     .iter()
                     .map(|snippet| snippet.declaration.identifier().language_id)
                     .chain(prediction.context.excerpt_text.language_id)
@@ -334,7 +334,7 @@ impl Zeta2Inspector {
                                 cx,
                             );
 
-                            for snippet in &prediction.context.snippets {
+                            for snippet in &prediction.context.declarations {
                                 let path = this
                                     .project
                                     .read(cx)
@@ -345,7 +345,7 @@ impl Zeta2Inspector {
                                         "{} (Score density: {})",
                                         path.map(|p| p.path.display(path_style).to_string())
                                             .unwrap_or_else(|| "".to_string()),
-                                        snippet.score_density(SnippetStyle::Declaration)
+                                        snippet.score_density(DeclarationStyle::Declaration)
                                     ))
                                     .unwrap()
                                     .into(),