assistant edit tool: Use buffer search and replace in background (#26679)

Agus Zubiaga and Max created

Instead of getting the whole text from the buffer, replacing with
`String::replace`, and getting a whole diff, we'll now use `SearchQuery`
to get a range, diff only that range, and apply it (all in the
background).

When we match zero strings, we'll record a "bad search", keep going and
report it to the model at the end.

Release Notes:

- N/A

---------

Co-authored-by: Max <max@zed.dev>

Change summary

crates/assistant_tools/src/edit_files_tool.rs | 312 ++++++++++++++------
crates/language/src/buffer.rs                 |   4 
2 files changed, 212 insertions(+), 104 deletions(-)

Detailed changes

crates/assistant_tools/src/edit_files_tool.rs 🔗

@@ -6,16 +6,17 @@ use assistant_tool::Tool;
 use collections::HashSet;
 use edit_action::{EditAction, EditActionParser};
 use futures::StreamExt;
-use gpui::{App, Entity, Task};
+use gpui::{App, AsyncApp, Entity, Task};
 use language_model::{
     LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, Role,
 };
 use log::{EditToolLog, EditToolRequestId};
-use project::Project;
+use project::{search::SearchQuery, Project};
 use schemars::JsonSchema;
 use serde::{Deserialize, Serialize};
 use std::fmt::Write;
 use std::sync::Arc;
+use util::paths::PathMatcher;
 use util::ResultExt;
 
 #[derive(Debug, Serialize, Deserialize, JsonSchema)]
@@ -93,7 +94,7 @@ impl Tool for EditFilesTool {
                 });
 
                 let task =
-                    EditFilesTool::run(input, messages, project, Some((log.clone(), req_id)), cx);
+                    EditToolRequest::new(input, messages, project, Some((log.clone(), req_id)), cx);
 
                 cx.spawn(|mut cx| async move {
                     let result = task.await;
@@ -112,13 +113,33 @@ impl Tool for EditFilesTool {
                 })
             }
 
-            None => EditFilesTool::run(input, messages, project, None, cx),
+            None => EditToolRequest::new(input, messages, project, None, cx),
         }
     }
 }
 
