From 8ec0309645a2a813ce470e9122e30e3ac71b01e2 Mon Sep 17 00:00:00 2001 From: Agus Zubiaga Date: Thu, 13 Mar 2025 12:25:49 -0300 Subject: [PATCH] assistant edit tool: Use buffer search and replace in background (#26679) Instead of getting the whole text from the buffer, replacing with `String::replace`, and getting a whole diff, we'll now use `SearchQuery` to get a range, diff only that range, and apply it (all in the background). When we match zero strings, we'll record a "bad search", keep going and report it to the model at the end. Release Notes: - N/A --------- Co-authored-by: Max --- crates/assistant_tools/src/edit_files_tool.rs | 312 ++++++++++++------ crates/language/src/buffer.rs | 4 +- 2 files changed, 212 insertions(+), 104 deletions(-) diff --git a/crates/assistant_tools/src/edit_files_tool.rs b/crates/assistant_tools/src/edit_files_tool.rs index 3ad8113fbf3263facff5827600ce2465d33ea88d..5ef7e7d77aab9b782bae448dfad0071f6b428394 100644 --- a/crates/assistant_tools/src/edit_files_tool.rs +++ b/crates/assistant_tools/src/edit_files_tool.rs @@ -6,16 +6,17 @@ use assistant_tool::Tool; use collections::HashSet; use edit_action::{EditAction, EditActionParser}; use futures::StreamExt; -use gpui::{App, Entity, Task}; +use gpui::{App, AsyncApp, Entity, Task}; use language_model::{ LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, Role, }; use log::{EditToolLog, EditToolRequestId}; -use project::Project; +use project::{search::SearchQuery, Project}; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use std::fmt::Write; use std::sync::Arc; +use util::paths::PathMatcher; use util::ResultExt; #[derive(Debug, Serialize, Deserialize, JsonSchema)] @@ -93,7 +94,7 @@ impl Tool for EditFilesTool { }); let task = - EditFilesTool::run(input, messages, project, Some((log.clone(), req_id)), cx); + EditToolRequest::new(input, messages, project, Some((log.clone(), req_id)), cx); cx.spawn(|mut cx| async move { let result = task.await; @@ -112,13 +113,33 @@ impl Tool for EditFilesTool { }) } - None => EditFilesTool::run(input, messages, project, None, cx), + None => EditToolRequest::new(input, messages, project, None, cx), } } } -impl EditFilesTool { - fn run( +struct EditToolRequest { + parser: EditActionParser, + changed_buffers: HashSet>, + bad_searches: Vec, + project: Entity, + log: Option<(Entity, EditToolRequestId)>, +} + +#[derive(Debug)] +enum DiffResult { + BadSearch(BadSearch), + Diff(language::Diff), +} + +#[derive(Debug)] +struct BadSearch { + file_path: String, + search: String, +} + +impl EditToolRequest { + fn new( input: EditFilesToolInput, messages: &[LanguageModelRequestMessage], project: Entity, @@ -147,121 +168,208 @@ impl EditFilesTool { }); cx.spawn(|mut cx| async move { - let request = LanguageModelRequest { + let llm_request = LanguageModelRequest { messages, tools: vec![], stop: vec![], temperature: Some(0.0), }; - let mut parser = EditActionParser::new(); - - let stream = model.stream_completion_text(request, &cx); + let stream = model.stream_completion_text(llm_request, &cx); let mut chunks = stream.await?; - let mut changed_buffers = HashSet::default(); - let mut applied_edits = 0; - - let log = log.clone(); + let mut request = Self { + parser: EditActionParser::new(), + changed_buffers: HashSet::default(), + bad_searches: Vec::new(), + project, + log, + }; while let Some(chunk) = chunks.stream.next().await { - let chunk = chunk?; + request.process_response_chunk(&chunk?, &mut cx).await?; + } + + request.finalize(&mut cx).await + }) + } - let new_actions = parser.parse_chunk(&chunk); + async fn process_response_chunk(&mut self, chunk: &str, cx: &mut AsyncApp) -> Result<()> { + let new_actions = self.parser.parse_chunk(chunk); - if let Some((ref log, req_id)) = log { - log.update(&mut cx, |log, cx| { - log.push_editor_response_chunk(req_id, &chunk, &new_actions, cx) - }) - .log_err(); - } + if let Some((ref log, req_id)) = self.log { + log.update(cx, |log, cx| { + log.push_editor_response_chunk(req_id, chunk, &new_actions, cx) + }) + .log_err(); + } - for action in new_actions { - let project_path = project.read_with(&cx, |project, cx| { - project - .find_project_path(action.file_path(), cx) - .context("Path not found in project") - })??; - - let buffer = project - .update(&mut cx, |project, cx| project.open_buffer(project_path, cx))? - .await?; - - let diff = buffer - .read_with(&cx, |buffer, cx| { - let new_text = match action { - EditAction::Replace { - file_path, - old, - new, - } => { - // TODO: Replace in background? - let text = buffer.text(); - if text.contains(&old) { - text.replace(&old, &new) - } else { - return Err(anyhow!( - "Could not find search text in {}", - file_path.display() - )); - } - } - EditAction::Write { content, .. } => content, - }; - - anyhow::Ok(buffer.diff(new_text, cx)) - })?? - .await; - - let _clock = - buffer.update(&mut cx, |buffer, cx| buffer.apply_diff(diff, cx))?; - - changed_buffers.insert(buffer); - - applied_edits += 1; - } + for action in new_actions { + self.apply_action(action, cx).await?; + } + + Ok(()) + } + + async fn apply_action(&mut self, action: EditAction, cx: &mut AsyncApp) -> Result<()> { + let project_path = self.project.read_with(cx, |project, cx| { + project + .find_project_path(action.file_path(), cx) + .context("Path not found in project") + })??; + + let buffer = self + .project + .update(cx, |project, cx| project.open_buffer(project_path, cx))? + .await?; + + let result = match action { + EditAction::Replace { + old, + new, + file_path, + } => { + let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?; + + cx.background_executor() + .spawn(Self::replace_diff(old, new, file_path, snapshot)) + .await + } + EditAction::Write { content, .. } => Ok(DiffResult::Diff( + buffer + .read_with(cx, |buffer, cx| buffer.diff(content, cx))? + .await, + )), + }?; + + match result { + DiffResult::BadSearch(invalid_replace) => { + self.bad_searches.push(invalid_replace); } + DiffResult::Diff(diff) => { + let _clock = buffer.update(cx, |buffer, cx| buffer.apply_diff(diff, cx))?; - let mut answer = match changed_buffers.len() { - 0 => "No files were edited.".to_string(), - 1 => "Successfully edited ".to_string(), - _ => "Successfully edited these files:\n\n".to_string(), - }; + self.changed_buffers.insert(buffer); + } + } + + Ok(()) + } + + async fn replace_diff( + old: String, + new: String, + file_path: std::path::PathBuf, + snapshot: language::BufferSnapshot, + ) -> Result { + let query = SearchQuery::text( + old.clone(), + false, + true, + true, + PathMatcher::new(&[])?, + PathMatcher::new(&[])?, + None, + )?; + + let matches = query.search(&snapshot, None).await; + + if matches.is_empty() { + return Ok(DiffResult::BadSearch(BadSearch { + search: new.clone(), + file_path: file_path.display().to_string(), + })); + } + + let edit_range = matches[0].clone(); + let diff = language::text_diff(&old, &new); + + let edits = diff + .into_iter() + .map(|(old_range, text)| { + let start = edit_range.start + old_range.start; + let end = edit_range.start + old_range.end; + (start..end, text) + }) + .collect::>(); + + let diff = language::Diff { + base_version: snapshot.version().clone(), + line_ending: snapshot.line_ending(), + edits, + }; + + anyhow::Ok(DiffResult::Diff(diff)) + } + + async fn finalize(self, cx: &mut AsyncApp) -> Result { + let mut answer = match self.changed_buffers.len() { + 0 => "No files were edited.".to_string(), + 1 => "Successfully edited ".to_string(), + _ => "Successfully edited these files:\n\n".to_string(), + }; + + // Save each buffer once at the end + for buffer in self.changed_buffers { + let (path, save_task) = self.project.update(cx, |project, cx| { + let path = buffer + .read(cx) + .file() + .map(|file| file.path().display().to_string()); - // Save each buffer once at the end - for buffer in changed_buffers { - project - .update(&mut cx, |project, cx| { - if let Some(file) = buffer.read(&cx).file() { - let _ = writeln!(&mut answer, "{}", &file.full_path(cx).display()); - } - - project.save_buffer(buffer, cx) - })? - .await?; + let task = project.save_buffer(buffer.clone(), cx); + + (path, task) + })?; + + save_task.await?; + + if let Some(path) = path { + writeln!(&mut answer, "{}", path)?; } + } - let errors = parser.errors(); - - if errors.is_empty() { - Ok(answer.trim_end().to_string()) - } else { - let error_message = errors - .iter() - .map(|e| e.to_string()) - .collect::>() - .join("\n"); - - if applied_edits > 0 { - Err(anyhow!( - "Applied {} edit(s), but some blocks failed to parse:\n{}", - applied_edits, - error_message - )) - } else { - Err(anyhow!(error_message)) + let errors = self.parser.errors(); + + if errors.is_empty() && self.bad_searches.is_empty() { + Ok(answer.trim_end().to_string()) + } else { + if !self.bad_searches.is_empty() { + writeln!( + &mut answer, + "\nThese searches failed because they didn't match any strings:" + )?; + + for replace in self.bad_searches { + writeln!( + &mut answer, + "- '{}' does not appear in `{}`", + replace.search.replace("\r", "\\r").replace("\n", "\\n"), + replace.file_path + )?; } + + writeln!(&mut answer, "Make sure to use exact searches.")?; } - }) + + if !errors.is_empty() { + writeln!( + &mut answer, + "\nThese SEARCH/REPLACE blocks failed to parse:" + )?; + + for error in errors { + writeln!(&mut answer, "- {}", error)?; + } + } + + writeln!( + &mut answer, + "\nYou can fix errors by running the tool again. You can include instructions,\ + but errors are part of the conversation so you don't need to repeat them." + )?; + + Err(anyhow!(answer)) + } } } diff --git a/crates/language/src/buffer.rs b/crates/language/src/buffer.rs index b08fe045f132214ba313348e389716a40a4ca2b4..91ad627cbeb1ed82739bed17a5b9bfdbca73b30a 100644 --- a/crates/language/src/buffer.rs +++ b/crates/language/src/buffer.rs @@ -526,8 +526,8 @@ impl DerefMut for ChunkRendererContext<'_, '_> { /// A set of edits to a given version of a buffer, computed asynchronously. #[derive(Debug)] pub struct Diff { - pub(crate) base_version: clock::Global, - line_ending: LineEnding, + pub base_version: clock::Global, + pub line_ending: LineEnding, pub edits: Vec<(Range, Arc)>, }