source.rs

  1use crate::text_similarity::occurrences::HashFrom;
  2use arraydeque::ArrayDeque;
  3use std::{borrow::Cow, iter::Peekable, marker::PhantomData, path::Path};
  4use util::rel_path::RelPath;
  5
  6pub trait OccurrenceSource {
  7    fn occurrences_in_utf8_bytes(
  8        str_bytes: impl IntoIterator<Item = u8>,
  9    ) -> impl Iterator<Item = HashFrom<Self>>;
 10
 11    fn occurrences_in_str(str: &str) -> impl Iterator<Item = HashFrom<Self>> {
 12        Self::occurrences_in_utf8_bytes(str.bytes())
 13    }
 14
 15    /// Includes worktree name and omits file extension.
 16    fn occurrences_in_worktree_path(
 17        worktree_name: Option<Cow<'_, str>>,
 18        rel_path: &RelPath,
 19    ) -> impl Iterator<Item = HashFrom<Self>> {
 20        if let Some(worktree_name) = worktree_name {
 21            itertools::Either::Left(
 22                Self::occurrences_in_utf8_bytes(cow_str_into_bytes(worktree_name))
 23                    .chain(Self::occurrences_in_path(rel_path.as_std_path())),
 24            )
 25        } else {
 26            itertools::Either::Right(Self::occurrences_in_path(rel_path.as_std_path()))
 27        }
 28    }
 29
 30    /// Occurrences from a path, the omitting file extension. Note that this does not split on
 31    /// components (they are omitted by `IdentifierParts` but not `CodeParts`).
 32    fn occurrences_in_path(path: &Path) -> impl Iterator<Item = HashFrom<Self>> {
 33        let path_bytes = path.as_os_str().as_encoded_bytes();
 34        let bytes = if let Some(extension) = path.extension() {
 35            &path_bytes[0..path_bytes.len() - extension.as_encoded_bytes().len()]
 36        } else {
 37            path_bytes
 38        };
 39        Self::occurrences_in_utf8_bytes(bytes.iter().cloned())
 40    }
 41}
 42
 43/// Occurrences source for finding relevant code by matching parts of identifiers.
 44///
 45/// * Splits the input into runs of ascii alphanumeric or unicode characters
 46/// * Splits these on ascii case transitions, handling camelCase and PascalCase
 47/// * Lowercases each part
 48#[derive(Debug)]
 49pub struct IdentifierParts;
 50
 51/// Occurrences source for finding similar code, by including full identifiers and sequences of
 52/// symbols.
 53///
 54/// * Splits the input on ascii whitespace
 55/// * Splits these into runs of ascii punctuation or alphanumeric/unicode characters
 56///
 57/// Due to common use in identifiers, `_` and `-` are not treated as punctuation. This is consistent
 58/// with not splitting on case transitions.
 59pub struct CodeParts;
 60
 61/// Source type for occurrences that come from n-grams, aka w-shingling. Each N length interval of
 62/// the input will be treated as one occurrence.
 63///
 64/// Note that this hashes the hashes it's provided for every output - may be more efficient to use a
 65/// proper rolling hash. Unfortunately, I didn't find a rust rolling hash implementation that
 66/// operated on updates larger than u8.
 67#[derive(Debug)]
 68pub struct NGram<const N: usize, S> {
 69    _source: PhantomData<S>,
 70}
 71
 72impl OccurrenceSource for IdentifierParts {
 73    fn occurrences_in_utf8_bytes(
 74        str_bytes: impl IntoIterator<Item = u8>,
 75    ) -> impl Iterator<Item = HashFrom<Self>> {
 76        HashedIdentifierParts::new(str_bytes.into_iter())
 77    }
 78}
 79
 80impl OccurrenceSource for CodeParts {
 81    fn occurrences_in_utf8_bytes(
 82        str_bytes: impl IntoIterator<Item = u8>,
 83    ) -> impl Iterator<Item = HashFrom<Self>> {
 84        HashedCodeParts::new(str_bytes.into_iter())
 85    }
 86}
 87
 88impl<const N: usize, S: OccurrenceSource> OccurrenceSource for NGram<N, S> {
 89    fn occurrences_in_utf8_bytes(
 90        str_bytes: impl IntoIterator<Item = u8>,
 91    ) -> impl Iterator<Item = HashFrom<NGram<N, S>>> {
 92        NGramIterator {
 93            hashes: S::occurrences_in_utf8_bytes(str_bytes),
 94            window: ArrayDeque::new(),
 95            _source: PhantomData,
 96        }
 97    }
 98}
 99
