1use regex::Regex;
2use std::{collections::HashMap, sync::LazyLock};
3
4use crate::reference::Reference;
5
6// TODO: Consider implementing sliding window similarity matching like
7// https://github.com/sourcegraph/cody-public-snapshot/blob/8e20ac6c1460c08b0db581c0204658112a246eda/vscode/src/completions/context/retrievers/jaccard-similarity/bestJaccardMatch.ts
8//
9// That implementation could actually be more efficient - no need to track words in the window that
10// are not in the query.
11
12static IDENTIFIER_REGEX: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"\b\w+\b").unwrap());
13
14#[derive(Debug)]
15pub struct IdentifierOccurrences {
16 identifier_to_count: HashMap<String, usize>,
17 total_count: usize,
18}
19
20impl IdentifierOccurrences {
21 pub fn within_string(code: &str) -> Self {
22 Self::from_iterator(IDENTIFIER_REGEX.find_iter(code).map(|mat| mat.as_str()))
23 }
24
25 #[allow(dead_code)]
26 pub fn within_references(references: &[Reference]) -> Self {
27 Self::from_iterator(
28 references
29 .iter()
30 .map(|reference| reference.identifier.name.as_ref()),
31 )
32 }
33
34 pub fn from_iterator<'a>(identifier_iterator: impl Iterator<Item = &'a str>) -> Self {
35 let mut identifier_to_count = HashMap::new();
36 let mut total_count = 0;
37 for identifier in identifier_iterator {
38 // TODO: Score matches that match case higher?
39 //
40 // TODO: Also include unsplit identifier?
41 for identifier_part in split_identifier(identifier) {
42 identifier_to_count
43 .entry(identifier_part.to_lowercase())
44 .and_modify(|count| *count += 1)
45 .or_insert(1);
46 total_count += 1;
47 }
48 }
49 IdentifierOccurrences {
50 identifier_to_count,
51 total_count,
52 }
53 }
54}
55
56// Splits camelcase / snakecase / kebabcase / pascalcase
57//
58// TODO: Make this more efficient / elegant.
59fn split_identifier<'a>(identifier: &'a str) -> Vec<&'a str> {
60 let mut parts = Vec::new();
61 let mut start = 0;
62 let chars: Vec<char> = identifier.chars().collect();
63
64 if chars.is_empty() {
65 return parts;
66 }
67
68 let mut i = 0;
69 while i < chars.len() {
70 let ch = chars[i];
71
72 // Handle explicit delimiters (underscore and hyphen)
73 if ch == '_' || ch == '-' {
74 if i > start {
75 parts.push(&identifier[start..i]);
76 }
77 start = i + 1;
78 i += 1;
79 continue;
80 }
81
82 // Handle camelCase and PascalCase transitions
83 if i > 0 && i < chars.len() {
84 let prev_char = chars[i - 1];
85
86 // Transition from lowercase/digit to uppercase
87 if (prev_char.is_lowercase() || prev_char.is_ascii_digit()) && ch.is_uppercase() {
88 parts.push(&identifier[start..i]);
89 start = i;
90 }
91 // Handle sequences like "XMLParser" -> ["XML", "Parser"]
92 else if i + 1 < chars.len()
93 && ch.is_uppercase()
94 && chars[i + 1].is_lowercase()
95 && prev_char.is_uppercase()
96 {
97 parts.push(&identifier[start..i]);
98 start = i;
99 }
100 }
101
102 i += 1;
103 }
104
105 // Add the last part if there's any remaining
106 if start < identifier.len() {
107 parts.push(&identifier[start..]);
108 }
109
110 // Filter out empty strings
111 parts.into_iter().filter(|s| !s.is_empty()).collect()
112}
113
114pub fn jaccard_similarity<'a>(
115 mut set_a: &'a IdentifierOccurrences,
116 mut set_b: &'a IdentifierOccurrences,
117) -> f32 {
118 if set_a.identifier_to_count.len() > set_b.identifier_to_count.len() {
119 std::mem::swap(&mut set_a, &mut set_b);
120 }
121 let intersection = set_a
122 .identifier_to_count
123 .keys()
124 .filter(|key| set_b.identifier_to_count.contains_key(*key))
125 .count();
126 let union = set_a.identifier_to_count.len() + set_b.identifier_to_count.len() - intersection;
127 intersection as f32 / union as f32
128}
129
130pub fn overlap_coefficient<'a>(
131 mut set_a: &'a IdentifierOccurrences,
132 mut set_b: &'a IdentifierOccurrences,
133) -> f32 {
134 if set_a.identifier_to_count.len() > set_b.identifier_to_count.len() {
135 std::mem::swap(&mut set_a, &mut set_b);
136 }
137 let intersection = set_a
138 .identifier_to_count
139 .keys()
140 .filter(|key| set_b.identifier_to_count.contains_key(*key))
141 .count();
142 intersection as f32 / set_a.identifier_to_count.len() as f32
143}
144
145pub fn weighted_jaccard_similarity<'a>(
146 mut set_a: &'a IdentifierOccurrences,
147 mut set_b: &'a IdentifierOccurrences,
148) -> f32 {
149 if set_a.identifier_to_count.len() > set_b.identifier_to_count.len() {
150 std::mem::swap(&mut set_a, &mut set_b);
151 }
152
153 let mut numerator = 0;
154 let mut denominator_a = 0;
155 let mut used_count_b = 0;
156 for (symbol, count_a) in set_a.identifier_to_count.iter() {
157 let count_b = set_b.identifier_to_count.get(symbol).unwrap_or(&0);
158 numerator += count_a.min(count_b);
159 denominator_a += count_a.max(count_b);
160 used_count_b += count_b;
161 }
162
163 let denominator = denominator_a + (set_b.total_count - used_count_b);
164 if denominator == 0 {
165 0.0
166 } else {
167 numerator as f32 / denominator as f32
168 }
169}
170
171pub fn weighted_overlap_coefficient<'a>(
172 mut set_a: &'a IdentifierOccurrences,
173 mut set_b: &'a IdentifierOccurrences,
174) -> f32 {
175 if set_a.identifier_to_count.len() > set_b.identifier_to_count.len() {
176 std::mem::swap(&mut set_a, &mut set_b);
177 }
178
179 let mut numerator = 0;
180 for (symbol, count_a) in set_a.identifier_to_count.iter() {
181 let count_b = set_b.identifier_to_count.get(symbol).unwrap_or(&0);
182 numerator += count_a.min(count_b);
183 }
184
185 let denominator = set_a.total_count.min(set_b.total_count);
186 if denominator == 0 {
187 0.0
188 } else {
189 numerator as f32 / denominator as f32
190 }
191}
192
193#[cfg(test)]
194mod test {
195 use super::*;
196
197 #[test]
198 fn test_split_identifier() {
199 assert_eq!(split_identifier("snake_case"), vec!["snake", "case"]);
200 assert_eq!(split_identifier("kebab-case"), vec!["kebab", "case"]);
201 assert_eq!(split_identifier("PascalCase"), vec!["Pascal", "Case"]);
202 assert_eq!(split_identifier("camelCase"), vec!["camel", "Case"]);
203 assert_eq!(split_identifier("XMLParser"), vec!["XML", "Parser"]);
204 }
205
206 #[test]
207 fn test_similarity_functions() {
208 // 10 identifier parts, 8 unique
209 // Repeats: 2 "outline", 2 "items"
210 let set_a = IdentifierOccurrences::within_string(
211 "let mut outline_items = query_outline_items(&language, &tree, &source);",
212 );
213 // 14 identifier parts, 11 unique
214 // Repeats: 2 "outline", 2 "language", 2 "tree"
215 let set_b = IdentifierOccurrences::within_string(
216 "pub fn query_outline_items(language: &Language, tree: &Tree, source: &str) -> Vec<OutlineItem> {",
217 );
218
219 // 6 overlaps: "outline", "items", "query", "language", "tree", "source"
220 // 7 non-overlaps: "let", "mut", "pub", "fn", "vec", "item", "str"
221 assert_eq!(jaccard_similarity(&set_a, &set_b), 6.0 / (6.0 + 7.0));
222
223 // Numerator is one more than before due to both having 2 "outline".
224 // Denominator is the same except for 3 more due to the non-overlapping duplicates
225 assert_eq!(
226 weighted_jaccard_similarity(&set_a, &set_b),
227 7.0 / (7.0 + 7.0 + 3.0)
228 );
229
230 // Numerator is the same as jaccard_similarity. Denominator is the size of the smaller set, 8.
231 assert_eq!(overlap_coefficient(&set_a, &set_b), 6.0 / 8.0);
232
233 // Numerator is the same as weighted_jaccard_similarity. Denominator is the total weight of
234 // the smaller set, 10.
235 assert_eq!(weighted_overlap_coefficient(&set_a, &set_b), 7.0 / 10.0);
236 }
237}