1use hashbrown::HashTable;
2use regex::Regex;
3use std::{
4 borrow::Cow,
5 hash::{Hash, Hasher as _},
6 path::Path,
7 sync::LazyLock,
8};
9use util::rel_path::RelPath;
10
11use crate::reference::Reference;
12
13// TODO: Consider implementing sliding window similarity matching like
14// https://github.com/sourcegraph/cody-public-snapshot/blob/8e20ac6c1460c08b0db581c0204658112a246eda/vscode/src/completions/context/retrievers/jaccard-similarity/bestJaccardMatch.ts
15//
16// That implementation could actually be more efficient - no need to track words in the window that
17// are not in the query.
18
19// TODO: Consider a flat sorted Vec<(String, usize)> representation. Intersection can just walk the
20// two in parallel.
21
22static IDENTIFIER_REGEX: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"\b\w+\b").unwrap());
23
24/// Multiset of text occurrences for text similarity that only stores hashes and counts.
25#[derive(Debug, Default)]
26pub struct Occurrences {
27 table: HashTable<OccurrenceEntry>,
28 total_count: usize,
29}
30
31#[derive(Debug)]
32struct OccurrenceEntry {
33 hash: u64,
34 count: usize,
35}
36
37impl Occurrences {
38 pub fn within_string(text: &str) -> Self {
39 Self::from_identifiers(IDENTIFIER_REGEX.find_iter(text).map(|mat| mat.as_str()))
40 }
41
42 #[allow(dead_code)]
43 pub fn within_references(references: &[Reference]) -> Self {
44 Self::from_identifiers(
45 references
46 .iter()
47 .map(|reference| reference.identifier.name.as_ref()),
48 )
49 }
50
51 pub fn from_identifiers(identifiers: impl IntoIterator<Item = impl AsRef<str>>) -> Self {
52 let mut this = Self::default();
53 // TODO: Score matches that match case higher?
54 //
55 // TODO: Also include unsplit identifier?
56 for identifier in identifiers {
57 for identifier_part in split_identifier(identifier.as_ref()) {
58 this.add_hash(fx_hash(&identifier_part.to_lowercase()));
59 }
60 }
61 this
62 }
63
64 pub fn from_worktree_path(worktree_name: Option<Cow<'_, str>>, rel_path: &RelPath) -> Self {
65 if let Some(worktree_name) = worktree_name {
66 Self::from_identifiers(
67 std::iter::once(worktree_name)
68 .chain(iter_path_without_extension(rel_path.as_std_path())),
69 )
70 } else {
71 Self::from_path(rel_path.as_std_path())
72 }
73 }
74
75 pub fn from_path(path: &Path) -> Self {
76 Self::from_identifiers(iter_path_without_extension(path))
77 }
78
79 fn add_hash(&mut self, hash: u64) {
80 self.table
81 .entry(
82 hash,
83 |entry: &OccurrenceEntry| entry.hash == hash,
84 |entry| entry.hash,
85 )
86 .and_modify(|entry| entry.count += 1)
87 .or_insert(OccurrenceEntry { hash, count: 1 });
88 self.total_count += 1;
89 }
90
91 fn contains_hash(&self, hash: u64) -> bool {
92 self.get_count(hash) != 0
93 }
94
95 fn get_count(&self, hash: u64) -> usize {
96 self.table
97 .find(hash, |entry| entry.hash == hash)
98 .map(|entry| entry.count)
99 .unwrap_or(0)
100 }
101}
102
103fn iter_path_without_extension(path: &Path) -> impl Iterator<Item = Cow<'_, str>> {
104 let last_component: Option<Cow<'_, str>> = path.file_stem().map(|stem| stem.to_string_lossy());
105 let mut path_components = path.components();
106 path_components.next_back();
107 path_components
108 .map(|component| component.as_os_str().to_string_lossy())
109 .chain(last_component)
110}
111
112pub fn fx_hash<T: Hash + ?Sized>(data: &T) -> u64 {
113 let mut hasher = collections::FxHasher::default();
114 data.hash(&mut hasher);
115 hasher.finish()
116}
117
118// Splits camelcase / snakecase / kebabcase / pascalcase
119//
120// TODO: Make this more efficient / elegant.
121fn split_identifier(identifier: &str) -> Vec<&str> {
122 let mut parts = Vec::new();
123 let mut start = 0;
124 let chars: Vec<char> = identifier.chars().collect();
125
126 if chars.is_empty() {
127 return parts;
128 }
129
130 let mut i = 0;
131 while i < chars.len() {
132 let ch = chars[i];
133
134 // Handle explicit delimiters (underscore and hyphen)
135 if ch == '_' || ch == '-' {
136 if i > start {
137 parts.push(&identifier[start..i]);
138 }
139 start = i + 1;
140 i += 1;
141 continue;
142 }
143
144 // Handle camelCase and PascalCase transitions
145 if i > 0 && i < chars.len() {
146 let prev_char = chars[i - 1];
147
148 // Transition from lowercase/digit to uppercase
149 if (prev_char.is_lowercase() || prev_char.is_ascii_digit()) && ch.is_uppercase() {
150 parts.push(&identifier[start..i]);
151 start = i;
152 }
153 // Handle sequences like "XMLParser" -> ["XML", "Parser"]
154 else if i + 1 < chars.len()
155 && ch.is_uppercase()
156 && chars[i + 1].is_lowercase()
157 && prev_char.is_uppercase()
158 {
159 parts.push(&identifier[start..i]);
160 start = i;
161 }
162 }
163
164 i += 1;
165 }
166
167 // Add the last part if there's any remaining
168 if start < identifier.len() {
169 parts.push(&identifier[start..]);
170 }
171
172 // Filter out empty strings
173 parts.into_iter().filter(|s| !s.is_empty()).collect()
174}
175
176pub fn jaccard_similarity<'a>(mut set_a: &'a Occurrences, mut set_b: &'a Occurrences) -> f32 {
177 if set_a.table.len() > set_b.table.len() {
178 std::mem::swap(&mut set_a, &mut set_b);
179 }
180 let intersection = set_a
181 .table
182 .iter()
183 .filter(|entry| set_b.contains_hash(entry.hash))
184 .count();
185 let union = set_a.table.len() + set_b.table.len() - intersection;
186 intersection as f32 / union as f32
187}
188
189// TODO
190#[allow(dead_code)]
191pub fn overlap_coefficient<'a>(mut set_a: &'a Occurrences, mut set_b: &'a Occurrences) -> f32 {
192 if set_a.table.len() > set_b.table.len() {
193 std::mem::swap(&mut set_a, &mut set_b);
194 }
195 let intersection = set_a
196 .table
197 .iter()
198 .filter(|entry| set_b.contains_hash(entry.hash))
199 .count();
200 intersection as f32 / set_a.table.len() as f32
201}
202
203// TODO
204#[allow(dead_code)]
205pub fn weighted_jaccard_similarity<'a>(
206 mut set_a: &'a Occurrences,
207 mut set_b: &'a Occurrences,
208) -> f32 {
209 if set_a.table.len() > set_b.table.len() {
210 std::mem::swap(&mut set_a, &mut set_b);
211 }
212
213 let mut numerator = 0;
214 let mut denominator_a = 0;
215 let mut used_count_b = 0;
216 for entry_a in set_a.table.iter() {
217 let count_a = entry_a.count;
218 let count_b = set_b.get_count(entry_a.hash);
219 numerator += count_a.min(count_b);
220 denominator_a += count_a.max(count_b);
221 used_count_b += count_b;
222 }
223
224 let denominator = denominator_a + (set_b.total_count - used_count_b);
225 if denominator == 0 {
226 0.0
227 } else {
228 numerator as f32 / denominator as f32
229 }
230}
231
232pub fn weighted_overlap_coefficient<'a>(
233 mut set_a: &'a Occurrences,
234 mut set_b: &'a Occurrences,
235) -> f32 {
236 if set_a.table.len() > set_b.table.len() {
237 std::mem::swap(&mut set_a, &mut set_b);
238 }
239
240 let mut numerator = 0;
241 for entry_a in set_a.table.iter() {
242 let count_a = entry_a.count;
243 let count_b = set_b.get_count(entry_a.hash);
244 numerator += count_a.min(count_b);
245 }
246
247 let denominator = set_a.total_count.min(set_b.total_count);
248 if denominator == 0 {
249 0.0
250 } else {
251 numerator as f32 / denominator as f32
252 }
253}
254
255#[cfg(test)]
256mod test {
257 use super::*;
258
259 #[test]
260 fn test_split_identifier() {
261 assert_eq!(split_identifier("snake_case"), vec!["snake", "case"]);
262 assert_eq!(split_identifier("kebab-case"), vec!["kebab", "case"]);
263 assert_eq!(split_identifier("PascalCase"), vec!["Pascal", "Case"]);
264 assert_eq!(split_identifier("camelCase"), vec!["camel", "Case"]);
265 assert_eq!(split_identifier("XMLParser"), vec!["XML", "Parser"]);
266 }
267
268 #[test]
269 fn test_similarity_functions() {
270 // 10 identifier parts, 8 unique
271 // Repeats: 2 "outline", 2 "items"
272 let set_a = Occurrences::within_string(
273 "let mut outline_items = query_outline_items(&language, &tree, &source);",
274 );
275 // 14 identifier parts, 11 unique
276 // Repeats: 2 "outline", 2 "language", 2 "tree"
277 let set_b = Occurrences::within_string(
278 "pub fn query_outline_items(language: &Language, tree: &Tree, source: &str) -> Vec<OutlineItem> {",
279 );
280
281 // 6 overlaps: "outline", "items", "query", "language", "tree", "source"
282 // 7 non-overlaps: "let", "mut", "pub", "fn", "vec", "item", "str"
283 assert_eq!(jaccard_similarity(&set_a, &set_b), 6.0 / (6.0 + 7.0));
284
285 // Numerator is one more than before due to both having 2 "outline".
286 // Denominator is the same except for 3 more due to the non-overlapping duplicates
287 assert_eq!(
288 weighted_jaccard_similarity(&set_a, &set_b),
289 7.0 / (7.0 + 7.0 + 3.0)
290 );
291
292 // Numerator is the same as jaccard_similarity. Denominator is the size of the smaller set, 8.
293 assert_eq!(overlap_coefficient(&set_a, &set_b), 6.0 / 8.0);
294
295 // Numerator is the same as weighted_jaccard_similarity. Denominator is the total weight of
296 // the smaller set, 10.
297 assert_eq!(weighted_overlap_coefficient(&set_a, &set_b), 7.0 / 10.0);
298 }
299
300 #[test]
301 fn test_iter_path_without_extension() {
302 let mut iter = iter_path_without_extension(Path::new(""));
303 assert_eq!(iter.next(), None);
304
305 let iter = iter_path_without_extension(Path::new("foo"));
306 assert_eq!(iter.collect::<Vec<_>>(), ["foo"]);
307
308 let iter = iter_path_without_extension(Path::new("foo/bar.txt"));
309 assert_eq!(iter.collect::<Vec<_>>(), ["foo", "bar"]);
310
311 let iter = iter_path_without_extension(Path::new("foo/bar/baz.txt"));
312 assert_eq!(iter.collect::<Vec<_>>(), ["foo", "bar", "baz"]);
313 }
314}