text_similarity.rs

  1use hashbrown::HashTable;
  2use regex::Regex;
  3use std::{
  4    borrow::Cow,
  5    hash::{Hash, Hasher as _},
  6    path::Path,
  7    sync::LazyLock,
  8};
  9use util::rel_path::RelPath;
 10
 11use crate::reference::Reference;
 12
 13// TODO: Consider implementing sliding window similarity matching like
 14// https://github.com/sourcegraph/cody-public-snapshot/blob/8e20ac6c1460c08b0db581c0204658112a246eda/vscode/src/completions/context/retrievers/jaccard-similarity/bestJaccardMatch.ts
 15//
 16// That implementation could actually be more efficient - no need to track words in the window that
 17// are not in the query.
 18
 19// TODO: Consider a flat sorted Vec<(String, usize)> representation. Intersection can just walk the
 20// two in parallel.
 21
 22static IDENTIFIER_REGEX: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"\b\w+\b").unwrap());
 23
 24/// Multiset of text occurrences for text similarity that only stores hashes and counts.
 25#[derive(Debug, Default)]
 26pub struct Occurrences {
 27    table: HashTable<OccurrenceEntry>,
 28    total_count: usize,
 29}
 30
 31#[derive(Debug)]
 32struct OccurrenceEntry {
 33    hash: u64,
 34    count: usize,
 35}
 36
 37impl Occurrences {
 38    pub fn within_string(text: &str) -> Self {
 39        Self::from_identifiers(IDENTIFIER_REGEX.find_iter(text).map(|mat| mat.as_str()))
 40    }
 41
 42    #[allow(dead_code)]
 43    pub fn within_references(references: &[Reference]) -> Self {
 44        Self::from_identifiers(
 45            references
 46                .iter()
 47                .map(|reference| reference.identifier.name.as_ref()),
 48        )
 49    }
 50
 51    pub fn from_identifiers(identifiers: impl IntoIterator<Item = impl AsRef<str>>) -> Self {
 52        let mut this = Self::default();
 53        // TODO: Score matches that match case higher?
 54        //
 55        // TODO: Also include unsplit identifier?
 56        for identifier in identifiers {
 57            for identifier_part in split_identifier(identifier.as_ref()) {
 58                this.add_hash(fx_hash(&identifier_part.to_lowercase()));
 59            }
 60        }
 61        this
 62    }
 63
 64    pub fn from_worktree_path(worktree_name: Option<Cow<'_, str>>, rel_path: &RelPath) -> Self {
 65        if let Some(worktree_name) = worktree_name {
 66            Self::from_identifiers(
 67                std::iter::once(worktree_name)
 68                    .chain(iter_path_without_extension(rel_path.as_std_path())),
 69            )
 70        } else {
 71            Self::from_path(rel_path.as_std_path())
 72        }
 73    }
 74
 75    pub fn from_path(path: &Path) -> Self {
 76        Self::from_identifiers(iter_path_without_extension(path))
 77    }
 78
 79    fn add_hash(&mut self, hash: u64) {
 80        self.table
 81            .entry(
 82                hash,
 83                |entry: &OccurrenceEntry| entry.hash == hash,
 84                |entry| entry.hash,
 85            )
 86            .and_modify(|entry| entry.count += 1)
 87            .or_insert(OccurrenceEntry { hash, count: 1 });
 88        self.total_count += 1;
 89    }
 90
 91    fn contains_hash(&self, hash: u64) -> bool {
 92        self.get_count(hash) != 0
 93    }
 94
 95    fn get_count(&self, hash: u64) -> usize {
 96        self.table
 97            .find(hash, |entry| entry.hash == hash)
 98            .map(|entry| entry.count)
 99            .unwrap_or(0)
100    }
101}
102
103fn iter_path_without_extension(path: &Path) -> impl Iterator<Item = Cow<'_, str>> {
104    let last_component: Option<Cow<'_, str>> = path.file_stem().map(|stem| stem.to_string_lossy());
105    let mut path_components = path.components();
106    path_components.next_back();
107    path_components
108        .map(|component| component.as_os_str().to_string_lossy())
109        .chain(last_component)
110}
111
112pub fn fx_hash<T: Hash + ?Sized>(data: &T) -> u64 {
113    let mut hasher = collections::FxHasher::default();
114    data.hash(&mut hasher);
115    hasher.finish()
116}
117
118// Splits camelcase / snakecase / kebabcase / pascalcase
119//
120// TODO: Make this more efficient / elegant.
121fn split_identifier(identifier: &str) -> Vec<&str> {
122    let mut parts = Vec::new();
123    let mut start = 0;
124    let chars: Vec<char> = identifier.chars().collect();
125
126    if chars.is_empty() {
127        return parts;
128    }
129
130    let mut i = 0;
131    while i < chars.len() {
132        let ch = chars[i];
133
134        // Handle explicit delimiters (underscore and hyphen)
135        if ch == '_' || ch == '-' {
136            if i > start {
137                parts.push(&identifier[start..i]);
138            }
139            start = i + 1;
140            i += 1;
141            continue;
142        }
143
144        // Handle camelCase and PascalCase transitions
145        if i > 0 && i < chars.len() {
146            let prev_char = chars[i - 1];
147
148            // Transition from lowercase/digit to uppercase
149            if (prev_char.is_lowercase() || prev_char.is_ascii_digit()) && ch.is_uppercase() {
150                parts.push(&identifier[start..i]);
151                start = i;
152            }
153            // Handle sequences like "XMLParser" -> ["XML", "Parser"]
154            else if i + 1 < chars.len()
155                && ch.is_uppercase()
156                && chars[i + 1].is_lowercase()
157                && prev_char.is_uppercase()
158            {
159                parts.push(&identifier[start..i]);
160                start = i;
161            }
162        }
163
164        i += 1;
165    }
166
167    // Add the last part if there's any remaining
168    if start < identifier.len() {
169        parts.push(&identifier[start..]);
170    }
171
172    // Filter out empty strings
173    parts.into_iter().filter(|s| !s.is_empty()).collect()
174}
175
176pub fn jaccard_similarity<'a>(mut set_a: &'a Occurrences, mut set_b: &'a Occurrences) -> f32 {
177    if set_a.table.len() > set_b.table.len() {
178        std::mem::swap(&mut set_a, &mut set_b);
179    }
180    let intersection = set_a
181        .table
182        .iter()
183        .filter(|entry| set_b.contains_hash(entry.hash))
184        .count();
185    let union = set_a.table.len() + set_b.table.len() - intersection;
186    intersection as f32 / union as f32
187}
188
189// TODO
190#[allow(dead_code)]
191pub fn overlap_coefficient<'a>(mut set_a: &'a Occurrences, mut set_b: &'a Occurrences) -> f32 {
192    if set_a.table.len() > set_b.table.len() {
193        std::mem::swap(&mut set_a, &mut set_b);
194    }
195    let intersection = set_a
196        .table
197        .iter()
198        .filter(|entry| set_b.contains_hash(entry.hash))
199        .count();
200    intersection as f32 / set_a.table.len() as f32
201}
202
203// TODO
204#[allow(dead_code)]
205pub fn weighted_jaccard_similarity<'a>(
206    mut set_a: &'a Occurrences,
207    mut set_b: &'a Occurrences,
208) -> f32 {
209    if set_a.table.len() > set_b.table.len() {
210        std::mem::swap(&mut set_a, &mut set_b);
211    }
212
213    let mut numerator = 0;
214    let mut denominator_a = 0;
215    let mut used_count_b = 0;
216    for entry_a in set_a.table.iter() {
217        let count_a = entry_a.count;
218        let count_b = set_b.get_count(entry_a.hash);
219        numerator += count_a.min(count_b);
220        denominator_a += count_a.max(count_b);
221        used_count_b += count_b;
222    }
223
224    let denominator = denominator_a + (set_b.total_count - used_count_b);
225    if denominator == 0 {
226        0.0
227    } else {
228        numerator as f32 / denominator as f32
229    }
230}
231
232pub fn weighted_overlap_coefficient<'a>(
233    mut set_a: &'a Occurrences,
234    mut set_b: &'a Occurrences,
235) -> f32 {
236    if set_a.table.len() > set_b.table.len() {
237        std::mem::swap(&mut set_a, &mut set_b);
238    }
239
240    let mut numerator = 0;
241    for entry_a in set_a.table.iter() {
242        let count_a = entry_a.count;
243        let count_b = set_b.get_count(entry_a.hash);
244        numerator += count_a.min(count_b);
245    }
246
247    let denominator = set_a.total_count.min(set_b.total_count);
248    if denominator == 0 {
249        0.0
250    } else {
251        numerator as f32 / denominator as f32
252    }
253}
254
255#[cfg(test)]
256mod test {
257    use super::*;
258
259    #[test]
260    fn test_split_identifier() {
261        assert_eq!(split_identifier("snake_case"), vec!["snake", "case"]);
262        assert_eq!(split_identifier("kebab-case"), vec!["kebab", "case"]);
263        assert_eq!(split_identifier("PascalCase"), vec!["Pascal", "Case"]);
264        assert_eq!(split_identifier("camelCase"), vec!["camel", "Case"]);
265        assert_eq!(split_identifier("XMLParser"), vec!["XML", "Parser"]);
266    }
267
268    #[test]
269    fn test_similarity_functions() {
270        // 10 identifier parts, 8 unique
271        // Repeats: 2 "outline", 2 "items"
272        let set_a = Occurrences::within_string(
273            "let mut outline_items = query_outline_items(&language, &tree, &source);",
274        );
275        // 14 identifier parts, 11 unique
276        // Repeats: 2 "outline", 2 "language", 2 "tree"
277        let set_b = Occurrences::within_string(
278            "pub fn query_outline_items(language: &Language, tree: &Tree, source: &str) -> Vec<OutlineItem> {",
279        );
280
281        // 6 overlaps: "outline", "items", "query", "language", "tree", "source"
282        // 7 non-overlaps: "let", "mut", "pub", "fn", "vec", "item", "str"
283        assert_eq!(jaccard_similarity(&set_a, &set_b), 6.0 / (6.0 + 7.0));
284
285        // Numerator is one more than before due to both having 2 "outline".
286        // Denominator is the same except for 3 more due to the non-overlapping duplicates
287        assert_eq!(
288            weighted_jaccard_similarity(&set_a, &set_b),
289            7.0 / (7.0 + 7.0 + 3.0)
290        );
291
292        // Numerator is the same as jaccard_similarity. Denominator is the size of the smaller set, 8.
293        assert_eq!(overlap_coefficient(&set_a, &set_b), 6.0 / 8.0);
294
295        // Numerator is the same as weighted_jaccard_similarity. Denominator is the total weight of
296        // the smaller set, 10.
297        assert_eq!(weighted_overlap_coefficient(&set_a, &set_b), 7.0 / 10.0);
298    }
299
300    #[test]
301    fn test_iter_path_without_extension() {
302        let mut iter = iter_path_without_extension(Path::new(""));
303        assert_eq!(iter.next(), None);
304
305        let iter = iter_path_without_extension(Path::new("foo"));
306        assert_eq!(iter.collect::<Vec<_>>(), ["foo"]);
307
308        let iter = iter_path_without_extension(Path::new("foo/bar.txt"));
309        assert_eq!(iter.collect::<Vec<_>>(), ["foo", "bar"]);
310
311        let iter = iter_path_without_extension(Path::new("foo/bar/baz.txt"));
312        assert_eq!(iter.collect::<Vec<_>>(), ["foo", "bar", "baz"]);
313    }
314}