From a2e34cb7bff8121a3b195bc3d5d1fbfe668ef5a5 Mon Sep 17 00:00:00 2001 From: Bennet Bo Fenner Date: Wed, 25 Feb 2026 23:58:25 +0100 Subject: [PATCH] agent: Implement streaming for edit file tool (#50004) Before you mark this PR as ready for review, make sure that you have: - [x] Added a solid test coverage and/or screenshots from doing manual testing - [x] Done a self-review taking into account security and performance aspects - [x] Aligned any UI changes with the [UI checklist](https://github.com/zed-industries/zed/blob/main/CONTRIBUTING.md#uiux-checklist) Release Notes: - N/A --------- Co-authored-by: Zed Zippy <234243425+zed-zippy[bot]@users.noreply.github.com> --- crates/agent/src/thread.rs | 1 + crates/agent/src/tools.rs | 1 + .../src/tools/streaming_edit_file_tool.rs | 3679 +++++++++++++++-- crates/agent_ui/src/buffer_codegen.rs | 2 + crates/anthropic/src/anthropic.rs | 6 + crates/language_model/src/request.rs | 1 + .../language_models/src/provider/anthropic.rs | 2 + .../language_models/src/provider/open_ai.rs | 1 + 8 files changed, 3395 insertions(+), 298 deletions(-) diff --git a/crates/agent/src/thread.rs b/crates/agent/src/thread.rs index 39fef567916eb7d4a7bf04db3a0455bead6eee2f..923fbd11126f21459131b7ca194288de6af5498e 100644 --- a/crates/agent/src/thread.rs +++ b/crates/agent/src/thread.rs @@ -2461,6 +2461,7 @@ impl Thread { name: tool_name.to_string(), description: tool.description().to_string(), input_schema: tool.input_schema(model.tool_input_format()).log_err()?, + use_input_streaming: tool.supports_input_streaming(), }) }) .collect::>() diff --git a/crates/agent/src/tools.rs b/crates/agent/src/tools.rs index 1962f237045c47935de90ebb231575da29d1205c..b2724801befc7459ad37494d298819f4b7ca6b27 100644 --- a/crates/agent/src/tools.rs +++ b/crates/agent/src/tools.rs @@ -100,6 +100,7 @@ macro_rules! tools { name: T::NAME.to_string(), description: T::description().to_string(), input_schema: T::input_schema(LanguageModelToolSchemaFormat::JsonSchema).to_value(), + use_input_streaming: T::supports_input_streaming(), } } [ diff --git a/crates/agent/src/tools/streaming_edit_file_tool.rs b/crates/agent/src/tools/streaming_edit_file_tool.rs index 95651b44bac44ad3cc67c25c0ef13fc885342ce3..933fa2ff1e996ac6802da2a60ba832104531a230 100644 --- a/crates/agent/src/tools/streaming_edit_file_tool.rs +++ b/crates/agent/src/tools/streaming_edit_file_tool.rs @@ -11,8 +11,8 @@ use anyhow::{Context as _, Result, anyhow}; use collections::HashSet; use futures::FutureExt as _; use gpui::{App, AppContext, AsyncApp, Entity, Task, WeakEntity}; -use language::LanguageRegistry; use language::language_settings::{self, FormatOnSave}; +use language::{Buffer, LanguageRegistry}; use language_model::LanguageModelToolResultContent; use project::lsp_store::{FormatTrigger, LspFormatTarget}; use project::{Project, ProjectPath}; @@ -23,8 +23,8 @@ use std::path::PathBuf; use std::sync::Arc; use text::BufferSnapshot; use ui::SharedString; -use util::ResultExt; use util::rel_path::RelPath; +use util::{Deferred, ResultExt, debug_panic}; const DEFAULT_UI_TEXT: &str = "Editing file"; @@ -67,7 +67,7 @@ pub struct StreamingEditFileToolInput { /// /// `frontend/db.js` /// - pub path: PathBuf, + pub path: String, /// The mode of operation on the file. Possible values: /// - 'create': Create a new file if it doesn't exist. Requires 'content' field. @@ -109,12 +109,488 @@ pub struct EditOperation { pub new_text: String, } -#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)] +#[derive(Default, Debug, Deserialize)] struct StreamingEditFileToolPartialInput { #[serde(default)] - path: String, + display_description: Option, + #[serde(default)] + path: Option, + #[serde(default)] + mode: Option, + #[serde(default)] + #[allow(dead_code)] + content: Option, + #[serde(default)] + edits: Option>, +} + +#[derive(Default, Debug, Deserialize)] +struct PartialEditOperation { + #[serde(default)] + old_text: Option, #[serde(default)] - display_description: String, + new_text: Option, +} + +enum StreamingEditState { + Idle, + BufferResolved { + abs_path: PathBuf, + buffer: Entity, + old_text: Arc, + diff: Entity, + edit_state: IncrementalEditState, + _finalize_diff_guard: Deferred>, + }, +} + +#[derive(Default)] +struct IncrementalEditState { + applied_count: usize, + in_progress_matcher: Option, + last_old_text_len: usize, +} + +impl StreamingEditState { + async fn finalize( + &mut self, + input: StreamingEditFileToolInput, + tool: &StreamingEditFileTool, + event_stream: &ToolCallEventStream, + cx: &mut AsyncApp, + ) -> Result { + let remaining_edits_start_ix = match self { + StreamingEditState::Idle => { + *self = Self::transition_to_buffer_resolved( + &input.path, + &input.display_description, + input.mode.clone(), + tool, + event_stream, + cx, + ) + .await?; + 0 + } + StreamingEditState::BufferResolved { edit_state, .. } => edit_state.applied_count, + }; + + let StreamingEditState::BufferResolved { + buffer, + old_text, + diff, + abs_path, + .. + } = self + else { + debug_panic!("Invalid state"); + return Ok(StreamingEditFileToolOutput::Error { + error: "Internal error. Try to apply the edits again".to_string(), + }); + }; + + let result: anyhow::Result = async { + let action_log = tool + .thread + .read_with(cx, |thread, _cx| thread.action_log().clone())?; + + match input.mode { + StreamingEditFileMode::Create | StreamingEditFileMode::Overwrite => { + action_log.update(cx, |log, cx| { + log.buffer_created(buffer.clone(), cx); + }); + let content = input.content.ok_or_else(|| { + anyhow!("'content' field is required for create and overwrite modes") + })?; + cx.update(|cx| { + buffer.update(cx, |buffer, cx| { + buffer.edit([(0..buffer.len(), content.as_str())], None, cx); + }); + action_log.update(cx, |log, cx| { + log.buffer_edited(buffer.clone(), cx); + }); + }); + } + StreamingEditFileMode::Edit => { + let edits = input + .edits + .ok_or_else(|| anyhow!("'edits' field is required for edit mode"))?; + + let remaining_edits = &edits[remaining_edits_start_ix..]; + apply_edits( + &buffer, + &action_log, + remaining_edits, + &diff, + event_stream, + &abs_path, + cx, + )?; + } + } + + let format_on_save_enabled = buffer.read_with(cx, |buffer, cx| { + let settings = language_settings::language_settings( + buffer.language().map(|l| l.name()), + buffer.file(), + cx, + ); + settings.format_on_save != FormatOnSave::Off + }); + + if format_on_save_enabled { + action_log.update(cx, |log, cx| { + log.buffer_edited(buffer.clone(), cx); + }); + + let format_task = tool.project.update(cx, |project, cx| { + project.format( + HashSet::from_iter([buffer.clone()]), + LspFormatTarget::Buffers, + false, + FormatTrigger::Save, + cx, + ) + }); + futures::select! { + result = format_task.fuse() => { result.log_err(); }, + _ = event_stream.cancelled_by_user().fuse() => { + anyhow::bail!("Edit cancelled by user"); + } + }; + } + + let save_task = tool + .project + .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx)); + futures::select! { + result = save_task.fuse() => { result?; }, + _ = event_stream.cancelled_by_user().fuse() => { + anyhow::bail!("Edit cancelled by user"); + } + }; + + action_log.update(cx, |log, cx| { + log.buffer_edited(buffer.clone(), cx); + }); + + if let Some(new_mtime) = buffer.read_with(cx, |buffer, _| { + buffer.file().and_then(|file| file.disk_state().mtime()) + }) { + tool.thread.update(cx, |thread, _| { + thread + .file_read_times + .insert(abs_path.to_path_buf(), new_mtime); + })?; + } + + let new_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot()); + let (new_text, unified_diff) = cx + .background_spawn({ + let new_snapshot = new_snapshot.clone(); + let old_text = old_text.clone(); + async move { + let new_text = new_snapshot.text(); + let diff = language::unified_diff(&old_text, &new_text); + (new_text, diff) + } + }) + .await; + + let output = StreamingEditFileToolOutput::Success { + input_path: PathBuf::from(input.path), + new_text, + old_text: old_text.clone(), + diff: unified_diff, + }; + Ok(output) + } + .await; + result.map_err(|e| StreamingEditFileToolOutput::Error { + error: e.to_string(), + }) + } + + async fn process( + &mut self, + partial: StreamingEditFileToolPartialInput, + tool: &StreamingEditFileTool, + event_stream: &ToolCallEventStream, + cx: &mut AsyncApp, + ) -> Result<(), StreamingEditFileToolOutput> { + match self { + Self::Idle => { + if let Some(path_str) = partial.path + && let Some(display_description) = partial.display_description + && let Some(mode) = partial.mode + { + *self = Self::transition_to_buffer_resolved( + &path_str, + &display_description, + mode, + tool, + event_stream, + cx, + ) + .await?; + } + } + Self::BufferResolved { + abs_path, + buffer, + edit_state, + diff, + .. + } => { + if let Some(edits) = partial.edits { + Self::process_streaming_edits( + buffer, + diff, + edit_state, + &edits, + abs_path, + tool, + event_stream, + cx, + )?; + } + } + } + Ok(()) + } + + async fn transition_to_buffer_resolved( + path_str: &str, + display_description: &str, + mode: StreamingEditFileMode, + tool: &StreamingEditFileTool, + event_stream: &ToolCallEventStream, + cx: &mut AsyncApp, + ) -> Result { + let path = PathBuf::from(path_str); + let project_path = cx + .update(|cx| resolve_path(mode, &path, &tool.project, cx)) + .map_err(|e| StreamingEditFileToolOutput::Error { + error: e.to_string(), + })?; + + let Some(abs_path) = cx.update(|cx| tool.project.read(cx).absolute_path(&project_path, cx)) + else { + return Err(StreamingEditFileToolOutput::Error { + error: format!("File '{path_str}' does not exist"), + }); + }; + + event_stream.update_fields( + ToolCallUpdateFields::new().locations(vec![ToolCallLocation::new(abs_path.clone())]), + ); + + cx.update(|cx| tool.authorize(&path, &display_description, event_stream, cx)) + .await + .map_err(|e| StreamingEditFileToolOutput::Error { + error: e.to_string(), + })?; + + let buffer = tool + .project + .update(cx, |project, cx| project.open_buffer(project_path, cx)) + .await + .map_err(|e| StreamingEditFileToolOutput::Error { + error: e.to_string(), + })?; + + ensure_buffer_saved(&buffer, &abs_path, tool, cx)?; + + let diff = cx.new(|cx| Diff::new(buffer.clone(), cx)); + event_stream.update_diff(diff.clone()); + let finalize_diff_guard = util::defer(Box::new({ + let diff = diff.downgrade(); + let mut cx = cx.clone(); + move || { + diff.update(&mut cx, |diff, cx| diff.finalize(cx)).ok(); + } + }) as Box); + + let old_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot()); + let old_text = cx + .background_spawn({ + let old_snapshot = old_snapshot.clone(); + async move { Arc::new(old_snapshot.text()) } + }) + .await; + + Ok(Self::BufferResolved { + abs_path, + buffer, + old_text, + diff, + edit_state: IncrementalEditState::default(), + _finalize_diff_guard: finalize_diff_guard, + }) + } + + fn process_streaming_edits( + buffer: &Entity, + diff: &Entity, + edit_state: &mut IncrementalEditState, + edits: &[PartialEditOperation], + abs_path: &PathBuf, + tool: &StreamingEditFileTool, + event_stream: &ToolCallEventStream, + cx: &mut AsyncApp, + ) -> Result<(), StreamingEditFileToolOutput> { + if edits.is_empty() { + return Ok(()); + } + + // Edits at indices applied_count..edits.len()-1 are newly complete + // (a subsequent edit exists, proving the LLM moved on). + // The last edit (edits.len()-1) is potentially still in progress. + let completed_count = edits.len().saturating_sub(1); + + // Apply newly-complete edits + while edit_state.applied_count < completed_count { + let edit_index = edit_state.applied_count; + let partial_edit = &edits[edit_index]; + + let old_text = match &partial_edit.old_text { + Some(t) => t.clone(), + None => { + edit_state.applied_count += 1; + continue; + } + }; + let new_text = partial_edit.new_text.clone().unwrap_or_default(); + + edit_state.in_progress_matcher = None; + edit_state.last_old_text_len = 0; + + let edit_op = EditOperation { + old_text: old_text.clone(), + new_text: new_text.clone(), + }; + + let action_log = tool + .thread + .read_with(cx, |thread, _cx| thread.action_log().clone()) + .ok(); + + // On the first edit, mark the buffer as read + if edit_state.applied_count == 0 { + if let Some(action_log) = &action_log { + action_log.update(cx, |log, cx| { + log.buffer_read(buffer.clone(), cx); + }); + } + } + + resolve_reveal_and_apply_edit( + buffer, + diff, + &edit_op, + edit_index, + abs_path, + action_log.as_ref(), + event_stream, + cx, + ) + .map_err(|e| StreamingEditFileToolOutput::Error { + error: e.to_string(), + })?; + + edit_state.applied_count += 1; + } + + // Feed the in-progress last edit's old_text to the matcher for live preview + if let Some(partial_edit) = edits.last() { + if let Some(old_text) = &partial_edit.old_text { + let old_text_len = old_text.len(); + if old_text_len > edit_state.last_old_text_len { + let new_chunk = &old_text[edit_state.last_old_text_len..]; + + let matcher = edit_state.in_progress_matcher.get_or_insert_with(|| { + let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.text_snapshot()); + StreamingFuzzyMatcher::new(snapshot) + }); + + if let Some(match_range) = matcher.push(new_chunk, None) { + let anchor_range = buffer.read_with(cx, |buffer, _cx| { + buffer.anchor_range_between(match_range.clone()) + }); + diff.update(cx, |card, cx| card.reveal_range(anchor_range, cx)); + } + + edit_state.last_old_text_len = old_text_len; + } + } + } + + Ok(()) + } +} + +fn ensure_buffer_saved( + buffer: &Entity, + abs_path: &PathBuf, + tool: &StreamingEditFileTool, + cx: &mut AsyncApp, +) -> Result<(), StreamingEditFileToolOutput> { + let check_result = tool.thread.update(cx, |thread, cx| { + let last_read = thread.file_read_times.get(abs_path).copied(); + let current = buffer + .read(cx) + .file() + .and_then(|file| file.disk_state().mtime()); + let dirty = buffer.read(cx).is_dirty(); + let has_save = thread.has_tool(SaveFileTool::NAME); + let has_restore = thread.has_tool(RestoreFileFromDiskTool::NAME); + (last_read, current, dirty, has_save, has_restore) + }); + + let Ok((last_read_mtime, current_mtime, is_dirty, has_save_tool, has_restore_tool)) = + check_result + else { + return Ok(()); + }; + + if is_dirty { + let message = match (has_save_tool, has_restore_tool) { + (true, true) => { + "This file has unsaved changes. Ask the user whether they want to keep or discard those changes. \ + If they want to keep them, ask for confirmation then use the save_file tool to save the file, then retry this edit. \ + If they want to discard them, ask for confirmation then use the restore_file_from_disk tool to restore the on-disk contents, then retry this edit." + } + (true, false) => { + "This file has unsaved changes. Ask the user whether they want to keep or discard those changes. \ + If they want to keep them, ask for confirmation then use the save_file tool to save the file, then retry this edit. \ + If they want to discard them, ask the user to manually revert the file, then inform you when it's ok to proceed." + } + (false, true) => { + "This file has unsaved changes. Ask the user whether they want to keep or discard those changes. \ + If they want to keep them, ask the user to manually save the file, then inform you when it's ok to proceed. \ + If they want to discard them, ask for confirmation then use the restore_file_from_disk tool to restore the on-disk contents, then retry this edit." + } + (false, false) => { + "This file has unsaved changes. Ask the user whether they want to keep or discard those changes, \ + then ask them to save or revert the file manually and inform you when it's ok to proceed." + } + }; + return Err(StreamingEditFileToolOutput::Error { + error: message.to_string(), + }); + } + + if let (Some(last_read), Some(current)) = (last_read_mtime, current_mtime) { + if current != last_read { + return Err(StreamingEditFileToolOutput::Error { + error: "The file has been modified since you last read it. \ + Please read the file again to get the current state before editing it." + .to_string(), + }); + } + } + + Ok(()) } #[derive(Debug, Serialize, Deserialize)] @@ -179,24 +655,17 @@ impl StreamingEditFileTool { } } - pub fn with_thread(&self, new_thread: WeakEntity) -> Self { - Self { - project: self.project.clone(), - thread: new_thread, - language_registry: self.language_registry.clone(), - } - } - fn authorize( &self, - input: &StreamingEditFileToolInput, + path: &PathBuf, + description: &str, event_stream: &ToolCallEventStream, cx: &mut App, ) -> Task> { super::tool_permissions::authorize_file_edit( EditFileTool::NAME, - &input.path, - &input.display_description, + path, + description, &self.thread, event_stream, cx, @@ -210,6 +679,10 @@ impl AgentTool for StreamingEditFileTool { const NAME: &'static str = "streaming_edit_file"; + fn supports_input_streaming() -> bool { + true + } + fn kind() -> acp::ToolKind { acp::ToolKind::Edit } @@ -229,28 +702,30 @@ impl AgentTool for StreamingEditFileTool { .read(cx) .short_full_path_for_project_path(&project_path, cx) }) - .unwrap_or(input.path.to_string_lossy().into_owned()) + .unwrap_or(input.path) .into(), Err(raw_input) => { if let Some(input) = serde_json::from_value::(raw_input).ok() { - let path = input.path.trim(); + let path = input.path.unwrap_or_default(); + let path = path.trim(); if !path.is_empty() { return self .project .read(cx) - .find_project_path(&input.path, cx) + .find_project_path(&path, cx) .and_then(|project_path| { self.project .read(cx) .short_full_path_for_project_path(&project_path, cx) }) - .unwrap_or(input.path) + .unwrap_or_else(|| path.to_string()) .into(); } - let description = input.display_description.trim(); + let description = input.display_description.unwrap_or_default(); + let description = description.trim(); if !description.is_empty() { return description.to_string().into(); } @@ -263,230 +738,36 @@ impl AgentTool for StreamingEditFileTool { fn run( self: Arc, - input: ToolInput, + mut input: ToolInput, event_stream: ToolCallEventStream, cx: &mut App, ) -> Task> { cx.spawn(async move |cx: &mut AsyncApp| { - let input = input.recv().await.map_err(|e| { - StreamingEditFileToolOutput::Error { - error: format!("Failed to receive tool input: {e}"), - } - })?; - - let project = self - .thread - .read_with(cx, |thread, _cx| thread.project().clone()) - .map_err(|_| StreamingEditFileToolOutput::Error { - error: "thread was dropped".to_string(), - })?; - - let (project_path, abs_path, authorize) = cx.update(|cx| { - let project_path = - resolve_path(&input, project.clone(), cx).map_err(|err| { - StreamingEditFileToolOutput::Error { - error: err.to_string(), - } - })?; - let abs_path = project.read(cx).absolute_path(&project_path, cx); - if let Some(abs_path) = abs_path.clone() { - event_stream.update_fields( - ToolCallUpdateFields::new() - .locations(vec![acp::ToolCallLocation::new(abs_path)]), - ); - } - let authorize = self.authorize(&input, &event_stream, cx); - Ok::<_, StreamingEditFileToolOutput>((project_path, abs_path, authorize)) - })?; - let result: anyhow::Result = async { - authorize.await?; - - let buffer = project - .update(cx, |project, cx| { - project.open_buffer(project_path.clone(), cx) - }) - .await?; - - if let Some(abs_path) = abs_path.as_ref() { - let (last_read_mtime, current_mtime, is_dirty, has_save_tool, has_restore_tool) = - self.thread.update(cx, |thread, cx| { - let last_read = thread.file_read_times.get(abs_path).copied(); - let current = buffer - .read(cx) - .file() - .and_then(|file| file.disk_state().mtime()); - let dirty = buffer.read(cx).is_dirty(); - let has_save = thread.has_tool(SaveFileTool::NAME); - let has_restore = thread.has_tool(RestoreFileFromDiskTool::NAME); - (last_read, current, dirty, has_save, has_restore) - })?; - - if is_dirty { - let message = match (has_save_tool, has_restore_tool) { - (true, true) => { - "This file has unsaved changes. Ask the user whether they want to keep or discard those changes. \ - If they want to keep them, ask for confirmation then use the save_file tool to save the file, then retry this edit. \ - If they want to discard them, ask for confirmation then use the restore_file_from_disk tool to restore the on-disk contents, then retry this edit." - } - (true, false) => { - "This file has unsaved changes. Ask the user whether they want to keep or discard those changes. \ - If they want to keep them, ask for confirmation then use the save_file tool to save the file, then retry this edit. \ - If they want to discard them, ask the user to manually revert the file, then inform you when it's ok to proceed." - } - (false, true) => { - "This file has unsaved changes. Ask the user whether they want to keep or discard those changes. \ - If they want to keep them, ask the user to manually save the file, then inform you when it's ok to proceed. \ - If they want to discard them, ask for confirmation then use the restore_file_from_disk tool to restore the on-disk contents, then retry this edit." - } - (false, false) => { - "This file has unsaved changes. Ask the user whether they want to keep or discard those changes, \ - then ask them to save or revert the file manually and inform you when it's ok to proceed." - } - }; - anyhow::bail!("{}", message); - } - - if let (Some(last_read), Some(current)) = (last_read_mtime, current_mtime) { - if current != last_read { - anyhow::bail!( - "The file {} has been modified since you last read it. \ - Please read the file again to get the current state before editing it.", - input.path.display() - ); + let mut state = StreamingEditState::Idle; + loop { + futures::select! { + partial = input.recv_partial().fuse() => { + let Some(partial_value) = partial else { break }; + if let Ok(parsed) = serde_json::from_value::(partial_value) { + state.process(parsed, &self, &event_stream, cx).await?; } } - } - - let diff = cx.new(|cx| Diff::new(buffer.clone(), cx)); - event_stream.update_diff(diff.clone()); - let _finalize_diff = util::defer({ - let diff = diff.downgrade(); - let mut cx = cx.clone(); - move || { - diff.update(&mut cx, |diff, cx| diff.finalize(cx)).ok(); - } - }); - - let old_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot()); - let old_text = cx - .background_spawn({ - let old_snapshot = old_snapshot.clone(); - async move { Arc::new(old_snapshot.text()) } - }) - .await; - - let action_log = self.thread.read_with(cx, |thread, _cx| thread.action_log().clone())?; - - // Edit the buffer and report edits to the action log as part of the - // same effect cycle, otherwise the edit will be reported as if the - // user made it (due to the buffer subscription in action_log). - match input.mode { - StreamingEditFileMode::Create | StreamingEditFileMode::Overwrite => { - action_log.update(cx, |log, cx| { - log.buffer_created(buffer.clone(), cx); - }); - let content = input.content.ok_or_else(|| { - anyhow!("'content' field is required for create and overwrite modes") - })?; - cx.update(|cx| { - buffer.update(cx, |buffer, cx| { - buffer.edit([(0..buffer.len(), content.as_str())], None, cx); - }); - action_log.update(cx, |log, cx| { - log.buffer_edited(buffer.clone(), cx); - }); - }); - } - StreamingEditFileMode::Edit => { - action_log.update(cx, |log, cx| { - log.buffer_read(buffer.clone(), cx); - }); - let edits = input.edits.ok_or_else(|| { - anyhow!("'edits' field is required for edit mode") - })?; - // apply_edits now handles buffer_edited internally in the same effect cycle - apply_edits(&buffer, &action_log, &edits, &diff, &event_stream, &abs_path, cx)?; - } - } - - let format_on_save_enabled = buffer.read_with(cx, |buffer, cx| { - let settings = language_settings::language_settings( - buffer.language().map(|l| l.name()), - buffer.file(), - cx, - ); - settings.format_on_save != FormatOnSave::Off - }); - - if format_on_save_enabled { - action_log.update(cx, |log, cx| { - log.buffer_edited(buffer.clone(), cx); - }); - - let format_task = project.update(cx, |project, cx| { - project.format( - HashSet::from_iter([buffer.clone()]), - LspFormatTarget::Buffers, - false, - FormatTrigger::Save, - cx, - ) - }); - futures::select! { - result = format_task.fuse() => { result.log_err(); }, - _ = event_stream.cancelled_by_user().fuse() => { - anyhow::bail!("Edit cancelled by user"); - } - }; - } - - let save_task = project - .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx)); - futures::select! { - result = save_task.fuse() => { result?; }, _ = event_stream.cancelled_by_user().fuse() => { - anyhow::bail!("Edit cancelled by user"); - } - }; - - action_log.update(cx, |log, cx| { - log.buffer_edited(buffer.clone(), cx); - }); - - if let Some(abs_path) = abs_path.as_ref() { - if let Some(new_mtime) = buffer.read_with(cx, |buffer, _| { - buffer.file().and_then(|file| file.disk_state().mtime()) - }) { - self.thread.update(cx, |thread, _| { - thread.file_read_times.insert(abs_path.to_path_buf(), new_mtime); - })?; + return Err(StreamingEditFileToolOutput::Error { + error: "Edit cancelled by user".to_string(), + }); } } + } + let full_input = + input + .recv() + .await + .map_err(|e| StreamingEditFileToolOutput::Error { + error: format!("Failed to receive tool input: {e}"), + })?; - let new_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot()); - let (new_text, unified_diff) = cx - .background_spawn({ - let new_snapshot = new_snapshot.clone(); - let old_text = old_text.clone(); - async move { - let new_text = new_snapshot.text(); - let diff = language::unified_diff(&old_text, &new_text); - (new_text, diff) - } - }) - .await; - - let output = StreamingEditFileToolOutput::Success { - input_path: input.path, - new_text, - old_text, - diff: unified_diff, - }; - - Ok(output) - }.await; - result - .map_err(|e| StreamingEditFileToolOutput::Error { error: e.to_string() }) + state.finalize(full_input, &self, &event_stream, cx).await }) } @@ -526,42 +807,28 @@ fn apply_edits( edits: &[EditOperation], diff: &Entity, event_stream: &ToolCallEventStream, - abs_path: &Option, + abs_path: &PathBuf, cx: &mut AsyncApp, ) -> Result<()> { let mut failed_edits = Vec::new(); let mut ambiguous_edits = Vec::new(); let mut resolved_edits: Vec<(Range, String)> = Vec::new(); - // First pass: resolve all edits without applying them let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot()); for (index, edit) in edits.iter().enumerate() { - let result = resolve_edit(&snapshot, edit); - - match result { - Ok(Some((range, new_text))) => { - // Reveal the range in the diff view - let (start_anchor, end_anchor) = buffer.read_with(cx, |buffer, _cx| { - ( - buffer.anchor_before(range.start), - buffer.anchor_after(range.end), - ) - }); - diff.update(cx, |card, cx| { - card.reveal_range(start_anchor..end_anchor, cx) - }); + match resolve_and_reveal_edit(buffer, diff, &snapshot, edit, cx) { + Ok((range, new_text)) => { resolved_edits.push((range, new_text)); } - Ok(None) => { + Err(EditResolveError::NotFound) => { failed_edits.push(index); } - Err(ranges) => { + Err(EditResolveError::Ambiguous(ranges)) => { ambiguous_edits.push((index, ranges)); } } } - // Check for errors before applying any edits if !failed_edits.is_empty() { let indices = failed_edits .iter() @@ -595,22 +862,17 @@ fn apply_edits( ); } - // Sort edits by position so buffer.edit() can handle offset translation let mut edits_sorted = resolved_edits; edits_sorted.sort_by(|a, b| a.0.start.cmp(&b.0.start)); - // Emit location for the earliest edit in the file if let Some((first_range, _)) = edits_sorted.first() { - if let Some(abs_path) = abs_path.clone() { - let line = snapshot.offset_to_point(first_range.start).row; - event_stream.update_fields( - ToolCallUpdateFields::new() - .locations(vec![ToolCallLocation::new(abs_path).line(Some(line))]), - ); - } + let line = snapshot.offset_to_point(first_range.start).row; + event_stream.update_fields( + ToolCallUpdateFields::new() + .locations(vec![ToolCallLocation::new(abs_path).line(Some(line))]), + ); } - // Validate no overlaps (sorted ascending by start) for window in edits_sorted.windows(2) { if let [(earlier_range, _), (later_range, _)] = window && (earlier_range.end > later_range.start || earlier_range.start == later_range.start) @@ -630,9 +892,6 @@ fn apply_edits( } } - // Apply all edits in a single batch and report to action_log in the same - // effect cycle. This prevents the buffer subscription from treating these - // as user edits. if !edits_sorted.is_empty() { cx.update(|cx| { buffer.update(cx, |buffer, cx| { @@ -653,40 +912,111 @@ fn apply_edits( Ok(()) } -/// Resolves an edit operation by finding the matching text in the buffer. -/// Returns Ok(Some((range, new_text))) if a unique match is found, -/// Ok(None) if no match is found, or Err(ranges) if multiple matches are found. -fn resolve_edit( - snapshot: &BufferSnapshot, +/// Resolves, reveals, and applies a single edit to the buffer. Emits +/// a location update and reports the change to the action log. +fn resolve_reveal_and_apply_edit( + buffer: &Entity, + diff: &Entity, edit: &EditOperation, -) -> std::result::Result, String)>, Vec>> { - let mut matcher = StreamingFuzzyMatcher::new(snapshot.clone()); - matcher.push(&edit.old_text, None); - let matches = matcher.finish(); - - if matches.is_empty() { - return Ok(None); - } + edit_index: usize, + abs_path: &PathBuf, + action_log: Option<&Entity>, + event_stream: &ToolCallEventStream, + cx: &mut AsyncApp, +) -> Result<()> { + let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot()); - if matches.len() > 1 { - return Err(matches); - } + match resolve_and_reveal_edit(buffer, diff, &snapshot, edit, cx) { + Ok((range, new_text)) => { + let line = snapshot.offset_to_point(range.start).row; + event_stream.update_fields( + ToolCallUpdateFields::new() + .locations(vec![ToolCallLocation::new(abs_path).line(Some(line))]), + ); - let match_range = matches.into_iter().next().expect("checked len above"); - Ok(Some((match_range, edit.new_text.clone()))) + if let Some(action_log) = action_log { + cx.update(|cx| { + buffer.update(cx, |buffer, cx| { + buffer.edit([(range, new_text.as_str())], None, cx); + }); + action_log.update(cx, |log, cx| { + log.buffer_edited(buffer.clone(), cx); + }); + }); + } + + Ok(()) + } + Err(EditResolveError::NotFound) => { + anyhow::bail!( + "Could not find matching text for edit at index {}. \ + The old_text did not match any content in the file. \ + Please read the file again to get the current content.", + edit_index + ); + } + Err(EditResolveError::Ambiguous(ranges)) => { + let lines = ranges + .iter() + .map(|r| (snapshot.offset_to_point(r.start).row + 1).to_string()) + .collect::>() + .join(", "); + anyhow::bail!( + "Edit {} matched multiple locations in the file at lines: {}. \ + Please provide more context in old_text to uniquely identify the location.", + edit_index, + lines + ); + } + } +} + +enum EditResolveError { + NotFound, + Ambiguous(Vec>), +} + +/// Resolves an edit operation by finding matching text in the buffer, +/// reveals the matched range in the diff view, and returns the resolved +/// range and replacement text. +fn resolve_and_reveal_edit( + buffer: &Entity, + diff: &Entity, + snapshot: &BufferSnapshot, + edit: &EditOperation, + cx: &mut AsyncApp, +) -> std::result::Result<(Range, String), EditResolveError> { + let mut matcher = StreamingFuzzyMatcher::new(snapshot.clone()); + matcher.push(&edit.old_text, None); + let matches = matcher.finish(); + if matches.is_empty() { + return Err(EditResolveError::NotFound); + } + if matches.len() > 1 { + return Err(EditResolveError::Ambiguous(matches)); + } + + let range = matches.into_iter().next().expect("checked len above"); + + let anchor_range = + buffer.read_with(cx, |buffer, _cx| buffer.anchor_range_between(range.clone())); + diff.update(cx, |card, cx| card.reveal_range(anchor_range, cx)); + + Ok((range, edit.new_text.clone())) } fn resolve_path( - input: &StreamingEditFileToolInput, - project: Entity, + mode: StreamingEditFileMode, + path: &PathBuf, + project: &Entity, cx: &mut App, ) -> Result { let project = project.read(cx); - match input.mode { + match mode { StreamingEditFileMode::Edit | StreamingEditFileMode::Overwrite => { let path = project - .find_project_path(&input.path, cx) + .find_project_path(&path, cx) .context("Can't edit file: path not found")?; let entry = project @@ -698,17 +1028,14 @@ fn resolve_path( } StreamingEditFileMode::Create => { - if let Some(path) = project.find_project_path(&input.path, cx) { + if let Some(path) = project.find_project_path(&path, cx) { anyhow::ensure!( project.entry_for_path(&path, cx).is_none(), "Can't create file: file already exists" ); } - let parent_path = input - .path - .parent() - .context("Can't create file: incorrect path")?; + let parent_path = path.parent().context("Can't create file: incorrect path")?; let parent_project_path = project.find_project_path(&parent_path, cx); @@ -722,8 +1049,7 @@ fn resolve_path( "Can't create file: parent is not a directory" ); - let file_name = input - .path + let file_name = path .file_name() .and_then(|file_name| file_name.to_str()) .and_then(|file_name| RelPath::unix(file_name).ok()) @@ -742,13 +1068,17 @@ fn resolve_path( #[cfg(test)] mod tests { use super::*; - use crate::{ContextServerRegistry, Templates}; + use crate::{ContextServerRegistry, Templates, ToolInputSender}; + use fs::Fs as _; + use futures::StreamExt as _; use gpui::{TestAppContext, UpdateGlobal}; use language_model::fake_provider::FakeLanguageModel; use prompt_store::ProjectContext; use serde_json::json; + use settings::Settings; use settings::SettingsStore; use util::path; + use util::rel_path::rel_path; #[gpui::test] async fn test_streaming_edit_create_file(cx: &mut TestAppContext) { @@ -1302,6 +1632,2759 @@ mod tests { ); } + #[gpui::test] + async fn test_streaming_early_buffer_open(cx: &mut TestAppContext) { + init_test(cx); + + let fs = project::FakeFs::new(cx.executor()); + fs.insert_tree( + "/root", + json!({ + "file.txt": "line 1\nline 2\nline 3\n" + }), + ) + .await; + let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; + let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); + let context_server_registry = + cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); + let model = Arc::new(FakeLanguageModel::default()); + let thread = cx.new(|cx| { + crate::Thread::new( + project.clone(), + cx.new(|_cx| ProjectContext::default()), + context_server_registry, + Templates::new(), + Some(model), + cx, + ) + }); + + let (sender, input) = ToolInput::::test(); + let (event_stream, _receiver) = ToolCallEventStream::test(); + + let tool = Arc::new(StreamingEditFileTool::new( + project.clone(), + thread.downgrade(), + language_registry, + )); + + let task = cx.update(|cx| tool.run(input, event_stream, cx)); + + // Send partials simulating LLM streaming: description first, then path, then mode + sender.send_partial(json!({"display_description": "Edit lines"})); + cx.run_until_parked(); + + sender.send_partial(json!({ + "display_description": "Edit lines", + "path": "root/file.txt" + })); + cx.run_until_parked(); + + // Path is NOT yet complete because mode hasn't appeared — no buffer open yet + sender.send_partial(json!({ + "display_description": "Edit lines", + "path": "root/file.txt", + "mode": "edit" + })); + cx.run_until_parked(); + + // Now send the final complete input + sender.send_final(json!({ + "display_description": "Edit lines", + "path": "root/file.txt", + "mode": "edit", + "edits": [{"old_text": "line 2", "new_text": "modified line 2"}] + })); + + let result = task.await; + let StreamingEditFileToolOutput::Success { new_text, .. } = result.unwrap() else { + panic!("expected success"); + }; + assert_eq!(new_text, "line 1\nmodified line 2\nline 3\n"); + } + + #[gpui::test] + async fn test_streaming_path_completeness_heuristic(cx: &mut TestAppContext) { + init_test(cx); + + let fs = project::FakeFs::new(cx.executor()); + fs.insert_tree( + "/root", + json!({ + "file.txt": "hello world" + }), + ) + .await; + let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; + let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); + let context_server_registry = + cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); + let model = Arc::new(FakeLanguageModel::default()); + let thread = cx.new(|cx| { + crate::Thread::new( + project.clone(), + cx.new(|_cx| ProjectContext::default()), + context_server_registry, + Templates::new(), + Some(model), + cx, + ) + }); + + let (sender, input) = ToolInput::::test(); + let (event_stream, _receiver) = ToolCallEventStream::test(); + + let tool = Arc::new(StreamingEditFileTool::new( + project.clone(), + thread.downgrade(), + language_registry, + )); + + let task = cx.update(|cx| tool.run(input, event_stream, cx)); + + // Send partial with path but NO mode — path should NOT be treated as complete + sender.send_partial(json!({ + "display_description": "Overwrite file", + "path": "root/file" + })); + cx.run_until_parked(); + + // Now the path grows and mode appears + sender.send_partial(json!({ + "display_description": "Overwrite file", + "path": "root/file.txt", + "mode": "overwrite" + })); + cx.run_until_parked(); + + // Send final + sender.send_final(json!({ + "display_description": "Overwrite file", + "path": "root/file.txt", + "mode": "overwrite", + "content": "new content" + })); + + let result = task.await; + let StreamingEditFileToolOutput::Success { new_text, .. } = result.unwrap() else { + panic!("expected success"); + }; + assert_eq!(new_text, "new content"); + } + + #[gpui::test] + async fn test_streaming_cancellation_during_partials(cx: &mut TestAppContext) { + init_test(cx); + + let fs = project::FakeFs::new(cx.executor()); + fs.insert_tree( + "/root", + json!({ + "file.txt": "hello world" + }), + ) + .await; + let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; + let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); + let context_server_registry = + cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); + let model = Arc::new(FakeLanguageModel::default()); + let thread = cx.new(|cx| { + crate::Thread::new( + project.clone(), + cx.new(|_cx| ProjectContext::default()), + context_server_registry, + Templates::new(), + Some(model), + cx, + ) + }); + + let (sender, input) = ToolInput::::test(); + let (event_stream, _receiver, mut cancellation_tx) = + ToolCallEventStream::test_with_cancellation(); + + let tool = Arc::new(StreamingEditFileTool::new( + project.clone(), + thread.downgrade(), + language_registry, + )); + + let task = cx.update(|cx| tool.run(input, event_stream, cx)); + + // Send a partial + sender.send_partial(json!({"display_description": "Edit"})); + cx.run_until_parked(); + + // Cancel during streaming + ToolCallEventStream::signal_cancellation_with_sender(&mut cancellation_tx); + cx.run_until_parked(); + + // The sender is still alive so the partial loop should detect cancellation + // We need to drop the sender to also unblock recv() if the loop didn't catch it + drop(sender); + + let result = task.await; + let StreamingEditFileToolOutput::Error { error } = result.unwrap_err() else { + panic!("expected error"); + }; + assert!( + error.contains("cancelled"), + "Expected cancellation error but got: {error}" + ); + } + + #[gpui::test] + async fn test_streaming_edit_with_multiple_partials(cx: &mut TestAppContext) { + init_test(cx); + + let fs = project::FakeFs::new(cx.executor()); + fs.insert_tree( + "/root", + json!({ + "file.txt": "line 1\nline 2\nline 3\nline 4\nline 5\n" + }), + ) + .await; + let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; + let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); + let context_server_registry = + cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); + let model = Arc::new(FakeLanguageModel::default()); + let thread = cx.new(|cx| { + crate::Thread::new( + project.clone(), + cx.new(|_cx| ProjectContext::default()), + context_server_registry, + Templates::new(), + Some(model), + cx, + ) + }); + + let (sender, input) = ToolInput::::test(); + let (event_stream, _receiver) = ToolCallEventStream::test(); + + let tool = Arc::new(StreamingEditFileTool::new( + project.clone(), + thread.downgrade(), + language_registry, + )); + + let task = cx.update(|cx| tool.run(input, event_stream, cx)); + + // Simulate fine-grained streaming of the JSON + sender.send_partial(json!({"display_description": "Edit multiple"})); + cx.run_until_parked(); + + sender.send_partial(json!({ + "display_description": "Edit multiple lines", + "path": "root/file.txt" + })); + cx.run_until_parked(); + + sender.send_partial(json!({ + "display_description": "Edit multiple lines", + "path": "root/file.txt", + "mode": "edit" + })); + cx.run_until_parked(); + + sender.send_partial(json!({ + "display_description": "Edit multiple lines", + "path": "root/file.txt", + "mode": "edit", + "edits": [{"old_text": "line 1"}] + })); + cx.run_until_parked(); + + sender.send_partial(json!({ + "display_description": "Edit multiple lines", + "path": "root/file.txt", + "mode": "edit", + "edits": [ + {"old_text": "line 1", "new_text": "modified line 1"}, + {"old_text": "line 5"} + ] + })); + cx.run_until_parked(); + + // Send final complete input + sender.send_final(json!({ + "display_description": "Edit multiple lines", + "path": "root/file.txt", + "mode": "edit", + "edits": [ + {"old_text": "line 1", "new_text": "modified line 1"}, + {"old_text": "line 5", "new_text": "modified line 5"} + ] + })); + + let result = task.await; + let StreamingEditFileToolOutput::Success { new_text, .. } = result.unwrap() else { + panic!("expected success"); + }; + assert_eq!( + new_text, + "modified line 1\nline 2\nline 3\nline 4\nmodified line 5\n" + ); + } + + #[gpui::test] + async fn test_streaming_create_file_with_partials(cx: &mut TestAppContext) { + init_test(cx); + + let fs = project::FakeFs::new(cx.executor()); + fs.insert_tree("/root", json!({"dir": {}})).await; + let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; + let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); + let context_server_registry = + cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); + let model = Arc::new(FakeLanguageModel::default()); + let thread = cx.new(|cx| { + crate::Thread::new( + project.clone(), + cx.new(|_cx| ProjectContext::default()), + context_server_registry, + Templates::new(), + Some(model), + cx, + ) + }); + + let (sender, input) = ToolInput::::test(); + let (event_stream, _receiver) = ToolCallEventStream::test(); + + let tool = Arc::new(StreamingEditFileTool::new( + project.clone(), + thread.downgrade(), + language_registry, + )); + + let task = cx.update(|cx| tool.run(input, event_stream, cx)); + + // Stream partials for create mode + sender.send_partial(json!({"display_description": "Create new file"})); + cx.run_until_parked(); + + sender.send_partial(json!({ + "display_description": "Create new file", + "path": "root/dir/new_file.txt", + "mode": "create" + })); + cx.run_until_parked(); + + sender.send_partial(json!({ + "display_description": "Create new file", + "path": "root/dir/new_file.txt", + "mode": "create", + "content": "Hello, " + })); + cx.run_until_parked(); + + // Final with full content + sender.send_final(json!({ + "display_description": "Create new file", + "path": "root/dir/new_file.txt", + "mode": "create", + "content": "Hello, World!" + })); + + let result = task.await; + let StreamingEditFileToolOutput::Success { new_text, .. } = result.unwrap() else { + panic!("expected success"); + }; + assert_eq!(new_text, "Hello, World!"); + } + + #[gpui::test] + async fn test_streaming_no_partials_direct_final(cx: &mut TestAppContext) { + init_test(cx); + + let fs = project::FakeFs::new(cx.executor()); + fs.insert_tree( + "/root", + json!({ + "file.txt": "line 1\nline 2\nline 3\n" + }), + ) + .await; + let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; + let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); + let context_server_registry = + cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); + let model = Arc::new(FakeLanguageModel::default()); + let thread = cx.new(|cx| { + crate::Thread::new( + project.clone(), + cx.new(|_cx| ProjectContext::default()), + context_server_registry, + Templates::new(), + Some(model), + cx, + ) + }); + + let (sender, input) = ToolInput::::test(); + let (event_stream, _receiver) = ToolCallEventStream::test(); + + let tool = Arc::new(StreamingEditFileTool::new( + project.clone(), + thread.downgrade(), + language_registry, + )); + + let task = cx.update(|cx| tool.run(input, event_stream, cx)); + + // Send final immediately with no partials (simulates non-streaming path) + sender.send_final(json!({ + "display_description": "Edit lines", + "path": "root/file.txt", + "mode": "edit", + "edits": [{"old_text": "line 2", "new_text": "modified line 2"}] + })); + + let result = task.await; + let StreamingEditFileToolOutput::Success { new_text, .. } = result.unwrap() else { + panic!("expected success"); + }; + assert_eq!(new_text, "line 1\nmodified line 2\nline 3\n"); + } + + #[gpui::test] + async fn test_streaming_incremental_edit_application(cx: &mut TestAppContext) { + init_test(cx); + + let fs = project::FakeFs::new(cx.executor()); + fs.insert_tree( + "/root", + json!({ + "file.txt": "line 1\nline 2\nline 3\nline 4\nline 5\n" + }), + ) + .await; + let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; + let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); + let context_server_registry = + cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); + let model = Arc::new(FakeLanguageModel::default()); + let thread = cx.new(|cx| { + crate::Thread::new( + project.clone(), + cx.new(|_cx| ProjectContext::default()), + context_server_registry, + Templates::new(), + Some(model), + cx, + ) + }); + + let (sender, input) = ToolInput::::test(); + let (event_stream, _receiver) = ToolCallEventStream::test(); + + let tool = Arc::new(StreamingEditFileTool::new( + project.clone(), + thread.downgrade(), + language_registry, + )); + + let task = cx.update(|cx| tool.run(input, event_stream, cx)); + + // Stream description, path, mode + sender.send_partial(json!({"display_description": "Edit multiple lines"})); + cx.run_until_parked(); + + sender.send_partial(json!({ + "display_description": "Edit multiple lines", + "path": "root/file.txt", + "mode": "edit" + })); + cx.run_until_parked(); + + // First edit starts streaming (old_text only, still in progress) + sender.send_partial(json!({ + "display_description": "Edit multiple lines", + "path": "root/file.txt", + "mode": "edit", + "edits": [{"old_text": "line 1"}] + })); + cx.run_until_parked(); + + // Buffer should not have changed yet — the first edit is still in progress + // (no second edit has appeared to prove the first is complete) + let buffer_text = project.update(cx, |project, cx| { + let project_path = project.find_project_path(&PathBuf::from("root/file.txt"), cx); + project_path.and_then(|pp| { + project + .get_open_buffer(&pp, cx) + .map(|buffer| buffer.read(cx).text()) + }) + }); + // Buffer is open (from streaming) but edit 1 is still in-progress + assert_eq!( + buffer_text.as_deref(), + Some("line 1\nline 2\nline 3\nline 4\nline 5\n"), + "Buffer should not be modified while first edit is still in progress" + ); + + // Second edit appears — this proves the first edit is complete, so it + // should be applied immediately during streaming + sender.send_partial(json!({ + "display_description": "Edit multiple lines", + "path": "root/file.txt", + "mode": "edit", + "edits": [ + {"old_text": "line 1", "new_text": "MODIFIED 1"}, + {"old_text": "line 5"} + ] + })); + cx.run_until_parked(); + + // First edit should now be applied to the buffer + let buffer_text = project.update(cx, |project, cx| { + let project_path = project.find_project_path(&PathBuf::from("root/file.txt"), cx); + project_path.and_then(|pp| { + project + .get_open_buffer(&pp, cx) + .map(|buffer| buffer.read(cx).text()) + }) + }); + assert_eq!( + buffer_text.as_deref(), + Some("MODIFIED 1\nline 2\nline 3\nline 4\nline 5\n"), + "First edit should be applied during streaming when second edit appears" + ); + + // Send final complete input + sender.send_final(json!({ + "display_description": "Edit multiple lines", + "path": "root/file.txt", + "mode": "edit", + "edits": [ + {"old_text": "line 1", "new_text": "MODIFIED 1"}, + {"old_text": "line 5", "new_text": "MODIFIED 5"} + ] + })); + + let result = task.await; + let StreamingEditFileToolOutput::Success { + new_text, old_text, .. + } = result.unwrap() + else { + panic!("expected success"); + }; + assert_eq!(new_text, "MODIFIED 1\nline 2\nline 3\nline 4\nMODIFIED 5\n"); + assert_eq!( + *old_text, "line 1\nline 2\nline 3\nline 4\nline 5\n", + "old_text should reflect the original file content before any edits" + ); + } + + #[gpui::test] + async fn test_streaming_incremental_three_edits(cx: &mut TestAppContext) { + init_test(cx); + + let fs = project::FakeFs::new(cx.executor()); + fs.insert_tree( + "/root", + json!({ + "file.txt": "aaa\nbbb\nccc\nddd\neee\n" + }), + ) + .await; + let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; + let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); + let context_server_registry = + cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); + let model = Arc::new(FakeLanguageModel::default()); + let thread = cx.new(|cx| { + crate::Thread::new( + project.clone(), + cx.new(|_cx| ProjectContext::default()), + context_server_registry, + Templates::new(), + Some(model), + cx, + ) + }); + + let (sender, input) = ToolInput::::test(); + let (event_stream, _receiver) = ToolCallEventStream::test(); + + let tool = Arc::new(StreamingEditFileTool::new( + project.clone(), + thread.downgrade(), + language_registry, + )); + + let task = cx.update(|cx| tool.run(input, event_stream, cx)); + + // Setup: description + path + mode + sender.send_partial(json!({ + "display_description": "Edit three lines", + "path": "root/file.txt", + "mode": "edit" + })); + cx.run_until_parked(); + + // Edit 1 in progress + sender.send_partial(json!({ + "display_description": "Edit three lines", + "path": "root/file.txt", + "mode": "edit", + "edits": [{"old_text": "aaa", "new_text": "AAA"}] + })); + cx.run_until_parked(); + + // Edit 2 appears — edit 1 is now complete and should be applied + sender.send_partial(json!({ + "display_description": "Edit three lines", + "path": "root/file.txt", + "mode": "edit", + "edits": [ + {"old_text": "aaa", "new_text": "AAA"}, + {"old_text": "ccc", "new_text": "CCC"} + ] + })); + cx.run_until_parked(); + + // Verify edit 1 applied + let buffer_text = project.update(cx, |project, cx| { + let pp = project + .find_project_path(&PathBuf::from("root/file.txt"), cx) + .unwrap(); + project.get_open_buffer(&pp, cx).map(|b| b.read(cx).text()) + }); + assert_eq!(buffer_text.as_deref(), Some("AAA\nbbb\nccc\nddd\neee\n")); + + // Edit 3 appears — edit 2 is now complete and should be applied + sender.send_partial(json!({ + "display_description": "Edit three lines", + "path": "root/file.txt", + "mode": "edit", + "edits": [ + {"old_text": "aaa", "new_text": "AAA"}, + {"old_text": "ccc", "new_text": "CCC"}, + {"old_text": "eee", "new_text": "EEE"} + ] + })); + cx.run_until_parked(); + + // Verify edits 1 and 2 both applied + let buffer_text = project.update(cx, |project, cx| { + let pp = project + .find_project_path(&PathBuf::from("root/file.txt"), cx) + .unwrap(); + project.get_open_buffer(&pp, cx).map(|b| b.read(cx).text()) + }); + assert_eq!(buffer_text.as_deref(), Some("AAA\nbbb\nCCC\nddd\neee\n")); + + // Send final + sender.send_final(json!({ + "display_description": "Edit three lines", + "path": "root/file.txt", + "mode": "edit", + "edits": [ + {"old_text": "aaa", "new_text": "AAA"}, + {"old_text": "ccc", "new_text": "CCC"}, + {"old_text": "eee", "new_text": "EEE"} + ] + })); + + let result = task.await; + let StreamingEditFileToolOutput::Success { new_text, .. } = result.unwrap() else { + panic!("expected success"); + }; + assert_eq!(new_text, "AAA\nbbb\nCCC\nddd\nEEE\n"); + } + + #[gpui::test] + async fn test_streaming_edit_failure_mid_stream(cx: &mut TestAppContext) { + init_test(cx); + + let fs = project::FakeFs::new(cx.executor()); + fs.insert_tree( + "/root", + json!({ + "file.txt": "line 1\nline 2\nline 3\n" + }), + ) + .await; + let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; + let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); + let context_server_registry = + cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); + let model = Arc::new(FakeLanguageModel::default()); + let thread = cx.new(|cx| { + crate::Thread::new( + project.clone(), + cx.new(|_cx| ProjectContext::default()), + context_server_registry, + Templates::new(), + Some(model), + cx, + ) + }); + + let (sender, input) = ToolInput::::test(); + let (event_stream, _receiver) = ToolCallEventStream::test(); + + let tool = Arc::new(StreamingEditFileTool::new( + project.clone(), + thread.downgrade(), + language_registry, + )); + + let task = cx.update(|cx| tool.run(input, event_stream, cx)); + + // Setup + sender.send_partial(json!({ + "display_description": "Edit lines", + "path": "root/file.txt", + "mode": "edit" + })); + cx.run_until_parked(); + + // Edit 1 (valid) in progress — not yet complete (no second edit) + sender.send_partial(json!({ + "display_description": "Edit lines", + "path": "root/file.txt", + "mode": "edit", + "edits": [ + {"old_text": "line 1", "new_text": "MODIFIED"} + ] + })); + cx.run_until_parked(); + + // Edit 2 appears (will fail to match) — this makes edit 1 complete. + // Edit 1 should be applied. Edit 2 is still in-progress (last edit). + sender.send_partial(json!({ + "display_description": "Edit lines", + "path": "root/file.txt", + "mode": "edit", + "edits": [ + {"old_text": "line 1", "new_text": "MODIFIED"}, + {"old_text": "nonexistent text that does not appear anywhere in the file at all", "new_text": "whatever"} + ] + })); + cx.run_until_parked(); + + // Verify edit 1 was applied + let buffer_text = project.update(cx, |project, cx| { + let pp = project + .find_project_path(&PathBuf::from("root/file.txt"), cx) + .unwrap(); + project.get_open_buffer(&pp, cx).map(|b| b.read(cx).text()) + }); + assert_eq!( + buffer_text.as_deref(), + Some("MODIFIED\nline 2\nline 3\n"), + "First edit should be applied even though second edit will fail" + ); + + // Edit 3 appears — this makes edit 2 "complete", triggering its + // resolution which should fail (old_text doesn't exist in the file). + sender.send_partial(json!({ + "display_description": "Edit lines", + "path": "root/file.txt", + "mode": "edit", + "edits": [ + {"old_text": "line 1", "new_text": "MODIFIED"}, + {"old_text": "nonexistent text that does not appear anywhere in the file at all", "new_text": "whatever"}, + {"old_text": "line 3", "new_text": "MODIFIED 3"} + ] + })); + cx.run_until_parked(); + + // The error from edit 2 should have propagated out of the partial loop. + // Drop sender to unblock recv() if the loop didn't catch it. + drop(sender); + + let result = task.await; + let StreamingEditFileToolOutput::Error { error } = result.unwrap_err() else { + panic!("expected error"); + }; + assert!( + error.contains("Could not find matching text for edit at index 1"), + "Expected error about edit 1 failing, got: {error}" + ); + } + + #[gpui::test] + async fn test_streaming_overlapping_edits_detected_naturally(cx: &mut TestAppContext) { + init_test(cx); + + let fs = project::FakeFs::new(cx.executor()); + fs.insert_tree( + "/root", + json!({ + "file.txt": "line 1\nline 2\nline 3\n" + }), + ) + .await; + let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; + let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); + let context_server_registry = + cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); + let model = Arc::new(FakeLanguageModel::default()); + let thread = cx.new(|cx| { + crate::Thread::new( + project.clone(), + cx.new(|_cx| ProjectContext::default()), + context_server_registry, + Templates::new(), + Some(model), + cx, + ) + }); + + let (sender, input) = ToolInput::::test(); + let (event_stream, _receiver) = ToolCallEventStream::test(); + + let tool = Arc::new(StreamingEditFileTool::new( + project.clone(), + thread.downgrade(), + language_registry, + )); + + let task = cx.update(|cx| tool.run(input, event_stream, cx)); + + // Setup + sender.send_partial(json!({ + "display_description": "Overlapping edits", + "path": "root/file.txt", + "mode": "edit" + })); + cx.run_until_parked(); + + // Edit 1 targets "line 1\nline 2" and replaces it. + // Edit 2 targets "line 2\nline 3" — but after edit 1 is applied, + // "line 2" has been removed so this should fail to match. + // Edit 3 exists to make edit 2 "complete" during streaming. + sender.send_partial(json!({ + "display_description": "Overlapping edits", + "path": "root/file.txt", + "mode": "edit", + "edits": [ + {"old_text": "line 1\nline 2", "new_text": "REPLACED"}, + {"old_text": "line 2\nline 3", "new_text": "ALSO REPLACED"}, + {"old_text": "line 3", "new_text": "DUMMY"} + ] + })); + cx.run_until_parked(); + + // Edit 1 was applied, edit 2 should fail since "line 2" no longer exists + drop(sender); + + let result = task.await; + let StreamingEditFileToolOutput::Error { error } = result.unwrap_err() else { + panic!("expected error"); + }; + assert!( + error.contains("Could not find matching text for edit at index 1"), + "Expected overlapping edit to fail naturally, got: {error}" + ); + } + + #[gpui::test] + async fn test_streaming_single_edit_no_incremental(cx: &mut TestAppContext) { + init_test(cx); + + let fs = project::FakeFs::new(cx.executor()); + fs.insert_tree( + "/root", + json!({ + "file.txt": "hello world\n" + }), + ) + .await; + let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; + let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); + let context_server_registry = + cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); + let model = Arc::new(FakeLanguageModel::default()); + let thread = cx.new(|cx| { + crate::Thread::new( + project.clone(), + cx.new(|_cx| ProjectContext::default()), + context_server_registry, + Templates::new(), + Some(model), + cx, + ) + }); + + let (sender, input) = ToolInput::::test(); + let (event_stream, _receiver) = ToolCallEventStream::test(); + + let tool = Arc::new(StreamingEditFileTool::new( + project.clone(), + thread.downgrade(), + language_registry, + )); + + let task = cx.update(|cx| tool.run(input, event_stream, cx)); + + // Setup + single edit that stays in-progress (no second edit to prove completion) + sender.send_partial(json!({ + "display_description": "Single edit", + "path": "root/file.txt", + "mode": "edit", + "edits": [{"old_text": "hello world", "new_text": "goodbye world"}] + })); + cx.run_until_parked(); + + // Buffer should NOT be modified — the single edit is still in-progress + let buffer_text = project.update(cx, |project, cx| { + let pp = project + .find_project_path(&PathBuf::from("root/file.txt"), cx) + .unwrap(); + project.get_open_buffer(&pp, cx).map(|b| b.read(cx).text()) + }); + assert_eq!( + buffer_text.as_deref(), + Some("hello world\n"), + "Single in-progress edit should not be applied during streaming" + ); + + // Send final — the edit is applied during finalization + sender.send_final(json!({ + "display_description": "Single edit", + "path": "root/file.txt", + "mode": "edit", + "edits": [{"old_text": "hello world", "new_text": "goodbye world"}] + })); + + let result = task.await; + let StreamingEditFileToolOutput::Success { new_text, .. } = result.unwrap() else { + panic!("expected success"); + }; + assert_eq!(new_text, "goodbye world\n"); + } + + #[gpui::test] + async fn test_streaming_input_partials_then_final(cx: &mut TestAppContext) { + init_test(cx); + + let fs = project::FakeFs::new(cx.executor()); + fs.insert_tree( + "/root", + json!({ + "file.txt": "line 1\nline 2\nline 3\n" + }), + ) + .await; + let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; + let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); + let context_server_registry = + cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); + let model = Arc::new(FakeLanguageModel::default()); + let thread = cx.new(|cx| { + crate::Thread::new( + project.clone(), + cx.new(|_cx| ProjectContext::default()), + context_server_registry, + Templates::new(), + Some(model), + cx, + ) + }); + + let (sender, input): (ToolInputSender, ToolInput) = + ToolInput::test(); + + let (event_stream, _event_rx) = ToolCallEventStream::test(); + let task = cx.update(|cx| { + Arc::new(StreamingEditFileTool::new( + project.clone(), + thread.downgrade(), + language_registry, + )) + .run(input, event_stream, cx) + }); + + // Send progressively more complete partial snapshots, as the LLM would + sender.send_partial(json!({ + "display_description": "Edit lines" + })); + cx.run_until_parked(); + + sender.send_partial(json!({ + "display_description": "Edit lines", + "path": "root/file.txt", + "mode": "edit" + })); + cx.run_until_parked(); + + sender.send_partial(json!({ + "display_description": "Edit lines", + "path": "root/file.txt", + "mode": "edit", + "edits": [{"old_text": "line 2", "new_text": "modified line 2"}] + })); + cx.run_until_parked(); + + // Send the final complete input + sender.send_final(json!({ + "display_description": "Edit lines", + "path": "root/file.txt", + "mode": "edit", + "edits": [{"old_text": "line 2", "new_text": "modified line 2"}] + })); + + let result = task.await; + let StreamingEditFileToolOutput::Success { new_text, .. } = result.unwrap() else { + panic!("expected success"); + }; + assert_eq!(new_text, "line 1\nmodified line 2\nline 3\n"); + } + + #[gpui::test] + async fn test_streaming_input_sender_dropped_before_final(cx: &mut TestAppContext) { + init_test(cx); + + let fs = project::FakeFs::new(cx.executor()); + fs.insert_tree( + "/root", + json!({ + "file.txt": "hello world\n" + }), + ) + .await; + let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; + let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); + let context_server_registry = + cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); + let model = Arc::new(FakeLanguageModel::default()); + let thread = cx.new(|cx| { + crate::Thread::new( + project.clone(), + cx.new(|_cx| ProjectContext::default()), + context_server_registry, + Templates::new(), + Some(model), + cx, + ) + }); + + let (sender, input): (ToolInputSender, ToolInput) = + ToolInput::test(); + + let (event_stream, _event_rx) = ToolCallEventStream::test(); + let task = cx.update(|cx| { + Arc::new(StreamingEditFileTool::new( + project.clone(), + thread.downgrade(), + language_registry, + )) + .run(input, event_stream, cx) + }); + + // Send a partial then drop the sender without sending final + sender.send_partial(json!({ + "display_description": "Edit file" + })); + cx.run_until_parked(); + + drop(sender); + + let result = task.await; + assert!( + result.is_err(), + "Tool should error when sender is dropped without sending final input" + ); + } + + #[gpui::test] + async fn test_streaming_input_recv_drains_partials(cx: &mut TestAppContext) { + init_test(cx); + + let fs = project::FakeFs::new(cx.executor()); + fs.insert_tree("/root", json!({"dir": {}})).await; + let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; + let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); + let context_server_registry = + cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); + let model = Arc::new(FakeLanguageModel::default()); + let thread = cx.new(|cx| { + crate::Thread::new( + project.clone(), + cx.new(|_cx| ProjectContext::default()), + context_server_registry, + Templates::new(), + Some(model), + cx, + ) + }); + + // Create a channel and send multiple partials before a final, then use + // ToolInput::resolved-style immediate delivery to confirm recv() works + // when partials are already buffered. + let (sender, input): (ToolInputSender, ToolInput) = + ToolInput::test(); + + let (event_stream, _event_rx) = ToolCallEventStream::test(); + let task = cx.update(|cx| { + Arc::new(StreamingEditFileTool::new( + project.clone(), + thread.downgrade(), + language_registry, + )) + .run(input, event_stream, cx) + }); + + // Buffer several partials before sending the final + sender.send_partial(json!({"display_description": "Create"})); + sender.send_partial(json!({"display_description": "Create", "path": "root/dir/new.txt"})); + sender.send_partial(json!({ + "display_description": "Create", + "path": "root/dir/new.txt", + "mode": "create" + })); + sender.send_final(json!({ + "display_description": "Create", + "path": "root/dir/new.txt", + "mode": "create", + "content": "streamed content" + })); + + let result = task.await; + let StreamingEditFileToolOutput::Success { new_text, .. } = result.unwrap() else { + panic!("expected success"); + }; + assert_eq!(new_text, "streamed content"); + } + + #[gpui::test] + async fn test_streaming_resolve_path_for_creating_file(cx: &mut TestAppContext) { + let mode = StreamingEditFileMode::Create; + + let result = test_resolve_path(&mode, "root/new.txt", cx); + assert_resolved_path_eq(result.await, rel_path("new.txt")); + + let result = test_resolve_path(&mode, "new.txt", cx); + assert_resolved_path_eq(result.await, rel_path("new.txt")); + + let result = test_resolve_path(&mode, "dir/new.txt", cx); + assert_resolved_path_eq(result.await, rel_path("dir/new.txt")); + + let result = test_resolve_path(&mode, "root/dir/subdir/existing.txt", cx); + assert_eq!( + result.await.unwrap_err().to_string(), + "Can't create file: file already exists" + ); + + let result = test_resolve_path(&mode, "root/dir/nonexistent_dir/new.txt", cx); + assert_eq!( + result.await.unwrap_err().to_string(), + "Can't create file: parent directory doesn't exist" + ); + } + + #[gpui::test] + async fn test_streaming_resolve_path_for_editing_file(cx: &mut TestAppContext) { + let mode = StreamingEditFileMode::Edit; + + let path_with_root = "root/dir/subdir/existing.txt"; + let path_without_root = "dir/subdir/existing.txt"; + let result = test_resolve_path(&mode, path_with_root, cx); + assert_resolved_path_eq(result.await, rel_path(path_without_root)); + + let result = test_resolve_path(&mode, path_without_root, cx); + assert_resolved_path_eq(result.await, rel_path(path_without_root)); + + let result = test_resolve_path(&mode, "root/nonexistent.txt", cx); + assert_eq!( + result.await.unwrap_err().to_string(), + "Can't edit file: path not found" + ); + + let result = test_resolve_path(&mode, "root/dir", cx); + assert_eq!( + result.await.unwrap_err().to_string(), + "Can't edit file: path is a directory" + ); + } + + async fn test_resolve_path( + mode: &StreamingEditFileMode, + path: &str, + cx: &mut TestAppContext, + ) -> anyhow::Result { + init_test(cx); + + let fs = project::FakeFs::new(cx.executor()); + fs.insert_tree( + "/root", + json!({ + "dir": { + "subdir": { + "existing.txt": "hello" + } + } + }), + ) + .await; + let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; + + cx.update(|cx| resolve_path(mode.clone(), &PathBuf::from(path), &project, cx)) + } + + #[track_caller] + fn assert_resolved_path_eq(path: anyhow::Result, expected: &RelPath) { + let actual = path.expect("Should return valid path").path; + assert_eq!(actual.as_ref(), expected); + } + + #[gpui::test] + async fn test_streaming_format_on_save(cx: &mut TestAppContext) { + init_test(cx); + + let fs = project::FakeFs::new(cx.executor()); + fs.insert_tree("/root", json!({"src": {}})).await; + + let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; + + let rust_language = Arc::new(language::Language::new( + language::LanguageConfig { + name: "Rust".into(), + matcher: language::LanguageMatcher { + path_suffixes: vec!["rs".to_string()], + ..Default::default() + }, + ..Default::default() + }, + None, + )); + + let language_registry = project.read_with(cx, |project, _| project.languages().clone()); + language_registry.add(rust_language); + + let mut fake_language_servers = language_registry.register_fake_lsp( + "Rust", + language::FakeLspAdapter { + capabilities: lsp::ServerCapabilities { + document_formatting_provider: Some(lsp::OneOf::Left(true)), + ..Default::default() + }, + ..Default::default() + }, + ); + + fs.save( + path!("/root/src/main.rs").as_ref(), + &"initial content".into(), + language::LineEnding::Unix, + ) + .await + .unwrap(); + + // Open the buffer to trigger LSP initialization + let buffer = project + .update(cx, |project, cx| { + project.open_local_buffer(path!("/root/src/main.rs"), cx) + }) + .await + .unwrap(); + + // Register the buffer with language servers + let _handle = project.update(cx, |project, cx| { + project.register_buffer_with_language_servers(&buffer, cx) + }); + + const UNFORMATTED_CONTENT: &str = "fn main() {println!(\"Hello!\");}\n"; + const FORMATTED_CONTENT: &str = + "This file was formatted by the fake formatter in the test.\n"; + + // Get the fake language server and set up formatting handler + let fake_language_server = fake_language_servers.next().await.unwrap(); + fake_language_server.set_request_handler::({ + |_, _| async move { + Ok(Some(vec![lsp::TextEdit { + range: lsp::Range::new(lsp::Position::new(0, 0), lsp::Position::new(1, 0)), + new_text: FORMATTED_CONTENT.to_string(), + }])) + } + }); + + let context_server_registry = + cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); + let model = Arc::new(FakeLanguageModel::default()); + let thread = cx.new(|cx| { + crate::Thread::new( + project.clone(), + cx.new(|_cx| ProjectContext::default()), + context_server_registry, + Templates::new(), + Some(model.clone()), + cx, + ) + }); + + // Test with format_on_save enabled + cx.update(|cx| { + SettingsStore::update_global(cx, |store, cx| { + store.update_user_settings(cx, |settings| { + settings.project.all_languages.defaults.format_on_save = Some(FormatOnSave::On); + settings.project.all_languages.defaults.formatter = + Some(language::language_settings::FormatterList::default()); + }); + }); + }); + + // Use streaming pattern so executor can pump the LSP request/response + let (sender, input) = ToolInput::::test(); + let (event_stream, _receiver) = ToolCallEventStream::test(); + + let tool = Arc::new(StreamingEditFileTool::new( + project.clone(), + thread.downgrade(), + language_registry.clone(), + )); + + let task = cx.update(|cx| tool.run(input, event_stream, cx)); + + sender.send_partial(json!({ + "display_description": "Create main function", + "path": "root/src/main.rs", + "mode": "overwrite" + })); + cx.run_until_parked(); + + sender.send_final(json!({ + "display_description": "Create main function", + "path": "root/src/main.rs", + "mode": "overwrite", + "content": UNFORMATTED_CONTENT + })); + + let result = task.await; + assert!(result.is_ok()); + + cx.executor().run_until_parked(); + + let new_content = fs.load(path!("/root/src/main.rs").as_ref()).await.unwrap(); + assert_eq!( + new_content.replace("\r\n", "\n"), + FORMATTED_CONTENT, + "Code should be formatted when format_on_save is enabled" + ); + + let stale_buffer_count = thread + .read_with(cx, |thread, _cx| thread.action_log.clone()) + .read_with(cx, |log, cx| log.stale_buffers(cx).count()); + + assert_eq!( + stale_buffer_count, 0, + "BUG: Buffer is incorrectly marked as stale after format-on-save. Found {} stale buffers.", + stale_buffer_count + ); + + // Test with format_on_save disabled + cx.update(|cx| { + SettingsStore::update_global(cx, |store, cx| { + store.update_user_settings(cx, |settings| { + settings.project.all_languages.defaults.format_on_save = + Some(FormatOnSave::Off); + }); + }); + }); + + let (sender, input) = ToolInput::::test(); + let (event_stream, _receiver) = ToolCallEventStream::test(); + + let tool = Arc::new(StreamingEditFileTool::new( + project.clone(), + thread.downgrade(), + language_registry, + )); + + let task = cx.update(|cx| tool.run(input, event_stream, cx)); + + sender.send_partial(json!({ + "display_description": "Update main function", + "path": "root/src/main.rs", + "mode": "overwrite" + })); + cx.run_until_parked(); + + sender.send_final(json!({ + "display_description": "Update main function", + "path": "root/src/main.rs", + "mode": "overwrite", + "content": UNFORMATTED_CONTENT + })); + + let result = task.await; + assert!(result.is_ok()); + + cx.executor().run_until_parked(); + + let new_content = fs.load(path!("/root/src/main.rs").as_ref()).await.unwrap(); + assert_eq!( + new_content.replace("\r\n", "\n"), + UNFORMATTED_CONTENT, + "Code should not be formatted when format_on_save is disabled" + ); + } + + #[gpui::test] + async fn test_streaming_remove_trailing_whitespace(cx: &mut TestAppContext) { + init_test(cx); + + let fs = project::FakeFs::new(cx.executor()); + fs.insert_tree("/root", json!({"src": {}})).await; + + fs.save( + path!("/root/src/main.rs").as_ref(), + &"initial content".into(), + language::LineEnding::Unix, + ) + .await + .unwrap(); + + let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; + let context_server_registry = + cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); + let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); + let model = Arc::new(FakeLanguageModel::default()); + let thread = cx.new(|cx| { + crate::Thread::new( + project.clone(), + cx.new(|_cx| ProjectContext::default()), + context_server_registry, + Templates::new(), + Some(model.clone()), + cx, + ) + }); + + // Test with remove_trailing_whitespace_on_save enabled + cx.update(|cx| { + SettingsStore::update_global(cx, |store, cx| { + store.update_user_settings(cx, |settings| { + settings + .project + .all_languages + .defaults + .remove_trailing_whitespace_on_save = Some(true); + }); + }); + }); + + const CONTENT_WITH_TRAILING_WHITESPACE: &str = + "fn main() { \n println!(\"Hello!\"); \n}\n"; + + let result = cx + .update(|cx| { + let input = StreamingEditFileToolInput { + display_description: "Create main function".into(), + path: "root/src/main.rs".into(), + mode: StreamingEditFileMode::Overwrite, + content: Some(CONTENT_WITH_TRAILING_WHITESPACE.into()), + edits: None, + }; + Arc::new(StreamingEditFileTool::new( + project.clone(), + thread.downgrade(), + language_registry.clone(), + )) + .run( + ToolInput::resolved(input), + ToolCallEventStream::test().0, + cx, + ) + }) + .await; + assert!(result.is_ok()); + + cx.executor().run_until_parked(); + + assert_eq!( + fs.load(path!("/root/src/main.rs").as_ref()) + .await + .unwrap() + .replace("\r\n", "\n"), + "fn main() {\n println!(\"Hello!\");\n}\n", + "Trailing whitespace should be removed when remove_trailing_whitespace_on_save is enabled" + ); + + // Test with remove_trailing_whitespace_on_save disabled + cx.update(|cx| { + SettingsStore::update_global(cx, |store, cx| { + store.update_user_settings(cx, |settings| { + settings + .project + .all_languages + .defaults + .remove_trailing_whitespace_on_save = Some(false); + }); + }); + }); + + let result = cx + .update(|cx| { + let input = StreamingEditFileToolInput { + display_description: "Update main function".into(), + path: "root/src/main.rs".into(), + mode: StreamingEditFileMode::Overwrite, + content: Some(CONTENT_WITH_TRAILING_WHITESPACE.into()), + edits: None, + }; + Arc::new(StreamingEditFileTool::new( + project.clone(), + thread.downgrade(), + language_registry, + )) + .run( + ToolInput::resolved(input), + ToolCallEventStream::test().0, + cx, + ) + }) + .await; + assert!(result.is_ok()); + + cx.executor().run_until_parked(); + + let final_content = fs.load(path!("/root/src/main.rs").as_ref()).await.unwrap(); + assert_eq!( + final_content.replace("\r\n", "\n"), + CONTENT_WITH_TRAILING_WHITESPACE, + "Trailing whitespace should remain when remove_trailing_whitespace_on_save is disabled" + ); + } + + #[gpui::test] + async fn test_streaming_authorize(cx: &mut TestAppContext) { + init_test(cx); + let fs = project::FakeFs::new(cx.executor()); + let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; + let context_server_registry = + cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); + let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); + let model = Arc::new(FakeLanguageModel::default()); + let thread = cx.new(|cx| { + crate::Thread::new( + project.clone(), + cx.new(|_cx| ProjectContext::default()), + context_server_registry, + Templates::new(), + Some(model.clone()), + cx, + ) + }); + let tool = Arc::new(StreamingEditFileTool::new( + project.clone(), + thread.downgrade(), + language_registry, + )); + fs.insert_tree("/root", json!({})).await; + + // Test 1: Path with .zed component should require confirmation + let (stream_tx, mut stream_rx) = ToolCallEventStream::test(); + let _auth = cx.update(|cx| { + tool.authorize( + &PathBuf::from(".zed/settings.json"), + "test 1", + &stream_tx, + cx, + ) + }); + + let event = stream_rx.expect_authorization().await; + assert_eq!( + event.tool_call.fields.title, + Some("test 1 (local settings)".into()) + ); + + // Test 2: Path outside project should require confirmation + let (stream_tx, mut stream_rx) = ToolCallEventStream::test(); + let _auth = + cx.update(|cx| tool.authorize(&PathBuf::from("/etc/hosts"), "test 2", &stream_tx, cx)); + + let event = stream_rx.expect_authorization().await; + assert_eq!(event.tool_call.fields.title, Some("test 2".into())); + + // Test 3: Relative path without .zed should not require confirmation + let (stream_tx, mut stream_rx) = ToolCallEventStream::test(); + cx.update(|cx| { + tool.authorize(&PathBuf::from("root/src/main.rs"), "test 3", &stream_tx, cx) + }) + .await + .unwrap(); + assert!(stream_rx.try_next().is_err()); + + // Test 4: Path with .zed in the middle should require confirmation + let (stream_tx, mut stream_rx) = ToolCallEventStream::test(); + let _auth = cx.update(|cx| { + tool.authorize( + &PathBuf::from("root/.zed/tasks.json"), + "test 4", + &stream_tx, + cx, + ) + }); + let event = stream_rx.expect_authorization().await; + assert_eq!( + event.tool_call.fields.title, + Some("test 4 (local settings)".into()) + ); + + // Test 5: When global default is allow, sensitive and outside-project + // paths still require confirmation + cx.update(|cx| { + let mut settings = agent_settings::AgentSettings::get_global(cx).clone(); + settings.tool_permissions.default = settings::ToolPermissionMode::Allow; + agent_settings::AgentSettings::override_global(settings, cx); + }); + + // 5.1: .zed/settings.json is a sensitive path — still prompts + let (stream_tx, mut stream_rx) = ToolCallEventStream::test(); + let _auth = cx.update(|cx| { + tool.authorize( + &PathBuf::from(".zed/settings.json"), + "test 5.1", + &stream_tx, + cx, + ) + }); + let event = stream_rx.expect_authorization().await; + assert_eq!( + event.tool_call.fields.title, + Some("test 5.1 (local settings)".into()) + ); + + // 5.2: /etc/hosts is outside the project, but Allow auto-approves + let (stream_tx, mut stream_rx) = ToolCallEventStream::test(); + cx.update(|cx| tool.authorize(&PathBuf::from("/etc/hosts"), "test 5.2", &stream_tx, cx)) + .await + .unwrap(); + assert!(stream_rx.try_next().is_err()); + + // 5.3: Normal in-project path with allow — no confirmation needed + let (stream_tx, mut stream_rx) = ToolCallEventStream::test(); + cx.update(|cx| { + tool.authorize( + &PathBuf::from("root/src/main.rs"), + "test 5.3", + &stream_tx, + cx, + ) + }) + .await + .unwrap(); + assert!(stream_rx.try_next().is_err()); + + // 5.4: With Confirm default, non-project paths still prompt + cx.update(|cx| { + let mut settings = agent_settings::AgentSettings::get_global(cx).clone(); + settings.tool_permissions.default = settings::ToolPermissionMode::Confirm; + agent_settings::AgentSettings::override_global(settings, cx); + }); + + let (stream_tx, mut stream_rx) = ToolCallEventStream::test(); + let _auth = cx + .update(|cx| tool.authorize(&PathBuf::from("/etc/hosts"), "test 5.4", &stream_tx, cx)); + + let event = stream_rx.expect_authorization().await; + assert_eq!(event.tool_call.fields.title, Some("test 5.4".into())); + } + + #[gpui::test] + async fn test_streaming_authorize_create_under_symlink_with_allow(cx: &mut TestAppContext) { + init_test(cx); + + let fs = project::FakeFs::new(cx.executor()); + fs.insert_tree("/root", json!({})).await; + fs.insert_tree("/outside", json!({})).await; + fs.insert_symlink("/root/link", PathBuf::from("/outside")) + .await; + + let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; + let context_server_registry = + cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); + let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); + let model = Arc::new(FakeLanguageModel::default()); + let thread = cx.new(|cx| { + crate::Thread::new( + project.clone(), + cx.new(|_cx| ProjectContext::default()), + context_server_registry, + Templates::new(), + Some(model), + cx, + ) + }); + let tool = Arc::new(StreamingEditFileTool::new( + project, + thread.downgrade(), + language_registry, + )); + + cx.update(|cx| { + let mut settings = agent_settings::AgentSettings::get_global(cx).clone(); + settings.tool_permissions.default = settings::ToolPermissionMode::Allow; + agent_settings::AgentSettings::override_global(settings, cx); + }); + + let (stream_tx, mut stream_rx) = ToolCallEventStream::test(); + let authorize_task = cx.update(|cx| { + tool.authorize( + &PathBuf::from("link/new.txt"), + "create through symlink", + &stream_tx, + cx, + ) + }); + + let event = stream_rx.expect_authorization().await; + assert!( + event + .tool_call + .fields + .title + .as_deref() + .is_some_and(|title| title.contains("points outside the project")), + "Expected symlink escape authorization for create under external symlink" + ); + + event + .response + .send(acp::PermissionOptionId::new("allow")) + .unwrap(); + authorize_task.await.unwrap(); + } + + #[gpui::test] + async fn test_streaming_edit_file_symlink_escape_requests_authorization( + cx: &mut TestAppContext, + ) { + init_test(cx); + + let fs = project::FakeFs::new(cx.executor()); + fs.insert_tree( + path!("/root"), + json!({ + "src": { "main.rs": "fn main() {}" } + }), + ) + .await; + fs.insert_tree( + path!("/outside"), + json!({ + "config.txt": "old content" + }), + ) + .await; + fs.create_symlink( + path!("/root/link_to_external").as_ref(), + PathBuf::from("/outside"), + ) + .await + .unwrap(); + + let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; + cx.executor().run_until_parked(); + + let language_registry = project.read_with(cx, |project, _| project.languages().clone()); + let context_server_registry = + cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); + let model = Arc::new(FakeLanguageModel::default()); + let thread = cx.new(|cx| { + crate::Thread::new( + project.clone(), + cx.new(|_cx| ProjectContext::default()), + context_server_registry, + Templates::new(), + Some(model), + cx, + ) + }); + let tool = Arc::new(StreamingEditFileTool::new( + project.clone(), + thread.downgrade(), + language_registry, + )); + + let (stream_tx, mut stream_rx) = ToolCallEventStream::test(); + let _authorize_task = cx.update(|cx| { + tool.authorize( + &PathBuf::from("link_to_external/config.txt"), + "edit through symlink", + &stream_tx, + cx, + ) + }); + + let auth = stream_rx.expect_authorization().await; + let title = auth.tool_call.fields.title.as_deref().unwrap_or(""); + assert!( + title.contains("points outside the project"), + "title should mention symlink escape, got: {title}" + ); + } + + #[gpui::test] + async fn test_streaming_edit_file_symlink_escape_denied(cx: &mut TestAppContext) { + init_test(cx); + + let fs = project::FakeFs::new(cx.executor()); + fs.insert_tree( + path!("/root"), + json!({ + "src": { "main.rs": "fn main() {}" } + }), + ) + .await; + fs.insert_tree( + path!("/outside"), + json!({ + "config.txt": "old content" + }), + ) + .await; + fs.create_symlink( + path!("/root/link_to_external").as_ref(), + PathBuf::from("/outside"), + ) + .await + .unwrap(); + + let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; + cx.executor().run_until_parked(); + + let language_registry = project.read_with(cx, |project, _| project.languages().clone()); + let context_server_registry = + cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); + let model = Arc::new(FakeLanguageModel::default()); + let thread = cx.new(|cx| { + crate::Thread::new( + project.clone(), + cx.new(|_cx| ProjectContext::default()), + context_server_registry, + Templates::new(), + Some(model), + cx, + ) + }); + let tool = Arc::new(StreamingEditFileTool::new( + project.clone(), + thread.downgrade(), + language_registry, + )); + + let (stream_tx, mut stream_rx) = ToolCallEventStream::test(); + let authorize_task = cx.update(|cx| { + tool.authorize( + &PathBuf::from("link_to_external/config.txt"), + "edit through symlink", + &stream_tx, + cx, + ) + }); + + let auth = stream_rx.expect_authorization().await; + drop(auth); // deny by dropping + + let result = authorize_task.await; + assert!(result.is_err(), "should fail when denied"); + } + + #[gpui::test] + async fn test_streaming_edit_file_symlink_escape_honors_deny_policy(cx: &mut TestAppContext) { + init_test(cx); + cx.update(|cx| { + let mut settings = agent_settings::AgentSettings::get_global(cx).clone(); + settings.tool_permissions.tools.insert( + "edit_file".into(), + agent_settings::ToolRules { + default: Some(settings::ToolPermissionMode::Deny), + ..Default::default() + }, + ); + agent_settings::AgentSettings::override_global(settings, cx); + }); + + let fs = project::FakeFs::new(cx.executor()); + fs.insert_tree( + path!("/root"), + json!({ + "src": { "main.rs": "fn main() {}" } + }), + ) + .await; + fs.insert_tree( + path!("/outside"), + json!({ + "config.txt": "old content" + }), + ) + .await; + fs.create_symlink( + path!("/root/link_to_external").as_ref(), + PathBuf::from("/outside"), + ) + .await + .unwrap(); + + let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; + cx.executor().run_until_parked(); + + let language_registry = project.read_with(cx, |project, _| project.languages().clone()); + let context_server_registry = + cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); + let model = Arc::new(FakeLanguageModel::default()); + let thread = cx.new(|cx| { + crate::Thread::new( + project.clone(), + cx.new(|_cx| ProjectContext::default()), + context_server_registry, + Templates::new(), + Some(model), + cx, + ) + }); + let tool = Arc::new(StreamingEditFileTool::new( + project.clone(), + thread.downgrade(), + language_registry, + )); + + let (stream_tx, mut stream_rx) = ToolCallEventStream::test(); + let result = cx + .update(|cx| { + tool.authorize( + &PathBuf::from("link_to_external/config.txt"), + "edit through symlink", + &stream_tx, + cx, + ) + }) + .await; + + assert!(result.is_err(), "Tool should fail when policy denies"); + assert!( + !matches!( + stream_rx.try_next(), + Ok(Some(Ok(crate::ThreadEvent::ToolCallAuthorization(_)))) + ), + "Deny policy should not emit symlink authorization prompt", + ); + } + + #[gpui::test] + async fn test_streaming_authorize_global_config(cx: &mut TestAppContext) { + init_test(cx); + let fs = project::FakeFs::new(cx.executor()); + fs.insert_tree("/project", json!({})).await; + let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; + let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); + let context_server_registry = + cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); + let model = Arc::new(FakeLanguageModel::default()); + let thread = cx.new(|cx| { + crate::Thread::new( + project.clone(), + cx.new(|_cx| ProjectContext::default()), + context_server_registry, + Templates::new(), + Some(model.clone()), + cx, + ) + }); + let tool = Arc::new(StreamingEditFileTool::new( + project.clone(), + thread.downgrade(), + language_registry, + )); + + let test_cases = vec![ + ( + "/etc/hosts", + true, + "System file should require confirmation", + ), + ( + "/usr/local/bin/script", + true, + "System bin file should require confirmation", + ), + ( + "project/normal_file.rs", + false, + "Normal project file should not require confirmation", + ), + ]; + + for (path, should_confirm, description) in test_cases { + let (stream_tx, mut stream_rx) = ToolCallEventStream::test(); + let auth = + cx.update(|cx| tool.authorize(&PathBuf::from(path), "Edit file", &stream_tx, cx)); + + if should_confirm { + stream_rx.expect_authorization().await; + } else { + auth.await.unwrap(); + assert!( + stream_rx.try_next().is_err(), + "Failed for case: {} - path: {} - expected no confirmation but got one", + description, + path + ); + } + } + } + + #[gpui::test] + async fn test_streaming_needs_confirmation_with_multiple_worktrees(cx: &mut TestAppContext) { + init_test(cx); + let fs = project::FakeFs::new(cx.executor()); + + fs.insert_tree( + "/workspace/frontend", + json!({ + "src": { + "main.js": "console.log('frontend');" + } + }), + ) + .await; + fs.insert_tree( + "/workspace/backend", + json!({ + "src": { + "main.rs": "fn main() {}" + } + }), + ) + .await; + fs.insert_tree( + "/workspace/shared", + json!({ + ".zed": { + "settings.json": "{}" + } + }), + ) + .await; + + let project = Project::test( + fs.clone(), + [ + path!("/workspace/frontend").as_ref(), + path!("/workspace/backend").as_ref(), + path!("/workspace/shared").as_ref(), + ], + cx, + ) + .await; + let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); + let context_server_registry = + cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); + let model = Arc::new(FakeLanguageModel::default()); + let thread = cx.new(|cx| { + crate::Thread::new( + project.clone(), + cx.new(|_cx| ProjectContext::default()), + context_server_registry.clone(), + Templates::new(), + Some(model.clone()), + cx, + ) + }); + let tool = Arc::new(StreamingEditFileTool::new( + project.clone(), + thread.downgrade(), + language_registry, + )); + + let test_cases = vec![ + ("frontend/src/main.js", false, "File in first worktree"), + ("backend/src/main.rs", false, "File in second worktree"), + ( + "shared/.zed/settings.json", + true, + ".zed file in third worktree", + ), + ("/etc/hosts", true, "Absolute path outside all worktrees"), + ( + "../outside/file.txt", + true, + "Relative path outside worktrees", + ), + ]; + + for (path, should_confirm, description) in test_cases { + let (stream_tx, mut stream_rx) = ToolCallEventStream::test(); + let auth = + cx.update(|cx| tool.authorize(&PathBuf::from(path), "Edit file", &stream_tx, cx)); + + if should_confirm { + stream_rx.expect_authorization().await; + } else { + auth.await.unwrap(); + assert!( + stream_rx.try_next().is_err(), + "Failed for case: {} - path: {} - expected no confirmation but got one", + description, + path + ); + } + } + } + + #[gpui::test] + async fn test_streaming_needs_confirmation_edge_cases(cx: &mut TestAppContext) { + init_test(cx); + let fs = project::FakeFs::new(cx.executor()); + fs.insert_tree( + "/project", + json!({ + ".zed": { + "settings.json": "{}" + }, + "src": { + ".zed": { + "local.json": "{}" + } + } + }), + ) + .await; + let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; + let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); + let context_server_registry = + cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); + let model = Arc::new(FakeLanguageModel::default()); + let thread = cx.new(|cx| { + crate::Thread::new( + project.clone(), + cx.new(|_cx| ProjectContext::default()), + context_server_registry.clone(), + Templates::new(), + Some(model.clone()), + cx, + ) + }); + let tool = Arc::new(StreamingEditFileTool::new( + project.clone(), + thread.downgrade(), + language_registry, + )); + + let test_cases = vec![ + ("", false, "Empty path is treated as project root"), + ("/", true, "Root directory should be outside project"), + ( + "project/../other", + true, + "Path with .. that goes outside of root directory", + ), + ( + "project/./src/file.rs", + false, + "Path with . should work normally", + ), + #[cfg(target_os = "windows")] + ("C:\\Windows\\System32\\hosts", true, "Windows system path"), + #[cfg(target_os = "windows")] + ("project\\src\\main.rs", false, "Windows-style project path"), + ]; + + for (path, should_confirm, description) in test_cases { + let (stream_tx, mut stream_rx) = ToolCallEventStream::test(); + let auth = + cx.update(|cx| tool.authorize(&PathBuf::from(path), "Edit file", &stream_tx, cx)); + + cx.run_until_parked(); + + if should_confirm { + stream_rx.expect_authorization().await; + } else { + assert!( + stream_rx.try_next().is_err(), + "Failed for case: {} - path: {} - expected no confirmation but got one", + description, + path + ); + auth.await.unwrap(); + } + } + } + + #[gpui::test] + async fn test_streaming_needs_confirmation_with_different_modes(cx: &mut TestAppContext) { + init_test(cx); + let fs = project::FakeFs::new(cx.executor()); + fs.insert_tree( + "/project", + json!({ + "existing.txt": "content", + ".zed": { + "settings.json": "{}" + } + }), + ) + .await; + let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; + let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); + let context_server_registry = + cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); + let model = Arc::new(FakeLanguageModel::default()); + let thread = cx.new(|cx| { + crate::Thread::new( + project.clone(), + cx.new(|_cx| ProjectContext::default()), + context_server_registry.clone(), + Templates::new(), + Some(model.clone()), + cx, + ) + }); + let tool = Arc::new(StreamingEditFileTool::new( + project.clone(), + thread.downgrade(), + language_registry, + )); + + let modes = vec![ + StreamingEditFileMode::Edit, + StreamingEditFileMode::Create, + StreamingEditFileMode::Overwrite, + ]; + + for _mode in modes { + // Test .zed path with different modes + let (stream_tx, mut stream_rx) = ToolCallEventStream::test(); + let _auth = cx.update(|cx| { + tool.authorize( + &PathBuf::from("project/.zed/settings.json"), + "Edit settings", + &stream_tx, + cx, + ) + }); + + stream_rx.expect_authorization().await; + + // Test outside path with different modes + let (stream_tx, mut stream_rx) = ToolCallEventStream::test(); + let _auth = cx.update(|cx| { + tool.authorize( + &PathBuf::from("/outside/file.txt"), + "Edit file", + &stream_tx, + cx, + ) + }); + + stream_rx.expect_authorization().await; + + // Test normal path with different modes + let (stream_tx, mut stream_rx) = ToolCallEventStream::test(); + cx.update(|cx| { + tool.authorize( + &PathBuf::from("project/normal.txt"), + "Edit file", + &stream_tx, + cx, + ) + }) + .await + .unwrap(); + assert!(stream_rx.try_next().is_err()); + } + } + + #[gpui::test] + async fn test_streaming_initial_title_with_partial_input(cx: &mut TestAppContext) { + init_test(cx); + let fs = project::FakeFs::new(cx.executor()); + let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await; + let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); + let context_server_registry = + cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); + let model = Arc::new(FakeLanguageModel::default()); + let thread = cx.new(|cx| { + crate::Thread::new( + project.clone(), + cx.new(|_cx| ProjectContext::default()), + context_server_registry, + Templates::new(), + Some(model.clone()), + cx, + ) + }); + let tool = Arc::new(StreamingEditFileTool::new( + project, + thread.downgrade(), + language_registry, + )); + + cx.update(|cx| { + assert_eq!( + tool.initial_title( + Err(json!({ + "path": "src/main.rs", + "display_description": "", + })), + cx + ), + "src/main.rs" + ); + assert_eq!( + tool.initial_title( + Err(json!({ + "path": "", + "display_description": "Fix error handling", + })), + cx + ), + "Fix error handling" + ); + assert_eq!( + tool.initial_title( + Err(json!({ + "path": "src/main.rs", + "display_description": "Fix error handling", + })), + cx + ), + "src/main.rs" + ); + assert_eq!( + tool.initial_title( + Err(json!({ + "path": "", + "display_description": "", + })), + cx + ), + DEFAULT_UI_TEXT + ); + assert_eq!( + tool.initial_title(Err(serde_json::Value::Null), cx), + DEFAULT_UI_TEXT + ); + }); + } + + #[gpui::test] + async fn test_streaming_diff_finalization(cx: &mut TestAppContext) { + init_test(cx); + let fs = project::FakeFs::new(cx.executor()); + fs.insert_tree("/", json!({"main.rs": ""})).await; + + let project = Project::test(fs.clone(), [path!("/").as_ref()], cx).await; + let languages = project.read_with(cx, |project, _cx| project.languages().clone()); + let context_server_registry = + cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); + let model = Arc::new(FakeLanguageModel::default()); + let thread = cx.new(|cx| { + crate::Thread::new( + project.clone(), + cx.new(|_cx| ProjectContext::default()), + context_server_registry.clone(), + Templates::new(), + Some(model.clone()), + cx, + ) + }); + + // Ensure the diff is finalized after the edit completes. + { + let tool = Arc::new(StreamingEditFileTool::new( + project.clone(), + thread.downgrade(), + languages.clone(), + )); + let (stream_tx, mut stream_rx) = ToolCallEventStream::test(); + let edit = cx.update(|cx| { + tool.run( + ToolInput::resolved(StreamingEditFileToolInput { + display_description: "Edit file".into(), + path: path!("/main.rs").into(), + mode: StreamingEditFileMode::Overwrite, + content: Some("new content".into()), + edits: None, + }), + stream_tx, + cx, + ) + }); + stream_rx.expect_update_fields().await; + let diff = stream_rx.expect_diff().await; + diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Pending(_)))); + cx.run_until_parked(); + edit.await.unwrap(); + diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Finalized(_)))); + } + + // Ensure the diff is finalized if the tool call gets dropped. + { + let tool = Arc::new(StreamingEditFileTool::new( + project.clone(), + thread.downgrade(), + languages.clone(), + )); + let (stream_tx, mut stream_rx) = ToolCallEventStream::test(); + let edit = cx.update(|cx| { + tool.run( + ToolInput::resolved(StreamingEditFileToolInput { + display_description: "Edit file".into(), + path: path!("/main.rs").into(), + mode: StreamingEditFileMode::Overwrite, + content: Some("dropped content".into()), + edits: None, + }), + stream_tx, + cx, + ) + }); + stream_rx.expect_update_fields().await; + let diff = stream_rx.expect_diff().await; + diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Pending(_)))); + drop(edit); + cx.run_until_parked(); + diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Finalized(_)))); + } + } + + #[gpui::test] + async fn test_streaming_consecutive_edits_work(cx: &mut TestAppContext) { + init_test(cx); + + let fs = project::FakeFs::new(cx.executor()); + fs.insert_tree( + "/root", + json!({ + "test.txt": "original content" + }), + ) + .await; + let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; + let context_server_registry = + cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); + let model = Arc::new(FakeLanguageModel::default()); + let thread = cx.new(|cx| { + crate::Thread::new( + project.clone(), + cx.new(|_cx| ProjectContext::default()), + context_server_registry, + Templates::new(), + Some(model.clone()), + cx, + ) + }); + let languages = project.read_with(cx, |project, _| project.languages().clone()); + let action_log = thread.read_with(cx, |thread, _| thread.action_log().clone()); + + let read_tool = Arc::new(crate::ReadFileTool::new( + thread.downgrade(), + project.clone(), + action_log, + )); + let edit_tool = Arc::new(StreamingEditFileTool::new( + project.clone(), + thread.downgrade(), + languages, + )); + + // Read the file first + cx.update(|cx| { + read_tool.clone().run( + ToolInput::resolved(crate::ReadFileToolInput { + path: "root/test.txt".to_string(), + start_line: None, + end_line: None, + }), + ToolCallEventStream::test().0, + cx, + ) + }) + .await + .unwrap(); + + // First edit should work + let edit_result = cx + .update(|cx| { + edit_tool.clone().run( + ToolInput::resolved(StreamingEditFileToolInput { + display_description: "First edit".into(), + path: "root/test.txt".into(), + mode: StreamingEditFileMode::Edit, + content: None, + edits: Some(vec![EditOperation { + old_text: "original content".into(), + new_text: "modified content".into(), + }]), + }), + ToolCallEventStream::test().0, + cx, + ) + }) + .await; + assert!( + edit_result.is_ok(), + "First edit should succeed, got error: {:?}", + edit_result.as_ref().err() + ); + + // Second edit should also work because the edit updated the recorded read time + let edit_result = cx + .update(|cx| { + edit_tool.clone().run( + ToolInput::resolved(StreamingEditFileToolInput { + display_description: "Second edit".into(), + path: "root/test.txt".into(), + mode: StreamingEditFileMode::Edit, + content: None, + edits: Some(vec![EditOperation { + old_text: "modified content".into(), + new_text: "further modified content".into(), + }]), + }), + ToolCallEventStream::test().0, + cx, + ) + }) + .await; + assert!( + edit_result.is_ok(), + "Second consecutive edit should succeed, got error: {:?}", + edit_result.as_ref().err() + ); + } + + #[gpui::test] + async fn test_streaming_external_modification_detected(cx: &mut TestAppContext) { + init_test(cx); + + let fs = project::FakeFs::new(cx.executor()); + fs.insert_tree( + "/root", + json!({ + "test.txt": "original content" + }), + ) + .await; + let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; + let context_server_registry = + cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); + let model = Arc::new(FakeLanguageModel::default()); + let thread = cx.new(|cx| { + crate::Thread::new( + project.clone(), + cx.new(|_cx| ProjectContext::default()), + context_server_registry, + Templates::new(), + Some(model.clone()), + cx, + ) + }); + let languages = project.read_with(cx, |project, _| project.languages().clone()); + let action_log = thread.read_with(cx, |thread, _| thread.action_log().clone()); + + let read_tool = Arc::new(crate::ReadFileTool::new( + thread.downgrade(), + project.clone(), + action_log, + )); + let edit_tool = Arc::new(StreamingEditFileTool::new( + project.clone(), + thread.downgrade(), + languages, + )); + + // Read the file first + cx.update(|cx| { + read_tool.clone().run( + ToolInput::resolved(crate::ReadFileToolInput { + path: "root/test.txt".to_string(), + start_line: None, + end_line: None, + }), + ToolCallEventStream::test().0, + cx, + ) + }) + .await + .unwrap(); + + // Simulate external modification + cx.background_executor + .advance_clock(std::time::Duration::from_secs(2)); + fs.save( + path!("/root/test.txt").as_ref(), + &"externally modified content".into(), + language::LineEnding::Unix, + ) + .await + .unwrap(); + + // Reload the buffer to pick up the new mtime + let project_path = project + .read_with(cx, |project, cx| { + project.find_project_path("root/test.txt", cx) + }) + .expect("Should find project path"); + let buffer = project + .update(cx, |project, cx| project.open_buffer(project_path, cx)) + .await + .unwrap(); + buffer + .update(cx, |buffer, cx| buffer.reload(cx)) + .await + .unwrap(); + + cx.executor().run_until_parked(); + + // Try to edit - should fail because file was modified externally + let result = cx + .update(|cx| { + edit_tool.clone().run( + ToolInput::resolved(StreamingEditFileToolInput { + display_description: "Edit after external change".into(), + path: "root/test.txt".into(), + mode: StreamingEditFileMode::Edit, + content: None, + edits: Some(vec![EditOperation { + old_text: "externally modified content".into(), + new_text: "new content".into(), + }]), + }), + ToolCallEventStream::test().0, + cx, + ) + }) + .await; + + let StreamingEditFileToolOutput::Error { error } = result.unwrap_err() else { + panic!("expected error"); + }; + assert!( + error.contains("has been modified since you last read it"), + "Error should mention file modification, got: {}", + error + ); + } + + #[gpui::test] + async fn test_streaming_dirty_buffer_detected(cx: &mut TestAppContext) { + init_test(cx); + + let fs = project::FakeFs::new(cx.executor()); + fs.insert_tree( + "/root", + json!({ + "test.txt": "original content" + }), + ) + .await; + let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; + let context_server_registry = + cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); + let model = Arc::new(FakeLanguageModel::default()); + let thread = cx.new(|cx| { + crate::Thread::new( + project.clone(), + cx.new(|_cx| ProjectContext::default()), + context_server_registry, + Templates::new(), + Some(model.clone()), + cx, + ) + }); + let languages = project.read_with(cx, |project, _| project.languages().clone()); + let action_log = thread.read_with(cx, |thread, _| thread.action_log().clone()); + + let read_tool = Arc::new(crate::ReadFileTool::new( + thread.downgrade(), + project.clone(), + action_log, + )); + let edit_tool = Arc::new(StreamingEditFileTool::new( + project.clone(), + thread.downgrade(), + languages, + )); + + // Read the file first + cx.update(|cx| { + read_tool.clone().run( + ToolInput::resolved(crate::ReadFileToolInput { + path: "root/test.txt".to_string(), + start_line: None, + end_line: None, + }), + ToolCallEventStream::test().0, + cx, + ) + }) + .await + .unwrap(); + + // Open the buffer and make it dirty + let project_path = project + .read_with(cx, |project, cx| { + project.find_project_path("root/test.txt", cx) + }) + .expect("Should find project path"); + let buffer = project + .update(cx, |project, cx| project.open_buffer(project_path, cx)) + .await + .unwrap(); + + buffer.update(cx, |buffer, cx| { + let end_point = buffer.max_point(); + buffer.edit([(end_point..end_point, " added text")], None, cx); + }); + + let is_dirty = buffer.read_with(cx, |buffer, _| buffer.is_dirty()); + assert!(is_dirty, "Buffer should be dirty after in-memory edit"); + + // Try to edit - should fail because buffer has unsaved changes + let result = cx + .update(|cx| { + edit_tool.clone().run( + ToolInput::resolved(StreamingEditFileToolInput { + display_description: "Edit with dirty buffer".into(), + path: "root/test.txt".into(), + mode: StreamingEditFileMode::Edit, + content: None, + edits: Some(vec![EditOperation { + old_text: "original content".into(), + new_text: "new content".into(), + }]), + }), + ToolCallEventStream::test().0, + cx, + ) + }) + .await; + + let StreamingEditFileToolOutput::Error { error } = result.unwrap_err() else { + panic!("expected error"); + }; + assert!( + error.contains("This file has unsaved changes."), + "Error should mention unsaved changes, got: {}", + error + ); + assert!( + error.contains("keep or discard"), + "Error should ask whether to keep or discard changes, got: {}", + error + ); + assert!( + error.contains("save or revert the file manually"), + "Error should ask user to manually save or revert when tools aren't available, got: {}", + error + ); + } + fn init_test(cx: &mut TestAppContext) { cx.update(|cx| { let settings_store = SettingsStore::test(cx); diff --git a/crates/agent_ui/src/buffer_codegen.rs b/crates/agent_ui/src/buffer_codegen.rs index 39759f264996ee07a7efd2b2bee8b1d1e3847f51..0376fda47e0b20820e19cf9cc2b09493b06898b8 100644 --- a/crates/agent_ui/src/buffer_codegen.rs +++ b/crates/agent_ui/src/buffer_codegen.rs @@ -526,11 +526,13 @@ impl CodegenAlternative { name: REWRITE_SECTION_TOOL_NAME.to_string(), description: "Replaces text in tags with your replacement_text.".to_string(), input_schema: language_model::tool_schema::root_schema_for::(tool_input_format).to_value(), + use_input_streaming: false, }, LanguageModelRequestTool { name: FAILURE_MESSAGE_TOOL_NAME.to_string(), description: "Use this tool to provide a message to the user when you're unable to complete a task.".to_string(), input_schema: language_model::tool_schema::root_schema_for::(tool_input_format).to_value(), + use_input_streaming: false, }, ]; diff --git a/crates/anthropic/src/anthropic.rs b/crates/anthropic/src/anthropic.rs index bc2516b8b0f53e79a03fca40f6ce4dc5b564efc1..56baf4b58fe9ac568ea22012234510ff617fab25 100644 --- a/crates/anthropic/src/anthropic.rs +++ b/crates/anthropic/src/anthropic.rs @@ -906,11 +906,17 @@ pub struct ImageSource { pub data: String, } +fn is_false(value: &bool) -> bool { + !value +} + #[derive(Debug, Serialize, Deserialize)] pub struct Tool { pub name: String, pub description: String, pub input_schema: serde_json::Value, + #[serde(default, skip_serializing_if = "is_false")] + pub eager_input_streaming: bool, } #[derive(Debug, Serialize, Deserialize)] diff --git a/crates/language_model/src/request.rs b/crates/language_model/src/request.rs index 04a61ae79474ea525cfe522dac2ac75048e7510b..cb2f6a27de65739bb684626ce5bd985a187bf28f 100644 --- a/crates/language_model/src/request.rs +++ b/crates/language_model/src/request.rs @@ -431,6 +431,7 @@ pub struct LanguageModelRequestTool { pub name: String, pub description: String, pub input_schema: serde_json::Value, + pub use_input_streaming: bool, } #[derive(Debug, PartialEq, Hash, Clone, Serialize, Deserialize)] diff --git a/crates/language_models/src/provider/anthropic.rs b/crates/language_models/src/provider/anthropic.rs index c1de89e4f8505433972d4c5673b130a1e4d0e72e..5b7ad62e0e66977465502d61f3db3707274a9718 100644 --- a/crates/language_models/src/provider/anthropic.rs +++ b/crates/language_models/src/provider/anthropic.rs @@ -370,6 +370,7 @@ pub fn into_anthropic_count_tokens_request( name: tool.name, description: tool.description, input_schema: tool.input_schema, + eager_input_streaming: tool.use_input_streaming, }) .collect(), tool_choice: request.tool_choice.map(|choice| match choice { @@ -713,6 +714,7 @@ pub fn into_anthropic( name: tool.name, description: tool.description, input_schema: tool.input_schema, + eager_input_streaming: tool.use_input_streaming, }) .collect(), tool_choice: request.tool_choice.map(|choice| match choice { diff --git a/crates/language_models/src/provider/open_ai.rs b/crates/language_models/src/provider/open_ai.rs index 5998242d13f335f5343ed61d739cdd1dcad08c38..40cc67098a76d0430f597feb8f1045859863486a 100644 --- a/crates/language_models/src/provider/open_ai.rs +++ b/crates/language_models/src/provider/open_ai.rs @@ -1566,6 +1566,7 @@ mod tests { name: "get_weather".into(), description: "Fetches the weather".into(), input_schema: json!({ "type": "object" }), + use_input_streaming: false, }], tool_choice: Some(LanguageModelToolChoice::Any), stop: vec!["".into()],