1use crate::text_similarity::occurrences::HashFrom;
2use arraydeque::ArrayDeque;
3use std::{borrow::Cow, iter::Peekable, marker::PhantomData, path::Path};
4use util::rel_path::RelPath;
5
6pub trait OccurrenceSource {
7 fn occurrences_in_utf8_bytes(
8 str_bytes: impl IntoIterator<Item = u8>,
9 ) -> impl Iterator<Item = HashFrom<Self>>;
10
11 fn occurrences_in_str(str: &str) -> impl Iterator<Item = HashFrom<Self>> {
12 Self::occurrences_in_utf8_bytes(str.bytes())
13 }
14
15 /// Includes worktree name and omits file extension.
16 fn occurrences_in_worktree_path(
17 worktree_name: Option<Cow<'_, str>>,
18 rel_path: &RelPath,
19 ) -> impl Iterator<Item = HashFrom<Self>> {
20 if let Some(worktree_name) = worktree_name {
21 itertools::Either::Left(
22 Self::occurrences_in_utf8_bytes(cow_str_into_bytes(worktree_name))
23 .chain(Self::occurrences_in_path(rel_path.as_std_path())),
24 )
25 } else {
26 itertools::Either::Right(Self::occurrences_in_path(rel_path.as_std_path()))
27 }
28 }
29
30 /// Occurrences from a path, the omitting file extension. Note that this does not split on
31 /// components (they are omitted by `IdentifierParts` but not `CodeParts`).
32 fn occurrences_in_path(path: &Path) -> impl Iterator<Item = HashFrom<Self>> {
33 let path_bytes = path.as_os_str().as_encoded_bytes();
34 let bytes = if let Some(extension) = path.extension() {
35 &path_bytes[0..path_bytes.len() - extension.as_encoded_bytes().len()]
36 } else {
37 path_bytes
38 };
39 Self::occurrences_in_utf8_bytes(bytes.iter().cloned())
40 }
41}
42
43/// Occurrences source for finding relevant code by matching parts of identifiers.
44///
45/// * Splits the input into runs of ascii alphanumeric or unicode characters
46/// * Splits these on ascii case transitions, handling camelCase and PascalCase
47/// * Lowercases each part
48#[derive(Debug)]
49pub struct IdentifierParts;
50
51/// Occurrences source for finding similar code, by including full identifiers and sequences of
52/// symbols.
53///
54/// * Splits the input on ascii whitespace
55/// * Splits these into runs of ascii punctuation or alphanumeric/unicode characters
56///
57/// Due to common use in identifiers, `_` and `-` are not treated as punctuation. This is consistent
58/// with not splitting on case transitions.
59pub struct CodeParts;
60
61/// Source type for occurrences that come from n-grams, aka w-shingling. Each N length interval of
62/// the input will be treated as one occurrence.
63///
64/// Note that this hashes the hashes it's provided for every output - may be more efficient to use a
65/// proper rolling hash. Unfortunately, I didn't find a rust rolling hash implementation that
66/// operated on updates larger than u8.
67#[derive(Debug)]
68pub struct NGram<const N: usize, S> {
69 _source: PhantomData<S>,
70}
71
72impl OccurrenceSource for IdentifierParts {
73 fn occurrences_in_utf8_bytes(
74 str_bytes: impl IntoIterator<Item = u8>,
75 ) -> impl Iterator<Item = HashFrom<Self>> {
76 HashedIdentifierParts::new(str_bytes.into_iter())
77 }
78}
79
80impl OccurrenceSource for CodeParts {
81 fn occurrences_in_utf8_bytes(
82 str_bytes: impl IntoIterator<Item = u8>,
83 ) -> impl Iterator<Item = HashFrom<Self>> {
84 HashedCodeParts::new(str_bytes.into_iter())
85 }
86}
87
88impl<const N: usize, S: OccurrenceSource> OccurrenceSource for NGram<N, S> {
89 fn occurrences_in_utf8_bytes(
90 str_bytes: impl IntoIterator<Item = u8>,
91 ) -> impl Iterator<Item = HashFrom<NGram<N, S>>> {
92 NGramIterator {
93 hashes: S::occurrences_in_utf8_bytes(str_bytes),
94 window: ArrayDeque::new(),
95 _source: PhantomData,
96 }
97 }
98}
99
100struct HashedIdentifierParts<I: Iterator<Item = u8>> {
101 str_bytes: Peekable<I>,
102 hasher: Option<FxHasher32>,
103 prev_char_is_uppercase: bool,
104}
105
106impl<I: Iterator<Item = u8>> HashedIdentifierParts<I> {
107 fn new(str_bytes: I) -> Self {
108 Self {
109 str_bytes: str_bytes.peekable(),
110 hasher: None,
111 prev_char_is_uppercase: false,
112 }
113 }
114}
115
116impl<I: Iterator<Item = u8>> Iterator for HashedIdentifierParts<I> {
117 type Item = HashFrom<IdentifierParts>;
118
119 fn next(&mut self) -> Option<Self::Item> {
120 while let Some(ch) = self.str_bytes.next() {
121 let included = !ch.is_ascii() || ch.is_ascii_alphanumeric();
122 if let Some(mut hasher) = self.hasher.take() {
123 if !included {
124 return Some(hasher.finish().into());
125 }
126
127 // camelCase and PascalCase
128 let is_uppercase = ch.is_ascii_uppercase();
129 let should_split = is_uppercase
130 && (!self.prev_char_is_uppercase ||
131 // sequences like "XMLParser" -> ["XML", "Parser"]
132 self.str_bytes
133 .peek()
134 .map_or(false, |c| c.is_ascii_lowercase()));
135
136 self.prev_char_is_uppercase = is_uppercase;
137
138 if should_split {
139 let result = (hasher.finish() as u32).into();
140 let mut hasher = FxHasher32::default();
141 hasher.write_u8(ch.to_ascii_lowercase());
142 self.hasher = Some(hasher);
143 return Some(result);
144 } else {
145 hasher.write_u8(ch.to_ascii_lowercase());
146 self.hasher = Some(hasher);
147 }
148 } else if included {
149 let mut hasher = FxHasher32::default();
150 hasher.write_u8(ch.to_ascii_lowercase());
151 self.hasher = Some(hasher);
152 self.prev_char_is_uppercase = ch.is_ascii_uppercase();
153 }
154 }
155
156 if let Some(hasher) = self.hasher.take() {
157 return Some(hasher.finish().into());
158 }
159
160 None
161 }
162}
163
164struct HashedCodeParts<I: Iterator<Item = u8>> {
165 str_bytes: Peekable<I>,
166 // TODO: Since this doesn't do lowercasing, it might be more efficient to find str slices and
167 // hash those, instead of hashing a byte at a time. This would be a bit complex with chunked
168 // input, though.
169 hasher: Option<FxHasher32>,
170 prev_char_is_punctuation: bool,
171}
172
173impl<I: Iterator<Item = u8>> HashedCodeParts<I> {
174 fn new(str_bytes: I) -> Self {
175 Self {
176 str_bytes: str_bytes.peekable(),
177 hasher: None,
178 prev_char_is_punctuation: false,
179 }
180 }
181}
182
183impl<I: Iterator<Item = u8>> Iterator for HashedCodeParts<I> {
184 type Item = HashFrom<CodeParts>;
185
186 fn next(&mut self) -> Option<Self::Item> {
187 fn is_punctuation(ch: u8) -> bool {
188 ch.is_ascii_punctuation() && ch != b'_' && ch != b'-'
189 }
190
191 while let Some(ch) = self.str_bytes.next() {
192 let included = !ch.is_ascii() || !ch.is_ascii_whitespace();
193 if let Some(mut hasher) = self.hasher.take() {
194 if !included {
195 return Some(hasher.finish().into());
196 }
197
198 let is_punctuation = is_punctuation(ch);
199 let should_split = is_punctuation != self.prev_char_is_punctuation;
200 self.prev_char_is_punctuation = is_punctuation;
201
202 if should_split {
203 let result = (hasher.finish() as u32).into();
204 let mut hasher = FxHasher32::default();
205 hasher.write_u8(ch);
206 self.hasher = Some(hasher);
207 return Some(result);
208 } else {
209 hasher.write_u8(ch);
210 self.hasher = Some(hasher);
211 }
212 } else if included {
213 let mut hasher = FxHasher32::default();
214 hasher.write_u8(ch);
215 self.hasher = Some(hasher);
216 self.prev_char_is_punctuation = is_punctuation(ch);
217 }
218 }
219
220 if let Some(hasher) = self.hasher.take() {
221 return Some(hasher.finish().into());
222 }
223
224 None
225 }
226}
227
228struct NGramIterator<const N: usize, S, I> {
229 hashes: I,
230 window: ArrayDeque<u32, N, arraydeque::Wrapping>,
231 _source: PhantomData<S>,
232}
233
234impl<const N: usize, S, I> Iterator for NGramIterator<N, S, I>
235where
236 I: Iterator<Item = HashFrom<S>>,
237{
238 type Item = HashFrom<NGram<N, S>>;
239
240 fn next(&mut self) -> Option<Self::Item> {
241 while let Some(hash) = self.hashes.next() {
242 if self.window.push_back(hash.into()).is_some() {
243 let mut hasher = FxHasher32::default();
244 for hash in &self.window {
245 hasher.write_u32(*hash);
246 }
247 return Some(hasher.finish().into());
248 }
249 }
250 return None;
251 }
252}
253
254/// 32-bit variant of FXHasher
255struct FxHasher32(u32);
256
257impl Default for FxHasher32 {
258 fn default() -> Self {
259 FxHasher32(0)
260 }
261}
262
263impl FxHasher32 {
264 #[inline]
265 pub fn write_u8(&mut self, value: u8) {
266 self.write_u32(value as u32);
267 }
268
269 #[inline]
270 pub fn write_u32(&mut self, value: u32) {
271 self.0 = self.0.wrapping_add(value).wrapping_mul(0x93d765dd);
272 }
273
274 pub fn finish(self) -> u32 {
275 self.0
276 }
277}
278
279/// Converts a `Cow<'_, str>` into an iterator of bytes.
280fn cow_str_into_bytes(text: Cow<'_, str>) -> impl Iterator<Item = u8> {
281 match text {
282 Cow::Borrowed(text) => itertools::Either::Left(text.bytes()),
283 Cow::Owned(text) => itertools::Either::Right(text.into_bytes().into_iter()),
284 }
285}
286
287#[cfg(test)]
288mod test {
289 use crate::{
290 Similarity as _, WeightedSimilarity as _,
291 text_similarity::occurrences::{Occurrences, SmallOccurrences},
292 };
293
294 use super::*;
295
296 #[test]
297 fn test_identifier_parts() {
298 #[track_caller]
299 fn check(text: &str, expected: &[&str]) {
300 assert_eq!(
301 IdentifierParts::occurrences_in_str(text).collect::<Vec<_>>(),
302 expected
303 .iter()
304 .map(|part| string_fxhash32(part).into())
305 .collect::<Vec<_>>()
306 );
307 }
308
309 check("", &[]);
310 check("a", &["a"]);
311 check("abc", &["abc"]);
312 check("ABC", &["abc"]);
313 check("123", &["123"]);
314 check("snake_case", &["snake", "case"]);
315 check("kebab-case", &["kebab", "case"]);
316 check("PascalCase", &["pascal", "case"]);
317 check("camelCase", &["camel", "case"]);
318 check("XMLParser", &["xml", "parser"]);
319 check("a1B2c3", &["a1", "b2c3"]);
320 check("HTML5Parser", &["html5", "parser"]);
321 check("_leading_underscore", &["leading", "underscore"]);
322 check("trailing_underscore_", &["trailing", "underscore"]);
323 check("--multiple--delimiters--", &["multiple", "delimiters"]);
324 check(
325 "snake_case kebab-case PascalCase camelCase XMLParser",
326 &[
327 "snake", "case", "kebab", "case", "pascal", "case", "camel", "case", "xml",
328 "parser",
329 ],
330 );
331 }
332
333 #[test]
334 fn test_code_parts() {
335 #[track_caller]
336 fn check(text: &str, expected: &[&str]) {
337 assert_eq!(
338 CodeParts::occurrences_in_str(text).collect::<Vec<_>>(),
339 expected
340 .iter()
341 .map(|part| string_fxhash32(part).into())
342 .collect::<Vec<_>>()
343 );
344 }
345
346 check("", &[]);
347 check("a", &["a"]);
348 check("ABC", &["ABC"]);
349 check("ABC", &["ABC"]);
350 check(
351 "pub fn write_u8(&mut self, byte: u8) {",
352 &[
353 "pub", "fn", "write_u8", "(&", "mut", "self", ",", "byte", ":", "u8", ")", "{",
354 ],
355 );
356 check(
357 "snake_case kebab-case PascalCase camelCase XMLParser _leading_underscore --multiple--delimiters--",
358 &[
359 "snake_case",
360 "kebab-case",
361 "PascalCase",
362 "camelCase",
363 "XMLParser",
364 "_leading_underscore",
365 "--multiple--delimiters--",
366 ],
367 );
368 }
369
370 #[test]
371 fn test_similarity_functions() {
372 // 10 identifier parts, 8 unique
373 // Repeats: 2 "outline", 2 "items"
374 let multiset_a = Occurrences::new(IdentifierParts::occurrences_in_str(
375 "let mut outline_items = query_outline_items(&language, &tree, &source);",
376 ));
377 let set_a =
378 SmallOccurrences::<8, IdentifierParts>::new(IdentifierParts::occurrences_in_str(
379 "let mut outline_items = query_outline_items(&language, &tree, &source);",
380 ));
381 // 14 identifier parts, 11 unique
382 // Repeats: 2 "outline", 2 "language", 2 "tree"
383 let set_b = Occurrences::new(IdentifierParts::occurrences_in_str(
384 "pub fn query_outline_items(language: &Language, tree: &Tree, source: &str) -> Vec<OutlineItem> {",
385 ));
386
387 // 6 overlaps: "outline", "items", "query", "language", "tree", "source"
388 // 7 non-overlaps: "let", "mut", "pub", "fn", "vec", "item", "str"
389 assert_eq!(multiset_a.jaccard_similarity(&set_b), 6.0 / (6.0 + 7.0));
390 assert_eq!(set_a.jaccard_similarity(&set_b), 6.0 / (6.0 + 7.0));
391
392 // Numerator is one more than before due to both having 2 "outline".
393 // Denominator is the same except for 3 more due to the non-overlapping duplicates
394 assert_eq!(
395 multiset_a.weighted_jaccard_similarity(&set_b),
396 7.0 / (7.0 + 7.0 + 3.0)
397 );
398
399 // Numerator is the same as jaccard_similarity. Denominator is the size of the smaller set, 8.
400 assert_eq!(multiset_a.overlap_coefficient(&set_b), 6.0 / 8.0);
401 assert_eq!(set_a.overlap_coefficient(&set_b), 6.0 / 8.0);
402
403 // Numerator is the same as weighted_jaccard_similarity. Denominator is the total weight of
404 // the smaller set, 10.
405 assert_eq!(multiset_a.weighted_overlap_coefficient(&set_b), 7.0 / 10.0);
406 }
407
408 fn string_fxhash32(text: &str) -> u32 {
409 let mut hasher = FxHasher32::default();
410 for byte in text.bytes() {
411 hasher.write_u8(byte);
412 }
413 hasher.finish() as u32
414 }
415}