strings.rs

  1use std::{
  2    borrow::Borrow,
  3    cmp::Ordering,
  4    iter,
  5    ops::Range,
  6    sync::atomic::{self, AtomicBool},
  7};
  8
  9use gpui::{BackgroundExecutor, SharedString};
 10use nucleo::Utf32Str;
 11
 12use crate::{
 13    Cancelled, Case, LengthPenalty, Query, case_penalty, count_case_mismatches,
 14    matcher::{self, LENGTH_PENALTY},
 15    positions_from_sorted,
 16};
 17use fuzzy::CharBag;
 18
 19#[derive(Clone, Debug)]
 20pub struct StringMatchCandidate {
 21    pub id: usize,
 22    pub string: SharedString,
 23    char_bag: CharBag,
 24}
 25
 26impl StringMatchCandidate {
 27    pub fn new(id: usize, string: impl ToString) -> Self {
 28        Self::from_shared(id, SharedString::new(string.to_string()))
 29    }
 30
 31    pub fn from_shared(id: usize, string: SharedString) -> Self {
 32        let char_bag = CharBag::from(string.as_ref());
 33        Self {
 34            id,
 35            string,
 36            char_bag,
 37        }
 38    }
 39}
 40
 41#[derive(Clone, Debug)]
 42pub struct StringMatch {
 43    pub candidate_id: usize,
 44    pub score: f64,
 45    pub positions: Vec<usize>,
 46    pub string: SharedString,
 47}
 48
 49impl StringMatch {
 50    pub fn ranges(&self) -> impl '_ + Iterator<Item = Range<usize>> {
 51        let mut positions = self.positions.iter().peekable();
 52        iter::from_fn(move || {
 53            let start = *positions.next()?;
 54            let char_len = self.char_len_at_index(start)?;
 55            let mut end = start + char_len;
 56            while let Some(next_start) = positions.peek() {
 57                if end == **next_start {
 58                    let Some(char_len) = self.char_len_at_index(end) else {
 59                        break;
 60                    };
 61                    end += char_len;
 62                    positions.next();
 63                } else {
 64                    break;
 65                }
 66            }
 67            Some(start..end)
 68        })
 69    }
 70
 71    fn char_len_at_index(&self, ix: usize) -> Option<usize> {
 72        self.string
 73            .get(ix..)
 74            .and_then(|slice| slice.chars().next().map(|c| c.len_utf8()))
 75    }
 76}
 77
 78impl PartialEq for StringMatch {
 79    fn eq(&self, other: &Self) -> bool {
 80        self.cmp(other).is_eq()
 81    }
 82}
 83
 84impl Eq for StringMatch {}
 85
 86impl PartialOrd for StringMatch {
 87    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
 88        Some(self.cmp(other))
 89    }
 90}
 91
 92impl Ord for StringMatch {
 93    fn cmp(&self, other: &Self) -> Ordering {
 94        self.score
 95            .total_cmp(&other.score)
 96            .then_with(|| self.candidate_id.cmp(&other.candidate_id))
 97    }
 98}
 99
