text_similarity.rs

  1use regex::Regex;
  2use std::{collections::HashMap, sync::LazyLock};
  3
  4use crate::reference::Reference;
  5
  6// TODO: Consider implementing sliding window similarity matching like
  7// https://github.com/sourcegraph/cody-public-snapshot/blob/8e20ac6c1460c08b0db581c0204658112a246eda/vscode/src/completions/context/retrievers/jaccard-similarity/bestJaccardMatch.ts
  8//
  9// That implementation could actually be more efficient - no need to track words in the window that
 10// are not in the query.
 11
 12// TODO: Consider a flat sorted Vec<(String, usize)> representation. Intersection can just walk the
 13// two in parallel.
 14
 15static IDENTIFIER_REGEX: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"\b\w+\b").unwrap());
 16
 17// TODO: use &str or Cow<str> keys?
 18#[derive(Debug)]
 19pub struct IdentifierOccurrences {
 20    identifier_to_count: HashMap<String, usize>,
 21    total_count: usize,
 22}
 23
 24impl IdentifierOccurrences {
 25    pub fn within_string(code: &str) -> Self {
 26        Self::from_iterator(IDENTIFIER_REGEX.find_iter(code).map(|mat| mat.as_str()))
 27    }
 28
 29    #[allow(dead_code)]
 30    pub fn within_references(references: &[Reference]) -> Self {
 31        Self::from_iterator(
 32            references
 33                .iter()
 34                .map(|reference| reference.identifier.name.as_ref()),
 35        )
 36    }
 37
 38    pub fn from_iterator<'a>(identifier_iterator: impl Iterator<Item = &'a str>) -> Self {
 39        let mut identifier_to_count = HashMap::new();
 40        let mut total_count = 0;
 41        for identifier in identifier_iterator {
 42            // TODO: Score matches that match case higher?
 43            //
 44            // TODO: Also include unsplit identifier?
 45            for identifier_part in split_identifier(identifier) {
 46                identifier_to_count
 47                    .entry(identifier_part.to_lowercase())
 48                    .and_modify(|count| *count += 1)
 49                    .or_insert(1);
 50                total_count += 1;
 51            }
 52        }
 53        IdentifierOccurrences {
 54            identifier_to_count,
 55            total_count,
 56        }
 57    }
 58}
 59
 60// Splits camelcase / snakecase / kebabcase / pascalcase
 61//
 62// TODO: Make this more efficient / elegant.
 63fn split_identifier(identifier: &str) -> Vec<&str> {
 64    let mut parts = Vec::new();
 65    let mut start = 0;
 66    let chars: Vec<char> = identifier.chars().collect();
 67
 68    if chars.is_empty() {
 69        return parts;
 70    }
 71
 72    let mut i = 0;
 73    while i < chars.len() {
 74        let ch = chars[i];
 75
 76        // Handle explicit delimiters (underscore and hyphen)
 77        if ch == '_' || ch == '-' {
 78            if i > start {
 79                parts.push(&identifier[start..i]);
 80            }
 81            start = i + 1;
 82            i += 1;
 83            continue;
 84        }
 85
 86        // Handle camelCase and PascalCase transitions
 87        if i > 0 && i < chars.len() {
 88            let prev_char = chars[i - 1];
 89
 90            // Transition from lowercase/digit to uppercase
 91            if (prev_char.is_lowercase() || prev_char.is_ascii_digit()) && ch.is_uppercase() {
 92                parts.push(&identifier[start..i]);
 93                start = i;
 94            }
 95            // Handle sequences like "XMLParser" -> ["XML", "Parser"]
 96            else if i + 1 < chars.len()
 97                && ch.is_uppercase()
 98                && chars[i + 1].is_lowercase()
 99                && prev_char.is_uppercase()
100            {
101                parts.push(&identifier[start..i]);
102                start = i;
103            }
104        }
105
106        i += 1;
107    }
108
109    // Add the last part if there's any remaining
110    if start < identifier.len() {
111        parts.push(&identifier[start..]);
112    }
113
114    // Filter out empty strings
115    parts.into_iter().filter(|s| !s.is_empty()).collect()
116}
117
118pub fn jaccard_similarity<'a>(
119    mut set_a: &'a IdentifierOccurrences,
120    mut set_b: &'a IdentifierOccurrences,
121) -> f32 {
122    if set_a.identifier_to_count.len() > set_b.identifier_to_count.len() {
123        std::mem::swap(&mut set_a, &mut set_b);
124    }
125    let intersection = set_a
126        .identifier_to_count
127        .keys()
128        .filter(|key| set_b.identifier_to_count.contains_key(*key))
129        .count();
130    let union = set_a.identifier_to_count.len() + set_b.identifier_to_count.len() - intersection;
131    intersection as f32 / union as f32
132}
133
134// TODO
135#[allow(dead_code)]
136pub fn overlap_coefficient<'a>(
137    mut set_a: &'a IdentifierOccurrences,
138    mut set_b: &'a IdentifierOccurrences,
139) -> f32 {
140    if set_a.identifier_to_count.len() > set_b.identifier_to_count.len() {
141        std::mem::swap(&mut set_a, &mut set_b);
142    }
143    let intersection = set_a
144        .identifier_to_count
145        .keys()
146        .filter(|key| set_b.identifier_to_count.contains_key(*key))
147        .count();
148    intersection as f32 / set_a.identifier_to_count.len() as f32
149}
150
151// TODO
152#[allow(dead_code)]
153pub fn weighted_jaccard_similarity<'a>(
154    mut set_a: &'a IdentifierOccurrences,
155    mut set_b: &'a IdentifierOccurrences,
156) -> f32 {
157    if set_a.identifier_to_count.len() > set_b.identifier_to_count.len() {
158        std::mem::swap(&mut set_a, &mut set_b);
159    }
160
161    let mut numerator = 0;
162    let mut denominator_a = 0;
163    let mut used_count_b = 0;
164    for (symbol, count_a) in set_a.identifier_to_count.iter() {
165        let count_b = set_b.identifier_to_count.get(symbol).unwrap_or(&0);
166        numerator += count_a.min(count_b);
167        denominator_a += count_a.max(count_b);
168        used_count_b += count_b;
169    }
170
171    let denominator = denominator_a + (set_b.total_count - used_count_b);
172    if denominator == 0 {
173        0.0
174    } else {
175        numerator as f32 / denominator as f32
176    }
177}
178
179pub fn weighted_overlap_coefficient<'a>(
180    mut set_a: &'a IdentifierOccurrences,
181    mut set_b: &'a IdentifierOccurrences,
182) -> f32 {
183    if set_a.identifier_to_count.len() > set_b.identifier_to_count.len() {
184        std::mem::swap(&mut set_a, &mut set_b);
185    }
186
187    let mut numerator = 0;
188    for (symbol, count_a) in set_a.identifier_to_count.iter() {
189        let count_b = set_b.identifier_to_count.get(symbol).unwrap_or(&0);
190        numerator += count_a.min(count_b);
191    }
192
193    let denominator = set_a.total_count.min(set_b.total_count);
194    if denominator == 0 {
195        0.0
196    } else {
197        numerator as f32 / denominator as f32
198    }
199}
200
201#[cfg(test)]
202mod test {
203    use super::*;
204
205    #[test]
206    fn test_split_identifier() {
207        assert_eq!(split_identifier("snake_case"), vec!["snake", "case"]);
208        assert_eq!(split_identifier("kebab-case"), vec!["kebab", "case"]);
209        assert_eq!(split_identifier("PascalCase"), vec!["Pascal", "Case"]);
210        assert_eq!(split_identifier("camelCase"), vec!["camel", "Case"]);
211        assert_eq!(split_identifier("XMLParser"), vec!["XML", "Parser"]);
212    }
213
214    #[test]
215    fn test_similarity_functions() {
216        // 10 identifier parts, 8 unique
217        // Repeats: 2 "outline", 2 "items"
218        let set_a = IdentifierOccurrences::within_string(
219            "let mut outline_items = query_outline_items(&language, &tree, &source);",
220        );
221        // 14 identifier parts, 11 unique
222        // Repeats: 2 "outline", 2 "language", 2 "tree"
223        let set_b = IdentifierOccurrences::within_string(
224            "pub fn query_outline_items(language: &Language, tree: &Tree, source: &str) -> Vec<OutlineItem> {",
225        );
226
227        // 6 overlaps: "outline", "items", "query", "language", "tree", "source"
228        // 7 non-overlaps: "let", "mut", "pub", "fn", "vec", "item", "str"
229        assert_eq!(jaccard_similarity(&set_a, &set_b), 6.0 / (6.0 + 7.0));
230
231        // Numerator is one more than before due to both having 2 "outline".
232        // Denominator is the same except for 3 more due to the non-overlapping duplicates
233        assert_eq!(
234            weighted_jaccard_similarity(&set_a, &set_b),
235            7.0 / (7.0 + 7.0 + 3.0)
236        );
237
238        // Numerator is the same as jaccard_similarity. Denominator is the size of the smaller set, 8.
239        assert_eq!(overlap_coefficient(&set_a, &set_b), 6.0 / 8.0);
240
241        // Numerator is the same as weighted_jaccard_similarity. Denominator is the total weight of
242        // the smaller set, 10.
243        assert_eq!(weighted_overlap_coefficient(&set_a, &set_b), 7.0 / 10.0);
244    }
245}