1use hashbrown::HashTable;
2use regex::Regex;
3use std::{
4 hash::{Hash, Hasher as _},
5 sync::LazyLock,
6};
7
8use crate::reference::Reference;
9
10// TODO: Consider implementing sliding window similarity matching like
11// https://github.com/sourcegraph/cody-public-snapshot/blob/8e20ac6c1460c08b0db581c0204658112a246eda/vscode/src/completions/context/retrievers/jaccard-similarity/bestJaccardMatch.ts
12//
13// That implementation could actually be more efficient - no need to track words in the window that
14// are not in the query.
15
16// TODO: Consider a flat sorted Vec<(String, usize)> representation. Intersection can just walk the
17// two in parallel.
18
19static IDENTIFIER_REGEX: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"\b\w+\b").unwrap());
20
21/// Multiset of text occurrences for text similarity that only stores hashes and counts.
22#[derive(Debug, Default)]
23pub struct Occurrences {
24 table: HashTable<OccurrenceEntry>,
25 total_count: usize,
26}
27
28#[derive(Debug)]
29struct OccurrenceEntry {
30 hash: u64,
31 count: usize,
32}
33
34impl Occurrences {
35 pub fn within_string(text: &str) -> Self {
36 Self::from_identifiers(IDENTIFIER_REGEX.find_iter(text).map(|mat| mat.as_str()))
37 }
38
39 #[allow(dead_code)]
40 pub fn within_references(references: &[Reference]) -> Self {
41 Self::from_identifiers(
42 references
43 .iter()
44 .map(|reference| reference.identifier.name.as_ref()),
45 )
46 }
47
48 pub fn from_identifiers<'a>(identifiers: impl IntoIterator<Item = &'a str>) -> Self {
49 let mut this = Self::default();
50 // TODO: Score matches that match case higher?
51 //
52 // TODO: Also include unsplit identifier?
53 for identifier in identifiers {
54 for identifier_part in split_identifier(identifier) {
55 this.add_hash(fx_hash(&identifier_part.to_lowercase()));
56 }
57 }
58 this
59 }
60
61 fn add_hash(&mut self, hash: u64) {
62 self.table
63 .entry(
64 hash,
65 |entry: &OccurrenceEntry| entry.hash == hash,
66 |entry| entry.hash,
67 )
68 .and_modify(|entry| entry.count += 1)
69 .or_insert(OccurrenceEntry { hash, count: 1 });
70 self.total_count += 1;
71 }
72
73 fn contains_hash(&self, hash: u64) -> bool {
74 self.get_count(hash) != 0
75 }
76
77 fn get_count(&self, hash: u64) -> usize {
78 self.table
79 .find(hash, |entry| entry.hash == hash)
80 .map(|entry| entry.count)
81 .unwrap_or(0)
82 }
83}
84
85pub fn fx_hash<T: Hash + ?Sized>(data: &T) -> u64 {
86 let mut hasher = collections::FxHasher::default();
87 data.hash(&mut hasher);
88 hasher.finish()
89}
90
91// Splits camelcase / snakecase / kebabcase / pascalcase
92//
93// TODO: Make this more efficient / elegant.
94fn split_identifier(identifier: &str) -> Vec<&str> {
95 let mut parts = Vec::new();
96 let mut start = 0;
97 let chars: Vec<char> = identifier.chars().collect();
98
99 if chars.is_empty() {
100 return parts;
101 }
102
103 let mut i = 0;
104 while i < chars.len() {
105 let ch = chars[i];
106
107 // Handle explicit delimiters (underscore and hyphen)
108 if ch == '_' || ch == '-' {
109 if i > start {
110 parts.push(&identifier[start..i]);
111 }
112 start = i + 1;
113 i += 1;
114 continue;
115 }
116
117 // Handle camelCase and PascalCase transitions
118 if i > 0 && i < chars.len() {
119 let prev_char = chars[i - 1];
120
121 // Transition from lowercase/digit to uppercase
122 if (prev_char.is_lowercase() || prev_char.is_ascii_digit()) && ch.is_uppercase() {
123 parts.push(&identifier[start..i]);
124 start = i;
125 }
126 // Handle sequences like "XMLParser" -> ["XML", "Parser"]
127 else if i + 1 < chars.len()
128 && ch.is_uppercase()
129 && chars[i + 1].is_lowercase()
130 && prev_char.is_uppercase()
131 {
132 parts.push(&identifier[start..i]);
133 start = i;
134 }
135 }
136
137 i += 1;
138 }
139
140 // Add the last part if there's any remaining
141 if start < identifier.len() {
142 parts.push(&identifier[start..]);
143 }
144
145 // Filter out empty strings
146 parts.into_iter().filter(|s| !s.is_empty()).collect()
147}
148
149pub fn jaccard_similarity<'a>(mut set_a: &'a Occurrences, mut set_b: &'a Occurrences) -> f32 {
150 if set_a.table.len() > set_b.table.len() {
151 std::mem::swap(&mut set_a, &mut set_b);
152 }
153 let intersection = set_a
154 .table
155 .iter()
156 .filter(|entry| set_b.contains_hash(entry.hash))
157 .count();
158 let union = set_a.table.len() + set_b.table.len() - intersection;
159 intersection as f32 / union as f32
160}
161
162// TODO
163#[allow(dead_code)]
164pub fn overlap_coefficient<'a>(mut set_a: &'a Occurrences, mut set_b: &'a Occurrences) -> f32 {
165 if set_a.table.len() > set_b.table.len() {
166 std::mem::swap(&mut set_a, &mut set_b);
167 }
168 let intersection = set_a
169 .table
170 .iter()
171 .filter(|entry| set_b.contains_hash(entry.hash))
172 .count();
173 intersection as f32 / set_a.table.len() as f32
174}
175
176// TODO
177#[allow(dead_code)]
178pub fn weighted_jaccard_similarity<'a>(
179 mut set_a: &'a Occurrences,
180 mut set_b: &'a Occurrences,
181) -> f32 {
182 if set_a.table.len() > set_b.table.len() {
183 std::mem::swap(&mut set_a, &mut set_b);
184 }
185
186 let mut numerator = 0;
187 let mut denominator_a = 0;
188 let mut used_count_b = 0;
189 for entry_a in set_a.table.iter() {
190 let count_a = entry_a.count;
191 let count_b = set_b.get_count(entry_a.hash);
192 numerator += count_a.min(count_b);
193 denominator_a += count_a.max(count_b);
194 used_count_b += count_b;
195 }
196
197 let denominator = denominator_a + (set_b.total_count - used_count_b);
198 if denominator == 0 {
199 0.0
200 } else {
201 numerator as f32 / denominator as f32
202 }
203}
204
205pub fn weighted_overlap_coefficient<'a>(
206 mut set_a: &'a Occurrences,
207 mut set_b: &'a Occurrences,
208) -> f32 {
209 if set_a.table.len() > set_b.table.len() {
210 std::mem::swap(&mut set_a, &mut set_b);
211 }
212
213 let mut numerator = 0;
214 for entry_a in set_a.table.iter() {
215 let count_a = entry_a.count;
216 let count_b = set_b.get_count(entry_a.hash);
217 numerator += count_a.min(count_b);
218 }
219
220 let denominator = set_a.total_count.min(set_b.total_count);
221 if denominator == 0 {
222 0.0
223 } else {
224 numerator as f32 / denominator as f32
225 }
226}
227
228#[cfg(test)]
229mod test {
230 use super::*;
231
232 #[test]
233 fn test_split_identifier() {
234 assert_eq!(split_identifier("snake_case"), vec!["snake", "case"]);
235 assert_eq!(split_identifier("kebab-case"), vec!["kebab", "case"]);
236 assert_eq!(split_identifier("PascalCase"), vec!["Pascal", "Case"]);
237 assert_eq!(split_identifier("camelCase"), vec!["camel", "Case"]);
238 assert_eq!(split_identifier("XMLParser"), vec!["XML", "Parser"]);
239 }
240
241 #[test]
242 fn test_similarity_functions() {
243 // 10 identifier parts, 8 unique
244 // Repeats: 2 "outline", 2 "items"
245 let set_a = Occurrences::within_string(
246 "let mut outline_items = query_outline_items(&language, &tree, &source);",
247 );
248 // 14 identifier parts, 11 unique
249 // Repeats: 2 "outline", 2 "language", 2 "tree"
250 let set_b = Occurrences::within_string(
251 "pub fn query_outline_items(language: &Language, tree: &Tree, source: &str) -> Vec<OutlineItem> {",
252 );
253
254 // 6 overlaps: "outline", "items", "query", "language", "tree", "source"
255 // 7 non-overlaps: "let", "mut", "pub", "fn", "vec", "item", "str"
256 assert_eq!(jaccard_similarity(&set_a, &set_b), 6.0 / (6.0 + 7.0));
257
258 // Numerator is one more than before due to both having 2 "outline".
259 // Denominator is the same except for 3 more due to the non-overlapping duplicates
260 assert_eq!(
261 weighted_jaccard_similarity(&set_a, &set_b),
262 7.0 / (7.0 + 7.0 + 3.0)
263 );
264
265 // Numerator is the same as jaccard_similarity. Denominator is the size of the smaller set, 8.
266 assert_eq!(overlap_coefficient(&set_a, &set_b), 6.0 / 8.0);
267
268 // Numerator is the same as weighted_jaccard_similarity. Denominator is the total weight of
269 // the smaller set, 10.
270 assert_eq!(weighted_overlap_coefficient(&set_a, &set_b), 7.0 / 10.0);
271 }
272}