-impl EditFilesTool {
-    fn run(
+struct EditToolRequest {
+    parser: EditActionParser,
+    changed_buffers: HashSet<Entity<language::Buffer>>,
+    bad_searches: Vec<BadSearch>,
+    project: Entity<Project>,
+    log: Option<(Entity<EditToolLog>, EditToolRequestId)>,
+}
+
+#[derive(Debug)]
+enum DiffResult {
+    BadSearch(BadSearch),
+    Diff(language::Diff),
+}
+
+#[derive(Debug)]
+struct BadSearch {
+    file_path: String,
+    search: String,
+}
+
+impl EditToolRequest {
+    fn new(
         input: EditFilesToolInput,
         messages: &[LanguageModelRequestMessage],
         project: Entity<Project>,
@@ -147,121 +168,208 @@ impl EditFilesTool {
         });
 
         cx.spawn(|mut cx| async move {
-            let request = LanguageModelRequest {
+            let llm_request = LanguageModelRequest {
                 messages,
                 tools: vec![],
                 stop: vec![],
                 temperature: Some(0.0),
             };
 
-            let mut parser = EditActionParser::new();
-
-            let stream = model.stream_completion_text(request, &cx);
+            let stream = model.stream_completion_text(llm_request, &cx);
             let mut chunks = stream.await?;
 
-            let mut changed_buffers = HashSet::default();
-            let mut applied_edits = 0;
-
-            let log = log.clone();
+            let mut request = Self {
+                parser: EditActionParser::new(),
+                changed_buffers: HashSet::default(),
+                bad_searches: Vec::new(),
+                project,
+                log,
+            };
 
             while let Some(chunk) = chunks.stream.next().await {
-                let chunk = chunk?;
+                request.process_response_chunk(&chunk?, &mut cx).await?;
+            }
+
+            request.finalize(&mut cx).await
+        })
+    }
 
-                let new_actions = parser.parse_chunk(&chunk);
+    async fn process_response_chunk(&mut self, chunk: &str, cx: &mut AsyncApp) -> Result<()> {
+        let new_actions = self.parser.parse_chunk(chunk);
 
-                if let Some((ref log, req_id)) = log {
-                    log.update(&mut cx, |log, cx| {
-                        log.push_editor_response_chunk(req_id, &chunk, &new_actions, cx)
-                    })
-                    .log_err();
-                }
+        if let Some((ref log, req_id)) = self.log {
+            log.update(cx, |log, cx| {
+                log.push_editor_response_chunk(req_id, chunk, &new_actions, cx)
+            })
+            .log_err();
+        }
 
-                for action in new_actions {
-                    let project_path = project.read_with(&cx, |project, cx| {
-                        project
-                            .find_project_path(action.file_path(), cx)
-                            .context("Path not found in project")
-                    })??;
-
-                    let buffer = project
-                        .update(&mut cx, |project, cx| project.open_buffer(project_path, cx))?
-                        .await?;
-
-                    let diff = buffer
-                        .read_with(&cx, |buffer, cx| {
-                            let new_text = match action {
-                                EditAction::Replace {
-                                    file_path,
-                                    old,
-                                    new,
-                                } => {
-                                    // TODO: Replace in background?
-                                    let text = buffer.text();
-                                    if text.contains(&old) {
-                                        text.replace(&old, &new)
-                                    } else {
-                                        return Err(anyhow!(
-                                            "Could not find search text in {}",
-                                            file_path.display()
-                                        ));
-                                    }
-                                }
-                                EditAction::Write { content, .. } => content,
-                            };
-
-                            anyhow::Ok(buffer.diff(new_text, cx))
-                        })??
-                        .await;
-
-                    let _clock =
-                        buffer.update(&mut cx, |buffer, cx| buffer.apply_diff(diff, cx))?;
-
-                    changed_buffers.insert(buffer);
-
-                    applied_edits += 1;
-                }
+        for action in new_actions {
+            self.apply_action(action, cx).await?;
+        }
+
+        Ok(())
+    }
+
+    async fn apply_action(&mut self, action: EditAction, cx: &mut AsyncApp) -> Result<()> {
+        let project_path = self.project.read_with(cx, |project, cx| {
+            project
+                .find_project_path(action.file_path(), cx)
+                .context("Path not found in project")
+        })??;
+
+        let buffer = self
+            .project
+            .update(cx, |project, cx| project.open_buffer(project_path, cx))?
+            .await?;
+
+        let result = match action {
+            EditAction::Replace {
+                old,
+                new,
+                file_path,
+            } => {
+                let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
+
+                cx.background_executor()
+                    .spawn(Self::replace_diff(old, new, file_path, snapshot))
+                    .await
+            }
+            EditAction::Write { content, .. } => Ok(DiffResult::Diff(
+                buffer
+                    .read_with(cx, |buffer, cx| buffer.diff(content, cx))?
+                    .await,
+            )),
+        }?;
+
+        match result {
+            DiffResult::BadSearch(invalid_replace) => {
+                self.bad_searches.push(invalid_replace);
             }
+            DiffResult::Diff(diff) => {
+                let _clock = buffer.update(cx, |buffer, cx| buffer.apply_diff(diff, cx))?;
 
-            let mut answer = match changed_buffers.len() {
-                0 => "No files were edited.".to_string(),
-                1 => "Successfully edited ".to_string(),
-                _ => "Successfully edited these files:\n\n".to_string(),
-            };
+                self.changed_buffers.insert(buffer);
+            }
+        }
+
+        Ok(())
+    }
+
+    async fn replace_diff(
+        old: String,
+        new: String,
+        file_path: std::path::PathBuf,
+        snapshot: language::BufferSnapshot,
+    ) -> Result<DiffResult> {
+        let query = SearchQuery::text(
+            old.clone(),
+            false,
+            true,
+            true,
+            PathMatcher::new(&[])?,
+            PathMatcher::new(&[])?,
+            None,
+        )?;
+
+        let matches = query.search(&snapshot, None).await;
+
+        if matches.is_empty() {
+            return Ok(DiffResult::BadSearch(BadSearch {
+                search: new.clone(),
+                file_path: file_path.display().to_string(),
+            }));
+        }
+
+        let edit_range = matches[0].clone();
+        let diff = language::text_diff(&old, &new);
+
+        let edits = diff
+            .into_iter()
+            .map(|(old_range, text)| {
+                let start = edit_range.start + old_range.start;
+                let end = edit_range.start + old_range.end;
+                (start..end, text)
+            })
+            .collect::<Vec<_>>();
+
+        let diff = language::Diff {
+            base_version: snapshot.version().clone(),
+            line_ending: snapshot.line_ending(),
+            edits,
+        };
+
+        anyhow::Ok(DiffResult::Diff(diff))
+    }
+
+    async fn finalize(self, cx: &mut AsyncApp) -> Result<String> {
+        let mut answer = match self.changed_buffers.len() {
+            0 => "No files were edited.".to_string(),
+            1 => "Successfully edited ".to_string(),
+            _ => "Successfully edited these files:\n\n".to_string(),
+        };
+
+        // Save each buffer once at the end
+        for buffer in self.changed_buffers {
+            let (path, save_task) = self.project.update(cx, |project, cx| {
+                let path = buffer
+                    .read(cx)
+                    .file()
+                    .map(|file| file.path().display().to_string());
 
-            // Save each buffer once at the end
-            for buffer in changed_buffers {
-                project
-                    .update(&mut cx, |project, cx| {
-                        if let Some(file) = buffer.read(&cx).file() {
-                            let _ = writeln!(&mut answer, "{}", &file.full_path(cx).display());
-                        }
-
-                        project.save_buffer(buffer, cx)
-                    })?
-                    .await?;
+                let task = project.save_buffer(buffer.clone(), cx);
+
+                (path, task)
+            })?;
+
+            save_task.await?;
+
+            if let Some(path) = path {
+                writeln!(&mut answer, "{}", path)?;
             }
+        }
 
-            let errors = parser.errors();
-
-            if errors.is_empty() {
-                Ok(answer.trim_end().to_string())
-            } else {
-                let error_message = errors
-                    .iter()
-                    .map(|e| e.to_string())
-                    .collect::<Vec<_>>()
-                    .join("\n");
-
-                if applied_edits > 0 {
-                    Err(anyhow!(
-                        "Applied {} edit(s), but some blocks failed to parse:\n{}",
-                        applied_edits,
-                        error_message
-                    ))
-                } else {
-                    Err(anyhow!(error_message))
+        let errors = self.parser.errors();
+
+        if errors.is_empty() && self.bad_searches.is_empty() {
+            Ok(answer.trim_end().to_string())
+        } else {
+            if !self.bad_searches.is_empty() {
+                writeln!(
+                    &mut answer,
+                    "\nThese searches failed because they didn't match any strings:"
+                )?;
+
+                for replace in self.bad_searches {
+                    writeln!(
+                        &mut answer,
+                        "- '{}' does not appear in `{}`",
+                        replace.search.replace("\r", "\\r").replace("\n", "\\n"),
+                        replace.file_path
+                    )?;
                 }
+
+                writeln!(&mut answer, "Make sure to use exact searches.")?;
             }
-        })
+
+            if !errors.is_empty() {
+                writeln!(
+                    &mut answer,
+                    "\nThese SEARCH/REPLACE blocks failed to parse:"
+                )?;
+
+                for error in errors {
+                    writeln!(&mut answer, "- {}", error)?;
+                }
+            }
+
+            writeln!(
+                &mut answer,
+                "\nYou can fix errors by running the tool again. You can include instructions,\
+                but errors are part of the conversation so you don't need to repeat them."
+            )?;
+
+            Err(anyhow!(answer))
+        }
     }
 }

crates/language/src/buffer.rs 🔗

@@ -526,8 +526,8 @@ impl DerefMut for ChunkRendererContext<'_, '_> {
 /// A set of edits to a given version of a buffer, computed asynchronously.
 #[derive(Debug)]
 pub struct Diff {
-    pub(crate) base_version: clock::Global,
-    line_ending: LineEnding,
+    pub base_version: clock::Global,
+    pub line_ending: LineEnding,
     pub edits: Vec<(Range<usize>, Arc<str>)>,
 }