text_similarity.rs

  1use regex::Regex;
  2use std::{collections::HashMap, sync::LazyLock};
  3
  4use crate::reference::Reference;
  5
  6// TODO: Consider implementing sliding window similarity matching like
  7// https://github.com/sourcegraph/cody-public-snapshot/blob/8e20ac6c1460c08b0db581c0204658112a246eda/vscode/src/completions/context/retrievers/jaccard-similarity/bestJaccardMatch.ts
  8//
  9// That implementation could actually be more efficient - no need to track words in the window that
 10// are not in the query.
 11
 12static IDENTIFIER_REGEX: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"\b\w+\b").unwrap());
 13
 14#[derive(Debug)]
 15pub struct IdentifierOccurrences {
 16    identifier_to_count: HashMap<String, usize>,
 17    total_count: usize,
 18}
 19
 20impl IdentifierOccurrences {
 21    pub fn within_string(code: &str) -> Self {
 22        Self::from_iterator(IDENTIFIER_REGEX.find_iter(code).map(|mat| mat.as_str()))
 23    }
 24
 25    #[allow(dead_code)]
 26    pub fn within_references(references: &[Reference]) -> Self {
 27        Self::from_iterator(
 28            references
 29                .iter()
 30                .map(|reference| reference.identifier.name.as_ref()),
 31        )
 32    }
 33
 34    pub fn from_iterator<'a>(identifier_iterator: impl Iterator<Item = &'a str>) -> Self {
 35        let mut identifier_to_count = HashMap::new();
 36        let mut total_count = 0;
 37        for identifier in identifier_iterator {
 38            // TODO: Score matches that match case higher?
 39            //
 40            // TODO: Also include unsplit identifier?
 41            for identifier_part in split_identifier(identifier) {
 42                identifier_to_count
 43                    .entry(identifier_part.to_lowercase())
 44                    .and_modify(|count| *count += 1)
 45                    .or_insert(1);
 46                total_count += 1;
 47            }
 48        }
 49        IdentifierOccurrences {
 50            identifier_to_count,
 51            total_count,
 52        }
 53    }
 54}
 55
 56// Splits camelcase / snakecase / kebabcase / pascalcase
 57//
 58// TODO: Make this more efficient / elegant.
 59fn split_identifier(identifier: &str) -> Vec<&str> {
 60    let mut parts = Vec::new();
 61    let mut start = 0;
 62    let chars: Vec<char> = identifier.chars().collect();
 63
 64    if chars.is_empty() {
 65        return parts;
 66    }
 67
 68    let mut i = 0;
 69    while i < chars.len() {
 70        let ch = chars[i];
 71
 72        // Handle explicit delimiters (underscore and hyphen)
 73        if ch == '_' || ch == '-' {
 74            if i > start {
 75                parts.push(&identifier[start..i]);
 76            }
 77            start = i + 1;
 78            i += 1;
 79            continue;
 80        }
 81
 82        // Handle camelCase and PascalCase transitions
 83        if i > 0 && i < chars.len() {
 84            let prev_char = chars[i - 1];
 85
 86            // Transition from lowercase/digit to uppercase
 87            if (prev_char.is_lowercase() || prev_char.is_ascii_digit()) && ch.is_uppercase() {
 88                parts.push(&identifier[start..i]);
 89                start = i;
 90            }
 91            // Handle sequences like "XMLParser" -> ["XML", "Parser"]
 92            else if i + 1 < chars.len()
 93                && ch.is_uppercase()
 94                && chars[i + 1].is_lowercase()
 95                && prev_char.is_uppercase()
 96            {
 97                parts.push(&identifier[start..i]);
 98                start = i;
 99            }
100        }
101
102        i += 1;
103    }
104
105    // Add the last part if there's any remaining
106    if start < identifier.len() {
107        parts.push(&identifier[start..]);
108    }
109
110    // Filter out empty strings
111    parts.into_iter().filter(|s| !s.is_empty()).collect()
112}
113
114pub fn jaccard_similarity<'a>(
115    mut set_a: &'a IdentifierOccurrences,
116    mut set_b: &'a IdentifierOccurrences,
117) -> f32 {
118    if set_a.identifier_to_count.len() > set_b.identifier_to_count.len() {
119        std::mem::swap(&mut set_a, &mut set_b);
120    }
121    let intersection = set_a
122        .identifier_to_count
123        .keys()
124        .filter(|key| set_b.identifier_to_count.contains_key(*key))
125        .count();
126    let union = set_a.identifier_to_count.len() + set_b.identifier_to_count.len() - intersection;
127    intersection as f32 / union as f32
128}
129
130// TODO
131#[allow(dead_code)]
132pub fn overlap_coefficient<'a>(
133    mut set_a: &'a IdentifierOccurrences,
134    mut set_b: &'a IdentifierOccurrences,
135) -> f32 {
136    if set_a.identifier_to_count.len() > set_b.identifier_to_count.len() {
137        std::mem::swap(&mut set_a, &mut set_b);
138    }
139    let intersection = set_a
140        .identifier_to_count
141        .keys()
142        .filter(|key| set_b.identifier_to_count.contains_key(*key))
143        .count();
144    intersection as f32 / set_a.identifier_to_count.len() as f32
145}
146
147// TODO
148#[allow(dead_code)]
149pub fn weighted_jaccard_similarity<'a>(
150    mut set_a: &'a IdentifierOccurrences,
151    mut set_b: &'a IdentifierOccurrences,
152) -> f32 {
153    if set_a.identifier_to_count.len() > set_b.identifier_to_count.len() {
154        std::mem::swap(&mut set_a, &mut set_b);
155    }
156
157    let mut numerator = 0;
158    let mut denominator_a = 0;
159    let mut used_count_b = 0;
160    for (symbol, count_a) in set_a.identifier_to_count.iter() {
161        let count_b = set_b.identifier_to_count.get(symbol).unwrap_or(&0);
162        numerator += count_a.min(count_b);
163        denominator_a += count_a.max(count_b);
164        used_count_b += count_b;
165    }
166
167    let denominator = denominator_a + (set_b.total_count - used_count_b);
168    if denominator == 0 {
169        0.0
170    } else {
171        numerator as f32 / denominator as f32
172    }
173}
174
175pub fn weighted_overlap_coefficient<'a>(
176    mut set_a: &'a IdentifierOccurrences,
177    mut set_b: &'a IdentifierOccurrences,
178) -> f32 {
179    if set_a.identifier_to_count.len() > set_b.identifier_to_count.len() {
180        std::mem::swap(&mut set_a, &mut set_b);
181    }
182
183    let mut numerator = 0;
184    for (symbol, count_a) in set_a.identifier_to_count.iter() {
185        let count_b = set_b.identifier_to_count.get(symbol).unwrap_or(&0);
186        numerator += count_a.min(count_b);
187    }
188
189    let denominator = set_a.total_count.min(set_b.total_count);
190    if denominator == 0 {
191        0.0
192    } else {
193        numerator as f32 / denominator as f32
194    }
195}
196
197#[cfg(test)]
198mod test {
199    use super::*;
200
201    #[test]
202    fn test_split_identifier() {
203        assert_eq!(split_identifier("snake_case"), vec!["snake", "case"]);
204        assert_eq!(split_identifier("kebab-case"), vec!["kebab", "case"]);
205        assert_eq!(split_identifier("PascalCase"), vec!["Pascal", "Case"]);
206        assert_eq!(split_identifier("camelCase"), vec!["camel", "Case"]);
207        assert_eq!(split_identifier("XMLParser"), vec!["XML", "Parser"]);
208    }
209
210    #[test]
211    fn test_similarity_functions() {
212        // 10 identifier parts, 8 unique
213        // Repeats: 2 "outline", 2 "items"
214        let set_a = IdentifierOccurrences::within_string(
215            "let mut outline_items = query_outline_items(&language, &tree, &source);",
216        );
217        // 14 identifier parts, 11 unique
218        // Repeats: 2 "outline", 2 "language", 2 "tree"
219        let set_b = IdentifierOccurrences::within_string(
220            "pub fn query_outline_items(language: &Language, tree: &Tree, source: &str) -> Vec<OutlineItem> {",
221        );
222
223        // 6 overlaps: "outline", "items", "query", "language", "tree", "source"
224        // 7 non-overlaps: "let", "mut", "pub", "fn", "vec", "item", "str"
225        assert_eq!(jaccard_similarity(&set_a, &set_b), 6.0 / (6.0 + 7.0));
226
227        // Numerator is one more than before due to both having 2 "outline".
228        // Denominator is the same except for 3 more due to the non-overlapping duplicates
229        assert_eq!(
230            weighted_jaccard_similarity(&set_a, &set_b),
231            7.0 / (7.0 + 7.0 + 3.0)
232        );
233
234        // Numerator is the same as jaccard_similarity. Denominator is the size of the smaller set, 8.
235        assert_eq!(overlap_coefficient(&set_a, &set_b), 6.0 / 8.0);
236
237        // Numerator is the same as weighted_jaccard_similarity. Denominator is the total weight of
238        // the smaller set, 10.
239        assert_eq!(weighted_overlap_coefficient(&set_a, &set_b), 7.0 / 10.0);
240    }
241}