streaming_fuzzy_matcher.rs

  1use language::{Point, TextBufferSnapshot};
  2use std::{cmp, ops::Range};
  3
  4const REPLACEMENT_COST: u32 = 1;
  5const INSERTION_COST: u32 = 3;
  6const DELETION_COST: u32 = 10;
  7
  8/// A streaming fuzzy matcher that can process text chunks incrementally
  9/// and return the best match found so far at each step.
 10pub struct StreamingFuzzyMatcher {
 11    snapshot: TextBufferSnapshot,
 12    query_lines: Vec<String>,
 13    incomplete_line: String,
 14    best_matches: Vec<Range<usize>>,
 15    matrix: SearchMatrix,
 16}
 17
 18impl StreamingFuzzyMatcher {
 19    pub fn new(snapshot: TextBufferSnapshot) -> Self {
 20        let buffer_line_count = snapshot.max_point().row as usize + 1;
 21        Self {
 22            snapshot,
 23            query_lines: Vec::new(),
 24            incomplete_line: String::new(),
 25            best_matches: Vec::new(),
 26            matrix: SearchMatrix::new(buffer_line_count + 1),
 27        }
 28    }
 29
 30    /// Returns the query lines.
 31    pub fn query_lines(&self) -> &[String] {
 32        &self.query_lines
 33    }
 34
 35    /// Push a new chunk of text and get the best match found so far.
 36    ///
 37    /// This method accumulates text chunks and processes complete lines.
 38    /// Partial lines are buffered internally until a newline is received.
 39    ///
 40    /// # Returns
 41    ///
 42    /// Returns `Some(range)` if a match has been found with the accumulated
 43    /// query so far, or `None` if no suitable match exists yet.
 44    pub fn push(&mut self, chunk: &str) -> Option<Range<usize>> {
 45        // Add the chunk to our incomplete line buffer
 46        self.incomplete_line.push_str(chunk);
 47
 48        if let Some((last_pos, _)) = self.incomplete_line.match_indices('\n').next_back() {
 49            let complete_part = &self.incomplete_line[..=last_pos];
 50
 51            // Split into lines and add to query_lines
 52            for line in complete_part.lines() {
 53                self.query_lines.push(line.to_string());
 54            }
 55
 56            self.incomplete_line.replace_range(..last_pos + 1, "");
 57
 58            self.best_matches = self.resolve_location_fuzzy();
 59
 60            if let Some(first_match) = self.best_matches.first() {
 61                Some(first_match.clone())
 62            } else {
 63                None
 64            }
 65        } else {
 66            if let Some(first_match) = self.best_matches.first() {
 67                Some(first_match.clone())
 68            } else {
 69                None
 70            }
 71        }
 72    }
 73
 74    /// Finish processing and return the final best match(es).
 75    ///
 76    /// This processes any remaining incomplete line before returning the final
 77    /// match result.
 78    pub fn finish(&mut self) -> Vec<Range<usize>> {
 79        // Process any remaining incomplete line
 80        if !self.incomplete_line.is_empty() {
 81            self.query_lines.push(self.incomplete_line.clone());
 82            self.incomplete_line.clear();
 83            self.best_matches = self.resolve_location_fuzzy();
 84        }
 85        self.best_matches.clone()
 86    }
 87
 88    fn resolve_location_fuzzy(&mut self) -> Vec<Range<usize>> {
 89        let new_query_line_count = self.query_lines.len();
 90        let old_query_line_count = self.matrix.rows.saturating_sub(1);
 91        if new_query_line_count == old_query_line_count {
 92            return Vec::new();
 93        }
 94
 95        self.matrix.resize_rows(new_query_line_count + 1);
 96
 97        // Process only the new query lines
 98        for row in old_query_line_count..new_query_line_count {
 99            let query_line = self.query_lines[row].trim();
100            let leading_deletion_cost = (row + 1) as u32 * DELETION_COST;
101
102            self.matrix.set(
103                row + 1,
104                0,
105                SearchState::new(leading_deletion_cost, SearchDirection::Up),
106            );
107
108            let mut buffer_lines = self.snapshot.as_rope().chunks().lines();
109            let mut col = 0;
110            while let Some(buffer_line) = buffer_lines.next() {
111                let buffer_line = buffer_line.trim();
112                let up = SearchState::new(
113                    self.matrix
114                        .get(row, col + 1)
115                        .cost
116                        .saturating_add(DELETION_COST),
117                    SearchDirection::Up,
118                );
119                let left = SearchState::new(
120                    self.matrix
121                        .get(row + 1, col)
122                        .cost
123                        .saturating_add(INSERTION_COST),
124                    SearchDirection::Left,
125                );
126                let diagonal = SearchState::new(
127                    if query_line == buffer_line {
128                        self.matrix.get(row, col).cost
129                    } else if fuzzy_eq(query_line, buffer_line) {
130                        self.matrix.get(row, col).cost + REPLACEMENT_COST
131                    } else {
132                        self.matrix
133                            .get(row, col)
134                            .cost
135                            .saturating_add(DELETION_COST + INSERTION_COST)
136                    },
137                    SearchDirection::Diagonal,
138                );
139                self.matrix
140                    .set(row + 1, col + 1, up.min(left).min(diagonal));
141                col += 1;
142            }
143        }
144
145        // Find all matches with the best cost
146        let buffer_line_count = self.snapshot.max_point().row as usize + 1;
147        let mut best_cost = u32::MAX;
148        let mut matches_with_best_cost = Vec::new();
149
150        for col in 1..=buffer_line_count {
151            let cost = self.matrix.get(new_query_line_count, col).cost;
152            if cost < best_cost {
153                best_cost = cost;
154                matches_with_best_cost.clear();
155                matches_with_best_cost.push(col as u32);
156            } else if cost == best_cost {
157                matches_with_best_cost.push(col as u32);
158            }
159        }
160
161        // Find ranges for the matches
162        let mut valid_matches = Vec::new();
163        for &buffer_row_end in &matches_with_best_cost {
164            let mut matched_lines = 0;
165            let mut query_row = new_query_line_count;
166            let mut buffer_row_start = buffer_row_end;
167            while query_row > 0 && buffer_row_start > 0 {
168                let current = self.matrix.get(query_row, buffer_row_start as usize);
169                match current.direction {
170                    SearchDirection::Diagonal => {
171                        query_row -= 1;
172                        buffer_row_start -= 1;
173                        matched_lines += 1;
174                    }
175                    SearchDirection::Up => {
176                        query_row -= 1;
177                    }
178                    SearchDirection::Left => {
179                        buffer_row_start -= 1;
180                    }
181                }
182            }
183
184            let matched_buffer_row_count = buffer_row_end - buffer_row_start;
185            let matched_ratio = matched_lines as f32
186                / (matched_buffer_row_count as f32).max(new_query_line_count as f32);
187            if matched_ratio >= 0.8 {
188                let buffer_start_ix = self
189                    .snapshot
190                    .point_to_offset(Point::new(buffer_row_start, 0));
191                let buffer_end_ix = self.snapshot.point_to_offset(Point::new(
192                    buffer_row_end - 1,
193                    self.snapshot.line_len(buffer_row_end - 1),
194                ));
195                valid_matches.push((buffer_row_start, buffer_start_ix..buffer_end_ix));
196            }
197        }
198
199        valid_matches.into_iter().map(|(_, range)| range).collect()
200    }
201}
202
203fn fuzzy_eq(left: &str, right: &str) -> bool {
204    const THRESHOLD: f64 = 0.8;
205
206    let min_levenshtein = left.len().abs_diff(right.len());
207    let min_normalized_levenshtein =
208        1. - (min_levenshtein as f64 / cmp::max(left.len(), right.len()) as f64);
209    if min_normalized_levenshtein < THRESHOLD {
210        return false;
211    }
212
213    strsim::normalized_levenshtein(left, right) >= THRESHOLD
214}
215
216#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)]
217enum SearchDirection {
218    Up,
219    Left,
220    Diagonal,
221}
222
223#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
224struct SearchState {
225    cost: u32,
226    direction: SearchDirection,
227}
228
229impl SearchState {
230    fn new(cost: u32, direction: SearchDirection) -> Self {
231        Self { cost, direction }
232    }
233}
234
235struct SearchMatrix {
236    cols: usize,
237    rows: usize,
238    data: Vec<SearchState>,
239}
240
241impl SearchMatrix {
242    fn new(cols: usize) -> Self {
243        SearchMatrix {
244            cols,
245            rows: 0,
246            data: Vec::new(),
247        }
248    }
249
250    fn resize_rows(&mut self, needed_rows: usize) {
251        debug_assert!(needed_rows > self.rows);
252        self.rows = needed_rows;
253        self.data.resize(
254            self.rows * self.cols,
255            SearchState::new(0, SearchDirection::Diagonal),
256        );
257    }
258
259    fn get(&self, row: usize, col: usize) -> SearchState {
260        debug_assert!(row < self.rows && col < self.cols);
261        self.data[row * self.cols + col]
262    }
263
264    fn set(&mut self, row: usize, col: usize, state: SearchState) {
265        debug_assert!(row < self.rows && col < self.cols);
266        self.data[row * self.cols + col] = state;
267    }
268}
269
270#[cfg(test)]
271mod tests {
272    use super::*;
273    use indoc::indoc;
274    use language::{BufferId, TextBuffer};
275    use rand::prelude::*;
276    use util::test::{generate_marked_text, marked_text_ranges};
277
278    #[test]
279    fn test_empty_query() {
280        let buffer = TextBuffer::new(
281            0,
282            BufferId::new(1).unwrap(),
283            "Hello world\nThis is a test\nFoo bar baz",
284        );
285        let snapshot = buffer.snapshot();
286
287        let mut finder = StreamingFuzzyMatcher::new(snapshot.clone());
288        assert_eq!(push(&mut finder, ""), None);
289        assert_eq!(finish(finder), None);
290    }
291
292    #[test]
293    fn test_streaming_exact_match() {
294        let buffer = TextBuffer::new(
295            0,
296            BufferId::new(1).unwrap(),
297            "Hello world\nThis is a test\nFoo bar baz",
298        );
299        let snapshot = buffer.snapshot();
300
301        let mut finder = StreamingFuzzyMatcher::new(snapshot.clone());
302
303        // Push partial query
304        assert_eq!(push(&mut finder, "This"), None);
305
306        // Complete the line
307        assert_eq!(
308            push(&mut finder, " is a test\n"),
309            Some("This is a test".to_string())
310        );
311
312        // Finish should return the same result
313        assert_eq!(finish(finder), Some("This is a test".to_string()));
314    }
315
316    #[test]
317    fn test_streaming_fuzzy_match() {
318        let buffer = TextBuffer::new(
319            0,
320            BufferId::new(1).unwrap(),
321            indoc! {"
322                function foo(a, b) {
323                    return a + b;
324                }
325
326                function bar(x, y) {
327                    return x * y;
328                }
329            "},
330        );
331        let snapshot = buffer.snapshot();
332
333        let mut finder = StreamingFuzzyMatcher::new(snapshot.clone());
334
335        // Push a fuzzy query that should match the first function
336        assert_eq!(
337            push(&mut finder, "function foo(a, c) {\n").as_deref(),
338            Some("function foo(a, b) {")
339        );
340        assert_eq!(
341            push(&mut finder, "    return a + c;\n}\n").as_deref(),
342            Some(concat!(
343                "function foo(a, b) {\n",
344                "    return a + b;\n",
345                "}"
346            ))
347        );
348    }
349
350    #[test]
351    fn test_incremental_improvement() {
352        let buffer = TextBuffer::new(
353            0,
354            BufferId::new(1).unwrap(),
355            "Line 1\nLine 2\nLine 3\nLine 4\nLine 5",
356        );
357        let snapshot = buffer.snapshot();
358
359        let mut finder = StreamingFuzzyMatcher::new(snapshot.clone());
360
361        // No match initially
362        assert_eq!(push(&mut finder, "Lin"), None);
363
364        // Get a match when we complete a line
365        assert_eq!(push(&mut finder, "e 3\n"), Some("Line 3".to_string()));
366
367        // The match might change if we add more specific content
368        assert_eq!(
369            push(&mut finder, "Line 4\n"),
370            Some("Line 3\nLine 4".to_string())
371        );
372        assert_eq!(finish(finder), Some("Line 3\nLine 4".to_string()));
373    }
374
375    #[test]
376    fn test_incomplete_lines_buffering() {
377        let buffer = TextBuffer::new(
378            0,
379            BufferId::new(1).unwrap(),
380            indoc! {"
381                The quick brown fox
382                jumps over the lazy dog
383                Pack my box with five dozen liquor jugs
384            "},
385        );
386        let snapshot = buffer.snapshot();
387
388        let mut finder = StreamingFuzzyMatcher::new(snapshot.clone());
389
390        // Push text in small chunks across line boundaries
391        assert_eq!(push(&mut finder, "jumps "), None); // No newline yet
392        assert_eq!(push(&mut finder, "over the"), None); // Still no newline
393        assert_eq!(push(&mut finder, " lazy"), None); // Still incomplete
394
395        // Complete the line
396        assert_eq!(
397            push(&mut finder, " dog\n"),
398            Some("jumps over the lazy dog".to_string())
399        );
400    }
401
402    #[test]
403    fn test_multiline_fuzzy_match() {
404        let buffer = TextBuffer::new(
405            0,
406            BufferId::new(1).unwrap(),
407            indoc! {r#"
408                impl Display for User {
409                    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
410                        write!(f, "User: {} ({})", self.name, self.email)
411                    }
412                }
413
414                impl Debug for User {
415                    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
416                        f.debug_struct("User")
417                            .field("name", &self.name)
418                            .field("email", &self.email)
419                            .finish()
420                    }
421                }
422            "#},
423        );
424        let snapshot = buffer.snapshot();
425
426        let mut finder = StreamingFuzzyMatcher::new(snapshot.clone());
427
428        assert_eq!(
429            push(&mut finder, "impl Debug for User {\n"),
430            Some("impl Debug for User {".to_string())
431        );
432        assert_eq!(
433            push(
434                &mut finder,
435                "    fn fmt(&self, f: &mut Formatter) -> Result {\n"
436            )
437            .as_deref(),
438            Some(concat!(
439                "impl Debug for User {\n",
440                "    fn fmt(&self, f: &mut Formatter) -> fmt::Result {"
441            ))
442        );
443        assert_eq!(
444            push(&mut finder, "        f.debug_struct(\"User\")\n").as_deref(),
445            Some(concat!(
446                "impl Debug for User {\n",
447                "    fn fmt(&self, f: &mut Formatter) -> fmt::Result {\n",
448                "        f.debug_struct(\"User\")"
449            ))
450        );
451        assert_eq!(
452            push(
453                &mut finder,
454                "            .field(\"name\", &self.username)\n"
455            )
456            .as_deref(),
457            Some(concat!(
458                "impl Debug for User {\n",
459                "    fn fmt(&self, f: &mut Formatter) -> fmt::Result {\n",
460                "        f.debug_struct(\"User\")\n",
461                "            .field(\"name\", &self.name)"
462            ))
463        );
464        assert_eq!(
465            finish(finder).as_deref(),
466            Some(concat!(
467                "impl Debug for User {\n",
468                "    fn fmt(&self, f: &mut Formatter) -> fmt::Result {\n",
469                "        f.debug_struct(\"User\")\n",
470                "            .field(\"name\", &self.name)"
471            ))
472        );
473    }
474
475    #[gpui::test(iterations = 100)]
476    fn test_resolve_location_single_line(mut rng: StdRng) {
477        assert_location_resolution(
478            concat!(
479                "    Lorem\n",
480                "«    ipsum»\n",
481                "    dolor sit amet\n",
482                "    consecteur",
483            ),
484            "ipsum",
485            &mut rng,
486        );
487    }
488
489    #[gpui::test(iterations = 100)]
490    fn test_resolve_location_multiline(mut rng: StdRng) {
491        assert_location_resolution(
492            concat!(
493                "    Lorem\n",
494                "«    ipsum\n",
495                "    dolor sit amet»\n",
496                "    consecteur",
497            ),
498            "ipsum\ndolor sit amet",
499            &mut rng,
500        );
501    }
502
503    #[gpui::test(iterations = 100)]
504    fn test_resolve_location_function_with_typo(mut rng: StdRng) {
505        assert_location_resolution(
506            indoc! {"
507                «fn foo1(a: usize) -> usize {
508                    40
509510
511                fn foo2(b: usize) -> usize {
512                    42
513                }
514            "},
515            "fn foo1(a: usize) -> u32 {\n40\n}",
516            &mut rng,
517        );
518    }
519
520    #[gpui::test(iterations = 100)]
521    fn test_resolve_location_class_methods(mut rng: StdRng) {
522        assert_location_resolution(
523            indoc! {"
524                class Something {
525                    one() { return 1; }
526                «    two() { return 2222; }
527                    three() { return 333; }
528                    four() { return 4444; }
529                    five() { return 5555; }
530                    six() { return 6666; }»
531                    seven() { return 7; }
532                    eight() { return 8; }
533                }
534            "},
535            indoc! {"
536                two() { return 2222; }
537                four() { return 4444; }
538                five() { return 5555; }
539                six() { return 6666; }
540            "},
541            &mut rng,
542        );
543    }
544
545    #[gpui::test(iterations = 100)]
546    fn test_resolve_location_imports_no_match(mut rng: StdRng) {
547        assert_location_resolution(
548            indoc! {"
549                use std::ops::Range;
550                use std::sync::Mutex;
551                use std::{
552                    collections::HashMap,
553                    env,
554                    ffi::{OsStr, OsString},
555                    fs,
556                    io::{BufRead, BufReader},
557                    mem,
558                    path::{Path, PathBuf},
559                    process::Command,
560                    sync::LazyLock,
561                    time::SystemTime,
562                };
563            "},
564            indoc! {"
565                use std::collections::{HashMap, HashSet};
566                use std::ffi::{OsStr, OsString};
567                use std::fmt::Write as _;
568                use std::fs;
569                use std::io::{BufReader, Read, Write};
570                use std::mem;
571                use std::path::{Path, PathBuf};
572                use std::process::Command;
573                use std::sync::Arc;
574            "},
575            &mut rng,
576        );
577    }
578
579    #[gpui::test(iterations = 100)]
580    fn test_resolve_location_nested_closure(mut rng: StdRng) {
581        assert_location_resolution(
582            indoc! {"
583                impl Foo {
584                    fn new() -> Self {
585                        Self {
586                            subscriptions: vec![
587                                cx.observe_window_activation(window, |editor, window, cx| {
588                                    let active = window.is_window_active();
589                                    editor.blink_manager.update(cx, |blink_manager, cx| {
590                                        if active {
591                                            blink_manager.enable(cx);
592                                        } else {
593                                            blink_manager.disable(cx);
594                                        }
595                                    });
596                                }),
597                            ];
598                        }
599                    }
600                }
601            "},
602            concat!(
603                "                    editor.blink_manager.update(cx, |blink_manager, cx| {\n",
604                "                        blink_manager.enable(cx);\n",
605                "                    });",
606            ),
607            &mut rng,
608        );
609    }
610
611    #[gpui::test(iterations = 100)]
612    fn test_resolve_location_tool_invocation(mut rng: StdRng) {
613        assert_location_resolution(
614            indoc! {r#"
615                let tool = cx
616                    .update(|cx| working_set.tool(&tool_name, cx))
617                    .map_err(|err| {
618                        anyhow!("Failed to look up tool '{}': {}", tool_name, err)
619                    })?;
620
621                let Some(tool) = tool else {
622                    return Err(anyhow!("Tool '{}' not found", tool_name));
623                };
624
625                let project = project.clone();
626                let action_log = action_log.clone();
627                let messages = messages.clone();
628                let tool_result = cx
629                    .update(|cx| tool.run(invocation.input, &messages, project, action_log, cx))
630                    .map_err(|err| anyhow!("Failed to start tool '{}': {}", tool_name, err))?;
631
632                tasks.push(tool_result.output);
633            "#},
634            concat!(
635                "let tool_result = cx\n",
636                "    .update(|cx| tool.run(invocation.input, &messages, project, action_log, cx))\n",
637                "    .output;",
638            ),
639            &mut rng,
640        );
641    }
642
643    #[track_caller]
644    fn assert_location_resolution(text_with_expected_range: &str, query: &str, rng: &mut StdRng) {
645        let (text, expected_ranges) = marked_text_ranges(text_with_expected_range, false);
646        let buffer = TextBuffer::new(0, BufferId::new(1).unwrap(), text.clone());
647        let snapshot = buffer.snapshot();
648
649        let mut matcher = StreamingFuzzyMatcher::new(snapshot.clone());
650
651        // Split query into random chunks
652        let chunks = to_random_chunks(rng, query);
653
654        // Push chunks incrementally
655        for chunk in &chunks {
656            matcher.push(chunk);
657        }
658
659        let actual_ranges = matcher.finish();
660
661        // If no expected ranges, we expect no match
662        if expected_ranges.is_empty() {
663            assert!(
664                actual_ranges.is_empty(),
665                "Expected no match for query: {:?}, but found: {:?}",
666                query,
667                actual_ranges
668            );
669        } else {
670            let text_with_actual_range = generate_marked_text(&text, &actual_ranges, false);
671            pretty_assertions::assert_eq!(
672                text_with_actual_range,
673                text_with_expected_range,
674                indoc! {"
675                    Query: {:?}
676                    Chunks: {:?}
677                    Expected marked text: {}
678                    Actual marked text: {}
679                    Expected ranges: {:?}
680                    Actual ranges: {:?}"
681                },
682                query,
683                chunks,
684                text_with_expected_range,
685                text_with_actual_range,
686                expected_ranges,
687                actual_ranges
688            );
689        }
690    }
691
692    fn to_random_chunks(rng: &mut StdRng, input: &str) -> Vec<String> {
693        let chunk_count = rng.gen_range(1..=cmp::min(input.len(), 50));
694        let mut chunk_indices = (0..input.len()).choose_multiple(rng, chunk_count);
695        chunk_indices.sort();
696        chunk_indices.push(input.len());
697
698        let mut chunks = Vec::new();
699        let mut last_ix = 0;
700        for chunk_ix in chunk_indices {
701            chunks.push(input[last_ix..chunk_ix].to_string());
702            last_ix = chunk_ix;
703        }
704        chunks
705    }
706
707    fn push(finder: &mut StreamingFuzzyMatcher, chunk: &str) -> Option<String> {
708        finder
709            .push(chunk)
710            .map(|range| finder.snapshot.text_for_range(range).collect::<String>())
711    }
712
713    fn finish(mut finder: StreamingFuzzyMatcher) -> Option<String> {
714        let snapshot = finder.snapshot.clone();
715        let matches = finder.finish();
716        if let Some(range) = matches.first() {
717            Some(snapshot.text_for_range(range.clone()).collect::<String>())
718        } else {
719            None
720        }
721    }
722}