assistant edit tool: Fuzzy match search block (#26935)

Agus Zubiaga and Antonio Scandurra created

Release Notes:

- N/A

Co-authored-by: Antonio Scandurra <me@as-cii.com>

Change summary

Cargo.lock                                                         |   2 
crates/assistant_eval/src/main.rs                                  |   7 
crates/assistant_tools/Cargo.toml                                  |   2 
crates/assistant_tools/src/edit_files_tool.rs                      | 100 
crates/assistant_tools/src/edit_files_tool/resolve_search_block.rs | 226 
5 files changed, 258 insertions(+), 79 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -716,6 +716,7 @@ dependencies = [
  "gpui",
  "language",
  "language_model",
+ "pretty_assertions",
  "project",
  "rand 0.8.5",
  "release_channel",
@@ -725,6 +726,7 @@ dependencies = [
  "settings",
  "theme",
  "ui",
+ "unindent",
  "util",
  "workspace",
  "worktree",

crates/assistant_eval/src/main.rs 🔗

@@ -48,7 +48,12 @@ fn main() {
 
     let crate_dir = PathBuf::from("../zed-agent-bench");
     let evaluation_data_dir = crate_dir.join("evaluation_data").canonicalize().unwrap();
-    let repos_dir = crate_dir.join("repos").canonicalize().unwrap();
+
+    let repos_dir = crate_dir.join("repos");
+    if !repos_dir.exists() {
+        std::fs::create_dir_all(&repos_dir).unwrap();
+    }
+    let repos_dir = repos_dir.canonicalize().unwrap();
 
     let all_evals = std::fs::read_dir(&evaluation_data_dir)
         .unwrap()

crates/assistant_tools/Cargo.toml 🔗

@@ -38,5 +38,7 @@ rand.workspace = true
 collections = { workspace = true, features = ["test-support"] }
 gpui = { workspace = true, features = ["test-support"] }
 language = { workspace = true, features = ["test-support"] }
+pretty_assertions.workspace = true
 project = { workspace = true, features = ["test-support"] }
+unindent.workspace = true
 workspace = { workspace = true, features = ["test-support"] }

crates/assistant_tools/src/edit_files_tool.rs 🔗

@@ -1,5 +1,6 @@
 mod edit_action;
 pub mod log;
+mod resolve_search_block;
 
 use anyhow::{anyhow, Context, Result};
 use assistant_tool::{ActionLog, Tool};
@@ -7,16 +8,17 @@ use collections::HashSet;
 use edit_action::{EditAction, EditActionParser};
 use futures::StreamExt;
 use gpui::{App, AsyncApp, Entity, Task};
+use language::OffsetRangeExt;
 use language_model::{
     LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, MessageContent, Role,
 };
 use log::{EditToolLog, EditToolRequestId};
-use project::{search::SearchQuery, Project};
+use project::Project;
+use resolve_search_block::resolve_search_block;
 use schemars::JsonSchema;
 use serde::{Deserialize, Serialize};
 use std::fmt::Write;
 use std::sync::Arc;
-use util::paths::PathMatcher;
 use util::ResultExt;
 
 #[derive(Debug, Serialize, Deserialize, JsonSchema)]
@@ -129,24 +131,11 @@ struct EditToolRequest {
     parser: EditActionParser,
     output: String,
     changed_buffers: HashSet<Entity<language::Buffer>>,
-    bad_searches: Vec<BadSearch>,
     project: Entity<Project>,
     action_log: Entity<ActionLog>,
     tool_log: Option<(Entity<EditToolLog>, EditToolRequestId)>,
 }
 
-#[derive(Debug)]
-enum DiffResult {
-    BadSearch(BadSearch),
-    Diff(language::Diff),
-}
-
-#[derive(Debug)]
-struct BadSearch {
-    file_path: String,
-    search: String,
-}
-
 impl EditToolRequest {
     fn new(
         input: EditFilesToolInput,
@@ -204,7 +193,6 @@ impl EditToolRequest {
                 // we start with the success header so we don't need to shift the output in the common case
                 output: Self::SUCCESS_OUTPUT_HEADER.to_string(),
                 changed_buffers: HashSet::default(),
-                bad_searches: Vec::new(),
                 action_log,
                 project,
                 tool_log,
@@ -251,36 +239,30 @@ impl EditToolRequest {
             .update(cx, |project, cx| project.open_buffer(project_path, cx))?
             .await?;
 
-        let result = match action {
+        let diff = match action {
             EditAction::Replace {
                 old,
                 new,
-                file_path,
+                file_path: _,
             } => {
                 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
 
-                cx.background_executor()
-                    .spawn(Self::replace_diff(old, new, file_path, snapshot))
-                    .await
+                let diff = cx
+                    .background_executor()
+                    .spawn(Self::replace_diff(old, new, snapshot))
+                    .await;
+
+                anyhow::Ok(diff)
             }
-            EditAction::Write { content, .. } => Ok(DiffResult::Diff(
-                buffer
-                    .read_with(cx, |buffer, cx| buffer.diff(content, cx))?
-                    .await,
-            )),
+            EditAction::Write { content, .. } => Ok(buffer
+                .read_with(cx, |buffer, cx| buffer.diff(content, cx))?
+                .await),
         }?;
 
-        match result {
-            DiffResult::BadSearch(invalid_replace) => {
-                self.bad_searches.push(invalid_replace);
-            }
-            DiffResult::Diff(diff) => {
-                let _clock = buffer.update(cx, |buffer, cx| buffer.apply_diff(diff, cx))?;
+        let _clock = buffer.update(cx, |buffer, cx| buffer.apply_diff(diff, cx))?;
 
-                write!(&mut self.output, "\n\n{}", source)?;
-                self.changed_buffers.insert(buffer);
-            }
-        }
+        write!(&mut self.output, "\n\n{}", source)?;
+        self.changed_buffers.insert(buffer);
 
         Ok(())
     }
@@ -288,29 +270,9 @@ impl EditToolRequest {
     async fn replace_diff(
         old: String,
         new: String,
-        file_path: std::path::PathBuf,
         snapshot: language::BufferSnapshot,
-    ) -> Result<DiffResult> {
-        let query = SearchQuery::text(
-            old.clone(),
-            false,
-            true,
-            true,
-            PathMatcher::new(&[])?,
-            PathMatcher::new(&[])?,
-            None,
-        )?;
-
-        let matches = query.search(&snapshot, None).await;
-
-        if matches.is_empty() {
-            return Ok(DiffResult::BadSearch(BadSearch {
-                search: new.clone(),
-                file_path: file_path.display().to_string(),
-            }));
-        }
-
-        let edit_range = matches[0].clone();
+    ) -> language::Diff {
+        let edit_range = resolve_search_block(&snapshot, &old).to_offset(&snapshot);
         let diff = language::text_diff(&old, &new);
 
         let edits = diff
@@ -328,7 +290,7 @@ impl EditToolRequest {
             edits,
         };
 
-        anyhow::Ok(DiffResult::Diff(diff))
+        diff
     }
 
     const SUCCESS_OUTPUT_HEADER: &str = "Successfully applied. Here's a list of changes:";
@@ -354,7 +316,7 @@ impl EditToolRequest {
 
         let errors = self.parser.errors();
 
-        if errors.is_empty() && self.bad_searches.is_empty() {
+        if errors.is_empty() {
             if changed_buffer_count == 0 {
                 return Err(anyhow!(
                     "The instructions didn't lead to any changes. You might need to consult the file contents first."
@@ -377,24 +339,6 @@ impl EditToolRequest {
                 );
             }
 
-            if !self.bad_searches.is_empty() {
-                writeln!(
-                    &mut output,
-                    "\n\nThese searches failed because they didn't match any strings:"
-                )?;
-
-                for replace in self.bad_searches {
-                    writeln!(
-                        &mut output,
-                        "- '{}' does not appear in `{}`",
-                        replace.search.replace("\r", "\\r").replace("\n", "\\n"),
-                        replace.file_path
-                    )?;
-                }
-
-                write!(&mut output, "Make sure to use exact searches.")?;
-            }
-
             if !errors.is_empty() {
                 writeln!(
                     &mut output,

crates/assistant_tools/src/edit_files_tool/resolve_search_block.rs 🔗

@@ -0,0 +1,226 @@
+use language::{Anchor, Bias, BufferSnapshot};
+use std::ops::Range;
+
+#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)]
+enum SearchDirection {
+    Up,
+    Left,
+    Diagonal,
+}
+
+#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
+struct SearchState {
+    cost: u32,
+    direction: SearchDirection,
+}
+
+impl SearchState {
+    fn new(cost: u32, direction: SearchDirection) -> Self {
+        Self { cost, direction }
+    }
+}
+
+struct SearchMatrix {
+    cols: usize,
+    data: Vec<SearchState>,
+}
+
+impl SearchMatrix {
+    fn new(rows: usize, cols: usize) -> Self {
+        SearchMatrix {
+            cols,
+            data: vec![SearchState::new(0, SearchDirection::Diagonal); rows * cols],
+        }
+    }
+
+    fn get(&self, row: usize, col: usize) -> SearchState {
+        self.data[row * self.cols + col]
+    }
+
+    fn set(&mut self, row: usize, col: usize, cost: SearchState) {
+        self.data[row * self.cols + col] = cost;
+    }
+}
+
+pub fn resolve_search_block(buffer: &BufferSnapshot, search_query: &str) -> Range<Anchor> {
+    const INSERTION_COST: u32 = 3;
+    const DELETION_COST: u32 = 10;
+    const WHITESPACE_INSERTION_COST: u32 = 1;
+    const WHITESPACE_DELETION_COST: u32 = 1;
+
+    let buffer_len = buffer.len();
+    let query_len = search_query.len();
+    let mut matrix = SearchMatrix::new(query_len + 1, buffer_len + 1);
+    let mut leading_deletion_cost = 0_u32;
+    for (row, query_byte) in search_query.bytes().enumerate() {
+        let deletion_cost = if query_byte.is_ascii_whitespace() {
+            WHITESPACE_DELETION_COST
+        } else {
+            DELETION_COST
+        };
+
+        leading_deletion_cost = leading_deletion_cost.saturating_add(deletion_cost);
+        matrix.set(
+            row + 1,
+            0,
+            SearchState::new(leading_deletion_cost, SearchDirection::Diagonal),
+        );
+
+        for (col, buffer_byte) in buffer.bytes_in_range(0..buffer.len()).flatten().enumerate() {
+            let insertion_cost = if buffer_byte.is_ascii_whitespace() {
+                WHITESPACE_INSERTION_COST
+            } else {
+                INSERTION_COST
+            };
+
+            let up = SearchState::new(
+                matrix.get(row, col + 1).cost.saturating_add(deletion_cost),
+                SearchDirection::Up,
+            );
+            let left = SearchState::new(
+                matrix.get(row + 1, col).cost.saturating_add(insertion_cost),
+                SearchDirection::Left,
+            );
+            let diagonal = SearchState::new(
+                if query_byte == *buffer_byte {
+                    matrix.get(row, col).cost
+                } else {
+                    matrix
+                        .get(row, col)
+                        .cost
+                        .saturating_add(deletion_cost + insertion_cost)
+                },
+                SearchDirection::Diagonal,
+            );
+            matrix.set(row + 1, col + 1, up.min(left).min(diagonal));
+        }
+    }
+
+    // Traceback to find the best match
+    let mut best_buffer_end = buffer_len;
+    let mut best_cost = u32::MAX;
+    for col in 1..=buffer_len {
+        let cost = matrix.get(query_len, col).cost;
+        if cost < best_cost {
+            best_cost = cost;
+            best_buffer_end = col;
+        }
+    }
+
+    let mut query_ix = query_len;
+    let mut buffer_ix = best_buffer_end;
+    while query_ix > 0 && buffer_ix > 0 {
+        let current = matrix.get(query_ix, buffer_ix);
+        match current.direction {
+            SearchDirection::Diagonal => {
+                query_ix -= 1;
+                buffer_ix -= 1;
+            }
+            SearchDirection::Up => {
+                query_ix -= 1;
+            }
+            SearchDirection::Left => {
+                buffer_ix -= 1;
+            }
+        }
+    }
+
+    let mut start = buffer.offset_to_point(buffer.clip_offset(buffer_ix, Bias::Left));
+    start.column = 0;
+    let mut end = buffer.offset_to_point(buffer.clip_offset(best_buffer_end, Bias::Right));
+    if end.column > 0 {
+        end.column = buffer.line_len(end.row);
+    }
+
+    buffer.anchor_after(start)..buffer.anchor_before(end)
+}
+
+#[cfg(test)]
+mod tests {
+    use crate::edit_files_tool::resolve_search_block::resolve_search_block;
+    use gpui::{prelude::*, App};
+    use language::{Buffer, OffsetRangeExt as _};
+    use unindent::Unindent as _;
+    use util::test::{generate_marked_text, marked_text_ranges};
+
+    #[gpui::test]
+    fn test_resolve_search_block(cx: &mut App) {
+        assert_resolved(
+            concat!(
+                "    Lorem\n",
+                "«    ipsum\n",
+                "    dolor sit amet»\n",
+                "    consecteur",
+            ),
+            "ipsum\ndolor",
+            cx,
+        );
+
+        assert_resolved(
+            &"
+            «fn foo1(a: usize) -> usize {
+                40
+            }»
+
+            fn foo2(b: usize) -> usize {
+                42
+            }
+            "
+            .unindent(),
+            "fn foo1(b: usize) {\n40\n}",
+            cx,
+        );
+
+        assert_resolved(
+            &"
+            fn main() {
+            «    Foo
+                    .bar()
+                    .baz()
+                    .qux()»
+            }
+
+            fn foo2(b: usize) -> usize {
+                42
+            }
+            "
+            .unindent(),
+            "Foo.bar.baz.qux()",
+            cx,
+        );
+
+        assert_resolved(
+            &"
+            class Something {
+                one() { return 1; }
+            «    two() { return 2222; }
+                three() { return 333; }
+                four() { return 4444; }
+                five() { return 5555; }
+                six() { return 6666; }
+            »    seven() { return 7; }
+                eight() { return 8; }
+            }
+            "
+            .unindent(),
+            &"
+                two() { return 2222; }
+                four() { return 4444; }
+                five() { return 5555; }
+                six() { return 6666; }
+            "
+            .unindent(),
+            cx,
+        );
+    }
+
+    #[track_caller]
+    fn assert_resolved(text_with_expected_range: &str, query: &str, cx: &mut App) {
+        let (text, _) = marked_text_ranges(text_with_expected_range, false);
+        let buffer = cx.new(|cx| Buffer::local(text.clone(), cx));
+        let snapshot = buffer.read(cx).snapshot();
+        let range = resolve_search_block(&snapshot, query).to_offset(&snapshot);
+        let text_with_actual_range = generate_marked_text(&text, &[range], false);
+        pretty_assertions::assert_eq!(text_with_actual_range, text_with_expected_range);
+    }
+}