text_similarity.rs

  1use hashbrown::HashTable;
  2use regex::Regex;
  3use std::{
  4    hash::{Hash, Hasher as _},
  5    sync::LazyLock,
  6};
  7
  8use crate::reference::Reference;
  9
 10// TODO: Consider implementing sliding window similarity matching like
 11// https://github.com/sourcegraph/cody-public-snapshot/blob/8e20ac6c1460c08b0db581c0204658112a246eda/vscode/src/completions/context/retrievers/jaccard-similarity/bestJaccardMatch.ts
 12//
 13// That implementation could actually be more efficient - no need to track words in the window that
 14// are not in the query.
 15
 16// TODO: Consider a flat sorted Vec<(String, usize)> representation. Intersection can just walk the
 17// two in parallel.
 18
 19static IDENTIFIER_REGEX: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"\b\w+\b").unwrap());
 20
 21/// Multiset of text occurrences for text similarity that only stores hashes and counts.
 22#[derive(Debug, Default)]
 23pub struct Occurrences {
 24    table: HashTable<OccurrenceEntry>,
 25    total_count: usize,
 26}
 27
 28#[derive(Debug)]
 29struct OccurrenceEntry {
 30    hash: u64,
 31    count: usize,
 32}
 33
 34impl Occurrences {
 35    pub fn within_string(text: &str) -> Self {
 36        Self::from_identifiers(IDENTIFIER_REGEX.find_iter(text).map(|mat| mat.as_str()))
 37    }
 38
 39    #[allow(dead_code)]
 40    pub fn within_references(references: &[Reference]) -> Self {
 41        Self::from_identifiers(
 42            references
 43                .iter()
 44                .map(|reference| reference.identifier.name.as_ref()),
 45        )
 46    }
 47
 48    pub fn from_identifiers<'a>(identifiers: impl IntoIterator<Item = &'a str>) -> Self {
 49        let mut this = Self::default();
 50        // TODO: Score matches that match case higher?
 51        //
 52        // TODO: Also include unsplit identifier?
 53        for identifier in identifiers {
 54            for identifier_part in split_identifier(identifier) {
 55                this.add_hash(fx_hash(&identifier_part.to_lowercase()));
 56            }
 57        }
 58        this
 59    }
 60
 61    fn add_hash(&mut self, hash: u64) {
 62        self.table
 63            .entry(
 64                hash,
 65                |entry: &OccurrenceEntry| entry.hash == hash,
 66                |entry| entry.hash,
 67            )
 68            .and_modify(|entry| entry.count += 1)
 69            .or_insert(OccurrenceEntry { hash, count: 1 });
 70        self.total_count += 1;
 71    }
 72
 73    fn contains_hash(&self, hash: u64) -> bool {
 74        self.get_count(hash) != 0
 75    }
 76
 77    fn get_count(&self, hash: u64) -> usize {
 78        self.table
 79            .find(hash, |entry| entry.hash == hash)
 80            .map(|entry| entry.count)
 81            .unwrap_or(0)
 82    }
 83}
 84
 85pub fn fx_hash<T: Hash + ?Sized>(data: &T) -> u64 {
 86    let mut hasher = collections::FxHasher::default();
 87    data.hash(&mut hasher);
 88    hasher.finish()
 89}
 90
 91// Splits camelcase / snakecase / kebabcase / pascalcase
 92//
 93// TODO: Make this more efficient / elegant.
 94fn split_identifier(identifier: &str) -> Vec<&str> {
 95    let mut parts = Vec::new();
 96    let mut start = 0;
 97    let chars: Vec<char> = identifier.chars().collect();
 98
 99    if chars.is_empty() {
100        return parts;
101    }
102
103    let mut i = 0;
104    while i < chars.len() {
105        let ch = chars[i];
106
107        // Handle explicit delimiters (underscore and hyphen)
108        if ch == '_' || ch == '-' {
109            if i > start {
110                parts.push(&identifier[start..i]);
111            }
112            start = i + 1;
113            i += 1;
114            continue;
115        }
116
117        // Handle camelCase and PascalCase transitions
118        if i > 0 && i < chars.len() {
119            let prev_char = chars[i - 1];
120
121            // Transition from lowercase/digit to uppercase
122            if (prev_char.is_lowercase() || prev_char.is_ascii_digit()) && ch.is_uppercase() {
123                parts.push(&identifier[start..i]);
124                start = i;
125            }
126            // Handle sequences like "XMLParser" -> ["XML", "Parser"]
127            else if i + 1 < chars.len()
128                && ch.is_uppercase()
129                && chars[i + 1].is_lowercase()
130                && prev_char.is_uppercase()
131            {
132                parts.push(&identifier[start..i]);
133                start = i;
134            }
135        }
136
137        i += 1;
138    }
139
140    // Add the last part if there's any remaining
141    if start < identifier.len() {
142        parts.push(&identifier[start..]);
143    }
144
145    // Filter out empty strings
146    parts.into_iter().filter(|s| !s.is_empty()).collect()
147}
148
149pub fn jaccard_similarity<'a>(mut set_a: &'a Occurrences, mut set_b: &'a Occurrences) -> f32 {
150    if set_a.table.len() > set_b.table.len() {
151        std::mem::swap(&mut set_a, &mut set_b);
152    }
153    let intersection = set_a
154        .table
155        .iter()
156        .filter(|entry| set_b.contains_hash(entry.hash))
157        .count();
158    let union = set_a.table.len() + set_b.table.len() - intersection;
159    intersection as f32 / union as f32
160}
161
162// TODO
163#[allow(dead_code)]
164pub fn overlap_coefficient<'a>(mut set_a: &'a Occurrences, mut set_b: &'a Occurrences) -> f32 {
165    if set_a.table.len() > set_b.table.len() {
166        std::mem::swap(&mut set_a, &mut set_b);
167    }
168    let intersection = set_a
169        .table
170        .iter()
171        .filter(|entry| set_b.contains_hash(entry.hash))
172        .count();
173    intersection as f32 / set_a.table.len() as f32
174}
175
176// TODO
177#[allow(dead_code)]
178pub fn weighted_jaccard_similarity<'a>(
179    mut set_a: &'a Occurrences,
180    mut set_b: &'a Occurrences,
181) -> f32 {
182    if set_a.table.len() > set_b.table.len() {
183        std::mem::swap(&mut set_a, &mut set_b);
184    }
185
186    let mut numerator = 0;
187    let mut denominator_a = 0;
188    let mut used_count_b = 0;
189    for entry_a in set_a.table.iter() {
190        let count_a = entry_a.count;
191        let count_b = set_b.get_count(entry_a.hash);
192        numerator += count_a.min(count_b);
193        denominator_a += count_a.max(count_b);
194        used_count_b += count_b;
195    }
196
197    let denominator = denominator_a + (set_b.total_count - used_count_b);
198    if denominator == 0 {
199        0.0
200    } else {
201        numerator as f32 / denominator as f32
202    }
203}
204
205pub fn weighted_overlap_coefficient<'a>(
206    mut set_a: &'a Occurrences,
207    mut set_b: &'a Occurrences,
208) -> f32 {
209    if set_a.table.len() > set_b.table.len() {
210        std::mem::swap(&mut set_a, &mut set_b);
211    }
212
213    let mut numerator = 0;
214    for entry_a in set_a.table.iter() {
215        let count_a = entry_a.count;
216        let count_b = set_b.get_count(entry_a.hash);
217        numerator += count_a.min(count_b);
218    }
219
220    let denominator = set_a.total_count.min(set_b.total_count);
221    if denominator == 0 {
222        0.0
223    } else {
224        numerator as f32 / denominator as f32
225    }
226}
227
228#[cfg(test)]
229mod test {
230    use super::*;
231
232    #[test]
233    fn test_split_identifier() {
234        assert_eq!(split_identifier("snake_case"), vec!["snake", "case"]);
235        assert_eq!(split_identifier("kebab-case"), vec!["kebab", "case"]);
236        assert_eq!(split_identifier("PascalCase"), vec!["Pascal", "Case"]);
237        assert_eq!(split_identifier("camelCase"), vec!["camel", "Case"]);
238        assert_eq!(split_identifier("XMLParser"), vec!["XML", "Parser"]);
239    }
240
241    #[test]
242    fn test_similarity_functions() {
243        // 10 identifier parts, 8 unique
244        // Repeats: 2 "outline", 2 "items"
245        let set_a = Occurrences::within_string(
246            "let mut outline_items = query_outline_items(&language, &tree, &source);",
247        );
248        // 14 identifier parts, 11 unique
249        // Repeats: 2 "outline", 2 "language", 2 "tree"
250        let set_b = Occurrences::within_string(
251            "pub fn query_outline_items(language: &Language, tree: &Tree, source: &str) -> Vec<OutlineItem> {",
252        );
253
254        // 6 overlaps: "outline", "items", "query", "language", "tree", "source"
255        // 7 non-overlaps: "let", "mut", "pub", "fn", "vec", "item", "str"
256        assert_eq!(jaccard_similarity(&set_a, &set_b), 6.0 / (6.0 + 7.0));
257
258        // Numerator is one more than before due to both having 2 "outline".
259        // Denominator is the same except for 3 more due to the non-overlapping duplicates
260        assert_eq!(
261            weighted_jaccard_similarity(&set_a, &set_b),
262            7.0 / (7.0 + 7.0 + 3.0)
263        );
264
265        // Numerator is the same as jaccard_similarity. Denominator is the size of the smaller set, 8.
266        assert_eq!(overlap_coefficient(&set_a, &set_b), 6.0 / 8.0);
267
268        // Numerator is the same as weighted_jaccard_similarity. Denominator is the total weight of
269        // the smaller set, 10.
270        assert_eq!(weighted_overlap_coefficient(&set_a, &set_b), 7.0 / 10.0);
271    }
272}