Cargo.lock 🔗
@@ -5171,6 +5171,7 @@ dependencies = [
"collections",
"futures 0.3.31",
"gpui",
+ "hashbrown 0.15.3",
"indoc",
"itertools 0.14.0",
"language",
Michael Sloan created
Release Notes:
- N/A
Cargo.lock | 1
Cargo.toml | 1
crates/cloud_llm_client/src/predict_edits_v3.rs | 12
crates/cloud_zeta2_prompt/src/cloud_zeta2_prompt.rs | 46
crates/edit_prediction_context/Cargo.toml | 1
crates/edit_prediction_context/src/declaration_scoring.rs | 143 ++--
crates/edit_prediction_context/src/edit_prediction_context.rs | 23
crates/edit_prediction_context/src/text_similarity.rs | 137 ++-
crates/zeta2/src/zeta2.rs | 2
crates/zeta2_tools/src/zeta2_tools.rs | 8
10 files changed, 198 insertions(+), 176 deletions(-)
@@ -5171,6 +5171,7 @@ dependencies = [
"collections",
"futures 0.3.31",
"gpui",
+ "hashbrown 0.15.3",
"indoc",
"itertools 0.14.0",
"language",
@@ -511,6 +511,7 @@ futures-lite = "1.13"
git2 = { version = "0.20.1", default-features = false }
globset = "0.4"
handlebars = "4.3"
+hashbrown = "0.15.3"
heck = "0.5"
heed = { version = "0.21.0", features = ["read-txn-no-tls"] }
hex = "0.4.3"
@@ -103,13 +103,13 @@ pub struct ReferencedDeclaration {
/// Index within `signatures`.
#[serde(skip_serializing_if = "Option::is_none", default)]
pub parent_index: Option<usize>,
- pub score_components: ScoreComponents,
+ pub score_components: DeclarationScoreComponents,
pub signature_score: f32,
pub declaration_score: f32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
-pub struct ScoreComponents {
+pub struct DeclarationScoreComponents {
pub is_same_file: bool,
pub is_referenced_nearby: bool,
pub is_referenced_in_breadcrumb: bool,
@@ -119,12 +119,12 @@ pub struct ScoreComponents {
pub reference_line_distance: u32,
pub declaration_line_distance: u32,
pub declaration_line_distance_rank: usize,
- pub containing_range_vs_item_jaccard: f32,
- pub containing_range_vs_signature_jaccard: f32,
+ pub excerpt_vs_item_jaccard: f32,
+ pub excerpt_vs_signature_jaccard: f32,
pub adjacent_vs_item_jaccard: f32,
pub adjacent_vs_signature_jaccard: f32,
- pub containing_range_vs_item_weighted_overlap: f32,
- pub containing_range_vs_signature_weighted_overlap: f32,
+ pub excerpt_vs_item_weighted_overlap: f32,
+ pub excerpt_vs_signature_weighted_overlap: f32,
pub adjacent_vs_item_weighted_overlap: f32,
pub adjacent_vs_signature_weighted_overlap: f32,
}
@@ -70,7 +70,7 @@ pub struct PlannedSnippet<'a> {
}
#[derive(EnumIter, Clone, Copy, PartialEq, Eq, Hash, Debug, PartialOrd, Ord)]
-pub enum SnippetStyle {
+pub enum DeclarationStyle {
Signature,
Declaration,
}
@@ -84,10 +84,10 @@ pub struct SectionLabels {
impl<'a> PlannedPrompt<'a> {
/// Greedy one-pass knapsack algorithm to populate the prompt plan. Does the following:
///
- /// Initializes a priority queue by populating it with each snippet, finding the SnippetStyle
- /// that minimizes `score_density = score / snippet.range(style).len()`. When a "signature"
- /// snippet is popped, insert an entry for the "declaration" variant that reflects the cost of
- /// upgrade.
+ /// Initializes a priority queue by populating it with each snippet, finding the
+ /// DeclarationStyle that minimizes `score_density = score / snippet.range(style).len()`. When a
+ /// "signature" snippet is popped, insert an entry for the "declaration" variant that reflects
+ /// the cost of upgrade.
///
/// TODO: Implement an early halting condition. One option might be to have another priority
/// queue where the score is the size, and update it accordingly. Another option might be to
@@ -131,13 +131,13 @@ impl<'a> PlannedPrompt<'a> {
struct QueueEntry {
score_density: OrderedFloat<f32>,
declaration_index: usize,
- style: SnippetStyle,
+ style: DeclarationStyle,
}
// Initialize priority queue with the best score for each snippet.
let mut queue: BinaryHeap<QueueEntry> = BinaryHeap::new();
for (declaration_index, declaration) in request.referenced_declarations.iter().enumerate() {
- let (style, score_density) = SnippetStyle::iter()
+ let (style, score_density) = DeclarationStyle::iter()
.map(|style| {
(
style,
@@ -186,7 +186,7 @@ impl<'a> PlannedPrompt<'a> {
this.budget_used += additional_bytes;
this.add_parents(&mut included_parents, additional_parents);
let planned_snippet = match queue_entry.style {
- SnippetStyle::Signature => {
+ DeclarationStyle::Signature => {
let Some(text) = declaration.text.get(declaration.signature_range.clone())
else {
return Err(anyhow!(
@@ -203,7 +203,7 @@ impl<'a> PlannedPrompt<'a> {
text_is_truncated: declaration.text_is_truncated,
}
}
- SnippetStyle::Declaration => PlannedSnippet {
+ DeclarationStyle::Declaration => PlannedSnippet {
path: declaration.path.clone(),
range: declaration.range.clone(),
text: &declaration.text,
@@ -213,11 +213,13 @@ impl<'a> PlannedPrompt<'a> {
this.snippets.push(planned_snippet);
// When a Signature is consumed, insert an entry for Definition style.
- if queue_entry.style == SnippetStyle::Signature {
- let signature_size = declaration_size(&declaration, SnippetStyle::Signature);
- let declaration_size = declaration_size(&declaration, SnippetStyle::Declaration);
- let signature_score = declaration_score(&declaration, SnippetStyle::Signature);
- let declaration_score = declaration_score(&declaration, SnippetStyle::Declaration);
+ if queue_entry.style == DeclarationStyle::Signature {
+ let signature_size = declaration_size(&declaration, DeclarationStyle::Signature);
+ let declaration_size =
+ declaration_size(&declaration, DeclarationStyle::Declaration);
+ let signature_score = declaration_score(&declaration, DeclarationStyle::Signature);
+ let declaration_score =
+ declaration_score(&declaration, DeclarationStyle::Declaration);
let score_diff = declaration_score - signature_score;
let size_diff = declaration_size.saturating_sub(signature_size);
@@ -225,7 +227,7 @@ impl<'a> PlannedPrompt<'a> {
queue.push(QueueEntry {
declaration_index: queue_entry.declaration_index,
score_density: OrderedFloat(score_diff / (size_diff as f32)),
- style: SnippetStyle::Declaration,
+ style: DeclarationStyle::Declaration,
});
}
}
@@ -510,20 +512,20 @@ impl<'a> PlannedPrompt<'a> {
}
}
-fn declaration_score_density(declaration: &ReferencedDeclaration, style: SnippetStyle) -> f32 {
+fn declaration_score_density(declaration: &ReferencedDeclaration, style: DeclarationStyle) -> f32 {
declaration_score(declaration, style) / declaration_size(declaration, style) as f32
}
-fn declaration_score(declaration: &ReferencedDeclaration, style: SnippetStyle) -> f32 {
+fn declaration_score(declaration: &ReferencedDeclaration, style: DeclarationStyle) -> f32 {
match style {
- SnippetStyle::Signature => declaration.signature_score,
- SnippetStyle::Declaration => declaration.declaration_score,
+ DeclarationStyle::Signature => declaration.signature_score,
+ DeclarationStyle::Declaration => declaration.declaration_score,
}
}
-fn declaration_size(declaration: &ReferencedDeclaration, style: SnippetStyle) -> usize {
+fn declaration_size(declaration: &ReferencedDeclaration, style: DeclarationStyle) -> usize {
match style {
- SnippetStyle::Signature => declaration.signature_range.len(),
- SnippetStyle::Declaration => declaration.text.len(),
+ DeclarationStyle::Signature => declaration.signature_range.len(),
+ DeclarationStyle::Declaration => declaration.text.len(),
}
}
@@ -18,6 +18,7 @@ cloud_llm_client.workspace = true
collections.workspace = true
futures.workspace = true
gpui.workspace = true
+hashbrown.workspace = true
itertools.workspace = true
language.workspace = true
log.workspace = true
@@ -1,4 +1,4 @@
-use cloud_llm_client::predict_edits_v3::ScoreComponents;
+use cloud_llm_client::predict_edits_v3::DeclarationScoreComponents;
use itertools::Itertools as _;
use language::BufferSnapshot;
use ordered_float::OrderedFloat;
@@ -8,76 +8,67 @@ use strum::EnumIter;
use text::{Point, ToPoint};
use crate::{
- Declaration, EditPredictionExcerpt, EditPredictionExcerptText, Identifier,
+ Declaration, EditPredictionExcerpt, Identifier,
reference::{Reference, ReferenceRegion},
syntax_index::SyntaxIndexState,
- text_similarity::{IdentifierOccurrences, jaccard_similarity, weighted_overlap_coefficient},
+ text_similarity::{Occurrences, jaccard_similarity, weighted_overlap_coefficient},
};
const MAX_IDENTIFIER_DECLARATION_COUNT: usize = 16;
#[derive(Clone, Debug)]
-pub struct ScoredSnippet {
+pub struct ScoredDeclaration {
pub identifier: Identifier,
pub declaration: Declaration,
- pub score_components: ScoreComponents,
- pub scores: Scores,
+ pub score_components: DeclarationScoreComponents,
+ pub scores: DeclarationScores,
}
#[derive(EnumIter, Clone, Copy, PartialEq, Eq, Hash, Debug)]
-pub enum SnippetStyle {
+pub enum DeclarationStyle {
Signature,
Declaration,
}
-impl ScoredSnippet {
- /// Returns the score for this snippet with the specified style.
- pub fn score(&self, style: SnippetStyle) -> f32 {
+impl ScoredDeclaration {
+ /// Returns the score for this declaration with the specified style.
+ pub fn score(&self, style: DeclarationStyle) -> f32 {
match style {
- SnippetStyle::Signature => self.scores.signature,
- SnippetStyle::Declaration => self.scores.declaration,
+ DeclarationStyle::Signature => self.scores.signature,
+ DeclarationStyle::Declaration => self.scores.declaration,
}
}
- pub fn size(&self, style: SnippetStyle) -> usize {
+ pub fn size(&self, style: DeclarationStyle) -> usize {
match &self.declaration {
Declaration::File { declaration, .. } => match style {
- SnippetStyle::Signature => declaration.signature_range.len(),
- SnippetStyle::Declaration => declaration.text.len(),
+ DeclarationStyle::Signature => declaration.signature_range.len(),
+ DeclarationStyle::Declaration => declaration.text.len(),
},
Declaration::Buffer { declaration, .. } => match style {
- SnippetStyle::Signature => declaration.signature_range.len(),
- SnippetStyle::Declaration => declaration.item_range.len(),
+ DeclarationStyle::Signature => declaration.signature_range.len(),
+ DeclarationStyle::Declaration => declaration.item_range.len(),
},
}
}
- pub fn score_density(&self, style: SnippetStyle) -> f32 {
+ pub fn score_density(&self, style: DeclarationStyle) -> f32 {
self.score(style) / (self.size(style)) as f32
}
}
-pub fn scored_snippets(
+pub fn scored_declarations(
index: &SyntaxIndexState,
excerpt: &EditPredictionExcerpt,
- excerpt_text: &EditPredictionExcerptText,
+ excerpt_occurrences: &Occurrences,
+ adjacent_occurrences: &Occurrences,
identifier_to_references: HashMap<Identifier, Vec<Reference>>,
cursor_offset: usize,
current_buffer: &BufferSnapshot,
-) -> Vec<ScoredSnippet> {
- let containing_range_identifier_occurrences =
- IdentifierOccurrences::within_string(&excerpt_text.body);
+) -> Vec<ScoredDeclaration> {
let cursor_point = cursor_offset.to_point(¤t_buffer);
- let start_point = Point::new(cursor_point.row.saturating_sub(2), 0);
- let end_point = Point::new(cursor_point.row + 1, 0);
- let adjacent_identifier_occurrences = IdentifierOccurrences::within_string(
- ¤t_buffer
- .text_for_range(start_point..end_point)
- .collect::<String>(),
- );
-
- let mut snippets = identifier_to_references
+ let mut declarations = identifier_to_references
.into_iter()
.flat_map(|(identifier, references)| {
let declarations =
@@ -137,7 +128,7 @@ pub fn scored_snippets(
)| {
let same_file_declaration_count = index.file_declaration_count(declaration);
- score_snippet(
+ score_declaration(
&identifier,
&references,
declaration.clone(),
@@ -146,8 +137,8 @@ pub fn scored_snippets(
declaration_line_distance_rank,
same_file_declaration_count,
declaration_count,
- &containing_range_identifier_occurrences,
- &adjacent_identifier_occurrences,
+ &excerpt_occurrences,
+ &adjacent_occurrences,
cursor_point,
current_buffer,
)
@@ -158,14 +149,14 @@ pub fn scored_snippets(
.flatten()
.collect::<Vec<_>>();
- snippets.sort_unstable_by_key(|snippet| {
- let score_density = snippet
- .score_density(SnippetStyle::Declaration)
- .max(snippet.score_density(SnippetStyle::Signature));
+ declarations.sort_unstable_by_key(|declaration| {
+ let score_density = declaration
+ .score_density(DeclarationStyle::Declaration)
+ .max(declaration.score_density(DeclarationStyle::Signature));
Reverse(OrderedFloat(score_density))
});
- snippets
+ declarations
}
fn range_intersection<T: Ord + Clone>(a: &Range<T>, b: &Range<T>) -> Option<Range<T>> {
@@ -178,7 +169,7 @@ fn range_intersection<T: Ord + Clone>(a: &Range<T>, b: &Range<T>) -> Option<Rang
}
}
-fn score_snippet(
+fn score_declaration(
identifier: &Identifier,
references: &[Reference],
declaration: Declaration,
@@ -187,11 +178,11 @@ fn score_snippet(
declaration_line_distance_rank: usize,
same_file_declaration_count: usize,
declaration_count: usize,
- containing_range_identifier_occurrences: &IdentifierOccurrences,
- adjacent_identifier_occurrences: &IdentifierOccurrences,
+ excerpt_occurrences: &Occurrences,
+ adjacent_occurrences: &Occurrences,
cursor: Point,
current_buffer: &BufferSnapshot,
-) -> Option<ScoredSnippet> {
+) -> Option<ScoredDeclaration> {
let is_referenced_nearby = references
.iter()
.any(|r| r.region == ReferenceRegion::Nearby);
@@ -208,37 +199,27 @@ fn score_snippet(
.min()
.unwrap();
- let item_source_occurrences = IdentifierOccurrences::within_string(&declaration.item_text().0);
- let item_signature_occurrences =
- IdentifierOccurrences::within_string(&declaration.signature_text().0);
- let containing_range_vs_item_jaccard = jaccard_similarity(
- containing_range_identifier_occurrences,
- &item_source_occurrences,
- );
- let containing_range_vs_signature_jaccard = jaccard_similarity(
- containing_range_identifier_occurrences,
- &item_signature_occurrences,
- );
+ let item_source_occurrences = Occurrences::within_string(&declaration.item_text().0);
+ let item_signature_occurrences = Occurrences::within_string(&declaration.signature_text().0);
+ let excerpt_vs_item_jaccard = jaccard_similarity(excerpt_occurrences, &item_source_occurrences);
+ let excerpt_vs_signature_jaccard =
+ jaccard_similarity(excerpt_occurrences, &item_signature_occurrences);
let adjacent_vs_item_jaccard =
- jaccard_similarity(adjacent_identifier_occurrences, &item_source_occurrences);
+ jaccard_similarity(adjacent_occurrences, &item_source_occurrences);
let adjacent_vs_signature_jaccard =
- jaccard_similarity(adjacent_identifier_occurrences, &item_signature_occurrences);
+ jaccard_similarity(adjacent_occurrences, &item_signature_occurrences);
- let containing_range_vs_item_weighted_overlap = weighted_overlap_coefficient(
- containing_range_identifier_occurrences,
- &item_source_occurrences,
- );
- let containing_range_vs_signature_weighted_overlap = weighted_overlap_coefficient(
- containing_range_identifier_occurrences,
- &item_signature_occurrences,
- );
+ let excerpt_vs_item_weighted_overlap =
+ weighted_overlap_coefficient(excerpt_occurrences, &item_source_occurrences);
+ let excerpt_vs_signature_weighted_overlap =
+ weighted_overlap_coefficient(excerpt_occurrences, &item_signature_occurrences);
let adjacent_vs_item_weighted_overlap =
- weighted_overlap_coefficient(adjacent_identifier_occurrences, &item_source_occurrences);
+ weighted_overlap_coefficient(adjacent_occurrences, &item_source_occurrences);
let adjacent_vs_signature_weighted_overlap =
- weighted_overlap_coefficient(adjacent_identifier_occurrences, &item_signature_occurrences);
+ weighted_overlap_coefficient(adjacent_occurrences, &item_signature_occurrences);
// TODO: Consider adding declaration_file_count
- let score_components = ScoreComponents {
+ let score_components = DeclarationScoreComponents {
is_same_file,
is_referenced_nearby,
is_referenced_in_breadcrumb,
@@ -248,32 +229,32 @@ fn score_snippet(
reference_count,
same_file_declaration_count,
declaration_count,
- containing_range_vs_item_jaccard,
- containing_range_vs_signature_jaccard,
+ excerpt_vs_item_jaccard,
+ excerpt_vs_signature_jaccard,
adjacent_vs_item_jaccard,
adjacent_vs_signature_jaccard,
- containing_range_vs_item_weighted_overlap,
- containing_range_vs_signature_weighted_overlap,
+ excerpt_vs_item_weighted_overlap,
+ excerpt_vs_signature_weighted_overlap,
adjacent_vs_item_weighted_overlap,
adjacent_vs_signature_weighted_overlap,
};
- Some(ScoredSnippet {
+ Some(ScoredDeclaration {
identifier: identifier.clone(),
declaration: declaration,
- scores: Scores::score(&score_components),
+ scores: DeclarationScores::score(&score_components),
score_components,
})
}
#[derive(Clone, Debug, Serialize)]
-pub struct Scores {
+pub struct DeclarationScores {
pub signature: f32,
pub declaration: f32,
}
-impl Scores {
- fn score(components: &ScoreComponents) -> Scores {
+impl DeclarationScores {
+ fn score(components: &DeclarationScoreComponents) -> DeclarationScores {
// TODO: handle truncation
// Score related to how likely this is the correct declaration, range 0 to 1
@@ -295,13 +276,11 @@ impl Scores {
// For now instead of linear combination, the scores are just multiplied together.
let combined_score = 10.0 * accuracy_score * distance_score;
- Scores {
- signature: combined_score * components.containing_range_vs_signature_weighted_overlap,
+ DeclarationScores {
+ signature: combined_score * components.excerpt_vs_signature_weighted_overlap,
// declaration score gets boosted both by being multiplied by 2 and by there being more
// weighted overlap.
- declaration: 2.0
- * combined_score
- * components.containing_range_vs_item_weighted_overlap,
+ declaration: 2.0 * combined_score * components.excerpt_vs_item_weighted_overlap,
}
}
}
@@ -21,7 +21,7 @@ pub struct EditPredictionContext {
pub excerpt: EditPredictionExcerpt,
pub excerpt_text: EditPredictionExcerptText,
pub cursor_offset_in_excerpt: usize,
- pub snippets: Vec<ScoredSnippet>,
+ pub declarations: Vec<ScoredDeclaration>,
}
impl EditPredictionContext {
@@ -58,17 +58,28 @@ impl EditPredictionContext {
index_state,
)?;
let excerpt_text = excerpt.text(buffer);
+ let excerpt_occurrences = text_similarity::Occurrences::within_string(&excerpt_text.body);
+
+ let adjacent_start = Point::new(cursor_point.row.saturating_sub(2), 0);
+ let adjacent_end = Point::new(cursor_point.row + 1, 0);
+ let adjacent_occurrences = text_similarity::Occurrences::within_string(
+ &buffer
+ .text_for_range(adjacent_start..adjacent_end)
+ .collect::<String>(),
+ );
+
let cursor_offset_in_file = cursor_point.to_offset(buffer);
// TODO fix this to not need saturating_sub
let cursor_offset_in_excerpt = cursor_offset_in_file.saturating_sub(excerpt.range.start);
- let snippets = if let Some(index_state) = index_state {
+ let declarations = if let Some(index_state) = index_state {
let references = references_in_excerpt(&excerpt, &excerpt_text, buffer);
- scored_snippets(
+ scored_declarations(
&index_state,
&excerpt,
- &excerpt_text,
+ &excerpt_occurrences,
+ &adjacent_occurrences,
references,
cursor_offset_in_file,
buffer,
@@ -81,7 +92,7 @@ impl EditPredictionContext {
excerpt,
excerpt_text,
cursor_offset_in_excerpt,
- snippets,
+ declarations,
})
}
}
@@ -137,7 +148,7 @@ mod tests {
.unwrap();
let mut snippet_identifiers = context
- .snippets
+ .declarations
.iter()
.map(|snippet| snippet.identifier.name.as_ref())
.collect::<Vec<_>>();
@@ -1,5 +1,9 @@
+use hashbrown::HashTable;
use regex::Regex;
-use std::{collections::HashMap, sync::LazyLock};
+use std::{
+ hash::{Hash, Hasher as _},
+ sync::LazyLock,
+};
use crate::reference::Reference;
@@ -14,49 +18,76 @@ use crate::reference::Reference;
static IDENTIFIER_REGEX: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"\b\w+\b").unwrap());
-// TODO: use &str or Cow<str> keys?
-#[derive(Debug)]
-pub struct IdentifierOccurrences {
- identifier_to_count: HashMap<String, usize>,
+/// Multiset of text occurrences for text similarity that only stores hashes and counts.
+#[derive(Debug, Default)]
+pub struct Occurrences {
+ table: HashTable<OccurrenceEntry>,
total_count: usize,
}
-impl IdentifierOccurrences {
- pub fn within_string(code: &str) -> Self {
- Self::from_iterator(IDENTIFIER_REGEX.find_iter(code).map(|mat| mat.as_str()))
+#[derive(Debug)]
+struct OccurrenceEntry {
+ hash: u64,
+ count: usize,
+}
+
+impl Occurrences {
+ pub fn within_string(text: &str) -> Self {
+ Self::from_identifiers(IDENTIFIER_REGEX.find_iter(text).map(|mat| mat.as_str()))
}
#[allow(dead_code)]
pub fn within_references(references: &[Reference]) -> Self {
- Self::from_iterator(
+ Self::from_identifiers(
references
.iter()
.map(|reference| reference.identifier.name.as_ref()),
)
}
- pub fn from_iterator<'a>(identifier_iterator: impl Iterator<Item = &'a str>) -> Self {
- let mut identifier_to_count = HashMap::new();
- let mut total_count = 0;
- for identifier in identifier_iterator {
- // TODO: Score matches that match case higher?
- //
- // TODO: Also include unsplit identifier?
+ pub fn from_identifiers<'a>(identifiers: impl IntoIterator<Item = &'a str>) -> Self {
+ let mut this = Self::default();
+ // TODO: Score matches that match case higher?
+ //
+ // TODO: Also include unsplit identifier?
+ for identifier in identifiers {
for identifier_part in split_identifier(identifier) {
- identifier_to_count
- .entry(identifier_part.to_lowercase())
- .and_modify(|count| *count += 1)
- .or_insert(1);
- total_count += 1;
+ this.add_hash(fx_hash(&identifier_part.to_lowercase()));
}
}
- IdentifierOccurrences {
- identifier_to_count,
- total_count,
- }
+ this
+ }
+
+ fn add_hash(&mut self, hash: u64) {
+ self.table
+ .entry(
+ hash,
+ |entry: &OccurrenceEntry| entry.hash == hash,
+ |entry| entry.hash,
+ )
+ .and_modify(|entry| entry.count += 1)
+ .or_insert(OccurrenceEntry { hash, count: 1 });
+ self.total_count += 1;
+ }
+
+ fn contains_hash(&self, hash: u64) -> bool {
+ self.get_count(hash) != 0
+ }
+
+ fn get_count(&self, hash: u64) -> usize {
+ self.table
+ .find(hash, |entry| entry.hash == hash)
+ .map(|entry| entry.count)
+ .unwrap_or(0)
}
}
+pub fn fx_hash<T: Hash + ?Sized>(data: &T) -> u64 {
+ let mut hasher = collections::FxHasher::default();
+ data.hash(&mut hasher);
+ hasher.finish()
+}
+
// Splits camelcase / snakecase / kebabcase / pascalcase
//
// TODO: Make this more efficient / elegant.
@@ -115,54 +146,49 @@ fn split_identifier(identifier: &str) -> Vec<&str> {
parts.into_iter().filter(|s| !s.is_empty()).collect()
}
-pub fn jaccard_similarity<'a>(
- mut set_a: &'a IdentifierOccurrences,
- mut set_b: &'a IdentifierOccurrences,
-) -> f32 {
- if set_a.identifier_to_count.len() > set_b.identifier_to_count.len() {
+pub fn jaccard_similarity<'a>(mut set_a: &'a Occurrences, mut set_b: &'a Occurrences) -> f32 {
+ if set_a.table.len() > set_b.table.len() {
std::mem::swap(&mut set_a, &mut set_b);
}
let intersection = set_a
- .identifier_to_count
- .keys()
- .filter(|key| set_b.identifier_to_count.contains_key(*key))
+ .table
+ .iter()
+ .filter(|entry| set_b.contains_hash(entry.hash))
.count();
- let union = set_a.identifier_to_count.len() + set_b.identifier_to_count.len() - intersection;
+ let union = set_a.table.len() + set_b.table.len() - intersection;
intersection as f32 / union as f32
}
// TODO
#[allow(dead_code)]
-pub fn overlap_coefficient<'a>(
- mut set_a: &'a IdentifierOccurrences,
- mut set_b: &'a IdentifierOccurrences,
-) -> f32 {
- if set_a.identifier_to_count.len() > set_b.identifier_to_count.len() {
+pub fn overlap_coefficient<'a>(mut set_a: &'a Occurrences, mut set_b: &'a Occurrences) -> f32 {
+ if set_a.table.len() > set_b.table.len() {
std::mem::swap(&mut set_a, &mut set_b);
}
let intersection = set_a
- .identifier_to_count
- .keys()
- .filter(|key| set_b.identifier_to_count.contains_key(*key))
+ .table
+ .iter()
+ .filter(|entry| set_b.contains_hash(entry.hash))
.count();
- intersection as f32 / set_a.identifier_to_count.len() as f32
+ intersection as f32 / set_a.table.len() as f32
}
// TODO
#[allow(dead_code)]
pub fn weighted_jaccard_similarity<'a>(
- mut set_a: &'a IdentifierOccurrences,
- mut set_b: &'a IdentifierOccurrences,
+ mut set_a: &'a Occurrences,
+ mut set_b: &'a Occurrences,
) -> f32 {
- if set_a.identifier_to_count.len() > set_b.identifier_to_count.len() {
+ if set_a.table.len() > set_b.table.len() {
std::mem::swap(&mut set_a, &mut set_b);
}
let mut numerator = 0;
let mut denominator_a = 0;
let mut used_count_b = 0;
- for (symbol, count_a) in set_a.identifier_to_count.iter() {
- let count_b = set_b.identifier_to_count.get(symbol).unwrap_or(&0);
+ for entry_a in set_a.table.iter() {
+ let count_a = entry_a.count;
+ let count_b = set_b.get_count(entry_a.hash);
numerator += count_a.min(count_b);
denominator_a += count_a.max(count_b);
used_count_b += count_b;
@@ -177,16 +203,17 @@ pub fn weighted_jaccard_similarity<'a>(
}
pub fn weighted_overlap_coefficient<'a>(
- mut set_a: &'a IdentifierOccurrences,
- mut set_b: &'a IdentifierOccurrences,
+ mut set_a: &'a Occurrences,
+ mut set_b: &'a Occurrences,
) -> f32 {
- if set_a.identifier_to_count.len() > set_b.identifier_to_count.len() {
+ if set_a.table.len() > set_b.table.len() {
std::mem::swap(&mut set_a, &mut set_b);
}
let mut numerator = 0;
- for (symbol, count_a) in set_a.identifier_to_count.iter() {
- let count_b = set_b.identifier_to_count.get(symbol).unwrap_or(&0);
+ for entry_a in set_a.table.iter() {
+ let count_a = entry_a.count;
+ let count_b = set_b.get_count(entry_a.hash);
numerator += count_a.min(count_b);
}
@@ -215,12 +242,12 @@ mod test {
fn test_similarity_functions() {
// 10 identifier parts, 8 unique
// Repeats: 2 "outline", 2 "items"
- let set_a = IdentifierOccurrences::within_string(
+ let set_a = Occurrences::within_string(
"let mut outline_items = query_outline_items(&language, &tree, &source);",
);
// 14 identifier parts, 11 unique
// Repeats: 2 "outline", 2 "language", 2 "tree"
- let set_b = IdentifierOccurrences::within_string(
+ let set_b = Occurrences::within_string(
"pub fn query_outline_items(language: &Language, tree: &Tree, source: &str) -> Vec<OutlineItem> {",
);
@@ -733,7 +733,7 @@ fn make_cloud_request(
let mut declaration_to_signature_index = HashMap::default();
let mut referenced_declarations = Vec::new();
- for snippet in context.snippets {
+ for snippet in context.declarations {
let project_entry_id = snippet.declaration.project_entry_id();
let Some(path) = worktrees.iter().find_map(|worktree| {
worktree.entry_for_id(project_entry_id).map(|entry| {
@@ -18,7 +18,7 @@ use util::{ResultExt, paths::PathStyle, rel_path::RelPath};
use workspace::{Item, SplitDirection, Workspace};
use zeta2::{Zeta, ZetaOptions};
-use edit_prediction_context::{EditPredictionExcerptOptions, SnippetStyle};
+use edit_prediction_context::{DeclarationStyle, EditPredictionExcerptOptions};
actions!(
dev,
@@ -285,7 +285,7 @@ impl Zeta2Inspector {
let mut languages = HashMap::default();
for lang_id in prediction
.context
- .snippets
+ .declarations
.iter()
.map(|snippet| snippet.declaration.identifier().language_id)
.chain(prediction.context.excerpt_text.language_id)
@@ -334,7 +334,7 @@ impl Zeta2Inspector {
cx,
);
- for snippet in &prediction.context.snippets {
+ for snippet in &prediction.context.declarations {
let path = this
.project
.read(cx)
@@ -345,7 +345,7 @@ impl Zeta2Inspector {
"{} (Score density: {})",
path.map(|p| p.path.display(path_style).to_string())
.unwrap_or_else(|| "".to_string()),
- snippet.score_density(SnippetStyle::Declaration)
+ snippet.score_density(DeclarationStyle::Declaration)
))
.unwrap()
.into(),