assistant edit tool: Revert fuzzy matching (#26996)

Agus Zubiaga created

#26935 is leading to bad edits, so let's revert it for now. I'll bring
back a version of this, but it'll likely just focus on indentation
instead of making the whole search fuzzy.

Release Notes: 

- N/A

Change summary

Cargo.lock                                                         |   2 
crates/assistant_tools/Cargo.toml                                  |   2 
crates/assistant_tools/src/edit_files_tool.rs                      | 100 
crates/assistant_tools/src/edit_files_tool/resolve_search_block.rs | 226 
4 files changed, 78 insertions(+), 252 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -718,7 +718,6 @@ dependencies = [
  "itertools 0.14.0",
  "language",
  "language_model",
- "pretty_assertions",
  "project",
  "rand 0.8.5",
  "release_channel",
@@ -728,7 +727,6 @@ dependencies = [
  "settings",
  "theme",
  "ui",
- "unindent",
  "util",
  "workspace",
  "worktree",

crates/assistant_tools/Cargo.toml 🔗

@@ -39,7 +39,5 @@ rand.workspace = true
 collections = { workspace = true, features = ["test-support"] }
 gpui = { workspace = true, features = ["test-support"] }
 language = { workspace = true, features = ["test-support"] }
-pretty_assertions.workspace = true
 project = { workspace = true, features = ["test-support"] }
-unindent.workspace = true
 workspace = { workspace = true, features = ["test-support"] }

crates/assistant_tools/src/edit_files_tool.rs 🔗

@@ -1,6 +1,5 @@
 mod edit_action;
 pub mod log;
-mod resolve_search_block;
 
 use anyhow::{anyhow, Context, Result};
 use assistant_tool::{ActionLog, Tool};
@@ -8,17 +7,16 @@ use collections::HashSet;
 use edit_action::{EditAction, EditActionParser};
 use futures::StreamExt;
 use gpui::{App, AsyncApp, Entity, Task};
-use language::OffsetRangeExt;
 use language_model::{
     LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, MessageContent, Role,
 };
 use log::{EditToolLog, EditToolRequestId};
-use project::Project;
-use resolve_search_block::resolve_search_block;
+use project::{search::SearchQuery, Project};
 use schemars::JsonSchema;
 use serde::{Deserialize, Serialize};
 use std::fmt::Write;
 use std::sync::Arc;
+use util::paths::PathMatcher;
 use util::ResultExt;
 
 #[derive(Debug, Serialize, Deserialize, JsonSchema)]
@@ -131,11 +129,24 @@ struct EditToolRequest {
     parser: EditActionParser,
     output: String,
     changed_buffers: HashSet<Entity<language::Buffer>>,
+    bad_searches: Vec<BadSearch>,
     project: Entity<Project>,
     action_log: Entity<ActionLog>,
     tool_log: Option<(Entity<EditToolLog>, EditToolRequestId)>,
 }
 
+#[derive(Debug)]
+enum DiffResult {
+    BadSearch(BadSearch),
+    Diff(language::Diff),
+}
+
+#[derive(Debug)]
+struct BadSearch {
+    file_path: String,
+    search: String,
+}
+
 impl EditToolRequest {
     fn new(
         input: EditFilesToolInput,
@@ -193,6 +204,7 @@ impl EditToolRequest {
                 // we start with the success header so we don't need to shift the output in the common case
                 output: Self::SUCCESS_OUTPUT_HEADER.to_string(),
                 changed_buffers: HashSet::default(),
+                bad_searches: Vec::new(),
                 action_log,
                 project,
                 tool_log,
@@ -239,30 +251,36 @@ impl EditToolRequest {
             .update(cx, |project, cx| project.open_buffer(project_path, cx))?
             .await?;
 
-        let diff = match action {
+        let result = match action {
             EditAction::Replace {
                 old,
                 new,
-                file_path: _,
+                file_path,
             } => {
                 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
 
-                let diff = cx
-                    .background_executor()
-                    .spawn(Self::replace_diff(old, new, snapshot))
-                    .await;
-
-                anyhow::Ok(diff)
+                cx.background_executor()
+                    .spawn(Self::replace_diff(old, new, file_path, snapshot))
+                    .await
             }
-            EditAction::Write { content, .. } => Ok(buffer
-                .read_with(cx, |buffer, cx| buffer.diff(content, cx))?
-                .await),
+            EditAction::Write { content, .. } => Ok(DiffResult::Diff(
+                buffer
+                    .read_with(cx, |buffer, cx| buffer.diff(content, cx))?
+                    .await,
+            )),
         }?;
 
-        let _clock = buffer.update(cx, |buffer, cx| buffer.apply_diff(diff, cx))?;
+        match result {
+            DiffResult::BadSearch(invalid_replace) => {
+                self.bad_searches.push(invalid_replace);
+            }
+            DiffResult::Diff(diff) => {
+                let _clock = buffer.update(cx, |buffer, cx| buffer.apply_diff(diff, cx))?;
 
-        write!(&mut self.output, "\n\n{}", source)?;
-        self.changed_buffers.insert(buffer);
+                write!(&mut self.output, "\n\n{}", source)?;
+                self.changed_buffers.insert(buffer);
+            }
+        }
 
         Ok(())
     }
@@ -270,9 +288,29 @@ impl EditToolRequest {
     async fn replace_diff(
         old: String,
         new: String,
+        file_path: std::path::PathBuf,
         snapshot: language::BufferSnapshot,
-    ) -> language::Diff {
-        let edit_range = resolve_search_block(&snapshot, &old).to_offset(&snapshot);
+    ) -> Result<DiffResult> {
+        let query = SearchQuery::text(
+            old.clone(),
+            false,
+            true,
+            true,
+            PathMatcher::new(&[])?,
+            PathMatcher::new(&[])?,
+            None,
+        )?;
+
+        let matches = query.search(&snapshot, None).await;
+
+        if matches.is_empty() {
+            return Ok(DiffResult::BadSearch(BadSearch {
+                search: new.clone(),
+                file_path: file_path.display().to_string(),
+            }));
+        }
+
+        let edit_range = matches[0].clone();
         let diff = language::text_diff(&old, &new);
 
         let edits = diff
@@ -290,7 +328,7 @@ impl EditToolRequest {
             edits,
         };
 
-        diff
+        anyhow::Ok(DiffResult::Diff(diff))
     }
 
     const SUCCESS_OUTPUT_HEADER: &str = "Successfully applied. Here's a list of changes:";
@@ -314,7 +352,7 @@ impl EditToolRequest {
 
         let errors = self.parser.errors();
 
-        if errors.is_empty() {
+        if errors.is_empty() && self.bad_searches.is_empty() {
             if changed_buffer_count == 0 {
                 return Err(anyhow!(
                     "The instructions didn't lead to any changes. You might need to consult the file contents first."
@@ -337,6 +375,24 @@ impl EditToolRequest {
                 );
             }
 
+            if !self.bad_searches.is_empty() {
+                writeln!(
+                    &mut output,
+                    "\n\nThese searches failed because they didn't match any strings:"
+                )?;
+
+                for replace in self.bad_searches {
+                    writeln!(
+                        &mut output,
+                        "- '{}' does not appear in `{}`",
+                        replace.search.replace("\r", "\\r").replace("\n", "\\n"),
+                        replace.file_path
+                    )?;
+                }
+
+                write!(&mut output, "Make sure to use exact searches.")?;
+            }
+
             if !errors.is_empty() {
                 writeln!(
                     &mut output,

crates/assistant_tools/src/edit_files_tool/resolve_search_block.rs 🔗

@@ -1,226 +0,0 @@
-use language::{Anchor, Bias, BufferSnapshot};
-use std::ops::Range;
-
-#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)]
-enum SearchDirection {
-    Up,
-    Left,
-    Diagonal,
-}
-
-#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
-struct SearchState {
-    cost: u32,
-    direction: SearchDirection,
-}
-
-impl SearchState {
-    fn new(cost: u32, direction: SearchDirection) -> Self {
-        Self { cost, direction }
-    }
-}
-
-struct SearchMatrix {
-    cols: usize,
-    data: Vec<SearchState>,
-}
-
-impl SearchMatrix {
-    fn new(rows: usize, cols: usize) -> Self {
-        SearchMatrix {
-            cols,
-            data: vec![SearchState::new(0, SearchDirection::Diagonal); rows * cols],
-        }
-    }
-
-    fn get(&self, row: usize, col: usize) -> SearchState {
-        self.data[row * self.cols + col]
-    }
-
-    fn set(&mut self, row: usize, col: usize, cost: SearchState) {
-        self.data[row * self.cols + col] = cost;
-    }
-}
-
-pub fn resolve_search_block(buffer: &BufferSnapshot, search_query: &str) -> Range<Anchor> {
-    const INSERTION_COST: u32 = 3;
-    const DELETION_COST: u32 = 10;
-    const WHITESPACE_INSERTION_COST: u32 = 1;
-    const WHITESPACE_DELETION_COST: u32 = 1;
-
-    let buffer_len = buffer.len();
-    let query_len = search_query.len();
-    let mut matrix = SearchMatrix::new(query_len + 1, buffer_len + 1);
-    let mut leading_deletion_cost = 0_u32;
-    for (row, query_byte) in search_query.bytes().enumerate() {
-        let deletion_cost = if query_byte.is_ascii_whitespace() {
-            WHITESPACE_DELETION_COST
-        } else {
-            DELETION_COST
-        };
-
-        leading_deletion_cost = leading_deletion_cost.saturating_add(deletion_cost);
-        matrix.set(
-            row + 1,
-            0,
-            SearchState::new(leading_deletion_cost, SearchDirection::Diagonal),
-        );
-
-        for (col, buffer_byte) in buffer.bytes_in_range(0..buffer.len()).flatten().enumerate() {
-            let insertion_cost = if buffer_byte.is_ascii_whitespace() {
-                WHITESPACE_INSERTION_COST
-            } else {
-                INSERTION_COST
-            };
-
-            let up = SearchState::new(
-                matrix.get(row, col + 1).cost.saturating_add(deletion_cost),
-                SearchDirection::Up,
-            );
-            let left = SearchState::new(
-                matrix.get(row + 1, col).cost.saturating_add(insertion_cost),
-                SearchDirection::Left,
-            );
-            let diagonal = SearchState::new(
-                if query_byte == *buffer_byte {
-                    matrix.get(row, col).cost
-                } else {
-                    matrix
-                        .get(row, col)
-                        .cost
-                        .saturating_add(deletion_cost + insertion_cost)
-                },
-                SearchDirection::Diagonal,
-            );
-            matrix.set(row + 1, col + 1, up.min(left).min(diagonal));
-        }
-    }
-
-    // Traceback to find the best match
-    let mut best_buffer_end = buffer_len;
-    let mut best_cost = u32::MAX;
-    for col in 1..=buffer_len {
-        let cost = matrix.get(query_len, col).cost;
-        if cost < best_cost {
-            best_cost = cost;
-            best_buffer_end = col;
-        }
-    }
-
-    let mut query_ix = query_len;
-    let mut buffer_ix = best_buffer_end;
-    while query_ix > 0 && buffer_ix > 0 {
-        let current = matrix.get(query_ix, buffer_ix);
-        match current.direction {
-            SearchDirection::Diagonal => {
-                query_ix -= 1;
-                buffer_ix -= 1;
-            }
-            SearchDirection::Up => {
-                query_ix -= 1;
-            }
-            SearchDirection::Left => {
-                buffer_ix -= 1;
-            }
-        }
-    }
-
-    let mut start = buffer.offset_to_point(buffer.clip_offset(buffer_ix, Bias::Left));
-    start.column = 0;
-    let mut end = buffer.offset_to_point(buffer.clip_offset(best_buffer_end, Bias::Right));
-    if end.column > 0 {
-        end.column = buffer.line_len(end.row);
-    }
-
-    buffer.anchor_after(start)..buffer.anchor_before(end)
-}
-
-#[cfg(test)]
-mod tests {
-    use crate::edit_files_tool::resolve_search_block::resolve_search_block;
-    use gpui::{prelude::*, App};
-    use language::{Buffer, OffsetRangeExt as _};
-    use unindent::Unindent as _;
-    use util::test::{generate_marked_text, marked_text_ranges};
-
-    #[gpui::test]
-    fn test_resolve_search_block(cx: &mut App) {
-        assert_resolved(
-            concat!(
-                "    Lorem\n",
-                "«    ipsum\n",
-                "    dolor sit amet»\n",
-                "    consecteur",
-            ),
-            "ipsum\ndolor",
-            cx,
-        );
-
-        assert_resolved(
-            &"
-            «fn foo1(a: usize) -> usize {
-                40
-            }»
-
-            fn foo2(b: usize) -> usize {
-                42
-            }
-            "
-            .unindent(),
-            "fn foo1(b: usize) {\n40\n}",
-            cx,
-        );
-
-        assert_resolved(
-            &"
-            fn main() {
-            «    Foo
-                    .bar()
-                    .baz()
-                    .qux()»
-            }
-
-            fn foo2(b: usize) -> usize {
-                42
-            }
-            "
-            .unindent(),
-            "Foo.bar.baz.qux()",
-            cx,
-        );
-
-        assert_resolved(
-            &"
-            class Something {
-                one() { return 1; }
-            «    two() { return 2222; }
-                three() { return 333; }
-                four() { return 4444; }
-                five() { return 5555; }
-                six() { return 6666; }
-            »    seven() { return 7; }
-                eight() { return 8; }
-            }
-            "
-            .unindent(),
-            &"
-                two() { return 2222; }
-                four() { return 4444; }
-                five() { return 5555; }
-                six() { return 6666; }
-            "
-            .unindent(),
-            cx,
-        );
-    }
-
-    #[track_caller]
-    fn assert_resolved(text_with_expected_range: &str, query: &str, cx: &mut App) {
-        let (text, _) = marked_text_ranges(text_with_expected_range, false);
-        let buffer = cx.new(|cx| Buffer::local(text.clone(), cx));
-        let snapshot = buffer.read(cx).snapshot();
-        let range = resolve_search_block(&snapshot, query).to_offset(&snapshot);
-        let text_with_actual_range = generate_marked_text(&text, &[range], false);
-        pretty_assertions::assert_eq!(text_with_actual_range, text_with_expected_range);
-    }
-}