Detailed changes
@@ -1110,6 +1110,7 @@
"diagnostics": true,
"apply_code_action": true,
"edit_file": true,
+ "write_file": true,
"fetch": true,
"find_path": true,
"find_references": true,
@@ -6062,9 +6062,7 @@ async fn test_edit_file_tool_deny_rule_blocks_edit(cx: &mut TestAppContext) {
tool.run(
ToolInput::resolved(crate::EditFileToolInput {
path: "root/sensitive_config.txt".into(),
- mode: crate::EditFileMode::Edit,
- content: None,
- edits: Some(vec![]),
+ edits: vec![],
}),
event_stream,
cx,
@@ -6496,9 +6494,7 @@ async fn test_edit_file_tool_allow_rule_skips_confirmation(cx: &mut TestAppConte
tool.run(
ToolInput::resolved(crate::EditFileToolInput {
path: "root/README.md".into(),
- mode: crate::EditFileMode::Edit,
- content: None,
- edits: Some(vec![]),
+ edits: vec![],
}),
event_stream,
cx,
@@ -6568,9 +6564,7 @@ async fn test_edit_file_tool_allow_still_prompts_for_local_settings(cx: &mut Tes
tool.run(
ToolInput::resolved(crate::EditFileToolInput {
path: "root/.zed/settings.json".into(),
- mode: crate::EditFileMode::Edit,
- content: None,
- edits: Some(vec![]),
+ edits: vec![],
}),
event_stream,
cx,
@@ -4,7 +4,7 @@ use crate::{
FindPathTool, FindReferencesTool, GetCodeActionsTool, GoToDefinitionTool, GrepTool,
ListDirectoryTool, MovePathTool, NowTool, OpenTool, ProjectSnapshot, ReadFileTool, RenameTool,
RestoreFileFromDiskTool, SaveFileTool, SpawnAgentTool, SystemPromptTemplate, Template,
- Templates, TerminalTool, ToolPermissionDecision, UpdatePlanTool, WebSearchTool,
+ Templates, TerminalTool, ToolPermissionDecision, UpdatePlanTool, WebSearchTool, WriteFileTool,
decide_permission_from_settings,
};
use acp_thread::{MentionUri, UserMessageId};
@@ -822,6 +822,7 @@ impl ToolPermissionContext {
} else if tool_name == CopyPathTool::NAME
|| tool_name == MovePathTool::NAME
|| tool_name == EditFileTool::NAME
+ || tool_name == WriteFileTool::NAME
|| tool_name == DeletePathTool::NAME
|| tool_name == CreateDirectoryTool::NAME
|| tool_name == SaveFileTool::NAME
@@ -1544,6 +1545,12 @@ impl Thread {
self.action_log.clone(),
));
self.add_tool(EditFileTool::new(
+ self.project.clone(),
+ cx.weak_entity(),
+ self.action_log.clone(),
+ language_registry.clone(),
+ ));
+ self.add_tool(WriteFileTool::new(
self.project.clone(),
cx.weak_entity(),
self.action_log.clone(),
@@ -5,6 +5,7 @@ mod create_directory_tool;
mod delete_path_tool;
mod diagnostics_tool;
mod edit_file_tool;
+mod edit_session;
#[cfg(all(test, feature = "unit-eval"))]
mod evals;
mod fetch_tool;
@@ -27,6 +28,7 @@ mod terminal_tool;
mod tool_permissions;
mod update_plan_tool;
mod web_search_tool;
+mod write_file_tool;
use crate::AgentTool;
use language_model::{LanguageModelRequestTool, LanguageModelToolSchemaFormat};
@@ -85,6 +87,7 @@ pub use terminal_tool::*;
pub use tool_permissions::*;
pub use update_plan_tool::*;
pub use web_search_tool::*;
+pub use write_file_tool::*;
macro_rules! tools {
($($tool:ty),* $(,)?) => {
@@ -179,4 +182,5 @@ tools! {
TerminalTool,
UpdatePlanTool,
WebSearchTool,
+ WriteFileTool,
}
@@ -1,53 +1,36 @@
-mod reindent;
-mod streaming_fuzzy_matcher;
-mod streaming_parser;
-
use super::deserialize_maybe_stringified;
-use super::restore_file_from_disk_tool::RestoreFileFromDiskTool;
-use super::save_file_tool::SaveFileTool;
-use crate::ToolInputPayload;
-use crate::tools::edit_file_tool::{
- reindent::{Reindenter, compute_indent_delta},
- streaming_fuzzy_matcher::StreamingFuzzyMatcher,
- streaming_parser::{EditEvent, StreamingParser, WriteEvent},
+pub(crate) use super::edit_session::PartialEdit;
+pub use super::edit_session::{Edit, EditSessionOutput as EditFileToolOutput};
+use super::edit_session::{
+ EditSession, EditSessionContext, EditSessionMode, EditSessionResult,
+ initial_title_from_partial_path, run_session,
};
-use crate::{AgentTool, Thread, ToolCallEventStream, ToolInput};
-use acp_thread::Diff;
+use crate::{AgentTool, Thread, ToolCallEventStream, ToolInput, ToolInputPayload};
use action_log::ActionLog;
-use agent_client_protocol::schema::{self as acp, ToolCallLocation, ToolCallUpdateFields};
+use agent_client_protocol::schema as acp;
use anyhow::Result;
-use collections::HashSet;
use futures::FutureExt as _;
-use gpui::{App, AppContext, AsyncApp, Entity, Task, WeakEntity};
-use language::language_settings::{self, FormatOnSave};
-use language::{Buffer, LanguageRegistry};
-use language_model::LanguageModelToolResultContent;
-use project::lsp_store::{FormatTrigger, LspFormatTarget};
-use project::{AgentLocation, Project, ProjectPath};
+use gpui::{App, AsyncApp, Entity, Task, WeakEntity};
+use language::LanguageRegistry;
+use project::Project;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
-use std::ops::Range;
use std::path::PathBuf;
use std::sync::Arc;
-use streaming_diff::{CharOperation, StreamingDiff};
-use text::ToOffset;
use ui::SharedString;
-use util::rel_path::RelPath;
-use util::{Deferred, ResultExt};
const DEFAULT_UI_TEXT: &str = "Editing file";
-/// This is a tool for creating a new file or editing an existing file. For moving or renaming files, you should generally use the `move_path` tool instead.
+/// This is a tool for applying edits to an existing file.
///
/// Before using this tool:
///
/// 1. Use the `read_file` tool to understand the file's contents and context
///
-/// 2. Verify the directory path is correct (only applicable when creating new files):
-/// - Use the `list_directory` tool to verify the parent directory exists and is the correct location
+/// To create a new file or overwrite an existing one with completely new contents, use the `write_file` tool instead.
#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
pub struct EditFileToolInput {
- /// The full path of the file to create or modify in the project.
+ /// The full path of the file to edit in the project.
///
/// WARNING: When specifying which file path need changing, you MUST start each path with one of the project's root directories.
///
@@ -66,50 +49,10 @@ pub struct EditFileToolInput {
/// </example>
pub path: PathBuf,
- /// The mode of operation on the file. Possible values:
- /// - 'write': Replace the entire contents of the file. If the file doesn't exist, it will be created. Requires 'content' field.
- /// - 'edit': Make granular edits to an existing file. Requires 'edits' field.
- ///
- /// When a file already exists or you just created it, prefer editing it as opposed to recreating it from scratch.
- #[serde(deserialize_with = "deserialize_maybe_stringified")]
- pub mode: EditFileMode,
-
- /// The complete content for the new file (required for 'write' mode).
- /// This field should contain the entire file content.
- #[serde(default, skip_serializing_if = "Option::is_none")]
- pub content: Option<String>,
-
- /// List of edit operations to apply sequentially (required for 'edit' mode).
+ /// List of edit operations to apply sequentially.
/// Each edit finds `old_text` in the file and replaces it with `new_text`.
- #[serde(
- default,
- skip_serializing_if = "Option::is_none",
- deserialize_with = "deserialize_maybe_stringified"
- )]
- pub edits: Option<Vec<Edit>>,
-}
-
-#[derive(Clone, Copy, Debug, Serialize, Deserialize, JsonSchema)]
-#[serde(rename_all = "snake_case")]
-pub enum EditFileMode {
- Write,
- Edit,
-}
-
-/// A single edit operation that replaces old text with new text
-/// Properly escape all text fields as valid JSON strings.
-/// Remember to escape special characters like newlines (`\n`) and quotes (`"`) in JSON strings.
-#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
-pub struct Edit {
- /// The exact text to find in the file. This will be matched using fuzzy matching
- /// to handle minor differences in whitespace or formatting.
- ///
- /// Be minimal with replacements:
- /// - For unique lines, include only those lines
- /// - For non-unique lines, include enough context to identify them
- pub old_text: String,
- /// The text to replace it with
- pub new_text: String,
+ #[serde(deserialize_with = "deserialize_maybe_stringified")]
+ pub edits: Vec<Edit>,
}
#[derive(Clone, Default, Debug, Deserialize)]
@@ -117,108 +60,11 @@ struct EditFileToolPartialInput {
#[serde(default)]
path: Option<String>,
#[serde(default, deserialize_with = "deserialize_maybe_stringified")]
- mode: Option<EditFileMode>,
- #[serde(default)]
- content: Option<String>,
- #[serde(default, deserialize_with = "deserialize_maybe_stringified")]
edits: Option<Vec<PartialEdit>>,
}
-#[derive(Clone, Default, Debug, Deserialize)]
-pub struct PartialEdit {
- #[serde(default)]
- pub old_text: Option<String>,
- #[serde(default)]
- pub new_text: Option<String>,
-}
-
-#[derive(Debug, Serialize, Deserialize)]
-#[serde(untagged)]
-pub enum EditFileToolOutput {
- Success {
- #[serde(alias = "original_path")]
- input_path: PathBuf,
- new_text: String,
- old_text: Arc<String>,
- #[serde(default)]
- diff: String,
- },
- Error {
- error: String,
- #[serde(default, skip_serializing_if = "Option::is_none")]
- input_path: Option<PathBuf>,
- #[serde(default, skip_serializing_if = "String::is_empty")]
- diff: String,
- },
-}
-
-impl EditFileToolOutput {
- pub fn error(error: impl Into<String>) -> Self {
- Self::Error {
- error: error.into(),
- input_path: None,
- diff: String::new(),
- }
- }
-}
-
-impl std::fmt::Display for EditFileToolOutput {
- fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
- match self {
- EditFileToolOutput::Success {
- diff, input_path, ..
- } => {
- if diff.is_empty() {
- write!(f, "No edits were made.")
- } else {
- write!(
- f,
- "Edited {}:\n\n```diff\n{diff}\n```",
- input_path.display()
- )
- }
- }
- EditFileToolOutput::Error {
- error,
- diff,
- input_path,
- } => {
- write!(f, "{error}\n")?;
- if let Some(input_path) = input_path
- && !diff.is_empty()
- {
- write!(
- f,
- "Edited {}:\n\n```diff\n{diff}\n```",
- input_path.display()
- )
- } else {
- write!(f, "No edits were made.")
- }
- }
- }
- }
-}
-
-impl From<EditFileToolOutput> for LanguageModelToolResultContent {
- fn from(output: EditFileToolOutput) -> Self {
- output.to_string().into()
- }
-}
-
pub struct EditFileTool {
- project: Entity<Project>,
- thread: WeakEntity<Thread>,
- action_log: Entity<ActionLog>,
- language_registry: Arc<LanguageRegistry>,
-}
-
-enum EditSessionResult {
- Completed(EditSession),
- Failed {
- error: String,
- session: Option<EditSession>,
- },
+ session_context: Arc<EditSessionContext>,
}
impl EditFileTool {
@@ -229,69 +75,24 @@ impl EditFileTool {
language_registry: Arc<LanguageRegistry>,
) -> Self {
Self {
- project,
- thread,
- action_log,
- language_registry,
+ session_context: Arc::new(EditSessionContext::new(
+ project,
+ thread,
+ action_log,
+ language_registry,
+ )),
}
}
+ #[cfg(test)]
fn authorize(
&self,
path: &PathBuf,
event_stream: &ToolCallEventStream,
cx: &mut App,
) -> Task<Result<()>> {
- super::tool_permissions::authorize_file_edit(
- EditFileTool::NAME,
- path,
- &self.thread,
- event_stream,
- cx,
- )
- }
-
- fn set_agent_location(&self, buffer: WeakEntity<Buffer>, position: text::Anchor, cx: &mut App) {
- let should_update_agent_location = self
- .thread
- .read_with(cx, |thread, _cx| !thread.is_subagent())
- .unwrap_or_default();
- if should_update_agent_location {
- self.project.update(cx, |project, cx| {
- project.set_agent_location(Some(AgentLocation { buffer, position }), cx);
- });
- }
- }
-
- async fn ensure_buffer_saved(&self, buffer: &Entity<Buffer>, cx: &mut AsyncApp) {
- let format_on_save_enabled = buffer.read_with(cx, |buffer, cx| {
- let settings = language_settings::LanguageSettings::for_buffer(buffer, cx);
- settings.format_on_save != FormatOnSave::Off
- });
-
- if format_on_save_enabled {
- self.project
- .update(cx, |project, cx| {
- project.format(
- HashSet::from_iter([buffer.clone()]),
- LspFormatTarget::Buffers,
- false,
- FormatTrigger::Save,
- cx,
- )
- })
- .await
- .log_err();
- }
-
- self.project
- .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))
- .await
- .log_err();
-
- self.action_log.update(cx, |log, cx| {
- log.buffer_edited(buffer.clone(), cx);
- });
+ self.session_context
+ .authorize(Self::NAME, path, event_stream, cx)
}
async fn process_streaming_edits(
@@ -301,7 +102,7 @@ impl EditFileTool {
cx: &mut AsyncApp,
) -> EditSessionResult {
let mut session: Option<EditSession> = None;
- let mut last_partial: Option<EditFileToolPartialInput> = None;
+ let mut last_path: Option<String> = None;
loop {
futures::select! {
@@ -311,22 +112,19 @@ impl EditFileTool {
ToolInputPayload::Partial(partial) => {
if let Ok(parsed) = serde_json::from_value::<EditFileToolPartialInput>(partial) {
let path_complete = parsed.path.is_some()
- && parsed.path.as_ref() == last_partial.as_ref().and_then(|partial| partial.path.as_ref());
+ && parsed.path.as_ref() == last_path.as_ref();
- last_partial = Some(parsed.clone());
+ last_path = parsed.path.clone();
if session.is_none()
&& path_complete
- && let EditFileToolPartialInput {
- path: Some(path),
- mode: Some(mode),
- ..
- } = &parsed
+ && let Some(path) = parsed.path.as_ref()
{
match EditSession::new(
PathBuf::from(path),
- *mode,
- self,
+ EditSessionMode::Edit,
+ Self::NAME,
+ self.session_context.clone(),
event_stream,
cx,
)
@@ -344,7 +142,7 @@ impl EditFileTool {
}
if let Some(current_session) = &mut session
- && let Err(error) = current_session.process(parsed, self, event_stream, cx)
+ && let Err(error) = current_session.process_edit(parsed.edits.as_deref(), event_stream, cx)
{
log::error!("Failed to process edit: {}", error);
return EditSessionResult::Failed { error, session };
@@ -357,8 +155,9 @@ impl EditFileTool {
} else {
match EditSession::new(
full_input.path.clone(),
- full_input.mode,
- self,
+ EditSessionMode::Edit,
+ Self::NAME,
+ self.session_context.clone(),
event_stream,
cx,
)
@@ -375,7 +174,7 @@ impl EditFileTool {
}
};
- return match session.finalize(full_input, self, event_stream, cx).await {
+ return match session.finalize_edit(full_input.edits, event_stream, cx).await {
Ok(()) => EditSessionResult::Completed(session),
Err(error) => {
log::error!("Failed to finalize edit: {}", error);
@@ -433,38 +232,17 @@ impl AgentTool for EditFileTool {
cx: &mut App,
) -> SharedString {
match input {
- Ok(input) => self
- .project
- .read(cx)
- .find_project_path(&input.path, cx)
- .and_then(|project_path| {
- self.project
- .read(cx)
- .short_full_path_for_project_path(&project_path, cx)
- })
- .unwrap_or(input.path.to_string_lossy().into_owned())
- .into(),
- Err(raw_input) => {
- if let Ok(input) = serde_json::from_value::<EditFileToolPartialInput>(raw_input) {
- let path = input.path.unwrap_or_default();
- let path = path.trim();
- if !path.is_empty() {
- return self
- .project
- .read(cx)
- .find_project_path(&path, cx)
- .and_then(|project_path| {
- self.project
- .read(cx)
- .short_full_path_for_project_path(&project_path, cx)
- })
- .unwrap_or_else(|| path.to_string())
- .into();
- }
- }
-
- DEFAULT_UI_TEXT.into()
+ Ok(input) => {
+ self.session_context
+ .initial_title_from_path(&input.path, DEFAULT_UI_TEXT, cx)
}
+ Err(raw_input) => initial_title_from_partial_path::<EditFileToolPartialInput>(
+ &self.session_context,
+ raw_input,
+ |partial| partial.path.clone(),
+ DEFAULT_UI_TEXT,
+ cx,
+ ),
}
}
@@ -475,41 +253,12 @@ impl AgentTool for EditFileTool {
cx: &mut App,
) -> Task<Result<Self::Output, Self::Output>> {
cx.spawn(async move |cx: &mut AsyncApp| {
- match self
- .process_streaming_edits(&mut input, &event_stream, cx)
- .await
- {
- EditSessionResult::Completed(session) => {
- self.ensure_buffer_saved(&session.buffer, cx).await;
- let (new_text, diff) = session.compute_new_text_and_diff(cx).await;
- Ok(EditFileToolOutput::Success {
- old_text: session.old_text.clone(),
- new_text,
- input_path: session.input_path,
- diff,
- })
- }
- EditSessionResult::Failed {
- error,
- session: Some(session),
- } => {
- self.ensure_buffer_saved(&session.buffer, cx).await;
- let (_new_text, diff) = session.compute_new_text_and_diff(cx).await;
- Err(EditFileToolOutput::Error {
- error,
- input_path: Some(session.input_path),
- diff,
- })
- }
- EditSessionResult::Failed {
- error,
- session: None,
- } => Err(EditFileToolOutput::Error {
- error,
- input_path: None,
- diff: String::new(),
- }),
- }
+ run_session(
+ self.process_streaming_edits(&mut input, &event_stream, cx)
+ .await,
+ cx,
+ )
+ .await
})
}
@@ -520,706 +269,7 @@ impl AgentTool for EditFileTool {
event_stream: ToolCallEventStream,
cx: &mut App,
) -> Result<()> {
- match output {
- EditFileToolOutput::Success {
- input_path,
- old_text,
- new_text,
- ..
- } => {
- event_stream.update_diff(cx.new(|cx| {
- Diff::finalized(
- input_path.to_string_lossy().into_owned(),
- Some(old_text.to_string()),
- new_text,
- self.language_registry.clone(),
- cx,
- )
- }));
- Ok(())
- }
- EditFileToolOutput::Error { .. } => Ok(()),
- }
- }
-}
-
-pub struct EditSession {
- abs_path: PathBuf,
- input_path: PathBuf,
- buffer: Entity<Buffer>,
- old_text: Arc<String>,
- diff: Entity<Diff>,
- parser: StreamingParser,
- pipeline: Pipeline,
- _finalize_diff_guard: Deferred<Box<dyn FnOnce()>>,
-}
-
-enum Pipeline {
- Write(WritePipeline),
- Edit(EditPipeline),
-}
-
-struct WritePipeline {
- content_written: bool,
-}
-
-struct EditPipeline {
- current_edit: Option<EditPipelineEntry>,
- file_changed_since_last_read: bool,
-}
-
-enum EditPipelineEntry {
- ResolvingOldText {
- matcher: StreamingFuzzyMatcher,
- },
- StreamingNewText {
- streaming_diff: StreamingDiff,
- edit_cursor: usize,
- reindenter: Reindenter,
- original_snapshot: text::BufferSnapshot,
- },
-}
-
-impl Pipeline {
- fn new(mode: EditFileMode, file_changed_since_last_read: bool) -> Self {
- match mode {
- EditFileMode::Write => Self::Write(WritePipeline {
- content_written: false,
- }),
- EditFileMode::Edit => Self::Edit(EditPipeline {
- current_edit: None,
- file_changed_since_last_read,
- }),
- }
- }
-}
-
-impl WritePipeline {
- fn process_event(
- &mut self,
- event: &WriteEvent,
- buffer: &Entity<Buffer>,
- tool: &EditFileTool,
- cx: &mut AsyncApp,
- ) {
- let WriteEvent::ContentChunk { chunk } = event;
-
- let (buffer_id, buffer_len) =
- buffer.read_with(cx, |buffer, _cx| (buffer.remote_id(), buffer.len()));
- let edit_range = if self.content_written {
- buffer_len..buffer_len
- } else {
- 0..buffer_len
- };
-
- agent_edit_buffer(buffer, [(edit_range, chunk.as_str())], &tool.action_log, cx);
- cx.update(|cx| {
- tool.set_agent_location(
- buffer.downgrade(),
- text::Anchor::max_for_buffer(buffer_id),
- cx,
- );
- });
- self.content_written = true;
- }
-}
-
-impl EditPipeline {
- fn ensure_resolving_old_text(&mut self, buffer: &Entity<Buffer>, cx: &mut AsyncApp) {
- if self.current_edit.is_none() {
- let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.text_snapshot());
- self.current_edit = Some(EditPipelineEntry::ResolvingOldText {
- matcher: StreamingFuzzyMatcher::new(snapshot),
- });
- }
- }
-
- fn process_event(
- &mut self,
- event: &EditEvent,
- buffer: &Entity<Buffer>,
- diff: &Entity<Diff>,
- abs_path: &PathBuf,
- tool: &EditFileTool,
- event_stream: &ToolCallEventStream,
- cx: &mut AsyncApp,
- ) -> Result<(), String> {
- match event {
- EditEvent::OldTextChunk {
- chunk, done: false, ..
- } => {
- log::debug!("old_text_chunk: done=false, chunk='{}'", chunk);
- self.ensure_resolving_old_text(buffer, cx);
-
- if let Some(EditPipelineEntry::ResolvingOldText { matcher }) =
- &mut self.current_edit
- && !chunk.is_empty()
- {
- if let Some(match_range) = matcher.push(chunk, None) {
- let anchor_range = buffer.read_with(cx, |buffer, _cx| {
- buffer.anchor_range_outside(match_range.clone())
- });
- diff.update(cx, |diff, cx| diff.reveal_range(anchor_range, cx));
-
- cx.update(|cx| {
- let position = buffer.read(cx).anchor_before(match_range.end);
- tool.set_agent_location(buffer.downgrade(), position, cx);
- });
- }
- }
- }
- EditEvent::OldTextChunk {
- edit_index,
- chunk,
- done: true,
- } => {
- log::debug!("old_text_chunk: done=true, chunk='{}'", chunk);
-
- self.ensure_resolving_old_text(buffer, cx);
-
- let Some(EditPipelineEntry::ResolvingOldText { matcher }) = &mut self.current_edit
- else {
- return Ok(());
- };
-
- if !chunk.is_empty() {
- matcher.push(chunk, None);
- }
- let range = extract_match(
- matcher.finish(),
- buffer,
- edit_index,
- self.file_changed_since_last_read,
- cx,
- )?;
-
- let anchor_range =
- buffer.read_with(cx, |buffer, _cx| buffer.anchor_range_outside(range.clone()));
- diff.update(cx, |diff, cx| diff.reveal_range(anchor_range, cx));
-
- let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
-
- let line = snapshot.offset_to_point(range.start).row;
- event_stream.update_fields(
- ToolCallUpdateFields::new()
- .locations(vec![ToolCallLocation::new(abs_path).line(Some(line))]),
- );
-
- let buffer_indent = snapshot.line_indent_for_row(line);
- let query_indent = text::LineIndent::from_iter(
- matcher
- .query_lines()
- .first()
- .map(|s| s.as_str())
- .unwrap_or("")
- .chars(),
- );
- let indent_delta = compute_indent_delta(buffer_indent, query_indent);
-
- let old_text_in_buffer = snapshot.text_for_range(range.clone()).collect::<String>();
-
- log::debug!(
- "edit[{}] old_text matched at {}..{}: {:?}",
- edit_index,
- range.start,
- range.end,
- old_text_in_buffer,
- );
-
- let text_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.text_snapshot());
- self.current_edit = Some(EditPipelineEntry::StreamingNewText {
- streaming_diff: StreamingDiff::new(old_text_in_buffer),
- edit_cursor: range.start,
- reindenter: Reindenter::new(indent_delta),
- original_snapshot: text_snapshot,
- });
-
- cx.update(|cx| {
- let position = buffer.read(cx).anchor_before(range.end);
- tool.set_agent_location(buffer.downgrade(), position, cx);
- });
- }
- EditEvent::NewTextChunk {
- chunk, done: false, ..
- } => {
- log::debug!("new_text_chunk: done=false, chunk='{}'", chunk);
-
- let Some(EditPipelineEntry::StreamingNewText {
- streaming_diff,
- edit_cursor,
- reindenter,
- original_snapshot,
- ..
- }) = &mut self.current_edit
- else {
- return Ok(());
- };
-
- let reindented = reindenter.push(chunk);
- if reindented.is_empty() {
- return Ok(());
- }
-
- let char_ops = streaming_diff.push_new(&reindented);
- apply_char_operations(
- &char_ops,
- buffer,
- original_snapshot,
- edit_cursor,
- &tool.action_log,
- cx,
- );
-
- let position = original_snapshot.anchor_before(*edit_cursor);
- cx.update(|cx| {
- tool.set_agent_location(buffer.downgrade(), position, cx);
- });
- }
- EditEvent::NewTextChunk {
- chunk, done: true, ..
- } => {
- log::debug!("new_text_chunk: done=true, chunk='{}'", chunk);
-
- let Some(EditPipelineEntry::StreamingNewText {
- mut streaming_diff,
- mut edit_cursor,
- mut reindenter,
- original_snapshot,
- }) = self.current_edit.take()
- else {
- return Ok(());
- };
-
- // Flush any remaining reindent buffer + final chunk.
- let mut final_text = reindenter.push(chunk);
- final_text.push_str(&reindenter.finish());
-
- log::debug!("new_text_chunk: done=true, final_text='{}'", final_text);
-
- if !final_text.is_empty() {
- let char_ops = streaming_diff.push_new(&final_text);
- apply_char_operations(
- &char_ops,
- buffer,
- &original_snapshot,
- &mut edit_cursor,
- &tool.action_log,
- cx,
- );
- }
-
- let remaining_ops = streaming_diff.finish();
- apply_char_operations(
- &remaining_ops,
- buffer,
- &original_snapshot,
- &mut edit_cursor,
- &tool.action_log,
- cx,
- );
-
- let position = original_snapshot.anchor_before(edit_cursor);
- cx.update(|cx| {
- tool.set_agent_location(buffer.downgrade(), position, cx);
- });
- }
- }
- Ok(())
- }
-}
-
-impl EditSession {
- async fn new(
- path: PathBuf,
- mode: EditFileMode,
- tool: &EditFileTool,
- event_stream: &ToolCallEventStream,
- cx: &mut AsyncApp,
- ) -> Result<Self, String> {
- let project_path = cx.update(|cx| resolve_path(mode, &path, &tool.project, cx))?;
-
- let Some(abs_path) = cx.update(|cx| tool.project.read(cx).absolute_path(&project_path, cx))
- else {
- return Err(format!(
- "Worktree at '{}' does not exist",
- path.to_string_lossy()
- ));
- };
-
- event_stream.update_fields(
- ToolCallUpdateFields::new().locations(vec![ToolCallLocation::new(abs_path.clone())]),
- );
-
- cx.update(|cx| tool.authorize(&path, event_stream, cx))
- .await
- .map_err(|e| e.to_string())?;
-
- let buffer = tool
- .project
- .update(cx, |project, cx| project.open_buffer(project_path, cx))
- .await
- .map_err(|e| e.to_string())?;
-
- let file_changed_since_last_read = ensure_buffer_saved(&buffer, &abs_path, tool, cx)?;
-
- let diff = cx.new(|cx| Diff::new(buffer.clone(), cx));
- event_stream.update_diff(diff.clone());
- let finalize_diff_guard = util::defer(Box::new({
- let diff = diff.downgrade();
- let mut cx = cx.clone();
- move || {
- diff.update(&mut cx, |diff, cx| diff.finalize(cx)).ok();
- }
- }) as Box<dyn FnOnce()>);
-
- tool.action_log.update(cx, |log, cx| match mode {
- EditFileMode::Write => log.buffer_created(buffer.clone(), cx),
- EditFileMode::Edit => log.buffer_read(buffer.clone(), cx),
- });
-
- let old_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
- let old_text = cx
- .background_spawn({
- let old_snapshot = old_snapshot.clone();
- async move { Arc::new(old_snapshot.text()) }
- })
- .await;
-
- Ok(Self {
- abs_path,
- input_path: path,
- buffer,
- old_text,
- diff,
- parser: StreamingParser::default(),
- pipeline: Pipeline::new(mode, file_changed_since_last_read),
- _finalize_diff_guard: finalize_diff_guard,
- })
- }
-
- async fn finalize(
- &mut self,
- input: EditFileToolInput,
- tool: &EditFileTool,
- event_stream: &ToolCallEventStream,
- cx: &mut AsyncApp,
- ) -> Result<(), String> {
- let Self {
- abs_path,
- buffer,
- diff,
- parser,
- pipeline,
- ..
- } = self;
- match pipeline {
- Pipeline::Write(write) => {
- let content = input
- .content
- .ok_or_else(|| "'content' field is required for write mode".to_string())?;
-
- for event in &parser.finalize_content(&content) {
- write.process_event(event, buffer, tool, cx);
- }
- }
- Pipeline::Edit(edit_pipeline) => {
- let edits = input
- .edits
- .ok_or_else(|| "'edits' field is required for edit mode".to_string())?;
- for event in &parser.finalize_edits(&edits) {
- edit_pipeline.process_event(
- event,
- buffer,
- diff,
- abs_path,
- tool,
- event_stream,
- cx,
- )?;
- }
-
- if log::log_enabled!(log::Level::Debug) {
- log::debug!("Got edits:");
- for edit in &edits {
- log::debug!(
- " old_text: '{}', new_text: '{}'",
- edit.old_text.replace('\n', "\\n"),
- edit.new_text.replace('\n', "\\n")
- );
- }
- }
- }
- }
- Ok(())
- }
-
- async fn compute_new_text_and_diff(&self, cx: &mut AsyncApp) -> (String, String) {
- let new_snapshot = self.buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
- let (new_text, unified_diff) = cx
- .background_spawn({
- let new_snapshot = new_snapshot.clone();
- let old_text = self.old_text.clone();
- async move {
- let new_text = new_snapshot.text();
- let diff = language::unified_diff(&old_text, &new_text);
- (new_text, diff)
- }
- })
- .await;
- (new_text, unified_diff)
- }
-
- fn process(
- &mut self,
- partial: EditFileToolPartialInput,
- tool: &EditFileTool,
- event_stream: &ToolCallEventStream,
- cx: &mut AsyncApp,
- ) -> Result<(), String> {
- let Self {
- abs_path,
- buffer,
- diff,
- parser,
- pipeline,
- ..
- } = self;
- match pipeline {
- Pipeline::Write(write) => {
- if let Some(content) = &partial.content {
- for event in &parser.push_content(content) {
- write.process_event(event, buffer, tool, cx);
- }
- }
- }
- Pipeline::Edit(edit_pipeline) => {
- if let Some(edits) = partial.edits {
- for event in &parser.push_edits(&edits) {
- edit_pipeline.process_event(
- event,
- buffer,
- diff,
- abs_path,
- tool,
- event_stream,
- cx,
- )?;
- }
- }
- }
- }
- Ok(())
- }
-}
-
-fn apply_char_operations(
- ops: &[CharOperation],
- buffer: &Entity<Buffer>,
- snapshot: &text::BufferSnapshot,
- edit_cursor: &mut usize,
- action_log: &Entity<ActionLog>,
- cx: &mut AsyncApp,
-) {
- for op in ops {
- match op {
- CharOperation::Insert { text } => {
- let anchor = snapshot.anchor_after(*edit_cursor);
- agent_edit_buffer(&buffer, [(anchor..anchor, text.as_str())], action_log, cx);
- }
- CharOperation::Delete { bytes } => {
- let delete_end = *edit_cursor + bytes;
- let anchor_range = snapshot.anchor_range_inside(*edit_cursor..delete_end);
- agent_edit_buffer(&buffer, [(anchor_range, "")], action_log, cx);
- *edit_cursor = delete_end;
- }
- CharOperation::Keep { bytes } => {
- *edit_cursor += bytes;
- }
- }
- }
-}
-
-fn extract_match(
- matches: Vec<Range<usize>>,
- buffer: &Entity<Buffer>,
- edit_index: &usize,
- file_changed_since_last_read: bool,
- cx: &mut AsyncApp,
-) -> Result<Range<usize>, String> {
- let file_changed_since_last_read_message = if file_changed_since_last_read {
- " The file has changed on disk since you last read it."
- } else {
- ""
- };
-
- match matches.len() {
- 0 => Err(format!(
- "Could not find matching text for edit at index {}. \
- The old_text did not match any content in the file.{} \
- Please read the file again to get the current content.",
- edit_index, file_changed_since_last_read_message,
- )),
- 1 => Ok(matches.into_iter().next().unwrap()),
- _ => {
- let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
- let lines = matches
- .iter()
- .map(|r| (snapshot.offset_to_point(r.start).row + 1).to_string())
- .collect::<Vec<_>>()
- .join(", ");
- Err(format!(
- "Edit {} matched multiple locations in the file at lines: {}. \
- Please provide more context in old_text to uniquely \
- identify the location.",
- edit_index, lines
- ))
- }
- }
-}
-
-/// Edits a buffer and reports the edit to the action log in the same effect
-/// cycle. This ensures the action log's subscription handler sees the version
-/// already updated by `buffer_edited`, so it does not misattribute the agent's
-/// edit as a user edit.
-fn agent_edit_buffer<I, S, T>(
- buffer: &Entity<Buffer>,
- edits: I,
- action_log: &Entity<ActionLog>,
- cx: &mut AsyncApp,
-) where
- I: IntoIterator<Item = (Range<S>, T)>,
- S: ToOffset,
- T: Into<Arc<str>>,
-{
- cx.update(|cx| {
- buffer.update(cx, |buffer, cx| {
- buffer.edit(edits, None, cx);
- });
- action_log.update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx));
- });
-}
-
-fn ensure_buffer_saved(
- buffer: &Entity<Buffer>,
- abs_path: &PathBuf,
- tool: &EditFileTool,
- cx: &mut AsyncApp,
-) -> Result<bool, String> {
- let last_read_mtime = tool
- .action_log
- .read_with(cx, |log, _| log.file_read_time(abs_path));
- let check_result = tool.thread.read_with(cx, |thread, cx| {
- let current = buffer
- .read(cx)
- .file()
- .and_then(|file| file.disk_state().mtime());
- let dirty = buffer.read(cx).is_dirty();
- let has_save = thread.has_tool(SaveFileTool::NAME);
- let has_restore = thread.has_tool(RestoreFileFromDiskTool::NAME);
- (current, dirty, has_save, has_restore)
- });
-
- let Ok((current_mtime, is_dirty, has_save_tool, has_restore_tool)) = check_result else {
- return Ok(false);
- };
-
- if is_dirty {
- let message = match (has_save_tool, has_restore_tool) {
- (true, true) => {
- "This file has unsaved changes. Ask the user whether they want to keep or discard those changes. \
- If they want to keep them, ask for confirmation then use the save_file tool to save the file, then retry this edit. \
- If they want to discard them, ask for confirmation then use the restore_file_from_disk tool to restore the on-disk contents, then retry this edit."
- }
- (true, false) => {
- "This file has unsaved changes. Ask the user whether they want to keep or discard those changes. \
- If they want to keep them, ask for confirmation then use the save_file tool to save the file, then retry this edit. \
- If they want to discard them, ask the user to manually revert the file, then inform you when it's ok to proceed."
- }
- (false, true) => {
- "This file has unsaved changes. Ask the user whether they want to keep or discard those changes. \
- If they want to keep them, ask the user to manually save the file, then inform you when it's ok to proceed. \
- If they want to discard them, ask for confirmation then use the restore_file_from_disk tool to restore the on-disk contents, then retry this edit."
- }
- (false, false) => {
- "This file has unsaved changes. Ask the user whether they want to keep or discard those changes, \
- then ask them to save or revert the file manually and inform you when it's ok to proceed."
- }
- };
- return Err(message.to_string());
- }
-
- if let (Some(last_read), Some(current)) = (last_read_mtime, current_mtime)
- && current != last_read
- {
- return Ok(true);
- }
-
- Ok(false)
-}
-
-fn resolve_path(
- mode: EditFileMode,
- path: &PathBuf,
- project: &Entity<Project>,
- cx: &mut App,
-) -> Result<ProjectPath, String> {
- let project = project.read(cx);
-
- match mode {
- EditFileMode::Edit => {
- let path = project
- .find_project_path(&path, cx)
- .ok_or_else(|| "Can't edit file: path not found".to_string())?;
-
- let entry = project
- .entry_for_path(&path, cx)
- .ok_or_else(|| "Can't edit file: path not found".to_string())?;
-
- if entry.is_file() {
- Ok(path)
- } else {
- Err("Can't edit file: path is a directory".to_string())
- }
- }
- EditFileMode::Write => {
- if let Some(path) = project.find_project_path(&path, cx)
- && let Some(entry) = project.entry_for_path(&path, cx)
- {
- if entry.is_file() {
- return Ok(path);
- } else {
- return Err("Can't write to file: path is a directory".to_string());
- }
- }
-
- let parent_path = path
- .parent()
- .ok_or_else(|| "Can't create file: incorrect path".to_string())?;
-
- let parent_project_path = project.find_project_path(&parent_path, cx);
-
- let parent_entry = parent_project_path
- .as_ref()
- .and_then(|path| project.entry_for_path(path, cx))
- .ok_or_else(|| "Can't create file: parent directory doesn't exist")?;
-
- if !parent_entry.is_dir() {
- return Err("Can't create file: parent is not a directory".to_string());
- }
-
- let file_name = path
- .file_name()
- .and_then(|file_name| file_name.to_str())
- .and_then(|file_name| RelPath::unix(file_name).ok())
- .ok_or_else(|| "Can't create file: invalid filename".to_string())?;
-
- let new_file_path = parent_project_path.map(|parent| ProjectPath {
- path: parent.path.join(file_name),
- ..parent
- });
-
- new_file_path.ok_or_else(|| "Can't create file".to_string())
- }
+ self.session_context.replay_output(output, event_stream, cx)
}
}
@@ -0,0 +1,1067 @@
+mod reindent;
+mod streaming_fuzzy_matcher;
+mod streaming_parser;
+
+use super::restore_file_from_disk_tool::RestoreFileFromDiskTool;
+use super::save_file_tool::SaveFileTool;
+use crate::{AgentTool, Thread, ToolCallEventStream};
+use acp_thread::Diff;
+use action_log::ActionLog;
+use agent_client_protocol::schema::{ToolCallLocation, ToolCallUpdateFields};
+use anyhow::Result;
+use collections::HashSet;
+use gpui::{App, AppContext, AsyncApp, Entity, Task, WeakEntity};
+use language::language_settings::{self, FormatOnSave};
+use language::{Buffer, LanguageRegistry};
+use language_model::LanguageModelToolResultContent;
+use project::lsp_store::{FormatTrigger, LspFormatTarget};
+use project::{AgentLocation, Project, ProjectPath};
+use reindent::{Reindenter, compute_indent_delta};
+use schemars::JsonSchema;
+use serde::{Deserialize, Serialize, de::DeserializeOwned};
+use std::ops::Range;
+use std::path::PathBuf;
+use std::sync::Arc;
+use streaming_diff::{CharOperation, StreamingDiff};
+use streaming_fuzzy_matcher::StreamingFuzzyMatcher;
+use streaming_parser::{EditEvent, StreamingParser, WriteEvent};
+use text::ToOffset;
+use ui::SharedString;
+use util::rel_path::RelPath;
+use util::{Deferred, ResultExt};
+
+/// Operating mode used internally by `EditSession`/`Pipeline` to choose between
+/// applying granular edits (the `edit_file` tool) or replacing/creating the
+/// entire file content (the `write_file` tool).
+#[derive(Clone, Copy, Debug, PartialEq, Eq)]
+pub(crate) enum EditSessionMode {
+ Write,
+ Edit,
+}
+
+/// A single edit operation that replaces old text with new text
+/// Properly escape all text fields as valid JSON strings.
+/// Remember to escape special characters like newlines (`\n`) and quotes (`"`) in JSON strings.
+#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
+pub struct Edit {
+ /// The exact text to find in the file. This will be matched using fuzzy matching
+ /// to handle minor differences in whitespace or formatting.
+ ///
+ /// Be minimal with replacements:
+ /// - For unique lines, include only those lines
+ /// - For non-unique lines, include enough context to identify them
+ pub old_text: String,
+ /// The text to replace it with
+ pub new_text: String,
+}
+
+#[derive(Clone, Default, Debug, Deserialize)]
+pub struct PartialEdit {
+ #[serde(default)]
+ pub old_text: Option<String>,
+ #[serde(default)]
+ pub new_text: Option<String>,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+#[serde(untagged)]
+pub enum EditSessionOutput {
+ Success {
+ #[serde(alias = "original_path")]
+ input_path: PathBuf,
+ new_text: String,
+ old_text: Arc<String>,
+ #[serde(default)]
+ diff: String,
+ },
+ Error {
+ error: String,
+ #[serde(default, skip_serializing_if = "Option::is_none")]
+ input_path: Option<PathBuf>,
+ #[serde(default, skip_serializing_if = "String::is_empty")]
+ diff: String,
+ },
+}
+
+impl EditSessionOutput {
+ pub fn error(error: impl Into<String>) -> Self {
+ Self::Error {
+ error: error.into(),
+ input_path: None,
+ diff: String::new(),
+ }
+ }
+}
+
+impl std::fmt::Display for EditSessionOutput {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ match self {
+ EditSessionOutput::Success {
+ diff, input_path, ..
+ } => {
+ if diff.is_empty() {
+ write!(f, "No edits were made.")
+ } else {
+ write!(
+ f,
+ "Edited {}:\n\n```diff\n{diff}\n```",
+ input_path.display()
+ )
+ }
+ }
+ EditSessionOutput::Error {
+ error,
+ diff,
+ input_path,
+ } => {
+ write!(f, "{error}\n")?;
+ if let Some(input_path) = input_path
+ && !diff.is_empty()
+ {
+ write!(
+ f,
+ "Edited {}:\n\n```diff\n{diff}\n```",
+ input_path.display()
+ )
+ } else {
+ write!(f, "No edits were made.")
+ }
+ }
+ }
+ }
+}
+
+impl From<EditSessionOutput> for LanguageModelToolResultContent {
+ fn from(output: EditSessionOutput) -> Self {
+ output.to_string().into()
+ }
+}
+
+pub(crate) struct EditSessionContext {
+ project: Entity<Project>,
+ thread: WeakEntity<Thread>,
+ action_log: Entity<ActionLog>,
+ language_registry: Arc<LanguageRegistry>,
+}
+
+impl EditSessionContext {
+ pub(crate) fn new(
+ project: Entity<Project>,
+ thread: WeakEntity<Thread>,
+ action_log: Entity<ActionLog>,
+ language_registry: Arc<LanguageRegistry>,
+ ) -> Self {
+ Self {
+ project,
+ thread,
+ action_log,
+ language_registry,
+ }
+ }
+
+ pub(crate) fn authorize(
+ &self,
+ tool_name: &str,
+ path: &PathBuf,
+ event_stream: &ToolCallEventStream,
+ cx: &mut App,
+ ) -> Task<Result<()>> {
+ super::tool_permissions::authorize_file_edit(
+ tool_name,
+ path,
+ &self.thread,
+ event_stream,
+ cx,
+ )
+ }
+
+ fn set_agent_location(&self, buffer: WeakEntity<Buffer>, position: text::Anchor, cx: &mut App) {
+ let should_update_agent_location = self
+ .thread
+ .read_with(cx, |thread, _cx| !thread.is_subagent())
+ .unwrap_or_default();
+ if should_update_agent_location {
+ self.project.update(cx, |project, cx| {
+ project.set_agent_location(Some(AgentLocation { buffer, position }), cx);
+ });
+ }
+ }
+
+ async fn ensure_buffer_saved(&self, buffer: &Entity<Buffer>, cx: &mut AsyncApp) {
+ let format_on_save_enabled = buffer.read_with(cx, |buffer, cx| {
+ let settings = language_settings::LanguageSettings::for_buffer(buffer, cx);
+ settings.format_on_save != FormatOnSave::Off
+ });
+
+ if format_on_save_enabled {
+ self.project
+ .update(cx, |project, cx| {
+ project.format(
+ HashSet::from_iter([buffer.clone()]),
+ LspFormatTarget::Buffers,
+ false,
+ FormatTrigger::Save,
+ cx,
+ )
+ })
+ .await
+ .log_err();
+ }
+
+ self.project
+ .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))
+ .await
+ .log_err();
+
+ self.action_log.update(cx, |log, cx| {
+ log.buffer_edited(buffer.clone(), cx);
+ });
+ }
+
+ pub(crate) fn initial_title_from_path(
+ &self,
+ path: &std::path::Path,
+ default: &str,
+ cx: &App,
+ ) -> SharedString {
+ let project = self.project.read(cx);
+ if let Some(project_path) = project.find_project_path(path, cx)
+ && let Some(short) = project.short_full_path_for_project_path(&project_path, cx)
+ {
+ return short.into();
+ }
+
+ let display = path.to_string_lossy();
+ if display.is_empty() {
+ default.into()
+ } else {
+ display.into_owned().into()
+ }
+ }
+
+ pub(crate) fn replay_output(
+ &self,
+ output: EditSessionOutput,
+ event_stream: ToolCallEventStream,
+ cx: &mut App,
+ ) -> Result<()> {
+ match output {
+ EditSessionOutput::Success {
+ input_path,
+ old_text,
+ new_text,
+ ..
+ } => {
+ event_stream.update_diff(cx.new(|cx| {
+ Diff::finalized(
+ input_path.to_string_lossy().into_owned(),
+ Some(old_text.to_string()),
+ new_text,
+ self.language_registry.clone(),
+ cx,
+ )
+ }));
+ Ok(())
+ }
+ EditSessionOutput::Error { .. } => Ok(()),
+ }
+ }
+}
+
+pub(crate) enum EditSessionResult {
+ Completed(EditSession),
+ Failed {
+ error: String,
+ session: Option<EditSession>,
+ },
+}
+
+pub(crate) async fn run_session(
+ result: EditSessionResult,
+ cx: &mut AsyncApp,
+) -> Result<EditSessionOutput, EditSessionOutput> {
+ match result {
+ EditSessionResult::Completed(session) => {
+ session
+ .context
+ .ensure_buffer_saved(&session.buffer, cx)
+ .await;
+ let (new_text, diff) = session.compute_new_text_and_diff(cx).await;
+ Ok(EditSessionOutput::Success {
+ old_text: session.old_text.clone(),
+ new_text,
+ input_path: session.input_path,
+ diff,
+ })
+ }
+ EditSessionResult::Failed {
+ error,
+ session: Some(session),
+ } => {
+ session
+ .context
+ .ensure_buffer_saved(&session.buffer, cx)
+ .await;
+ let (_new_text, diff) = session.compute_new_text_and_diff(cx).await;
+ Err(EditSessionOutput::Error {
+ error,
+ input_path: Some(session.input_path),
+ diff,
+ })
+ }
+ EditSessionResult::Failed {
+ error,
+ session: None,
+ } => Err(EditSessionOutput::Error {
+ error,
+ input_path: None,
+ diff: String::new(),
+ }),
+ }
+}
+
+pub(crate) fn initial_title_from_partial_path<P>(
+ context: &EditSessionContext,
+ raw_input: serde_json::Value,
+ extract_path: impl FnOnce(&P) -> Option<String>,
+ default: &str,
+ cx: &App,
+) -> SharedString
+where
+ P: DeserializeOwned,
+{
+ if let Ok(partial) = serde_json::from_value::<P>(raw_input)
+ && let Some(raw_path) = extract_path(&partial)
+ {
+ let trimmed = raw_path.trim();
+ if !trimmed.is_empty() {
+ return context.initial_title_from_path(std::path::Path::new(trimmed), default, cx);
+ }
+ }
+ default.into()
+}
+
+pub(crate) struct EditSession {
+ abs_path: PathBuf,
+ pub(crate) input_path: PathBuf,
+ pub(crate) buffer: Entity<Buffer>,
+ pub(crate) old_text: Arc<String>,
+ diff: Entity<Diff>,
+ parser: StreamingParser,
+ pipeline: Pipeline,
+ context: Arc<EditSessionContext>,
+ _finalize_diff_guard: Deferred<Box<dyn FnOnce()>>,
+}
+
+enum Pipeline {
+ Write(WritePipeline),
+ Edit(EditPipeline),
+}
+
+struct WritePipeline {
+ content_written: bool,
+}
+
+struct EditPipeline {
+ current_edit: Option<EditPipelineEntry>,
+ file_changed_since_last_read: bool,
+}
+
+enum EditPipelineEntry {
+ ResolvingOldText {
+ matcher: StreamingFuzzyMatcher,
+ },
+ StreamingNewText {
+ streaming_diff: StreamingDiff,
+ edit_cursor: usize,
+ reindenter: Reindenter,
+ original_snapshot: text::BufferSnapshot,
+ },
+}
+
+impl Pipeline {
+ fn new(mode: EditSessionMode, file_changed_since_last_read: bool) -> Self {
+ match mode {
+ EditSessionMode::Write => Self::Write(WritePipeline {
+ content_written: false,
+ }),
+ EditSessionMode::Edit => Self::Edit(EditPipeline {
+ current_edit: None,
+ file_changed_since_last_read,
+ }),
+ }
+ }
+}
+
+impl WritePipeline {
+ fn process_event(
+ &mut self,
+ event: &WriteEvent,
+ buffer: &Entity<Buffer>,
+ context: &EditSessionContext,
+ cx: &mut AsyncApp,
+ ) {
+ let WriteEvent::ContentChunk { chunk } = event;
+
+ let (buffer_id, buffer_len) =
+ buffer.read_with(cx, |buffer, _cx| (buffer.remote_id(), buffer.len()));
+ let edit_range = if self.content_written {
+ buffer_len..buffer_len
+ } else {
+ 0..buffer_len
+ };
+
+ agent_edit_buffer(
+ buffer,
+ [(edit_range, chunk.as_str())],
+ &context.action_log,
+ cx,
+ );
+ cx.update(|cx| {
+ context.set_agent_location(
+ buffer.downgrade(),
+ text::Anchor::max_for_buffer(buffer_id),
+ cx,
+ );
+ });
+ self.content_written = true;
+ }
+}
+
+impl EditPipeline {
+ fn ensure_resolving_old_text(&mut self, buffer: &Entity<Buffer>, cx: &mut AsyncApp) {
+ if self.current_edit.is_none() {
+ let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.text_snapshot());
+ self.current_edit = Some(EditPipelineEntry::ResolvingOldText {
+ matcher: StreamingFuzzyMatcher::new(snapshot),
+ });
+ }
+ }
+
+ fn process_event(
+ &mut self,
+ event: &EditEvent,
+ buffer: &Entity<Buffer>,
+ diff: &Entity<Diff>,
+ abs_path: &PathBuf,
+ context: &EditSessionContext,
+ event_stream: &ToolCallEventStream,
+ cx: &mut AsyncApp,
+ ) -> Result<(), String> {
+ match event {
+ EditEvent::OldTextChunk {
+ chunk, done: false, ..
+ } => {
+ log::debug!("old_text_chunk: done=false, chunk='{}'", chunk);
+ self.ensure_resolving_old_text(buffer, cx);
+
+ if let Some(EditPipelineEntry::ResolvingOldText { matcher }) =
+ &mut self.current_edit
+ && !chunk.is_empty()
+ {
+ if let Some(match_range) = matcher.push(chunk, None) {
+ let anchor_range = buffer.read_with(cx, |buffer, _cx| {
+ buffer.anchor_range_outside(match_range.clone())
+ });
+ diff.update(cx, |diff, cx| diff.reveal_range(anchor_range, cx));
+
+ cx.update(|cx| {
+ let position = buffer.read(cx).anchor_before(match_range.end);
+ context.set_agent_location(buffer.downgrade(), position, cx);
+ });
+ }
+ }
+ }
+ EditEvent::OldTextChunk {
+ edit_index,
+ chunk,
+ done: true,
+ } => {
+ log::debug!("old_text_chunk: done=true, chunk='{}'", chunk);
+
+ self.ensure_resolving_old_text(buffer, cx);
+
+ let Some(EditPipelineEntry::ResolvingOldText { matcher }) = &mut self.current_edit
+ else {
+ return Ok(());
+ };
+
+ if !chunk.is_empty() {
+ matcher.push(chunk, None);
+ }
+ let range = extract_match(
+ matcher.finish(),
+ buffer,
+ edit_index,
+ self.file_changed_since_last_read,
+ cx,
+ )?;
+
+ let anchor_range =
+ buffer.read_with(cx, |buffer, _cx| buffer.anchor_range_outside(range.clone()));
+ diff.update(cx, |diff, cx| diff.reveal_range(anchor_range, cx));
+
+ let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
+
+ let line = snapshot.offset_to_point(range.start).row;
+ event_stream.update_fields(
+ ToolCallUpdateFields::new()
+ .locations(vec![ToolCallLocation::new(abs_path).line(Some(line))]),
+ );
+
+ let buffer_indent = snapshot.line_indent_for_row(line);
+ let query_indent = text::LineIndent::from_iter(
+ matcher
+ .query_lines()
+ .first()
+ .map(|s| s.as_str())
+ .unwrap_or("")
+ .chars(),
+ );
+ let indent_delta = compute_indent_delta(buffer_indent, query_indent);
+
+ let old_text_in_buffer = snapshot.text_for_range(range.clone()).collect::<String>();
+
+ log::debug!(
+ "edit[{}] old_text matched at {}..{}: {:?}",
+ edit_index,
+ range.start,
+ range.end,
+ old_text_in_buffer,
+ );
+
+ let text_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.text_snapshot());
+ self.current_edit = Some(EditPipelineEntry::StreamingNewText {
+ streaming_diff: StreamingDiff::new(old_text_in_buffer),
+ edit_cursor: range.start,
+ reindenter: Reindenter::new(indent_delta),
+ original_snapshot: text_snapshot,
+ });
+
+ cx.update(|cx| {
+ let position = buffer.read(cx).anchor_before(range.end);
+ context.set_agent_location(buffer.downgrade(), position, cx);
+ });
+ }
+ EditEvent::NewTextChunk {
+ chunk, done: false, ..
+ } => {
+ log::debug!("new_text_chunk: done=false, chunk='{}'", chunk);
+
+ let Some(EditPipelineEntry::StreamingNewText {
+ streaming_diff,
+ edit_cursor,
+ reindenter,
+ original_snapshot,
+ ..
+ }) = &mut self.current_edit
+ else {
+ return Ok(());
+ };
+
+ let reindented = reindenter.push(chunk);
+ if reindented.is_empty() {
+ return Ok(());
+ }
+
+ let char_ops = streaming_diff.push_new(&reindented);
+ apply_char_operations(
+ &char_ops,
+ buffer,
+ original_snapshot,
+ edit_cursor,
+ &context.action_log,
+ cx,
+ );
+
+ let position = original_snapshot.anchor_before(*edit_cursor);
+ cx.update(|cx| {
+ context.set_agent_location(buffer.downgrade(), position, cx);
+ });
+ }
+ EditEvent::NewTextChunk {
+ chunk, done: true, ..
+ } => {
+ log::debug!("new_text_chunk: done=true, chunk='{}'", chunk);
+
+ let Some(EditPipelineEntry::StreamingNewText {
+ mut streaming_diff,
+ mut edit_cursor,
+ mut reindenter,
+ original_snapshot,
+ }) = self.current_edit.take()
+ else {
+ return Ok(());
+ };
+
+ let mut final_text = reindenter.push(chunk);
+ final_text.push_str(&reindenter.finish());
+
+ log::debug!("new_text_chunk: done=true, final_text='{}'", final_text);
+
+ if !final_text.is_empty() {
+ let char_ops = streaming_diff.push_new(&final_text);
+ apply_char_operations(
+ &char_ops,
+ buffer,
+ &original_snapshot,
+ &mut edit_cursor,
+ &context.action_log,
+ cx,
+ );
+ }
+
+ let remaining_ops = streaming_diff.finish();
+ apply_char_operations(
+ &remaining_ops,
+ buffer,
+ &original_snapshot,
+ &mut edit_cursor,
+ &context.action_log,
+ cx,
+ );
+
+ let position = original_snapshot.anchor_before(edit_cursor);
+ cx.update(|cx| {
+ context.set_agent_location(buffer.downgrade(), position, cx);
+ });
+ }
+ }
+ Ok(())
+ }
+}
+
+impl EditSession {
+ pub(crate) async fn new(
+ path: PathBuf,
+ mode: EditSessionMode,
+ tool_name: &str,
+ context: Arc<EditSessionContext>,
+ event_stream: &ToolCallEventStream,
+ cx: &mut AsyncApp,
+ ) -> Result<Self, String> {
+ let project_path = cx.update(|cx| resolve_path(mode, &path, &context.project, cx))?;
+
+ let Some(abs_path) =
+ cx.update(|cx| context.project.read(cx).absolute_path(&project_path, cx))
+ else {
+ return Err(format!(
+ "Worktree at '{}' does not exist",
+ path.to_string_lossy()
+ ));
+ };
+
+ event_stream.update_fields(
+ ToolCallUpdateFields::new().locations(vec![ToolCallLocation::new(abs_path.clone())]),
+ );
+
+ cx.update(|cx| context.authorize(tool_name, &path, event_stream, cx))
+ .await
+ .map_err(|e| e.to_string())?;
+
+ let buffer = context
+ .project
+ .update(cx, |project, cx| project.open_buffer(project_path, cx))
+ .await
+ .map_err(|e| e.to_string())?;
+
+ let file_changed_since_last_read = ensure_buffer_saved(&buffer, &abs_path, &context, cx)?;
+
+ let diff = cx.new(|cx| Diff::new(buffer.clone(), cx));
+ event_stream.update_diff(diff.clone());
+ let finalize_diff_guard = util::defer(Box::new({
+ let diff = diff.downgrade();
+ let mut cx = cx.clone();
+ move || {
+ diff.update(&mut cx, |diff, cx| diff.finalize(cx)).ok();
+ }
+ }) as Box<dyn FnOnce()>);
+
+ context.action_log.update(cx, |log, cx| match mode {
+ EditSessionMode::Write => log.buffer_created(buffer.clone(), cx),
+ EditSessionMode::Edit => log.buffer_read(buffer.clone(), cx),
+ });
+
+ let old_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
+ let old_text = cx
+ .background_spawn({
+ let old_snapshot = old_snapshot.clone();
+ async move { Arc::new(old_snapshot.text()) }
+ })
+ .await;
+
+ Ok(Self {
+ abs_path,
+ input_path: path,
+ buffer,
+ old_text,
+ diff,
+ parser: StreamingParser::default(),
+ pipeline: Pipeline::new(mode, file_changed_since_last_read),
+ context,
+ _finalize_diff_guard: finalize_diff_guard,
+ })
+ }
+
+ pub(crate) async fn finalize_edit(
+ &mut self,
+ edits: Vec<Edit>,
+ event_stream: &ToolCallEventStream,
+ cx: &mut AsyncApp,
+ ) -> Result<(), String> {
+ let Self {
+ abs_path,
+ buffer,
+ diff,
+ parser,
+ pipeline,
+ context,
+ ..
+ } = self;
+ let Pipeline::Edit(edit_pipeline) = pipeline else {
+ return Err("Cannot finalize edits on a write session".to_string());
+ };
+
+ for event in &parser.finalize_edits(&edits) {
+ edit_pipeline.process_event(
+ event,
+ buffer,
+ diff,
+ abs_path,
+ context,
+ event_stream,
+ cx,
+ )?;
+ }
+
+ if log::log_enabled!(log::Level::Debug) {
+ log::debug!("Got edits:");
+ for edit in &edits {
+ log::debug!(
+ " old_text: '{}', new_text: '{}'",
+ edit.old_text.replace('\n', "\\n"),
+ edit.new_text.replace('\n', "\\n")
+ );
+ }
+ }
+ Ok(())
+ }
+
+ pub(crate) async fn finalize_write(
+ &mut self,
+ content: &str,
+ cx: &mut AsyncApp,
+ ) -> Result<(), String> {
+ let Self {
+ buffer,
+ parser,
+ pipeline,
+ context,
+ ..
+ } = self;
+ let Pipeline::Write(write) = pipeline else {
+ return Err("Cannot finalize a write on an edit session".to_string());
+ };
+
+ for event in &parser.finalize_content(content) {
+ write.process_event(event, buffer, context, cx);
+ }
+ Ok(())
+ }
+
+ async fn compute_new_text_and_diff(&self, cx: &mut AsyncApp) -> (String, String) {
+ let new_snapshot = self.buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
+ let (new_text, unified_diff) = cx
+ .background_spawn({
+ let new_snapshot = new_snapshot.clone();
+ let old_text = self.old_text.clone();
+ async move {
+ let new_text = new_snapshot.text();
+ let diff = language::unified_diff(&old_text, &new_text);
+ (new_text, diff)
+ }
+ })
+ .await;
+ (new_text, unified_diff)
+ }
+
+ pub(crate) fn process_edit(
+ &mut self,
+ edits: Option<&[PartialEdit]>,
+ event_stream: &ToolCallEventStream,
+ cx: &mut AsyncApp,
+ ) -> Result<(), String> {
+ let Self {
+ abs_path,
+ buffer,
+ diff,
+ parser,
+ pipeline,
+ context,
+ ..
+ } = self;
+ let Pipeline::Edit(edit_pipeline) = pipeline else {
+ return Err("Cannot apply partial edits on a write session".to_string());
+ };
+ let Some(edits) = edits else {
+ return Ok(());
+ };
+ for event in &parser.push_edits(edits) {
+ edit_pipeline.process_event(
+ event,
+ buffer,
+ diff,
+ abs_path,
+ context,
+ event_stream,
+ cx,
+ )?;
+ }
+ Ok(())
+ }
+
+ pub(crate) fn process_write(
+ &mut self,
+ content: Option<&str>,
+ cx: &mut AsyncApp,
+ ) -> Result<(), String> {
+ let Self {
+ buffer,
+ parser,
+ pipeline,
+ context,
+ ..
+ } = self;
+ let Pipeline::Write(write) = pipeline else {
+ return Err("Cannot apply partial content on an edit session".to_string());
+ };
+ let Some(content) = content else {
+ return Ok(());
+ };
+ for event in &parser.push_content(content) {
+ write.process_event(event, buffer, context, cx);
+ }
+ Ok(())
+ }
+}
+
+fn apply_char_operations(
+ ops: &[CharOperation],
+ buffer: &Entity<Buffer>,
+ snapshot: &text::BufferSnapshot,
+ edit_cursor: &mut usize,
+ action_log: &Entity<ActionLog>,
+ cx: &mut AsyncApp,
+) {
+ for op in ops {
+ match op {
+ CharOperation::Insert { text } => {
+ let anchor = snapshot.anchor_after(*edit_cursor);
+ agent_edit_buffer(&buffer, [(anchor..anchor, text.as_str())], action_log, cx);
+ }
+ CharOperation::Delete { bytes } => {
+ let delete_end = *edit_cursor + bytes;
+ let anchor_range = snapshot.anchor_range_inside(*edit_cursor..delete_end);
+ agent_edit_buffer(&buffer, [(anchor_range, "")], action_log, cx);
+ *edit_cursor = delete_end;
+ }
+ CharOperation::Keep { bytes } => {
+ *edit_cursor += bytes;
+ }
+ }
+ }
+}
+
+fn extract_match(
+ matches: Vec<Range<usize>>,
+ buffer: &Entity<Buffer>,
+ edit_index: &usize,
+ file_changed_since_last_read: bool,
+ cx: &mut AsyncApp,
+) -> Result<Range<usize>, String> {
+ let file_changed_since_last_read_message = if file_changed_since_last_read {
+ " The file has changed on disk since you last read it."
+ } else {
+ ""
+ };
+
+ match matches.len() {
+ 0 => Err(format!(
+ "Could not find matching text for edit at index {}. \
+ The old_text did not match any content in the file.{} \
+ Please read the file again to get the current content.",
+ edit_index, file_changed_since_last_read_message,
+ )),
+ 1 => Ok(matches.into_iter().next().unwrap()),
+ _ => {
+ let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
+ let lines = matches
+ .iter()
+ .map(|range| (snapshot.offset_to_point(range.start).row + 1).to_string())
+ .collect::<Vec<_>>()
+ .join(", ");
+ Err(format!(
+ "Edit {} matched multiple locations in the file at lines: {}. \
+ Please provide more context in old_text to uniquely \
+ identify the location.",
+ edit_index, lines
+ ))
+ }
+ }
+}
+
+/// Edits a buffer and reports the edit to the action log in the same effect
+/// cycle. This ensures the action log's subscription handler sees the version
+/// already updated by `buffer_edited`, so it does not misattribute the agent's
+/// edit as a user edit.
+fn agent_edit_buffer<I, S, T>(
+ buffer: &Entity<Buffer>,
+ edits: I,
+ action_log: &Entity<ActionLog>,
+ cx: &mut AsyncApp,
+) where
+ I: IntoIterator<Item = (Range<S>, T)>,
+ S: ToOffset,
+ T: Into<Arc<str>>,
+{
+ cx.update(|cx| {
+ buffer.update(cx, |buffer, cx| {
+ buffer.edit(edits, None, cx);
+ });
+ action_log.update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx));
+ });
+}
+
+fn ensure_buffer_saved(
+ buffer: &Entity<Buffer>,
+ abs_path: &PathBuf,
+ context: &EditSessionContext,
+ cx: &mut AsyncApp,
+) -> Result<bool, String> {
+ let last_read_mtime = context
+ .action_log
+ .read_with(cx, |log, _| log.file_read_time(abs_path));
+ let check_result = context.thread.read_with(cx, |thread, cx| {
+ let current = buffer
+ .read(cx)
+ .file()
+ .and_then(|file| file.disk_state().mtime());
+ let dirty = buffer.read(cx).is_dirty();
+ let has_save = thread.has_tool(SaveFileTool::NAME);
+ let has_restore = thread.has_tool(RestoreFileFromDiskTool::NAME);
+ (current, dirty, has_save, has_restore)
+ });
+
+ let Ok((current_mtime, is_dirty, has_save_tool, has_restore_tool)) = check_result else {
+ return Ok(false);
+ };
+
+ if is_dirty {
+ let message = match (has_save_tool, has_restore_tool) {
+ (true, true) => {
+ "This file has unsaved changes. Ask the user whether they want to keep or discard those changes. \
+ If they want to keep them, ask for confirmation then use the save_file tool to save the file, then retry this edit. \
+ If they want to discard them, ask for confirmation then use the restore_file_from_disk tool to restore the on-disk contents, then retry this edit."
+ }
+ (true, false) => {
+ "This file has unsaved changes. Ask the user whether they want to keep or discard those changes. \
+ If they want to keep them, ask for confirmation then use the save_file tool to save the file, then retry this edit. \
+ If they want to discard them, ask the user to manually revert the file, then inform you when it's ok to proceed."
+ }
+ (false, true) => {
+ "This file has unsaved changes. Ask the user whether they want to keep or discard those changes. \
+ If they want to keep them, ask the user to manually save the file, then inform you when it's ok to proceed. \
+ If they want to discard them, ask for confirmation then use the restore_file_from_disk tool to restore the on-disk contents, then retry this edit."
+ }
+ (false, false) => {
+ "This file has unsaved changes. Ask the user whether they want to keep or discard those changes, \
+ then ask them to save or revert the file manually and inform you when it's ok to proceed."
+ }
+ };
+ return Err(message.to_string());
+ }
+
+ if let (Some(last_read), Some(current)) = (last_read_mtime, current_mtime)
+ && current != last_read
+ {
+ return Ok(true);
+ }
+
+ Ok(false)
+}
+
+fn resolve_path(
+ mode: EditSessionMode,
+ path: &PathBuf,
+ project: &Entity<Project>,
+ cx: &mut App,
+) -> Result<ProjectPath, String> {
+ let project = project.read(cx);
+
+ match mode {
+ EditSessionMode::Edit => {
+ let path = project
+ .find_project_path(&path, cx)
+ .ok_or_else(|| "Can't edit file: path not found".to_string())?;
+
+ let entry = project
+ .entry_for_path(&path, cx)
+ .ok_or_else(|| "Can't edit file: path not found".to_string())?;
+
+ if entry.is_file() {
+ Ok(path)
+ } else {
+ Err("Can't edit file: path is a directory".to_string())
+ }
+ }
+ EditSessionMode::Write => {
+ if let Some(path) = project.find_project_path(&path, cx)
+ && let Some(entry) = project.entry_for_path(&path, cx)
+ {
+ if entry.is_file() {
+ return Ok(path);
+ } else {
+ return Err("Can't write to file: path is a directory".to_string());
+ }
+ }
+
+ let parent_path = path
+ .parent()
+ .ok_or_else(|| "Can't create file: incorrect path".to_string())?;
+
+ let parent_project_path = project.find_project_path(&parent_path, cx);
+
+ let parent_entry = parent_project_path
+ .as_ref()
+ .and_then(|path| project.entry_for_path(path, cx))
+ .ok_or_else(|| "Can't create file: parent directory doesn't exist")?;
+
+ if !parent_entry.is_dir() {
+ return Err("Can't create file: parent is not a directory".to_string());
+ }
+
+ let file_name = path
+ .file_name()
+ .and_then(|file_name| file_name.to_str())
+ .and_then(|file_name| RelPath::unix(file_name).ok())
+ .ok_or_else(|| "Can't create file: invalid filename".to_string())?;
+
+ let new_file_path = parent_project_path.map(|parent| ProjectPath {
+ path: parent.path.join(file_name),
+ ..parent
+ });
+
+ new_file_path.ok_or_else(|| "Can't create file".to_string())
+ }
+ }
+}
+
+#[cfg(test)]
+pub(crate) async fn test_resolve_path(
+ mode: &EditSessionMode,
+ path: &str,
+ project: &Entity<Project>,
+ cx: &mut gpui::TestAppContext,
+) -> Result<ProjectPath, String> {
+ cx.update(|cx| resolve_path(*mode, &PathBuf::from(path), project, cx))
+}
@@ -1,6 +1,6 @@
use smallvec::SmallVec;
-use crate::{Edit, PartialEdit};
+use super::{Edit, PartialEdit};
/// Events emitted by `StreamingParser` for edit-mode input.
#[derive(Debug, PartialEq, Eq)]
@@ -2,3 +2,5 @@
mod edit_file;
#[cfg(all(test, feature = "unit-eval"))]
mod terminal_tool;
+#[cfg(all(test, feature = "unit-eval"))]
+mod write_file;
@@ -1,8 +1,7 @@
use crate::tools::edit_file_tool::*;
use crate::{
- AgentTool, ContextServerRegistry, EditFileTool, GrepTool, GrepToolInput, ListDirectoryTool,
- ListDirectoryToolInput, ReadFileTool, ReadFileToolInput, Template, Templates, Thread,
- ToolCallEventStream, ToolInput,
+ AgentTool, ContextServerRegistry, EditFileTool, GrepTool, GrepToolInput, ReadFileTool,
+ ReadFileToolInput, Template, Templates, Thread, ToolCallEventStream, ToolInput,
};
use Role::*;
use anyhow::{Context as _, Result};
@@ -124,20 +123,6 @@ impl EvalAssertion {
EvalAssertion(Arc::new(f))
}
- fn assert_eq(expected: impl Into<String>) -> Self {
- let expected = expected.into();
- Self::new(async move |sample, _judge, _cx| {
- Ok(EvalAssertionOutcome {
- score: if strip_empty_lines(&sample.text_after) == strip_empty_lines(&expected) {
- 100
- } else {
- 0
- },
- message: None,
- })
- })
- }
-
fn assert_diff_any(expected_diffs: Vec<impl Into<String>>) -> Self {
let expected_diffs: Vec<String> = expected_diffs.into_iter().map(Into::into).collect();
Self::new(async move |sample, _judge, _cx| {
@@ -1499,46 +1484,3 @@ fn eval_add_overwrite_test() {
))
});
}
-
-#[test]
-#[cfg_attr(not(feature = "unit-eval"), ignore)]
-fn eval_create_empty_file() {
- let input_file_path = "root/TODO3";
- let input_file_content = None;
- let expected_output_content = String::new();
-
- eval_utils::eval(100, 0.99, eval_utils::NoProcessor, move || {
- run_eval(EvalInput::new(
- vec![
- message(User, [text("Create a second empty todo file ")]),
- message(
- Assistant,
- [
- text(indoc::formatdoc! {"
- I'll help you create a second empty todo file.
- First, let me examine the project structure to see if there's already a todo file, which will help me determine the appropriate name and location for the second one.
- "}),
- tool_use(
- "toolu_01GAF8TtsgpjKxCr8fgQLDgR",
- ListDirectoryTool::NAME,
- ListDirectoryToolInput {
- path: "root".to_string(),
- },
- ),
- ],
- ),
- message(
- User,
- [tool_result(
- "toolu_01GAF8TtsgpjKxCr8fgQLDgR",
- ListDirectoryTool::NAME,
- "root/TODO\nroot/TODO2\nroot/new.txt\n",
- )],
- ),
- ],
- input_file_path,
- input_file_content.clone(),
- EvalAssertion::assert_eq(expected_output_content.clone()),
- ))
- });
-}
@@ -0,0 +1,561 @@
+use crate::{
+ AgentTool, ContextServerRegistry, ListDirectoryTool, ListDirectoryToolInput, Template,
+ Templates, Thread, ToolCallEventStream, ToolInput, WriteFileTool, WriteFileToolInput,
+};
+use Role::*;
+use anyhow::{Context as _, Result};
+use client::{Client, RefreshLlmTokenListener, UserStore};
+use fs::FakeFs;
+use futures::StreamExt;
+use gpui::{AppContext as _, AsyncApp, Entity, TestAppContext, UpdateGlobal as _};
+use http_client::StatusCode;
+use language::language_settings::FormatOnSave;
+use language_model::{
+ LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
+ LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
+ LanguageModelToolResult, LanguageModelToolResultContent, LanguageModelToolUse,
+ LanguageModelToolUseId, MessageContent, Role, SelectedModel,
+};
+use project::Project;
+use prompt_store::{ProjectContext, WorktreeContext};
+use rand::prelude::*;
+use reqwest_client::ReqwestClient;
+use serde::Serialize;
+use settings::SettingsStore;
+use std::{
+ fmt::{self, Display},
+ path::{Path, PathBuf},
+ str::FromStr,
+ sync::Arc,
+ time::Duration,
+};
+use util::path;
+
+#[derive(Clone)]
+struct EvalInput {
+ conversation: Vec<LanguageModelRequestMessage>,
+ input_file_path: PathBuf,
+ input_content: Option<String>,
+ expected_output_content: String,
+}
+
+impl EvalInput {
+ fn new(
+ conversation: Vec<LanguageModelRequestMessage>,
+ input_file_path: impl Into<PathBuf>,
+ input_content: Option<String>,
+ expected_output_content: String,
+ ) -> Self {
+ Self {
+ conversation,
+ input_file_path: input_file_path.into(),
+ input_content,
+ expected_output_content,
+ }
+ }
+}
+
+struct WriteEvalOutput {
+ tool_input: WriteFileToolInput,
+ text_after: String,
+}
+
+impl Display for WriteEvalOutput {
+ fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+ writeln!(f, "Tool Input:\n{:#?}", self.tool_input)?;
+ writeln!(f, "Text After:\n{}", self.text_after)?;
+ Ok(())
+ }
+}
+
+struct WriteToolTest {
+ fs: Arc<FakeFs>,
+ project: Entity<Project>,
+ model: Arc<dyn LanguageModel>,
+ model_thinking_effort: Option<String>,
+}
+
+impl WriteToolTest {
+ async fn new(cx: &mut TestAppContext) -> Self {
+ cx.executor().allow_parking();
+
+ let fs = FakeFs::new(cx.executor());
+ cx.update(|cx| {
+ let settings_store = SettingsStore::test(cx);
+ cx.set_global(settings_store);
+ SettingsStore::update_global(cx, |store: &mut SettingsStore, cx| {
+ store.update_user_settings(cx, |settings| {
+ settings
+ .project
+ .all_languages
+ .defaults
+ .ensure_final_newline_on_save = Some(false);
+ settings.project.all_languages.defaults.format_on_save =
+ Some(FormatOnSave::Off);
+ });
+ });
+
+ gpui_tokio::init(cx);
+ let http_client = Arc::new(ReqwestClient::user_agent("agent tests").unwrap());
+ cx.set_http_client(http_client);
+ let client = Client::production(cx);
+ let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
+ language_model::init(cx);
+ RefreshLlmTokenListener::register(client.clone(), user_store.clone(), cx);
+ language_models::init(user_store, client, cx);
+ });
+
+ fs.insert_tree("/root", serde_json::json!({})).await;
+ let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
+ let agent_model = SelectedModel::from_str(
+ &std::env::var("ZED_AGENT_MODEL")
+ .unwrap_or("anthropic/claude-sonnet-4-6-latest".into()),
+ )
+ .unwrap();
+
+ let authenticate_provider_tasks = cx.update(|cx| {
+ LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
+ registry
+ .providers()
+ .iter()
+ .map(|p| p.authenticate(cx))
+ .collect::<Vec<_>>()
+ })
+ });
+ let model = cx
+ .update(|cx| {
+ cx.spawn(async move |cx| {
+ futures::future::join_all(authenticate_provider_tasks).await;
+ Self::load_model(&agent_model, cx).await.unwrap()
+ })
+ })
+ .await;
+
+ let model_thinking_effort = model
+ .default_effort_level()
+ .map(|effort_level| effort_level.value.to_string());
+
+ Self {
+ fs,
+ project,
+ model,
+ model_thinking_effort,
+ }
+ }
+
+ async fn load_model(
+ selected_model: &SelectedModel,
+ cx: &mut AsyncApp,
+ ) -> Result<Arc<dyn LanguageModel>> {
+ cx.update(|cx| {
+ let registry = LanguageModelRegistry::read_global(cx);
+ let provider = registry
+ .provider(&selected_model.provider)
+ .expect("Provider not found");
+ provider.authenticate(cx)
+ })
+ .await?;
+ Ok(cx.update(|cx| {
+ let models = LanguageModelRegistry::read_global(cx);
+ models
+ .available_models(cx)
+ .find(|model| {
+ model.provider_id() == selected_model.provider
+ && model.id() == selected_model.model
+ })
+ .unwrap_or_else(|| panic!("Model {} not found", selected_model.model.0))
+ }))
+ }
+
+ async fn eval(&self, mut eval: EvalInput, cx: &mut TestAppContext) -> Result<WriteEvalOutput> {
+ eval.conversation
+ .last_mut()
+ .context("Conversation must not be empty")?
+ .cache = true;
+
+ if let Some(input_content) = eval.input_content.as_deref() {
+ let abs_path = Path::new("/root").join(
+ eval.input_file_path
+ .strip_prefix("root")
+ .unwrap_or(&eval.input_file_path),
+ );
+ self.fs.insert_file(&abs_path, input_content.into()).await;
+ cx.run_until_parked();
+ }
+
+ let tools = crate::built_in_tools().collect::<Vec<_>>();
+
+ let system_prompt = {
+ let worktrees = vec![WorktreeContext {
+ root_name: "root".to_string(),
+ abs_path: Path::new("/path/to/root").into(),
+ rules_file: None,
+ }];
+ let project_context = ProjectContext::new(worktrees, Vec::default());
+ let tool_names = tools
+ .iter()
+ .map(|tool| tool.name.clone().into())
+ .collect::<Vec<_>>();
+ let template = crate::SystemPromptTemplate {
+ project: &project_context,
+ available_tools: tool_names,
+ model_name: None,
+ };
+ let templates = Templates::new();
+ template.render(&templates)?
+ };
+
+ let messages = [LanguageModelRequestMessage {
+ role: Role::System,
+ content: vec![MessageContent::Text(system_prompt)],
+ cache: true,
+ reasoning_details: None,
+ }]
+ .into_iter()
+ .chain(eval.conversation)
+ .collect::<Vec<_>>();
+
+ let request = LanguageModelRequest {
+ messages,
+ tools,
+ thinking_allowed: true,
+ thinking_effort: self.model_thinking_effort.clone(),
+ ..Default::default()
+ };
+
+ let tool_input =
+ retry_on_rate_limit(async || self.extract_tool_use(request.clone(), cx).await).await?;
+
+ let language_registry = self
+ .project
+ .read_with(cx, |project, _cx| project.languages().clone());
+
+ let context_server_registry = cx
+ .new(|cx| ContextServerRegistry::new(self.project.read(cx).context_server_store(), cx));
+ let thread = cx.new(|cx| {
+ Thread::new(
+ self.project.clone(),
+ cx.new(|_cx| ProjectContext::default()),
+ context_server_registry,
+ Templates::new(),
+ Some(self.model.clone()),
+ cx,
+ )
+ });
+ let action_log = thread.read_with(cx, |thread, _| thread.action_log().clone());
+
+ let tool = Arc::new(WriteFileTool::new(
+ self.project.clone(),
+ thread.downgrade(),
+ action_log,
+ language_registry,
+ ));
+
+ let result = cx
+ .update(|cx| {
+ tool.clone().run(
+ ToolInput::resolved(tool_input.clone()),
+ ToolCallEventStream::test().0,
+ cx,
+ )
+ })
+ .await;
+
+ let output = match result {
+ Ok(output) => output,
+ Err(output) => anyhow::bail!("Tool returned error: {}", output),
+ };
+
+ let crate::EditFileToolOutput::Success { new_text, .. } = &output else {
+ anyhow::bail!("Tool returned error output: {}", output);
+ };
+
+ if tool_input.path != eval.input_file_path {
+ anyhow::bail!(
+ "Tool path mismatch. Expected {:?}, got {:?}",
+ eval.input_file_path,
+ tool_input.path,
+ );
+ }
+
+ if new_text != &eval.expected_output_content {
+ anyhow::bail!(
+ "Output content mismatch. Expected {:?}, got {:?}",
+ eval.expected_output_content,
+ new_text,
+ );
+ }
+
+ Ok(WriteEvalOutput {
+ tool_input,
+ text_after: new_text.clone(),
+ })
+ }
+
+ async fn extract_tool_use(
+ &self,
+ request: LanguageModelRequest,
+ cx: &mut TestAppContext,
+ ) -> Result<WriteFileToolInput> {
+ let model = self.model.clone();
+ let events = cx
+ .update(|cx| {
+ let async_cx = cx.to_async();
+ cx.foreground_executor()
+ .spawn(async move { model.stream_completion(request, &async_cx).await })
+ })
+ .await
+ .map_err(|err| anyhow::anyhow!("completion error: {}", err))?;
+
+ let mut streamed_text = String::new();
+ let mut stop_reason = None;
+ let mut parse_errors = Vec::new();
+
+ let mut events = events.fuse();
+ while let Some(event) = events.next().await {
+ match event {
+ Ok(LanguageModelCompletionEvent::ToolUse(tool_use))
+ if tool_use.is_input_complete
+ && tool_use.name.as_ref() == WriteFileTool::NAME =>
+ {
+ let input: WriteFileToolInput = serde_json::from_value(tool_use.input)
+ .context("Failed to parse tool input as WriteFileToolInput")?;
+ return Ok(input);
+ }
+ Ok(LanguageModelCompletionEvent::Text(text)) => {
+ if streamed_text.len() < 2_000 {
+ streamed_text.push_str(&text);
+ }
+ }
+ Ok(LanguageModelCompletionEvent::Stop(reason)) => {
+ stop_reason = Some(reason);
+ }
+ Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
+ tool_name,
+ raw_input,
+ json_parse_error,
+ ..
+ }) if tool_name.as_ref() == WriteFileTool::NAME => {
+ parse_errors.push(format!("{json_parse_error}\nRaw input:\n{raw_input:?}"));
+ }
+ Err(err) => return Err(anyhow::anyhow!("completion error: {}", err)),
+ _ => {}
+ }
+ }
+
+ let streamed_text = streamed_text.trim();
+ let streamed_text_suffix = if streamed_text.is_empty() {
+ String::new()
+ } else {
+ format!("\nStreamed text:\n{streamed_text}")
+ };
+ let stop_reason_suffix = stop_reason
+ .map(|reason| format!("\nStop reason: {reason:?}"))
+ .unwrap_or_default();
+ let parse_errors_suffix = if parse_errors.is_empty() {
+ String::new()
+ } else {
+ format!("\nTool parse errors:\n{}", parse_errors.join("\n"))
+ };
+
+ anyhow::bail!(
+ "Stream ended without a write_file tool use{stop_reason_suffix}{parse_errors_suffix}{streamed_text_suffix}"
+ )
+ }
+}
+
+fn run_eval(eval: EvalInput) -> eval_utils::EvalOutput<()> {
+ let dispatcher = gpui::TestDispatcher::new(rand::random());
+ let mut cx = TestAppContext::build(dispatcher, None);
+ let foreground_executor = cx.foreground_executor().clone();
+ let result = foreground_executor.block_test(async {
+ let test = WriteToolTest::new(&mut cx).await;
+ let result = test.eval(eval, &mut cx).await;
+ drop(test);
+ cx.run_until_parked();
+ result
+ });
+ cx.quit();
+ match result {
+ Ok(output) => eval_utils::EvalOutput {
+ data: output.to_string(),
+ outcome: eval_utils::OutcomeKind::Passed,
+ metadata: (),
+ },
+ Err(err) => eval_utils::EvalOutput {
+ data: format!("{err:?}"),
+ outcome: eval_utils::OutcomeKind::Error,
+ metadata: (),
+ },
+ }
+}
+
+fn message(
+ role: Role,
+ content: impl IntoIterator<Item = MessageContent>,
+) -> LanguageModelRequestMessage {
+ LanguageModelRequestMessage {
+ role,
+ content: content.into_iter().collect(),
+ cache: false,
+ reasoning_details: None,
+ }
+}
+
+fn text(text: impl Into<String>) -> MessageContent {
+ MessageContent::Text(text.into())
+}
+
+fn tool_use(
+ id: impl Into<Arc<str>>,
+ name: impl Into<Arc<str>>,
+ input: impl Serialize,
+) -> MessageContent {
+ MessageContent::ToolUse(LanguageModelToolUse {
+ id: LanguageModelToolUseId::from(id.into()),
+ name: name.into(),
+ raw_input: serde_json::to_string_pretty(&input).unwrap(),
+ input: serde_json::to_value(input).unwrap(),
+ is_input_complete: true,
+ thought_signature: None,
+ })
+}
+
+fn tool_result(
+ id: impl Into<Arc<str>>,
+ name: impl Into<Arc<str>>,
+ result: impl Into<Arc<str>>,
+) -> MessageContent {
+ MessageContent::ToolResult(LanguageModelToolResult {
+ tool_use_id: LanguageModelToolUseId::from(id.into()),
+ tool_name: name.into(),
+ is_error: false,
+ content: vec![LanguageModelToolResultContent::Text(result.into())],
+ output: None,
+ })
+}
+
+async fn retry_on_rate_limit<R>(mut request: impl AsyncFnMut() -> Result<R>) -> Result<R> {
+ const MAX_RETRIES: usize = 20;
+ let mut attempt = 0;
+
+ loop {
+ attempt += 1;
+ let response = request().await;
+
+ if attempt >= MAX_RETRIES {
+ return response;
+ }
+
+ let retry_delay = match &response {
+ Ok(_) => None,
+ Err(err) => match err.downcast_ref::<LanguageModelCompletionError>() {
+ Some(err) => match &err {
+ LanguageModelCompletionError::RateLimitExceeded { retry_after, .. }
+ | LanguageModelCompletionError::ServerOverloaded { retry_after, .. } => {
+ Some(retry_after.unwrap_or(Duration::from_secs(5)))
+ }
+ LanguageModelCompletionError::UpstreamProviderError {
+ status,
+ retry_after,
+ ..
+ } => {
+ let should_retry = matches!(
+ *status,
+ StatusCode::TOO_MANY_REQUESTS | StatusCode::SERVICE_UNAVAILABLE
+ ) || status.as_u16() == 529;
+
+ if should_retry {
+ Some(retry_after.unwrap_or(Duration::from_secs(5)))
+ } else {
+ None
+ }
+ }
+ LanguageModelCompletionError::ApiReadResponseError { .. }
+ | LanguageModelCompletionError::ApiInternalServerError { .. }
+ | LanguageModelCompletionError::HttpSend { .. } => {
+ Some(Duration::from_secs(2_u64.pow((attempt - 1) as u32).min(30)))
+ }
+ _ => None,
+ },
+ _ => None,
+ },
+ };
+
+ if let Some(retry_after) = retry_delay {
+ let jitter = retry_after.mul_f64(rand::rng().random_range(0.0..1.0));
+ eprintln!("Attempt #{attempt}: Retry after {retry_after:?} + jitter of {jitter:?}");
+ #[allow(clippy::disallowed_methods)]
+ async_io::Timer::after(retry_after + jitter).await;
+ } else {
+ return response;
+ }
+ }
+}
+
+#[test]
+#[cfg_attr(not(feature = "unit-eval"), ignore)]
+fn eval_create_file() {
+ let input_file_path = "root/TODO3";
+ let expected_output_content = "todo".to_string();
+
+ eval_utils::eval(100, 1., eval_utils::NoProcessor, move || {
+ run_eval(EvalInput::new(
+ vec![
+ message(
+ User,
+ [text("Create a third todo file. Write 'todo' inside it.")],
+ ),
+ message(
+ Assistant,
+ [
+ text(indoc::formatdoc! {"
+ I'll help you create a third empty todo file.
+ First, let me examine the project structure to see if there's already a todo file, which will help me determine the appropriate name and location for the second one.
+ "}),
+ tool_use(
+ "toolu_01GAF8TtsgpjKxCr8fgQLDgR",
+ ListDirectoryTool::NAME,
+ ListDirectoryToolInput {
+ path: "root".to_string(),
+ },
+ ),
+ ],
+ ),
+ message(
+ User,
+ [tool_result(
+ "toolu_01GAF8TtsgpjKxCr8fgQLDgR",
+ ListDirectoryTool::NAME,
+ "root/TODO\nroot/TODO2\nroot/new.txt\n",
+ )],
+ ),
+ ],
+ input_file_path,
+ None,
+ expected_output_content.clone(),
+ ))
+ });
+}
+
+#[test]
+#[cfg_attr(not(feature = "unit-eval"), ignore)]
+fn eval_overwrite_file() {
+ let input_file_path = "root/notes.txt";
+ let input_file_content = "old notes\nkeep nothing\n".to_string();
+ let expected_output_content = "new notes".to_string();
+
+ eval_utils::eval(100, 1., eval_utils::NoProcessor, move || {
+ run_eval(EvalInput::new(
+ vec![message(
+ User,
+ [text(indoc::formatdoc! {"
+ Overwrite `{input_file_path}` so that its complete contents are exactly: 'new notes'
+ "})],
+ )],
+ input_file_path,
+ Some(input_file_content.clone()),
+ expected_output_content.clone(),
+ ))
+ });
+}
@@ -0,0 +1,1190 @@
+use super::edit_session::{
+ EditSession, EditSessionContext, EditSessionMode, EditSessionOutput, EditSessionResult,
+ initial_title_from_partial_path, run_session,
+};
+use crate::{AgentTool, Thread, ToolCallEventStream, ToolInput, ToolInputPayload};
+use action_log::ActionLog;
+use agent_client_protocol::schema as acp;
+use futures::FutureExt as _;
+use gpui::{App, AsyncApp, Entity, Task, WeakEntity};
+use language::LanguageRegistry;
+use project::Project;
+use schemars::JsonSchema;
+use serde::{Deserialize, Serialize};
+use std::path::PathBuf;
+use std::sync::Arc;
+use ui::SharedString;
+
+const DEFAULT_UI_TEXT: &str = "Writing file";
+
+/// This is a tool for creating a new file or overwriting an existing file with completely new contents.
+///
+/// To make granular edits to an existing file, prefer the `edit_file` tool instead.
+///
+/// Before using this tool:
+///
+/// 1. Verify the directory path is correct (only applicable when creating new files):
+/// - Use the `list_directory` tool to verify the parent directory exists and is the correct location
+#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
+pub struct WriteFileToolInput {
+ /// The full path of the file to create or overwrite in the project.
+ ///
+ /// WARNING: When specifying which file path need changing, you MUST start each path with one of the project's root directories.
+ ///
+ /// The following examples assume we have two root directories in the project:
+ /// - /a/b/backend
+ /// - /c/d/frontend
+ ///
+ /// <example>
+ /// `backend/src/main.rs`
+ ///
+ /// Notice how the file path starts with `backend`. Without that, the path would be ambiguous and the call would fail!
+ /// </example>
+ ///
+ /// <example>
+ /// `frontend/db.js`
+ /// </example>
+ pub path: PathBuf,
+
+ /// The complete content for the file.
+ /// This field should contain the entire file content.
+ pub content: String,
+}
+
+#[derive(Clone, Default, Debug, Deserialize)]
+struct WriteFileToolPartialInput {
+ #[serde(default)]
+ path: Option<String>,
+ #[serde(default)]
+ content: Option<String>,
+}
+
+pub struct WriteFileTool {
+ session_context: Arc<EditSessionContext>,
+}
+
+impl WriteFileTool {
+ pub fn new(
+ project: Entity<Project>,
+ thread: WeakEntity<Thread>,
+ action_log: Entity<ActionLog>,
+ language_registry: Arc<LanguageRegistry>,
+ ) -> Self {
+ Self {
+ session_context: Arc::new(EditSessionContext::new(
+ project,
+ thread,
+ action_log,
+ language_registry,
+ )),
+ }
+ }
+
+ async fn process_streaming_writes(
+ &self,
+ input: &mut ToolInput<WriteFileToolInput>,
+ event_stream: &ToolCallEventStream,
+ cx: &mut AsyncApp,
+ ) -> EditSessionResult {
+ let mut session: Option<EditSession> = None;
+ let mut last_path: Option<String> = None;
+
+ loop {
+ futures::select! {
+ payload = input.next().fuse() => {
+ match payload {
+ Ok(payload) => match payload {
+ ToolInputPayload::Partial(partial) => {
+ if let Ok(parsed) = serde_json::from_value::<WriteFileToolPartialInput>(partial) {
+ let path_complete = parsed.path.is_some()
+ && parsed.path.as_ref() == last_path.as_ref();
+
+ last_path = parsed.path.clone();
+
+ if session.is_none()
+ && path_complete
+ && let Some(path) = parsed.path.as_ref()
+ {
+ match EditSession::new(
+ PathBuf::from(path),
+ EditSessionMode::Write,
+ Self::NAME,
+ self.session_context.clone(),
+ event_stream,
+ cx,
+ )
+ .await
+ {
+ Ok(created_session) => session = Some(created_session),
+ Err(error) => {
+ log::error!("Failed to create edit session: {}", error);
+ return EditSessionResult::Failed {
+ error,
+ session: None,
+ };
+ }
+ }
+ }
+
+ if let Some(current_session) = &mut session
+ && let Err(error) = current_session.process_write(parsed.content.as_deref(), cx)
+ {
+ log::error!("Failed to process write: {}", error);
+ return EditSessionResult::Failed { error, session };
+ }
+ }
+ }
+ ToolInputPayload::Full(full_input) => {
+ let mut session = if let Some(session) = session {
+ session
+ } else {
+ match EditSession::new(
+ full_input.path.clone(),
+ EditSessionMode::Write,
+ Self::NAME,
+ self.session_context.clone(),
+ event_stream,
+ cx,
+ )
+ .await
+ {
+ Ok(created_session) => created_session,
+ Err(error) => {
+ log::error!("Failed to create edit session: {}", error);
+ return EditSessionResult::Failed {
+ error,
+ session: None,
+ };
+ }
+ }
+ };
+
+ return match session.finalize_write(&full_input.content, cx).await {
+ Ok(()) => EditSessionResult::Completed(session),
+ Err(error) => {
+ log::error!("Failed to finalize write: {}", error);
+ EditSessionResult::Failed {
+ error,
+ session: Some(session),
+ }
+ }
+ };
+ }
+ ToolInputPayload::InvalidJson { error_message } => {
+ log::error!("Received invalid JSON: {error_message}");
+ return EditSessionResult::Failed {
+ error: error_message,
+ session,
+ };
+ }
+ },
+ Err(error) => {
+ return EditSessionResult::Failed {
+ error: error.to_string(),
+ session,
+ };
+ }
+ }
+ }
+ _ = event_stream.cancelled_by_user().fuse() => {
+ return EditSessionResult::Failed {
+ error: "Write cancelled by user".to_string(),
+ session,
+ };
+ }
+ }
+ }
+ }
+}
+
+impl AgentTool for WriteFileTool {
+ type Input = WriteFileToolInput;
+ type Output = EditSessionOutput;
+
+ const NAME: &'static str = "write_file";
+
+ fn supports_input_streaming() -> bool {
+ true
+ }
+
+ fn kind() -> acp::ToolKind {
+ acp::ToolKind::Edit
+ }
+
+ fn initial_title(
+ &self,
+ input: Result<Self::Input, serde_json::Value>,
+ cx: &mut App,
+ ) -> SharedString {
+ match input {
+ Ok(input) => {
+ self.session_context
+ .initial_title_from_path(&input.path, DEFAULT_UI_TEXT, cx)
+ }
+ Err(raw_input) => initial_title_from_partial_path::<WriteFileToolPartialInput>(
+ &self.session_context,
+ raw_input,
+ |partial| partial.path.clone(),
+ DEFAULT_UI_TEXT,
+ cx,
+ ),
+ }
+ }
+
+ fn run(
+ self: Arc<Self>,
+ mut input: ToolInput<Self::Input>,
+ event_stream: ToolCallEventStream,
+ cx: &mut App,
+ ) -> Task<Result<Self::Output, Self::Output>> {
+ cx.spawn(async move |cx: &mut AsyncApp| {
+ run_session(
+ self.process_streaming_writes(&mut input, &event_stream, cx)
+ .await,
+ cx,
+ )
+ .await
+ })
+ }
+
+ fn replay(
+ &self,
+ _input: Self::Input,
+ output: Self::Output,
+ event_stream: ToolCallEventStream,
+ cx: &mut App,
+ ) -> anyhow::Result<()> {
+ self.session_context.replay_output(output, event_stream, cx)
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use crate::{
+ AgentTool, ContextServerRegistry, Templates, Thread, ToolCallEventStream, ToolInput,
+ ToolInputSender,
+ };
+ use acp_thread::Diff;
+ use action_log::ActionLog;
+ use fs::Fs as _;
+ use futures::StreamExt as _;
+ use gpui::{AppContext as _, Entity, TestAppContext, UpdateGlobal};
+ use language::language_settings::FormatOnSave;
+ use language_model::fake_provider::FakeLanguageModel;
+ use project::{Project, ProjectPath};
+ use prompt_store::ProjectContext;
+ use serde_json::json;
+ use settings::{Settings, SettingsStore};
+ use std::sync::Arc;
+ use util::path;
+ use util::rel_path::{RelPath, rel_path};
+
+ #[gpui::test]
+ async fn test_streaming_write_create_file(cx: &mut TestAppContext) {
+ let (write_tool, _project, _action_log, _fs, _thread) =
+ setup_test(cx, json!({"dir": {}})).await;
+ let result = cx
+ .update(|cx| {
+ write_tool.clone().run(
+ ToolInput::resolved(WriteFileToolInput {
+ path: "root/dir/new_file.txt".into(),
+ content: "Hello, World!".into(),
+ }),
+ ToolCallEventStream::test().0,
+ cx,
+ )
+ })
+ .await;
+
+ let EditSessionOutput::Success { new_text, diff, .. } = result.unwrap() else {
+ panic!("expected success");
+ };
+ assert_eq!(new_text, "Hello, World!");
+ assert!(!diff.is_empty());
+ }
+
+ #[gpui::test]
+ async fn test_streaming_write_overwrite_file(cx: &mut TestAppContext) {
+ let (write_tool, _project, _action_log, _fs, _thread) =
+ setup_test(cx, json!({"file.txt": "old content"})).await;
+ let result = cx
+ .update(|cx| {
+ write_tool.clone().run(
+ ToolInput::resolved(WriteFileToolInput {
+ path: "root/file.txt".into(),
+ content: "new content".into(),
+ }),
+ ToolCallEventStream::test().0,
+ cx,
+ )
+ })
+ .await;
+
+ let EditSessionOutput::Success {
+ new_text, old_text, ..
+ } = result.unwrap()
+ else {
+ panic!("expected success");
+ };
+ assert_eq!(new_text, "new content");
+ assert_eq!(*old_text, "old content");
+ }
+
+ #[gpui::test]
+ async fn test_streaming_path_completeness_heuristic(cx: &mut TestAppContext) {
+ let (write_tool, _project, _action_log, _fs, _thread) =
+ setup_test(cx, json!({"file.txt": "hello world"})).await;
+ let (mut sender, input) = ToolInput::<WriteFileToolInput>::test();
+ let (event_stream, _receiver) = ToolCallEventStream::test();
+ let task = cx.update(|cx| write_tool.clone().run(input, event_stream, cx));
+
+ // Send partial with path but NO mode — path should NOT be treated as complete
+ sender.send_partial(json!({
+ "path": "root/file"
+ }));
+ cx.run_until_parked();
+
+ // Now the path grows and mode appears
+ sender.send_partial(json!({
+ "path": "root/file.txt",
+ }));
+ cx.run_until_parked();
+
+ // Send final
+ sender.send_full(json!({
+ "path": "root/file.txt",
+ "content": "new content"
+ }));
+
+ let result = task.await;
+ let EditSessionOutput::Success { new_text, .. } = result.unwrap() else {
+ panic!("expected success");
+ };
+ assert_eq!(new_text, "new content");
+ }
+
+ #[gpui::test]
+ async fn test_streaming_create_file_with_partials(cx: &mut TestAppContext) {
+ let (write_tool, _project, _action_log, _fs, _thread) =
+ setup_test(cx, json!({"dir": {}})).await;
+ let (mut sender, input) = ToolInput::<WriteFileToolInput>::test();
+ let (event_stream, _receiver) = ToolCallEventStream::test();
+ let task = cx.update(|cx| write_tool.clone().run(input, event_stream, cx));
+
+ // Stream partials for create mode
+ sender.send_partial(json!({}));
+ cx.run_until_parked();
+
+ sender.send_partial(json!({
+ "path": "root/dir/new_file.txt",
+ }));
+ cx.run_until_parked();
+
+ sender.send_partial(json!({
+ "path": "root/dir/new_file.txt",
+ "content": "Hello, "
+ }));
+ cx.run_until_parked();
+
+ // Final with full content
+ sender.send_full(json!({
+ "path": "root/dir/new_file.txt",
+ "content": "Hello, World!"
+ }));
+
+ let result = task.await;
+ let EditSessionOutput::Success { new_text, .. } = result.unwrap() else {
+ panic!("expected success");
+ };
+ assert_eq!(new_text, "Hello, World!");
+ }
+
+ #[gpui::test]
+ async fn test_streaming_input_recv_drains_partials(cx: &mut TestAppContext) {
+ let (write_tool, _project, _action_log, _fs, _thread) =
+ setup_test(cx, json!({"dir": {}})).await;
+ // Create a channel and send multiple partials before a final, then use
+ // ToolInput::resolved-style immediate delivery to confirm recv() works
+ // when partials are already buffered.
+ let (mut sender, input): (ToolInputSender, ToolInput<WriteFileToolInput>) =
+ ToolInput::test();
+ let (event_stream, _event_rx) = ToolCallEventStream::test();
+ let task = cx.update(|cx| write_tool.clone().run(input, event_stream, cx));
+
+ // Buffer several partials before sending the final
+ sender.send_partial(json!({}));
+ sender.send_partial(json!({"path": "root/dir/new.txt"}));
+ sender.send_partial(json!({
+ "path": "root/dir/new.txt",
+ }));
+ sender.send_full(json!({
+ "path": "root/dir/new.txt",
+ "content": "streamed content"
+ }));
+
+ let result = task.await;
+ let EditSessionOutput::Success { new_text, .. } = result.unwrap() else {
+ panic!("expected success");
+ };
+ assert_eq!(new_text, "streamed content");
+ }
+
+ #[gpui::test]
+ async fn test_streaming_resolve_path_for_creating_file(cx: &mut TestAppContext) {
+ let mode = EditSessionMode::Write;
+
+ let result = test_resolve_path(&mode, "root/new.txt", cx);
+ assert_resolved_path_eq(result.await, rel_path("new.txt"));
+
+ let result = test_resolve_path(&mode, "new.txt", cx);
+ assert_resolved_path_eq(result.await, rel_path("new.txt"));
+
+ let result = test_resolve_path(&mode, "dir/new.txt", cx);
+ assert_resolved_path_eq(result.await, rel_path("dir/new.txt"));
+
+ let result = test_resolve_path(&mode, "root/dir/subdir/existing.txt", cx);
+ assert_resolved_path_eq(result.await, rel_path("dir/subdir/existing.txt"));
+
+ let result = test_resolve_path(&mode, "root/dir/subdir", cx);
+ assert_eq!(
+ result.await.unwrap_err(),
+ "Can't write to file: path is a directory"
+ );
+
+ let result = test_resolve_path(&mode, "root/dir/nonexistent_dir/new.txt", cx);
+ assert_eq!(
+ result.await.unwrap_err(),
+ "Can't create file: parent directory doesn't exist"
+ );
+ }
+
+ #[gpui::test]
+ async fn test_streaming_format_on_save(cx: &mut TestAppContext) {
+ init_test(cx);
+
+ let fs = project::FakeFs::new(cx.executor());
+ fs.insert_tree("/root", json!({"src": {}})).await;
+ let (write_tool, project, action_log, fs, thread) =
+ setup_test_with_fs(cx, fs, &[path!("/root").as_ref()]).await;
+
+ let rust_language = Arc::new(language::Language::new(
+ language::LanguageConfig {
+ name: "Rust".into(),
+ matcher: language::LanguageMatcher {
+ path_suffixes: vec!["rs".to_string()],
+ ..Default::default()
+ },
+ ..Default::default()
+ },
+ None,
+ ));
+
+ let language_registry = project.read_with(cx, |project, _| project.languages().clone());
+ language_registry.add(rust_language);
+
+ let mut fake_language_servers = language_registry.register_fake_lsp(
+ "Rust",
+ language::FakeLspAdapter {
+ capabilities: lsp::ServerCapabilities {
+ document_formatting_provider: Some(lsp::OneOf::Left(true)),
+ ..Default::default()
+ },
+ ..Default::default()
+ },
+ );
+
+ fs.save(
+ path!("/root/src/main.rs").as_ref(),
+ &"initial content".into(),
+ language::LineEnding::Unix,
+ )
+ .await
+ .unwrap();
+
+ // Open the buffer to trigger LSP initialization
+ let buffer = project
+ .update(cx, |project, cx| {
+ project.open_local_buffer(path!("/root/src/main.rs"), cx)
+ })
+ .await
+ .unwrap();
+
+ // Register the buffer with language servers
+ let _handle = project.update(cx, |project, cx| {
+ project.register_buffer_with_language_servers(&buffer, cx)
+ });
+
+ const UNFORMATTED_CONTENT: &str = "fn main() {println!(\"Hello!\");}\
+";
+ const FORMATTED_CONTENT: &str = "This file was formatted by the fake formatter in the test.\
+";
+
+ // Get the fake language server and set up formatting handler
+ let fake_language_server = fake_language_servers.next().await.unwrap();
+ fake_language_server.set_request_handler::<lsp::request::Formatting, _, _>({
+ |_, _| async move {
+ Ok(Some(vec![lsp::TextEdit {
+ range: lsp::Range::new(lsp::Position::new(0, 0), lsp::Position::new(1, 0)),
+ new_text: FORMATTED_CONTENT.to_string(),
+ }]))
+ }
+ });
+
+ // Test with format_on_save enabled
+ cx.update(|cx| {
+ SettingsStore::update_global(cx, |store, cx| {
+ store.update_user_settings(cx, |settings| {
+ settings.project.all_languages.defaults.format_on_save = Some(FormatOnSave::On);
+ settings.project.all_languages.defaults.formatter =
+ Some(language::language_settings::FormatterList::default());
+ });
+ });
+ });
+
+ // Use streaming pattern so executor can pump the LSP request/response
+ let (mut sender, input) = ToolInput::<WriteFileToolInput>::test();
+ let (event_stream, _receiver) = ToolCallEventStream::test();
+
+ let task = cx.update(|cx| write_tool.clone().run(input, event_stream, cx));
+
+ sender.send_partial(json!({
+ "path": "root/src/main.rs",
+ }));
+ cx.run_until_parked();
+
+ sender.send_full(json!({
+ "path": "root/src/main.rs",
+ "content": UNFORMATTED_CONTENT
+ }));
+
+ let result = task.await;
+ assert!(result.is_ok());
+
+ cx.executor().run_until_parked();
+
+ let new_content = fs.load(path!("/root/src/main.rs").as_ref()).await.unwrap();
+ assert_eq!(
+ new_content.replace("\r\n", "\n"),
+ FORMATTED_CONTENT,
+ "Code should be formatted when format_on_save is enabled"
+ );
+
+ let stale_buffer_count = thread
+ .read_with(cx, |thread, _cx| thread.action_log.clone())
+ .read_with(cx, |log, cx| log.stale_buffers(cx).count());
+
+ assert_eq!(
+ stale_buffer_count, 0,
+ "BUG: Buffer is incorrectly marked as stale after format-on-save. Found {} stale buffers.",
+ stale_buffer_count
+ );
+
+ // Test with format_on_save disabled
+ cx.update(|cx| {
+ SettingsStore::update_global(cx, |store, cx| {
+ store.update_user_settings(cx, |settings| {
+ settings.project.all_languages.defaults.format_on_save =
+ Some(FormatOnSave::Off);
+ });
+ });
+ });
+
+ let (mut sender, input) = ToolInput::<WriteFileToolInput>::test();
+ let (event_stream, _receiver) = ToolCallEventStream::test();
+
+ let tool2 = Arc::new(WriteFileTool::new(
+ project.clone(),
+ thread.downgrade(),
+ action_log.clone(),
+ language_registry,
+ ));
+
+ let task = cx.update(|cx| tool2.run(input, event_stream, cx));
+
+ sender.send_partial(json!({
+ "path": "root/src/main.rs",
+ }));
+ cx.run_until_parked();
+
+ sender.send_full(json!({
+ "path": "root/src/main.rs",
+ "content": UNFORMATTED_CONTENT
+ }));
+
+ let result = task.await;
+ assert!(result.is_ok());
+
+ cx.executor().run_until_parked();
+
+ let new_content = fs.load(path!("/root/src/main.rs").as_ref()).await.unwrap();
+ assert_eq!(
+ new_content.replace("\r\n", "\n"),
+ UNFORMATTED_CONTENT,
+ "Code should not be formatted when format_on_save is disabled"
+ );
+ }
+
+ #[gpui::test]
+ async fn test_streaming_remove_trailing_whitespace(cx: &mut TestAppContext) {
+ init_test(cx);
+
+ let fs = project::FakeFs::new(cx.executor());
+ fs.insert_tree("/root", json!({"src": {}})).await;
+ fs.save(
+ path!("/root/src/main.rs").as_ref(),
+ &"initial content".into(),
+ language::LineEnding::Unix,
+ )
+ .await
+ .unwrap();
+ let (write_tool, project, action_log, fs, thread) =
+ setup_test_with_fs(cx, fs, &[path!("/root").as_ref()]).await;
+ let language_registry = project.read_with(cx, |p, _cx| p.languages().clone());
+
+ // Test with remove_trailing_whitespace_on_save enabled
+ cx.update(|cx| {
+ SettingsStore::update_global(cx, |store, cx| {
+ store.update_user_settings(cx, |settings| {
+ settings
+ .project
+ .all_languages
+ .defaults
+ .remove_trailing_whitespace_on_save = Some(true);
+ });
+ });
+ });
+
+ const CONTENT_WITH_TRAILING_WHITESPACE: &str =
+ "fn main() { \n println!(\"Hello!\"); \n}\n";
+
+ let result = cx
+ .update(|cx| {
+ write_tool.clone().run(
+ ToolInput::resolved(WriteFileToolInput {
+ path: "root/src/main.rs".into(),
+ content: CONTENT_WITH_TRAILING_WHITESPACE.into(),
+ }),
+ ToolCallEventStream::test().0,
+ cx,
+ )
+ })
+ .await;
+ assert!(result.is_ok());
+
+ cx.executor().run_until_parked();
+
+ assert_eq!(
+ fs.load(path!("/root/src/main.rs").as_ref())
+ .await
+ .unwrap()
+ .replace("\r\n", "\n"),
+ "fn main() {\n println!(\"Hello!\");\n}\n",
+ "Trailing whitespace should be removed when remove_trailing_whitespace_on_save is enabled"
+ );
+
+ // Test with remove_trailing_whitespace_on_save disabled
+ cx.update(|cx| {
+ SettingsStore::update_global(cx, |store, cx| {
+ store.update_user_settings(cx, |settings| {
+ settings
+ .project
+ .all_languages
+ .defaults
+ .remove_trailing_whitespace_on_save = Some(false);
+ });
+ });
+ });
+
+ let tool2 = Arc::new(WriteFileTool::new(
+ project.clone(),
+ thread.downgrade(),
+ action_log.clone(),
+ language_registry,
+ ));
+
+ let result = cx
+ .update(|cx| {
+ tool2.run(
+ ToolInput::resolved(WriteFileToolInput {
+ path: "root/src/main.rs".into(),
+ content: CONTENT_WITH_TRAILING_WHITESPACE.into(),
+ }),
+ ToolCallEventStream::test().0,
+ cx,
+ )
+ })
+ .await;
+ assert!(result.is_ok());
+
+ cx.executor().run_until_parked();
+
+ let final_content = fs.load(path!("/root/src/main.rs").as_ref()).await.unwrap();
+ assert_eq!(
+ final_content.replace("\r\n", "\n"),
+ CONTENT_WITH_TRAILING_WHITESPACE,
+ "Trailing whitespace should remain when remove_trailing_whitespace_on_save is disabled"
+ );
+ }
+
+ #[gpui::test]
+ async fn test_streaming_diff_finalization(cx: &mut TestAppContext) {
+ init_test(cx);
+ let fs = project::FakeFs::new(cx.executor());
+ fs.insert_tree("/", json!({"main.rs": ""})).await;
+ let (write_tool, project, action_log, _fs, thread) =
+ setup_test_with_fs(cx, fs, &[path!("/").as_ref()]).await;
+ let language_registry = project.read_with(cx, |p, _cx| p.languages().clone());
+
+ // Ensure the diff is finalized after the edit completes.
+ {
+ let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
+ let edit = cx.update(|cx| {
+ write_tool.clone().run(
+ ToolInput::resolved(WriteFileToolInput {
+ path: path!("/main.rs").into(),
+ content: "new content".into(),
+ }),
+ stream_tx,
+ cx,
+ )
+ });
+ stream_rx.expect_update_fields().await;
+ let diff = stream_rx.expect_diff().await;
+ diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Pending(_))));
+ cx.run_until_parked();
+ edit.await.unwrap();
+ diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Finalized(_))));
+ }
+
+ // Ensure the diff is finalized if the tool call gets dropped.
+ {
+ let tool = Arc::new(WriteFileTool::new(
+ project.clone(),
+ thread.downgrade(),
+ action_log,
+ language_registry,
+ ));
+ let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
+ let edit = cx.update(|cx| {
+ tool.run(
+ ToolInput::resolved(WriteFileToolInput {
+ path: path!("/main.rs").into(),
+ content: "dropped content".into(),
+ }),
+ stream_tx,
+ cx,
+ )
+ });
+ stream_rx.expect_update_fields().await;
+ let diff = stream_rx.expect_diff().await;
+ diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Pending(_))));
+ drop(edit);
+ cx.run_until_parked();
+ diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Finalized(_))));
+ }
+ }
+
+ #[gpui::test]
+ async fn test_streaming_create_content_streamed(cx: &mut TestAppContext) {
+ let (write_tool, project, _action_log, _fs, _thread) =
+ setup_test(cx, json!({"dir": {}})).await;
+ let (mut sender, input) = ToolInput::<WriteFileToolInput>::test();
+ let (event_stream, _receiver) = ToolCallEventStream::test();
+ let task = cx.update(|cx| write_tool.clone().run(input, event_stream, cx));
+
+ // Transition to BufferResolved
+ sender.send_partial(json!({
+ "path": "root/dir/new_file.txt",
+ }));
+ cx.run_until_parked();
+
+ // Stream content incrementally
+ sender.send_partial(json!({
+ "path": "root/dir/new_file.txt",
+ "content": "line 1\n"
+ }));
+ cx.run_until_parked();
+
+ // Verify buffer has partial content
+ let buffer = project.update(cx, |project, cx| {
+ let path = project
+ .find_project_path("root/dir/new_file.txt", cx)
+ .unwrap();
+ project.get_open_buffer(&path, cx).unwrap()
+ });
+ assert_eq!(buffer.read_with(cx, |b, _| b.text()), "line 1\n");
+
+ // Stream more content
+ sender.send_partial(json!({
+ "path": "root/dir/new_file.txt",
+ "content": "line 1\nline 2\n"
+ }));
+ cx.run_until_parked();
+ assert_eq!(buffer.read_with(cx, |b, _| b.text()), "line 1\nline 2\n");
+
+ // Stream final chunk
+ sender.send_partial(json!({
+ "path": "root/dir/new_file.txt",
+ "content": "line 1\nline 2\nline 3\n"
+ }));
+ cx.run_until_parked();
+ assert_eq!(
+ buffer.read_with(cx, |b, _| b.text()),
+ "line 1\nline 2\nline 3\n"
+ );
+
+ // Send final input
+ sender.send_full(json!({
+ "path": "root/dir/new_file.txt",
+ "content": "line 1\nline 2\nline 3\n"
+ }));
+
+ let result = task.await;
+ let EditSessionOutput::Success { new_text, .. } = result.unwrap() else {
+ panic!("expected success");
+ };
+ assert_eq!(new_text, "line 1\nline 2\nline 3\n");
+ }
+
+ #[gpui::test]
+ async fn test_streaming_overwrite_diff_revealed_during_streaming(cx: &mut TestAppContext) {
+ let (write_tool, _project, _action_log, _fs, _thread) = setup_test(
+ cx,
+ json!({"file.txt": "old line 1\nold line 2\nold line 3\n"}),
+ )
+ .await;
+ let (mut sender, input) = ToolInput::<WriteFileToolInput>::test();
+ let (event_stream, mut receiver) = ToolCallEventStream::test();
+ let task = cx.update(|cx| write_tool.clone().run(input, event_stream, cx));
+
+ // Transition to BufferResolved
+ sender.send_partial(json!({
+ "path": "root/file.txt",
+ }));
+ cx.run_until_parked();
+
+ sender.send_partial(json!({
+ "path": "root/file.txt",
+ }));
+ cx.run_until_parked();
+
+ // Get the diff entity from the event stream
+ receiver.expect_update_fields().await;
+ let diff = receiver.expect_diff().await;
+
+ // Diff starts pending with no revealed ranges
+ diff.read_with(cx, |diff, cx| {
+ assert!(matches!(diff, Diff::Pending(_)));
+ assert!(!diff.has_revealed_range(cx));
+ });
+
+ // Stream first content chunk
+ sender.send_partial(json!({
+ "path": "root/file.txt",
+ "content": "new line 1\n"
+ }));
+ cx.run_until_parked();
+
+ // Diff should now have revealed ranges showing the new content
+ diff.read_with(cx, |diff, cx| {
+ assert!(diff.has_revealed_range(cx));
+ });
+
+ // Send final input
+ sender.send_full(json!({
+ "path": "root/file.txt",
+ "content": "new line 1\nnew line 2\n"
+ }));
+
+ let result = task.await;
+ let EditSessionOutput::Success {
+ new_text, old_text, ..
+ } = result.unwrap()
+ else {
+ panic!("expected success");
+ };
+ assert_eq!(new_text, "new line 1\nnew line 2\n");
+ assert_eq!(*old_text, "old line 1\nold line 2\nold line 3\n");
+
+ // Diff is finalized after completion
+ diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Finalized(_))));
+ }
+
+ #[gpui::test]
+ async fn test_streaming_overwrite_content_streamed(cx: &mut TestAppContext) {
+ let (write_tool, project, _action_log, _fs, _thread) = setup_test(
+ cx,
+ json!({"file.txt": "old line 1\nold line 2\nold line 3\n"}),
+ )
+ .await;
+ let (mut sender, input) = ToolInput::<WriteFileToolInput>::test();
+ let (event_stream, _receiver) = ToolCallEventStream::test();
+ let task = cx.update(|cx| write_tool.clone().run(input, event_stream, cx));
+
+ // Transition to BufferResolved
+ sender.send_partial(json!({
+ "path": "root/file.txt",
+ }));
+ cx.run_until_parked();
+
+ // Verify buffer still has old content (no content partial yet)
+ let buffer = project.update(cx, |project, cx| {
+ let path = project.find_project_path("root/file.txt", cx).unwrap();
+ project.open_buffer(path, cx)
+ });
+ let buffer = buffer.await.unwrap();
+ assert_eq!(
+ buffer.read_with(cx, |b, _| b.text()),
+ "old line 1\nold line 2\nold line 3\n"
+ );
+
+ // First content partial replaces old content
+ sender.send_partial(json!({
+ "path": "root/file.txt",
+ "content": "new line 1\n"
+ }));
+ cx.run_until_parked();
+ assert_eq!(buffer.read_with(cx, |b, _| b.text()), "new line 1\n");
+
+ // Subsequent content partials append
+ sender.send_partial(json!({
+ "path": "root/file.txt",
+ "content": "new line 1\nnew line 2\n"
+ }));
+ cx.run_until_parked();
+ assert_eq!(
+ buffer.read_with(cx, |b, _| b.text()),
+ "new line 1\nnew line 2\n"
+ );
+
+ // Send final input with complete content
+ sender.send_full(json!({
+ "path": "root/file.txt",
+ "content": "new line 1\nnew line 2\nnew line 3\n"
+ }));
+
+ let result = task.await;
+ let EditSessionOutput::Success {
+ new_text, old_text, ..
+ } = result.unwrap()
+ else {
+ panic!("expected success");
+ };
+ assert_eq!(new_text, "new line 1\nnew line 2\nnew line 3\n");
+ assert_eq!(*old_text, "old line 1\nold line 2\nold line 3\n");
+ }
+
+ #[gpui::test]
+ async fn test_streaming_write_file_tool_registers_changed_buffers(cx: &mut TestAppContext) {
+ let (write_tool, _project, action_log, _fs, _thread) =
+ setup_test(cx, json!({"file.txt": "original content"})).await;
+ cx.update(|cx| {
+ let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
+ settings.tool_permissions.default = settings::ToolPermissionMode::Allow;
+ agent_settings::AgentSettings::override_global(settings, cx);
+ });
+
+ let (event_stream, _rx) = ToolCallEventStream::test();
+ let task = cx.update(|cx| {
+ write_tool.clone().run(
+ ToolInput::resolved(WriteFileToolInput {
+ path: "root/file.txt".into(),
+ content: "completely new content".into(),
+ }),
+ event_stream,
+ cx,
+ )
+ });
+
+ let result = task.await;
+ assert!(result.is_ok(), "write should succeed: {:?}", result.err());
+
+ cx.run_until_parked();
+
+ let changed = action_log.read_with(cx, |log, cx| log.changed_buffers(cx));
+ assert!(
+ !changed.is_empty(),
+ "action_log.changed_buffers() should be non-empty after streaming write, \
+ but no changed buffers were found \u{2014} Accept All / Reject All will not appear"
+ );
+ }
+
+ #[gpui::test]
+ async fn test_streaming_write_file_tool_fields_out_of_order(cx: &mut TestAppContext) {
+ let (write_tool, _project, _action_log, _fs, _thread) =
+ setup_test(cx, json!({"file.txt": "old_content"})).await;
+ let (mut sender, input) = ToolInput::<WriteFileToolInput>::test();
+ let (event_stream, _receiver) = ToolCallEventStream::test();
+ let task = cx.update(|cx| write_tool.clone().run(input, event_stream, cx));
+
+ sender.send_partial(json!({
+ "content": "new_content"
+ }));
+ cx.run_until_parked();
+
+ sender.send_partial(json!({
+ "content": "new_content",
+ "path": "root"
+ }));
+ cx.run_until_parked();
+
+ // Send final.
+ sender.send_full(json!({
+ "content": "new_content",
+ "path": "root/file.txt"
+ }));
+
+ let result = task.await;
+ let EditSessionOutput::Success { new_text, .. } = result.unwrap() else {
+ panic!("expected success");
+ };
+ assert_eq!(new_text, "new_content");
+ }
+
+ #[gpui::test]
+ async fn test_streaming_reject_created_file_deletes_it(cx: &mut TestAppContext) {
+ let (write_tool, _project, action_log, fs, _thread) =
+ setup_test(cx, json!({"dir": {}})).await;
+ cx.update(|cx| {
+ let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
+ settings.tool_permissions.default = settings::ToolPermissionMode::Allow;
+ agent_settings::AgentSettings::override_global(settings, cx);
+ });
+
+ // Create a new file via the streaming write file tool
+ let (event_stream, _rx) = ToolCallEventStream::test();
+ let task = cx.update(|cx| {
+ write_tool.clone().run(
+ ToolInput::resolved(WriteFileToolInput {
+ path: "root/dir/new_file.txt".into(),
+ content: "Hello, World!".into(),
+ }),
+ event_stream,
+ cx,
+ )
+ });
+ let result = task.await;
+ assert!(result.is_ok(), "create should succeed: {:?}", result.err());
+ cx.run_until_parked();
+
+ assert!(
+ fs.is_file(path!("/root/dir/new_file.txt").as_ref()).await,
+ "file should exist after creation"
+ );
+
+ // Reject all edits — this should delete the newly created file
+ let changed = action_log.read_with(cx, |log, cx| log.changed_buffers(cx));
+ assert!(
+ !changed.is_empty(),
+ "action_log should track the created file as changed"
+ );
+
+ action_log
+ .update(cx, |log, cx| log.reject_all_edits(None, cx))
+ .await;
+ cx.run_until_parked();
+
+ assert!(
+ !fs.is_file(path!("/root/dir/new_file.txt").as_ref()).await,
+ "file should be deleted after rejecting creation, but an empty file was left behind"
+ );
+ }
+
+ async fn setup_test_with_fs(
+ cx: &mut TestAppContext,
+ fs: Arc<project::FakeFs>,
+ worktree_paths: &[&std::path::Path],
+ ) -> (
+ Arc<WriteFileTool>,
+ Entity<Project>,
+ Entity<ActionLog>,
+ Arc<project::FakeFs>,
+ Entity<Thread>,
+ ) {
+ let project = Project::test(fs.clone(), worktree_paths.iter().copied(), cx).await;
+ let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
+ let context_server_registry =
+ cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
+ let model = Arc::new(FakeLanguageModel::default());
+ let thread = cx.new(|cx| {
+ crate::Thread::new(
+ project.clone(),
+ cx.new(|_cx| ProjectContext::default()),
+ context_server_registry,
+ Templates::new(),
+ Some(model),
+ cx,
+ )
+ });
+ let action_log = thread.read_with(cx, |thread, _| thread.action_log().clone());
+ let write_tool = Arc::new(WriteFileTool::new(
+ project.clone(),
+ thread.downgrade(),
+ action_log.clone(),
+ language_registry,
+ ));
+ (write_tool, project, action_log, fs, thread)
+ }
+
+ async fn setup_test(
+ cx: &mut TestAppContext,
+ initial_tree: serde_json::Value,
+ ) -> (
+ Arc<WriteFileTool>,
+ Entity<Project>,
+ Entity<ActionLog>,
+ Arc<project::FakeFs>,
+ Entity<Thread>,
+ ) {
+ init_test(cx);
+ let fs = project::FakeFs::new(cx.executor());
+ fs.insert_tree("/root", initial_tree).await;
+ setup_test_with_fs(cx, fs, &[path!("/root").as_ref()]).await
+ }
+
+ async fn test_resolve_path(
+ mode: &EditSessionMode,
+ path: &str,
+ cx: &mut TestAppContext,
+ ) -> Result<ProjectPath, String> {
+ init_test(cx);
+ let fs = project::FakeFs::new(cx.executor());
+ fs.insert_tree(
+ "/root",
+ json!({
+ "dir": {
+ "subdir": {
+ "existing.txt": "content"
+ }
+ }
+ }),
+ )
+ .await;
+ let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
+
+ crate::tools::edit_session::test_resolve_path(mode, path, &project, cx).await
+ }
+
+ #[track_caller]
+ fn assert_resolved_path_eq(path: Result<ProjectPath, String>, expected: &RelPath) {
+ let actual = path.expect("Should return valid path").path;
+ assert_eq!(actual.as_ref(), expected);
+ }
+
+ fn init_test(cx: &mut TestAppContext) {
+ cx.update(|cx| {
+ let settings_store = SettingsStore::test(cx);
+ cx.set_global(settings_store);
+ SettingsStore::update_global(cx, |store: &mut SettingsStore, cx| {
+ store.update_user_settings(cx, |settings| {
+ settings
+ .project
+ .all_languages
+ .defaults
+ .ensure_final_newline_on_save = Some(false);
+ });
+ });
+ });
+ }
+}
@@ -17,4 +17,5 @@ pub use tool_permissions_setup::{
render_delete_path_tool_config, render_edit_file_tool_config, render_fetch_tool_config,
render_move_path_tool_config, render_restore_file_from_disk_tool_config,
render_save_file_tool_config, render_terminal_tool_config, render_web_search_tool_config,
+ render_write_file_tool_config,
};
@@ -32,6 +32,12 @@ const TOOLS: &[ToolInfo] = &[
description: "File editing operations",
regex_explanation: "Patterns are matched against the file path being edited.",
},
+ ToolInfo {
+ id: "write_file",
+ name: "Write File",
+ description: "File creation and overwrite operations",
+ regex_explanation: "Patterns are matched against the file path being written.",
+ },
ToolInfo {
id: "delete_path",
name: "Delete Path",
@@ -303,6 +309,7 @@ fn get_tool_render_fn(
match tool_id {
"terminal" => render_terminal_tool_config,
"edit_file" => render_edit_file_tool_config,
+ "write_file" => render_write_file_tool_config,
"delete_path" => render_delete_path_tool_config,
"copy_path" => render_copy_path_tool_config,
"move_path" => render_move_path_tool_config,
@@ -1383,6 +1390,7 @@ macro_rules! tool_config_page_fn {
tool_config_page_fn!(render_terminal_tool_config, "terminal");
tool_config_page_fn!(render_edit_file_tool_config, "edit_file");
+tool_config_page_fn!(render_write_file_tool_config, "write_file");
tool_config_page_fn!(render_delete_path_tool_config, "delete_path");
tool_config_page_fn!(render_copy_path_tool_config, "copy_path");
tool_config_page_fn!(render_move_path_tool_config, "move_path");