100pub async fn match_strings_async<T>(
101    candidates: &[T],
102    query: &str,
103    case: Case,
104    length_penalty: LengthPenalty,
105    max_results: usize,
106    cancel_flag: &AtomicBool,
107    executor: BackgroundExecutor,
108) -> Vec<StringMatch>
109where
110    T: Borrow<StringMatchCandidate> + Sync,
111{
112    if candidates.is_empty() || max_results == 0 {
113        return Vec::new();
114    }
115
116    let Some(query) = Query::build(query, case) else {
117        return empty_query_results(candidates, max_results);
118    };
119
120    let num_cpus = executor.num_cpus().min(candidates.len());
121    let base_size = candidates.len() / num_cpus;
122    let remainder = candidates.len() % num_cpus;
123    let mut segment_results = (0..num_cpus)
124        .map(|_| Vec::with_capacity(max_results.min(candidates.len())))
125        .collect::<Vec<_>>();
126
127    let config = nucleo::Config::DEFAULT;
128    let mut matchers = matcher::get_matchers(num_cpus, config);
129
130    executor
131        .scoped(|scope| {
132            for (segment_idx, (results, matcher)) in segment_results
133                .iter_mut()
134                .zip(matchers.iter_mut())
135                .enumerate()
136            {
137                let query = &query;
138                scope.spawn(async move {
139                    let segment_start = segment_idx * base_size + segment_idx.min(remainder);
140                    let segment_end =
141                        (segment_idx + 1) * base_size + (segment_idx + 1).min(remainder);
142
143                    match_string_helper(
144                        &candidates[segment_start..segment_end],
145                        query,
146                        matcher,
147                        length_penalty,
148                        results,
149                        cancel_flag,
150                    )
151                    .ok();
152                });
153            }
154        })
155        .await;
156
157    matcher::return_matchers(matchers);
158
159    if cancel_flag.load(atomic::Ordering::Acquire) {
160        return Vec::new();
161    }
162
163    let mut results = segment_results.concat();
164    util::truncate_to_bottom_n_sorted_by(&mut results, max_results, &|a, b| b.cmp(a));
165    results
166}
167
168pub fn match_strings<T>(
169    candidates: &[T],
170    query: &str,
171    case: Case,
172    length_penalty: LengthPenalty,
173    max_results: usize,
174) -> Vec<StringMatch>
175where
176    T: Borrow<StringMatchCandidate>,
177{
178    if candidates.is_empty() || max_results == 0 {
179        return Vec::new();
180    }
181
182    let Some(query) = Query::build(query, case) else {
183        return empty_query_results(candidates, max_results);
184    };
185
186    let config = nucleo::Config::DEFAULT;
187    let mut matcher = matcher::get_matcher(config);
188    let mut results = Vec::with_capacity(max_results.min(candidates.len()));
189
190    match_string_helper(
191        candidates,
192        &query,
193        &mut matcher,
194        length_penalty,
195        &mut results,
196        &AtomicBool::new(false),
197    )
198    .ok();
199
200    matcher::return_matcher(matcher);
201    util::truncate_to_bottom_n_sorted_by(&mut results, max_results, &|a, b| b.cmp(a));
202    results
203}
204
205fn empty_query_results<T: Borrow<StringMatchCandidate>>(
206    candidates: &[T],
207    max_results: usize,
208) -> Vec<StringMatch> {
209    candidates
210        .iter()
211        .take(max_results)
212        .map(|candidate| {
213            let borrowed = candidate.borrow();
214            StringMatch {
215                candidate_id: borrowed.id,
216                score: 0.,
217                positions: Vec::new(),
218                string: borrowed.string.clone(),
219            }
220        })
221        .collect()
222}
223
224fn match_string_helper<T>(
225    candidates: &[T],
226    query: &Query,
227    matcher: &mut nucleo::Matcher,
228    length_penalty: LengthPenalty,
229    results: &mut Vec<StringMatch>,
230    cancel_flag: &AtomicBool,
231) -> Result<(), Cancelled>
232where
233    T: Borrow<StringMatchCandidate>,
234{
235    let mut buf = Vec::new();
236    let mut matched_chars: Vec<u32> = Vec::new();
237    let mut candidate_chars: Vec<char> = Vec::new();
238
239    for candidate in candidates {
240        buf.clear();
241        matched_chars.clear();
242        if cancel_flag.load(atomic::Ordering::Relaxed) {
243            return Err(Cancelled);
244        }
245
246        let borrowed = candidate.borrow();
247
248        if !borrowed.char_bag.is_superset(query.char_bag) {
249            continue;
250        }
251
252        let haystack: Utf32Str = Utf32Str::new(borrowed.string.as_ref(), &mut buf);
253
254        let Some(score) = query.pattern.indices(haystack, matcher, &mut matched_chars) else {
255            continue;
256        };
257
258        let case_mismatches = count_case_mismatches(
259            query.query_chars.as_deref(),
260            &matched_chars,
261            borrowed.string.as_ref(),
262            &mut candidate_chars,
263        );
264
265        matched_chars.sort_unstable();
266        matched_chars.dedup();
267
268        let positive = score as f64 * case_penalty(case_mismatches);
269        let adjusted_score =
270            positive - length_penalty_for(borrowed.string.as_ref(), length_penalty);
271        let positions = positions_from_sorted(borrowed.string.as_ref(), &matched_chars);
272
273        results.push(StringMatch {
274            candidate_id: borrowed.id,
275            score: adjusted_score,
276            positions,
277            string: borrowed.string.clone(),
278        });
279    }
280    Ok(())
281}
282
283#[inline]
284fn length_penalty_for(s: &str, length_penalty: LengthPenalty) -> f64 {
285    if length_penalty.is_on() {
286        s.len() as f64 * LENGTH_PENALTY
287    } else {
288        0.0
289    }
290}
291
292#[cfg(test)]
293mod tests {
294    use super::*;
295    use gpui::BackgroundExecutor;
296
297    fn candidates(strings: &[&str]) -> Vec<StringMatchCandidate> {
298        strings
299            .iter()
300            .enumerate()
301            .map(|(id, s)| StringMatchCandidate::new(id, s))
302            .collect()
303    }
304
305    #[gpui::test]
306    async fn test_basic_match(executor: BackgroundExecutor) {
307        let cs = candidates(&["hello", "world", "help"]);
308        let cancel = AtomicBool::new(false);
309        let results = match_strings_async(
310            &cs,
311            "hel",
312            Case::Ignore,
313            LengthPenalty::Off,
314            10,
315            &cancel,
316            executor,
317        )
318        .await;
319        let matched: Vec<&str> = results.iter().map(|m| m.string.as_ref()).collect();
320        assert!(matched.contains(&"hello"));
321        assert!(matched.contains(&"help"));
322        assert!(!matched.contains(&"world"));
323    }
324
325    #[gpui::test]
326    async fn test_multi_word_query(executor: BackgroundExecutor) {
327        let cs = candidates(&[
328            "src/lib/parser.rs",
329            "src/bin/main.rs",
330            "tests/parser_test.rs",
331        ]);
332        let cancel = AtomicBool::new(false);
333        let results = match_strings_async(
334            &cs,
335            "src parser",
336            Case::Ignore,
337            LengthPenalty::Off,
338            10,
339            &cancel,
340            executor,
341        )
342        .await;
343        assert_eq!(results.len(), 1);
344        assert_eq!(results[0].string, "src/lib/parser.rs");
345    }
346
347    #[gpui::test]
348    async fn test_empty_query_returns_all(executor: BackgroundExecutor) {
349        let cs = candidates(&["alpha", "beta", "gamma"]);
350        let cancel = AtomicBool::new(false);
351        let results = match_strings_async(
352            &cs,
353            "",
354            Case::Ignore,
355            LengthPenalty::Off,
356            10,
357            &cancel,
358            executor,
359        )
360        .await;
361        assert_eq!(results.len(), 3);
362        assert!(results.iter().all(|m| m.score == 0.0));
363    }
364
365    #[gpui::test]
366    async fn test_whitespace_only_query_returns_all(executor: BackgroundExecutor) {
367        let cs = candidates(&["alpha", "beta", "gamma"]);
368        let cancel = AtomicBool::new(false);
369        let results = match_strings_async(
370            &cs,
371            "   \t\n",
372            Case::Ignore,
373            LengthPenalty::Off,
374            10,
375            &cancel,
376            executor,
377        )
378        .await;
379        assert_eq!(results.len(), 3);
380    }
381
382    #[gpui::test]
383    async fn test_empty_candidates(executor: BackgroundExecutor) {
384        let cs: Vec<StringMatchCandidate> = vec![];
385        let cancel = AtomicBool::new(false);
386        let results = match_strings_async(
387            &cs,
388            "query",
389            Case::Ignore,
390            LengthPenalty::Off,
391            10,
392            &cancel,
393            executor,
394        )
395        .await;
396        assert!(results.is_empty());
397    }
398
399    #[gpui::test]
400    async fn test_cancellation(executor: BackgroundExecutor) {
401        let cs = candidates(&["hello", "world"]);
402        let cancel = AtomicBool::new(true);
403        let results = match_strings_async(
404            &cs,
405            "hel",
406            Case::Ignore,
407            LengthPenalty::Off,
408            10,
409            &cancel,
410            executor,
411        )
412        .await;
413        assert!(results.is_empty());
414    }
415
416    #[gpui::test]
417    async fn test_max_results_limit(executor: BackgroundExecutor) {
418        let cs = candidates(&["ab", "abc", "abcd", "abcde"]);
419        let cancel = AtomicBool::new(false);
420        let results = match_strings_async(
421            &cs,
422            "ab",
423            Case::Ignore,
424            LengthPenalty::Off,
425            2,
426            &cancel,
427            executor,
428        )
429        .await;
430        assert_eq!(results.len(), 2);
431    }
432
433    #[gpui::test]
434    async fn test_scoring_order(executor: BackgroundExecutor) {
435        let cs = candidates(&[
436            "some_very_long_variable_name_fuzzy",
437            "fuzzy",
438            "a_fuzzy_thing",
439        ]);
440        let cancel = AtomicBool::new(false);
441        let results = match_strings_async(
442            &cs,
443            "fuzzy",
444            Case::Ignore,
445            LengthPenalty::Off,
446            10,
447            &cancel,
448            executor.clone(),
449        )
450        .await;
451
452        let ordered = matches!(
453            (
454                results[0].string.as_ref(),
455                results[1].string.as_ref(),
456                results[2].string.as_ref()
457            ),
458            (
459                "fuzzy",
460                "a_fuzzy_thing",
461                "some_very_long_variable_name_fuzzy"
462            )
463        );
464        assert!(ordered, "matches are not in the proper order.");
465
466        let results_penalty = match_strings_async(
467            &cs,
468            "fuzzy",
469            Case::Ignore,
470            LengthPenalty::On,
471            10,
472            &cancel,
473            executor,
474        )
475        .await;
476        let greater = results[2].score > results_penalty[2].score;
477        assert!(greater, "penalize length not affecting long candidates");
478    }
479
480    #[gpui::test]
481    async fn test_utf8_positions(executor: BackgroundExecutor) {
482        let cs = candidates(&["café"]);
483        let cancel = AtomicBool::new(false);
484        let results = match_strings_async(
485            &cs,
486            "caf",
487            Case::Ignore,
488            LengthPenalty::Off,
489            10,
490            &cancel,
491            executor,
492        )
493        .await;
494        assert_eq!(results.len(), 1);
495        let m = &results[0];
496        assert_eq!(m.positions, vec![0, 1, 2]);
497        for &pos in &m.positions {
498            assert!(m.string.is_char_boundary(pos));
499        }
500    }
501
502    #[gpui::test]
503    async fn test_smart_case(executor: BackgroundExecutor) {
504        let cs = candidates(&["FooBar", "foobar", "FOOBAR"]);
505        let cancel = AtomicBool::new(false);
506
507        let case_insensitive = match_strings_async(
508            &cs,
509            "foobar",
510            Case::Ignore,
511            LengthPenalty::Off,
512            10,
513            &cancel,
514            executor.clone(),
515        )
516        .await;
517        assert_eq!(case_insensitive.len(), 3);
518
519        let smart = match_strings_async(
520            &cs,
521            "FooBar",
522            Case::Smart,
523            LengthPenalty::Off,
524            10,
525            &cancel,
526            executor,
527        )
528        .await;
529        assert!(smart.iter().any(|m| m.string == "FooBar"));
530        let foobar_score = smart.iter().find(|m| m.string == "FooBar").map(|m| m.score);
531        let lower_score = smart.iter().find(|m| m.string == "foobar").map(|m| m.score);
532        if let (Some(exact), Some(lower)) = (foobar_score, lower_score) {
533            assert!(exact >= lower);
534        }
535    }
536
537    #[gpui::test]
538    async fn test_smart_case_does_not_flip_order_when_length_penalty_on(
539        executor: BackgroundExecutor,
540    ) {
541        // Regression for the sign bug: with a length penalty large enough to push
542        // `total_score - length_penalty` negative, case mismatches used to make
543        // scores *better* (less negative). Exact-case match must still rank first.
544        let cs = candidates(&[
545            "aaaaaaaaaaaaaaaaaaaaaaaaaaaa_FooBar",
546            "aaaaaaaaaaaaaaaaaaaaaaaaaaaa_foobar",
547        ]);
548        let cancel = AtomicBool::new(false);
549        let results = match_strings_async(
550            &cs,
551            "FooBar",
552            Case::Smart,
553            LengthPenalty::On,
554            10,
555            &cancel,
556            executor,
557        )
558        .await;
559        let exact = results
560            .iter()
561            .find(|m| m.string.as_ref() == "aaaaaaaaaaaaaaaaaaaaaaaaaaaa_FooBar")
562            .map(|m| m.score)
563            .expect("exact-case candidate should match");
564        let mismatch = results
565            .iter()
566            .find(|m| m.string.as_ref() == "aaaaaaaaaaaaaaaaaaaaaaaaaaaa_foobar")
567            .map(|m| m.score)
568            .expect("mismatch-case candidate should match");
569        assert!(
570            exact >= mismatch,
571            "exact-case score ({exact}) should be >= mismatch-case score ({mismatch})"
572        );
573    }
574
575    #[gpui::test]
576    async fn test_char_bag_prefilter(executor: BackgroundExecutor) {
577        let cs = candidates(&["abcdef", "abc", "def", "aabbcc"]);
578        let cancel = AtomicBool::new(false);
579        let results = match_strings_async(
580            &cs,
581            "abc",
582            Case::Ignore,
583            LengthPenalty::Off,
584            10,
585            &cancel,
586            executor,
587        )
588        .await;
589        let matched: Vec<&str> = results.iter().map(|m| m.string.as_ref()).collect();
590        assert!(matched.contains(&"abcdef"));
591        assert!(matched.contains(&"abc"));
592        assert!(matched.contains(&"aabbcc"));
593        assert!(!matched.contains(&"def"));
594    }
595
596    #[test]
597    fn test_sync_basic_match() {
598        let cs = candidates(&["hello", "world", "help"]);
599        let results = match_strings(&cs, "hel", Case::Ignore, LengthPenalty::Off, 10);
600        let matched: Vec<&str> = results.iter().map(|m| m.string.as_ref()).collect();
601        assert!(matched.contains(&"hello"));
602        assert!(matched.contains(&"help"));
603        assert!(!matched.contains(&"world"));
604    }
605
606    #[test]
607    fn test_sync_empty_query_returns_all() {
608        let cs = candidates(&["alpha", "beta", "gamma"]);
609        let results = match_strings(&cs, "", Case::Ignore, LengthPenalty::Off, 10);
610        assert_eq!(results.len(), 3);
611    }
612
613    #[test]
614    fn test_sync_whitespace_only_query_returns_all() {
615        let cs = candidates(&["alpha", "beta", "gamma"]);
616        let results = match_strings(&cs, "  ", Case::Ignore, LengthPenalty::Off, 10);
617        assert_eq!(results.len(), 3);
618    }
619
620    #[test]
621    fn test_sync_max_results() {
622        let cs = candidates(&["ab", "abc", "abcd", "abcde"]);
623        let results = match_strings(&cs, "ab", Case::Ignore, LengthPenalty::Off, 2);
624        assert_eq!(results.len(), 2);
625    }
626
627    #[gpui::test]
628    async fn test_empty_query_respects_max_results(executor: BackgroundExecutor) {
629        let cs = candidates(&["alpha", "beta", "gamma", "delta"]);
630        let cancel = AtomicBool::new(false);
631        let results = match_strings_async(
632            &cs,
633            "",
634            Case::Ignore,
635            LengthPenalty::Off,
636            2,
637            &cancel,
638            executor,
639        )
640        .await;
641        assert_eq!(results.len(), 2);
642    }
643
644    #[gpui::test]
645    async fn test_multi_word_with_nonmatching_word(executor: BackgroundExecutor) {
646        let cs = candidates(&["src/parser.rs", "src/main.rs"]);
647        let cancel = AtomicBool::new(false);
648        let results = match_strings_async(
649            &cs,
650            "src xyzzy",
651            Case::Ignore,
652            LengthPenalty::Off,
653            10,
654            &cancel,
655            executor,
656        )
657        .await;
658        assert!(
659            results.is_empty(),
660            "no candidate contains 'xyzzy', so nothing should match"
661        );
662    }
663
664    #[gpui::test]
665    async fn test_segment_size_not_divisible_by_cpus(executor: BackgroundExecutor) {
666        executor.set_num_cpus(4);
667        let cs = candidates(&["alpha", "beta", "gamma", "delta", "epsilon"]);
668        let cancel = AtomicBool::new(false);
669        let results = match_strings_async(
670            &cs,
671            "a",
672            Case::Ignore,
673            LengthPenalty::Off,
674            10,
675            &cancel,
676            executor,
677        )
678        .await;
679        let matched: Vec<&str> = results.iter().map(|m| m.string.as_ref()).collect();
680        assert!(matched.contains(&"alpha"));
681        assert!(matched.contains(&"gamma"));
682        assert!(matched.contains(&"delta"));
683    }
684
685    #[gpui::test]
686    async fn test_segment_size_with_many_cpus_few_candidates(executor: BackgroundExecutor) {
687        executor.set_num_cpus(16);
688        let cs = candidates(&["one", "two", "three"]);
689        let cancel = AtomicBool::new(false);
690        let results = match_strings_async(
691            &cs,
692            "o",
693            Case::Ignore,
694            LengthPenalty::Off,
695            10,
696            &cancel,
697            executor,
698        )
699        .await;
700        let matched: Vec<&str> = results.iter().map(|m| m.string.as_ref()).collect();
701        assert!(matched.contains(&"one"));
702        assert!(matched.contains(&"two"));
703    }
704
705    #[gpui::test]
706    async fn test_segment_size_single_candidate(executor: BackgroundExecutor) {
707        executor.set_num_cpus(8);
708        let cs = candidates(&["lonely"]);
709        let cancel = AtomicBool::new(false);
710        let results = match_strings_async(
711            &cs,
712            "lone",
713            Case::Ignore,
714            LengthPenalty::Off,
715            10,
716            &cancel,
717            executor,
718        )
719        .await;
720        assert_eq!(results.len(), 1);
721        assert_eq!(results[0].string.as_ref(), "lonely");
722    }
723
724    #[gpui::test]
725    async fn test_segment_size_candidates_equal_cpus(executor: BackgroundExecutor) {
726        executor.set_num_cpus(4);
727        let cs = candidates(&["aaa", "bbb", "ccc", "ddd"]);
728        let cancel = AtomicBool::new(false);
729        let results = match_strings_async(
730            &cs,
731            "a",
732            Case::Ignore,
733            LengthPenalty::Off,
734            10,
735            &cancel,
736            executor,
737        )
738        .await;
739        assert_eq!(results.len(), 1);
740        assert_eq!(results[0].string.as_ref(), "aaa");
741    }
742
743    #[gpui::test]
744    async fn test_segment_size_candidates_one_more_than_cpus(executor: BackgroundExecutor) {
745        executor.set_num_cpus(3);
746        let cs = candidates(&["ant", "ape", "dog", "axe"]);
747        let cancel = AtomicBool::new(false);
748        let results = match_strings_async(
749            &cs,
750            "a",
751            Case::Ignore,
752            LengthPenalty::Off,
753            10,
754            &cancel,
755            executor,
756        )
757        .await;
758        let matched: Vec<&str> = results.iter().map(|m| m.string.as_ref()).collect();
759        assert!(matched.contains(&"ant"));
760        assert!(matched.contains(&"ape"));
761        assert!(matched.contains(&"axe"));
762        assert!(!matched.contains(&"dog"));
763    }
764}