100struct HashedIdentifierParts<I: Iterator<Item = u8>> {
101    str_bytes: Peekable<I>,
102    hasher: Option<FxHasher32>,
103    prev_char_is_uppercase: bool,
104}
105
106impl<I: Iterator<Item = u8>> HashedIdentifierParts<I> {
107    fn new(str_bytes: I) -> Self {
108        Self {
109            str_bytes: str_bytes.peekable(),
110            hasher: None,
111            prev_char_is_uppercase: false,
112        }
113    }
114}
115
116impl<I: Iterator<Item = u8>> Iterator for HashedIdentifierParts<I> {
117    type Item = HashFrom<IdentifierParts>;
118
119    fn next(&mut self) -> Option<Self::Item> {
120        while let Some(ch) = self.str_bytes.next() {
121            let included = !ch.is_ascii() || ch.is_ascii_alphanumeric();
122            if let Some(mut hasher) = self.hasher.take() {
123                if !included {
124                    return Some(hasher.finish().into());
125                }
126
127                // camelCase and PascalCase
128                let is_uppercase = ch.is_ascii_uppercase();
129                let should_split = is_uppercase
130                    && (!self.prev_char_is_uppercase ||
131                        // sequences like "XMLParser" -> ["XML", "Parser"]
132                        self.str_bytes
133                            .peek()
134                            .map_or(false, |c| c.is_ascii_lowercase()));
135
136                self.prev_char_is_uppercase = is_uppercase;
137
138                if should_split {
139                    let result = (hasher.finish() as u32).into();
140                    let mut hasher = FxHasher32::default();
141                    hasher.write_u8(ch.to_ascii_lowercase());
142                    self.hasher = Some(hasher);
143                    return Some(result);
144                } else {
145                    hasher.write_u8(ch.to_ascii_lowercase());
146                    self.hasher = Some(hasher);
147                }
148            } else if included {
149                let mut hasher = FxHasher32::default();
150                hasher.write_u8(ch.to_ascii_lowercase());
151                self.hasher = Some(hasher);
152                self.prev_char_is_uppercase = ch.is_ascii_uppercase();
153            }
154        }
155
156        if let Some(hasher) = self.hasher.take() {
157            return Some(hasher.finish().into());
158        }
159
160        None
161    }
162}
163
164struct HashedCodeParts<I: Iterator<Item = u8>> {
165    str_bytes: Peekable<I>,
166    // TODO: Since this doesn't do lowercasing, it might be more efficient to find str slices and
167    // hash those, instead of hashing a byte at a time. This would be a bit complex with chunked
168    // input, though.
169    hasher: Option<FxHasher32>,
170    prev_char_is_punctuation: bool,
171}
172
173impl<I: Iterator<Item = u8>> HashedCodeParts<I> {
174    fn new(str_bytes: I) -> Self {
175        Self {
176            str_bytes: str_bytes.peekable(),
177            hasher: None,
178            prev_char_is_punctuation: false,
179        }
180    }
181}
182
183impl<I: Iterator<Item = u8>> Iterator for HashedCodeParts<I> {
184    type Item = HashFrom<CodeParts>;
185
186    fn next(&mut self) -> Option<Self::Item> {
187        fn is_punctuation(ch: u8) -> bool {
188            ch.is_ascii_punctuation() && ch != b'_' && ch != b'-'
189        }
190
191        while let Some(ch) = self.str_bytes.next() {
192            let included = !ch.is_ascii() || !ch.is_ascii_whitespace();
193            if let Some(mut hasher) = self.hasher.take() {
194                if !included {
195                    return Some(hasher.finish().into());
196                }
197
198                let is_punctuation = is_punctuation(ch);
199                let should_split = is_punctuation != self.prev_char_is_punctuation;
200                self.prev_char_is_punctuation = is_punctuation;
201
202                if should_split {
203                    let result = (hasher.finish() as u32).into();
204                    let mut hasher = FxHasher32::default();
205                    hasher.write_u8(ch);
206                    self.hasher = Some(hasher);
207                    return Some(result);
208                } else {
209                    hasher.write_u8(ch);
210                    self.hasher = Some(hasher);
211                }
212            } else if included {
213                let mut hasher = FxHasher32::default();
214                hasher.write_u8(ch);
215                self.hasher = Some(hasher);
216                self.prev_char_is_punctuation = is_punctuation(ch);
217            }
218        }
219
220        if let Some(hasher) = self.hasher.take() {
221            return Some(hasher.finish().into());
222        }
223
224        None
225    }
226}
227
228struct NGramIterator<const N: usize, S, I> {
229    hashes: I,
230    window: ArrayDeque<u32, N, arraydeque::Wrapping>,
231    _source: PhantomData<S>,
232}
233
234impl<const N: usize, S, I> Iterator for NGramIterator<N, S, I>
235where
236    I: Iterator<Item = HashFrom<S>>,
237{
238    type Item = HashFrom<NGram<N, S>>;
239
240    fn next(&mut self) -> Option<Self::Item> {
241        while let Some(hash) = self.hashes.next() {
242            if self.window.push_back(hash.into()).is_some() {
243                let mut hasher = FxHasher32::default();
244                for hash in &self.window {
245                    hasher.write_u32(*hash);
246                }
247                return Some(hasher.finish().into());
248            }
249        }
250        return None;
251    }
252}
253
254/// 32-bit variant of FXHasher
255struct FxHasher32(u32);
256
257impl Default for FxHasher32 {
258    fn default() -> Self {
259        FxHasher32(0)
260    }
261}
262
263impl FxHasher32 {
264    #[inline]
265    pub fn write_u8(&mut self, value: u8) {
266        self.write_u32(value as u32);
267    }
268
269    #[inline]
270    pub fn write_u32(&mut self, value: u32) {
271        self.0 = self.0.wrapping_add(value).wrapping_mul(0x93d765dd);
272    }
273
274    pub fn finish(self) -> u32 {
275        self.0
276    }
277}
278
279/// Converts a `Cow<'_, str>` into an iterator of bytes.
280fn cow_str_into_bytes(text: Cow<'_, str>) -> impl Iterator<Item = u8> {
281    match text {
282        Cow::Borrowed(text) => itertools::Either::Left(text.bytes()),
283        Cow::Owned(text) => itertools::Either::Right(text.into_bytes().into_iter()),
284    }
285}
286
287#[cfg(test)]
288mod test {
289    use crate::{
290        Similarity as _, WeightedSimilarity as _,
291        text_similarity::occurrences::{Occurrences, SmallOccurrences},
292    };
293
294    use super::*;
295
296    #[test]
297    fn test_identifier_parts() {
298        #[track_caller]
299        fn check(text: &str, expected: &[&str]) {
300            assert_eq!(
301                IdentifierParts::occurrences_in_str(text).collect::<Vec<_>>(),
302                expected
303                    .iter()
304                    .map(|part| string_fxhash32(part).into())
305                    .collect::<Vec<_>>()
306            );
307        }
308
309        check("", &[]);
310        check("a", &["a"]);
311        check("abc", &["abc"]);
312        check("ABC", &["abc"]);
313        check("123", &["123"]);
314        check("snake_case", &["snake", "case"]);
315        check("kebab-case", &["kebab", "case"]);
316        check("PascalCase", &["pascal", "case"]);
317        check("camelCase", &["camel", "case"]);
318        check("XMLParser", &["xml", "parser"]);
319        check("a1B2c3", &["a1", "b2c3"]);
320        check("HTML5Parser", &["html5", "parser"]);
321        check("_leading_underscore", &["leading", "underscore"]);
322        check("trailing_underscore_", &["trailing", "underscore"]);
323        check("--multiple--delimiters--", &["multiple", "delimiters"]);
324        check(
325            "snake_case kebab-case PascalCase camelCase XMLParser",
326            &[
327                "snake", "case", "kebab", "case", "pascal", "case", "camel", "case", "xml",
328                "parser",
329            ],
330        );
331    }
332
333    #[test]
334    fn test_code_parts() {
335        #[track_caller]
336        fn check(text: &str, expected: &[&str]) {
337            assert_eq!(
338                CodeParts::occurrences_in_str(text).collect::<Vec<_>>(),
339                expected
340                    .iter()
341                    .map(|part| string_fxhash32(part).into())
342                    .collect::<Vec<_>>()
343            );
344        }
345
346        check("", &[]);
347        check("a", &["a"]);
348        check("ABC", &["ABC"]);
349        check("ABC", &["ABC"]);
350        check(
351            "pub fn write_u8(&mut self, byte: u8) {",
352            &[
353                "pub", "fn", "write_u8", "(&", "mut", "self", ",", "byte", ":", "u8", ")", "{",
354            ],
355        );
356        check(
357            "snake_case kebab-case PascalCase camelCase XMLParser _leading_underscore --multiple--delimiters--",
358            &[
359                "snake_case",
360                "kebab-case",
361                "PascalCase",
362                "camelCase",
363                "XMLParser",
364                "_leading_underscore",
365                "--multiple--delimiters--",
366            ],
367        );
368    }
369
370    #[test]
371    fn test_similarity_functions() {
372        // 10 identifier parts, 8 unique
373        // Repeats: 2 "outline", 2 "items"
374        let multiset_a = Occurrences::new(IdentifierParts::occurrences_in_str(
375            "let mut outline_items = query_outline_items(&language, &tree, &source);",
376        ));
377        let set_a =
378            SmallOccurrences::<8, IdentifierParts>::new(IdentifierParts::occurrences_in_str(
379                "let mut outline_items = query_outline_items(&language, &tree, &source);",
380            ));
381        // 14 identifier parts, 11 unique
382        // Repeats: 2 "outline", 2 "language", 2 "tree"
383        let set_b = Occurrences::new(IdentifierParts::occurrences_in_str(
384            "pub fn query_outline_items(language: &Language, tree: &Tree, source: &str) -> Vec<OutlineItem> {",
385        ));
386
387        // 6 overlaps: "outline", "items", "query", "language", "tree", "source"
388        // 7 non-overlaps: "let", "mut", "pub", "fn", "vec", "item", "str"
389        assert_eq!(multiset_a.jaccard_similarity(&set_b), 6.0 / (6.0 + 7.0));
390        assert_eq!(set_a.jaccard_similarity(&set_b), 6.0 / (6.0 + 7.0));
391
392        // Numerator is one more than before due to both having 2 "outline".
393        // Denominator is the same except for 3 more due to the non-overlapping duplicates
394        assert_eq!(
395            multiset_a.weighted_jaccard_similarity(&set_b),
396            7.0 / (7.0 + 7.0 + 3.0)
397        );
398
399        // Numerator is the same as jaccard_similarity. Denominator is the size of the smaller set, 8.
400        assert_eq!(multiset_a.overlap_coefficient(&set_b), 6.0 / 8.0);
401        assert_eq!(set_a.overlap_coefficient(&set_b), 6.0 / 8.0);
402
403        // Numerator is the same as weighted_jaccard_similarity. Denominator is the total weight of
404        // the smaller set, 10.
405        assert_eq!(multiset_a.weighted_overlap_coefficient(&set_b), 7.0 / 10.0);
406    }
407
408    fn string_fxhash32(text: &str) -> u32 {
409        let mut hasher = FxHasher32::default();
410        for byte in text.bytes() {
411            hasher.write_u8(byte);
412        }
413        hasher.finish() as u32
414    }
415}