From 94b63808e0593ee927151b751f290ebb9e31dc25 Mon Sep 17 00:00:00 2001 From: Agus Zubiaga Date: Mon, 17 Mar 2025 15:33:20 -0300 Subject: [PATCH] assistant edit tool: Fuzzy match search block (#26935) Release Notes: - N/A Co-authored-by: Antonio Scandurra --- Cargo.lock | 2 + crates/assistant_eval/src/main.rs | 7 +- crates/assistant_tools/Cargo.toml | 2 + crates/assistant_tools/src/edit_files_tool.rs | 100 ++------ .../edit_files_tool/resolve_search_block.rs | 226 ++++++++++++++++++ 5 files changed, 258 insertions(+), 79 deletions(-) create mode 100644 crates/assistant_tools/src/edit_files_tool/resolve_search_block.rs diff --git a/Cargo.lock b/Cargo.lock index 9b4248114c6ccd77f75b86ab5f1aa518feb6a7c6..36182efd3dfbafa305d85c7bf931b992cf1ea914 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -716,6 +716,7 @@ dependencies = [ "gpui", "language", "language_model", + "pretty_assertions", "project", "rand 0.8.5", "release_channel", @@ -725,6 +726,7 @@ dependencies = [ "settings", "theme", "ui", + "unindent", "util", "workspace", "worktree", diff --git a/crates/assistant_eval/src/main.rs b/crates/assistant_eval/src/main.rs index 316aaf04ec5e9c924d7787d4e30b547d2d2d16fe..f2fdde1a92c6f602028a39c59721b193fb5aa698 100644 --- a/crates/assistant_eval/src/main.rs +++ b/crates/assistant_eval/src/main.rs @@ -48,7 +48,12 @@ fn main() { let crate_dir = PathBuf::from("../zed-agent-bench"); let evaluation_data_dir = crate_dir.join("evaluation_data").canonicalize().unwrap(); - let repos_dir = crate_dir.join("repos").canonicalize().unwrap(); + + let repos_dir = crate_dir.join("repos"); + if !repos_dir.exists() { + std::fs::create_dir_all(&repos_dir).unwrap(); + } + let repos_dir = repos_dir.canonicalize().unwrap(); let all_evals = std::fs::read_dir(&evaluation_data_dir) .unwrap() diff --git a/crates/assistant_tools/Cargo.toml b/crates/assistant_tools/Cargo.toml index 087b883bb6ed363fb07c4b47b775d5e8f0431953..69861f0bcfa6df07699f365735db1ad315d7deee 100644 --- a/crates/assistant_tools/Cargo.toml +++ b/crates/assistant_tools/Cargo.toml @@ -38,5 +38,7 @@ rand.workspace = true collections = { workspace = true, features = ["test-support"] } gpui = { workspace = true, features = ["test-support"] } language = { workspace = true, features = ["test-support"] } +pretty_assertions.workspace = true project = { workspace = true, features = ["test-support"] } +unindent.workspace = true workspace = { workspace = true, features = ["test-support"] } diff --git a/crates/assistant_tools/src/edit_files_tool.rs b/crates/assistant_tools/src/edit_files_tool.rs index 4be7e90b56ac55254e36a35ee9e52a49d4d29d94..5cc81eda37fc80da525f72b427d1a82130007bf9 100644 --- a/crates/assistant_tools/src/edit_files_tool.rs +++ b/crates/assistant_tools/src/edit_files_tool.rs @@ -1,5 +1,6 @@ mod edit_action; pub mod log; +mod resolve_search_block; use anyhow::{anyhow, Context, Result}; use assistant_tool::{ActionLog, Tool}; @@ -7,16 +8,17 @@ use collections::HashSet; use edit_action::{EditAction, EditActionParser}; use futures::StreamExt; use gpui::{App, AsyncApp, Entity, Task}; +use language::OffsetRangeExt; use language_model::{ LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, MessageContent, Role, }; use log::{EditToolLog, EditToolRequestId}; -use project::{search::SearchQuery, Project}; +use project::Project; +use resolve_search_block::resolve_search_block; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use std::fmt::Write; use std::sync::Arc; -use util::paths::PathMatcher; use util::ResultExt; #[derive(Debug, Serialize, Deserialize, JsonSchema)] @@ -129,24 +131,11 @@ struct EditToolRequest { parser: EditActionParser, output: String, changed_buffers: HashSet>, - bad_searches: Vec, project: Entity, action_log: Entity, tool_log: Option<(Entity, EditToolRequestId)>, } -#[derive(Debug)] -enum DiffResult { - BadSearch(BadSearch), - Diff(language::Diff), -} - -#[derive(Debug)] -struct BadSearch { - file_path: String, - search: String, -} - impl EditToolRequest { fn new( input: EditFilesToolInput, @@ -204,7 +193,6 @@ impl EditToolRequest { // we start with the success header so we don't need to shift the output in the common case output: Self::SUCCESS_OUTPUT_HEADER.to_string(), changed_buffers: HashSet::default(), - bad_searches: Vec::new(), action_log, project, tool_log, @@ -251,36 +239,30 @@ impl EditToolRequest { .update(cx, |project, cx| project.open_buffer(project_path, cx))? .await?; - let result = match action { + let diff = match action { EditAction::Replace { old, new, - file_path, + file_path: _, } => { let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?; - cx.background_executor() - .spawn(Self::replace_diff(old, new, file_path, snapshot)) - .await + let diff = cx + .background_executor() + .spawn(Self::replace_diff(old, new, snapshot)) + .await; + + anyhow::Ok(diff) } - EditAction::Write { content, .. } => Ok(DiffResult::Diff( - buffer - .read_with(cx, |buffer, cx| buffer.diff(content, cx))? - .await, - )), + EditAction::Write { content, .. } => Ok(buffer + .read_with(cx, |buffer, cx| buffer.diff(content, cx))? + .await), }?; - match result { - DiffResult::BadSearch(invalid_replace) => { - self.bad_searches.push(invalid_replace); - } - DiffResult::Diff(diff) => { - let _clock = buffer.update(cx, |buffer, cx| buffer.apply_diff(diff, cx))?; + let _clock = buffer.update(cx, |buffer, cx| buffer.apply_diff(diff, cx))?; - write!(&mut self.output, "\n\n{}", source)?; - self.changed_buffers.insert(buffer); - } - } + write!(&mut self.output, "\n\n{}", source)?; + self.changed_buffers.insert(buffer); Ok(()) } @@ -288,29 +270,9 @@ impl EditToolRequest { async fn replace_diff( old: String, new: String, - file_path: std::path::PathBuf, snapshot: language::BufferSnapshot, - ) -> Result { - let query = SearchQuery::text( - old.clone(), - false, - true, - true, - PathMatcher::new(&[])?, - PathMatcher::new(&[])?, - None, - )?; - - let matches = query.search(&snapshot, None).await; - - if matches.is_empty() { - return Ok(DiffResult::BadSearch(BadSearch { - search: new.clone(), - file_path: file_path.display().to_string(), - })); - } - - let edit_range = matches[0].clone(); + ) -> language::Diff { + let edit_range = resolve_search_block(&snapshot, &old).to_offset(&snapshot); let diff = language::text_diff(&old, &new); let edits = diff @@ -328,7 +290,7 @@ impl EditToolRequest { edits, }; - anyhow::Ok(DiffResult::Diff(diff)) + diff } const SUCCESS_OUTPUT_HEADER: &str = "Successfully applied. Here's a list of changes:"; @@ -354,7 +316,7 @@ impl EditToolRequest { let errors = self.parser.errors(); - if errors.is_empty() && self.bad_searches.is_empty() { + if errors.is_empty() { if changed_buffer_count == 0 { return Err(anyhow!( "The instructions didn't lead to any changes. You might need to consult the file contents first." @@ -377,24 +339,6 @@ impl EditToolRequest { ); } - if !self.bad_searches.is_empty() { - writeln!( - &mut output, - "\n\nThese searches failed because they didn't match any strings:" - )?; - - for replace in self.bad_searches { - writeln!( - &mut output, - "- '{}' does not appear in `{}`", - replace.search.replace("\r", "\\r").replace("\n", "\\n"), - replace.file_path - )?; - } - - write!(&mut output, "Make sure to use exact searches.")?; - } - if !errors.is_empty() { writeln!( &mut output, diff --git a/crates/assistant_tools/src/edit_files_tool/resolve_search_block.rs b/crates/assistant_tools/src/edit_files_tool/resolve_search_block.rs new file mode 100644 index 0000000000000000000000000000000000000000..5d2f61f8bb127800493c7156697159bdae7d20cf --- /dev/null +++ b/crates/assistant_tools/src/edit_files_tool/resolve_search_block.rs @@ -0,0 +1,226 @@ +use language::{Anchor, Bias, BufferSnapshot}; +use std::ops::Range; + +#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)] +enum SearchDirection { + Up, + Left, + Diagonal, +} + +#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord)] +struct SearchState { + cost: u32, + direction: SearchDirection, +} + +impl SearchState { + fn new(cost: u32, direction: SearchDirection) -> Self { + Self { cost, direction } + } +} + +struct SearchMatrix { + cols: usize, + data: Vec, +} + +impl SearchMatrix { + fn new(rows: usize, cols: usize) -> Self { + SearchMatrix { + cols, + data: vec![SearchState::new(0, SearchDirection::Diagonal); rows * cols], + } + } + + fn get(&self, row: usize, col: usize) -> SearchState { + self.data[row * self.cols + col] + } + + fn set(&mut self, row: usize, col: usize, cost: SearchState) { + self.data[row * self.cols + col] = cost; + } +} + +pub fn resolve_search_block(buffer: &BufferSnapshot, search_query: &str) -> Range { + const INSERTION_COST: u32 = 3; + const DELETION_COST: u32 = 10; + const WHITESPACE_INSERTION_COST: u32 = 1; + const WHITESPACE_DELETION_COST: u32 = 1; + + let buffer_len = buffer.len(); + let query_len = search_query.len(); + let mut matrix = SearchMatrix::new(query_len + 1, buffer_len + 1); + let mut leading_deletion_cost = 0_u32; + for (row, query_byte) in search_query.bytes().enumerate() { + let deletion_cost = if query_byte.is_ascii_whitespace() { + WHITESPACE_DELETION_COST + } else { + DELETION_COST + }; + + leading_deletion_cost = leading_deletion_cost.saturating_add(deletion_cost); + matrix.set( + row + 1, + 0, + SearchState::new(leading_deletion_cost, SearchDirection::Diagonal), + ); + + for (col, buffer_byte) in buffer.bytes_in_range(0..buffer.len()).flatten().enumerate() { + let insertion_cost = if buffer_byte.is_ascii_whitespace() { + WHITESPACE_INSERTION_COST + } else { + INSERTION_COST + }; + + let up = SearchState::new( + matrix.get(row, col + 1).cost.saturating_add(deletion_cost), + SearchDirection::Up, + ); + let left = SearchState::new( + matrix.get(row + 1, col).cost.saturating_add(insertion_cost), + SearchDirection::Left, + ); + let diagonal = SearchState::new( + if query_byte == *buffer_byte { + matrix.get(row, col).cost + } else { + matrix + .get(row, col) + .cost + .saturating_add(deletion_cost + insertion_cost) + }, + SearchDirection::Diagonal, + ); + matrix.set(row + 1, col + 1, up.min(left).min(diagonal)); + } + } + + // Traceback to find the best match + let mut best_buffer_end = buffer_len; + let mut best_cost = u32::MAX; + for col in 1..=buffer_len { + let cost = matrix.get(query_len, col).cost; + if cost < best_cost { + best_cost = cost; + best_buffer_end = col; + } + } + + let mut query_ix = query_len; + let mut buffer_ix = best_buffer_end; + while query_ix > 0 && buffer_ix > 0 { + let current = matrix.get(query_ix, buffer_ix); + match current.direction { + SearchDirection::Diagonal => { + query_ix -= 1; + buffer_ix -= 1; + } + SearchDirection::Up => { + query_ix -= 1; + } + SearchDirection::Left => { + buffer_ix -= 1; + } + } + } + + let mut start = buffer.offset_to_point(buffer.clip_offset(buffer_ix, Bias::Left)); + start.column = 0; + let mut end = buffer.offset_to_point(buffer.clip_offset(best_buffer_end, Bias::Right)); + if end.column > 0 { + end.column = buffer.line_len(end.row); + } + + buffer.anchor_after(start)..buffer.anchor_before(end) +} + +#[cfg(test)] +mod tests { + use crate::edit_files_tool::resolve_search_block::resolve_search_block; + use gpui::{prelude::*, App}; + use language::{Buffer, OffsetRangeExt as _}; + use unindent::Unindent as _; + use util::test::{generate_marked_text, marked_text_ranges}; + + #[gpui::test] + fn test_resolve_search_block(cx: &mut App) { + assert_resolved( + concat!( + " Lorem\n", + "« ipsum\n", + " dolor sit amet»\n", + " consecteur", + ), + "ipsum\ndolor", + cx, + ); + + assert_resolved( + &" + «fn foo1(a: usize) -> usize { + 40 + }» + + fn foo2(b: usize) -> usize { + 42 + } + " + .unindent(), + "fn foo1(b: usize) {\n40\n}", + cx, + ); + + assert_resolved( + &" + fn main() { + « Foo + .bar() + .baz() + .qux()» + } + + fn foo2(b: usize) -> usize { + 42 + } + " + .unindent(), + "Foo.bar.baz.qux()", + cx, + ); + + assert_resolved( + &" + class Something { + one() { return 1; } + « two() { return 2222; } + three() { return 333; } + four() { return 4444; } + five() { return 5555; } + six() { return 6666; } + » seven() { return 7; } + eight() { return 8; } + } + " + .unindent(), + &" + two() { return 2222; } + four() { return 4444; } + five() { return 5555; } + six() { return 6666; } + " + .unindent(), + cx, + ); + } + + #[track_caller] + fn assert_resolved(text_with_expected_range: &str, query: &str, cx: &mut App) { + let (text, _) = marked_text_ranges(text_with_expected_range, false); + let buffer = cx.new(|cx| Buffer::local(text.clone(), cx)); + let snapshot = buffer.read(cx).snapshot(); + let range = resolve_search_block(&snapshot, query).to_offset(&snapshot); + let text_with_actual_range = generate_marked_text(&text, &[range], false); + pretty_assertions::assert_eq!(text_with_actual_range, text_with_expected_range); + } +}