1use easy_parallel::Parallel;
2
3use super::char_bag::CharBag;
4
5use std::{
6 cmp::{max, min, Ordering, Reverse},
7 collections::BinaryHeap,
8};
9
10const BASE_DISTANCE_PENALTY: f64 = 0.6;
11const ADDITIONAL_DISTANCE_PENALTY: f64 = 0.05;
12const MIN_DISTANCE_PENALTY: f64 = 0.2;
13
14pub struct PathEntry {
15 pub entry_id: usize,
16 pub path_chars: CharBag,
17 pub path: Vec<char>,
18 pub lowercase_path: Vec<char>,
19 pub is_ignored: bool,
20}
21
22#[derive(Clone, Debug)]
23pub struct PathMatch {
24 pub score: f64,
25 pub positions: Vec<usize>,
26 pub tree_id: usize,
27 pub entry_id: usize,
28 pub skipped_prefix_len: usize,
29}
30
31impl PartialEq for PathMatch {
32 fn eq(&self, other: &Self) -> bool {
33 self.score.eq(&other.score)
34 }
35}
36
37impl Eq for PathMatch {}
38
39impl PartialOrd for PathMatch {
40 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
41 self.score.partial_cmp(&other.score)
42 }
43}
44
45impl Ord for PathMatch {
46 fn cmp(&self, other: &Self) -> Ordering {
47 self.partial_cmp(other).unwrap_or(Ordering::Equal)
48 }
49}
50
51pub fn match_paths(
52 paths_by_tree_id: &[(usize, usize, &[PathEntry])],
53 query: &str,
54 include_ignored: bool,
55 smart_case: bool,
56 max_results: usize,
57) -> Vec<PathMatch> {
58 let lowercase_query = query.to_lowercase().chars().collect::<Vec<_>>();
59 let query = query.chars().collect::<Vec<_>>();
60 let lowercase_query = &lowercase_query;
61 let query = &query;
62 let query_chars = CharBag::from(&lowercase_query[..]);
63
64 let cpus = num_cpus::get();
65 let path_count = paths_by_tree_id
66 .iter()
67 .fold(0, |sum, (_, _, paths)| sum + paths.len());
68 let segment_size = (path_count + cpus - 1) / cpus;
69 let mut segment_results = (0..cpus).map(|_| BinaryHeap::new()).collect::<Vec<_>>();
70
71 Parallel::new()
72 .each(
73 segment_results.iter_mut().enumerate(),
74 |(segment_idx, results)| {
75 let segment_start = segment_idx * segment_size;
76 let segment_end = segment_start + segment_size;
77
78 let mut min_score = 0.0;
79 let mut last_positions = Vec::new();
80 last_positions.resize(query.len(), 0);
81 let mut match_positions = Vec::new();
82 match_positions.resize(query.len(), 0);
83 let mut score_matrix = Vec::new();
84 let mut best_position_matrix = Vec::new();
85
86 let mut tree_start = 0;
87 for (tree_id, skipped_prefix_len, paths) in paths_by_tree_id {
88 let tree_end = tree_start + paths.len();
89 if tree_start < segment_end && segment_start < tree_end {
90 let start = max(tree_start, segment_start) - tree_start;
91 let end = min(tree_end, segment_end) - tree_start;
92
93 match_single_tree_paths(
94 *tree_id,
95 *skipped_prefix_len,
96 paths,
97 start,
98 end,
99 query,
100 lowercase_query,
101 query_chars,
102 include_ignored,
103 smart_case,
104 results,
105 max_results,
106 &mut min_score,
107 &mut match_positions,
108 &mut last_positions,
109 &mut score_matrix,
110 &mut best_position_matrix,
111 );
112 }
113 if tree_end >= segment_end {
114 break;
115 }
116 tree_start = tree_end;
117 }
118 },
119 )
120 .run();
121
122 let mut results = segment_results
123 .into_iter()
124 .flatten()
125 .map(|r| r.0)
126 .collect::<Vec<_>>();
127 results.sort_unstable_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
128 results.truncate(max_results);
129 results
130}
131
132fn match_single_tree_paths(
133 tree_id: usize,
134 skipped_prefix_len: usize,
135 path_entries: &[PathEntry],
136 start: usize,
137 end: usize,
138 query: &[char],
139 lowercase_query: &[char],
140 query_chars: CharBag,
141 include_ignored: bool,
142 smart_case: bool,
143 results: &mut BinaryHeap<Reverse<PathMatch>>,
144 max_results: usize,
145 min_score: &mut f64,
146 match_positions: &mut Vec<usize>,
147 last_positions: &mut Vec<usize>,
148 score_matrix: &mut Vec<Option<f64>>,
149 best_position_matrix: &mut Vec<usize>,
150) {
151 for i in start..end {
152 let path_entry = unsafe { &path_entries.get_unchecked(i) };
153
154 if !include_ignored && path_entry.is_ignored {
155 continue;
156 }
157
158 if !path_entry.path_chars.is_superset(query_chars) {
159 continue;
160 }
161
162 if !find_last_positions(
163 last_positions,
164 skipped_prefix_len,
165 &path_entry.lowercase_path,
166 &lowercase_query[..],
167 ) {
168 continue;
169 }
170
171 let matrix_len = query.len() * (path_entry.path.len() - skipped_prefix_len);
172 score_matrix.clear();
173 score_matrix.resize(matrix_len, None);
174 best_position_matrix.clear();
175 best_position_matrix.resize(matrix_len, skipped_prefix_len);
176
177 let score = score_match(
178 &query[..],
179 &lowercase_query[..],
180 &path_entry.path,
181 &path_entry.lowercase_path,
182 skipped_prefix_len,
183 smart_case,
184 &last_positions,
185 score_matrix,
186 best_position_matrix,
187 match_positions,
188 *min_score,
189 );
190
191 if score > 0.0 {
192 results.push(Reverse(PathMatch {
193 tree_id,
194 entry_id: path_entry.entry_id,
195 score,
196 positions: match_positions.clone(),
197 skipped_prefix_len,
198 }));
199 if results.len() == max_results {
200 *min_score = results.peek().unwrap().0.score;
201 }
202 }
203 }
204}
205
206fn find_last_positions(
207 last_positions: &mut Vec<usize>,
208 skipped_prefix_len: usize,
209 path: &[char],
210 query: &[char],
211) -> bool {
212 let mut path = path.iter();
213 for (i, char) in query.iter().enumerate().rev() {
214 if let Some(j) = path.rposition(|c| c == char) {
215 if j >= skipped_prefix_len {
216 last_positions[i] = j;
217 } else {
218 return false;
219 }
220 } else {
221 return false;
222 }
223 }
224 true
225}
226
227fn score_match(
228 query: &[char],
229 query_cased: &[char],
230 path: &[char],
231 path_cased: &[char],
232 skipped_prefix_len: usize,
233 smart_case: bool,
234 last_positions: &[usize],
235 score_matrix: &mut [Option<f64>],
236 best_position_matrix: &mut [usize],
237 match_positions: &mut [usize],
238 min_score: f64,
239) -> f64 {
240 let score = recursive_score_match(
241 query,
242 query_cased,
243 path,
244 path_cased,
245 skipped_prefix_len,
246 smart_case,
247 last_positions,
248 score_matrix,
249 best_position_matrix,
250 min_score,
251 0,
252 skipped_prefix_len,
253 query.len() as f64,
254 ) * query.len() as f64;
255
256 if score <= 0.0 {
257 return 0.0;
258 }
259
260 let path_len = path.len() - skipped_prefix_len;
261 let mut cur_start = 0;
262 for i in 0..query.len() {
263 match_positions[i] = best_position_matrix[i * path_len + cur_start] - skipped_prefix_len;
264 cur_start = match_positions[i] + 1;
265 }
266
267 score
268}
269
270fn recursive_score_match(
271 query: &[char],
272 query_cased: &[char],
273 path: &[char],
274 path_cased: &[char],
275 skipped_prefix_len: usize,
276 smart_case: bool,
277 last_positions: &[usize],
278 score_matrix: &mut [Option<f64>],
279 best_position_matrix: &mut [usize],
280 min_score: f64,
281 query_idx: usize,
282 path_idx: usize,
283 cur_score: f64,
284) -> f64 {
285 if query_idx == query.len() {
286 return 1.0;
287 }
288
289 let path_len = path.len() - skipped_prefix_len;
290
291 if let Some(memoized) = score_matrix[query_idx * path_len + path_idx - skipped_prefix_len] {
292 return memoized;
293 }
294
295 let mut score = 0.0;
296 let mut best_position = 0;
297
298 let query_char = query_cased[query_idx];
299 let limit = last_positions[query_idx];
300
301 let mut last_slash = 0;
302 for j in path_idx..=limit {
303 let path_char = path_cased[j];
304 let is_path_sep = path_char == '/' || path_char == '\\';
305
306 if query_idx == 0 && is_path_sep {
307 last_slash = j;
308 }
309
310 if query_char == path_char || (is_path_sep && query_char == '_' || query_char == '\\') {
311 let mut char_score = 1.0;
312 if j > path_idx {
313 let last = path[j - 1];
314 let curr = path[j];
315
316 if last == '/' {
317 char_score = 0.9;
318 } else if last == '-' || last == '_' || last == ' ' || last.is_numeric() {
319 char_score = 0.8;
320 } else if last.is_lowercase() && curr.is_uppercase() {
321 char_score = 0.8;
322 } else if last == '.' {
323 char_score = 0.7;
324 } else if query_idx == 0 {
325 char_score = BASE_DISTANCE_PENALTY;
326 } else {
327 char_score = MIN_DISTANCE_PENALTY.max(
328 BASE_DISTANCE_PENALTY
329 - (j - path_idx - 1) as f64 * ADDITIONAL_DISTANCE_PENALTY,
330 );
331 }
332 }
333
334 // Apply a severe penalty if the case doesn't match.
335 // This will make the exact matches have higher score than the case-insensitive and the
336 // path insensitive matches.
337 if (smart_case || path[j] == '/') && query[query_idx] != path[j] {
338 char_score *= 0.001;
339 }
340
341 let mut multiplier = char_score;
342
343 // Scale the score based on how deep within the patch we found the match.
344 if query_idx == 0 {
345 multiplier /= (path.len() - last_slash) as f64;
346 }
347
348 let mut next_score = 1.0;
349 if min_score > 0.0 {
350 next_score = cur_score * multiplier;
351 // Scores only decrease. If we can't pass the previous best, bail
352 if next_score < min_score {
353 // Ensure that score is non-zero so we use it in the memo table.
354 if score == 0.0 {
355 score = 1e-18;
356 }
357 continue;
358 }
359 }
360
361 let new_score = recursive_score_match(
362 query,
363 query_cased,
364 path,
365 path_cased,
366 skipped_prefix_len,
367 smart_case,
368 last_positions,
369 score_matrix,
370 best_position_matrix,
371 min_score,
372 query_idx + 1,
373 j + 1,
374 next_score,
375 ) * multiplier;
376
377 if new_score > score {
378 score = new_score;
379 best_position = j;
380 // Optimization: can't score better than 1.
381 if new_score == 1.0 {
382 break;
383 }
384 }
385 }
386 }
387
388 if best_position != 0 {
389 best_position_matrix[query_idx * path_len + path_idx - skipped_prefix_len] = best_position;
390 }
391
392 score_matrix[query_idx * path_len + path_idx - skipped_prefix_len] = Some(score);
393 score
394}
395
396#[cfg(test)]
397mod tests {
398 use super::*;
399
400 #[test]
401 fn test_match_path_entries() {
402 let paths = vec![
403 "",
404 "a",
405 "ab",
406 "abC",
407 "abcd",
408 "alphabravocharlie",
409 "AlphaBravoCharlie",
410 "thisisatestdir",
411 "/////ThisIsATestDir",
412 "/this/is/a/test/dir",
413 "/test/tiatd",
414 ];
415
416 assert_eq!(
417 match_query("abc", false, &paths),
418 vec![
419 ("abC", vec![0, 1, 2]),
420 ("abcd", vec![0, 1, 2]),
421 ("AlphaBravoCharlie", vec![0, 5, 10]),
422 ("alphabravocharlie", vec![4, 5, 10]),
423 ]
424 );
425 assert_eq!(
426 match_query("t/i/a/t/d", false, &paths),
427 vec![("/this/is/a/test/dir", vec![1, 5, 6, 8, 9, 10, 11, 15, 16]),]
428 );
429
430 assert_eq!(
431 match_query("tiatd", false, &paths),
432 vec![
433 ("/test/tiatd", vec![6, 7, 8, 9, 10]),
434 ("/this/is/a/test/dir", vec![1, 6, 9, 11, 16]),
435 ("/////ThisIsATestDir", vec![5, 9, 11, 12, 16]),
436 ("thisisatestdir", vec![0, 2, 6, 7, 11]),
437 ]
438 );
439 }
440
441 fn match_query<'a>(
442 query: &str,
443 smart_case: bool,
444 paths: &Vec<&'a str>,
445 ) -> Vec<(&'a str, Vec<usize>)> {
446 let lowercase_query = query.to_lowercase().chars().collect::<Vec<_>>();
447 let query = query.chars().collect::<Vec<_>>();
448 let query_chars = CharBag::from(&lowercase_query[..]);
449
450 let mut path_entries = Vec::new();
451 for (i, path) in paths.iter().enumerate() {
452 let lowercase_path = path.to_lowercase().chars().collect::<Vec<_>>();
453 let path_chars = CharBag::from(&lowercase_path[..]);
454 let path = path.chars().collect();
455 path_entries.push(PathEntry {
456 entry_id: i,
457 path_chars,
458 path,
459 lowercase_path,
460 is_ignored: false,
461 });
462 }
463
464 let mut match_positions = Vec::new();
465 let mut last_positions = Vec::new();
466 match_positions.resize(query.len(), 0);
467 last_positions.resize(query.len(), 0);
468
469 let mut results = BinaryHeap::new();
470 match_single_tree_paths(
471 0,
472 0,
473 &path_entries,
474 0,
475 path_entries.len(),
476 &query[..],
477 &lowercase_query[..],
478 query_chars,
479 true,
480 smart_case,
481 &mut results,
482 100,
483 &mut 0.0,
484 &mut match_positions,
485 &mut last_positions,
486 &mut Vec::new(),
487 &mut Vec::new(),
488 );
489
490 results
491 .into_iter()
492 .rev()
493 .map(|result| (paths[result.0.entry_id].clone(), result.0.positions))
494 .collect()
495 }
496}