strings.rs

  1use std::{
  2    borrow::Borrow,
  3    cmp::Ordering,
  4    iter,
  5    ops::Range,
  6    sync::atomic::{self, AtomicBool},
  7};
  8
  9use gpui::{BackgroundExecutor, SharedString};
 10use nucleo::Utf32Str;
 11use nucleo::pattern::{Atom, AtomKind, CaseMatching, Normalization};
 12
 13use crate::{
 14    Cancelled, Case, LengthPenalty,
 15    matcher::{self, LENGTH_PENALTY},
 16    positions_from_sorted,
 17};
 18use fuzzy::CharBag;
 19
 20// String matching is always case-insensitive at the nucleo level — using
 21// `CaseMatching::Smart` there would reject queries whose capitalization
 22// doesn't match the candidate, breaking pickers like the command palette
 23// (`"Editor: Backspace"` against the action named `"editor: backspace"`).
 24// `Case::Smart` is still honored as a *scoring hint*: when the query
 25// contains uppercase, candidates whose matched characters disagree in case
 26// are downranked rather than dropped.
 27const SMART_CASE_PENALTY_PER_MISMATCH: f64 = 0.9;
 28
 29struct Query {
 30    atoms: Vec<Atom>,
 31    source_words: Option<Vec<Vec<char>>>,
 32    char_bag: CharBag,
 33}
 34
 35impl Query {
 36    fn build(query: &str, case: Case) -> Option<Self> {
 37        let mut atoms = Vec::new();
 38        let mut source_words = Vec::new();
 39        let wants_case_penalty = case.is_smart() && query.chars().any(|c| c.is_uppercase());
 40
 41        for word in query.split_whitespace() {
 42            atoms.push(Atom::new(
 43                word,
 44                CaseMatching::Ignore,
 45                Normalization::Smart,
 46                AtomKind::Fuzzy,
 47                false,
 48            ));
 49            if wants_case_penalty {
 50                source_words.push(word.chars().collect());
 51            }
 52        }
 53
 54        if atoms.is_empty() {
 55            return None;
 56        }
 57
 58        Some(Query {
 59            atoms,
 60            source_words: wants_case_penalty.then_some(source_words),
 61            char_bag: CharBag::from(query),
 62        })
 63    }
 64}
 65
 66#[derive(Clone, Debug)]
 67pub struct StringMatchCandidate {
 68    pub id: usize,
 69    pub string: SharedString,
 70    char_bag: CharBag,
 71}
 72
 73impl StringMatchCandidate {
 74    pub fn new(id: usize, string: impl ToString) -> Self {
 75        Self::from_shared(id, SharedString::new(string.to_string()))
 76    }
 77
 78    pub fn from_shared(id: usize, string: SharedString) -> Self {
 79        let char_bag = CharBag::from(string.as_ref());
 80        Self {
 81            id,
 82            string,
 83            char_bag,
 84        }
 85    }
 86}
 87
 88#[derive(Clone, Debug)]
 89pub struct StringMatch {
 90    pub candidate_id: usize,
 91    pub score: f64,
 92    pub positions: Vec<usize>,
 93    pub string: SharedString,
 94}
 95
 96impl StringMatch {
 97    pub fn ranges(&self) -> impl '_ + Iterator<Item = Range<usize>> {
 98        let mut positions = self.positions.iter().peekable();
 99        iter::from_fn(move || {
100            let start = *positions.next()?;
101            let char_len = self.char_len_at_index(start)?;
102            let mut end = start + char_len;
103            while let Some(next_start) = positions.peek() {
104                if end == **next_start {
105                    let Some(char_len) = self.char_len_at_index(end) else {
106                        break;
107                    };
108                    end += char_len;
109                    positions.next();
110                } else {
111                    break;
112                }
113            }
114            Some(start..end)
115        })
116    }
117
118    fn char_len_at_index(&self, ix: usize) -> Option<usize> {
119        self.string
120            .get(ix..)
121            .and_then(|slice| slice.chars().next().map(|c| c.len_utf8()))
122    }
123}
124
125impl PartialEq for StringMatch {
126    fn eq(&self, other: &Self) -> bool {
127        self.cmp(other).is_eq()
128    }
129}
130
131impl Eq for StringMatch {}
132
133impl PartialOrd for StringMatch {
134    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
135        Some(self.cmp(other))
136    }
137}
138
139impl Ord for StringMatch {
140    fn cmp(&self, other: &Self) -> Ordering {
141        self.score
142            .total_cmp(&other.score)
143            .then_with(|| self.candidate_id.cmp(&other.candidate_id))
144    }
145}
146
147pub async fn match_strings_async<T>(
148    candidates: &[T],
149    query: &str,
150    case: Case,
151    length_penalty: LengthPenalty,
152    max_results: usize,
153    cancel_flag: &AtomicBool,
154    executor: BackgroundExecutor,
155) -> Vec<StringMatch>
156where
157    T: Borrow<StringMatchCandidate> + Sync,
158{
159    if candidates.is_empty() || max_results == 0 {
160        return Vec::new();
161    }
162
163    let Some(query) = Query::build(query, case) else {
164        return empty_query_results(candidates, max_results);
165    };
166
167    let num_cpus = executor.num_cpus().min(candidates.len());
168    let segment_size = candidates.len().div_ceil(num_cpus);
169    let mut segment_results = (0..num_cpus)
170        .map(|_| Vec::with_capacity(max_results.min(candidates.len())))
171        .collect::<Vec<_>>();
172
173    let config = nucleo::Config::DEFAULT;
174    let mut matchers = matcher::get_matchers(num_cpus, config);
175
176    executor
177        .scoped(|scope| {
178            for (segment_idx, (results, matcher)) in segment_results
179                .iter_mut()
180                .zip(matchers.iter_mut())
181                .enumerate()
182            {
183                let query = &query;
184                scope.spawn(async move {
185                    let segment_start = segment_idx * segment_size;
186                    let segment_end = (segment_start + segment_size).min(candidates.len());
187
188                    match_string_helper(
189                        &candidates[segment_start..segment_end],
190                        query,
191                        matcher,
192                        length_penalty,
193                        results,
194                        cancel_flag,
195                    )
196                    .ok();
197                });
198            }
199        })
200        .await;
201
202    matcher::return_matchers(matchers);
203
204    if cancel_flag.load(atomic::Ordering::Acquire) {
205        return Vec::new();
206    }
207
208    let mut results = segment_results.concat();
209    util::truncate_to_bottom_n_sorted_by(&mut results, max_results, &|a, b| b.cmp(a));
210    results
211}
212
213pub fn match_strings<T>(
214    candidates: &[T],
215    query: &str,
216    case: Case,
217    length_penalty: LengthPenalty,
218    max_results: usize,
219) -> Vec<StringMatch>
220where
221    T: Borrow<StringMatchCandidate>,
222{
223    if candidates.is_empty() || max_results == 0 {
224        return Vec::new();
225    }
226
227    let Some(query) = Query::build(query, case) else {
228        return empty_query_results(candidates, max_results);
229    };
230
231    let config = nucleo::Config::DEFAULT;
232    let mut matcher = matcher::get_matcher(config);
233    let mut results = Vec::with_capacity(max_results.min(candidates.len()));
234
235    match_string_helper(
236        candidates,
237        &query,
238        &mut matcher,
239        length_penalty,
240        &mut results,
241        &AtomicBool::new(false),
242    )
243    .ok();
244
245    matcher::return_matcher(matcher);
246    util::truncate_to_bottom_n_sorted_by(&mut results, max_results, &|a, b| b.cmp(a));
247    results
248}
249
250fn empty_query_results<T: Borrow<StringMatchCandidate>>(
251    candidates: &[T],
252    max_results: usize,
253) -> Vec<StringMatch> {
254    candidates
255        .iter()
256        .take(max_results)
257        .map(|candidate| {
258            let borrowed = candidate.borrow();
259            StringMatch {
260                candidate_id: borrowed.id,
261                score: 0.,
262                positions: Vec::new(),
263                string: borrowed.string.clone(),
264            }
265        })
266        .collect()
267}
268
269fn match_string_helper<T>(
270    candidates: &[T],
271    query: &Query,
272    matcher: &mut nucleo::Matcher,
273    length_penalty: LengthPenalty,
274    results: &mut Vec<StringMatch>,
275    cancel_flag: &AtomicBool,
276) -> Result<(), Cancelled>
277where
278    T: Borrow<StringMatchCandidate>,
279{
280    let mut buf = Vec::new();
281    let mut matched_chars: Vec<u32> = Vec::new();
282    let mut atom_matched_chars = Vec::new();
283    let mut candidate_chars: Vec<char> = Vec::new();
284
285    for candidate in candidates {
286        buf.clear();
287        matched_chars.clear();
288        if cancel_flag.load(atomic::Ordering::Relaxed) {
289            return Err(Cancelled);
290        }
291
292        let borrowed = candidate.borrow();
293
294        if !borrowed.char_bag.is_superset(query.char_bag) {
295            continue;
296        }
297
298        let haystack: Utf32Str = Utf32Str::new(&borrowed.string, &mut buf);
299
300        if query.source_words.is_some() {
301            candidate_chars.clear();
302            candidate_chars.extend(borrowed.string.chars());
303        }
304
305        let mut total_score: u32 = 0;
306        let mut case_mismatches: u32 = 0;
307        let mut all_matched = true;
308
309        for (atom_idx, atom) in query.atoms.iter().enumerate() {
310            atom_matched_chars.clear();
311            let Some(score) = atom.indices(haystack, matcher, &mut atom_matched_chars) else {
312                all_matched = false;
313                break;
314            };
315            total_score = total_score.saturating_add(score as u32);
316            if let Some(source_words) = query.source_words.as_deref() {
317                let query_chars = &source_words[atom_idx];
318                if query_chars.len() == atom_matched_chars.len() {
319                    for (&query_char, &pos) in query_chars.iter().zip(&atom_matched_chars) {
320                        if let Some(&candidate_char) = candidate_chars.get(pos as usize)
321                            && candidate_char != query_char
322                            && candidate_char.eq_ignore_ascii_case(&query_char)
323                        {
324                            case_mismatches += 1;
325                        }
326                    }
327                }
328            }
329            matched_chars.extend_from_slice(&atom_matched_chars);
330        }
331
332        if all_matched {
333            matched_chars.sort_unstable();
334            matched_chars.dedup();
335
336            let positive = total_score as f64 * case_penalty(case_mismatches);
337            let adjusted_score =
338                positive - length_penalty_for(borrowed.string.as_ref(), length_penalty);
339            let positions = positions_from_sorted(borrowed.string.as_ref(), &matched_chars);
340
341            results.push(StringMatch {
342                candidate_id: borrowed.id,
343                score: adjusted_score,
344                positions,
345                string: borrowed.string.clone(),
346            });
347        }
348    }
349    Ok(())
350}
351
352#[inline]
353fn case_penalty(mismatches: u32) -> f64 {
354    if mismatches == 0 {
355        1.0
356    } else {
357        SMART_CASE_PENALTY_PER_MISMATCH.powi(mismatches as i32)
358    }
359}
360
361#[inline]
362fn length_penalty_for(s: &str, length_penalty: LengthPenalty) -> f64 {
363    if length_penalty.is_on() {
364        s.len() as f64 * LENGTH_PENALTY
365    } else {
366        0.0
367    }
368}
369
370#[cfg(test)]
371mod tests {
372    use super::*;
373    use gpui::BackgroundExecutor;
374
375    fn candidates(strings: &[&str]) -> Vec<StringMatchCandidate> {
376        strings
377            .iter()
378            .enumerate()
379            .map(|(id, s)| StringMatchCandidate::new(id, s))
380            .collect()
381    }
382
383    #[gpui::test]
384    async fn test_basic_match(executor: BackgroundExecutor) {
385        let cs = candidates(&["hello", "world", "help"]);
386        let cancel = AtomicBool::new(false);
387        let results = match_strings_async(
388            &cs,
389            "hel",
390            Case::Ignore,
391            LengthPenalty::Off,
392            10,
393            &cancel,
394            executor,
395        )
396        .await;
397        let matched: Vec<&str> = results.iter().map(|m| m.string.as_ref()).collect();
398        assert!(matched.contains(&"hello"));
399        assert!(matched.contains(&"help"));
400        assert!(!matched.contains(&"world"));
401    }
402
403    #[gpui::test]
404    async fn test_multi_word_query(executor: BackgroundExecutor) {
405        let cs = candidates(&[
406            "src/lib/parser.rs",
407            "src/bin/main.rs",
408            "tests/parser_test.rs",
409        ]);
410        let cancel = AtomicBool::new(false);
411        let results = match_strings_async(
412            &cs,
413            "src parser",
414            Case::Ignore,
415            LengthPenalty::Off,
416            10,
417            &cancel,
418            executor,
419        )
420        .await;
421        assert_eq!(results.len(), 1);
422        assert_eq!(results[0].string, "src/lib/parser.rs");
423    }
424
425    #[gpui::test]
426    async fn test_empty_query_returns_all(executor: BackgroundExecutor) {
427        let cs = candidates(&["alpha", "beta", "gamma"]);
428        let cancel = AtomicBool::new(false);
429        let results = match_strings_async(
430            &cs,
431            "",
432            Case::Ignore,
433            LengthPenalty::Off,
434            10,
435            &cancel,
436            executor,
437        )
438        .await;
439        assert_eq!(results.len(), 3);
440        assert!(results.iter().all(|m| m.score == 0.0));
441    }
442
443    #[gpui::test]
444    async fn test_whitespace_only_query_returns_all(executor: BackgroundExecutor) {
445        let cs = candidates(&["alpha", "beta", "gamma"]);
446        let cancel = AtomicBool::new(false);
447        let results = match_strings_async(
448            &cs,
449            "   \t\n",
450            Case::Ignore,
451            LengthPenalty::Off,
452            10,
453            &cancel,
454            executor,
455        )
456        .await;
457        assert_eq!(results.len(), 3);
458    }
459
460    #[gpui::test]
461    async fn test_empty_candidates(executor: BackgroundExecutor) {
462        let cs: Vec<StringMatchCandidate> = vec![];
463        let cancel = AtomicBool::new(false);
464        let results = match_strings_async(
465            &cs,
466            "query",
467            Case::Ignore,
468            LengthPenalty::Off,
469            10,
470            &cancel,
471            executor,
472        )
473        .await;
474        assert!(results.is_empty());
475    }
476
477    #[gpui::test]
478    async fn test_cancellation(executor: BackgroundExecutor) {
479        let cs = candidates(&["hello", "world"]);
480        let cancel = AtomicBool::new(true);
481        let results = match_strings_async(
482            &cs,
483            "hel",
484            Case::Ignore,
485            LengthPenalty::Off,
486            10,
487            &cancel,
488            executor,
489        )
490        .await;
491        assert!(results.is_empty());
492    }
493
494    #[gpui::test]
495    async fn test_max_results_limit(executor: BackgroundExecutor) {
496        let cs = candidates(&["ab", "abc", "abcd", "abcde"]);
497        let cancel = AtomicBool::new(false);
498        let results = match_strings_async(
499            &cs,
500            "ab",
501            Case::Ignore,
502            LengthPenalty::Off,
503            2,
504            &cancel,
505            executor,
506        )
507        .await;
508        assert_eq!(results.len(), 2);
509    }
510
511    #[gpui::test]
512    async fn test_scoring_order(executor: BackgroundExecutor) {
513        let cs = candidates(&[
514            "some_very_long_variable_name_fuzzy",
515            "fuzzy",
516            "a_fuzzy_thing",
517        ]);
518        let cancel = AtomicBool::new(false);
519        let results = match_strings_async(
520            &cs,
521            "fuzzy",
522            Case::Ignore,
523            LengthPenalty::Off,
524            10,
525            &cancel,
526            executor.clone(),
527        )
528        .await;
529
530        let ordered = matches!(
531            (
532                results[0].string.as_ref(),
533                results[1].string.as_ref(),
534                results[2].string.as_ref()
535            ),
536            (
537                "fuzzy",
538                "a_fuzzy_thing",
539                "some_very_long_variable_name_fuzzy"
540            )
541        );
542        assert!(ordered, "matches are not in the proper order.");
543
544        let results_penalty = match_strings_async(
545            &cs,
546            "fuzzy",
547            Case::Ignore,
548            LengthPenalty::On,
549            10,
550            &cancel,
551            executor,
552        )
553        .await;
554        let greater = results[2].score > results_penalty[2].score;
555        assert!(greater, "penalize length not affecting long candidates");
556    }
557
558    #[gpui::test]
559    async fn test_utf8_positions(executor: BackgroundExecutor) {
560        let cs = candidates(&["café"]);
561        let cancel = AtomicBool::new(false);
562        let results = match_strings_async(
563            &cs,
564            "caf",
565            Case::Ignore,
566            LengthPenalty::Off,
567            10,
568            &cancel,
569            executor,
570        )
571        .await;
572        assert_eq!(results.len(), 1);
573        let m = &results[0];
574        assert_eq!(m.positions, vec![0, 1, 2]);
575        for &pos in &m.positions {
576            assert!(m.string.is_char_boundary(pos));
577        }
578    }
579
580    #[gpui::test]
581    async fn test_smart_case(executor: BackgroundExecutor) {
582        let cs = candidates(&["FooBar", "foobar", "FOOBAR"]);
583        let cancel = AtomicBool::new(false);
584
585        let case_insensitive = match_strings_async(
586            &cs,
587            "foobar",
588            Case::Ignore,
589            LengthPenalty::Off,
590            10,
591            &cancel,
592            executor.clone(),
593        )
594        .await;
595        assert_eq!(case_insensitive.len(), 3);
596
597        let smart = match_strings_async(
598            &cs,
599            "FooBar",
600            Case::Smart,
601            LengthPenalty::Off,
602            10,
603            &cancel,
604            executor,
605        )
606        .await;
607        assert!(smart.iter().any(|m| m.string == "FooBar"));
608        let foobar_score = smart.iter().find(|m| m.string == "FooBar").map(|m| m.score);
609        let lower_score = smart.iter().find(|m| m.string == "foobar").map(|m| m.score);
610        if let (Some(exact), Some(lower)) = (foobar_score, lower_score) {
611            assert!(exact >= lower);
612        }
613    }
614
615    #[gpui::test]
616    async fn test_smart_case_does_not_flip_order_when_length_penalty_on(
617        executor: BackgroundExecutor,
618    ) {
619        // Regression for the sign bug: with a length penalty large enough to push
620        // `total_score - length_penalty` negative, case mismatches used to make
621        // scores *better* (less negative). Exact-case match must still rank first.
622        let cs = candidates(&[
623            "aaaaaaaaaaaaaaaaaaaaaaaaaaaa_FooBar",
624            "aaaaaaaaaaaaaaaaaaaaaaaaaaaa_foobar",
625        ]);
626        let cancel = AtomicBool::new(false);
627        let results = match_strings_async(
628            &cs,
629            "FooBar",
630            Case::Smart,
631            LengthPenalty::On,
632            10,
633            &cancel,
634            executor,
635        )
636        .await;
637        let exact = results
638            .iter()
639            .find(|m| m.string.as_ref() == "aaaaaaaaaaaaaaaaaaaaaaaaaaaa_FooBar")
640            .map(|m| m.score)
641            .expect("exact-case candidate should match");
642        let mismatch = results
643            .iter()
644            .find(|m| m.string.as_ref() == "aaaaaaaaaaaaaaaaaaaaaaaaaaaa_foobar")
645            .map(|m| m.score)
646            .expect("mismatch-case candidate should match");
647        assert!(
648            exact >= mismatch,
649            "exact-case score ({exact}) should be >= mismatch-case score ({mismatch})"
650        );
651    }
652
653    #[gpui::test]
654    async fn test_char_bag_prefilter(executor: BackgroundExecutor) {
655        let cs = candidates(&["abcdef", "abc", "def", "aabbcc"]);
656        let cancel = AtomicBool::new(false);
657        let results = match_strings_async(
658            &cs,
659            "abc",
660            Case::Ignore,
661            LengthPenalty::Off,
662            10,
663            &cancel,
664            executor,
665        )
666        .await;
667        let matched: Vec<&str> = results.iter().map(|m| m.string.as_ref()).collect();
668        assert!(matched.contains(&"abcdef"));
669        assert!(matched.contains(&"abc"));
670        assert!(matched.contains(&"aabbcc"));
671        assert!(!matched.contains(&"def"));
672    }
673
674    #[test]
675    fn test_sync_basic_match() {
676        let cs = candidates(&["hello", "world", "help"]);
677        let results = match_strings(&cs, "hel", Case::Ignore, LengthPenalty::Off, 10);
678        let matched: Vec<&str> = results.iter().map(|m| m.string.as_ref()).collect();
679        assert!(matched.contains(&"hello"));
680        assert!(matched.contains(&"help"));
681        assert!(!matched.contains(&"world"));
682    }
683
684    #[test]
685    fn test_sync_empty_query_returns_all() {
686        let cs = candidates(&["alpha", "beta", "gamma"]);
687        let results = match_strings(&cs, "", Case::Ignore, LengthPenalty::Off, 10);
688        assert_eq!(results.len(), 3);
689    }
690
691    #[test]
692    fn test_sync_whitespace_only_query_returns_all() {
693        let cs = candidates(&["alpha", "beta", "gamma"]);
694        let results = match_strings(&cs, "  ", Case::Ignore, LengthPenalty::Off, 10);
695        assert_eq!(results.len(), 3);
696    }
697
698    #[test]
699    fn test_sync_max_results() {
700        let cs = candidates(&["ab", "abc", "abcd", "abcde"]);
701        let results = match_strings(&cs, "ab", Case::Ignore, LengthPenalty::Off, 2);
702        assert_eq!(results.len(), 2);
703    }
704
705    #[gpui::test]
706    async fn test_empty_query_respects_max_results(executor: BackgroundExecutor) {
707        let cs = candidates(&["alpha", "beta", "gamma", "delta"]);
708        let cancel = AtomicBool::new(false);
709        let results = match_strings_async(
710            &cs,
711            "",
712            Case::Ignore,
713            LengthPenalty::Off,
714            2,
715            &cancel,
716            executor,
717        )
718        .await;
719        assert_eq!(results.len(), 2);
720    }
721
722    #[gpui::test]
723    async fn test_multi_word_with_nonmatching_word(executor: BackgroundExecutor) {
724        let cs = candidates(&["src/parser.rs", "src/main.rs"]);
725        let cancel = AtomicBool::new(false);
726        let results = match_strings_async(
727            &cs,
728            "src xyzzy",
729            Case::Ignore,
730            LengthPenalty::Off,
731            10,
732            &cancel,
733            executor,
734        )
735        .await;
736        assert!(
737            results.is_empty(),
738            "no candidate contains 'xyzzy', so nothing should match"
739        );
740    }
741}