Fuzzy-match lines when applying edits from the assistant (#12056)

Antonio Scandurra created

This uses Jaro-Winkler similarity for now, which seemed to produce
pretty good results in my tests. We can easily swap it with something
else if needed.

Release Notes:

- N/A

Change summary

Cargo.lock                              |  17 ++
crates/assistant/Cargo.toml             |   1 
crates/assistant/src/assistant_panel.rs |   6 
crates/assistant/src/search.rs          | 153 +++++++++++++++-----------
4 files changed, 103 insertions(+), 74 deletions(-)

Detailed changes

Cargo.lock ๐Ÿ”—

@@ -368,6 +368,7 @@ dependencies = [
  "serde_json",
  "settings",
  "smol",
+ "strsim 0.11.1",
  "telemetry_events",
  "theme",
  "tiktoken-rs",
@@ -1684,7 +1685,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "4c2f7349907b712260e64b0afe2f84692af14a454be26187d9df565c7f69266a"
 dependencies = [
  "memchr",
- "regex-automata 0.3.8",
+ "regex-automata 0.3.9",
  "serde",
 ]
 
@@ -2094,7 +2095,7 @@ dependencies = [
  "bitflags 1.3.2",
  "clap_lex 0.2.4",
  "indexmap 1.9.3",
- "strsim",
+ "strsim 0.10.0",
  "termcolor",
  "textwrap",
 ]
@@ -2118,7 +2119,7 @@ dependencies = [
  "anstream",
  "anstyle",
  "clap_lex 0.5.1",
- "strsim",
+ "strsim 0.10.0",
 ]
 
 [[package]]
@@ -8141,9 +8142,9 @@ dependencies = [
 
 [[package]]
 name = "regex-automata"
-version = "0.3.8"
+version = "0.3.9"
 source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "c2f401f4955220693b56f8ec66ee9c78abffd8d1c4f23dc41a23839eb88f0795"
+checksum = "59b23e92ee4318893fa3fe3e6fb365258efbfe6ac6ab30f090cdcbb7aa37efa9"
 
 [[package]]
 name = "regex-automata"
@@ -9783,6 +9784,12 @@ version = "0.10.0"
 source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "73473c0e59e6d5812c5dfe2a064a6444949f089e20eec9a2e5506596494e4623"
 
+[[package]]
+name = "strsim"
+version = "0.11.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f"
+
 [[package]]
 name = "strum"
 version = "0.25.0"

crates/assistant/Cargo.toml ๐Ÿ”—

@@ -40,6 +40,7 @@ serde.workspace = true
 serde_json.workspace = true
 settings.workspace = true
 smol.workspace = true
+strsim = "0.11"
 telemetry_events.workspace = true
 theme.workspace = true
 tiktoken-rs.workspace = true

crates/assistant/src/assistant_panel.rs ๐Ÿ”—

@@ -3058,9 +3058,9 @@ impl ConversationEditor {
                                 .entry(buffer)
                                 .or_insert(Vec::<(Range<language::Anchor>, _)>::new());
                         for suggestion in suggestions {
-                            let ranges =
-                                fuzzy_search_lines(snapshot.as_rope(), &suggestion.old_text);
-                            if let Some(range) = ranges.first() {
+                            if let Some(range) =
+                                fuzzy_search_lines(snapshot.as_rope(), &suggestion.old_text)
+                            {
                                 let edit_start = snapshot.anchor_after(range.start);
                                 let edit_end = snapshot.anchor_before(range.end);
                                 if let Err(ix) = edits.binary_search_by(|(range, _)| {

crates/assistant/src/search.rs ๐Ÿ”—

@@ -6,51 +6,75 @@ use std::ops::Range;
 ///
 /// Returns a vector of ranges of byte offsets in the buffer corresponding
 /// to the entire lines of the buffer.
-pub fn fuzzy_search_lines(haystack: &Rope, needle: &str) -> Vec<Range<usize>> {
-    let mut matches = Vec::new();
+pub fn fuzzy_search_lines(haystack: &Rope, needle: &str) -> Option<Range<usize>> {
+    const SIMILARITY_THRESHOLD: f64 = 0.8;
+
+    let mut best_match: Option<(Range<usize>, f64)> = None; // (range, score)
     let mut haystack_lines = haystack.chunks().lines();
     let mut haystack_line_start = 0;
-    while let Some(haystack_line) = haystack_lines.next() {
+    while let Some(mut haystack_line) = haystack_lines.next() {
         let next_haystack_line_start = haystack_line_start + haystack_line.len() + 1;
-        let mut trimmed_needle_lines = needle.lines().map(|line| line.trim());
-        if Some(haystack_line.trim()) == trimmed_needle_lines.next() {
-            let match_start = haystack_line_start;
-            let mut match_end = next_haystack_line_start;
-            let matched = loop {
-                match (haystack_lines.next(), trimmed_needle_lines.next()) {
-                    (Some(haystack_line), Some(needle_line)) => {
-                        // Haystack line differs from needle line: not a match.
-                        if haystack_line.trim() == needle_line {
-                            match_end = haystack_lines.offset();
-                        } else {
-                            break false;
-                        }
+        let mut advanced_to_next_haystack_line = false;
+
+        let mut matched = true;
+        let match_start = haystack_line_start;
+        let mut match_end = next_haystack_line_start;
+        let mut match_score = 0.0;
+        let mut needle_lines = needle.lines().peekable();
+        while let Some(needle_line) = needle_lines.next() {
+            let similarity = line_similarity(haystack_line, needle_line);
+            if similarity >= SIMILARITY_THRESHOLD {
+                match_end = haystack_lines.offset();
+                match_score += similarity;
+
+                if needle_lines.peek().is_some() {
+                    if let Some(next_haystack_line) = haystack_lines.next() {
+                        advanced_to_next_haystack_line = true;
+                        haystack_line = next_haystack_line;
+                    } else {
+                        matched = false;
+                        break;
                     }
-                    // We exhausted the haystack but not the query: not a match.
-                    (None, Some(_)) => break false,
-                    // We exhausted the query: it's a match.
-                    (_, None) => break true,
+                } else {
+                    break;
                 }
-            };
-
-            if matched {
-                matches.push(match_start..match_end)
+            } else {
+                matched = false;
+                break;
             }
+        }
 
-            // Advance to the next line.
-            haystack_lines.seek(next_haystack_line_start);
+        if matched
+            && best_match
+                .as_ref()
+                .map(|(_, best_score)| match_score > *best_score)
+                .unwrap_or(true)
+        {
+            best_match = Some((match_start..match_end, match_score));
         }
 
+        if advanced_to_next_haystack_line {
+            haystack_lines.seek(next_haystack_line_start);
+        }
         haystack_line_start = next_haystack_line_start;
     }
-    matches
+
+    best_match.map(|(range, _)| range)
+}
+
+/// Calculates the similarity between two lines, ignoring leading and trailing whitespace,
+/// using the Jaro-Winkler distance.
+///
+/// Returns a value between 0.0 and 1.0, where 1.0 indicates an exact match.
+fn line_similarity(line1: &str, line2: &str) -> f64 {
+    strsim::jaro_winkler(line1.trim(), line2.trim())
 }
 
 #[cfg(test)]
 mod test {
     use super::*;
     use gpui::{AppContext, Context as _};
-    use language::{Buffer, OffsetRangeExt};
+    use language::Buffer;
     use unindent::Unindent as _;
     use util::test::marked_text_ranges;
 
@@ -79,17 +103,11 @@ mod test {
                 );
             ยป
 
-                assert_eq!(
+            ยซ    assert_eq!(
                     "something",
                     "else",
                 );
-
-                if b {
-            ยซ        assert_eq!(
-                        1 + 2,
-                        3,
-                    );
-            ยป    }
+            ยป
             }
             "#
             .unindent(),
@@ -99,7 +117,7 @@ mod test {
         let buffer = cx.new_model(|cx| Buffer::local(&text, cx));
         let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot());
 
-        let actual_ranges = fuzzy_search_lines(
+        let actual_range = fuzzy_search_lines(
             snapshot.as_rope(),
             &"
             assert_eq!(
@@ -108,43 +126,46 @@ mod test {
             );
             "
             .unindent(),
-        );
-        assert_eq!(
-            actual_ranges,
-            expected_ranges,
-            "actual: {:?}, expected: {:?}",
-            actual_ranges
-                .iter()
-                .map(|range| range.to_point(&snapshot))
-                .collect::<Vec<_>>(),
-            expected_ranges
-                .iter()
-                .map(|range| range.to_point(&snapshot))
-                .collect::<Vec<_>>()
-        );
+        )
+        .unwrap();
+        assert_eq!(actual_range, expected_ranges[0]);
 
-        let actual_ranges = fuzzy_search_lines(
+        let actual_range = fuzzy_search_lines(
             snapshot.as_rope(),
             &"
             assert_eq!(
                 1 + 2,
                 3,
-                );
+            );
+            "
+            .unindent(),
+        )
+        .unwrap();
+        assert_eq!(actual_range, expected_ranges[0]);
+
+        let actual_range = fuzzy_search_lines(
+            snapshot.as_rope(),
+            &"
+            asst_eq!(
+                \"something\",
+                \"els\"
+            )
+            "
+            .unindent(),
+        )
+        .unwrap();
+        assert_eq!(actual_range, expected_ranges[1]);
+
+        let actual_range = fuzzy_search_lines(
+            snapshot.as_rope(),
+            &"
+            assert_eq!(
+                2 + 1,
+                3,
+            );
             "
             .unindent(),
         );
-        assert_eq!(
-            actual_ranges,
-            expected_ranges,
-            "actual: {:?}, expected: {:?}",
-            actual_ranges
-                .iter()
-                .map(|range| range.to_point(&snapshot))
-                .collect::<Vec<_>>(),
-            expected_ranges
-                .iter()
-                .map(|range| range.to_point(&snapshot))
-                .collect::<Vec<_>>()
-        );
+        assert_eq!(actual_range, None);
     }
 }