fuzzy.rs

  1use easy_parallel::Parallel;
  2
  3use super::char_bag::CharBag;
  4
  5use std::{
  6    cmp::{max, min, Ordering, Reverse},
  7    collections::BinaryHeap,
  8};
  9
 10const BASE_DISTANCE_PENALTY: f64 = 0.6;
 11const ADDITIONAL_DISTANCE_PENALTY: f64 = 0.05;
 12const MIN_DISTANCE_PENALTY: f64 = 0.2;
 13
 14pub struct PathEntry {
 15    pub entry_id: usize,
 16    pub path_chars: CharBag,
 17    pub path: Vec<char>,
 18    pub lowercase_path: Vec<char>,
 19    pub is_ignored: bool,
 20}
 21
 22#[derive(Clone, Debug)]
 23pub struct PathMatch {
 24    pub score: f64,
 25    pub positions: Vec<usize>,
 26    pub tree_id: usize,
 27    pub entry_id: usize,
 28    pub skipped_prefix_len: usize,
 29}
 30
 31impl PartialEq for PathMatch {
 32    fn eq(&self, other: &Self) -> bool {
 33        self.score.eq(&other.score)
 34    }
 35}
 36
 37impl Eq for PathMatch {}
 38
 39impl PartialOrd for PathMatch {
 40    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
 41        self.score.partial_cmp(&other.score)
 42    }
 43}
 44
 45impl Ord for PathMatch {
 46    fn cmp(&self, other: &Self) -> Ordering {
 47        self.partial_cmp(other).unwrap_or(Ordering::Equal)
 48    }
 49}
 50
 51pub fn match_paths(
 52    paths_by_tree_id: &[(usize, usize, &[PathEntry])],
 53    query: &str,
 54    include_ignored: bool,
 55    smart_case: bool,
 56    max_results: usize,
 57) -> Vec<PathMatch> {
 58    let lowercase_query = query.to_lowercase().chars().collect::<Vec<_>>();
 59    let query = query.chars().collect::<Vec<_>>();
 60    let lowercase_query = &lowercase_query;
 61    let query = &query;
 62    let query_chars = CharBag::from(&lowercase_query[..]);
 63
 64    let cpus = num_cpus::get();
 65    let path_count = paths_by_tree_id
 66        .iter()
 67        .fold(0, |sum, (_, _, paths)| sum + paths.len());
 68    let segment_size = (path_count + cpus - 1) / cpus;
 69    let mut segment_results = (0..cpus).map(|_| BinaryHeap::new()).collect::<Vec<_>>();
 70
 71    Parallel::new()
 72        .each(
 73            segment_results.iter_mut().enumerate(),
 74            |(segment_idx, results)| {
 75                let segment_start = segment_idx * segment_size;
 76                let segment_end = segment_start + segment_size;
 77
 78                let mut min_score = 0.0;
 79                let mut last_positions = Vec::new();
 80                last_positions.resize(query.len(), 0);
 81                let mut match_positions = Vec::new();
 82                match_positions.resize(query.len(), 0);
 83                let mut score_matrix = Vec::new();
 84                let mut best_position_matrix = Vec::new();
 85
 86                let mut tree_start = 0;
 87                for (tree_id, skipped_prefix_len, paths) in paths_by_tree_id {
 88                    let tree_end = tree_start + paths.len();
 89                    if tree_start < segment_end && segment_start < tree_end {
 90                        let start = max(tree_start, segment_start) - tree_start;
 91                        let end = min(tree_end, segment_end) - tree_start;
 92
 93                        match_single_tree_paths(
 94                            *tree_id,
 95                            *skipped_prefix_len,
 96                            paths,
 97                            start,
 98                            end,
 99                            query,
100                            lowercase_query,
101                            query_chars,
102                            include_ignored,
103                            smart_case,
104                            results,
105                            max_results,
106                            &mut min_score,
107                            &mut match_positions,
108                            &mut last_positions,
109                            &mut score_matrix,
110                            &mut best_position_matrix,
111                        );
112                    }
113                    if tree_end >= segment_end {
114                        break;
115                    }
116                    tree_start = tree_end;
117                }
118            },
119        )
120        .run();
121
122    let mut results = segment_results
123        .into_iter()
124        .flatten()
125        .map(|r| r.0)
126        .collect::<Vec<_>>();
127    results.sort_unstable_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
128    results.truncate(max_results);
129    results
130}
131
132fn match_single_tree_paths(
133    tree_id: usize,
134    skipped_prefix_len: usize,
135    path_entries: &[PathEntry],
136    start: usize,
137    end: usize,
138    query: &[char],
139    lowercase_query: &[char],
140    query_chars: CharBag,
141    include_ignored: bool,
142    smart_case: bool,
143    results: &mut BinaryHeap<Reverse<PathMatch>>,
144    max_results: usize,
145    min_score: &mut f64,
146    match_positions: &mut Vec<usize>,
147    last_positions: &mut Vec<usize>,
148    score_matrix: &mut Vec<Option<f64>>,
149    best_position_matrix: &mut Vec<usize>,
150) {
151    for i in start..end {
152        let path_entry = unsafe { &path_entries.get_unchecked(i) };
153
154        if !include_ignored && path_entry.is_ignored {
155            continue;
156        }
157
158        if !path_entry.path_chars.is_superset(query_chars) {
159            continue;
160        }
161
162        if !find_last_positions(
163            last_positions,
164            skipped_prefix_len,
165            &path_entry.lowercase_path,
166            &lowercase_query[..],
167        ) {
168            continue;
169        }
170
171        let matrix_len = query.len() * (path_entry.path.len() - skipped_prefix_len);
172        score_matrix.clear();
173        score_matrix.resize(matrix_len, None);
174        best_position_matrix.clear();
175        best_position_matrix.resize(matrix_len, skipped_prefix_len);
176
177        let score = score_match(
178            &query[..],
179            &lowercase_query[..],
180            &path_entry.path,
181            &path_entry.lowercase_path,
182            skipped_prefix_len,
183            smart_case,
184            &last_positions,
185            score_matrix,
186            best_position_matrix,
187            match_positions,
188            *min_score,
189        );
190
191        if score > 0.0 {
192            results.push(Reverse(PathMatch {
193                tree_id,
194                entry_id: path_entry.entry_id,
195                score,
196                positions: match_positions.clone(),
197                skipped_prefix_len,
198            }));
199            if results.len() == max_results {
200                *min_score = results.peek().unwrap().0.score;
201            }
202        }
203    }
204}
205
206fn find_last_positions(
207    last_positions: &mut Vec<usize>,
208    skipped_prefix_len: usize,
209    path: &[char],
210    query: &[char],
211) -> bool {
212    let mut path = path.iter();
213    for (i, char) in query.iter().enumerate().rev() {
214        if let Some(j) = path.rposition(|c| c == char) {
215            if j >= skipped_prefix_len {
216                last_positions[i] = j;
217            } else {
218                return false;
219            }
220        } else {
221            return false;
222        }
223    }
224    true
225}
226
227fn score_match(
228    query: &[char],
229    query_cased: &[char],
230    path: &[char],
231    path_cased: &[char],
232    skipped_prefix_len: usize,
233    smart_case: bool,
234    last_positions: &[usize],
235    score_matrix: &mut [Option<f64>],
236    best_position_matrix: &mut [usize],
237    match_positions: &mut [usize],
238    min_score: f64,
239) -> f64 {
240    let score = recursive_score_match(
241        query,
242        query_cased,
243        path,
244        path_cased,
245        skipped_prefix_len,
246        smart_case,
247        last_positions,
248        score_matrix,
249        best_position_matrix,
250        min_score,
251        0,
252        skipped_prefix_len,
253        query.len() as f64,
254    ) * query.len() as f64;
255
256    if score <= 0.0 {
257        return 0.0;
258    }
259
260    let path_len = path.len() - skipped_prefix_len;
261    let mut cur_start = 0;
262    for i in 0..query.len() {
263        match_positions[i] = best_position_matrix[i * path_len + cur_start] - skipped_prefix_len;
264        cur_start = match_positions[i] + 1;
265    }
266
267    score
268}
269
270fn recursive_score_match(
271    query: &[char],
272    query_cased: &[char],
273    path: &[char],
274    path_cased: &[char],
275    skipped_prefix_len: usize,
276    smart_case: bool,
277    last_positions: &[usize],
278    score_matrix: &mut [Option<f64>],
279    best_position_matrix: &mut [usize],
280    min_score: f64,
281    query_idx: usize,
282    path_idx: usize,
283    cur_score: f64,
284) -> f64 {
285    if query_idx == query.len() {
286        return 1.0;
287    }
288
289    let path_len = path.len() - skipped_prefix_len;
290
291    if let Some(memoized) = score_matrix[query_idx * path_len + path_idx - skipped_prefix_len] {
292        return memoized;
293    }
294
295    let mut score = 0.0;
296    let mut best_position = 0;
297
298    let query_char = query_cased[query_idx];
299    let limit = last_positions[query_idx];
300
301    let mut last_slash = 0;
302    for j in path_idx..=limit {
303        let path_char = path_cased[j];
304        let is_path_sep = path_char == '/' || path_char == '\\';
305
306        if query_idx == 0 && is_path_sep {
307            last_slash = j;
308        }
309
310        if query_char == path_char || (is_path_sep && query_char == '_' || query_char == '\\') {
311            let mut char_score = 1.0;
312            if j > path_idx {
313                let last = path[j - 1];
314                let curr = path[j];
315
316                if last == '/' {
317                    char_score = 0.9;
318                } else if last == '-' || last == '_' || last == ' ' || last.is_numeric() {
319                    char_score = 0.8;
320                } else if last.is_lowercase() && curr.is_uppercase() {
321                    char_score = 0.8;
322                } else if last == '.' {
323                    char_score = 0.7;
324                } else if query_idx == 0 {
325                    char_score = BASE_DISTANCE_PENALTY;
326                } else {
327                    char_score = MIN_DISTANCE_PENALTY.max(
328                        BASE_DISTANCE_PENALTY
329                            - (j - path_idx - 1) as f64 * ADDITIONAL_DISTANCE_PENALTY,
330                    );
331                }
332            }
333
334            // Apply a severe penalty if the case doesn't match.
335            // This will make the exact matches have higher score than the case-insensitive and the
336            // path insensitive matches.
337            if (smart_case || path[j] == '/') && query[query_idx] != path[j] {
338                char_score *= 0.001;
339            }
340
341            let mut multiplier = char_score;
342
343            // Scale the score based on how deep within the patch we found the match.
344            if query_idx == 0 {
345                multiplier /= (path.len() - last_slash) as f64;
346            }
347
348            let mut next_score = 1.0;
349            if min_score > 0.0 {
350                next_score = cur_score * multiplier;
351                // Scores only decrease. If we can't pass the previous best, bail
352                if next_score < min_score {
353                    // Ensure that score is non-zero so we use it in the memo table.
354                    if score == 0.0 {
355                        score = 1e-18;
356                    }
357                    continue;
358                }
359            }
360
361            let new_score = recursive_score_match(
362                query,
363                query_cased,
364                path,
365                path_cased,
366                skipped_prefix_len,
367                smart_case,
368                last_positions,
369                score_matrix,
370                best_position_matrix,
371                min_score,
372                query_idx + 1,
373                j + 1,
374                next_score,
375            ) * multiplier;
376
377            if new_score > score {
378                score = new_score;
379                best_position = j;
380                // Optimization: can't score better than 1.
381                if new_score == 1.0 {
382                    break;
383                }
384            }
385        }
386    }
387
388    if best_position != 0 {
389        best_position_matrix[query_idx * path_len + path_idx - skipped_prefix_len] = best_position;
390    }
391
392    score_matrix[query_idx * path_len + path_idx - skipped_prefix_len] = Some(score);
393    score
394}
395
396#[cfg(test)]
397mod tests {
398    use super::*;
399
400    #[test]
401    fn test_match_path_entries() {
402        let paths = vec![
403            "",
404            "a",
405            "ab",
406            "abC",
407            "abcd",
408            "alphabravocharlie",
409            "AlphaBravoCharlie",
410            "thisisatestdir",
411            "/////ThisIsATestDir",
412            "/this/is/a/test/dir",
413            "/test/tiatd",
414        ];
415
416        assert_eq!(
417            match_query("abc", false, &paths),
418            vec![
419                ("abC", vec![0, 1, 2]),
420                ("abcd", vec![0, 1, 2]),
421                ("AlphaBravoCharlie", vec![0, 5, 10]),
422                ("alphabravocharlie", vec![4, 5, 10]),
423            ]
424        );
425        assert_eq!(
426            match_query("t/i/a/t/d", false, &paths),
427            vec![("/this/is/a/test/dir", vec![1, 5, 6, 8, 9, 10, 11, 15, 16]),]
428        );
429
430        assert_eq!(
431            match_query("tiatd", false, &paths),
432            vec![
433                ("/test/tiatd", vec![6, 7, 8, 9, 10]),
434                ("/this/is/a/test/dir", vec![1, 6, 9, 11, 16]),
435                ("/////ThisIsATestDir", vec![5, 9, 11, 12, 16]),
436                ("thisisatestdir", vec![0, 2, 6, 7, 11]),
437            ]
438        );
439    }
440
441    fn match_query<'a>(
442        query: &str,
443        smart_case: bool,
444        paths: &Vec<&'a str>,
445    ) -> Vec<(&'a str, Vec<usize>)> {
446        let lowercase_query = query.to_lowercase().chars().collect::<Vec<_>>();
447        let query = query.chars().collect::<Vec<_>>();
448        let query_chars = CharBag::from(&lowercase_query[..]);
449
450        let mut path_entries = Vec::new();
451        for (i, path) in paths.iter().enumerate() {
452            let lowercase_path = path.to_lowercase().chars().collect::<Vec<_>>();
453            let path_chars = CharBag::from(&lowercase_path[..]);
454            let path = path.chars().collect();
455            path_entries.push(PathEntry {
456                entry_id: i,
457                path_chars,
458                path,
459                lowercase_path,
460                is_ignored: false,
461            });
462        }
463
464        let mut match_positions = Vec::new();
465        let mut last_positions = Vec::new();
466        match_positions.resize(query.len(), 0);
467        last_positions.resize(query.len(), 0);
468
469        let mut results = BinaryHeap::new();
470        match_single_tree_paths(
471            0,
472            0,
473            &path_entries,
474            0,
475            path_entries.len(),
476            &query[..],
477            &lowercase_query[..],
478            query_chars,
479            true,
480            smart_case,
481            &mut results,
482            100,
483            &mut 0.0,
484            &mut match_positions,
485            &mut last_positions,
486            &mut Vec::new(),
487            &mut Vec::new(),
488        );
489
490        results
491            .into_iter()
492            .rev()
493            .map(|result| (paths[result.0.entry_id].clone(), result.0.positions))
494            .collect()
495    }
496}