1use regex::Regex;
2use std::{collections::HashMap, sync::LazyLock};
3
4use crate::reference::Reference;
5
6// TODO: Consider implementing sliding window similarity matching like
7// https://github.com/sourcegraph/cody-public-snapshot/blob/8e20ac6c1460c08b0db581c0204658112a246eda/vscode/src/completions/context/retrievers/jaccard-similarity/bestJaccardMatch.ts
8//
9// That implementation could actually be more efficient - no need to track words in the window that
10// are not in the query.
11
12static IDENTIFIER_REGEX: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"\b\w+\b").unwrap());
13
14#[derive(Debug)]
15pub struct IdentifierOccurrences {
16 identifier_to_count: HashMap<String, usize>,
17 total_count: usize,
18}
19
20impl IdentifierOccurrences {
21 pub fn within_string(code: &str) -> Self {
22 Self::from_iterator(IDENTIFIER_REGEX.find_iter(code).map(|mat| mat.as_str()))
23 }
24
25 #[allow(dead_code)]
26 pub fn within_references(references: &[Reference]) -> Self {
27 Self::from_iterator(
28 references
29 .iter()
30 .map(|reference| reference.identifier.name.as_ref()),
31 )
32 }
33
34 pub fn from_iterator<'a>(identifier_iterator: impl Iterator<Item = &'a str>) -> Self {
35 let mut identifier_to_count = HashMap::new();
36 let mut total_count = 0;
37 for identifier in identifier_iterator {
38 // TODO: Score matches that match case higher?
39 //
40 // TODO: Also include unsplit identifier?
41 for identifier_part in split_identifier(identifier) {
42 identifier_to_count
43 .entry(identifier_part.to_lowercase())
44 .and_modify(|count| *count += 1)
45 .or_insert(1);
46 total_count += 1;
47 }
48 }
49 IdentifierOccurrences {
50 identifier_to_count,
51 total_count,
52 }
53 }
54}
55
56// Splits camelcase / snakecase / kebabcase / pascalcase
57//
58// TODO: Make this more efficient / elegant.
59fn split_identifier(identifier: &str) -> Vec<&str> {
60 let mut parts = Vec::new();
61 let mut start = 0;
62 let chars: Vec<char> = identifier.chars().collect();
63
64 if chars.is_empty() {
65 return parts;
66 }
67
68 let mut i = 0;
69 while i < chars.len() {
70 let ch = chars[i];
71
72 // Handle explicit delimiters (underscore and hyphen)
73 if ch == '_' || ch == '-' {
74 if i > start {
75 parts.push(&identifier[start..i]);
76 }
77 start = i + 1;
78 i += 1;
79 continue;
80 }
81
82 // Handle camelCase and PascalCase transitions
83 if i > 0 && i < chars.len() {
84 let prev_char = chars[i - 1];
85
86 // Transition from lowercase/digit to uppercase
87 if (prev_char.is_lowercase() || prev_char.is_ascii_digit()) && ch.is_uppercase() {
88 parts.push(&identifier[start..i]);
89 start = i;
90 }
91 // Handle sequences like "XMLParser" -> ["XML", "Parser"]
92 else if i + 1 < chars.len()
93 && ch.is_uppercase()
94 && chars[i + 1].is_lowercase()
95 && prev_char.is_uppercase()
96 {
97 parts.push(&identifier[start..i]);
98 start = i;
99 }
100 }
101
102 i += 1;
103 }
104
105 // Add the last part if there's any remaining
106 if start < identifier.len() {
107 parts.push(&identifier[start..]);
108 }
109
110 // Filter out empty strings
111 parts.into_iter().filter(|s| !s.is_empty()).collect()
112}
113
114pub fn jaccard_similarity<'a>(
115 mut set_a: &'a IdentifierOccurrences,
116 mut set_b: &'a IdentifierOccurrences,
117) -> f32 {
118 if set_a.identifier_to_count.len() > set_b.identifier_to_count.len() {
119 std::mem::swap(&mut set_a, &mut set_b);
120 }
121 let intersection = set_a
122 .identifier_to_count
123 .keys()
124 .filter(|key| set_b.identifier_to_count.contains_key(*key))
125 .count();
126 let union = set_a.identifier_to_count.len() + set_b.identifier_to_count.len() - intersection;
127 intersection as f32 / union as f32
128}
129
130// TODO
131#[allow(dead_code)]
132pub fn overlap_coefficient<'a>(
133 mut set_a: &'a IdentifierOccurrences,
134 mut set_b: &'a IdentifierOccurrences,
135) -> f32 {
136 if set_a.identifier_to_count.len() > set_b.identifier_to_count.len() {
137 std::mem::swap(&mut set_a, &mut set_b);
138 }
139 let intersection = set_a
140 .identifier_to_count
141 .keys()
142 .filter(|key| set_b.identifier_to_count.contains_key(*key))
143 .count();
144 intersection as f32 / set_a.identifier_to_count.len() as f32
145}
146
147// TODO
148#[allow(dead_code)]
149pub fn weighted_jaccard_similarity<'a>(
150 mut set_a: &'a IdentifierOccurrences,
151 mut set_b: &'a IdentifierOccurrences,
152) -> f32 {
153 if set_a.identifier_to_count.len() > set_b.identifier_to_count.len() {
154 std::mem::swap(&mut set_a, &mut set_b);
155 }
156
157 let mut numerator = 0;
158 let mut denominator_a = 0;
159 let mut used_count_b = 0;
160 for (symbol, count_a) in set_a.identifier_to_count.iter() {
161 let count_b = set_b.identifier_to_count.get(symbol).unwrap_or(&0);
162 numerator += count_a.min(count_b);
163 denominator_a += count_a.max(count_b);
164 used_count_b += count_b;
165 }
166
167 let denominator = denominator_a + (set_b.total_count - used_count_b);
168 if denominator == 0 {
169 0.0
170 } else {
171 numerator as f32 / denominator as f32
172 }
173}
174
175pub fn weighted_overlap_coefficient<'a>(
176 mut set_a: &'a IdentifierOccurrences,
177 mut set_b: &'a IdentifierOccurrences,
178) -> f32 {
179 if set_a.identifier_to_count.len() > set_b.identifier_to_count.len() {
180 std::mem::swap(&mut set_a, &mut set_b);
181 }
182
183 let mut numerator = 0;
184 for (symbol, count_a) in set_a.identifier_to_count.iter() {
185 let count_b = set_b.identifier_to_count.get(symbol).unwrap_or(&0);
186 numerator += count_a.min(count_b);
187 }
188
189 let denominator = set_a.total_count.min(set_b.total_count);
190 if denominator == 0 {
191 0.0
192 } else {
193 numerator as f32 / denominator as f32
194 }
195}
196
197#[cfg(test)]
198mod test {
199 use super::*;
200
201 #[test]
202 fn test_split_identifier() {
203 assert_eq!(split_identifier("snake_case"), vec!["snake", "case"]);
204 assert_eq!(split_identifier("kebab-case"), vec!["kebab", "case"]);
205 assert_eq!(split_identifier("PascalCase"), vec!["Pascal", "Case"]);
206 assert_eq!(split_identifier("camelCase"), vec!["camel", "Case"]);
207 assert_eq!(split_identifier("XMLParser"), vec!["XML", "Parser"]);
208 }
209
210 #[test]
211 fn test_similarity_functions() {
212 // 10 identifier parts, 8 unique
213 // Repeats: 2 "outline", 2 "items"
214 let set_a = IdentifierOccurrences::within_string(
215 "let mut outline_items = query_outline_items(&language, &tree, &source);",
216 );
217 // 14 identifier parts, 11 unique
218 // Repeats: 2 "outline", 2 "language", 2 "tree"
219 let set_b = IdentifierOccurrences::within_string(
220 "pub fn query_outline_items(language: &Language, tree: &Tree, source: &str) -> Vec<OutlineItem> {",
221 );
222
223 // 6 overlaps: "outline", "items", "query", "language", "tree", "source"
224 // 7 non-overlaps: "let", "mut", "pub", "fn", "vec", "item", "str"
225 assert_eq!(jaccard_similarity(&set_a, &set_b), 6.0 / (6.0 + 7.0));
226
227 // Numerator is one more than before due to both having 2 "outline".
228 // Denominator is the same except for 3 more due to the non-overlapping duplicates
229 assert_eq!(
230 weighted_jaccard_similarity(&set_a, &set_b),
231 7.0 / (7.0 + 7.0 + 3.0)
232 );
233
234 // Numerator is the same as jaccard_similarity. Denominator is the size of the smaller set, 8.
235 assert_eq!(overlap_coefficient(&set_a, &set_b), 6.0 / 8.0);
236
237 // Numerator is the same as weighted_jaccard_similarity. Denominator is the total weight of
238 // the smaller set, 10.
239 assert_eq!(weighted_overlap_coefficient(&set_a, &set_b), 7.0 / 10.0);
240 }
241}