agent: Cleanup `StreamingEditFileTool` (#50725)

Bennet Bo Fenner created

- Adds logging if tool fails
- Reduces boilerplate in test 
- Simplified code in a bunch of places

Release Notes:

- N/A

Change summary

crates/agent/src/thread.rs                         |   1 
crates/agent/src/tools/streaming_edit_file_tool.rs | 674 ++++++---------
2 files changed, 272 insertions(+), 403 deletions(-)

Detailed changes

crates/agent/src/thread.rs 🔗

@@ -1452,6 +1452,7 @@ impl Thread {
         self.add_tool(StreamingEditFileTool::new(
             self.project.clone(),
             cx.weak_entity(),
+            self.action_log.clone(),
             language_registry,
         ));
         self.add_tool(FetchTool::new(self.project.read(cx).client().http_client()));

crates/agent/src/tools/streaming_edit_file_tool.rs 🔗

@@ -73,7 +73,7 @@ pub struct StreamingEditFileToolInput {
     /// <example>
     /// `frontend/db.js`
     /// </example>
-    pub path: String,
+    pub path: PathBuf,
 
     /// The mode of operation on the file. Possible values:
     /// - 'write': Replace the entire contents of the file. If the file doesn't exist, it will be created. Requires 'content' field.
@@ -93,7 +93,7 @@ pub struct StreamingEditFileToolInput {
     pub edits: Option<Vec<Edit>>,
 }
 
-#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
+#[derive(Clone, Copy, Debug, Serialize, Deserialize, JsonSchema)]
 #[serde(rename_all = "snake_case")]
 pub enum StreamingEditFileMode {
     /// Overwrite the file with new content (replacing any existing content).
@@ -187,20 +187,23 @@ impl From<StreamingEditFileToolOutput> for LanguageModelToolResultContent {
 }
 
 pub struct StreamingEditFileTool {
+    project: Entity<Project>,
     thread: WeakEntity<Thread>,
+    action_log: Entity<ActionLog>,
     language_registry: Arc<LanguageRegistry>,
-    project: Entity<Project>,
 }
 
 impl StreamingEditFileTool {
     pub fn new(
         project: Entity<Project>,
         thread: WeakEntity<Thread>,
+        action_log: Entity<ActionLog>,
         language_registry: Arc<LanguageRegistry>,
     ) -> Self {
         Self {
             project,
             thread,
+            action_log,
             language_registry,
         }
     }
@@ -264,11 +267,11 @@ impl AgentTool for StreamingEditFileTool {
                         .read(cx)
                         .short_full_path_for_project_path(&project_path, cx)
                 })
-                .unwrap_or(input.path)
+                .unwrap_or(input.path.to_string_lossy().into_owned())
                 .into(),
             Err(raw_input) => {
-                if let Some(input) =
-                    serde_json::from_value::<StreamingEditFileToolPartialInput>(raw_input).ok()
+                if let Ok(input) =
+                    serde_json::from_value::<StreamingEditFileToolPartialInput>(raw_input)
                 {
                     let path = input.path.unwrap_or_default();
                     let path = path.trim();
@@ -311,24 +314,37 @@ impl AgentTool for StreamingEditFileTool {
                     partial = input.recv_partial().fuse() => {
                         let Some(partial_value) = partial else { break };
                         if let Ok(parsed) = serde_json::from_value::<StreamingEditFileToolPartialInput>(partial_value) {
-                            if state.is_none() && let Some(path_str) = &parsed.path
-                                && let Some(display_description) = &parsed.display_description
-                                && let Some(mode) = parsed.mode.clone() {
-                                    state = Some(
-                                        EditSession::new(
-                                            path_str,
-                                            display_description,
-                                            mode,
-                                            &self,
-                                            &event_stream,
-                                            cx,
-                                        )
-                                        .await?,
-                                    );
+                            if state.is_none()
+                                && let StreamingEditFileToolPartialInput {
+                                    path: Some(path),
+                                    display_description: Some(display_description),
+                                    mode: Some(mode),
+                                    ..
+                                } = &parsed
+                            {
+                                match EditSession::new(
+                                    &PathBuf::from(path),
+                                    display_description,
+                                    *mode,
+                                    &self,
+                                    &event_stream,
+                                    cx,
+                                )
+                                .await
+                                {
+                                    Ok(session) => state = Some(session),
+                                    Err(e) => {
+                                        log::error!("Failed to create edit session: {}", e);
+                                        return Err(e);
+                                    }
+                                }
                             }
 
                             if let Some(state) = &mut state {
-                                state.process(parsed, &self, &event_stream, cx)?;
+                                if let Err(e) = state.process(parsed, &self, &event_stream, cx) {
+                                    log::error!("Failed to process edit: {}", e);
+                                    return Err(e);
+                                }
                             }
                         }
                     }
@@ -341,22 +357,39 @@ impl AgentTool for StreamingEditFileTool {
                 input
                     .recv()
                     .await
-                    .map_err(|e| StreamingEditFileToolOutput::error(format!("Failed to receive tool input: {e}")))?;
+                    .map_err(|e| {
+                        let err = StreamingEditFileToolOutput::error(format!("Failed to receive tool input: {e}"));
+                        log::error!("Failed to receive tool input: {e}");
+                        err
+                    })?;
 
             let mut state = if let Some(state) = state {
                 state
             } else {
-                EditSession::new(
+                match EditSession::new(
                     &full_input.path,
                     &full_input.display_description,
-                    full_input.mode.clone(),
+                    full_input.mode,
                     &self,
                     &event_stream,
                     cx,
                 )
-                .await?
+                .await
+                {
+                    Ok(session) => session,
+                    Err(e) => {
+                        log::error!("Failed to create edit session: {}", e);
+                        return Err(e);
+                    }
+                }
             };
-            state.finalize(full_input, &self, &event_stream, cx).await
+            match state.finalize(full_input, &self, &event_stream, cx).await {
+                Ok(output) => Ok(output),
+                Err(e) => {
+                    log::error!("Failed to finalize edit: {}", e);
+                    Err(e)
+                }
+            }
         })
     }
 
@@ -442,30 +475,24 @@ impl EditPipeline {
     }
 }
 
-/// Compute the `LineIndent` of the first line in a set of query lines.
-fn query_first_line_indent(query_lines: &[String]) -> text::LineIndent {
-    let first_line = query_lines.first().map(|s| s.as_str()).unwrap_or("");
-    text::LineIndent::from_iter(first_line.chars())
-}
-
 impl EditSession {
     async fn new(
-        path_str: &str,
+        path: &PathBuf,
         display_description: &str,
         mode: StreamingEditFileMode,
         tool: &StreamingEditFileTool,
         event_stream: &ToolCallEventStream,
         cx: &mut AsyncApp,
     ) -> Result<Self, StreamingEditFileToolOutput> {
-        let path = PathBuf::from(path_str);
         let project_path = cx
-            .update(|cx| resolve_path(mode.clone(), &path, &tool.project, cx))
+            .update(|cx| resolve_path(mode, &path, &tool.project, cx))
             .map_err(|e| StreamingEditFileToolOutput::error(e.to_string()))?;
 
         let Some(abs_path) = cx.update(|cx| tool.project.read(cx).absolute_path(&project_path, cx))
         else {
             return Err(StreamingEditFileToolOutput::error(format!(
-                "Worktree at '{path_str}' does not exist"
+                "Worktree at '{}' does not exist",
+                path.to_string_lossy()
             )));
         };
 
@@ -483,12 +510,7 @@ impl EditSession {
             .await
             .map_err(|e| StreamingEditFileToolOutput::error(e.to_string()))?;
 
-        let action_log = tool
-            .thread
-            .read_with(cx, |thread, _cx| thread.action_log().clone())
-            .ok();
-
-        ensure_buffer_saved(&buffer, &abs_path, tool, action_log.as_ref(), cx)?;
+        ensure_buffer_saved(&buffer, &abs_path, tool, cx)?;
 
         let diff = cx.new(|cx| Diff::new(buffer.clone(), cx));
         event_stream.update_diff(diff.clone());
@@ -500,9 +522,8 @@ impl EditSession {
             }
         }) as Box<dyn FnOnce()>);
 
-        if let Some(action_log) = &action_log {
-            action_log.update(cx, |log, cx| log.buffer_read(buffer.clone(), cx));
-        }
+        tool.action_log
+            .update(cx, |log, cx| log.buffer_read(buffer.clone(), cx));
 
         let old_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
         let old_text = cx
@@ -531,69 +552,31 @@ impl EditSession {
         event_stream: &ToolCallEventStream,
         cx: &mut AsyncApp,
     ) -> Result<StreamingEditFileToolOutput, StreamingEditFileToolOutput> {
-        let Self {
-            buffer,
-            old_text,
-            diff,
-            abs_path,
-            parser,
-            pipeline,
-            ..
-        } = self;
-
-        let action_log = tool
-            .thread
-            .read_with(cx, |thread, _cx| thread.action_log().clone())
-            .map_err(|e| StreamingEditFileToolOutput::error(e.to_string()))?;
+        let old_text = self.old_text.clone();
 
         match input.mode {
             StreamingEditFileMode::Write => {
-                action_log.update(cx, |log, cx| {
-                    log.buffer_created(buffer.clone(), cx);
-                });
                 let content = input.content.ok_or_else(|| {
                     StreamingEditFileToolOutput::error("'content' field is required for write mode")
                 })?;
 
-                let events = parser.finalize_content(&content);
-                Self::process_events(
-                    &events,
-                    buffer,
-                    diff,
-                    pipeline,
-                    abs_path,
-                    tool,
-                    event_stream,
-                    cx,
-                )?;
+                let events = self.parser.finalize_content(&content);
+                self.process_events(&events, tool, event_stream, cx)?;
+
+                tool.action_log.update(cx, |log, cx| {
+                    log.buffer_created(self.buffer.clone(), cx);
+                });
             }
             StreamingEditFileMode::Edit => {
                 let edits = input.edits.ok_or_else(|| {
                     StreamingEditFileToolOutput::error("'edits' field is required for edit mode")
                 })?;
-
-                let final_edits = edits
-                    .into_iter()
-                    .map(|e| Edit {
-                        old_text: e.old_text,
-                        new_text: e.new_text,
-                    })
-                    .collect::<Vec<_>>();
-                let events = parser.finalize_edits(&final_edits);
-                Self::process_events(
-                    &events,
-                    buffer,
-                    diff,
-                    pipeline,
-                    abs_path,
-                    tool,
-                    event_stream,
-                    cx,
-                )?;
+                let events = self.parser.finalize_edits(&edits);
+                self.process_events(&events, tool, event_stream, cx)?;
             }
         }
 
-        let format_on_save_enabled = buffer.read_with(cx, |buffer, cx| {
+        let format_on_save_enabled = self.buffer.read_with(cx, |buffer, cx| {
             let settings = language_settings::language_settings(
                 buffer.language().map(|l| l.name()),
                 buffer.file(),
@@ -603,13 +586,13 @@ impl EditSession {
         });
 
         if format_on_save_enabled {
-            action_log.update(cx, |log, cx| {
-                log.buffer_edited(buffer.clone(), cx);
+            tool.action_log.update(cx, |log, cx| {
+                log.buffer_edited(self.buffer.clone(), cx);
             });
 
             let format_task = tool.project.update(cx, |project, cx| {
                 project.format(
-                    HashSet::from_iter([buffer.clone()]),
+                    HashSet::from_iter([self.buffer.clone()]),
                     LspFormatTarget::Buffers,
                     false,
                     FormatTrigger::Save,
@@ -624,9 +607,9 @@ impl EditSession {
             };
         }
 
-        let save_task = tool
-            .project
-            .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx));
+        let save_task = tool.project.update(cx, |project, cx| {
+            project.save_buffer(self.buffer.clone(), cx)
+        });
         futures::select! {
             result = save_task.fuse() => { result.map_err(|e| StreamingEditFileToolOutput::error(e.to_string()))?; },
             _ = event_stream.cancelled_by_user().fuse() => {
@@ -634,11 +617,11 @@ impl EditSession {
             }
         };
 
-        action_log.update(cx, |log, cx| {
-            log.buffer_edited(buffer.clone(), cx);
+        tool.action_log.update(cx, |log, cx| {
+            log.buffer_edited(self.buffer.clone(), cx);
         });
 
-        let new_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
+        let new_snapshot = self.buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
         let (new_text, unified_diff) = cx
             .background_spawn({
                 let new_snapshot = new_snapshot.clone();
@@ -652,7 +635,7 @@ impl EditSession {
             .await;
 
         let output = StreamingEditFileToolOutput::Success {
-            input_path: PathBuf::from(input.path),
+            input_path: input.path,
             new_text,
             old_text: old_text.clone(),
             diff: unified_diff,
@@ -671,31 +654,13 @@ impl EditSession {
             StreamingEditFileMode::Write => {
                 if let Some(content) = &partial.content {
                     let events = self.parser.push_content(content);
-                    Self::process_events(
-                        &events,
-                        &self.buffer,
-                        &self.diff,
-                        &mut self.pipeline,
-                        &self.abs_path,
-                        tool,
-                        event_stream,
-                        cx,
-                    )?;
+                    self.process_events(&events, tool, event_stream, cx)?;
                 }
             }
             StreamingEditFileMode::Edit => {
                 if let Some(edits) = partial.edits {
                     let events = self.parser.push_edits(&edits);
-                    Self::process_events(
-                        &events,
-                        &self.buffer,
-                        &self.diff,
-                        &mut self.pipeline,
-                        &self.abs_path,
-                        tool,
-                        event_stream,
-                        cx,
-                    )?;
+                    self.process_events(&events, tool, event_stream, cx)?;
                 }
             }
         }
@@ -703,46 +668,38 @@ impl EditSession {
     }
 
     fn process_events(
+        &mut self,
         events: &[ToolEditEvent],
-        buffer: &Entity<Buffer>,
-        diff: &Entity<Diff>,
-        pipeline: &mut EditPipeline,
-        abs_path: &PathBuf,
         tool: &StreamingEditFileTool,
         event_stream: &ToolCallEventStream,
         cx: &mut AsyncApp,
     ) -> Result<(), StreamingEditFileToolOutput> {
-        let action_log = tool
-            .thread
-            .read_with(cx, |thread, _cx| thread.action_log().clone())
-            .ok();
-
         for event in events {
             match event {
                 ToolEditEvent::ContentChunk { chunk } => {
-                    let (buffer_id, insert_at) = buffer.read_with(cx, |buffer, _cx| {
-                        let insert_at = if !pipeline.content_written && buffer.len() > 0 {
-                            0..buffer.len()
-                        } else {
-                            let len = buffer.len();
-                            len..len
-                        };
-                        (buffer.remote_id(), insert_at)
-                    });
+                    let (buffer_id, buffer_len) = self
+                        .buffer
+                        .read_with(cx, |buffer, _cx| (buffer.remote_id(), buffer.len()));
+                    let edit_range = if self.pipeline.content_written {
+                        buffer_len..buffer_len
+                    } else {
+                        0..buffer_len
+                    };
+
                     agent_edit_buffer(
-                        buffer,
-                        [(insert_at, chunk.as_str())],
-                        action_log.as_ref(),
+                        &self.buffer,
+                        [(edit_range, chunk.as_str())],
+                        &tool.action_log,
                         cx,
                     );
                     cx.update(|cx| {
                         tool.set_agent_location(
-                            buffer.downgrade(),
+                            self.buffer.downgrade(),
                             text::Anchor::max_for_buffer(buffer_id),
                             cx,
                         );
                     });
-                    pipeline.content_written = true;
+                    self.pipeline.content_written = true;
                 }
 
                 ToolEditEvent::OldTextChunk {
@@ -750,23 +707,24 @@ impl EditSession {
                     chunk,
                     done: false,
                 } => {
-                    pipeline.ensure_resolving_old_text(*edit_index, buffer, cx);
+                    self.pipeline
+                        .ensure_resolving_old_text(*edit_index, &self.buffer, cx);
 
                     if let EditPipelineEntry::ResolvingOldText { matcher } =
-                        &mut pipeline.edits[*edit_index]
+                        &mut self.pipeline.edits[*edit_index]
+                        && !chunk.is_empty()
                     {
-                        if !chunk.is_empty() {
-                            if let Some(match_range) = matcher.push(chunk, None) {
-                                let anchor_range = buffer.read_with(cx, |buffer, _cx| {
-                                    buffer.anchor_range_between(match_range.clone())
-                                });
-                                diff.update(cx, |diff, cx| diff.reveal_range(anchor_range, cx));
-
-                                cx.update(|cx| {
-                                    let position = buffer.read(cx).anchor_before(match_range.end);
-                                    tool.set_agent_location(buffer.downgrade(), position, cx);
-                                });
-                            }
+                        if let Some(match_range) = matcher.push(chunk, None) {
+                            let anchor_range = self.buffer.read_with(cx, |buffer, _cx| {
+                                buffer.anchor_range_between(match_range.clone())
+                            });
+                            self.diff
+                                .update(cx, |diff, cx| diff.reveal_range(anchor_range, cx));
+
+                            cx.update(|cx| {
+                                let position = self.buffer.read(cx).anchor_before(match_range.end);
+                                tool.set_agent_location(self.buffer.downgrade(), position, cx);
+                            });
                         }
                     }
                 }
@@ -776,10 +734,11 @@ impl EditSession {
                     chunk,
                     done: true,
                 } => {
-                    pipeline.ensure_resolving_old_text(*edit_index, buffer, cx);
+                    self.pipeline
+                        .ensure_resolving_old_text(*edit_index, &self.buffer, cx);
 
                     let EditPipelineEntry::ResolvingOldText { matcher } =
-                        &mut pipeline.edits[*edit_index]
+                        &mut self.pipeline.edits[*edit_index]
                     else {
                         continue;
                     };
@@ -787,60 +746,47 @@ impl EditSession {
                     if !chunk.is_empty() {
                         matcher.push(chunk, None);
                     }
-                    let matches = matcher.finish();
-
-                    if matches.is_empty() {
-                        return Err(StreamingEditFileToolOutput::error(format!(
-                            "Could not find matching text for edit at index {}. \
-                                 The old_text did not match any content in the file. \
-                                 Please read the file again to get the current content.",
-                            edit_index,
-                        )));
-                    }
-                    if matches.len() > 1 {
-                        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
-                        let lines = matches
-                            .iter()
-                            .map(|r| (snapshot.offset_to_point(r.start).row + 1).to_string())
-                            .collect::<Vec<_>>()
-                            .join(", ");
-                        return Err(StreamingEditFileToolOutput::error(format!(
-                            "Edit {} matched multiple locations in the file at lines: {}. \
-                                 Please provide more context in old_text to uniquely \
-                                 identify the location.",
-                            edit_index, lines
-                        )));
-                    }
-
-                    let range = matches.into_iter().next().expect("checked len above");
+                    let range = extract_match(matcher.finish(), &self.buffer, edit_index, cx)?;
 
-                    let anchor_range = buffer
+                    let anchor_range = self
+                        .buffer
                         .read_with(cx, |buffer, _cx| buffer.anchor_range_between(range.clone()));
-                    diff.update(cx, |diff, cx| diff.reveal_range(anchor_range, cx));
+                    self.diff
+                        .update(cx, |diff, cx| diff.reveal_range(anchor_range, cx));
 
-                    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
+                    let snapshot = self.buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 
                     let line = snapshot.offset_to_point(range.start).row;
                     event_stream.update_fields(
-                        ToolCallUpdateFields::new()
-                            .locations(vec![ToolCallLocation::new(abs_path).line(Some(line))]),
+                        ToolCallUpdateFields::new().locations(vec![
+                            ToolCallLocation::new(&self.abs_path).line(Some(line)),
+                        ]),
                     );
 
                     let EditPipelineEntry::ResolvingOldText { matcher } =
-                        &pipeline.edits[*edit_index]
+                        &self.pipeline.edits[*edit_index]
                     else {
                         continue;
                     };
                     let buffer_indent =
                         snapshot.line_indent_for_row(snapshot.offset_to_point(range.start).row);
-                    let query_indent = query_first_line_indent(matcher.query_lines());
+                    let query_indent = text::LineIndent::from_iter(
+                        matcher
+                            .query_lines()
+                            .first()
+                            .map(|s| s.as_str())
+                            .unwrap_or("")
+                            .chars(),
+                    );
                     let indent_delta = compute_indent_delta(buffer_indent, query_indent);
 
                     let old_text_in_buffer =
                         snapshot.text_for_range(range.clone()).collect::<String>();
 
-                    let text_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.text_snapshot());
-                    pipeline.edits[*edit_index] = EditPipelineEntry::StreamingNewText {
+                    let text_snapshot = self
+                        .buffer
+                        .read_with(cx, |buffer, _cx| buffer.text_snapshot());
+                    self.pipeline.edits[*edit_index] = EditPipelineEntry::StreamingNewText {
                         streaming_diff: StreamingDiff::new(old_text_in_buffer),
                         edit_cursor: range.start,
                         reindenter: Reindenter::new(indent_delta),
@@ -848,8 +794,8 @@ impl EditSession {
                     };
 
                     cx.update(|cx| {
-                        let position = buffer.read(cx).anchor_before(range.end);
-                        tool.set_agent_location(buffer.downgrade(), position, cx);
+                        let position = self.buffer.read(cx).anchor_before(range.end);
+                        tool.set_agent_location(self.buffer.downgrade(), position, cx);
                     });
                 }
 
@@ -858,7 +804,7 @@ impl EditSession {
                     chunk,
                     done: false,
                 } => {
-                    if *edit_index >= pipeline.edits.len() {
+                    if *edit_index >= self.pipeline.edits.len() {
                         continue;
                     }
                     let EditPipelineEntry::StreamingNewText {
@@ -867,7 +813,7 @@ impl EditSession {
                         reindenter,
                         original_snapshot,
                         ..
-                    } = &mut pipeline.edits[*edit_index]
+                    } = &mut self.pipeline.edits[*edit_index]
                     else {
                         continue;
                     };
@@ -878,18 +824,18 @@ impl EditSession {
                     }
 
                     let char_ops = streaming_diff.push_new(&reindented);
-                    Self::apply_char_operations(
+                    apply_char_operations(
                         &char_ops,
-                        buffer,
+                        &self.buffer,
                         original_snapshot,
                         edit_cursor,
-                        action_log.as_ref(),
+                        &tool.action_log,
                         cx,
                     );
 
                     let position = original_snapshot.anchor_before(*edit_cursor);
                     cx.update(|cx| {
-                        tool.set_agent_location(buffer.downgrade(), position, cx);
+                        tool.set_agent_location(self.buffer.downgrade(), position, cx);
                     });
                 }
 
@@ -898,7 +844,7 @@ impl EditSession {
                     chunk,
                     done: true,
                 } => {
-                    if *edit_index >= pipeline.edits.len() {
+                    if *edit_index >= self.pipeline.edits.len() {
                         continue;
                     }
 
@@ -908,7 +854,7 @@ impl EditSession {
                         mut reindenter,
                         original_snapshot,
                     } = std::mem::replace(
-                        &mut pipeline.edits[*edit_index],
+                        &mut self.pipeline.edits[*edit_index],
                         EditPipelineEntry::Done,
                     )
                     else {
@@ -921,64 +867,95 @@ impl EditSession {
 
                     if !final_text.is_empty() {
                         let char_ops = streaming_diff.push_new(&final_text);
-                        Self::apply_char_operations(
+                        apply_char_operations(
                             &char_ops,
-                            buffer,
+                            &self.buffer,
                             &original_snapshot,
                             &mut edit_cursor,
-                            action_log.as_ref(),
+                            &tool.action_log,
                             cx,
                         );
                     }
 
                     let remaining_ops = streaming_diff.finish();
-                    Self::apply_char_operations(
+                    apply_char_operations(
                         &remaining_ops,
-                        buffer,
+                        &self.buffer,
                         &original_snapshot,
                         &mut edit_cursor,
-                        action_log.as_ref(),
+                        &tool.action_log,
                         cx,
                     );
 
                     let position = original_snapshot.anchor_before(edit_cursor);
                     cx.update(|cx| {
-                        tool.set_agent_location(buffer.downgrade(), position, cx);
+                        tool.set_agent_location(self.buffer.downgrade(), position, cx);
                     });
                 }
             }
         }
         Ok(())
     }
+}
 
-    fn apply_char_operations(
-        ops: &[CharOperation],
-        buffer: &Entity<Buffer>,
-        snapshot: &text::BufferSnapshot,
-        edit_cursor: &mut usize,
-        action_log: Option<&Entity<ActionLog>>,
-        cx: &mut AsyncApp,
-    ) {
-        for op in ops {
-            match op {
-                CharOperation::Insert { text } => {
-                    let anchor = snapshot.anchor_after(*edit_cursor);
-                    agent_edit_buffer(&buffer, [(anchor..anchor, text.as_str())], action_log, cx);
-                }
-                CharOperation::Delete { bytes } => {
-                    let delete_end = *edit_cursor + bytes;
-                    let anchor_range = snapshot.anchor_range_around(*edit_cursor..delete_end);
-                    agent_edit_buffer(&buffer, [(anchor_range, "")], action_log, cx);
-                    *edit_cursor = delete_end;
-                }
-                CharOperation::Keep { bytes } => {
-                    *edit_cursor += bytes;
-                }
+fn apply_char_operations(
+    ops: &[CharOperation],
+    buffer: &Entity<Buffer>,
+    snapshot: &text::BufferSnapshot,
+    edit_cursor: &mut usize,
+    action_log: &Entity<ActionLog>,
+    cx: &mut AsyncApp,
+) {
+    for op in ops {
+        match op {
+            CharOperation::Insert { text } => {
+                let anchor = snapshot.anchor_after(*edit_cursor);
+                agent_edit_buffer(&buffer, [(anchor..anchor, text.as_str())], action_log, cx);
+            }
+            CharOperation::Delete { bytes } => {
+                let delete_end = *edit_cursor + bytes;
+                let anchor_range = snapshot.anchor_range_around(*edit_cursor..delete_end);
+                agent_edit_buffer(&buffer, [(anchor_range, "")], action_log, cx);
+                *edit_cursor = delete_end;
+            }
+            CharOperation::Keep { bytes } => {
+                *edit_cursor += bytes;
             }
         }
     }
 }
 
+fn extract_match(
+    matches: Vec<Range<usize>>,
+    buffer: &Entity<Buffer>,
+    edit_index: &usize,
+    cx: &mut AsyncApp,
+) -> Result<Range<usize>, StreamingEditFileToolOutput> {
+    match matches.len() {
+        0 => Err(StreamingEditFileToolOutput::error(format!(
+            "Could not find matching text for edit at index {}. \
+                The old_text did not match any content in the file. \
+                Please read the file again to get the current content.",
+            edit_index,
+        ))),
+        1 => Ok(matches.into_iter().next().unwrap()),
+        _ => {
+            let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
+            let lines = matches
+                .iter()
+                .map(|r| (snapshot.offset_to_point(r.start).row + 1).to_string())
+                .collect::<Vec<_>>()
+                .join(", ");
+            Err(StreamingEditFileToolOutput::error(format!(
+                "Edit {} matched multiple locations in the file at lines: {}. \
+                    Please provide more context in old_text to uniquely \
+                    identify the location.",
+                edit_index, lines
+            )))
+        }
+    }
+}
+
 /// Edits a buffer and reports the edit to the action log in the same effect
 /// cycle. This ensures the action log's subscription handler sees the version
 /// already updated by `buffer_edited`, so it does not misattribute the agent's
@@ -986,7 +963,7 @@ impl EditSession {
 fn agent_edit_buffer<I, S, T>(
     buffer: &Entity<Buffer>,
     edits: I,
-    action_log: Option<&Entity<ActionLog>>,
+    action_log: &Entity<ActionLog>,
     cx: &mut AsyncApp,
 ) where
     I: IntoIterator<Item = (Range<S>, T)>,
@@ -997,9 +974,7 @@ fn agent_edit_buffer<I, S, T>(
         buffer.update(cx, |buffer, cx| {
             buffer.edit(edits, None, cx);
         });
-        if let Some(action_log) = action_log {
-            action_log.update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx));
-        }
+        action_log.update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx));
     });
 }
 
@@ -1007,11 +982,11 @@ fn ensure_buffer_saved(
     buffer: &Entity<Buffer>,
     abs_path: &PathBuf,
     tool: &StreamingEditFileTool,
-    action_log: Option<&Entity<ActionLog>>,
     cx: &mut AsyncApp,
 ) -> Result<(), StreamingEditFileToolOutput> {
-    let last_read_mtime =
-        action_log.and_then(|log| log.read_with(cx, |log, _| log.file_read_time(abs_path)));
+    let last_read_mtime = tool
+        .action_log
+        .read_with(cx, |log, _| log.file_read_time(abs_path));
     let check_result = tool.thread.read_with(cx, |thread, cx| {
         let current = buffer
             .read(cx)
@@ -1140,42 +1115,17 @@ mod tests {
 
     #[gpui::test]
     async fn test_streaming_edit_create_file(cx: &mut TestAppContext) {
-        init_test(cx);
-
-        let fs = project::FakeFs::new(cx.executor());
-        fs.insert_tree("/root", json!({"dir": {}})).await;
-        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
-        let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
-        let context_server_registry =
-            cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
-        let model = Arc::new(FakeLanguageModel::default());
-        let thread = cx.new(|cx| {
-            crate::Thread::new(
-                project.clone(),
-                cx.new(|_cx| ProjectContext::default()),
-                context_server_registry,
-                Templates::new(),
-                Some(model),
-                cx,
-            )
-        });
-
+        let (tool, _project, _action_log, _fs, _thread) = setup_test(cx, json!({"dir": {}})).await;
         let result = cx
             .update(|cx| {
-                let input = StreamingEditFileToolInput {
-                    display_description: "Create new file".into(),
-                    path: "root/dir/new_file.txt".into(),
-                    mode: StreamingEditFileMode::Write,
-                    content: Some("Hello, World!".into()),
-                    edits: None,
-                };
-                Arc::new(StreamingEditFileTool::new(
-                    project.clone(),
-                    thread.downgrade(),
-                    language_registry,
-                ))
-                .run(
-                    ToolInput::resolved(input),
+                tool.clone().run(
+                    ToolInput::resolved(StreamingEditFileToolInput {
+                        display_description: "Create new file".into(),
+                        path: "root/dir/new_file.txt".into(),
+                        mode: StreamingEditFileMode::Write,
+                        content: Some("Hello, World!".into()),
+                        edits: None,
+                    }),
                     ToolCallEventStream::test().0,
                     cx,
                 )
@@ -1191,43 +1141,18 @@ mod tests {
 
     #[gpui::test]
     async fn test_streaming_edit_overwrite_file(cx: &mut TestAppContext) {
-        init_test(cx);
-
-        let fs = project::FakeFs::new(cx.executor());
-        fs.insert_tree("/root", json!({"file.txt": "old content"}))
-            .await;
-        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
-        let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
-        let context_server_registry =
-            cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
-        let model = Arc::new(FakeLanguageModel::default());
-        let thread = cx.new(|cx| {
-            crate::Thread::new(
-                project.clone(),
-                cx.new(|_cx| ProjectContext::default()),
-                context_server_registry,
-                Templates::new(),
-                Some(model),
-                cx,
-            )
-        });
-
+        let (tool, _project, _action_log, _fs, _thread) =
+            setup_test(cx, json!({"file.txt": "old content"})).await;
         let result = cx
             .update(|cx| {
-                let input = StreamingEditFileToolInput {
-                    display_description: "Overwrite file".into(),
-                    path: "root/file.txt".into(),
-                    mode: StreamingEditFileMode::Write,
-                    content: Some("new content".into()),
-                    edits: None,
-                };
-                Arc::new(StreamingEditFileTool::new(
-                    project.clone(),
-                    thread.downgrade(),
-                    language_registry,
-                ))
-                .run(
-                    ToolInput::resolved(input),
+                tool.clone().run(
+                    ToolInput::resolved(StreamingEditFileToolInput {
+                        display_description: "Overwrite file".into(),
+                        path: "root/file.txt".into(),
+                        mode: StreamingEditFileMode::Write,
+                        content: Some("new content".into()),
+                        edits: None,
+                    }),
                     ToolCallEventStream::test().0,
                     cx,
                 )
@@ -1246,51 +1171,21 @@ mod tests {
 
     #[gpui::test]
     async fn test_streaming_edit_granular_edits(cx: &mut TestAppContext) {
-        init_test(cx);
-
-        let fs = project::FakeFs::new(cx.executor());
-        fs.insert_tree(
-            "/root",
-            json!({
-                "file.txt": "line 1\nline 2\nline 3\n"
-            }),
-        )
-        .await;
-        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
-        let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
-        let context_server_registry =
-            cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
-        let model = Arc::new(FakeLanguageModel::default());
-        let thread = cx.new(|cx| {
-            crate::Thread::new(
-                project.clone(),
-                cx.new(|_cx| ProjectContext::default()),
-                context_server_registry,
-                Templates::new(),
-                Some(model),
-                cx,
-            )
-        });
-
+        let (tool, _project, _action_log, _fs, _thread) =
+            setup_test(cx, json!({"file.txt": "line 1\nline 2\nline 3\n"})).await;
         let result = cx
             .update(|cx| {
-                let input = StreamingEditFileToolInput {
-                    display_description: "Edit lines".into(),
-                    path: "root/file.txt".into(),
-                    mode: StreamingEditFileMode::Edit,
-                    content: None,
-                    edits: Some(vec![Edit {
-                        old_text: "line 2".into(),
-                        new_text: "modified line 2".into(),
-                    }]),
-                };
-                Arc::new(StreamingEditFileTool::new(
-                    project.clone(),
-                    thread.downgrade(),
-                    language_registry,
-                ))
-                .run(
-                    ToolInput::resolved(input),
+                tool.clone().run(
+                    ToolInput::resolved(StreamingEditFileToolInput {
+                        display_description: "Edit lines".into(),
+                        path: "root/file.txt".into(),
+                        mode: StreamingEditFileMode::Edit,
+                        content: None,
+                        edits: Some(vec![Edit {
+                            old_text: "line 2".into(),
+                            new_text: "modified line 2".into(),
+                        }]),
+                    }),
                     ToolCallEventStream::test().0,
                     cx,
                 )
@@ -1305,57 +1200,30 @@ mod tests {
 
     #[gpui::test]
     async fn test_streaming_edit_multiple_edits(cx: &mut TestAppContext) {
-        init_test(cx);
-
-        let fs = project::FakeFs::new(cx.executor());
-        fs.insert_tree(
-            "/root",
-            json!({
-                "file.txt": "line 1\nline 2\nline 3\nline 4\nline 5\n"
-            }),
+        let (tool, _project, _action_log, _fs, _thread) = setup_test(
+            cx,
+            json!({"file.txt": "line 1\nline 2\nline 3\nline 4\nline 5\n"}),
         )
         .await;
-        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
-        let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
-        let context_server_registry =
-            cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
-        let model = Arc::new(FakeLanguageModel::default());
-        let thread = cx.new(|cx| {
-            crate::Thread::new(
-                project.clone(),
-                cx.new(|_cx| ProjectContext::default()),
-                context_server_registry,
-                Templates::new(),
-                Some(model),
-                cx,
-            )
-        });
-
         let result = cx
             .update(|cx| {
-                let input = StreamingEditFileToolInput {
-                    display_description: "Edit multiple lines".into(),
-                    path: "root/file.txt".into(),
-                    mode: StreamingEditFileMode::Edit,
-                    content: None,
-                    edits: Some(vec![
-                        Edit {
-                            old_text: "line 5".into(),
-                            new_text: "modified line 5".into(),
-                        },
-                        Edit {
-                            old_text: "line 1".into(),
-                            new_text: "modified line 1".into(),
-                        },
-                    ]),
-                };
-                Arc::new(StreamingEditFileTool::new(
-                    project.clone(),
-                    thread.downgrade(),
-                    language_registry,
-                ))
-                .run(
-                    ToolInput::resolved(input),
+                tool.clone().run(
+                    ToolInput::resolved(StreamingEditFileToolInput {
+                        display_description: "Edit multiple lines".into(),
+                        path: "root/file.txt".into(),
+                        mode: StreamingEditFileMode::Edit,
+                        content: None,
+                        edits: Some(vec![
+                            Edit {
+                                old_text: "line 5".into(),
+                                new_text: "modified line 5".into(),
+                            },
+                            Edit {
+                                old_text: "line 1".into(),
+                                new_text: "modified line 1".into(),
+                            },
+                        ]),
+                    }),
                     ToolCallEventStream::test().0,
                     cx,
                 )