From 70ee54da8fc7710316bb1ff7589dfc141e510f37 Mon Sep 17 00:00:00 2001 From: Bennet Bo Fenner Date: Wed, 6 May 2026 17:47:18 +0200 Subject: [PATCH] agent: Add `write_file` tool (#55865) Splits the edit tool into two separate tools `write_file` (previously `mode = write`), and `edit_file` (previously `mode = edit`). This makes the JSON schema for the `edit_tool` much simpler. We've seen models (especially older ones) struggle with providing `mode = edit + edits` and `mode = write + content` fields. This seems to improve eval scores for Sonnet 4.6 slightly. Also added two unit evals to ensure that the model uses the tool to create new/override existing files Release Notes: - N/A --------- Co-authored-by: Ben Brandt --- assets/settings/default.json | 1 + crates/agent/src/tests/mod.rs | 12 +- crates/agent/src/thread.rs | 9 +- crates/agent/src/tools.rs | 4 + crates/agent/src/tools/edit_file_tool.rs | 2279 ++--------------- crates/agent/src/tools/edit_session.rs | 1067 ++++++++ .../reindent.rs | 0 .../streaming_fuzzy_matcher.rs | 0 .../streaming_parser.rs | 2 +- crates/agent/src/tools/evals.rs | 2 + crates/agent/src/tools/evals/edit_file.rs | 62 +- crates/agent/src/tools/evals/write_file.rs | 561 ++++ crates/agent/src/tools/write_file_tool.rs | 1190 +++++++++ crates/settings_ui/src/pages.rs | 1 + .../src/pages/tool_permissions_setup.rs | 8 + 15 files changed, 3041 insertions(+), 2157 deletions(-) create mode 100644 crates/agent/src/tools/edit_session.rs rename crates/agent/src/tools/{edit_file_tool => edit_session}/reindent.rs (100%) rename crates/agent/src/tools/{edit_file_tool => edit_session}/streaming_fuzzy_matcher.rs (100%) rename crates/agent/src/tools/{edit_file_tool => edit_session}/streaming_parser.rs (99%) create mode 100644 crates/agent/src/tools/evals/write_file.rs create mode 100644 crates/agent/src/tools/write_file_tool.rs diff --git a/assets/settings/default.json b/assets/settings/default.json index 624dcc0f01233a8ae92035207e59aa0668c811fc..64f97c451b00ea6fad4a11b0f2787ad53f408d78 100644 --- a/assets/settings/default.json +++ b/assets/settings/default.json @@ -1110,6 +1110,7 @@ "diagnostics": true, "apply_code_action": true, "edit_file": true, + "write_file": true, "fetch": true, "find_path": true, "find_references": true, diff --git a/crates/agent/src/tests/mod.rs b/crates/agent/src/tests/mod.rs index 2a4e9c255fb3ceae2ad4bf99ce36ee77866017fe..57cec0bc5d07a99578d5a64d1ef442f47a5e04e9 100644 --- a/crates/agent/src/tests/mod.rs +++ b/crates/agent/src/tests/mod.rs @@ -6062,9 +6062,7 @@ async fn test_edit_file_tool_deny_rule_blocks_edit(cx: &mut TestAppContext) { tool.run( ToolInput::resolved(crate::EditFileToolInput { path: "root/sensitive_config.txt".into(), - mode: crate::EditFileMode::Edit, - content: None, - edits: Some(vec![]), + edits: vec![], }), event_stream, cx, @@ -6496,9 +6494,7 @@ async fn test_edit_file_tool_allow_rule_skips_confirmation(cx: &mut TestAppConte tool.run( ToolInput::resolved(crate::EditFileToolInput { path: "root/README.md".into(), - mode: crate::EditFileMode::Edit, - content: None, - edits: Some(vec![]), + edits: vec![], }), event_stream, cx, @@ -6568,9 +6564,7 @@ async fn test_edit_file_tool_allow_still_prompts_for_local_settings(cx: &mut Tes tool.run( ToolInput::resolved(crate::EditFileToolInput { path: "root/.zed/settings.json".into(), - mode: crate::EditFileMode::Edit, - content: None, - edits: Some(vec![]), + edits: vec![], }), event_stream, cx, diff --git a/crates/agent/src/thread.rs b/crates/agent/src/thread.rs index c6979391673ec67b892a4f3e55c79e406c20cd8e..78a4b2fd488918a0e2e773e7f07c55e013a52b48 100644 --- a/crates/agent/src/thread.rs +++ b/crates/agent/src/thread.rs @@ -4,7 +4,7 @@ use crate::{ FindPathTool, FindReferencesTool, GetCodeActionsTool, GoToDefinitionTool, GrepTool, ListDirectoryTool, MovePathTool, NowTool, OpenTool, ProjectSnapshot, ReadFileTool, RenameTool, RestoreFileFromDiskTool, SaveFileTool, SpawnAgentTool, SystemPromptTemplate, Template, - Templates, TerminalTool, ToolPermissionDecision, UpdatePlanTool, WebSearchTool, + Templates, TerminalTool, ToolPermissionDecision, UpdatePlanTool, WebSearchTool, WriteFileTool, decide_permission_from_settings, }; use acp_thread::{MentionUri, UserMessageId}; @@ -822,6 +822,7 @@ impl ToolPermissionContext { } else if tool_name == CopyPathTool::NAME || tool_name == MovePathTool::NAME || tool_name == EditFileTool::NAME + || tool_name == WriteFileTool::NAME || tool_name == DeletePathTool::NAME || tool_name == CreateDirectoryTool::NAME || tool_name == SaveFileTool::NAME @@ -1544,6 +1545,12 @@ impl Thread { self.action_log.clone(), )); self.add_tool(EditFileTool::new( + self.project.clone(), + cx.weak_entity(), + self.action_log.clone(), + language_registry.clone(), + )); + self.add_tool(WriteFileTool::new( self.project.clone(), cx.weak_entity(), self.action_log.clone(), diff --git a/crates/agent/src/tools.rs b/crates/agent/src/tools.rs index 71ee0b2ba1714fa7a8f9c282685ef901e6e8fb2e..c52e0e4745e43853487cec9167e8a483763035b1 100644 --- a/crates/agent/src/tools.rs +++ b/crates/agent/src/tools.rs @@ -5,6 +5,7 @@ mod create_directory_tool; mod delete_path_tool; mod diagnostics_tool; mod edit_file_tool; +mod edit_session; #[cfg(all(test, feature = "unit-eval"))] mod evals; mod fetch_tool; @@ -27,6 +28,7 @@ mod terminal_tool; mod tool_permissions; mod update_plan_tool; mod web_search_tool; +mod write_file_tool; use crate::AgentTool; use language_model::{LanguageModelRequestTool, LanguageModelToolSchemaFormat}; @@ -85,6 +87,7 @@ pub use terminal_tool::*; pub use tool_permissions::*; pub use update_plan_tool::*; pub use web_search_tool::*; +pub use write_file_tool::*; macro_rules! tools { ($($tool:ty),* $(,)?) => { @@ -179,4 +182,5 @@ tools! { TerminalTool, UpdatePlanTool, WebSearchTool, + WriteFileTool, } diff --git a/crates/agent/src/tools/edit_file_tool.rs b/crates/agent/src/tools/edit_file_tool.rs index 69f7be4662abcbe29968cc572a8620222bbb0d6a..1061d5a5b7e4cc69179636f76235c042068174a3 100644 --- a/crates/agent/src/tools/edit_file_tool.rs +++ b/crates/agent/src/tools/edit_file_tool.rs @@ -1,53 +1,36 @@ -mod reindent; -mod streaming_fuzzy_matcher; -mod streaming_parser; - use super::deserialize_maybe_stringified; -use super::restore_file_from_disk_tool::RestoreFileFromDiskTool; -use super::save_file_tool::SaveFileTool; -use crate::ToolInputPayload; -use crate::tools::edit_file_tool::{ - reindent::{Reindenter, compute_indent_delta}, - streaming_fuzzy_matcher::StreamingFuzzyMatcher, - streaming_parser::{EditEvent, StreamingParser, WriteEvent}, +pub(crate) use super::edit_session::PartialEdit; +pub use super::edit_session::{Edit, EditSessionOutput as EditFileToolOutput}; +use super::edit_session::{ + EditSession, EditSessionContext, EditSessionMode, EditSessionResult, + initial_title_from_partial_path, run_session, }; -use crate::{AgentTool, Thread, ToolCallEventStream, ToolInput}; -use acp_thread::Diff; +use crate::{AgentTool, Thread, ToolCallEventStream, ToolInput, ToolInputPayload}; use action_log::ActionLog; -use agent_client_protocol::schema::{self as acp, ToolCallLocation, ToolCallUpdateFields}; +use agent_client_protocol::schema as acp; use anyhow::Result; -use collections::HashSet; use futures::FutureExt as _; -use gpui::{App, AppContext, AsyncApp, Entity, Task, WeakEntity}; -use language::language_settings::{self, FormatOnSave}; -use language::{Buffer, LanguageRegistry}; -use language_model::LanguageModelToolResultContent; -use project::lsp_store::{FormatTrigger, LspFormatTarget}; -use project::{AgentLocation, Project, ProjectPath}; +use gpui::{App, AsyncApp, Entity, Task, WeakEntity}; +use language::LanguageRegistry; +use project::Project; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; -use std::ops::Range; use std::path::PathBuf; use std::sync::Arc; -use streaming_diff::{CharOperation, StreamingDiff}; -use text::ToOffset; use ui::SharedString; -use util::rel_path::RelPath; -use util::{Deferred, ResultExt}; const DEFAULT_UI_TEXT: &str = "Editing file"; -/// This is a tool for creating a new file or editing an existing file. For moving or renaming files, you should generally use the `move_path` tool instead. +/// This is a tool for applying edits to an existing file. /// /// Before using this tool: /// /// 1. Use the `read_file` tool to understand the file's contents and context /// -/// 2. Verify the directory path is correct (only applicable when creating new files): -/// - Use the `list_directory` tool to verify the parent directory exists and is the correct location +/// To create a new file or overwrite an existing one with completely new contents, use the `write_file` tool instead. #[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)] pub struct EditFileToolInput { - /// The full path of the file to create or modify in the project. + /// The full path of the file to edit in the project. /// /// WARNING: When specifying which file path need changing, you MUST start each path with one of the project's root directories. /// @@ -66,50 +49,10 @@ pub struct EditFileToolInput { /// pub path: PathBuf, - /// The mode of operation on the file. Possible values: - /// - 'write': Replace the entire contents of the file. If the file doesn't exist, it will be created. Requires 'content' field. - /// - 'edit': Make granular edits to an existing file. Requires 'edits' field. - /// - /// When a file already exists or you just created it, prefer editing it as opposed to recreating it from scratch. - #[serde(deserialize_with = "deserialize_maybe_stringified")] - pub mode: EditFileMode, - - /// The complete content for the new file (required for 'write' mode). - /// This field should contain the entire file content. - #[serde(default, skip_serializing_if = "Option::is_none")] - pub content: Option, - - /// List of edit operations to apply sequentially (required for 'edit' mode). + /// List of edit operations to apply sequentially. /// Each edit finds `old_text` in the file and replaces it with `new_text`. - #[serde( - default, - skip_serializing_if = "Option::is_none", - deserialize_with = "deserialize_maybe_stringified" - )] - pub edits: Option>, -} - -#[derive(Clone, Copy, Debug, Serialize, Deserialize, JsonSchema)] -#[serde(rename_all = "snake_case")] -pub enum EditFileMode { - Write, - Edit, -} - -/// A single edit operation that replaces old text with new text -/// Properly escape all text fields as valid JSON strings. -/// Remember to escape special characters like newlines (`\n`) and quotes (`"`) in JSON strings. -#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)] -pub struct Edit { - /// The exact text to find in the file. This will be matched using fuzzy matching - /// to handle minor differences in whitespace or formatting. - /// - /// Be minimal with replacements: - /// - For unique lines, include only those lines - /// - For non-unique lines, include enough context to identify them - pub old_text: String, - /// The text to replace it with - pub new_text: String, + #[serde(deserialize_with = "deserialize_maybe_stringified")] + pub edits: Vec, } #[derive(Clone, Default, Debug, Deserialize)] @@ -117,108 +60,11 @@ struct EditFileToolPartialInput { #[serde(default)] path: Option, #[serde(default, deserialize_with = "deserialize_maybe_stringified")] - mode: Option, - #[serde(default)] - content: Option, - #[serde(default, deserialize_with = "deserialize_maybe_stringified")] edits: Option>, } -#[derive(Clone, Default, Debug, Deserialize)] -pub struct PartialEdit { - #[serde(default)] - pub old_text: Option, - #[serde(default)] - pub new_text: Option, -} - -#[derive(Debug, Serialize, Deserialize)] -#[serde(untagged)] -pub enum EditFileToolOutput { - Success { - #[serde(alias = "original_path")] - input_path: PathBuf, - new_text: String, - old_text: Arc, - #[serde(default)] - diff: String, - }, - Error { - error: String, - #[serde(default, skip_serializing_if = "Option::is_none")] - input_path: Option, - #[serde(default, skip_serializing_if = "String::is_empty")] - diff: String, - }, -} - -impl EditFileToolOutput { - pub fn error(error: impl Into) -> Self { - Self::Error { - error: error.into(), - input_path: None, - diff: String::new(), - } - } -} - -impl std::fmt::Display for EditFileToolOutput { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - EditFileToolOutput::Success { - diff, input_path, .. - } => { - if diff.is_empty() { - write!(f, "No edits were made.") - } else { - write!( - f, - "Edited {}:\n\n```diff\n{diff}\n```", - input_path.display() - ) - } - } - EditFileToolOutput::Error { - error, - diff, - input_path, - } => { - write!(f, "{error}\n")?; - if let Some(input_path) = input_path - && !diff.is_empty() - { - write!( - f, - "Edited {}:\n\n```diff\n{diff}\n```", - input_path.display() - ) - } else { - write!(f, "No edits were made.") - } - } - } - } -} - -impl From for LanguageModelToolResultContent { - fn from(output: EditFileToolOutput) -> Self { - output.to_string().into() - } -} - pub struct EditFileTool { - project: Entity, - thread: WeakEntity, - action_log: Entity, - language_registry: Arc, -} - -enum EditSessionResult { - Completed(EditSession), - Failed { - error: String, - session: Option, - }, + session_context: Arc, } impl EditFileTool { @@ -229,69 +75,24 @@ impl EditFileTool { language_registry: Arc, ) -> Self { Self { - project, - thread, - action_log, - language_registry, + session_context: Arc::new(EditSessionContext::new( + project, + thread, + action_log, + language_registry, + )), } } + #[cfg(test)] fn authorize( &self, path: &PathBuf, event_stream: &ToolCallEventStream, cx: &mut App, ) -> Task> { - super::tool_permissions::authorize_file_edit( - EditFileTool::NAME, - path, - &self.thread, - event_stream, - cx, - ) - } - - fn set_agent_location(&self, buffer: WeakEntity, position: text::Anchor, cx: &mut App) { - let should_update_agent_location = self - .thread - .read_with(cx, |thread, _cx| !thread.is_subagent()) - .unwrap_or_default(); - if should_update_agent_location { - self.project.update(cx, |project, cx| { - project.set_agent_location(Some(AgentLocation { buffer, position }), cx); - }); - } - } - - async fn ensure_buffer_saved(&self, buffer: &Entity, cx: &mut AsyncApp) { - let format_on_save_enabled = buffer.read_with(cx, |buffer, cx| { - let settings = language_settings::LanguageSettings::for_buffer(buffer, cx); - settings.format_on_save != FormatOnSave::Off - }); - - if format_on_save_enabled { - self.project - .update(cx, |project, cx| { - project.format( - HashSet::from_iter([buffer.clone()]), - LspFormatTarget::Buffers, - false, - FormatTrigger::Save, - cx, - ) - }) - .await - .log_err(); - } - - self.project - .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx)) - .await - .log_err(); - - self.action_log.update(cx, |log, cx| { - log.buffer_edited(buffer.clone(), cx); - }); + self.session_context + .authorize(Self::NAME, path, event_stream, cx) } async fn process_streaming_edits( @@ -301,7 +102,7 @@ impl EditFileTool { cx: &mut AsyncApp, ) -> EditSessionResult { let mut session: Option = None; - let mut last_partial: Option = None; + let mut last_path: Option = None; loop { futures::select! { @@ -311,22 +112,19 @@ impl EditFileTool { ToolInputPayload::Partial(partial) => { if let Ok(parsed) = serde_json::from_value::(partial) { let path_complete = parsed.path.is_some() - && parsed.path.as_ref() == last_partial.as_ref().and_then(|partial| partial.path.as_ref()); + && parsed.path.as_ref() == last_path.as_ref(); - last_partial = Some(parsed.clone()); + last_path = parsed.path.clone(); if session.is_none() && path_complete - && let EditFileToolPartialInput { - path: Some(path), - mode: Some(mode), - .. - } = &parsed + && let Some(path) = parsed.path.as_ref() { match EditSession::new( PathBuf::from(path), - *mode, - self, + EditSessionMode::Edit, + Self::NAME, + self.session_context.clone(), event_stream, cx, ) @@ -344,7 +142,7 @@ impl EditFileTool { } if let Some(current_session) = &mut session - && let Err(error) = current_session.process(parsed, self, event_stream, cx) + && let Err(error) = current_session.process_edit(parsed.edits.as_deref(), event_stream, cx) { log::error!("Failed to process edit: {}", error); return EditSessionResult::Failed { error, session }; @@ -357,8 +155,9 @@ impl EditFileTool { } else { match EditSession::new( full_input.path.clone(), - full_input.mode, - self, + EditSessionMode::Edit, + Self::NAME, + self.session_context.clone(), event_stream, cx, ) @@ -375,7 +174,7 @@ impl EditFileTool { } }; - return match session.finalize(full_input, self, event_stream, cx).await { + return match session.finalize_edit(full_input.edits, event_stream, cx).await { Ok(()) => EditSessionResult::Completed(session), Err(error) => { log::error!("Failed to finalize edit: {}", error); @@ -433,38 +232,17 @@ impl AgentTool for EditFileTool { cx: &mut App, ) -> SharedString { match input { - Ok(input) => self - .project - .read(cx) - .find_project_path(&input.path, cx) - .and_then(|project_path| { - self.project - .read(cx) - .short_full_path_for_project_path(&project_path, cx) - }) - .unwrap_or(input.path.to_string_lossy().into_owned()) - .into(), - Err(raw_input) => { - if let Ok(input) = serde_json::from_value::(raw_input) { - let path = input.path.unwrap_or_default(); - let path = path.trim(); - if !path.is_empty() { - return self - .project - .read(cx) - .find_project_path(&path, cx) - .and_then(|project_path| { - self.project - .read(cx) - .short_full_path_for_project_path(&project_path, cx) - }) - .unwrap_or_else(|| path.to_string()) - .into(); - } - } - - DEFAULT_UI_TEXT.into() + Ok(input) => { + self.session_context + .initial_title_from_path(&input.path, DEFAULT_UI_TEXT, cx) } + Err(raw_input) => initial_title_from_partial_path::( + &self.session_context, + raw_input, + |partial| partial.path.clone(), + DEFAULT_UI_TEXT, + cx, + ), } } @@ -475,41 +253,12 @@ impl AgentTool for EditFileTool { cx: &mut App, ) -> Task> { cx.spawn(async move |cx: &mut AsyncApp| { - match self - .process_streaming_edits(&mut input, &event_stream, cx) - .await - { - EditSessionResult::Completed(session) => { - self.ensure_buffer_saved(&session.buffer, cx).await; - let (new_text, diff) = session.compute_new_text_and_diff(cx).await; - Ok(EditFileToolOutput::Success { - old_text: session.old_text.clone(), - new_text, - input_path: session.input_path, - diff, - }) - } - EditSessionResult::Failed { - error, - session: Some(session), - } => { - self.ensure_buffer_saved(&session.buffer, cx).await; - let (_new_text, diff) = session.compute_new_text_and_diff(cx).await; - Err(EditFileToolOutput::Error { - error, - input_path: Some(session.input_path), - diff, - }) - } - EditSessionResult::Failed { - error, - session: None, - } => Err(EditFileToolOutput::Error { - error, - input_path: None, - diff: String::new(), - }), - } + run_session( + self.process_streaming_edits(&mut input, &event_stream, cx) + .await, + cx, + ) + .await }) } @@ -520,706 +269,7 @@ impl AgentTool for EditFileTool { event_stream: ToolCallEventStream, cx: &mut App, ) -> Result<()> { - match output { - EditFileToolOutput::Success { - input_path, - old_text, - new_text, - .. - } => { - event_stream.update_diff(cx.new(|cx| { - Diff::finalized( - input_path.to_string_lossy().into_owned(), - Some(old_text.to_string()), - new_text, - self.language_registry.clone(), - cx, - ) - })); - Ok(()) - } - EditFileToolOutput::Error { .. } => Ok(()), - } - } -} - -pub struct EditSession { - abs_path: PathBuf, - input_path: PathBuf, - buffer: Entity, - old_text: Arc, - diff: Entity, - parser: StreamingParser, - pipeline: Pipeline, - _finalize_diff_guard: Deferred>, -} - -enum Pipeline { - Write(WritePipeline), - Edit(EditPipeline), -} - -struct WritePipeline { - content_written: bool, -} - -struct EditPipeline { - current_edit: Option, - file_changed_since_last_read: bool, -} - -enum EditPipelineEntry { - ResolvingOldText { - matcher: StreamingFuzzyMatcher, - }, - StreamingNewText { - streaming_diff: StreamingDiff, - edit_cursor: usize, - reindenter: Reindenter, - original_snapshot: text::BufferSnapshot, - }, -} - -impl Pipeline { - fn new(mode: EditFileMode, file_changed_since_last_read: bool) -> Self { - match mode { - EditFileMode::Write => Self::Write(WritePipeline { - content_written: false, - }), - EditFileMode::Edit => Self::Edit(EditPipeline { - current_edit: None, - file_changed_since_last_read, - }), - } - } -} - -impl WritePipeline { - fn process_event( - &mut self, - event: &WriteEvent, - buffer: &Entity, - tool: &EditFileTool, - cx: &mut AsyncApp, - ) { - let WriteEvent::ContentChunk { chunk } = event; - - let (buffer_id, buffer_len) = - buffer.read_with(cx, |buffer, _cx| (buffer.remote_id(), buffer.len())); - let edit_range = if self.content_written { - buffer_len..buffer_len - } else { - 0..buffer_len - }; - - agent_edit_buffer(buffer, [(edit_range, chunk.as_str())], &tool.action_log, cx); - cx.update(|cx| { - tool.set_agent_location( - buffer.downgrade(), - text::Anchor::max_for_buffer(buffer_id), - cx, - ); - }); - self.content_written = true; - } -} - -impl EditPipeline { - fn ensure_resolving_old_text(&mut self, buffer: &Entity, cx: &mut AsyncApp) { - if self.current_edit.is_none() { - let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.text_snapshot()); - self.current_edit = Some(EditPipelineEntry::ResolvingOldText { - matcher: StreamingFuzzyMatcher::new(snapshot), - }); - } - } - - fn process_event( - &mut self, - event: &EditEvent, - buffer: &Entity, - diff: &Entity, - abs_path: &PathBuf, - tool: &EditFileTool, - event_stream: &ToolCallEventStream, - cx: &mut AsyncApp, - ) -> Result<(), String> { - match event { - EditEvent::OldTextChunk { - chunk, done: false, .. - } => { - log::debug!("old_text_chunk: done=false, chunk='{}'", chunk); - self.ensure_resolving_old_text(buffer, cx); - - if let Some(EditPipelineEntry::ResolvingOldText { matcher }) = - &mut self.current_edit - && !chunk.is_empty() - { - if let Some(match_range) = matcher.push(chunk, None) { - let anchor_range = buffer.read_with(cx, |buffer, _cx| { - buffer.anchor_range_outside(match_range.clone()) - }); - diff.update(cx, |diff, cx| diff.reveal_range(anchor_range, cx)); - - cx.update(|cx| { - let position = buffer.read(cx).anchor_before(match_range.end); - tool.set_agent_location(buffer.downgrade(), position, cx); - }); - } - } - } - EditEvent::OldTextChunk { - edit_index, - chunk, - done: true, - } => { - log::debug!("old_text_chunk: done=true, chunk='{}'", chunk); - - self.ensure_resolving_old_text(buffer, cx); - - let Some(EditPipelineEntry::ResolvingOldText { matcher }) = &mut self.current_edit - else { - return Ok(()); - }; - - if !chunk.is_empty() { - matcher.push(chunk, None); - } - let range = extract_match( - matcher.finish(), - buffer, - edit_index, - self.file_changed_since_last_read, - cx, - )?; - - let anchor_range = - buffer.read_with(cx, |buffer, _cx| buffer.anchor_range_outside(range.clone())); - diff.update(cx, |diff, cx| diff.reveal_range(anchor_range, cx)); - - let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot()); - - let line = snapshot.offset_to_point(range.start).row; - event_stream.update_fields( - ToolCallUpdateFields::new() - .locations(vec![ToolCallLocation::new(abs_path).line(Some(line))]), - ); - - let buffer_indent = snapshot.line_indent_for_row(line); - let query_indent = text::LineIndent::from_iter( - matcher - .query_lines() - .first() - .map(|s| s.as_str()) - .unwrap_or("") - .chars(), - ); - let indent_delta = compute_indent_delta(buffer_indent, query_indent); - - let old_text_in_buffer = snapshot.text_for_range(range.clone()).collect::(); - - log::debug!( - "edit[{}] old_text matched at {}..{}: {:?}", - edit_index, - range.start, - range.end, - old_text_in_buffer, - ); - - let text_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.text_snapshot()); - self.current_edit = Some(EditPipelineEntry::StreamingNewText { - streaming_diff: StreamingDiff::new(old_text_in_buffer), - edit_cursor: range.start, - reindenter: Reindenter::new(indent_delta), - original_snapshot: text_snapshot, - }); - - cx.update(|cx| { - let position = buffer.read(cx).anchor_before(range.end); - tool.set_agent_location(buffer.downgrade(), position, cx); - }); - } - EditEvent::NewTextChunk { - chunk, done: false, .. - } => { - log::debug!("new_text_chunk: done=false, chunk='{}'", chunk); - - let Some(EditPipelineEntry::StreamingNewText { - streaming_diff, - edit_cursor, - reindenter, - original_snapshot, - .. - }) = &mut self.current_edit - else { - return Ok(()); - }; - - let reindented = reindenter.push(chunk); - if reindented.is_empty() { - return Ok(()); - } - - let char_ops = streaming_diff.push_new(&reindented); - apply_char_operations( - &char_ops, - buffer, - original_snapshot, - edit_cursor, - &tool.action_log, - cx, - ); - - let position = original_snapshot.anchor_before(*edit_cursor); - cx.update(|cx| { - tool.set_agent_location(buffer.downgrade(), position, cx); - }); - } - EditEvent::NewTextChunk { - chunk, done: true, .. - } => { - log::debug!("new_text_chunk: done=true, chunk='{}'", chunk); - - let Some(EditPipelineEntry::StreamingNewText { - mut streaming_diff, - mut edit_cursor, - mut reindenter, - original_snapshot, - }) = self.current_edit.take() - else { - return Ok(()); - }; - - // Flush any remaining reindent buffer + final chunk. - let mut final_text = reindenter.push(chunk); - final_text.push_str(&reindenter.finish()); - - log::debug!("new_text_chunk: done=true, final_text='{}'", final_text); - - if !final_text.is_empty() { - let char_ops = streaming_diff.push_new(&final_text); - apply_char_operations( - &char_ops, - buffer, - &original_snapshot, - &mut edit_cursor, - &tool.action_log, - cx, - ); - } - - let remaining_ops = streaming_diff.finish(); - apply_char_operations( - &remaining_ops, - buffer, - &original_snapshot, - &mut edit_cursor, - &tool.action_log, - cx, - ); - - let position = original_snapshot.anchor_before(edit_cursor); - cx.update(|cx| { - tool.set_agent_location(buffer.downgrade(), position, cx); - }); - } - } - Ok(()) - } -} - -impl EditSession { - async fn new( - path: PathBuf, - mode: EditFileMode, - tool: &EditFileTool, - event_stream: &ToolCallEventStream, - cx: &mut AsyncApp, - ) -> Result { - let project_path = cx.update(|cx| resolve_path(mode, &path, &tool.project, cx))?; - - let Some(abs_path) = cx.update(|cx| tool.project.read(cx).absolute_path(&project_path, cx)) - else { - return Err(format!( - "Worktree at '{}' does not exist", - path.to_string_lossy() - )); - }; - - event_stream.update_fields( - ToolCallUpdateFields::new().locations(vec![ToolCallLocation::new(abs_path.clone())]), - ); - - cx.update(|cx| tool.authorize(&path, event_stream, cx)) - .await - .map_err(|e| e.to_string())?; - - let buffer = tool - .project - .update(cx, |project, cx| project.open_buffer(project_path, cx)) - .await - .map_err(|e| e.to_string())?; - - let file_changed_since_last_read = ensure_buffer_saved(&buffer, &abs_path, tool, cx)?; - - let diff = cx.new(|cx| Diff::new(buffer.clone(), cx)); - event_stream.update_diff(diff.clone()); - let finalize_diff_guard = util::defer(Box::new({ - let diff = diff.downgrade(); - let mut cx = cx.clone(); - move || { - diff.update(&mut cx, |diff, cx| diff.finalize(cx)).ok(); - } - }) as Box); - - tool.action_log.update(cx, |log, cx| match mode { - EditFileMode::Write => log.buffer_created(buffer.clone(), cx), - EditFileMode::Edit => log.buffer_read(buffer.clone(), cx), - }); - - let old_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot()); - let old_text = cx - .background_spawn({ - let old_snapshot = old_snapshot.clone(); - async move { Arc::new(old_snapshot.text()) } - }) - .await; - - Ok(Self { - abs_path, - input_path: path, - buffer, - old_text, - diff, - parser: StreamingParser::default(), - pipeline: Pipeline::new(mode, file_changed_since_last_read), - _finalize_diff_guard: finalize_diff_guard, - }) - } - - async fn finalize( - &mut self, - input: EditFileToolInput, - tool: &EditFileTool, - event_stream: &ToolCallEventStream, - cx: &mut AsyncApp, - ) -> Result<(), String> { - let Self { - abs_path, - buffer, - diff, - parser, - pipeline, - .. - } = self; - match pipeline { - Pipeline::Write(write) => { - let content = input - .content - .ok_or_else(|| "'content' field is required for write mode".to_string())?; - - for event in &parser.finalize_content(&content) { - write.process_event(event, buffer, tool, cx); - } - } - Pipeline::Edit(edit_pipeline) => { - let edits = input - .edits - .ok_or_else(|| "'edits' field is required for edit mode".to_string())?; - for event in &parser.finalize_edits(&edits) { - edit_pipeline.process_event( - event, - buffer, - diff, - abs_path, - tool, - event_stream, - cx, - )?; - } - - if log::log_enabled!(log::Level::Debug) { - log::debug!("Got edits:"); - for edit in &edits { - log::debug!( - " old_text: '{}', new_text: '{}'", - edit.old_text.replace('\n', "\\n"), - edit.new_text.replace('\n', "\\n") - ); - } - } - } - } - Ok(()) - } - - async fn compute_new_text_and_diff(&self, cx: &mut AsyncApp) -> (String, String) { - let new_snapshot = self.buffer.read_with(cx, |buffer, _cx| buffer.snapshot()); - let (new_text, unified_diff) = cx - .background_spawn({ - let new_snapshot = new_snapshot.clone(); - let old_text = self.old_text.clone(); - async move { - let new_text = new_snapshot.text(); - let diff = language::unified_diff(&old_text, &new_text); - (new_text, diff) - } - }) - .await; - (new_text, unified_diff) - } - - fn process( - &mut self, - partial: EditFileToolPartialInput, - tool: &EditFileTool, - event_stream: &ToolCallEventStream, - cx: &mut AsyncApp, - ) -> Result<(), String> { - let Self { - abs_path, - buffer, - diff, - parser, - pipeline, - .. - } = self; - match pipeline { - Pipeline::Write(write) => { - if let Some(content) = &partial.content { - for event in &parser.push_content(content) { - write.process_event(event, buffer, tool, cx); - } - } - } - Pipeline::Edit(edit_pipeline) => { - if let Some(edits) = partial.edits { - for event in &parser.push_edits(&edits) { - edit_pipeline.process_event( - event, - buffer, - diff, - abs_path, - tool, - event_stream, - cx, - )?; - } - } - } - } - Ok(()) - } -} - -fn apply_char_operations( - ops: &[CharOperation], - buffer: &Entity, - snapshot: &text::BufferSnapshot, - edit_cursor: &mut usize, - action_log: &Entity, - cx: &mut AsyncApp, -) { - for op in ops { - match op { - CharOperation::Insert { text } => { - let anchor = snapshot.anchor_after(*edit_cursor); - agent_edit_buffer(&buffer, [(anchor..anchor, text.as_str())], action_log, cx); - } - CharOperation::Delete { bytes } => { - let delete_end = *edit_cursor + bytes; - let anchor_range = snapshot.anchor_range_inside(*edit_cursor..delete_end); - agent_edit_buffer(&buffer, [(anchor_range, "")], action_log, cx); - *edit_cursor = delete_end; - } - CharOperation::Keep { bytes } => { - *edit_cursor += bytes; - } - } - } -} - -fn extract_match( - matches: Vec>, - buffer: &Entity, - edit_index: &usize, - file_changed_since_last_read: bool, - cx: &mut AsyncApp, -) -> Result, String> { - let file_changed_since_last_read_message = if file_changed_since_last_read { - " The file has changed on disk since you last read it." - } else { - "" - }; - - match matches.len() { - 0 => Err(format!( - "Could not find matching text for edit at index {}. \ - The old_text did not match any content in the file.{} \ - Please read the file again to get the current content.", - edit_index, file_changed_since_last_read_message, - )), - 1 => Ok(matches.into_iter().next().unwrap()), - _ => { - let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot()); - let lines = matches - .iter() - .map(|r| (snapshot.offset_to_point(r.start).row + 1).to_string()) - .collect::>() - .join(", "); - Err(format!( - "Edit {} matched multiple locations in the file at lines: {}. \ - Please provide more context in old_text to uniquely \ - identify the location.", - edit_index, lines - )) - } - } -} - -/// Edits a buffer and reports the edit to the action log in the same effect -/// cycle. This ensures the action log's subscription handler sees the version -/// already updated by `buffer_edited`, so it does not misattribute the agent's -/// edit as a user edit. -fn agent_edit_buffer( - buffer: &Entity, - edits: I, - action_log: &Entity, - cx: &mut AsyncApp, -) where - I: IntoIterator, T)>, - S: ToOffset, - T: Into>, -{ - cx.update(|cx| { - buffer.update(cx, |buffer, cx| { - buffer.edit(edits, None, cx); - }); - action_log.update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx)); - }); -} - -fn ensure_buffer_saved( - buffer: &Entity, - abs_path: &PathBuf, - tool: &EditFileTool, - cx: &mut AsyncApp, -) -> Result { - let last_read_mtime = tool - .action_log - .read_with(cx, |log, _| log.file_read_time(abs_path)); - let check_result = tool.thread.read_with(cx, |thread, cx| { - let current = buffer - .read(cx) - .file() - .and_then(|file| file.disk_state().mtime()); - let dirty = buffer.read(cx).is_dirty(); - let has_save = thread.has_tool(SaveFileTool::NAME); - let has_restore = thread.has_tool(RestoreFileFromDiskTool::NAME); - (current, dirty, has_save, has_restore) - }); - - let Ok((current_mtime, is_dirty, has_save_tool, has_restore_tool)) = check_result else { - return Ok(false); - }; - - if is_dirty { - let message = match (has_save_tool, has_restore_tool) { - (true, true) => { - "This file has unsaved changes. Ask the user whether they want to keep or discard those changes. \ - If they want to keep them, ask for confirmation then use the save_file tool to save the file, then retry this edit. \ - If they want to discard them, ask for confirmation then use the restore_file_from_disk tool to restore the on-disk contents, then retry this edit." - } - (true, false) => { - "This file has unsaved changes. Ask the user whether they want to keep or discard those changes. \ - If they want to keep them, ask for confirmation then use the save_file tool to save the file, then retry this edit. \ - If they want to discard them, ask the user to manually revert the file, then inform you when it's ok to proceed." - } - (false, true) => { - "This file has unsaved changes. Ask the user whether they want to keep or discard those changes. \ - If they want to keep them, ask the user to manually save the file, then inform you when it's ok to proceed. \ - If they want to discard them, ask for confirmation then use the restore_file_from_disk tool to restore the on-disk contents, then retry this edit." - } - (false, false) => { - "This file has unsaved changes. Ask the user whether they want to keep or discard those changes, \ - then ask them to save or revert the file manually and inform you when it's ok to proceed." - } - }; - return Err(message.to_string()); - } - - if let (Some(last_read), Some(current)) = (last_read_mtime, current_mtime) - && current != last_read - { - return Ok(true); - } - - Ok(false) -} - -fn resolve_path( - mode: EditFileMode, - path: &PathBuf, - project: &Entity, - cx: &mut App, -) -> Result { - let project = project.read(cx); - - match mode { - EditFileMode::Edit => { - let path = project - .find_project_path(&path, cx) - .ok_or_else(|| "Can't edit file: path not found".to_string())?; - - let entry = project - .entry_for_path(&path, cx) - .ok_or_else(|| "Can't edit file: path not found".to_string())?; - - if entry.is_file() { - Ok(path) - } else { - Err("Can't edit file: path is a directory".to_string()) - } - } - EditFileMode::Write => { - if let Some(path) = project.find_project_path(&path, cx) - && let Some(entry) = project.entry_for_path(&path, cx) - { - if entry.is_file() { - return Ok(path); - } else { - return Err("Can't write to file: path is a directory".to_string()); - } - } - - let parent_path = path - .parent() - .ok_or_else(|| "Can't create file: incorrect path".to_string())?; - - let parent_project_path = project.find_project_path(&parent_path, cx); - - let parent_entry = parent_project_path - .as_ref() - .and_then(|path| project.entry_for_path(path, cx)) - .ok_or_else(|| "Can't create file: parent directory doesn't exist")?; - - if !parent_entry.is_dir() { - return Err("Can't create file: parent is not a directory".to_string()); - } - - let file_name = path - .file_name() - .and_then(|file_name| file_name.to_str()) - .and_then(|file_name| RelPath::unix(file_name).ok()) - .ok_or_else(|| "Can't create file: invalid filename".to_string())?; - - let new_file_path = parent_project_path.map(|parent| ProjectPath { - path: parent.path.join(file_name), - ..parent - }); - - new_file_path.ok_or_else(|| "Can't create file".to_string()) - } + self.session_context.replay_output(output, event_stream, cx) } } @@ -1228,85 +278,29 @@ mod tests { use super::*; use crate::{ContextServerRegistry, Templates, ToolInputSender}; use fs::Fs as _; - use futures::StreamExt as _; - use gpui::{TestAppContext, UpdateGlobal}; + use gpui::{AppContext as _, TestAppContext, UpdateGlobal}; use language_model::fake_provider::FakeLanguageModel; + use project::ProjectPath; use prompt_store::ProjectContext; - use serde_json::json; - use settings::Settings; - use settings::SettingsStore; - use util::path; - use util::rel_path::rel_path; - - #[gpui::test] - async fn test_streaming_edit_create_file(cx: &mut TestAppContext) { - let (tool, _project, _action_log, _fs, _thread) = setup_test(cx, json!({"dir": {}})).await; - let result = cx - .update(|cx| { - tool.clone().run( - ToolInput::resolved(EditFileToolInput { - path: "root/dir/new_file.txt".into(), - mode: EditFileMode::Write, - content: Some("Hello, World!".into()), - edits: None, - }), - ToolCallEventStream::test().0, - cx, - ) - }) - .await; - - let EditFileToolOutput::Success { new_text, diff, .. } = result.unwrap() else { - panic!("expected success"); - }; - assert_eq!(new_text, "Hello, World!"); - assert!(!diff.is_empty()); - } - - #[gpui::test] - async fn test_streaming_edit_overwrite_file(cx: &mut TestAppContext) { - let (tool, _project, _action_log, _fs, _thread) = - setup_test(cx, json!({"file.txt": "old content"})).await; - let result = cx - .update(|cx| { - tool.clone().run( - ToolInput::resolved(EditFileToolInput { - path: "root/file.txt".into(), - mode: EditFileMode::Write, - content: Some("new content".into()), - edits: None, - }), - ToolCallEventStream::test().0, - cx, - ) - }) - .await; - - let EditFileToolOutput::Success { - new_text, old_text, .. - } = result.unwrap() - else { - panic!("expected success"); - }; - assert_eq!(new_text, "new content"); - assert_eq!(*old_text, "old content"); - } + use serde_json::json; + use settings::Settings; + use settings::SettingsStore; + use util::path; + use util::rel_path::{RelPath, rel_path}; #[gpui::test] async fn test_streaming_edit_granular_edits(cx: &mut TestAppContext) { - let (tool, _project, _action_log, _fs, _thread) = + let (edit_tool, _project, _action_log, _fs, _thread) = setup_test(cx, json!({"file.txt": "line 1\nline 2\nline 3\n"})).await; let result = cx .update(|cx| { - tool.clone().run( + edit_tool.clone().run( ToolInput::resolved(EditFileToolInput { path: "root/file.txt".into(), - mode: EditFileMode::Edit, - content: None, - edits: Some(vec![Edit { + edits: vec![Edit { old_text: "line 2".into(), new_text: "modified line 2".into(), - }]), + }], }), ToolCallEventStream::test().0, cx, @@ -1322,19 +316,17 @@ mod tests { #[gpui::test] async fn test_streaming_edit_multiple_edits(cx: &mut TestAppContext) { - let (tool, _project, _action_log, _fs, _thread) = setup_test( + let (edit_tool, _project, _action_log, _fs, _thread) = setup_test( cx, json!({"file.txt": "line 1\nline 2\nline 3\nline 4\nline 5\n"}), ) .await; let result = cx .update(|cx| { - tool.clone().run( + edit_tool.clone().run( ToolInput::resolved(EditFileToolInput { path: "root/file.txt".into(), - mode: EditFileMode::Edit, - content: None, - edits: Some(vec![ + edits: vec![ Edit { old_text: "line 5".into(), new_text: "modified line 5".into(), @@ -1343,7 +335,7 @@ mod tests { old_text: "line 1".into(), new_text: "modified line 1".into(), }, - ]), + ], }), ToolCallEventStream::test().0, cx, @@ -1362,19 +354,17 @@ mod tests { #[gpui::test] async fn test_streaming_edit_adjacent_edits(cx: &mut TestAppContext) { - let (tool, _project, _action_log, _fs, _thread) = setup_test( + let (edit_tool, _project, _action_log, _fs, _thread) = setup_test( cx, json!({"file.txt": "line 1\nline 2\nline 3\nline 4\nline 5\n"}), ) .await; let result = cx .update(|cx| { - tool.clone().run( + edit_tool.clone().run( ToolInput::resolved(EditFileToolInput { path: "root/file.txt".into(), - mode: EditFileMode::Edit, - content: None, - edits: Some(vec![ + edits: vec![ Edit { old_text: "line 2".into(), new_text: "modified line 2".into(), @@ -1383,7 +373,7 @@ mod tests { old_text: "line 3".into(), new_text: "modified line 3".into(), }, - ]), + ], }), ToolCallEventStream::test().0, cx, @@ -1402,19 +392,17 @@ mod tests { #[gpui::test] async fn test_streaming_edit_ascending_order_edits(cx: &mut TestAppContext) { - let (tool, _project, _action_log, _fs, _thread) = setup_test( + let (edit_tool, _project, _action_log, _fs, _thread) = setup_test( cx, json!({"file.txt": "line 1\nline 2\nline 3\nline 4\nline 5\n"}), ) .await; let result = cx .update(|cx| { - tool.clone().run( + edit_tool.clone().run( ToolInput::resolved(EditFileToolInput { path: "root/file.txt".into(), - mode: EditFileMode::Edit, - content: None, - edits: Some(vec![ + edits: vec![ Edit { old_text: "line 1".into(), new_text: "modified line 1".into(), @@ -1423,7 +411,7 @@ mod tests { old_text: "line 5".into(), new_text: "modified line 5".into(), }, - ]), + ], }), ToolCallEventStream::test().0, cx, @@ -1442,18 +430,16 @@ mod tests { #[gpui::test] async fn test_streaming_edit_nonexistent_file(cx: &mut TestAppContext) { - let (tool, _project, _action_log, _fs, _thread) = setup_test(cx, json!({})).await; + let (edit_tool, _project, _action_log, _fs, _thread) = setup_test(cx, json!({})).await; let result = cx .update(|cx| { - tool.clone().run( + edit_tool.clone().run( ToolInput::resolved(EditFileToolInput { path: "root/nonexistent_file.txt".into(), - mode: EditFileMode::Edit, - content: None, - edits: Some(vec![Edit { + edits: vec![Edit { old_text: "foo".into(), new_text: "bar".into(), - }]), + }], }), ToolCallEventStream::test().0, cx, @@ -1476,19 +462,17 @@ mod tests { #[gpui::test] async fn test_streaming_edit_failed_match(cx: &mut TestAppContext) { - let (tool, _project, _action_log, _fs, _thread) = + let (edit_tool, _project, _action_log, _fs, _thread) = setup_test(cx, json!({"file.txt": "hello world"})).await; let result = cx .update(|cx| { - tool.clone().run( + edit_tool.clone().run( ToolInput::resolved(EditFileToolInput { path: "root/file.txt".into(), - mode: EditFileMode::Edit, - content: None, - edits: Some(vec![Edit { + edits: vec![Edit { old_text: "nonexistent text that is not in the file".into(), new_text: "replacement".into(), - }]), + }], }), ToolCallEventStream::test().0, cx, @@ -1507,11 +491,11 @@ mod tests { #[gpui::test] async fn test_streaming_early_buffer_open(cx: &mut TestAppContext) { - let (tool, _project, _action_log, _fs, _thread) = + let (edit_tool, _project, _action_log, _fs, _thread) = setup_test(cx, json!({"file.txt": "line 1\nline 2\nline 3\n"})).await; let (mut sender, input) = ToolInput::::test(); let (event_stream, _receiver) = ToolCallEventStream::test(); - let task = cx.update(|cx| tool.clone().run(input, event_stream, cx)); + let task = cx.update(|cx| edit_tool.clone().run(input, event_stream, cx)); // Send partials simulating LLM streaming: description first, then path, then mode sender.send_partial(json!({})); @@ -1525,14 +509,12 @@ mod tests { // Path is NOT yet complete because mode hasn't appeared — no buffer open yet sender.send_partial(json!({ "path": "root/file.txt", - "mode": "edit" })); cx.run_until_parked(); // Now send the final complete input sender.send_full(json!({ "path": "root/file.txt", - "mode": "edit", "edits": [{"old_text": "line 2", "new_text": "modified line 2"}] })); @@ -1543,49 +525,14 @@ mod tests { assert_eq!(new_text, "line 1\nmodified line 2\nline 3\n"); } - #[gpui::test] - async fn test_streaming_path_completeness_heuristic(cx: &mut TestAppContext) { - let (tool, _project, _action_log, _fs, _thread) = - setup_test(cx, json!({"file.txt": "hello world"})).await; - let (mut sender, input) = ToolInput::::test(); - let (event_stream, _receiver) = ToolCallEventStream::test(); - let task = cx.update(|cx| tool.clone().run(input, event_stream, cx)); - - // Send partial with path but NO mode — path should NOT be treated as complete - sender.send_partial(json!({ - "path": "root/file" - })); - cx.run_until_parked(); - - // Now the path grows and mode appears - sender.send_partial(json!({ - "path": "root/file.txt", - "mode": "write" - })); - cx.run_until_parked(); - - // Send final - sender.send_full(json!({ - "path": "root/file.txt", - "mode": "write", - "content": "new content" - })); - - let result = task.await; - let EditFileToolOutput::Success { new_text, .. } = result.unwrap() else { - panic!("expected success"); - }; - assert_eq!(new_text, "new content"); - } - #[gpui::test] async fn test_streaming_cancellation_during_partials(cx: &mut TestAppContext) { - let (tool, _project, _action_log, _fs, _thread) = + let (edit_tool, _project, _action_log, _fs, _thread) = setup_test(cx, json!({"file.txt": "hello world"})).await; let (mut sender, input) = ToolInput::::test(); let (event_stream, _receiver, mut cancellation_tx) = ToolCallEventStream::test_with_cancellation(); - let task = cx.update(|cx| tool.clone().run(input, event_stream, cx)); + let task = cx.update(|cx| edit_tool.clone().run(input, event_stream, cx)); // Send a partial sender.send_partial(json!({})); @@ -1611,14 +558,14 @@ mod tests { #[gpui::test] async fn test_streaming_edit_with_multiple_partials(cx: &mut TestAppContext) { - let (tool, _project, _action_log, _fs, _thread) = setup_test( + let (edit_tool, _project, _action_log, _fs, _thread) = setup_test( cx, json!({"file.txt": "line 1\nline 2\nline 3\nline 4\nline 5\n"}), ) .await; let (mut sender, input) = ToolInput::::test(); let (event_stream, _receiver) = ToolCallEventStream::test(); - let task = cx.update(|cx| tool.clone().run(input, event_stream, cx)); + let task = cx.update(|cx| edit_tool.clone().run(input, event_stream, cx)); // Simulate fine-grained streaming of the JSON sender.send_partial(json!({})); @@ -1631,20 +578,17 @@ mod tests { sender.send_partial(json!({ "path": "root/file.txt", - "mode": "edit" })); cx.run_until_parked(); sender.send_partial(json!({ "path": "root/file.txt", - "mode": "edit", "edits": [{"old_text": "line 1"}] })); cx.run_until_parked(); sender.send_partial(json!({ "path": "root/file.txt", - "mode": "edit", "edits": [ {"old_text": "line 1", "new_text": "modified line 1"}, {"old_text": "line 5"} @@ -1655,7 +599,6 @@ mod tests { // Send final complete input sender.send_full(json!({ "path": "root/file.txt", - "mode": "edit", "edits": [ {"old_text": "line 1", "new_text": "modified line 1"}, {"old_text": "line 5", "new_text": "modified line 5"} @@ -1672,56 +615,17 @@ mod tests { ); } - #[gpui::test] - async fn test_streaming_create_file_with_partials(cx: &mut TestAppContext) { - let (tool, _project, _action_log, _fs, _thread) = setup_test(cx, json!({"dir": {}})).await; - let (mut sender, input) = ToolInput::::test(); - let (event_stream, _receiver) = ToolCallEventStream::test(); - let task = cx.update(|cx| tool.clone().run(input, event_stream, cx)); - - // Stream partials for create mode - sender.send_partial(json!({})); - cx.run_until_parked(); - - sender.send_partial(json!({ - "path": "root/dir/new_file.txt", - "mode": "write" - })); - cx.run_until_parked(); - - sender.send_partial(json!({ - "path": "root/dir/new_file.txt", - "mode": "write", - "content": "Hello, " - })); - cx.run_until_parked(); - - // Final with full content - sender.send_full(json!({ - "path": "root/dir/new_file.txt", - "mode": "write", - "content": "Hello, World!" - })); - - let result = task.await; - let EditFileToolOutput::Success { new_text, .. } = result.unwrap() else { - panic!("expected success"); - }; - assert_eq!(new_text, "Hello, World!"); - } - #[gpui::test] async fn test_streaming_no_partials_direct_final(cx: &mut TestAppContext) { - let (tool, _project, _action_log, _fs, _thread) = + let (edit_tool, _project, _action_log, _fs, _thread) = setup_test(cx, json!({"file.txt": "line 1\nline 2\nline 3\n"})).await; let (mut sender, input) = ToolInput::::test(); let (event_stream, _receiver) = ToolCallEventStream::test(); - let task = cx.update(|cx| tool.clone().run(input, event_stream, cx)); + let task = cx.update(|cx| edit_tool.clone().run(input, event_stream, cx)); // Send final immediately with no partials (simulates non-streaming path) sender.send_full(json!({ "path": "root/file.txt", - "mode": "edit", "edits": [{"old_text": "line 2", "new_text": "modified line 2"}] })); @@ -1734,14 +638,14 @@ mod tests { #[gpui::test] async fn test_streaming_incremental_edit_application(cx: &mut TestAppContext) { - let (tool, project, _action_log, _fs, _thread) = setup_test( + let (edit_tool, project, _action_log, _fs, _thread) = setup_test( cx, json!({"file.txt": "line 1\nline 2\nline 3\nline 4\nline 5\n"}), ) .await; let (mut sender, input) = ToolInput::::test(); let (event_stream, _receiver) = ToolCallEventStream::test(); - let task = cx.update(|cx| tool.clone().run(input, event_stream, cx)); + let task = cx.update(|cx| edit_tool.clone().run(input, event_stream, cx)); // Stream description, path, mode sender.send_partial(json!({})); @@ -1749,14 +653,12 @@ mod tests { sender.send_partial(json!({ "path": "root/file.txt", - "mode": "edit" })); cx.run_until_parked(); // First edit starts streaming (old_text only, still in progress) sender.send_partial(json!({ "path": "root/file.txt", - "mode": "edit", "edits": [{"old_text": "line 1"}] })); cx.run_until_parked(); @@ -1782,7 +684,6 @@ mod tests { // should be applied immediately during streaming sender.send_partial(json!({ "path": "root/file.txt", - "mode": "edit", "edits": [ {"old_text": "line 1", "new_text": "MODIFIED 1"}, {"old_text": "line 5"} @@ -1808,7 +709,6 @@ mod tests { // Send final complete input sender.send_full(json!({ "path": "root/file.txt", - "mode": "edit", "edits": [ {"old_text": "line 1", "new_text": "MODIFIED 1"}, {"old_text": "line 5", "new_text": "MODIFIED 5"} @@ -1831,23 +731,21 @@ mod tests { #[gpui::test] async fn test_streaming_incremental_three_edits(cx: &mut TestAppContext) { - let (tool, project, _action_log, _fs, _thread) = + let (edit_tool, project, _action_log, _fs, _thread) = setup_test(cx, json!({"file.txt": "aaa\nbbb\nccc\nddd\neee\n"})).await; let (mut sender, input) = ToolInput::::test(); let (event_stream, _receiver) = ToolCallEventStream::test(); - let task = cx.update(|cx| tool.clone().run(input, event_stream, cx)); + let task = cx.update(|cx| edit_tool.clone().run(input, event_stream, cx)); // Setup: description + path + mode sender.send_partial(json!({ "path": "root/file.txt", - "mode": "edit" })); cx.run_until_parked(); // Edit 1 in progress sender.send_partial(json!({ "path": "root/file.txt", - "mode": "edit", "edits": [{"old_text": "aaa", "new_text": "AAA"}] })); cx.run_until_parked(); @@ -1855,7 +753,6 @@ mod tests { // Edit 2 appears — edit 1 is now complete and should be applied sender.send_partial(json!({ "path": "root/file.txt", - "mode": "edit", "edits": [ {"old_text": "aaa", "new_text": "AAA"}, {"old_text": "ccc", "new_text": "CCC"} @@ -1877,7 +774,6 @@ mod tests { // Edit 3 appears — edit 2 is now complete and should be applied sender.send_partial(json!({ "path": "root/file.txt", - "mode": "edit", "edits": [ {"old_text": "aaa", "new_text": "AAA"}, {"old_text": "ccc", "new_text": "CCC"}, @@ -1899,7 +795,6 @@ mod tests { // Send final sender.send_full(json!({ "path": "root/file.txt", - "mode": "edit", "edits": [ {"old_text": "aaa", "new_text": "AAA"}, {"old_text": "ccc", "new_text": "CCC"}, @@ -1916,23 +811,21 @@ mod tests { #[gpui::test] async fn test_streaming_edit_failure_mid_stream(cx: &mut TestAppContext) { - let (tool, project, _action_log, _fs, _thread) = + let (edit_tool, project, _action_log, _fs, _thread) = setup_test(cx, json!({"file.txt": "line 1\nline 2\nline 3\n"})).await; let (mut sender, input) = ToolInput::::test(); let (event_stream, _receiver) = ToolCallEventStream::test(); - let task = cx.update(|cx| tool.clone().run(input, event_stream, cx)); + let task = cx.update(|cx| edit_tool.clone().run(input, event_stream, cx)); // Setup sender.send_partial(json!({ "path": "root/file.txt", - "mode": "edit" })); cx.run_until_parked(); // Edit 1 (valid) in progress — not yet complete (no second edit) sender.send_partial(json!({ "path": "root/file.txt", - "mode": "edit", "edits": [ {"old_text": "line 1", "new_text": "MODIFIED"} ] @@ -1943,7 +836,6 @@ mod tests { // Edit 1 should be applied. Edit 2 is still in-progress (last edit). sender.send_partial(json!({ "path": "root/file.txt", - "mode": "edit", "edits": [ {"old_text": "line 1", "new_text": "MODIFIED"}, {"old_text": "nonexistent text that does not appear anywhere in the file at all", "new_text": "whatever"} @@ -1969,7 +861,6 @@ mod tests { // resolution which should fail (old_text doesn't exist in the file). sender.send_partial(json!({ "path": "root/file.txt", - "mode": "edit", "edits": [ {"old_text": "line 1", "new_text": "MODIFIED"}, {"old_text": "nonexistent text that does not appear anywhere in the file at all", "new_text": "whatever"}, @@ -2006,22 +897,20 @@ mod tests { #[gpui::test] async fn test_streaming_single_edit_no_incremental(cx: &mut TestAppContext) { - let (tool, project, _action_log, _fs, _thread) = + let (edit_tool, project, _action_log, _fs, _thread) = setup_test(cx, json!({"file.txt": "hello world\n"})).await; let (mut sender, input) = ToolInput::::test(); let (event_stream, _receiver) = ToolCallEventStream::test(); - let task = cx.update(|cx| tool.clone().run(input, event_stream, cx)); + let task = cx.update(|cx| edit_tool.clone().run(input, event_stream, cx)); // Setup + single edit that stays in-progress (no second edit to prove completion) sender.send_partial(json!({ "path": "root/file.txt", - "mode": "edit", })); cx.run_until_parked(); sender.send_partial(json!({ "path": "root/file.txt", - "mode": "edit", "edits": [{"old_text": "hello world", "new_text": "goodbye world"}] })); cx.run_until_parked(); @@ -2045,7 +934,6 @@ mod tests { // Send final — the edit is applied during finalization sender.send_full(json!({ "path": "root/file.txt", - "mode": "edit", "edits": [{"old_text": "hello world", "new_text": "goodbye world"}] })); @@ -2058,12 +946,12 @@ mod tests { #[gpui::test] async fn test_streaming_input_partials_then_final(cx: &mut TestAppContext) { - let (tool, _project, _action_log, _fs, _thread) = + let (edit_tool, _project, _action_log, _fs, _thread) = setup_test(cx, json!({"file.txt": "line 1\nline 2\nline 3\n"})).await; let (mut sender, input): (ToolInputSender, ToolInput) = ToolInput::test(); let (event_stream, _event_rx) = ToolCallEventStream::test(); - let task = cx.update(|cx| tool.clone().run(input, event_stream, cx)); + let task = cx.update(|cx| edit_tool.clone().run(input, event_stream, cx)); // Send progressively more complete partial snapshots, as the LLM would sender.send_partial(json!({})); @@ -2071,13 +959,11 @@ mod tests { sender.send_partial(json!({ "path": "root/file.txt", - "mode": "edit" })); cx.run_until_parked(); sender.send_partial(json!({ "path": "root/file.txt", - "mode": "edit", "edits": [{"old_text": "line 2", "new_text": "modified line 2"}] })); cx.run_until_parked(); @@ -2085,7 +971,6 @@ mod tests { // Send the final complete input sender.send_full(json!({ "path": "root/file.txt", - "mode": "edit", "edits": [{"old_text": "line 2", "new_text": "modified line 2"}] })); @@ -2098,12 +983,12 @@ mod tests { #[gpui::test] async fn test_streaming_input_sender_dropped_before_final(cx: &mut TestAppContext) { - let (tool, _project, _action_log, _fs, _thread) = + let (edit_tool, _project, _action_log, _fs, _thread) = setup_test(cx, json!({"file.txt": "hello world\n"})).await; let (mut sender, input): (ToolInputSender, ToolInput) = ToolInput::test(); let (event_stream, _event_rx) = ToolCallEventStream::test(); - let task = cx.update(|cx| tool.clone().run(input, event_stream, cx)); + let task = cx.update(|cx| edit_tool.clone().run(input, event_stream, cx)); // Send a partial then drop the sender without sending final sender.send_partial(json!({})); @@ -2118,69 +1003,9 @@ mod tests { ); } - #[gpui::test] - async fn test_streaming_input_recv_drains_partials(cx: &mut TestAppContext) { - let (tool, _project, _action_log, _fs, _thread) = setup_test(cx, json!({"dir": {}})).await; - // Create a channel and send multiple partials before a final, then use - // ToolInput::resolved-style immediate delivery to confirm recv() works - // when partials are already buffered. - let (mut sender, input): (ToolInputSender, ToolInput) = - ToolInput::test(); - let (event_stream, _event_rx) = ToolCallEventStream::test(); - let task = cx.update(|cx| tool.clone().run(input, event_stream, cx)); - - // Buffer several partials before sending the final - sender.send_partial(json!({})); - sender.send_partial(json!({"path": "root/dir/new.txt"})); - sender.send_partial(json!({ - "path": "root/dir/new.txt", - "mode": "write" - })); - sender.send_full(json!({ - "path": "root/dir/new.txt", - "mode": "write", - "content": "streamed content" - })); - - let result = task.await; - let EditFileToolOutput::Success { new_text, .. } = result.unwrap() else { - panic!("expected success"); - }; - assert_eq!(new_text, "streamed content"); - } - - #[gpui::test] - async fn test_streaming_resolve_path_for_creating_file(cx: &mut TestAppContext) { - let mode = EditFileMode::Write; - - let result = test_resolve_path(&mode, "root/new.txt", cx); - assert_resolved_path_eq(result.await, rel_path("new.txt")); - - let result = test_resolve_path(&mode, "new.txt", cx); - assert_resolved_path_eq(result.await, rel_path("new.txt")); - - let result = test_resolve_path(&mode, "dir/new.txt", cx); - assert_resolved_path_eq(result.await, rel_path("dir/new.txt")); - - let result = test_resolve_path(&mode, "root/dir/subdir/existing.txt", cx); - assert_resolved_path_eq(result.await, rel_path("dir/subdir/existing.txt")); - - let result = test_resolve_path(&mode, "root/dir/subdir", cx); - assert_eq!( - result.await.unwrap_err(), - "Can't write to file: path is a directory" - ); - - let result = test_resolve_path(&mode, "root/dir/nonexistent_dir/new.txt", cx); - assert_eq!( - result.await.unwrap_err(), - "Can't create file: parent directory doesn't exist" - ); - } - #[gpui::test] async fn test_streaming_resolve_path_for_editing_file(cx: &mut TestAppContext) { - let mode = EditFileMode::Edit; + let mode = EditSessionMode::Edit; let path_with_root = "root/dir/subdir/existing.txt"; let path_without_root = "dir/subdir/existing.txt"; @@ -2201,7 +1026,7 @@ mod tests { } async fn test_resolve_path( - mode: &EditFileMode, + mode: &EditSessionMode, path: &str, cx: &mut TestAppContext, ) -> Result { @@ -2221,7 +1046,7 @@ mod tests { .await; let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; - cx.update(|cx| resolve_path(*mode, &PathBuf::from(path), &project, cx)) + crate::tools::edit_session::test_resolve_path(mode, path, &project, cx).await } #[track_caller] @@ -2230,290 +1055,14 @@ mod tests { assert_eq!(actual.as_ref(), expected); } - #[gpui::test] - async fn test_streaming_format_on_save(cx: &mut TestAppContext) { - init_test(cx); - - let fs = project::FakeFs::new(cx.executor()); - fs.insert_tree("/root", json!({"src": {}})).await; - let (tool, project, action_log, fs, thread) = - setup_test_with_fs(cx, fs, &[path!("/root").as_ref()]).await; - - let rust_language = Arc::new(language::Language::new( - language::LanguageConfig { - name: "Rust".into(), - matcher: language::LanguageMatcher { - path_suffixes: vec!["rs".to_string()], - ..Default::default() - }, - ..Default::default() - }, - None, - )); - - let language_registry = project.read_with(cx, |project, _| project.languages().clone()); - language_registry.add(rust_language); - - let mut fake_language_servers = language_registry.register_fake_lsp( - "Rust", - language::FakeLspAdapter { - capabilities: lsp::ServerCapabilities { - document_formatting_provider: Some(lsp::OneOf::Left(true)), - ..Default::default() - }, - ..Default::default() - }, - ); - - fs.save( - path!("/root/src/main.rs").as_ref(), - &"initial content".into(), - language::LineEnding::Unix, - ) - .await - .unwrap(); - - // Open the buffer to trigger LSP initialization - let buffer = project - .update(cx, |project, cx| { - project.open_local_buffer(path!("/root/src/main.rs"), cx) - }) - .await - .unwrap(); - - // Register the buffer with language servers - let _handle = project.update(cx, |project, cx| { - project.register_buffer_with_language_servers(&buffer, cx) - }); - - const UNFORMATTED_CONTENT: &str = "fn main() {println!(\"Hello!\");}\ -"; - const FORMATTED_CONTENT: &str = "This file was formatted by the fake formatter in the test.\ -"; - - // Get the fake language server and set up formatting handler - let fake_language_server = fake_language_servers.next().await.unwrap(); - fake_language_server.set_request_handler::({ - |_, _| async move { - Ok(Some(vec![lsp::TextEdit { - range: lsp::Range::new(lsp::Position::new(0, 0), lsp::Position::new(1, 0)), - new_text: FORMATTED_CONTENT.to_string(), - }])) - } - }); - - // Test with format_on_save enabled - cx.update(|cx| { - SettingsStore::update_global(cx, |store, cx| { - store.update_user_settings(cx, |settings| { - settings.project.all_languages.defaults.format_on_save = Some(FormatOnSave::On); - settings.project.all_languages.defaults.formatter = - Some(language::language_settings::FormatterList::default()); - }); - }); - }); - - // Use streaming pattern so executor can pump the LSP request/response - let (mut sender, input) = ToolInput::::test(); - let (event_stream, _receiver) = ToolCallEventStream::test(); - - let task = cx.update(|cx| tool.clone().run(input, event_stream, cx)); - - sender.send_partial(json!({ - "path": "root/src/main.rs", - "mode": "write" - })); - cx.run_until_parked(); - - sender.send_full(json!({ - "path": "root/src/main.rs", - "mode": "write", - "content": UNFORMATTED_CONTENT - })); - - let result = task.await; - assert!(result.is_ok()); - - cx.executor().run_until_parked(); - - let new_content = fs.load(path!("/root/src/main.rs").as_ref()).await.unwrap(); - assert_eq!( - new_content.replace("\r\n", "\n"), - FORMATTED_CONTENT, - "Code should be formatted when format_on_save is enabled" - ); - - let stale_buffer_count = thread - .read_with(cx, |thread, _cx| thread.action_log.clone()) - .read_with(cx, |log, cx| log.stale_buffers(cx).count()); - - assert_eq!( - stale_buffer_count, 0, - "BUG: Buffer is incorrectly marked as stale after format-on-save. Found {} stale buffers.", - stale_buffer_count - ); - - // Test with format_on_save disabled - cx.update(|cx| { - SettingsStore::update_global(cx, |store, cx| { - store.update_user_settings(cx, |settings| { - settings.project.all_languages.defaults.format_on_save = - Some(FormatOnSave::Off); - }); - }); - }); - - let (mut sender, input) = ToolInput::::test(); - let (event_stream, _receiver) = ToolCallEventStream::test(); - - let tool2 = Arc::new(EditFileTool::new( - project.clone(), - thread.downgrade(), - action_log.clone(), - language_registry, - )); - - let task = cx.update(|cx| tool2.run(input, event_stream, cx)); - - sender.send_partial(json!({ - "path": "root/src/main.rs", - "mode": "write" - })); - cx.run_until_parked(); - - sender.send_full(json!({ - "path": "root/src/main.rs", - "mode": "write", - "content": UNFORMATTED_CONTENT - })); - - let result = task.await; - assert!(result.is_ok()); - - cx.executor().run_until_parked(); - - let new_content = fs.load(path!("/root/src/main.rs").as_ref()).await.unwrap(); - assert_eq!( - new_content.replace("\r\n", "\n"), - UNFORMATTED_CONTENT, - "Code should not be formatted when format_on_save is disabled" - ); - } - - #[gpui::test] - async fn test_streaming_remove_trailing_whitespace(cx: &mut TestAppContext) { - init_test(cx); - - let fs = project::FakeFs::new(cx.executor()); - fs.insert_tree("/root", json!({"src": {}})).await; - fs.save( - path!("/root/src/main.rs").as_ref(), - &"initial content".into(), - language::LineEnding::Unix, - ) - .await - .unwrap(); - let (tool, project, action_log, fs, thread) = - setup_test_with_fs(cx, fs, &[path!("/root").as_ref()]).await; - let language_registry = project.read_with(cx, |p, _cx| p.languages().clone()); - - // Test with remove_trailing_whitespace_on_save enabled - cx.update(|cx| { - SettingsStore::update_global(cx, |store, cx| { - store.update_user_settings(cx, |settings| { - settings - .project - .all_languages - .defaults - .remove_trailing_whitespace_on_save = Some(true); - }); - }); - }); - - const CONTENT_WITH_TRAILING_WHITESPACE: &str = - "fn main() { \n println!(\"Hello!\"); \n}\n"; - - let result = cx - .update(|cx| { - tool.clone().run( - ToolInput::resolved(EditFileToolInput { - path: "root/src/main.rs".into(), - mode: EditFileMode::Write, - content: Some(CONTENT_WITH_TRAILING_WHITESPACE.into()), - edits: None, - }), - ToolCallEventStream::test().0, - cx, - ) - }) - .await; - assert!(result.is_ok()); - - cx.executor().run_until_parked(); - - assert_eq!( - fs.load(path!("/root/src/main.rs").as_ref()) - .await - .unwrap() - .replace("\r\n", "\n"), - "fn main() {\n println!(\"Hello!\");\n}\n", - "Trailing whitespace should be removed when remove_trailing_whitespace_on_save is enabled" - ); - - // Test with remove_trailing_whitespace_on_save disabled - cx.update(|cx| { - SettingsStore::update_global(cx, |store, cx| { - store.update_user_settings(cx, |settings| { - settings - .project - .all_languages - .defaults - .remove_trailing_whitespace_on_save = Some(false); - }); - }); - }); - - let tool2 = Arc::new(EditFileTool::new( - project.clone(), - thread.downgrade(), - action_log.clone(), - language_registry, - )); - - let result = cx - .update(|cx| { - tool2.run( - ToolInput::resolved(EditFileToolInput { - path: "root/src/main.rs".into(), - mode: EditFileMode::Write, - content: Some(CONTENT_WITH_TRAILING_WHITESPACE.into()), - edits: None, - }), - ToolCallEventStream::test().0, - cx, - ) - }) - .await; - assert!(result.is_ok()); - - cx.executor().run_until_parked(); - - let final_content = fs.load(path!("/root/src/main.rs").as_ref()).await.unwrap(); - assert_eq!( - final_content.replace("\r\n", "\n"), - CONTENT_WITH_TRAILING_WHITESPACE, - "Trailing whitespace should remain when remove_trailing_whitespace_on_save is disabled" - ); - } - #[gpui::test] async fn test_streaming_authorize(cx: &mut TestAppContext) { - let (tool, _project, _action_log, _fs, _thread) = setup_test(cx, json!({})).await; + let (edit_tool, _project, _action_log, _fs, _thread) = setup_test(cx, json!({})).await; // Test 1: Path with .zed component should require confirmation let (stream_tx, mut stream_rx) = ToolCallEventStream::test(); - let _auth = - cx.update(|cx| tool.authorize(&PathBuf::from(".zed/settings.json"), &stream_tx, cx)); + let _auth = cx + .update(|cx| edit_tool.authorize(&PathBuf::from(".zed/settings.json"), &stream_tx, cx)); let event = stream_rx.expect_authorization().await; assert_eq!( @@ -2523,7 +1072,8 @@ mod tests { // Test 2: Path outside project should require confirmation let (stream_tx, mut stream_rx) = ToolCallEventStream::test(); - let _auth = cx.update(|cx| tool.authorize(&PathBuf::from("/etc/hosts"), &stream_tx, cx)); + let _auth = + cx.update(|cx| edit_tool.authorize(&PathBuf::from("/etc/hosts"), &stream_tx, cx)); let event = stream_rx.expect_authorization().await; assert_eq!( @@ -2533,15 +1083,16 @@ mod tests { // Test 3: Relative path without .zed should not require confirmation let (stream_tx, mut stream_rx) = ToolCallEventStream::test(); - cx.update(|cx| tool.authorize(&PathBuf::from("root/src/main.rs"), &stream_tx, cx)) + cx.update(|cx| edit_tool.authorize(&PathBuf::from("root/src/main.rs"), &stream_tx, cx)) .await .unwrap(); assert!(stream_rx.try_recv().is_err()); // Test 4: Path with .zed in the middle should require confirmation let (stream_tx, mut stream_rx) = ToolCallEventStream::test(); - let _auth = - cx.update(|cx| tool.authorize(&PathBuf::from("root/.zed/tasks.json"), &stream_tx, cx)); + let _auth = cx.update(|cx| { + edit_tool.authorize(&PathBuf::from("root/.zed/tasks.json"), &stream_tx, cx) + }); let event = stream_rx.expect_authorization().await; assert_eq!( event.tool_call.fields.title, @@ -2558,8 +1109,8 @@ mod tests { // 5.1: .zed/settings.json is a sensitive path — still prompts let (stream_tx, mut stream_rx) = ToolCallEventStream::test(); - let _auth = - cx.update(|cx| tool.authorize(&PathBuf::from(".zed/settings.json"), &stream_tx, cx)); + let _auth = cx + .update(|cx| edit_tool.authorize(&PathBuf::from(".zed/settings.json"), &stream_tx, cx)); let event = stream_rx.expect_authorization().await; assert_eq!( event.tool_call.fields.title, @@ -2568,14 +1119,14 @@ mod tests { // 5.2: /etc/hosts is outside the project, but Allow auto-approves let (stream_tx, mut stream_rx) = ToolCallEventStream::test(); - cx.update(|cx| tool.authorize(&PathBuf::from("/etc/hosts"), &stream_tx, cx)) + cx.update(|cx| edit_tool.authorize(&PathBuf::from("/etc/hosts"), &stream_tx, cx)) .await .unwrap(); assert!(stream_rx.try_recv().is_err()); // 5.3: Normal in-project path with allow — no confirmation needed let (stream_tx, mut stream_rx) = ToolCallEventStream::test(); - cx.update(|cx| tool.authorize(&PathBuf::from("root/src/main.rs"), &stream_tx, cx)) + cx.update(|cx| edit_tool.authorize(&PathBuf::from("root/src/main.rs"), &stream_tx, cx)) .await .unwrap(); assert!(stream_rx.try_recv().is_err()); @@ -2588,7 +1139,8 @@ mod tests { }); let (stream_tx, mut stream_rx) = ToolCallEventStream::test(); - let _auth = cx.update(|cx| tool.authorize(&PathBuf::from("/etc/hosts"), &stream_tx, cx)); + let _auth = + cx.update(|cx| edit_tool.authorize(&PathBuf::from("/etc/hosts"), &stream_tx, cx)); let event = stream_rx.expect_authorization().await; assert_eq!( @@ -2606,7 +1158,7 @@ mod tests { fs.insert_tree("/outside", json!({})).await; fs.insert_symlink("/root/link", PathBuf::from("/outside")) .await; - let (tool, _project, _action_log, _fs, _thread) = + let (edit_tool, _project, _action_log, _fs, _thread) = setup_test_with_fs(cx, fs, &[path!("/root").as_ref()]).await; cx.update(|cx| { @@ -2617,7 +1169,7 @@ mod tests { let (stream_tx, mut stream_rx) = ToolCallEventStream::test(); let authorize_task = - cx.update(|cx| tool.authorize(&PathBuf::from("link/new.txt"), &stream_tx, cx)); + cx.update(|cx| edit_tool.authorize(&PathBuf::from("link/new.txt"), &stream_tx, cx)); let event = stream_rx.expect_authorization().await; assert!( @@ -2667,12 +1219,12 @@ mod tests { ) .await .unwrap(); - let (tool, _project, _action_log, _fs, _thread) = + let (edit_tool, _project, _action_log, _fs, _thread) = setup_test_with_fs(cx, fs, &[path!("/root").as_ref()]).await; let (stream_tx, mut stream_rx) = ToolCallEventStream::test(); let _authorize_task = cx.update(|cx| { - tool.authorize( + edit_tool.authorize( &PathBuf::from("link_to_external/config.txt"), &stream_tx, cx, @@ -2712,12 +1264,12 @@ mod tests { ) .await .unwrap(); - let (tool, _project, _action_log, _fs, _thread) = + let (edit_tool, _project, _action_log, _fs, _thread) = setup_test_with_fs(cx, fs, &[path!("/root").as_ref()]).await; let (stream_tx, mut stream_rx) = ToolCallEventStream::test(); let authorize_task = cx.update(|cx| { - tool.authorize( + edit_tool.authorize( &PathBuf::from("link_to_external/config.txt"), &stream_tx, cx, @@ -2767,13 +1319,13 @@ mod tests { ) .await .unwrap(); - let (tool, _project, _action_log, _fs, _thread) = + let (edit_tool, _project, _action_log, _fs, _thread) = setup_test_with_fs(cx, fs, &[path!("/root").as_ref()]).await; let (stream_tx, mut stream_rx) = ToolCallEventStream::test(); let result = cx .update(|cx| { - tool.authorize( + edit_tool.authorize( &PathBuf::from("link_to_external/config.txt"), &stream_tx, cx, @@ -2796,7 +1348,7 @@ mod tests { init_test(cx); let fs = project::FakeFs::new(cx.executor()); fs.insert_tree("/project", json!({})).await; - let (tool, _project, _action_log, _fs, _thread) = + let (edit_tool, _project, _action_log, _fs, _thread) = setup_test_with_fs(cx, fs, &[path!("/project").as_ref()]).await; let test_cases = vec![ @@ -2819,7 +1371,7 @@ mod tests { for (path, should_confirm, description) in test_cases { let (stream_tx, mut stream_rx) = ToolCallEventStream::test(); - let auth = cx.update(|cx| tool.authorize(&PathBuf::from(path), &stream_tx, cx)); + let auth = cx.update(|cx| edit_tool.authorize(&PathBuf::from(path), &stream_tx, cx)); if should_confirm { stream_rx.expect_authorization().await; @@ -2866,7 +1418,7 @@ mod tests { }), ) .await; - let (tool, _project, _action_log, _fs, _thread) = setup_test_with_fs( + let (edit_tool, _project, _action_log, _fs, _thread) = setup_test_with_fs( cx, fs, &[ @@ -2895,7 +1447,7 @@ mod tests { for (path, should_confirm, description) in test_cases { let (stream_tx, mut stream_rx) = ToolCallEventStream::test(); - let auth = cx.update(|cx| tool.authorize(&PathBuf::from(path), &stream_tx, cx)); + let auth = cx.update(|cx| edit_tool.authorize(&PathBuf::from(path), &stream_tx, cx)); if should_confirm { stream_rx.expect_authorization().await; @@ -2929,7 +1481,7 @@ mod tests { }), ) .await; - let (tool, _project, _action_log, _fs, _thread) = + let (edit_tool, _project, _action_log, _fs, _thread) = setup_test_with_fs(cx, fs, &[path!("/project").as_ref()]).await; let test_cases = vec![ @@ -2953,7 +1505,7 @@ mod tests { for (path, should_confirm, description) in test_cases { let (stream_tx, mut stream_rx) = ToolCallEventStream::test(); - let auth = cx.update(|cx| tool.authorize(&PathBuf::from(path), &stream_tx, cx)); + let auth = cx.update(|cx| edit_tool.authorize(&PathBuf::from(path), &stream_tx, cx)); cx.run_until_parked(); @@ -2985,32 +1537,35 @@ mod tests { }), ) .await; - let (tool, _project, _action_log, _fs, _thread) = + let (edit_tool, _project, _action_log, _fs, _thread) = setup_test_with_fs(cx, fs, &[path!("/project").as_ref()]).await; - let modes = vec![EditFileMode::Edit, EditFileMode::Write]; + let modes = vec![EditSessionMode::Edit, EditSessionMode::Write]; for _mode in modes { // Test .zed path with different modes let (stream_tx, mut stream_rx) = ToolCallEventStream::test(); let _auth = cx.update(|cx| { - tool.authorize(&PathBuf::from("project/.zed/settings.json"), &stream_tx, cx) + edit_tool.authorize(&PathBuf::from("project/.zed/settings.json"), &stream_tx, cx) }); stream_rx.expect_authorization().await; // Test outside path with different modes let (stream_tx, mut stream_rx) = ToolCallEventStream::test(); - let _auth = - cx.update(|cx| tool.authorize(&PathBuf::from("/outside/file.txt"), &stream_tx, cx)); + let _auth = cx.update(|cx| { + edit_tool.authorize(&PathBuf::from("/outside/file.txt"), &stream_tx, cx) + }); stream_rx.expect_authorization().await; // Test normal path with different modes let (stream_tx, mut stream_rx) = ToolCallEventStream::test(); - cx.update(|cx| tool.authorize(&PathBuf::from("project/normal.txt"), &stream_tx, cx)) - .await - .unwrap(); + cx.update(|cx| { + edit_tool.authorize(&PathBuf::from("project/normal.txt"), &stream_tx, cx) + }) + .await + .unwrap(); assert!(stream_rx.try_recv().is_err()); } } @@ -3020,12 +1575,12 @@ mod tests { init_test(cx); let fs = project::FakeFs::new(cx.executor()); fs.insert_tree("/project", json!({})).await; - let (tool, _project, _action_log, _fs, _thread) = + let (edit_tool, _project, _action_log, _fs, _thread) = setup_test_with_fs(cx, fs, &[path!("/project").as_ref()]).await; cx.update(|cx| { assert_eq!( - tool.initial_title( + edit_tool.initial_title( Err(json!({ "path": "src/main.rs", })), @@ -3034,7 +1589,7 @@ mod tests { "src/main.rs" ); assert_eq!( - tool.initial_title( + edit_tool.initial_title( Err(json!({ "path": "", })), @@ -3043,77 +1598,15 @@ mod tests { DEFAULT_UI_TEXT ); assert_eq!( - tool.initial_title(Err(serde_json::Value::Null), cx), + edit_tool.initial_title(Err(serde_json::Value::Null), cx), DEFAULT_UI_TEXT ); }); } - #[gpui::test] - async fn test_streaming_diff_finalization(cx: &mut TestAppContext) { - init_test(cx); - let fs = project::FakeFs::new(cx.executor()); - fs.insert_tree("/", json!({"main.rs": ""})).await; - let (tool, project, action_log, _fs, thread) = - setup_test_with_fs(cx, fs, &[path!("/").as_ref()]).await; - let language_registry = project.read_with(cx, |p, _cx| p.languages().clone()); - - // Ensure the diff is finalized after the edit completes. - { - let (stream_tx, mut stream_rx) = ToolCallEventStream::test(); - let edit = cx.update(|cx| { - tool.clone().run( - ToolInput::resolved(EditFileToolInput { - path: path!("/main.rs").into(), - mode: EditFileMode::Write, - content: Some("new content".into()), - edits: None, - }), - stream_tx, - cx, - ) - }); - stream_rx.expect_update_fields().await; - let diff = stream_rx.expect_diff().await; - diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Pending(_)))); - cx.run_until_parked(); - edit.await.unwrap(); - diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Finalized(_)))); - } - - // Ensure the diff is finalized if the tool call gets dropped. - { - let tool = Arc::new(EditFileTool::new( - project.clone(), - thread.downgrade(), - action_log, - language_registry, - )); - let (stream_tx, mut stream_rx) = ToolCallEventStream::test(); - let edit = cx.update(|cx| { - tool.run( - ToolInput::resolved(EditFileToolInput { - path: path!("/main.rs").into(), - mode: EditFileMode::Write, - content: Some("dropped content".into()), - edits: None, - }), - stream_tx, - cx, - ) - }); - stream_rx.expect_update_fields().await; - let diff = stream_rx.expect_diff().await; - diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Pending(_)))); - drop(edit); - cx.run_until_parked(); - diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Finalized(_)))); - } - } - #[gpui::test] async fn test_streaming_consecutive_edits_work(cx: &mut TestAppContext) { - let (tool, project, action_log, _fs, _thread) = + let (edit_tool, project, action_log, _fs, _thread) = setup_test(cx, json!({"test.txt": "original content"})).await; let read_tool = Arc::new(crate::ReadFileTool::new( project.clone(), @@ -3139,15 +1632,13 @@ mod tests { // First edit should work let edit_result = cx .update(|cx| { - tool.clone().run( + edit_tool.clone().run( ToolInput::resolved(EditFileToolInput { path: "root/test.txt".into(), - mode: EditFileMode::Edit, - content: None, - edits: Some(vec![Edit { + edits: vec![Edit { old_text: "original content".into(), new_text: "modified content".into(), - }]), + }], }), ToolCallEventStream::test().0, cx, @@ -3163,15 +1654,13 @@ mod tests { // Second edit should also work because the edit updated the recorded read time let edit_result = cx .update(|cx| { - tool.clone().run( + edit_tool.clone().run( ToolInput::resolved(EditFileToolInput { path: "root/test.txt".into(), - mode: EditFileMode::Edit, - content: None, - edits: Some(vec![Edit { + edits: vec![Edit { old_text: "modified content".into(), new_text: "further modified content".into(), - }]), + }], }), ToolCallEventStream::test().0, cx, @@ -3187,7 +1676,7 @@ mod tests { #[gpui::test] async fn test_streaming_external_modification_matching_edit_succeeds(cx: &mut TestAppContext) { - let (tool, project, action_log, fs, _thread) = + let (edit_tool, project, action_log, fs, _thread) = setup_test(cx, json!({"test.txt": "original content"})).await; let read_tool = Arc::new(crate::ReadFileTool::new( project.clone(), @@ -3240,15 +1729,13 @@ mod tests { let result = cx .update(|cx| { - tool.clone().run( + edit_tool.clone().run( ToolInput::resolved(EditFileToolInput { path: "root/test.txt".into(), - mode: EditFileMode::Edit, - content: None, - edits: Some(vec![Edit { + edits: vec![Edit { old_text: "externally modified content".into(), new_text: "new content".into(), - }]), + }], }), ToolCallEventStream::test().0, cx, @@ -3274,7 +1761,7 @@ mod tests { async fn test_streaming_external_modification_mentioned_when_match_fails( cx: &mut TestAppContext, ) { - let (tool, project, action_log, fs, _thread) = + let (edit_tool, project, action_log, fs, _thread) = setup_test(cx, json!({"test.txt": "original content"})).await; let read_tool = Arc::new(crate::ReadFileTool::new( project.clone(), @@ -3324,15 +1811,13 @@ mod tests { let result = cx .update(|cx| { - tool.clone().run( + edit_tool.clone().run( ToolInput::resolved(EditFileToolInput { path: "root/test.txt".into(), - mode: EditFileMode::Edit, - content: None, - edits: Some(vec![Edit { + edits: vec![Edit { old_text: "original content".into(), new_text: "new content".into(), - }]), + }], }), ToolCallEventStream::test().0, cx, @@ -3363,7 +1848,7 @@ mod tests { #[gpui::test] async fn test_streaming_dirty_buffer_detected(cx: &mut TestAppContext) { - let (tool, project, action_log, _fs, _thread) = + let (edit_tool, project, action_log, _fs, _thread) = setup_test(cx, json!({"test.txt": "original content"})).await; let read_tool = Arc::new(crate::ReadFileTool::new( project.clone(), @@ -3408,15 +1893,13 @@ mod tests { // Try to edit - should fail because buffer has unsaved changes let result = cx .update(|cx| { - tool.clone().run( + edit_tool.clone().run( ToolInput::resolved(EditFileToolInput { path: "root/test.txt".into(), - mode: EditFileMode::Edit, - content: None, - edits: Some(vec![Edit { + edits: vec![Edit { old_text: "original content".into(), new_text: "new content".into(), - }]), + }], }), ToolCallEventStream::test().0, cx, @@ -3457,16 +1940,15 @@ mod tests { // old_text as a substring. Because edits resolve sequentially // against the current buffer, edit 2 finds a unique match in // the modified buffer and succeeds. - let (tool, _project, _action_log, _fs, _thread) = + let (edit_tool, _project, _action_log, _fs, _thread) = setup_test(cx, json!({"file.txt": "aaa\nbbb\nccc\nddd\neee\n"})).await; let (mut sender, input) = ToolInput::::test(); let (event_stream, _receiver) = ToolCallEventStream::test(); - let task = cx.update(|cx| tool.clone().run(input, event_stream, cx)); + let task = cx.update(|cx| edit_tool.clone().run(input, event_stream, cx)); // Setup: resolve the buffer sender.send_partial(json!({ "path": "root/file.txt", - "mode": "edit" })); cx.run_until_parked(); @@ -3477,7 +1959,6 @@ mod tests { // Edit 3 exists only to mark edit 2 as "complete" during streaming. sender.send_partial(json!({ "path": "root/file.txt", - "mode": "edit", "edits": [ {"old_text": "bbb\nccc", "new_text": "XXX\nccc\nddd"}, {"old_text": "ccc\nddd", "new_text": "ZZZ"}, @@ -3489,7 +1970,6 @@ mod tests { // Send the final input with all three edits. sender.send_full(json!({ "path": "root/file.txt", - "mode": "edit", "edits": [ {"old_text": "bbb\nccc", "new_text": "XXX\nccc\nddd"}, {"old_text": "ccc\nddd", "new_text": "ZZZ"}, @@ -3504,218 +1984,16 @@ mod tests { assert_eq!(new_text, "aaa\nXXX\nZZZ\nddd\nDUMMY\n"); } - #[gpui::test] - async fn test_streaming_create_content_streamed(cx: &mut TestAppContext) { - let (tool, project, _action_log, _fs, _thread) = setup_test(cx, json!({"dir": {}})).await; - let (mut sender, input) = ToolInput::::test(); - let (event_stream, _receiver) = ToolCallEventStream::test(); - let task = cx.update(|cx| tool.clone().run(input, event_stream, cx)); - - // Transition to BufferResolved - sender.send_partial(json!({ - "path": "root/dir/new_file.txt", - "mode": "write" - })); - cx.run_until_parked(); - - // Stream content incrementally - sender.send_partial(json!({ - "path": "root/dir/new_file.txt", - "mode": "write", - "content": "line 1\n" - })); - cx.run_until_parked(); - - // Verify buffer has partial content - let buffer = project.update(cx, |project, cx| { - let path = project - .find_project_path("root/dir/new_file.txt", cx) - .unwrap(); - project.get_open_buffer(&path, cx).unwrap() - }); - assert_eq!(buffer.read_with(cx, |b, _| b.text()), "line 1\n"); - - // Stream more content - sender.send_partial(json!({ - "path": "root/dir/new_file.txt", - "mode": "write", - "content": "line 1\nline 2\n" - })); - cx.run_until_parked(); - assert_eq!(buffer.read_with(cx, |b, _| b.text()), "line 1\nline 2\n"); - - // Stream final chunk - sender.send_partial(json!({ - "path": "root/dir/new_file.txt", - "mode": "write", - "content": "line 1\nline 2\nline 3\n" - })); - cx.run_until_parked(); - assert_eq!( - buffer.read_with(cx, |b, _| b.text()), - "line 1\nline 2\nline 3\n" - ); - - // Send final input - sender.send_full(json!({ - "path": "root/dir/new_file.txt", - "mode": "write", - "content": "line 1\nline 2\nline 3\n" - })); - - let result = task.await; - let EditFileToolOutput::Success { new_text, .. } = result.unwrap() else { - panic!("expected success"); - }; - assert_eq!(new_text, "line 1\nline 2\nline 3\n"); - } - - #[gpui::test] - async fn test_streaming_overwrite_diff_revealed_during_streaming(cx: &mut TestAppContext) { - let (tool, _project, _action_log, _fs, _thread) = setup_test( - cx, - json!({"file.txt": "old line 1\nold line 2\nold line 3\n"}), - ) - .await; - let (mut sender, input) = ToolInput::::test(); - let (event_stream, mut receiver) = ToolCallEventStream::test(); - let task = cx.update(|cx| tool.clone().run(input, event_stream, cx)); - - // Transition to BufferResolved - sender.send_partial(json!({ - "path": "root/file.txt", - })); - cx.run_until_parked(); - - sender.send_partial(json!({ - "path": "root/file.txt", - "mode": "write" - })); - cx.run_until_parked(); - - // Get the diff entity from the event stream - receiver.expect_update_fields().await; - let diff = receiver.expect_diff().await; - - // Diff starts pending with no revealed ranges - diff.read_with(cx, |diff, cx| { - assert!(matches!(diff, Diff::Pending(_))); - assert!(!diff.has_revealed_range(cx)); - }); - - // Stream first content chunk - sender.send_partial(json!({ - "path": "root/file.txt", - "mode": "write", - "content": "new line 1\n" - })); - cx.run_until_parked(); - - // Diff should now have revealed ranges showing the new content - diff.read_with(cx, |diff, cx| { - assert!(diff.has_revealed_range(cx)); - }); - - // Send final input - sender.send_full(json!({ - "path": "root/file.txt", - "mode": "write", - "content": "new line 1\nnew line 2\n" - })); - - let result = task.await; - let EditFileToolOutput::Success { - new_text, old_text, .. - } = result.unwrap() - else { - panic!("expected success"); - }; - assert_eq!(new_text, "new line 1\nnew line 2\n"); - assert_eq!(*old_text, "old line 1\nold line 2\nold line 3\n"); - - // Diff is finalized after completion - diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Finalized(_)))); - } - - #[gpui::test] - async fn test_streaming_overwrite_content_streamed(cx: &mut TestAppContext) { - let (tool, project, _action_log, _fs, _thread) = setup_test( - cx, - json!({"file.txt": "old line 1\nold line 2\nold line 3\n"}), - ) - .await; - let (mut sender, input) = ToolInput::::test(); - let (event_stream, _receiver) = ToolCallEventStream::test(); - let task = cx.update(|cx| tool.clone().run(input, event_stream, cx)); - - // Transition to BufferResolved - sender.send_partial(json!({ - "path": "root/file.txt", - "mode": "write" - })); - cx.run_until_parked(); - - // Verify buffer still has old content (no content partial yet) - let buffer = project.update(cx, |project, cx| { - let path = project.find_project_path("root/file.txt", cx).unwrap(); - project.open_buffer(path, cx) - }); - let buffer = buffer.await.unwrap(); - assert_eq!( - buffer.read_with(cx, |b, _| b.text()), - "old line 1\nold line 2\nold line 3\n" - ); - - // First content partial replaces old content - sender.send_partial(json!({ - "path": "root/file.txt", - "mode": "write", - "content": "new line 1\n" - })); - cx.run_until_parked(); - assert_eq!(buffer.read_with(cx, |b, _| b.text()), "new line 1\n"); - - // Subsequent content partials append - sender.send_partial(json!({ - "path": "root/file.txt", - "mode": "write", - "content": "new line 1\nnew line 2\n" - })); - cx.run_until_parked(); - assert_eq!( - buffer.read_with(cx, |b, _| b.text()), - "new line 1\nnew line 2\n" - ); - - // Send final input with complete content - sender.send_full(json!({ - "path": "root/file.txt", - "mode": "write", - "content": "new line 1\nnew line 2\nnew line 3\n" - })); - - let result = task.await; - let EditFileToolOutput::Success { - new_text, old_text, .. - } = result.unwrap() - else { - panic!("expected success"); - }; - assert_eq!(new_text, "new line 1\nnew line 2\nnew line 3\n"); - assert_eq!(*old_text, "old line 1\nold line 2\nold line 3\n"); - } - #[gpui::test] async fn test_streaming_edit_json_fixer_escape_corruption(cx: &mut TestAppContext) { - let (tool, _project, _action_log, _fs, _thread) = + let (edit_tool, _project, _action_log, _fs, _thread) = setup_test(cx, json!({"file.txt": "hello\nworld\nfoo\n"})).await; let (mut sender, input) = ToolInput::::test(); let (event_stream, _receiver) = ToolCallEventStream::test(); - let task = cx.update(|cx| tool.clone().run(input, event_stream, cx)); + let task = cx.update(|cx| edit_tool.clone().run(input, event_stream, cx)); sender.send_partial(json!({ "path": "root/file.txt", - "mode": "edit" })); cx.run_until_parked(); @@ -3726,7 +2004,6 @@ mod tests { // partial 2: old_text = "hello\nworld" (fixer corrected the escape) sender.send_partial(json!({ "path": "root/file.txt", - "mode": "edit", "edits": [{"old_text": "hello\\"}] })); cx.run_until_parked(); @@ -3734,7 +2011,6 @@ mod tests { // Now the fixer corrects it to the real newline. sender.send_partial(json!({ "path": "root/file.txt", - "mode": "edit", "edits": [{"old_text": "hello\nworld"}] })); cx.run_until_parked(); @@ -3742,7 +2018,6 @@ mod tests { // Send final. sender.send_full(json!({ "path": "root/file.txt", - "mode": "edit", "edits": [{"old_text": "hello\nworld", "new_text": "HELLO\nWORLD"}] })); @@ -3755,21 +2030,19 @@ mod tests { #[gpui::test] async fn test_streaming_final_input_stringified_edits_succeeds(cx: &mut TestAppContext) { - let (tool, _project, _action_log, _fs, _thread) = + let (edit_tool, _project, _action_log, _fs, _thread) = setup_test(cx, json!({"file.txt": "hello\nworld\n"})).await; let (mut sender, input) = ToolInput::::test(); let (event_stream, _receiver) = ToolCallEventStream::test(); - let task = cx.update(|cx| tool.clone().run(input, event_stream, cx)); + let task = cx.update(|cx| edit_tool.clone().run(input, event_stream, cx)); sender.send_partial(json!({ "path": "root/file.txt", - "mode": "edit" })); cx.run_until_parked(); sender.send_full(json!({ "path": "root/file.txt", - "mode": "edit", "edits": "[{\"old_text\": \"hello\\nworld\", \"new_text\": \"HELLO\\nWORLD\"}]" })); @@ -3784,7 +2057,7 @@ mod tests { // reports changed buffers so that the Accept All / Reject All review UI appears. #[gpui::test] async fn test_streaming_edit_file_tool_registers_changed_buffers(cx: &mut TestAppContext) { - let (tool, _project, action_log, _fs, _thread) = + let (edit_tool, _project, action_log, _fs, _thread) = setup_test(cx, json!({"file.txt": "line 1\nline 2\nline 3\n"})).await; cx.update(|cx| { let mut settings = agent_settings::AgentSettings::get_global(cx).clone(); @@ -3794,15 +2067,13 @@ mod tests { let (event_stream, _rx) = ToolCallEventStream::test(); let task = cx.update(|cx| { - tool.clone().run( + edit_tool.clone().run( ToolInput::resolved(EditFileToolInput { path: "root/file.txt".into(), - mode: EditFileMode::Edit, - content: None, - edits: Some(vec![Edit { + edits: vec![Edit { old_text: "line 2".into(), new_text: "modified line 2".into(), - }]), + }], }), event_stream, cx, @@ -3823,116 +2094,28 @@ mod tests { } // Same test but for Write mode (overwrite entire file). - #[gpui::test] - async fn test_streaming_edit_file_tool_write_mode_registers_changed_buffers( - cx: &mut TestAppContext, - ) { - let (tool, _project, action_log, _fs, _thread) = - setup_test(cx, json!({"file.txt": "original content"})).await; - cx.update(|cx| { - let mut settings = agent_settings::AgentSettings::get_global(cx).clone(); - settings.tool_permissions.default = settings::ToolPermissionMode::Allow; - agent_settings::AgentSettings::override_global(settings, cx); - }); - - let (event_stream, _rx) = ToolCallEventStream::test(); - let task = cx.update(|cx| { - tool.clone().run( - ToolInput::resolved(EditFileToolInput { - path: "root/file.txt".into(), - mode: EditFileMode::Write, - content: Some("completely new content".into()), - edits: None, - }), - event_stream, - cx, - ) - }); - - let result = task.await; - assert!(result.is_ok(), "write should succeed: {:?}", result.err()); - - cx.run_until_parked(); - - let changed = action_log.read_with(cx, |log, cx| log.changed_buffers(cx)); - assert!( - !changed.is_empty(), - "action_log.changed_buffers() should be non-empty after streaming write, \ - but no changed buffers were found \u{2014} Accept All / Reject All will not appear" - ); - } - - #[gpui::test] - async fn test_streaming_edit_file_tool_fields_out_of_order_in_write_mode( - cx: &mut TestAppContext, - ) { - let (tool, _project, _action_log, _fs, _thread) = - setup_test(cx, json!({"file.txt": "old_content"})).await; - let (mut sender, input) = ToolInput::::test(); - let (event_stream, _receiver) = ToolCallEventStream::test(); - let task = cx.update(|cx| tool.clone().run(input, event_stream, cx)); - - sender.send_partial(json!({ - "mode": "write" - })); - cx.run_until_parked(); - - sender.send_partial(json!({ - "mode": "write", - "content": "new_content" - })); - cx.run_until_parked(); - - sender.send_partial(json!({ - "mode": "write", - "content": "new_content", - "path": "root" - })); - cx.run_until_parked(); - - // Send final. - sender.send_full(json!({ - "mode": "write", - "content": "new_content", - "path": "root/file.txt" - })); - - let result = task.await; - let EditFileToolOutput::Success { new_text, .. } = result.unwrap() else { - panic!("expected success"); - }; - assert_eq!(new_text, "new_content"); - } #[gpui::test] async fn test_streaming_edit_file_tool_fields_out_of_order_in_edit_mode( cx: &mut TestAppContext, ) { - let (tool, _project, _action_log, _fs, _thread) = + let (edit_tool, _project, _action_log, _fs, _thread) = setup_test(cx, json!({"file.txt": "old_content"})).await; let (mut sender, input) = ToolInput::::test(); let (event_stream, _receiver) = ToolCallEventStream::test(); - let task = cx.update(|cx| tool.clone().run(input, event_stream, cx)); - - sender.send_partial(json!({ - "mode": "edit" - })); - cx.run_until_parked(); + let task = cx.update(|cx| edit_tool.clone().run(input, event_stream, cx)); sender.send_partial(json!({ - "mode": "edit", "edits": [{"old_text": "old_content"}] })); cx.run_until_parked(); sender.send_partial(json!({ - "mode": "edit", "edits": [{"old_text": "old_content", "new_text": "new_content"}] })); cx.run_until_parked(); sender.send_partial(json!({ - "mode": "edit", "edits": [{"old_text": "old_content", "new_text": "new_content"}], "path": "root" })); @@ -3940,7 +2123,6 @@ mod tests { // Send final. sender.send_full(json!({ - "mode": "edit", "edits": [{"old_text": "old_content", "new_text": "new_content"}], "path": "root/file.txt" })); @@ -3968,7 +2150,7 @@ mod tests { "#} .to_string(); - let (tool, _project, _action_log, _fs, _thread) = + let (edit_tool, _project, _action_log, _fs, _thread) = setup_test(cx, json!({"file.rs": file_content})).await; // The model sends old_text with a PARTIAL last line. @@ -3977,11 +2159,10 @@ mod tests { let (mut sender, input) = ToolInput::::test(); let (event_stream, _receiver) = ToolCallEventStream::test(); - let task = cx.update(|cx| tool.clone().run(input, event_stream, cx)); + let task = cx.update(|cx| edit_tool.clone().run(input, event_stream, cx)); sender.send_full(json!({ "path": "root/file.rs", - "mode": "edit", "edits": [{"old_text": old_text, "new_text": new_text}] })); @@ -4013,15 +2194,14 @@ mod tests { let new_text = "one\ntwo\ntarget\n"; let expected = "before\none\ntwo\ntarget\n\nafter\n"; - let (tool, _project, _action_log, _fs, _thread) = + let (edit_tool, _project, _action_log, _fs, _thread) = setup_test(cx, json!({"file.rs": file_content})).await; let (mut sender, input) = ToolInput::::test(); let (event_stream, _receiver) = ToolCallEventStream::test(); - let task = cx.update(|cx| tool.clone().run(input, event_stream, cx)); + let task = cx.update(|cx| edit_tool.clone().run(input, event_stream, cx)); sender.send_full(json!({ "path": "root/file.rs", - "mode": "edit", "edits": [{"old_text": old_text, "new_text": new_text}] })); @@ -4042,94 +2222,24 @@ mod tests { ); } - #[gpui::test] - async fn test_streaming_reject_created_file_deletes_it(cx: &mut TestAppContext) { - let (tool, _project, action_log, fs, _thread) = setup_test(cx, json!({"dir": {}})).await; - cx.update(|cx| { - let mut settings = agent_settings::AgentSettings::get_global(cx).clone(); - settings.tool_permissions.default = settings::ToolPermissionMode::Allow; - agent_settings::AgentSettings::override_global(settings, cx); - }); - - // Create a new file via the streaming edit file tool - let (event_stream, _rx) = ToolCallEventStream::test(); - let task = cx.update(|cx| { - tool.clone().run( - ToolInput::resolved(EditFileToolInput { - path: "root/dir/new_file.txt".into(), - mode: EditFileMode::Write, - content: Some("Hello, World!".into()), - edits: None, - }), - event_stream, - cx, - ) - }); - let result = task.await; - assert!(result.is_ok(), "create should succeed: {:?}", result.err()); - cx.run_until_parked(); - - assert!( - fs.is_file(path!("/root/dir/new_file.txt").as_ref()).await, - "file should exist after creation" - ); - - // Reject all edits — this should delete the newly created file - let changed = action_log.read_with(cx, |log, cx| log.changed_buffers(cx)); - assert!( - !changed.is_empty(), - "action_log should track the created file as changed" - ); - - action_log - .update(cx, |log, cx| log.reject_all_edits(None, cx)) - .await; - cx.run_until_parked(); - - assert!( - !fs.is_file(path!("/root/dir/new_file.txt").as_ref()).await, - "file should be deleted after rejecting creation, but an empty file was left behind" - ); - } - #[test] fn test_input_deserializes_double_encoded_fields() { let input = serde_json::from_value::(json!({ "path": "root/file.txt", - "mode": "\"edit\"", "edits": "[{\"old_text\": \"hello\\nworld\", \"new_text\": \"HELLO\\nWORLD\"}]" })) .expect("input should deserialize"); - assert!(matches!(input.mode, EditFileMode::Edit)); - let edits = input.edits.expect("edits should deserialize"); - assert_eq!(edits.len(), 1); - assert_eq!(edits[0].old_text, "hello\nworld"); - assert_eq!(edits[0].new_text, "HELLO\nWORLD"); - - let input = serde_json::from_value::(json!({ - "path": "root/file.txt", - "mode": "\"edit\"" - })) - .expect("input should deserialize"); - assert!(input.edits.is_none()); - - let input = serde_json::from_value::(json!({ - "path": "root/file.txt", - "mode": "\"edit\"", - "edits": null - })) - .expect("input should deserialize"); - assert!(input.edits.is_none()); + assert_eq!(input.edits.len(), 1); + assert_eq!(input.edits[0].old_text, "hello\nworld"); + assert_eq!(input.edits[0].new_text, "HELLO\nWORLD"); let input = serde_json::from_value::(json!({ "path": "root/file.txt", - "mode": "\"edit\"", "edits": "[{\"old_text\": \"hello\\nworld\", \"new_text\": \"HELLO\\nWORLD\"}]" })) .expect("input should deserialize"); - assert!(matches!(input.mode, Some(EditFileMode::Edit))); let edits = input.edits.expect("edits should deserialize"); assert_eq!(edits.len(), 1); assert_eq!(edits[0].old_text.as_deref(), Some("hello\nworld")); @@ -4139,16 +2249,13 @@ mod tests { "path": "root/file.txt" })) .expect("input should deserialize"); - assert!(input.mode.is_none()); assert!(input.edits.is_none()); let input = serde_json::from_value::(json!({ "path": "root/file.txt", - "mode": null, "edits": null })) .expect("input should deserialize"); - assert!(input.mode.is_none()); assert!(input.edits.is_none()); } @@ -4179,13 +2286,13 @@ mod tests { ) }); let action_log = thread.read_with(cx, |thread, _| thread.action_log().clone()); - let tool = Arc::new(EditFileTool::new( + let edit_tool = Arc::new(EditFileTool::new( project.clone(), thread.downgrade(), action_log.clone(), language_registry, )); - (tool, project, action_log, fs, thread) + (edit_tool, project, action_log, fs, thread) } async fn setup_test( diff --git a/crates/agent/src/tools/edit_session.rs b/crates/agent/src/tools/edit_session.rs new file mode 100644 index 0000000000000000000000000000000000000000..1be22a579a0fdd201995ccbac9da98c653bc6e98 --- /dev/null +++ b/crates/agent/src/tools/edit_session.rs @@ -0,0 +1,1067 @@ +mod reindent; +mod streaming_fuzzy_matcher; +mod streaming_parser; + +use super::restore_file_from_disk_tool::RestoreFileFromDiskTool; +use super::save_file_tool::SaveFileTool; +use crate::{AgentTool, Thread, ToolCallEventStream}; +use acp_thread::Diff; +use action_log::ActionLog; +use agent_client_protocol::schema::{ToolCallLocation, ToolCallUpdateFields}; +use anyhow::Result; +use collections::HashSet; +use gpui::{App, AppContext, AsyncApp, Entity, Task, WeakEntity}; +use language::language_settings::{self, FormatOnSave}; +use language::{Buffer, LanguageRegistry}; +use language_model::LanguageModelToolResultContent; +use project::lsp_store::{FormatTrigger, LspFormatTarget}; +use project::{AgentLocation, Project, ProjectPath}; +use reindent::{Reindenter, compute_indent_delta}; +use schemars::JsonSchema; +use serde::{Deserialize, Serialize, de::DeserializeOwned}; +use std::ops::Range; +use std::path::PathBuf; +use std::sync::Arc; +use streaming_diff::{CharOperation, StreamingDiff}; +use streaming_fuzzy_matcher::StreamingFuzzyMatcher; +use streaming_parser::{EditEvent, StreamingParser, WriteEvent}; +use text::ToOffset; +use ui::SharedString; +use util::rel_path::RelPath; +use util::{Deferred, ResultExt}; + +/// Operating mode used internally by `EditSession`/`Pipeline` to choose between +/// applying granular edits (the `edit_file` tool) or replacing/creating the +/// entire file content (the `write_file` tool). +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub(crate) enum EditSessionMode { + Write, + Edit, +} + +/// A single edit operation that replaces old text with new text +/// Properly escape all text fields as valid JSON strings. +/// Remember to escape special characters like newlines (`\n`) and quotes (`"`) in JSON strings. +#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)] +pub struct Edit { + /// The exact text to find in the file. This will be matched using fuzzy matching + /// to handle minor differences in whitespace or formatting. + /// + /// Be minimal with replacements: + /// - For unique lines, include only those lines + /// - For non-unique lines, include enough context to identify them + pub old_text: String, + /// The text to replace it with + pub new_text: String, +} + +#[derive(Clone, Default, Debug, Deserialize)] +pub struct PartialEdit { + #[serde(default)] + pub old_text: Option, + #[serde(default)] + pub new_text: Option, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(untagged)] +pub enum EditSessionOutput { + Success { + #[serde(alias = "original_path")] + input_path: PathBuf, + new_text: String, + old_text: Arc, + #[serde(default)] + diff: String, + }, + Error { + error: String, + #[serde(default, skip_serializing_if = "Option::is_none")] + input_path: Option, + #[serde(default, skip_serializing_if = "String::is_empty")] + diff: String, + }, +} + +impl EditSessionOutput { + pub fn error(error: impl Into) -> Self { + Self::Error { + error: error.into(), + input_path: None, + diff: String::new(), + } + } +} + +impl std::fmt::Display for EditSessionOutput { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + EditSessionOutput::Success { + diff, input_path, .. + } => { + if diff.is_empty() { + write!(f, "No edits were made.") + } else { + write!( + f, + "Edited {}:\n\n```diff\n{diff}\n```", + input_path.display() + ) + } + } + EditSessionOutput::Error { + error, + diff, + input_path, + } => { + write!(f, "{error}\n")?; + if let Some(input_path) = input_path + && !diff.is_empty() + { + write!( + f, + "Edited {}:\n\n```diff\n{diff}\n```", + input_path.display() + ) + } else { + write!(f, "No edits were made.") + } + } + } + } +} + +impl From for LanguageModelToolResultContent { + fn from(output: EditSessionOutput) -> Self { + output.to_string().into() + } +} + +pub(crate) struct EditSessionContext { + project: Entity, + thread: WeakEntity, + action_log: Entity, + language_registry: Arc, +} + +impl EditSessionContext { + pub(crate) fn new( + project: Entity, + thread: WeakEntity, + action_log: Entity, + language_registry: Arc, + ) -> Self { + Self { + project, + thread, + action_log, + language_registry, + } + } + + pub(crate) fn authorize( + &self, + tool_name: &str, + path: &PathBuf, + event_stream: &ToolCallEventStream, + cx: &mut App, + ) -> Task> { + super::tool_permissions::authorize_file_edit( + tool_name, + path, + &self.thread, + event_stream, + cx, + ) + } + + fn set_agent_location(&self, buffer: WeakEntity, position: text::Anchor, cx: &mut App) { + let should_update_agent_location = self + .thread + .read_with(cx, |thread, _cx| !thread.is_subagent()) + .unwrap_or_default(); + if should_update_agent_location { + self.project.update(cx, |project, cx| { + project.set_agent_location(Some(AgentLocation { buffer, position }), cx); + }); + } + } + + async fn ensure_buffer_saved(&self, buffer: &Entity, cx: &mut AsyncApp) { + let format_on_save_enabled = buffer.read_with(cx, |buffer, cx| { + let settings = language_settings::LanguageSettings::for_buffer(buffer, cx); + settings.format_on_save != FormatOnSave::Off + }); + + if format_on_save_enabled { + self.project + .update(cx, |project, cx| { + project.format( + HashSet::from_iter([buffer.clone()]), + LspFormatTarget::Buffers, + false, + FormatTrigger::Save, + cx, + ) + }) + .await + .log_err(); + } + + self.project + .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx)) + .await + .log_err(); + + self.action_log.update(cx, |log, cx| { + log.buffer_edited(buffer.clone(), cx); + }); + } + + pub(crate) fn initial_title_from_path( + &self, + path: &std::path::Path, + default: &str, + cx: &App, + ) -> SharedString { + let project = self.project.read(cx); + if let Some(project_path) = project.find_project_path(path, cx) + && let Some(short) = project.short_full_path_for_project_path(&project_path, cx) + { + return short.into(); + } + + let display = path.to_string_lossy(); + if display.is_empty() { + default.into() + } else { + display.into_owned().into() + } + } + + pub(crate) fn replay_output( + &self, + output: EditSessionOutput, + event_stream: ToolCallEventStream, + cx: &mut App, + ) -> Result<()> { + match output { + EditSessionOutput::Success { + input_path, + old_text, + new_text, + .. + } => { + event_stream.update_diff(cx.new(|cx| { + Diff::finalized( + input_path.to_string_lossy().into_owned(), + Some(old_text.to_string()), + new_text, + self.language_registry.clone(), + cx, + ) + })); + Ok(()) + } + EditSessionOutput::Error { .. } => Ok(()), + } + } +} + +pub(crate) enum EditSessionResult { + Completed(EditSession), + Failed { + error: String, + session: Option, + }, +} + +pub(crate) async fn run_session( + result: EditSessionResult, + cx: &mut AsyncApp, +) -> Result { + match result { + EditSessionResult::Completed(session) => { + session + .context + .ensure_buffer_saved(&session.buffer, cx) + .await; + let (new_text, diff) = session.compute_new_text_and_diff(cx).await; + Ok(EditSessionOutput::Success { + old_text: session.old_text.clone(), + new_text, + input_path: session.input_path, + diff, + }) + } + EditSessionResult::Failed { + error, + session: Some(session), + } => { + session + .context + .ensure_buffer_saved(&session.buffer, cx) + .await; + let (_new_text, diff) = session.compute_new_text_and_diff(cx).await; + Err(EditSessionOutput::Error { + error, + input_path: Some(session.input_path), + diff, + }) + } + EditSessionResult::Failed { + error, + session: None, + } => Err(EditSessionOutput::Error { + error, + input_path: None, + diff: String::new(), + }), + } +} + +pub(crate) fn initial_title_from_partial_path

( + context: &EditSessionContext, + raw_input: serde_json::Value, + extract_path: impl FnOnce(&P) -> Option, + default: &str, + cx: &App, +) -> SharedString +where + P: DeserializeOwned, +{ + if let Ok(partial) = serde_json::from_value::

(raw_input) + && let Some(raw_path) = extract_path(&partial) + { + let trimmed = raw_path.trim(); + if !trimmed.is_empty() { + return context.initial_title_from_path(std::path::Path::new(trimmed), default, cx); + } + } + default.into() +} + +pub(crate) struct EditSession { + abs_path: PathBuf, + pub(crate) input_path: PathBuf, + pub(crate) buffer: Entity, + pub(crate) old_text: Arc, + diff: Entity, + parser: StreamingParser, + pipeline: Pipeline, + context: Arc, + _finalize_diff_guard: Deferred>, +} + +enum Pipeline { + Write(WritePipeline), + Edit(EditPipeline), +} + +struct WritePipeline { + content_written: bool, +} + +struct EditPipeline { + current_edit: Option, + file_changed_since_last_read: bool, +} + +enum EditPipelineEntry { + ResolvingOldText { + matcher: StreamingFuzzyMatcher, + }, + StreamingNewText { + streaming_diff: StreamingDiff, + edit_cursor: usize, + reindenter: Reindenter, + original_snapshot: text::BufferSnapshot, + }, +} + +impl Pipeline { + fn new(mode: EditSessionMode, file_changed_since_last_read: bool) -> Self { + match mode { + EditSessionMode::Write => Self::Write(WritePipeline { + content_written: false, + }), + EditSessionMode::Edit => Self::Edit(EditPipeline { + current_edit: None, + file_changed_since_last_read, + }), + } + } +} + +impl WritePipeline { + fn process_event( + &mut self, + event: &WriteEvent, + buffer: &Entity, + context: &EditSessionContext, + cx: &mut AsyncApp, + ) { + let WriteEvent::ContentChunk { chunk } = event; + + let (buffer_id, buffer_len) = + buffer.read_with(cx, |buffer, _cx| (buffer.remote_id(), buffer.len())); + let edit_range = if self.content_written { + buffer_len..buffer_len + } else { + 0..buffer_len + }; + + agent_edit_buffer( + buffer, + [(edit_range, chunk.as_str())], + &context.action_log, + cx, + ); + cx.update(|cx| { + context.set_agent_location( + buffer.downgrade(), + text::Anchor::max_for_buffer(buffer_id), + cx, + ); + }); + self.content_written = true; + } +} + +impl EditPipeline { + fn ensure_resolving_old_text(&mut self, buffer: &Entity, cx: &mut AsyncApp) { + if self.current_edit.is_none() { + let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.text_snapshot()); + self.current_edit = Some(EditPipelineEntry::ResolvingOldText { + matcher: StreamingFuzzyMatcher::new(snapshot), + }); + } + } + + fn process_event( + &mut self, + event: &EditEvent, + buffer: &Entity, + diff: &Entity, + abs_path: &PathBuf, + context: &EditSessionContext, + event_stream: &ToolCallEventStream, + cx: &mut AsyncApp, + ) -> Result<(), String> { + match event { + EditEvent::OldTextChunk { + chunk, done: false, .. + } => { + log::debug!("old_text_chunk: done=false, chunk='{}'", chunk); + self.ensure_resolving_old_text(buffer, cx); + + if let Some(EditPipelineEntry::ResolvingOldText { matcher }) = + &mut self.current_edit + && !chunk.is_empty() + { + if let Some(match_range) = matcher.push(chunk, None) { + let anchor_range = buffer.read_with(cx, |buffer, _cx| { + buffer.anchor_range_outside(match_range.clone()) + }); + diff.update(cx, |diff, cx| diff.reveal_range(anchor_range, cx)); + + cx.update(|cx| { + let position = buffer.read(cx).anchor_before(match_range.end); + context.set_agent_location(buffer.downgrade(), position, cx); + }); + } + } + } + EditEvent::OldTextChunk { + edit_index, + chunk, + done: true, + } => { + log::debug!("old_text_chunk: done=true, chunk='{}'", chunk); + + self.ensure_resolving_old_text(buffer, cx); + + let Some(EditPipelineEntry::ResolvingOldText { matcher }) = &mut self.current_edit + else { + return Ok(()); + }; + + if !chunk.is_empty() { + matcher.push(chunk, None); + } + let range = extract_match( + matcher.finish(), + buffer, + edit_index, + self.file_changed_since_last_read, + cx, + )?; + + let anchor_range = + buffer.read_with(cx, |buffer, _cx| buffer.anchor_range_outside(range.clone())); + diff.update(cx, |diff, cx| diff.reveal_range(anchor_range, cx)); + + let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot()); + + let line = snapshot.offset_to_point(range.start).row; + event_stream.update_fields( + ToolCallUpdateFields::new() + .locations(vec![ToolCallLocation::new(abs_path).line(Some(line))]), + ); + + let buffer_indent = snapshot.line_indent_for_row(line); + let query_indent = text::LineIndent::from_iter( + matcher + .query_lines() + .first() + .map(|s| s.as_str()) + .unwrap_or("") + .chars(), + ); + let indent_delta = compute_indent_delta(buffer_indent, query_indent); + + let old_text_in_buffer = snapshot.text_for_range(range.clone()).collect::(); + + log::debug!( + "edit[{}] old_text matched at {}..{}: {:?}", + edit_index, + range.start, + range.end, + old_text_in_buffer, + ); + + let text_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.text_snapshot()); + self.current_edit = Some(EditPipelineEntry::StreamingNewText { + streaming_diff: StreamingDiff::new(old_text_in_buffer), + edit_cursor: range.start, + reindenter: Reindenter::new(indent_delta), + original_snapshot: text_snapshot, + }); + + cx.update(|cx| { + let position = buffer.read(cx).anchor_before(range.end); + context.set_agent_location(buffer.downgrade(), position, cx); + }); + } + EditEvent::NewTextChunk { + chunk, done: false, .. + } => { + log::debug!("new_text_chunk: done=false, chunk='{}'", chunk); + + let Some(EditPipelineEntry::StreamingNewText { + streaming_diff, + edit_cursor, + reindenter, + original_snapshot, + .. + }) = &mut self.current_edit + else { + return Ok(()); + }; + + let reindented = reindenter.push(chunk); + if reindented.is_empty() { + return Ok(()); + } + + let char_ops = streaming_diff.push_new(&reindented); + apply_char_operations( + &char_ops, + buffer, + original_snapshot, + edit_cursor, + &context.action_log, + cx, + ); + + let position = original_snapshot.anchor_before(*edit_cursor); + cx.update(|cx| { + context.set_agent_location(buffer.downgrade(), position, cx); + }); + } + EditEvent::NewTextChunk { + chunk, done: true, .. + } => { + log::debug!("new_text_chunk: done=true, chunk='{}'", chunk); + + let Some(EditPipelineEntry::StreamingNewText { + mut streaming_diff, + mut edit_cursor, + mut reindenter, + original_snapshot, + }) = self.current_edit.take() + else { + return Ok(()); + }; + + let mut final_text = reindenter.push(chunk); + final_text.push_str(&reindenter.finish()); + + log::debug!("new_text_chunk: done=true, final_text='{}'", final_text); + + if !final_text.is_empty() { + let char_ops = streaming_diff.push_new(&final_text); + apply_char_operations( + &char_ops, + buffer, + &original_snapshot, + &mut edit_cursor, + &context.action_log, + cx, + ); + } + + let remaining_ops = streaming_diff.finish(); + apply_char_operations( + &remaining_ops, + buffer, + &original_snapshot, + &mut edit_cursor, + &context.action_log, + cx, + ); + + let position = original_snapshot.anchor_before(edit_cursor); + cx.update(|cx| { + context.set_agent_location(buffer.downgrade(), position, cx); + }); + } + } + Ok(()) + } +} + +impl EditSession { + pub(crate) async fn new( + path: PathBuf, + mode: EditSessionMode, + tool_name: &str, + context: Arc, + event_stream: &ToolCallEventStream, + cx: &mut AsyncApp, + ) -> Result { + let project_path = cx.update(|cx| resolve_path(mode, &path, &context.project, cx))?; + + let Some(abs_path) = + cx.update(|cx| context.project.read(cx).absolute_path(&project_path, cx)) + else { + return Err(format!( + "Worktree at '{}' does not exist", + path.to_string_lossy() + )); + }; + + event_stream.update_fields( + ToolCallUpdateFields::new().locations(vec![ToolCallLocation::new(abs_path.clone())]), + ); + + cx.update(|cx| context.authorize(tool_name, &path, event_stream, cx)) + .await + .map_err(|e| e.to_string())?; + + let buffer = context + .project + .update(cx, |project, cx| project.open_buffer(project_path, cx)) + .await + .map_err(|e| e.to_string())?; + + let file_changed_since_last_read = ensure_buffer_saved(&buffer, &abs_path, &context, cx)?; + + let diff = cx.new(|cx| Diff::new(buffer.clone(), cx)); + event_stream.update_diff(diff.clone()); + let finalize_diff_guard = util::defer(Box::new({ + let diff = diff.downgrade(); + let mut cx = cx.clone(); + move || { + diff.update(&mut cx, |diff, cx| diff.finalize(cx)).ok(); + } + }) as Box); + + context.action_log.update(cx, |log, cx| match mode { + EditSessionMode::Write => log.buffer_created(buffer.clone(), cx), + EditSessionMode::Edit => log.buffer_read(buffer.clone(), cx), + }); + + let old_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot()); + let old_text = cx + .background_spawn({ + let old_snapshot = old_snapshot.clone(); + async move { Arc::new(old_snapshot.text()) } + }) + .await; + + Ok(Self { + abs_path, + input_path: path, + buffer, + old_text, + diff, + parser: StreamingParser::default(), + pipeline: Pipeline::new(mode, file_changed_since_last_read), + context, + _finalize_diff_guard: finalize_diff_guard, + }) + } + + pub(crate) async fn finalize_edit( + &mut self, + edits: Vec, + event_stream: &ToolCallEventStream, + cx: &mut AsyncApp, + ) -> Result<(), String> { + let Self { + abs_path, + buffer, + diff, + parser, + pipeline, + context, + .. + } = self; + let Pipeline::Edit(edit_pipeline) = pipeline else { + return Err("Cannot finalize edits on a write session".to_string()); + }; + + for event in &parser.finalize_edits(&edits) { + edit_pipeline.process_event( + event, + buffer, + diff, + abs_path, + context, + event_stream, + cx, + )?; + } + + if log::log_enabled!(log::Level::Debug) { + log::debug!("Got edits:"); + for edit in &edits { + log::debug!( + " old_text: '{}', new_text: '{}'", + edit.old_text.replace('\n', "\\n"), + edit.new_text.replace('\n', "\\n") + ); + } + } + Ok(()) + } + + pub(crate) async fn finalize_write( + &mut self, + content: &str, + cx: &mut AsyncApp, + ) -> Result<(), String> { + let Self { + buffer, + parser, + pipeline, + context, + .. + } = self; + let Pipeline::Write(write) = pipeline else { + return Err("Cannot finalize a write on an edit session".to_string()); + }; + + for event in &parser.finalize_content(content) { + write.process_event(event, buffer, context, cx); + } + Ok(()) + } + + async fn compute_new_text_and_diff(&self, cx: &mut AsyncApp) -> (String, String) { + let new_snapshot = self.buffer.read_with(cx, |buffer, _cx| buffer.snapshot()); + let (new_text, unified_diff) = cx + .background_spawn({ + let new_snapshot = new_snapshot.clone(); + let old_text = self.old_text.clone(); + async move { + let new_text = new_snapshot.text(); + let diff = language::unified_diff(&old_text, &new_text); + (new_text, diff) + } + }) + .await; + (new_text, unified_diff) + } + + pub(crate) fn process_edit( + &mut self, + edits: Option<&[PartialEdit]>, + event_stream: &ToolCallEventStream, + cx: &mut AsyncApp, + ) -> Result<(), String> { + let Self { + abs_path, + buffer, + diff, + parser, + pipeline, + context, + .. + } = self; + let Pipeline::Edit(edit_pipeline) = pipeline else { + return Err("Cannot apply partial edits on a write session".to_string()); + }; + let Some(edits) = edits else { + return Ok(()); + }; + for event in &parser.push_edits(edits) { + edit_pipeline.process_event( + event, + buffer, + diff, + abs_path, + context, + event_stream, + cx, + )?; + } + Ok(()) + } + + pub(crate) fn process_write( + &mut self, + content: Option<&str>, + cx: &mut AsyncApp, + ) -> Result<(), String> { + let Self { + buffer, + parser, + pipeline, + context, + .. + } = self; + let Pipeline::Write(write) = pipeline else { + return Err("Cannot apply partial content on an edit session".to_string()); + }; + let Some(content) = content else { + return Ok(()); + }; + for event in &parser.push_content(content) { + write.process_event(event, buffer, context, cx); + } + Ok(()) + } +} + +fn apply_char_operations( + ops: &[CharOperation], + buffer: &Entity, + snapshot: &text::BufferSnapshot, + edit_cursor: &mut usize, + action_log: &Entity, + cx: &mut AsyncApp, +) { + for op in ops { + match op { + CharOperation::Insert { text } => { + let anchor = snapshot.anchor_after(*edit_cursor); + agent_edit_buffer(&buffer, [(anchor..anchor, text.as_str())], action_log, cx); + } + CharOperation::Delete { bytes } => { + let delete_end = *edit_cursor + bytes; + let anchor_range = snapshot.anchor_range_inside(*edit_cursor..delete_end); + agent_edit_buffer(&buffer, [(anchor_range, "")], action_log, cx); + *edit_cursor = delete_end; + } + CharOperation::Keep { bytes } => { + *edit_cursor += bytes; + } + } + } +} + +fn extract_match( + matches: Vec>, + buffer: &Entity, + edit_index: &usize, + file_changed_since_last_read: bool, + cx: &mut AsyncApp, +) -> Result, String> { + let file_changed_since_last_read_message = if file_changed_since_last_read { + " The file has changed on disk since you last read it." + } else { + "" + }; + + match matches.len() { + 0 => Err(format!( + "Could not find matching text for edit at index {}. \ + The old_text did not match any content in the file.{} \ + Please read the file again to get the current content.", + edit_index, file_changed_since_last_read_message, + )), + 1 => Ok(matches.into_iter().next().unwrap()), + _ => { + let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot()); + let lines = matches + .iter() + .map(|range| (snapshot.offset_to_point(range.start).row + 1).to_string()) + .collect::>() + .join(", "); + Err(format!( + "Edit {} matched multiple locations in the file at lines: {}. \ + Please provide more context in old_text to uniquely \ + identify the location.", + edit_index, lines + )) + } + } +} + +/// Edits a buffer and reports the edit to the action log in the same effect +/// cycle. This ensures the action log's subscription handler sees the version +/// already updated by `buffer_edited`, so it does not misattribute the agent's +/// edit as a user edit. +fn agent_edit_buffer( + buffer: &Entity, + edits: I, + action_log: &Entity, + cx: &mut AsyncApp, +) where + I: IntoIterator, T)>, + S: ToOffset, + T: Into>, +{ + cx.update(|cx| { + buffer.update(cx, |buffer, cx| { + buffer.edit(edits, None, cx); + }); + action_log.update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx)); + }); +} + +fn ensure_buffer_saved( + buffer: &Entity, + abs_path: &PathBuf, + context: &EditSessionContext, + cx: &mut AsyncApp, +) -> Result { + let last_read_mtime = context + .action_log + .read_with(cx, |log, _| log.file_read_time(abs_path)); + let check_result = context.thread.read_with(cx, |thread, cx| { + let current = buffer + .read(cx) + .file() + .and_then(|file| file.disk_state().mtime()); + let dirty = buffer.read(cx).is_dirty(); + let has_save = thread.has_tool(SaveFileTool::NAME); + let has_restore = thread.has_tool(RestoreFileFromDiskTool::NAME); + (current, dirty, has_save, has_restore) + }); + + let Ok((current_mtime, is_dirty, has_save_tool, has_restore_tool)) = check_result else { + return Ok(false); + }; + + if is_dirty { + let message = match (has_save_tool, has_restore_tool) { + (true, true) => { + "This file has unsaved changes. Ask the user whether they want to keep or discard those changes. \ + If they want to keep them, ask for confirmation then use the save_file tool to save the file, then retry this edit. \ + If they want to discard them, ask for confirmation then use the restore_file_from_disk tool to restore the on-disk contents, then retry this edit." + } + (true, false) => { + "This file has unsaved changes. Ask the user whether they want to keep or discard those changes. \ + If they want to keep them, ask for confirmation then use the save_file tool to save the file, then retry this edit. \ + If they want to discard them, ask the user to manually revert the file, then inform you when it's ok to proceed." + } + (false, true) => { + "This file has unsaved changes. Ask the user whether they want to keep or discard those changes. \ + If they want to keep them, ask the user to manually save the file, then inform you when it's ok to proceed. \ + If they want to discard them, ask for confirmation then use the restore_file_from_disk tool to restore the on-disk contents, then retry this edit." + } + (false, false) => { + "This file has unsaved changes. Ask the user whether they want to keep or discard those changes, \ + then ask them to save or revert the file manually and inform you when it's ok to proceed." + } + }; + return Err(message.to_string()); + } + + if let (Some(last_read), Some(current)) = (last_read_mtime, current_mtime) + && current != last_read + { + return Ok(true); + } + + Ok(false) +} + +fn resolve_path( + mode: EditSessionMode, + path: &PathBuf, + project: &Entity, + cx: &mut App, +) -> Result { + let project = project.read(cx); + + match mode { + EditSessionMode::Edit => { + let path = project + .find_project_path(&path, cx) + .ok_or_else(|| "Can't edit file: path not found".to_string())?; + + let entry = project + .entry_for_path(&path, cx) + .ok_or_else(|| "Can't edit file: path not found".to_string())?; + + if entry.is_file() { + Ok(path) + } else { + Err("Can't edit file: path is a directory".to_string()) + } + } + EditSessionMode::Write => { + if let Some(path) = project.find_project_path(&path, cx) + && let Some(entry) = project.entry_for_path(&path, cx) + { + if entry.is_file() { + return Ok(path); + } else { + return Err("Can't write to file: path is a directory".to_string()); + } + } + + let parent_path = path + .parent() + .ok_or_else(|| "Can't create file: incorrect path".to_string())?; + + let parent_project_path = project.find_project_path(&parent_path, cx); + + let parent_entry = parent_project_path + .as_ref() + .and_then(|path| project.entry_for_path(path, cx)) + .ok_or_else(|| "Can't create file: parent directory doesn't exist")?; + + if !parent_entry.is_dir() { + return Err("Can't create file: parent is not a directory".to_string()); + } + + let file_name = path + .file_name() + .and_then(|file_name| file_name.to_str()) + .and_then(|file_name| RelPath::unix(file_name).ok()) + .ok_or_else(|| "Can't create file: invalid filename".to_string())?; + + let new_file_path = parent_project_path.map(|parent| ProjectPath { + path: parent.path.join(file_name), + ..parent + }); + + new_file_path.ok_or_else(|| "Can't create file".to_string()) + } + } +} + +#[cfg(test)] +pub(crate) async fn test_resolve_path( + mode: &EditSessionMode, + path: &str, + project: &Entity, + cx: &mut gpui::TestAppContext, +) -> Result { + cx.update(|cx| resolve_path(*mode, &PathBuf::from(path), project, cx)) +} diff --git a/crates/agent/src/tools/edit_file_tool/reindent.rs b/crates/agent/src/tools/edit_session/reindent.rs similarity index 100% rename from crates/agent/src/tools/edit_file_tool/reindent.rs rename to crates/agent/src/tools/edit_session/reindent.rs diff --git a/crates/agent/src/tools/edit_file_tool/streaming_fuzzy_matcher.rs b/crates/agent/src/tools/edit_session/streaming_fuzzy_matcher.rs similarity index 100% rename from crates/agent/src/tools/edit_file_tool/streaming_fuzzy_matcher.rs rename to crates/agent/src/tools/edit_session/streaming_fuzzy_matcher.rs diff --git a/crates/agent/src/tools/edit_file_tool/streaming_parser.rs b/crates/agent/src/tools/edit_session/streaming_parser.rs similarity index 99% rename from crates/agent/src/tools/edit_file_tool/streaming_parser.rs rename to crates/agent/src/tools/edit_session/streaming_parser.rs index 6a44959a141c804815981011f995f0ef2749b2d5..a976b08b004771131d7eefe8de77bb5968e37565 100644 --- a/crates/agent/src/tools/edit_file_tool/streaming_parser.rs +++ b/crates/agent/src/tools/edit_session/streaming_parser.rs @@ -1,6 +1,6 @@ use smallvec::SmallVec; -use crate::{Edit, PartialEdit}; +use super::{Edit, PartialEdit}; /// Events emitted by `StreamingParser` for edit-mode input. #[derive(Debug, PartialEq, Eq)] diff --git a/crates/agent/src/tools/evals.rs b/crates/agent/src/tools/evals.rs index a2e09b3f8aa9ed039cd2bc349eff1ca3b30b0317..3096068931161d0963028b28af12a1d178ca9b0d 100644 --- a/crates/agent/src/tools/evals.rs +++ b/crates/agent/src/tools/evals.rs @@ -2,3 +2,5 @@ mod edit_file; #[cfg(all(test, feature = "unit-eval"))] mod terminal_tool; +#[cfg(all(test, feature = "unit-eval"))] +mod write_file; diff --git a/crates/agent/src/tools/evals/edit_file.rs b/crates/agent/src/tools/evals/edit_file.rs index cce9f41c6efd8de112d0e5660537998ac5b41fd2..4c96b0797f87709ba93a40040110bbf6c943990a 100644 --- a/crates/agent/src/tools/evals/edit_file.rs +++ b/crates/agent/src/tools/evals/edit_file.rs @@ -1,8 +1,7 @@ use crate::tools::edit_file_tool::*; use crate::{ - AgentTool, ContextServerRegistry, EditFileTool, GrepTool, GrepToolInput, ListDirectoryTool, - ListDirectoryToolInput, ReadFileTool, ReadFileToolInput, Template, Templates, Thread, - ToolCallEventStream, ToolInput, + AgentTool, ContextServerRegistry, EditFileTool, GrepTool, GrepToolInput, ReadFileTool, + ReadFileToolInput, Template, Templates, Thread, ToolCallEventStream, ToolInput, }; use Role::*; use anyhow::{Context as _, Result}; @@ -124,20 +123,6 @@ impl EvalAssertion { EvalAssertion(Arc::new(f)) } - fn assert_eq(expected: impl Into) -> Self { - let expected = expected.into(); - Self::new(async move |sample, _judge, _cx| { - Ok(EvalAssertionOutcome { - score: if strip_empty_lines(&sample.text_after) == strip_empty_lines(&expected) { - 100 - } else { - 0 - }, - message: None, - }) - }) - } - fn assert_diff_any(expected_diffs: Vec>) -> Self { let expected_diffs: Vec = expected_diffs.into_iter().map(Into::into).collect(); Self::new(async move |sample, _judge, _cx| { @@ -1499,46 +1484,3 @@ fn eval_add_overwrite_test() { )) }); } - -#[test] -#[cfg_attr(not(feature = "unit-eval"), ignore)] -fn eval_create_empty_file() { - let input_file_path = "root/TODO3"; - let input_file_content = None; - let expected_output_content = String::new(); - - eval_utils::eval(100, 0.99, eval_utils::NoProcessor, move || { - run_eval(EvalInput::new( - vec![ - message(User, [text("Create a second empty todo file ")]), - message( - Assistant, - [ - text(indoc::formatdoc! {" - I'll help you create a second empty todo file. - First, let me examine the project structure to see if there's already a todo file, which will help me determine the appropriate name and location for the second one. - "}), - tool_use( - "toolu_01GAF8TtsgpjKxCr8fgQLDgR", - ListDirectoryTool::NAME, - ListDirectoryToolInput { - path: "root".to_string(), - }, - ), - ], - ), - message( - User, - [tool_result( - "toolu_01GAF8TtsgpjKxCr8fgQLDgR", - ListDirectoryTool::NAME, - "root/TODO\nroot/TODO2\nroot/new.txt\n", - )], - ), - ], - input_file_path, - input_file_content.clone(), - EvalAssertion::assert_eq(expected_output_content.clone()), - )) - }); -} diff --git a/crates/agent/src/tools/evals/write_file.rs b/crates/agent/src/tools/evals/write_file.rs new file mode 100644 index 0000000000000000000000000000000000000000..f34528fcd78577757a2189ae77389d35a1e0bd4a --- /dev/null +++ b/crates/agent/src/tools/evals/write_file.rs @@ -0,0 +1,561 @@ +use crate::{ + AgentTool, ContextServerRegistry, ListDirectoryTool, ListDirectoryToolInput, Template, + Templates, Thread, ToolCallEventStream, ToolInput, WriteFileTool, WriteFileToolInput, +}; +use Role::*; +use anyhow::{Context as _, Result}; +use client::{Client, RefreshLlmTokenListener, UserStore}; +use fs::FakeFs; +use futures::StreamExt; +use gpui::{AppContext as _, AsyncApp, Entity, TestAppContext, UpdateGlobal as _}; +use http_client::StatusCode; +use language::language_settings::FormatOnSave; +use language_model::{ + LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, + LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, + LanguageModelToolResult, LanguageModelToolResultContent, LanguageModelToolUse, + LanguageModelToolUseId, MessageContent, Role, SelectedModel, +}; +use project::Project; +use prompt_store::{ProjectContext, WorktreeContext}; +use rand::prelude::*; +use reqwest_client::ReqwestClient; +use serde::Serialize; +use settings::SettingsStore; +use std::{ + fmt::{self, Display}, + path::{Path, PathBuf}, + str::FromStr, + sync::Arc, + time::Duration, +}; +use util::path; + +#[derive(Clone)] +struct EvalInput { + conversation: Vec, + input_file_path: PathBuf, + input_content: Option, + expected_output_content: String, +} + +impl EvalInput { + fn new( + conversation: Vec, + input_file_path: impl Into, + input_content: Option, + expected_output_content: String, + ) -> Self { + Self { + conversation, + input_file_path: input_file_path.into(), + input_content, + expected_output_content, + } + } +} + +struct WriteEvalOutput { + tool_input: WriteFileToolInput, + text_after: String, +} + +impl Display for WriteEvalOutput { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + writeln!(f, "Tool Input:\n{:#?}", self.tool_input)?; + writeln!(f, "Text After:\n{}", self.text_after)?; + Ok(()) + } +} + +struct WriteToolTest { + fs: Arc, + project: Entity, + model: Arc, + model_thinking_effort: Option, +} + +impl WriteToolTest { + async fn new(cx: &mut TestAppContext) -> Self { + cx.executor().allow_parking(); + + let fs = FakeFs::new(cx.executor()); + cx.update(|cx| { + let settings_store = SettingsStore::test(cx); + cx.set_global(settings_store); + SettingsStore::update_global(cx, |store: &mut SettingsStore, cx| { + store.update_user_settings(cx, |settings| { + settings + .project + .all_languages + .defaults + .ensure_final_newline_on_save = Some(false); + settings.project.all_languages.defaults.format_on_save = + Some(FormatOnSave::Off); + }); + }); + + gpui_tokio::init(cx); + let http_client = Arc::new(ReqwestClient::user_agent("agent tests").unwrap()); + cx.set_http_client(http_client); + let client = Client::production(cx); + let user_store = cx.new(|cx| UserStore::new(client.clone(), cx)); + language_model::init(cx); + RefreshLlmTokenListener::register(client.clone(), user_store.clone(), cx); + language_models::init(user_store, client, cx); + }); + + fs.insert_tree("/root", serde_json::json!({})).await; + let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; + let agent_model = SelectedModel::from_str( + &std::env::var("ZED_AGENT_MODEL") + .unwrap_or("anthropic/claude-sonnet-4-6-latest".into()), + ) + .unwrap(); + + let authenticate_provider_tasks = cx.update(|cx| { + LanguageModelRegistry::global(cx).update(cx, |registry, cx| { + registry + .providers() + .iter() + .map(|p| p.authenticate(cx)) + .collect::>() + }) + }); + let model = cx + .update(|cx| { + cx.spawn(async move |cx| { + futures::future::join_all(authenticate_provider_tasks).await; + Self::load_model(&agent_model, cx).await.unwrap() + }) + }) + .await; + + let model_thinking_effort = model + .default_effort_level() + .map(|effort_level| effort_level.value.to_string()); + + Self { + fs, + project, + model, + model_thinking_effort, + } + } + + async fn load_model( + selected_model: &SelectedModel, + cx: &mut AsyncApp, + ) -> Result> { + cx.update(|cx| { + let registry = LanguageModelRegistry::read_global(cx); + let provider = registry + .provider(&selected_model.provider) + .expect("Provider not found"); + provider.authenticate(cx) + }) + .await?; + Ok(cx.update(|cx| { + let models = LanguageModelRegistry::read_global(cx); + models + .available_models(cx) + .find(|model| { + model.provider_id() == selected_model.provider + && model.id() == selected_model.model + }) + .unwrap_or_else(|| panic!("Model {} not found", selected_model.model.0)) + })) + } + + async fn eval(&self, mut eval: EvalInput, cx: &mut TestAppContext) -> Result { + eval.conversation + .last_mut() + .context("Conversation must not be empty")? + .cache = true; + + if let Some(input_content) = eval.input_content.as_deref() { + let abs_path = Path::new("/root").join( + eval.input_file_path + .strip_prefix("root") + .unwrap_or(&eval.input_file_path), + ); + self.fs.insert_file(&abs_path, input_content.into()).await; + cx.run_until_parked(); + } + + let tools = crate::built_in_tools().collect::>(); + + let system_prompt = { + let worktrees = vec![WorktreeContext { + root_name: "root".to_string(), + abs_path: Path::new("/path/to/root").into(), + rules_file: None, + }]; + let project_context = ProjectContext::new(worktrees, Vec::default()); + let tool_names = tools + .iter() + .map(|tool| tool.name.clone().into()) + .collect::>(); + let template = crate::SystemPromptTemplate { + project: &project_context, + available_tools: tool_names, + model_name: None, + }; + let templates = Templates::new(); + template.render(&templates)? + }; + + let messages = [LanguageModelRequestMessage { + role: Role::System, + content: vec![MessageContent::Text(system_prompt)], + cache: true, + reasoning_details: None, + }] + .into_iter() + .chain(eval.conversation) + .collect::>(); + + let request = LanguageModelRequest { + messages, + tools, + thinking_allowed: true, + thinking_effort: self.model_thinking_effort.clone(), + ..Default::default() + }; + + let tool_input = + retry_on_rate_limit(async || self.extract_tool_use(request.clone(), cx).await).await?; + + let language_registry = self + .project + .read_with(cx, |project, _cx| project.languages().clone()); + + let context_server_registry = cx + .new(|cx| ContextServerRegistry::new(self.project.read(cx).context_server_store(), cx)); + let thread = cx.new(|cx| { + Thread::new( + self.project.clone(), + cx.new(|_cx| ProjectContext::default()), + context_server_registry, + Templates::new(), + Some(self.model.clone()), + cx, + ) + }); + let action_log = thread.read_with(cx, |thread, _| thread.action_log().clone()); + + let tool = Arc::new(WriteFileTool::new( + self.project.clone(), + thread.downgrade(), + action_log, + language_registry, + )); + + let result = cx + .update(|cx| { + tool.clone().run( + ToolInput::resolved(tool_input.clone()), + ToolCallEventStream::test().0, + cx, + ) + }) + .await; + + let output = match result { + Ok(output) => output, + Err(output) => anyhow::bail!("Tool returned error: {}", output), + }; + + let crate::EditFileToolOutput::Success { new_text, .. } = &output else { + anyhow::bail!("Tool returned error output: {}", output); + }; + + if tool_input.path != eval.input_file_path { + anyhow::bail!( + "Tool path mismatch. Expected {:?}, got {:?}", + eval.input_file_path, + tool_input.path, + ); + } + + if new_text != &eval.expected_output_content { + anyhow::bail!( + "Output content mismatch. Expected {:?}, got {:?}", + eval.expected_output_content, + new_text, + ); + } + + Ok(WriteEvalOutput { + tool_input, + text_after: new_text.clone(), + }) + } + + async fn extract_tool_use( + &self, + request: LanguageModelRequest, + cx: &mut TestAppContext, + ) -> Result { + let model = self.model.clone(); + let events = cx + .update(|cx| { + let async_cx = cx.to_async(); + cx.foreground_executor() + .spawn(async move { model.stream_completion(request, &async_cx).await }) + }) + .await + .map_err(|err| anyhow::anyhow!("completion error: {}", err))?; + + let mut streamed_text = String::new(); + let mut stop_reason = None; + let mut parse_errors = Vec::new(); + + let mut events = events.fuse(); + while let Some(event) = events.next().await { + match event { + Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) + if tool_use.is_input_complete + && tool_use.name.as_ref() == WriteFileTool::NAME => + { + let input: WriteFileToolInput = serde_json::from_value(tool_use.input) + .context("Failed to parse tool input as WriteFileToolInput")?; + return Ok(input); + } + Ok(LanguageModelCompletionEvent::Text(text)) => { + if streamed_text.len() < 2_000 { + streamed_text.push_str(&text); + } + } + Ok(LanguageModelCompletionEvent::Stop(reason)) => { + stop_reason = Some(reason); + } + Ok(LanguageModelCompletionEvent::ToolUseJsonParseError { + tool_name, + raw_input, + json_parse_error, + .. + }) if tool_name.as_ref() == WriteFileTool::NAME => { + parse_errors.push(format!("{json_parse_error}\nRaw input:\n{raw_input:?}")); + } + Err(err) => return Err(anyhow::anyhow!("completion error: {}", err)), + _ => {} + } + } + + let streamed_text = streamed_text.trim(); + let streamed_text_suffix = if streamed_text.is_empty() { + String::new() + } else { + format!("\nStreamed text:\n{streamed_text}") + }; + let stop_reason_suffix = stop_reason + .map(|reason| format!("\nStop reason: {reason:?}")) + .unwrap_or_default(); + let parse_errors_suffix = if parse_errors.is_empty() { + String::new() + } else { + format!("\nTool parse errors:\n{}", parse_errors.join("\n")) + }; + + anyhow::bail!( + "Stream ended without a write_file tool use{stop_reason_suffix}{parse_errors_suffix}{streamed_text_suffix}" + ) + } +} + +fn run_eval(eval: EvalInput) -> eval_utils::EvalOutput<()> { + let dispatcher = gpui::TestDispatcher::new(rand::random()); + let mut cx = TestAppContext::build(dispatcher, None); + let foreground_executor = cx.foreground_executor().clone(); + let result = foreground_executor.block_test(async { + let test = WriteToolTest::new(&mut cx).await; + let result = test.eval(eval, &mut cx).await; + drop(test); + cx.run_until_parked(); + result + }); + cx.quit(); + match result { + Ok(output) => eval_utils::EvalOutput { + data: output.to_string(), + outcome: eval_utils::OutcomeKind::Passed, + metadata: (), + }, + Err(err) => eval_utils::EvalOutput { + data: format!("{err:?}"), + outcome: eval_utils::OutcomeKind::Error, + metadata: (), + }, + } +} + +fn message( + role: Role, + content: impl IntoIterator, +) -> LanguageModelRequestMessage { + LanguageModelRequestMessage { + role, + content: content.into_iter().collect(), + cache: false, + reasoning_details: None, + } +} + +fn text(text: impl Into) -> MessageContent { + MessageContent::Text(text.into()) +} + +fn tool_use( + id: impl Into>, + name: impl Into>, + input: impl Serialize, +) -> MessageContent { + MessageContent::ToolUse(LanguageModelToolUse { + id: LanguageModelToolUseId::from(id.into()), + name: name.into(), + raw_input: serde_json::to_string_pretty(&input).unwrap(), + input: serde_json::to_value(input).unwrap(), + is_input_complete: true, + thought_signature: None, + }) +} + +fn tool_result( + id: impl Into>, + name: impl Into>, + result: impl Into>, +) -> MessageContent { + MessageContent::ToolResult(LanguageModelToolResult { + tool_use_id: LanguageModelToolUseId::from(id.into()), + tool_name: name.into(), + is_error: false, + content: vec![LanguageModelToolResultContent::Text(result.into())], + output: None, + }) +} + +async fn retry_on_rate_limit(mut request: impl AsyncFnMut() -> Result) -> Result { + const MAX_RETRIES: usize = 20; + let mut attempt = 0; + + loop { + attempt += 1; + let response = request().await; + + if attempt >= MAX_RETRIES { + return response; + } + + let retry_delay = match &response { + Ok(_) => None, + Err(err) => match err.downcast_ref::() { + Some(err) => match &err { + LanguageModelCompletionError::RateLimitExceeded { retry_after, .. } + | LanguageModelCompletionError::ServerOverloaded { retry_after, .. } => { + Some(retry_after.unwrap_or(Duration::from_secs(5))) + } + LanguageModelCompletionError::UpstreamProviderError { + status, + retry_after, + .. + } => { + let should_retry = matches!( + *status, + StatusCode::TOO_MANY_REQUESTS | StatusCode::SERVICE_UNAVAILABLE + ) || status.as_u16() == 529; + + if should_retry { + Some(retry_after.unwrap_or(Duration::from_secs(5))) + } else { + None + } + } + LanguageModelCompletionError::ApiReadResponseError { .. } + | LanguageModelCompletionError::ApiInternalServerError { .. } + | LanguageModelCompletionError::HttpSend { .. } => { + Some(Duration::from_secs(2_u64.pow((attempt - 1) as u32).min(30))) + } + _ => None, + }, + _ => None, + }, + }; + + if let Some(retry_after) = retry_delay { + let jitter = retry_after.mul_f64(rand::rng().random_range(0.0..1.0)); + eprintln!("Attempt #{attempt}: Retry after {retry_after:?} + jitter of {jitter:?}"); + #[allow(clippy::disallowed_methods)] + async_io::Timer::after(retry_after + jitter).await; + } else { + return response; + } + } +} + +#[test] +#[cfg_attr(not(feature = "unit-eval"), ignore)] +fn eval_create_file() { + let input_file_path = "root/TODO3"; + let expected_output_content = "todo".to_string(); + + eval_utils::eval(100, 1., eval_utils::NoProcessor, move || { + run_eval(EvalInput::new( + vec![ + message( + User, + [text("Create a third todo file. Write 'todo' inside it.")], + ), + message( + Assistant, + [ + text(indoc::formatdoc! {" + I'll help you create a third empty todo file. + First, let me examine the project structure to see if there's already a todo file, which will help me determine the appropriate name and location for the second one. + "}), + tool_use( + "toolu_01GAF8TtsgpjKxCr8fgQLDgR", + ListDirectoryTool::NAME, + ListDirectoryToolInput { + path: "root".to_string(), + }, + ), + ], + ), + message( + User, + [tool_result( + "toolu_01GAF8TtsgpjKxCr8fgQLDgR", + ListDirectoryTool::NAME, + "root/TODO\nroot/TODO2\nroot/new.txt\n", + )], + ), + ], + input_file_path, + None, + expected_output_content.clone(), + )) + }); +} + +#[test] +#[cfg_attr(not(feature = "unit-eval"), ignore)] +fn eval_overwrite_file() { + let input_file_path = "root/notes.txt"; + let input_file_content = "old notes\nkeep nothing\n".to_string(); + let expected_output_content = "new notes".to_string(); + + eval_utils::eval(100, 1., eval_utils::NoProcessor, move || { + run_eval(EvalInput::new( + vec![message( + User, + [text(indoc::formatdoc! {" + Overwrite `{input_file_path}` so that its complete contents are exactly: 'new notes' + "})], + )], + input_file_path, + Some(input_file_content.clone()), + expected_output_content.clone(), + )) + }); +} diff --git a/crates/agent/src/tools/write_file_tool.rs b/crates/agent/src/tools/write_file_tool.rs new file mode 100644 index 0000000000000000000000000000000000000000..c9cd548f316ed898fc983f77bd2fdb8c5f94db4e --- /dev/null +++ b/crates/agent/src/tools/write_file_tool.rs @@ -0,0 +1,1190 @@ +use super::edit_session::{ + EditSession, EditSessionContext, EditSessionMode, EditSessionOutput, EditSessionResult, + initial_title_from_partial_path, run_session, +}; +use crate::{AgentTool, Thread, ToolCallEventStream, ToolInput, ToolInputPayload}; +use action_log::ActionLog; +use agent_client_protocol::schema as acp; +use futures::FutureExt as _; +use gpui::{App, AsyncApp, Entity, Task, WeakEntity}; +use language::LanguageRegistry; +use project::Project; +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; +use std::path::PathBuf; +use std::sync::Arc; +use ui::SharedString; + +const DEFAULT_UI_TEXT: &str = "Writing file"; + +/// This is a tool for creating a new file or overwriting an existing file with completely new contents. +/// +/// To make granular edits to an existing file, prefer the `edit_file` tool instead. +/// +/// Before using this tool: +/// +/// 1. Verify the directory path is correct (only applicable when creating new files): +/// - Use the `list_directory` tool to verify the parent directory exists and is the correct location +#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)] +pub struct WriteFileToolInput { + /// The full path of the file to create or overwrite in the project. + /// + /// WARNING: When specifying which file path need changing, you MUST start each path with one of the project's root directories. + /// + /// The following examples assume we have two root directories in the project: + /// - /a/b/backend + /// - /c/d/frontend + /// + /// + /// `backend/src/main.rs` + /// + /// Notice how the file path starts with `backend`. Without that, the path would be ambiguous and the call would fail! + /// + /// + /// + /// `frontend/db.js` + /// + pub path: PathBuf, + + /// The complete content for the file. + /// This field should contain the entire file content. + pub content: String, +} + +#[derive(Clone, Default, Debug, Deserialize)] +struct WriteFileToolPartialInput { + #[serde(default)] + path: Option, + #[serde(default)] + content: Option, +} + +pub struct WriteFileTool { + session_context: Arc, +} + +impl WriteFileTool { + pub fn new( + project: Entity, + thread: WeakEntity, + action_log: Entity, + language_registry: Arc, + ) -> Self { + Self { + session_context: Arc::new(EditSessionContext::new( + project, + thread, + action_log, + language_registry, + )), + } + } + + async fn process_streaming_writes( + &self, + input: &mut ToolInput, + event_stream: &ToolCallEventStream, + cx: &mut AsyncApp, + ) -> EditSessionResult { + let mut session: Option = None; + let mut last_path: Option = None; + + loop { + futures::select! { + payload = input.next().fuse() => { + match payload { + Ok(payload) => match payload { + ToolInputPayload::Partial(partial) => { + if let Ok(parsed) = serde_json::from_value::(partial) { + let path_complete = parsed.path.is_some() + && parsed.path.as_ref() == last_path.as_ref(); + + last_path = parsed.path.clone(); + + if session.is_none() + && path_complete + && let Some(path) = parsed.path.as_ref() + { + match EditSession::new( + PathBuf::from(path), + EditSessionMode::Write, + Self::NAME, + self.session_context.clone(), + event_stream, + cx, + ) + .await + { + Ok(created_session) => session = Some(created_session), + Err(error) => { + log::error!("Failed to create edit session: {}", error); + return EditSessionResult::Failed { + error, + session: None, + }; + } + } + } + + if let Some(current_session) = &mut session + && let Err(error) = current_session.process_write(parsed.content.as_deref(), cx) + { + log::error!("Failed to process write: {}", error); + return EditSessionResult::Failed { error, session }; + } + } + } + ToolInputPayload::Full(full_input) => { + let mut session = if let Some(session) = session { + session + } else { + match EditSession::new( + full_input.path.clone(), + EditSessionMode::Write, + Self::NAME, + self.session_context.clone(), + event_stream, + cx, + ) + .await + { + Ok(created_session) => created_session, + Err(error) => { + log::error!("Failed to create edit session: {}", error); + return EditSessionResult::Failed { + error, + session: None, + }; + } + } + }; + + return match session.finalize_write(&full_input.content, cx).await { + Ok(()) => EditSessionResult::Completed(session), + Err(error) => { + log::error!("Failed to finalize write: {}", error); + EditSessionResult::Failed { + error, + session: Some(session), + } + } + }; + } + ToolInputPayload::InvalidJson { error_message } => { + log::error!("Received invalid JSON: {error_message}"); + return EditSessionResult::Failed { + error: error_message, + session, + }; + } + }, + Err(error) => { + return EditSessionResult::Failed { + error: error.to_string(), + session, + }; + } + } + } + _ = event_stream.cancelled_by_user().fuse() => { + return EditSessionResult::Failed { + error: "Write cancelled by user".to_string(), + session, + }; + } + } + } + } +} + +impl AgentTool for WriteFileTool { + type Input = WriteFileToolInput; + type Output = EditSessionOutput; + + const NAME: &'static str = "write_file"; + + fn supports_input_streaming() -> bool { + true + } + + fn kind() -> acp::ToolKind { + acp::ToolKind::Edit + } + + fn initial_title( + &self, + input: Result, + cx: &mut App, + ) -> SharedString { + match input { + Ok(input) => { + self.session_context + .initial_title_from_path(&input.path, DEFAULT_UI_TEXT, cx) + } + Err(raw_input) => initial_title_from_partial_path::( + &self.session_context, + raw_input, + |partial| partial.path.clone(), + DEFAULT_UI_TEXT, + cx, + ), + } + } + + fn run( + self: Arc, + mut input: ToolInput, + event_stream: ToolCallEventStream, + cx: &mut App, + ) -> Task> { + cx.spawn(async move |cx: &mut AsyncApp| { + run_session( + self.process_streaming_writes(&mut input, &event_stream, cx) + .await, + cx, + ) + .await + }) + } + + fn replay( + &self, + _input: Self::Input, + output: Self::Output, + event_stream: ToolCallEventStream, + cx: &mut App, + ) -> anyhow::Result<()> { + self.session_context.replay_output(output, event_stream, cx) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + AgentTool, ContextServerRegistry, Templates, Thread, ToolCallEventStream, ToolInput, + ToolInputSender, + }; + use acp_thread::Diff; + use action_log::ActionLog; + use fs::Fs as _; + use futures::StreamExt as _; + use gpui::{AppContext as _, Entity, TestAppContext, UpdateGlobal}; + use language::language_settings::FormatOnSave; + use language_model::fake_provider::FakeLanguageModel; + use project::{Project, ProjectPath}; + use prompt_store::ProjectContext; + use serde_json::json; + use settings::{Settings, SettingsStore}; + use std::sync::Arc; + use util::path; + use util::rel_path::{RelPath, rel_path}; + + #[gpui::test] + async fn test_streaming_write_create_file(cx: &mut TestAppContext) { + let (write_tool, _project, _action_log, _fs, _thread) = + setup_test(cx, json!({"dir": {}})).await; + let result = cx + .update(|cx| { + write_tool.clone().run( + ToolInput::resolved(WriteFileToolInput { + path: "root/dir/new_file.txt".into(), + content: "Hello, World!".into(), + }), + ToolCallEventStream::test().0, + cx, + ) + }) + .await; + + let EditSessionOutput::Success { new_text, diff, .. } = result.unwrap() else { + panic!("expected success"); + }; + assert_eq!(new_text, "Hello, World!"); + assert!(!diff.is_empty()); + } + + #[gpui::test] + async fn test_streaming_write_overwrite_file(cx: &mut TestAppContext) { + let (write_tool, _project, _action_log, _fs, _thread) = + setup_test(cx, json!({"file.txt": "old content"})).await; + let result = cx + .update(|cx| { + write_tool.clone().run( + ToolInput::resolved(WriteFileToolInput { + path: "root/file.txt".into(), + content: "new content".into(), + }), + ToolCallEventStream::test().0, + cx, + ) + }) + .await; + + let EditSessionOutput::Success { + new_text, old_text, .. + } = result.unwrap() + else { + panic!("expected success"); + }; + assert_eq!(new_text, "new content"); + assert_eq!(*old_text, "old content"); + } + + #[gpui::test] + async fn test_streaming_path_completeness_heuristic(cx: &mut TestAppContext) { + let (write_tool, _project, _action_log, _fs, _thread) = + setup_test(cx, json!({"file.txt": "hello world"})).await; + let (mut sender, input) = ToolInput::::test(); + let (event_stream, _receiver) = ToolCallEventStream::test(); + let task = cx.update(|cx| write_tool.clone().run(input, event_stream, cx)); + + // Send partial with path but NO mode — path should NOT be treated as complete + sender.send_partial(json!({ + "path": "root/file" + })); + cx.run_until_parked(); + + // Now the path grows and mode appears + sender.send_partial(json!({ + "path": "root/file.txt", + })); + cx.run_until_parked(); + + // Send final + sender.send_full(json!({ + "path": "root/file.txt", + "content": "new content" + })); + + let result = task.await; + let EditSessionOutput::Success { new_text, .. } = result.unwrap() else { + panic!("expected success"); + }; + assert_eq!(new_text, "new content"); + } + + #[gpui::test] + async fn test_streaming_create_file_with_partials(cx: &mut TestAppContext) { + let (write_tool, _project, _action_log, _fs, _thread) = + setup_test(cx, json!({"dir": {}})).await; + let (mut sender, input) = ToolInput::::test(); + let (event_stream, _receiver) = ToolCallEventStream::test(); + let task = cx.update(|cx| write_tool.clone().run(input, event_stream, cx)); + + // Stream partials for create mode + sender.send_partial(json!({})); + cx.run_until_parked(); + + sender.send_partial(json!({ + "path": "root/dir/new_file.txt", + })); + cx.run_until_parked(); + + sender.send_partial(json!({ + "path": "root/dir/new_file.txt", + "content": "Hello, " + })); + cx.run_until_parked(); + + // Final with full content + sender.send_full(json!({ + "path": "root/dir/new_file.txt", + "content": "Hello, World!" + })); + + let result = task.await; + let EditSessionOutput::Success { new_text, .. } = result.unwrap() else { + panic!("expected success"); + }; + assert_eq!(new_text, "Hello, World!"); + } + + #[gpui::test] + async fn test_streaming_input_recv_drains_partials(cx: &mut TestAppContext) { + let (write_tool, _project, _action_log, _fs, _thread) = + setup_test(cx, json!({"dir": {}})).await; + // Create a channel and send multiple partials before a final, then use + // ToolInput::resolved-style immediate delivery to confirm recv() works + // when partials are already buffered. + let (mut sender, input): (ToolInputSender, ToolInput) = + ToolInput::test(); + let (event_stream, _event_rx) = ToolCallEventStream::test(); + let task = cx.update(|cx| write_tool.clone().run(input, event_stream, cx)); + + // Buffer several partials before sending the final + sender.send_partial(json!({})); + sender.send_partial(json!({"path": "root/dir/new.txt"})); + sender.send_partial(json!({ + "path": "root/dir/new.txt", + })); + sender.send_full(json!({ + "path": "root/dir/new.txt", + "content": "streamed content" + })); + + let result = task.await; + let EditSessionOutput::Success { new_text, .. } = result.unwrap() else { + panic!("expected success"); + }; + assert_eq!(new_text, "streamed content"); + } + + #[gpui::test] + async fn test_streaming_resolve_path_for_creating_file(cx: &mut TestAppContext) { + let mode = EditSessionMode::Write; + + let result = test_resolve_path(&mode, "root/new.txt", cx); + assert_resolved_path_eq(result.await, rel_path("new.txt")); + + let result = test_resolve_path(&mode, "new.txt", cx); + assert_resolved_path_eq(result.await, rel_path("new.txt")); + + let result = test_resolve_path(&mode, "dir/new.txt", cx); + assert_resolved_path_eq(result.await, rel_path("dir/new.txt")); + + let result = test_resolve_path(&mode, "root/dir/subdir/existing.txt", cx); + assert_resolved_path_eq(result.await, rel_path("dir/subdir/existing.txt")); + + let result = test_resolve_path(&mode, "root/dir/subdir", cx); + assert_eq!( + result.await.unwrap_err(), + "Can't write to file: path is a directory" + ); + + let result = test_resolve_path(&mode, "root/dir/nonexistent_dir/new.txt", cx); + assert_eq!( + result.await.unwrap_err(), + "Can't create file: parent directory doesn't exist" + ); + } + + #[gpui::test] + async fn test_streaming_format_on_save(cx: &mut TestAppContext) { + init_test(cx); + + let fs = project::FakeFs::new(cx.executor()); + fs.insert_tree("/root", json!({"src": {}})).await; + let (write_tool, project, action_log, fs, thread) = + setup_test_with_fs(cx, fs, &[path!("/root").as_ref()]).await; + + let rust_language = Arc::new(language::Language::new( + language::LanguageConfig { + name: "Rust".into(), + matcher: language::LanguageMatcher { + path_suffixes: vec!["rs".to_string()], + ..Default::default() + }, + ..Default::default() + }, + None, + )); + + let language_registry = project.read_with(cx, |project, _| project.languages().clone()); + language_registry.add(rust_language); + + let mut fake_language_servers = language_registry.register_fake_lsp( + "Rust", + language::FakeLspAdapter { + capabilities: lsp::ServerCapabilities { + document_formatting_provider: Some(lsp::OneOf::Left(true)), + ..Default::default() + }, + ..Default::default() + }, + ); + + fs.save( + path!("/root/src/main.rs").as_ref(), + &"initial content".into(), + language::LineEnding::Unix, + ) + .await + .unwrap(); + + // Open the buffer to trigger LSP initialization + let buffer = project + .update(cx, |project, cx| { + project.open_local_buffer(path!("/root/src/main.rs"), cx) + }) + .await + .unwrap(); + + // Register the buffer with language servers + let _handle = project.update(cx, |project, cx| { + project.register_buffer_with_language_servers(&buffer, cx) + }); + + const UNFORMATTED_CONTENT: &str = "fn main() {println!(\"Hello!\");}\ +"; + const FORMATTED_CONTENT: &str = "This file was formatted by the fake formatter in the test.\ +"; + + // Get the fake language server and set up formatting handler + let fake_language_server = fake_language_servers.next().await.unwrap(); + fake_language_server.set_request_handler::({ + |_, _| async move { + Ok(Some(vec![lsp::TextEdit { + range: lsp::Range::new(lsp::Position::new(0, 0), lsp::Position::new(1, 0)), + new_text: FORMATTED_CONTENT.to_string(), + }])) + } + }); + + // Test with format_on_save enabled + cx.update(|cx| { + SettingsStore::update_global(cx, |store, cx| { + store.update_user_settings(cx, |settings| { + settings.project.all_languages.defaults.format_on_save = Some(FormatOnSave::On); + settings.project.all_languages.defaults.formatter = + Some(language::language_settings::FormatterList::default()); + }); + }); + }); + + // Use streaming pattern so executor can pump the LSP request/response + let (mut sender, input) = ToolInput::::test(); + let (event_stream, _receiver) = ToolCallEventStream::test(); + + let task = cx.update(|cx| write_tool.clone().run(input, event_stream, cx)); + + sender.send_partial(json!({ + "path": "root/src/main.rs", + })); + cx.run_until_parked(); + + sender.send_full(json!({ + "path": "root/src/main.rs", + "content": UNFORMATTED_CONTENT + })); + + let result = task.await; + assert!(result.is_ok()); + + cx.executor().run_until_parked(); + + let new_content = fs.load(path!("/root/src/main.rs").as_ref()).await.unwrap(); + assert_eq!( + new_content.replace("\r\n", "\n"), + FORMATTED_CONTENT, + "Code should be formatted when format_on_save is enabled" + ); + + let stale_buffer_count = thread + .read_with(cx, |thread, _cx| thread.action_log.clone()) + .read_with(cx, |log, cx| log.stale_buffers(cx).count()); + + assert_eq!( + stale_buffer_count, 0, + "BUG: Buffer is incorrectly marked as stale after format-on-save. Found {} stale buffers.", + stale_buffer_count + ); + + // Test with format_on_save disabled + cx.update(|cx| { + SettingsStore::update_global(cx, |store, cx| { + store.update_user_settings(cx, |settings| { + settings.project.all_languages.defaults.format_on_save = + Some(FormatOnSave::Off); + }); + }); + }); + + let (mut sender, input) = ToolInput::::test(); + let (event_stream, _receiver) = ToolCallEventStream::test(); + + let tool2 = Arc::new(WriteFileTool::new( + project.clone(), + thread.downgrade(), + action_log.clone(), + language_registry, + )); + + let task = cx.update(|cx| tool2.run(input, event_stream, cx)); + + sender.send_partial(json!({ + "path": "root/src/main.rs", + })); + cx.run_until_parked(); + + sender.send_full(json!({ + "path": "root/src/main.rs", + "content": UNFORMATTED_CONTENT + })); + + let result = task.await; + assert!(result.is_ok()); + + cx.executor().run_until_parked(); + + let new_content = fs.load(path!("/root/src/main.rs").as_ref()).await.unwrap(); + assert_eq!( + new_content.replace("\r\n", "\n"), + UNFORMATTED_CONTENT, + "Code should not be formatted when format_on_save is disabled" + ); + } + + #[gpui::test] + async fn test_streaming_remove_trailing_whitespace(cx: &mut TestAppContext) { + init_test(cx); + + let fs = project::FakeFs::new(cx.executor()); + fs.insert_tree("/root", json!({"src": {}})).await; + fs.save( + path!("/root/src/main.rs").as_ref(), + &"initial content".into(), + language::LineEnding::Unix, + ) + .await + .unwrap(); + let (write_tool, project, action_log, fs, thread) = + setup_test_with_fs(cx, fs, &[path!("/root").as_ref()]).await; + let language_registry = project.read_with(cx, |p, _cx| p.languages().clone()); + + // Test with remove_trailing_whitespace_on_save enabled + cx.update(|cx| { + SettingsStore::update_global(cx, |store, cx| { + store.update_user_settings(cx, |settings| { + settings + .project + .all_languages + .defaults + .remove_trailing_whitespace_on_save = Some(true); + }); + }); + }); + + const CONTENT_WITH_TRAILING_WHITESPACE: &str = + "fn main() { \n println!(\"Hello!\"); \n}\n"; + + let result = cx + .update(|cx| { + write_tool.clone().run( + ToolInput::resolved(WriteFileToolInput { + path: "root/src/main.rs".into(), + content: CONTENT_WITH_TRAILING_WHITESPACE.into(), + }), + ToolCallEventStream::test().0, + cx, + ) + }) + .await; + assert!(result.is_ok()); + + cx.executor().run_until_parked(); + + assert_eq!( + fs.load(path!("/root/src/main.rs").as_ref()) + .await + .unwrap() + .replace("\r\n", "\n"), + "fn main() {\n println!(\"Hello!\");\n}\n", + "Trailing whitespace should be removed when remove_trailing_whitespace_on_save is enabled" + ); + + // Test with remove_trailing_whitespace_on_save disabled + cx.update(|cx| { + SettingsStore::update_global(cx, |store, cx| { + store.update_user_settings(cx, |settings| { + settings + .project + .all_languages + .defaults + .remove_trailing_whitespace_on_save = Some(false); + }); + }); + }); + + let tool2 = Arc::new(WriteFileTool::new( + project.clone(), + thread.downgrade(), + action_log.clone(), + language_registry, + )); + + let result = cx + .update(|cx| { + tool2.run( + ToolInput::resolved(WriteFileToolInput { + path: "root/src/main.rs".into(), + content: CONTENT_WITH_TRAILING_WHITESPACE.into(), + }), + ToolCallEventStream::test().0, + cx, + ) + }) + .await; + assert!(result.is_ok()); + + cx.executor().run_until_parked(); + + let final_content = fs.load(path!("/root/src/main.rs").as_ref()).await.unwrap(); + assert_eq!( + final_content.replace("\r\n", "\n"), + CONTENT_WITH_TRAILING_WHITESPACE, + "Trailing whitespace should remain when remove_trailing_whitespace_on_save is disabled" + ); + } + + #[gpui::test] + async fn test_streaming_diff_finalization(cx: &mut TestAppContext) { + init_test(cx); + let fs = project::FakeFs::new(cx.executor()); + fs.insert_tree("/", json!({"main.rs": ""})).await; + let (write_tool, project, action_log, _fs, thread) = + setup_test_with_fs(cx, fs, &[path!("/").as_ref()]).await; + let language_registry = project.read_with(cx, |p, _cx| p.languages().clone()); + + // Ensure the diff is finalized after the edit completes. + { + let (stream_tx, mut stream_rx) = ToolCallEventStream::test(); + let edit = cx.update(|cx| { + write_tool.clone().run( + ToolInput::resolved(WriteFileToolInput { + path: path!("/main.rs").into(), + content: "new content".into(), + }), + stream_tx, + cx, + ) + }); + stream_rx.expect_update_fields().await; + let diff = stream_rx.expect_diff().await; + diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Pending(_)))); + cx.run_until_parked(); + edit.await.unwrap(); + diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Finalized(_)))); + } + + // Ensure the diff is finalized if the tool call gets dropped. + { + let tool = Arc::new(WriteFileTool::new( + project.clone(), + thread.downgrade(), + action_log, + language_registry, + )); + let (stream_tx, mut stream_rx) = ToolCallEventStream::test(); + let edit = cx.update(|cx| { + tool.run( + ToolInput::resolved(WriteFileToolInput { + path: path!("/main.rs").into(), + content: "dropped content".into(), + }), + stream_tx, + cx, + ) + }); + stream_rx.expect_update_fields().await; + let diff = stream_rx.expect_diff().await; + diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Pending(_)))); + drop(edit); + cx.run_until_parked(); + diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Finalized(_)))); + } + } + + #[gpui::test] + async fn test_streaming_create_content_streamed(cx: &mut TestAppContext) { + let (write_tool, project, _action_log, _fs, _thread) = + setup_test(cx, json!({"dir": {}})).await; + let (mut sender, input) = ToolInput::::test(); + let (event_stream, _receiver) = ToolCallEventStream::test(); + let task = cx.update(|cx| write_tool.clone().run(input, event_stream, cx)); + + // Transition to BufferResolved + sender.send_partial(json!({ + "path": "root/dir/new_file.txt", + })); + cx.run_until_parked(); + + // Stream content incrementally + sender.send_partial(json!({ + "path": "root/dir/new_file.txt", + "content": "line 1\n" + })); + cx.run_until_parked(); + + // Verify buffer has partial content + let buffer = project.update(cx, |project, cx| { + let path = project + .find_project_path("root/dir/new_file.txt", cx) + .unwrap(); + project.get_open_buffer(&path, cx).unwrap() + }); + assert_eq!(buffer.read_with(cx, |b, _| b.text()), "line 1\n"); + + // Stream more content + sender.send_partial(json!({ + "path": "root/dir/new_file.txt", + "content": "line 1\nline 2\n" + })); + cx.run_until_parked(); + assert_eq!(buffer.read_with(cx, |b, _| b.text()), "line 1\nline 2\n"); + + // Stream final chunk + sender.send_partial(json!({ + "path": "root/dir/new_file.txt", + "content": "line 1\nline 2\nline 3\n" + })); + cx.run_until_parked(); + assert_eq!( + buffer.read_with(cx, |b, _| b.text()), + "line 1\nline 2\nline 3\n" + ); + + // Send final input + sender.send_full(json!({ + "path": "root/dir/new_file.txt", + "content": "line 1\nline 2\nline 3\n" + })); + + let result = task.await; + let EditSessionOutput::Success { new_text, .. } = result.unwrap() else { + panic!("expected success"); + }; + assert_eq!(new_text, "line 1\nline 2\nline 3\n"); + } + + #[gpui::test] + async fn test_streaming_overwrite_diff_revealed_during_streaming(cx: &mut TestAppContext) { + let (write_tool, _project, _action_log, _fs, _thread) = setup_test( + cx, + json!({"file.txt": "old line 1\nold line 2\nold line 3\n"}), + ) + .await; + let (mut sender, input) = ToolInput::::test(); + let (event_stream, mut receiver) = ToolCallEventStream::test(); + let task = cx.update(|cx| write_tool.clone().run(input, event_stream, cx)); + + // Transition to BufferResolved + sender.send_partial(json!({ + "path": "root/file.txt", + })); + cx.run_until_parked(); + + sender.send_partial(json!({ + "path": "root/file.txt", + })); + cx.run_until_parked(); + + // Get the diff entity from the event stream + receiver.expect_update_fields().await; + let diff = receiver.expect_diff().await; + + // Diff starts pending with no revealed ranges + diff.read_with(cx, |diff, cx| { + assert!(matches!(diff, Diff::Pending(_))); + assert!(!diff.has_revealed_range(cx)); + }); + + // Stream first content chunk + sender.send_partial(json!({ + "path": "root/file.txt", + "content": "new line 1\n" + })); + cx.run_until_parked(); + + // Diff should now have revealed ranges showing the new content + diff.read_with(cx, |diff, cx| { + assert!(diff.has_revealed_range(cx)); + }); + + // Send final input + sender.send_full(json!({ + "path": "root/file.txt", + "content": "new line 1\nnew line 2\n" + })); + + let result = task.await; + let EditSessionOutput::Success { + new_text, old_text, .. + } = result.unwrap() + else { + panic!("expected success"); + }; + assert_eq!(new_text, "new line 1\nnew line 2\n"); + assert_eq!(*old_text, "old line 1\nold line 2\nold line 3\n"); + + // Diff is finalized after completion + diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Finalized(_)))); + } + + #[gpui::test] + async fn test_streaming_overwrite_content_streamed(cx: &mut TestAppContext) { + let (write_tool, project, _action_log, _fs, _thread) = setup_test( + cx, + json!({"file.txt": "old line 1\nold line 2\nold line 3\n"}), + ) + .await; + let (mut sender, input) = ToolInput::::test(); + let (event_stream, _receiver) = ToolCallEventStream::test(); + let task = cx.update(|cx| write_tool.clone().run(input, event_stream, cx)); + + // Transition to BufferResolved + sender.send_partial(json!({ + "path": "root/file.txt", + })); + cx.run_until_parked(); + + // Verify buffer still has old content (no content partial yet) + let buffer = project.update(cx, |project, cx| { + let path = project.find_project_path("root/file.txt", cx).unwrap(); + project.open_buffer(path, cx) + }); + let buffer = buffer.await.unwrap(); + assert_eq!( + buffer.read_with(cx, |b, _| b.text()), + "old line 1\nold line 2\nold line 3\n" + ); + + // First content partial replaces old content + sender.send_partial(json!({ + "path": "root/file.txt", + "content": "new line 1\n" + })); + cx.run_until_parked(); + assert_eq!(buffer.read_with(cx, |b, _| b.text()), "new line 1\n"); + + // Subsequent content partials append + sender.send_partial(json!({ + "path": "root/file.txt", + "content": "new line 1\nnew line 2\n" + })); + cx.run_until_parked(); + assert_eq!( + buffer.read_with(cx, |b, _| b.text()), + "new line 1\nnew line 2\n" + ); + + // Send final input with complete content + sender.send_full(json!({ + "path": "root/file.txt", + "content": "new line 1\nnew line 2\nnew line 3\n" + })); + + let result = task.await; + let EditSessionOutput::Success { + new_text, old_text, .. + } = result.unwrap() + else { + panic!("expected success"); + }; + assert_eq!(new_text, "new line 1\nnew line 2\nnew line 3\n"); + assert_eq!(*old_text, "old line 1\nold line 2\nold line 3\n"); + } + + #[gpui::test] + async fn test_streaming_write_file_tool_registers_changed_buffers(cx: &mut TestAppContext) { + let (write_tool, _project, action_log, _fs, _thread) = + setup_test(cx, json!({"file.txt": "original content"})).await; + cx.update(|cx| { + let mut settings = agent_settings::AgentSettings::get_global(cx).clone(); + settings.tool_permissions.default = settings::ToolPermissionMode::Allow; + agent_settings::AgentSettings::override_global(settings, cx); + }); + + let (event_stream, _rx) = ToolCallEventStream::test(); + let task = cx.update(|cx| { + write_tool.clone().run( + ToolInput::resolved(WriteFileToolInput { + path: "root/file.txt".into(), + content: "completely new content".into(), + }), + event_stream, + cx, + ) + }); + + let result = task.await; + assert!(result.is_ok(), "write should succeed: {:?}", result.err()); + + cx.run_until_parked(); + + let changed = action_log.read_with(cx, |log, cx| log.changed_buffers(cx)); + assert!( + !changed.is_empty(), + "action_log.changed_buffers() should be non-empty after streaming write, \ + but no changed buffers were found \u{2014} Accept All / Reject All will not appear" + ); + } + + #[gpui::test] + async fn test_streaming_write_file_tool_fields_out_of_order(cx: &mut TestAppContext) { + let (write_tool, _project, _action_log, _fs, _thread) = + setup_test(cx, json!({"file.txt": "old_content"})).await; + let (mut sender, input) = ToolInput::::test(); + let (event_stream, _receiver) = ToolCallEventStream::test(); + let task = cx.update(|cx| write_tool.clone().run(input, event_stream, cx)); + + sender.send_partial(json!({ + "content": "new_content" + })); + cx.run_until_parked(); + + sender.send_partial(json!({ + "content": "new_content", + "path": "root" + })); + cx.run_until_parked(); + + // Send final. + sender.send_full(json!({ + "content": "new_content", + "path": "root/file.txt" + })); + + let result = task.await; + let EditSessionOutput::Success { new_text, .. } = result.unwrap() else { + panic!("expected success"); + }; + assert_eq!(new_text, "new_content"); + } + + #[gpui::test] + async fn test_streaming_reject_created_file_deletes_it(cx: &mut TestAppContext) { + let (write_tool, _project, action_log, fs, _thread) = + setup_test(cx, json!({"dir": {}})).await; + cx.update(|cx| { + let mut settings = agent_settings::AgentSettings::get_global(cx).clone(); + settings.tool_permissions.default = settings::ToolPermissionMode::Allow; + agent_settings::AgentSettings::override_global(settings, cx); + }); + + // Create a new file via the streaming write file tool + let (event_stream, _rx) = ToolCallEventStream::test(); + let task = cx.update(|cx| { + write_tool.clone().run( + ToolInput::resolved(WriteFileToolInput { + path: "root/dir/new_file.txt".into(), + content: "Hello, World!".into(), + }), + event_stream, + cx, + ) + }); + let result = task.await; + assert!(result.is_ok(), "create should succeed: {:?}", result.err()); + cx.run_until_parked(); + + assert!( + fs.is_file(path!("/root/dir/new_file.txt").as_ref()).await, + "file should exist after creation" + ); + + // Reject all edits — this should delete the newly created file + let changed = action_log.read_with(cx, |log, cx| log.changed_buffers(cx)); + assert!( + !changed.is_empty(), + "action_log should track the created file as changed" + ); + + action_log + .update(cx, |log, cx| log.reject_all_edits(None, cx)) + .await; + cx.run_until_parked(); + + assert!( + !fs.is_file(path!("/root/dir/new_file.txt").as_ref()).await, + "file should be deleted after rejecting creation, but an empty file was left behind" + ); + } + + async fn setup_test_with_fs( + cx: &mut TestAppContext, + fs: Arc, + worktree_paths: &[&std::path::Path], + ) -> ( + Arc, + Entity, + Entity, + Arc, + Entity, + ) { + let project = Project::test(fs.clone(), worktree_paths.iter().copied(), cx).await; + let language_registry = project.read_with(cx, |project, _cx| project.languages().clone()); + let context_server_registry = + cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx)); + let model = Arc::new(FakeLanguageModel::default()); + let thread = cx.new(|cx| { + crate::Thread::new( + project.clone(), + cx.new(|_cx| ProjectContext::default()), + context_server_registry, + Templates::new(), + Some(model), + cx, + ) + }); + let action_log = thread.read_with(cx, |thread, _| thread.action_log().clone()); + let write_tool = Arc::new(WriteFileTool::new( + project.clone(), + thread.downgrade(), + action_log.clone(), + language_registry, + )); + (write_tool, project, action_log, fs, thread) + } + + async fn setup_test( + cx: &mut TestAppContext, + initial_tree: serde_json::Value, + ) -> ( + Arc, + Entity, + Entity, + Arc, + Entity, + ) { + init_test(cx); + let fs = project::FakeFs::new(cx.executor()); + fs.insert_tree("/root", initial_tree).await; + setup_test_with_fs(cx, fs, &[path!("/root").as_ref()]).await + } + + async fn test_resolve_path( + mode: &EditSessionMode, + path: &str, + cx: &mut TestAppContext, + ) -> Result { + init_test(cx); + let fs = project::FakeFs::new(cx.executor()); + fs.insert_tree( + "/root", + json!({ + "dir": { + "subdir": { + "existing.txt": "content" + } + } + }), + ) + .await; + let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; + + crate::tools::edit_session::test_resolve_path(mode, path, &project, cx).await + } + + #[track_caller] + fn assert_resolved_path_eq(path: Result, expected: &RelPath) { + let actual = path.expect("Should return valid path").path; + assert_eq!(actual.as_ref(), expected); + } + + fn init_test(cx: &mut TestAppContext) { + cx.update(|cx| { + let settings_store = SettingsStore::test(cx); + cx.set_global(settings_store); + SettingsStore::update_global(cx, |store: &mut SettingsStore, cx| { + store.update_user_settings(cx, |settings| { + settings + .project + .all_languages + .defaults + .ensure_final_newline_on_save = Some(false); + }); + }); + }); + } +} diff --git a/crates/settings_ui/src/pages.rs b/crates/settings_ui/src/pages.rs index 401534b66059e61e52406b85f509ae9c935eeab2..63c3965d095a3ae769dbaa36d01a66281f48d868 100644 --- a/crates/settings_ui/src/pages.rs +++ b/crates/settings_ui/src/pages.rs @@ -17,4 +17,5 @@ pub use tool_permissions_setup::{ render_delete_path_tool_config, render_edit_file_tool_config, render_fetch_tool_config, render_move_path_tool_config, render_restore_file_from_disk_tool_config, render_save_file_tool_config, render_terminal_tool_config, render_web_search_tool_config, + render_write_file_tool_config, }; diff --git a/crates/settings_ui/src/pages/tool_permissions_setup.rs b/crates/settings_ui/src/pages/tool_permissions_setup.rs index 12693cb99d98fc022520352949e9c74b8501fad3..05cd51e3f924826609501f994a1abcb0ddf73ef4 100644 --- a/crates/settings_ui/src/pages/tool_permissions_setup.rs +++ b/crates/settings_ui/src/pages/tool_permissions_setup.rs @@ -32,6 +32,12 @@ const TOOLS: &[ToolInfo] = &[ description: "File editing operations", regex_explanation: "Patterns are matched against the file path being edited.", }, + ToolInfo { + id: "write_file", + name: "Write File", + description: "File creation and overwrite operations", + regex_explanation: "Patterns are matched against the file path being written.", + }, ToolInfo { id: "delete_path", name: "Delete Path", @@ -303,6 +309,7 @@ fn get_tool_render_fn( match tool_id { "terminal" => render_terminal_tool_config, "edit_file" => render_edit_file_tool_config, + "write_file" => render_write_file_tool_config, "delete_path" => render_delete_path_tool_config, "copy_path" => render_copy_path_tool_config, "move_path" => render_move_path_tool_config, @@ -1383,6 +1390,7 @@ macro_rules! tool_config_page_fn { tool_config_page_fn!(render_terminal_tool_config, "terminal"); tool_config_page_fn!(render_edit_file_tool_config, "edit_file"); +tool_config_page_fn!(render_write_file_tool_config, "write_file"); tool_config_page_fn!(render_delete_path_tool_config, "delete_path"); tool_config_page_fn!(render_copy_path_tool_config, "copy_path"); tool_config_page_fn!(render_move_path_tool_config, "move_path");