1use regex::Regex;
2use std::{collections::HashMap, sync::LazyLock};
3
4use crate::reference::Reference;
5
6// TODO: Consider implementing sliding window similarity matching like
7// https://github.com/sourcegraph/cody-public-snapshot/blob/8e20ac6c1460c08b0db581c0204658112a246eda/vscode/src/completions/context/retrievers/jaccard-similarity/bestJaccardMatch.ts
8//
9// That implementation could actually be more efficient - no need to track words in the window that
10// are not in the query.
11
12// TODO: Consider a flat sorted Vec<(String, usize)> representation. Intersection can just walk the
13// two in parallel.
14
15static IDENTIFIER_REGEX: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"\b\w+\b").unwrap());
16
17// TODO: use &str or Cow<str> keys?
18#[derive(Debug)]
19pub struct IdentifierOccurrences {
20 identifier_to_count: HashMap<String, usize>,
21 total_count: usize,
22}
23
24impl IdentifierOccurrences {
25 pub fn within_string(code: &str) -> Self {
26 Self::from_iterator(IDENTIFIER_REGEX.find_iter(code).map(|mat| mat.as_str()))
27 }
28
29 #[allow(dead_code)]
30 pub fn within_references(references: &[Reference]) -> Self {
31 Self::from_iterator(
32 references
33 .iter()
34 .map(|reference| reference.identifier.name.as_ref()),
35 )
36 }
37
38 pub fn from_iterator<'a>(identifier_iterator: impl Iterator<Item = &'a str>) -> Self {
39 let mut identifier_to_count = HashMap::new();
40 let mut total_count = 0;
41 for identifier in identifier_iterator {
42 // TODO: Score matches that match case higher?
43 //
44 // TODO: Also include unsplit identifier?
45 for identifier_part in split_identifier(identifier) {
46 identifier_to_count
47 .entry(identifier_part.to_lowercase())
48 .and_modify(|count| *count += 1)
49 .or_insert(1);
50 total_count += 1;
51 }
52 }
53 IdentifierOccurrences {
54 identifier_to_count,
55 total_count,
56 }
57 }
58}
59
60// Splits camelcase / snakecase / kebabcase / pascalcase
61//
62// TODO: Make this more efficient / elegant.
63fn split_identifier(identifier: &str) -> Vec<&str> {
64 let mut parts = Vec::new();
65 let mut start = 0;
66 let chars: Vec<char> = identifier.chars().collect();
67
68 if chars.is_empty() {
69 return parts;
70 }
71
72 let mut i = 0;
73 while i < chars.len() {
74 let ch = chars[i];
75
76 // Handle explicit delimiters (underscore and hyphen)
77 if ch == '_' || ch == '-' {
78 if i > start {
79 parts.push(&identifier[start..i]);
80 }
81 start = i + 1;
82 i += 1;
83 continue;
84 }
85
86 // Handle camelCase and PascalCase transitions
87 if i > 0 && i < chars.len() {
88 let prev_char = chars[i - 1];
89
90 // Transition from lowercase/digit to uppercase
91 if (prev_char.is_lowercase() || prev_char.is_ascii_digit()) && ch.is_uppercase() {
92 parts.push(&identifier[start..i]);
93 start = i;
94 }
95 // Handle sequences like "XMLParser" -> ["XML", "Parser"]
96 else if i + 1 < chars.len()
97 && ch.is_uppercase()
98 && chars[i + 1].is_lowercase()
99 && prev_char.is_uppercase()
100 {
101 parts.push(&identifier[start..i]);
102 start = i;
103 }
104 }
105
106 i += 1;
107 }
108
109 // Add the last part if there's any remaining
110 if start < identifier.len() {
111 parts.push(&identifier[start..]);
112 }
113
114 // Filter out empty strings
115 parts.into_iter().filter(|s| !s.is_empty()).collect()
116}
117
118pub fn jaccard_similarity<'a>(
119 mut set_a: &'a IdentifierOccurrences,
120 mut set_b: &'a IdentifierOccurrences,
121) -> f32 {
122 if set_a.identifier_to_count.len() > set_b.identifier_to_count.len() {
123 std::mem::swap(&mut set_a, &mut set_b);
124 }
125 let intersection = set_a
126 .identifier_to_count
127 .keys()
128 .filter(|key| set_b.identifier_to_count.contains_key(*key))
129 .count();
130 let union = set_a.identifier_to_count.len() + set_b.identifier_to_count.len() - intersection;
131 intersection as f32 / union as f32
132}
133
134// TODO
135#[allow(dead_code)]
136pub fn overlap_coefficient<'a>(
137 mut set_a: &'a IdentifierOccurrences,
138 mut set_b: &'a IdentifierOccurrences,
139) -> f32 {
140 if set_a.identifier_to_count.len() > set_b.identifier_to_count.len() {
141 std::mem::swap(&mut set_a, &mut set_b);
142 }
143 let intersection = set_a
144 .identifier_to_count
145 .keys()
146 .filter(|key| set_b.identifier_to_count.contains_key(*key))
147 .count();
148 intersection as f32 / set_a.identifier_to_count.len() as f32
149}
150
151// TODO
152#[allow(dead_code)]
153pub fn weighted_jaccard_similarity<'a>(
154 mut set_a: &'a IdentifierOccurrences,
155 mut set_b: &'a IdentifierOccurrences,
156) -> f32 {
157 if set_a.identifier_to_count.len() > set_b.identifier_to_count.len() {
158 std::mem::swap(&mut set_a, &mut set_b);
159 }
160
161 let mut numerator = 0;
162 let mut denominator_a = 0;
163 let mut used_count_b = 0;
164 for (symbol, count_a) in set_a.identifier_to_count.iter() {
165 let count_b = set_b.identifier_to_count.get(symbol).unwrap_or(&0);
166 numerator += count_a.min(count_b);
167 denominator_a += count_a.max(count_b);
168 used_count_b += count_b;
169 }
170
171 let denominator = denominator_a + (set_b.total_count - used_count_b);
172 if denominator == 0 {
173 0.0
174 } else {
175 numerator as f32 / denominator as f32
176 }
177}
178
179pub fn weighted_overlap_coefficient<'a>(
180 mut set_a: &'a IdentifierOccurrences,
181 mut set_b: &'a IdentifierOccurrences,
182) -> f32 {
183 if set_a.identifier_to_count.len() > set_b.identifier_to_count.len() {
184 std::mem::swap(&mut set_a, &mut set_b);
185 }
186
187 let mut numerator = 0;
188 for (symbol, count_a) in set_a.identifier_to_count.iter() {
189 let count_b = set_b.identifier_to_count.get(symbol).unwrap_or(&0);
190 numerator += count_a.min(count_b);
191 }
192
193 let denominator = set_a.total_count.min(set_b.total_count);
194 if denominator == 0 {
195 0.0
196 } else {
197 numerator as f32 / denominator as f32
198 }
199}
200
201#[cfg(test)]
202mod test {
203 use super::*;
204
205 #[test]
206 fn test_split_identifier() {
207 assert_eq!(split_identifier("snake_case"), vec!["snake", "case"]);
208 assert_eq!(split_identifier("kebab-case"), vec!["kebab", "case"]);
209 assert_eq!(split_identifier("PascalCase"), vec!["Pascal", "Case"]);
210 assert_eq!(split_identifier("camelCase"), vec!["camel", "Case"]);
211 assert_eq!(split_identifier("XMLParser"), vec!["XML", "Parser"]);
212 }
213
214 #[test]
215 fn test_similarity_functions() {
216 // 10 identifier parts, 8 unique
217 // Repeats: 2 "outline", 2 "items"
218 let set_a = IdentifierOccurrences::within_string(
219 "let mut outline_items = query_outline_items(&language, &tree, &source);",
220 );
221 // 14 identifier parts, 11 unique
222 // Repeats: 2 "outline", 2 "language", 2 "tree"
223 let set_b = IdentifierOccurrences::within_string(
224 "pub fn query_outline_items(language: &Language, tree: &Tree, source: &str) -> Vec<OutlineItem> {",
225 );
226
227 // 6 overlaps: "outline", "items", "query", "language", "tree", "source"
228 // 7 non-overlaps: "let", "mut", "pub", "fn", "vec", "item", "str"
229 assert_eq!(jaccard_similarity(&set_a, &set_b), 6.0 / (6.0 + 7.0));
230
231 // Numerator is one more than before due to both having 2 "outline".
232 // Denominator is the same except for 3 more due to the non-overlapping duplicates
233 assert_eq!(
234 weighted_jaccard_similarity(&set_a, &set_b),
235 7.0 / (7.0 + 7.0 + 3.0)
236 );
237
238 // Numerator is the same as jaccard_similarity. Denominator is the size of the smaller set, 8.
239 assert_eq!(overlap_coefficient(&set_a, &set_b), 6.0 / 8.0);
240
241 // Numerator is the same as weighted_jaccard_similarity. Denominator is the total weight of
242 // the smaller set, 10.
243 assert_eq!(weighted_overlap_coefficient(&set_a, &set_b), 7.0 / 10.0);
244 }
245}