1use crate::{
2 AgentTool, Templates, Thread, ToolCallEventStream,
3 edit_agent::{EditAgent, EditAgentOutput, EditAgentOutputEvent, EditFormat},
4};
5use acp_thread::Diff;
6use agent_client_protocol::{self as acp, ToolCallLocation, ToolCallUpdateFields};
7use anyhow::{Context as _, Result, anyhow};
8use cloud_llm_client::CompletionIntent;
9use collections::HashSet;
10use gpui::{App, AppContext, AsyncApp, Entity, Task, WeakEntity};
11use indoc::formatdoc;
12use language::language_settings::{self, FormatOnSave};
13use language::{LanguageRegistry, ToPoint};
14use language_model::LanguageModelToolResultContent;
15use paths;
16use project::lsp_store::{FormatTrigger, LspFormatTarget};
17use project::{Project, ProjectPath};
18use schemars::JsonSchema;
19use serde::{Deserialize, Serialize};
20use settings::Settings;
21use smol::stream::StreamExt as _;
22use std::ffi::OsStr;
23use std::path::{Path, PathBuf};
24use std::sync::Arc;
25use ui::SharedString;
26use util::ResultExt;
27use util::rel_path::RelPath;
28
29const DEFAULT_UI_TEXT: &str = "Editing file";
30
31/// This is a tool for creating a new file or editing an existing file. For moving or renaming files, you should generally use the `terminal` tool with the 'mv' command instead.
32///
33/// Before using this tool:
34///
35/// 1. Use the `read_file` tool to understand the file's contents and context
36///
37/// 2. Verify the directory path is correct (only applicable when creating new files):
38/// - Use the `list_directory` tool to verify the parent directory exists and is the correct location
39#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
40pub struct EditFileToolInput {
41 /// A one-line, user-friendly markdown description of the edit. This will be shown in the UI and also passed to another model to perform the edit.
42 ///
43 /// Be terse, but also descriptive in what you want to achieve with this edit. Avoid generic instructions.
44 ///
45 /// NEVER mention the file path in this description.
46 ///
47 /// <example>Fix API endpoint URLs</example>
48 /// <example>Update copyright year in `page_footer`</example>
49 ///
50 /// Make sure to include this field before all the others in the input object so that we can display it immediately.
51 pub display_description: String,
52
53 /// The full path of the file to create or modify in the project.
54 ///
55 /// WARNING: When specifying which file path need changing, you MUST start each path with one of the project's root directories.
56 ///
57 /// The following examples assume we have two root directories in the project:
58 /// - /a/b/backend
59 /// - /c/d/frontend
60 ///
61 /// <example>
62 /// `backend/src/main.rs`
63 ///
64 /// Notice how the file path starts with `backend`. Without that, the path would be ambiguous and the call would fail!
65 /// </example>
66 ///
67 /// <example>
68 /// `frontend/db.js`
69 /// </example>
70 pub path: PathBuf,
71 /// The mode of operation on the file. Possible values:
72 /// - 'edit': Make granular edits to an existing file.
73 /// - 'create': Create a new file if it doesn't exist.
74 /// - 'overwrite': Replace the entire contents of an existing file.
75 ///
76 /// When a file already exists or you just created it, prefer editing it as opposed to recreating it from scratch.
77 pub mode: EditFileMode,
78}
79
80#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
81struct EditFileToolPartialInput {
82 #[serde(default)]
83 path: String,
84 #[serde(default)]
85 display_description: String,
86}
87
88#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
89#[serde(rename_all = "lowercase")]
90#[schemars(inline)]
91pub enum EditFileMode {
92 Edit,
93 Create,
94 Overwrite,
95}
96
97#[derive(Debug, Serialize, Deserialize)]
98pub struct EditFileToolOutput {
99 #[serde(alias = "original_path")]
100 input_path: PathBuf,
101 new_text: String,
102 old_text: Arc<String>,
103 #[serde(default)]
104 diff: String,
105 #[serde(alias = "raw_output")]
106 edit_agent_output: EditAgentOutput,
107}
108
109impl From<EditFileToolOutput> for LanguageModelToolResultContent {
110 fn from(output: EditFileToolOutput) -> Self {
111 if output.diff.is_empty() {
112 "No edits were made.".into()
113 } else {
114 format!(
115 "Edited {}:\n\n```diff\n{}\n```",
116 output.input_path.display(),
117 output.diff
118 )
119 .into()
120 }
121 }
122}
123
124pub struct EditFileTool {
125 thread: WeakEntity<Thread>,
126 language_registry: Arc<LanguageRegistry>,
127 project: Entity<Project>,
128 templates: Arc<Templates>,
129}
130
131impl EditFileTool {
132 pub fn new(
133 project: Entity<Project>,
134 thread: WeakEntity<Thread>,
135 language_registry: Arc<LanguageRegistry>,
136 templates: Arc<Templates>,
137 ) -> Self {
138 Self {
139 project,
140 thread,
141 language_registry,
142 templates,
143 }
144 }
145
146 fn authorize(
147 &self,
148 input: &EditFileToolInput,
149 event_stream: &ToolCallEventStream,
150 cx: &mut App,
151 ) -> Task<Result<()>> {
152 if agent_settings::AgentSettings::get_global(cx).always_allow_tool_actions {
153 return Task::ready(Ok(()));
154 }
155
156 // If any path component matches the local settings folder, then this could affect
157 // the editor in ways beyond the project source, so prompt.
158 let local_settings_folder = paths::local_settings_folder_name();
159 let path = Path::new(&input.path);
160 if path.components().any(|component| {
161 component.as_os_str() == <_ as AsRef<OsStr>>::as_ref(&local_settings_folder)
162 }) {
163 return event_stream.authorize(
164 format!("{} (local settings)", input.display_description),
165 cx,
166 );
167 }
168
169 // It's also possible that the global config dir is configured to be inside the project,
170 // so check for that edge case too.
171 // TODO this is broken when remoting
172 if let Ok(canonical_path) = std::fs::canonicalize(&input.path)
173 && canonical_path.starts_with(paths::config_dir())
174 {
175 return event_stream.authorize(
176 format!("{} (global settings)", input.display_description),
177 cx,
178 );
179 }
180
181 // Check if path is inside the global config directory
182 // First check if it's already inside project - if not, try to canonicalize
183 let Ok(project_path) = self.thread.read_with(cx, |thread, cx| {
184 thread.project().read(cx).find_project_path(&input.path, cx)
185 }) else {
186 return Task::ready(Err(anyhow!("thread was dropped")));
187 };
188
189 // If the path is inside the project, and it's not one of the above edge cases,
190 // then no confirmation is necessary. Otherwise, confirmation is necessary.
191 if project_path.is_some() {
192 Task::ready(Ok(()))
193 } else {
194 event_stream.authorize(&input.display_description, cx)
195 }
196 }
197}
198
199impl AgentTool for EditFileTool {
200 type Input = EditFileToolInput;
201 type Output = EditFileToolOutput;
202
203 fn name() -> &'static str {
204 "edit_file"
205 }
206
207 fn kind() -> acp::ToolKind {
208 acp::ToolKind::Edit
209 }
210
211 fn initial_title(
212 &self,
213 input: Result<Self::Input, serde_json::Value>,
214 cx: &mut App,
215 ) -> SharedString {
216 match input {
217 Ok(input) => self
218 .project
219 .read(cx)
220 .find_project_path(&input.path, cx)
221 .and_then(|project_path| {
222 self.project
223 .read(cx)
224 .short_full_path_for_project_path(&project_path, cx)
225 })
226 .unwrap_or(input.path.to_string_lossy().into_owned())
227 .into(),
228 Err(raw_input) => {
229 if let Some(input) =
230 serde_json::from_value::<EditFileToolPartialInput>(raw_input).ok()
231 {
232 let path = input.path.trim();
233 if !path.is_empty() {
234 return self
235 .project
236 .read(cx)
237 .find_project_path(&input.path, cx)
238 .and_then(|project_path| {
239 self.project
240 .read(cx)
241 .short_full_path_for_project_path(&project_path, cx)
242 })
243 .unwrap_or(input.path)
244 .into();
245 }
246
247 let description = input.display_description.trim();
248 if !description.is_empty() {
249 return description.to_string().into();
250 }
251 }
252
253 DEFAULT_UI_TEXT.into()
254 }
255 }
256 }
257
258 fn run(
259 self: Arc<Self>,
260 input: Self::Input,
261 event_stream: ToolCallEventStream,
262 cx: &mut App,
263 ) -> Task<Result<Self::Output>> {
264 let Ok(project) = self
265 .thread
266 .read_with(cx, |thread, _cx| thread.project().clone())
267 else {
268 return Task::ready(Err(anyhow!("thread was dropped")));
269 };
270 let project_path = match resolve_path(&input, project.clone(), cx) {
271 Ok(path) => path,
272 Err(err) => return Task::ready(Err(anyhow!(err))),
273 };
274 let abs_path = project.read(cx).absolute_path(&project_path, cx);
275 if let Some(abs_path) = abs_path.clone() {
276 event_stream.update_fields(
277 ToolCallUpdateFields::new().locations(vec![acp::ToolCallLocation::new(abs_path)]),
278 );
279 }
280
281 let authorize = self.authorize(&input, &event_stream, cx);
282 cx.spawn(async move |cx: &mut AsyncApp| {
283 authorize.await?;
284
285 let (request, model, action_log) = self.thread.update(cx, |thread, cx| {
286 let request = thread.build_completion_request(CompletionIntent::ToolResults, cx);
287 (request, thread.model().cloned(), thread.action_log().clone())
288 })?;
289 let request = request?;
290 let model = model.context("No language model configured")?;
291
292 let edit_format = EditFormat::from_model(model.clone())?;
293 let edit_agent = EditAgent::new(
294 model,
295 project.clone(),
296 action_log.clone(),
297 self.templates.clone(),
298 edit_format,
299 );
300
301 let buffer = project
302 .update(cx, |project, cx| {
303 project.open_buffer(project_path.clone(), cx)
304 })
305 .await?;
306
307 // Check if the file has been modified since the agent last read it
308 if let Some(abs_path) = abs_path.as_ref() {
309 let (last_read_mtime, current_mtime, is_dirty, has_save_tool, has_restore_tool) = self.thread.update(cx, |thread, cx| {
310 let last_read = thread.file_read_times.get(abs_path).copied();
311 let current = buffer.read(cx).file().and_then(|file| file.disk_state().mtime());
312 let dirty = buffer.read(cx).is_dirty();
313 let has_save = thread.has_tool("save_file");
314 let has_restore = thread.has_tool("restore_file_from_disk");
315 (last_read, current, dirty, has_save, has_restore)
316 })?;
317
318 // Check for unsaved changes first - these indicate modifications we don't know about
319 if is_dirty {
320 let message = match (has_save_tool, has_restore_tool) {
321 (true, true) => {
322 "This file has unsaved changes. Ask the user whether they want to keep or discard those changes. \
323 If they want to keep them, ask for confirmation then use the save_file tool to save the file, then retry this edit. \
324 If they want to discard them, ask for confirmation then use the restore_file_from_disk tool to restore the on-disk contents, then retry this edit."
325 }
326 (true, false) => {
327 "This file has unsaved changes. Ask the user whether they want to keep or discard those changes. \
328 If they want to keep them, ask for confirmation then use the save_file tool to save the file, then retry this edit. \
329 If they want to discard them, ask the user to manually revert the file, then inform you when it's ok to proceed."
330 }
331 (false, true) => {
332 "This file has unsaved changes. Ask the user whether they want to keep or discard those changes. \
333 If they want to keep them, ask the user to manually save the file, then inform you when it's ok to proceed. \
334 If they want to discard them, ask for confirmation then use the restore_file_from_disk tool to restore the on-disk contents, then retry this edit."
335 }
336 (false, false) => {
337 "This file has unsaved changes. Ask the user whether they want to keep or discard those changes, \
338 then ask them to save or revert the file manually and inform you when it's ok to proceed."
339 }
340 };
341 anyhow::bail!("{}", message);
342 }
343
344 // Check if the file was modified on disk since we last read it
345 if let (Some(last_read), Some(current)) = (last_read_mtime, current_mtime) {
346 // MTime can be unreliable for comparisons, so our newtype intentionally
347 // doesn't support comparing them. If the mtime at all different
348 // (which could be because of a modification or because e.g. system clock changed),
349 // we pessimistically assume it was modified.
350 if current != last_read {
351 anyhow::bail!(
352 "The file {} has been modified since you last read it. \
353 Please read the file again to get the current state before editing it.",
354 input.path.display()
355 );
356 }
357 }
358 }
359
360 let diff = cx.new(|cx| Diff::new(buffer.clone(), cx));
361 event_stream.update_diff(diff.clone());
362 let _finalize_diff = util::defer({
363 let diff = diff.downgrade();
364 let mut cx = cx.clone();
365 move || {
366 diff.update(&mut cx, |diff, cx| diff.finalize(cx)).ok();
367 }
368 });
369
370 let old_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
371 let old_text = cx
372 .background_spawn({
373 let old_snapshot = old_snapshot.clone();
374 async move { Arc::new(old_snapshot.text()) }
375 })
376 .await;
377
378
379 let (output, mut events) = if matches!(input.mode, EditFileMode::Edit) {
380 edit_agent.edit(
381 buffer.clone(),
382 input.display_description.clone(),
383 &request,
384 cx,
385 )
386 } else {
387 edit_agent.overwrite(
388 buffer.clone(),
389 input.display_description.clone(),
390 &request,
391 cx,
392 )
393 };
394
395 let mut hallucinated_old_text = false;
396 let mut ambiguous_ranges = Vec::new();
397 let mut emitted_location = false;
398 while let Some(event) = events.next().await {
399 match event {
400 EditAgentOutputEvent::Edited(range) => {
401 if !emitted_location {
402 let line = Some(buffer.update(cx, |buffer, _cx| {
403 range.start.to_point(&buffer.snapshot()).row
404 }));
405 if let Some(abs_path) = abs_path.clone() {
406 event_stream.update_fields(ToolCallUpdateFields::new().locations(vec![ToolCallLocation::new(abs_path).line(line)]));
407 }
408 emitted_location = true;
409 }
410 },
411 EditAgentOutputEvent::UnresolvedEditRange => hallucinated_old_text = true,
412 EditAgentOutputEvent::AmbiguousEditRange(ranges) => ambiguous_ranges = ranges,
413 EditAgentOutputEvent::ResolvingEditRange(range) => {
414 diff.update(cx, |card, cx| card.reveal_range(range.clone(), cx));
415 // if !emitted_location {
416 // let line = buffer.update(cx, |buffer, _cx| {
417 // range.start.to_point(&buffer.snapshot()).row
418 // }).ok();
419 // if let Some(abs_path) = abs_path.clone() {
420 // event_stream.update_fields(ToolCallUpdateFields {
421 // locations: Some(vec![ToolCallLocation { path: abs_path, line }]),
422 // ..Default::default()
423 // });
424 // }
425 // }
426 }
427 }
428 }
429
430 // If format_on_save is enabled, format the buffer
431 let format_on_save_enabled = buffer.read_with(cx, |buffer, cx| {
432 let settings = language_settings::language_settings(
433 buffer.language().map(|l| l.name()),
434 buffer.file(),
435 cx,
436 );
437 settings.format_on_save != FormatOnSave::Off
438 });
439
440 let edit_agent_output = output.await?;
441
442 if format_on_save_enabled {
443 action_log.update(cx, |log, cx| {
444 log.buffer_edited(buffer.clone(), cx);
445 });
446
447 let format_task = project.update(cx, |project, cx| {
448 project.format(
449 HashSet::from_iter([buffer.clone()]),
450 LspFormatTarget::Buffers,
451 false, // Don't push to history since the tool did it.
452 FormatTrigger::Save,
453 cx,
454 )
455 });
456 format_task.await.log_err();
457 }
458
459 project
460 .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))
461 .await?;
462
463 action_log.update(cx, |log, cx| {
464 log.buffer_edited(buffer.clone(), cx);
465 });
466
467 // Update the recorded read time after a successful edit so consecutive edits work
468 if let Some(abs_path) = abs_path.as_ref() {
469 if let Some(new_mtime) = buffer.read_with(cx, |buffer, _| {
470 buffer.file().and_then(|file| file.disk_state().mtime())
471 }) {
472 self.thread.update(cx, |thread, _| {
473 thread.file_read_times.insert(abs_path.to_path_buf(), new_mtime);
474 })?;
475 }
476 }
477
478 let new_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
479 let (new_text, unified_diff) = cx
480 .background_spawn({
481 let new_snapshot = new_snapshot.clone();
482 let old_text = old_text.clone();
483 async move {
484 let new_text = new_snapshot.text();
485 let diff = language::unified_diff(&old_text, &new_text);
486 (new_text, diff)
487 }
488 })
489 .await;
490
491 let input_path = input.path.display();
492 if unified_diff.is_empty() {
493 anyhow::ensure!(
494 !hallucinated_old_text,
495 formatdoc! {"
496 Some edits were produced but none of them could be applied.
497 Read the relevant sections of {input_path} again so that
498 I can perform the requested edits.
499 "}
500 );
501 anyhow::ensure!(
502 ambiguous_ranges.is_empty(),
503 {
504 let line_numbers = ambiguous_ranges
505 .iter()
506 .map(|range| range.start.to_string())
507 .collect::<Vec<_>>()
508 .join(", ");
509 formatdoc! {"
510 <old_text> matches more than one position in the file (lines: {line_numbers}). Read the
511 relevant sections of {input_path} again and extend <old_text> so
512 that I can perform the requested edits.
513 "}
514 }
515 );
516 }
517
518 Ok(EditFileToolOutput {
519 input_path: input.path,
520 new_text,
521 old_text,
522 diff: unified_diff,
523 edit_agent_output,
524 })
525 })
526 }
527
528 fn replay(
529 &self,
530 _input: Self::Input,
531 output: Self::Output,
532 event_stream: ToolCallEventStream,
533 cx: &mut App,
534 ) -> Result<()> {
535 event_stream.update_diff(cx.new(|cx| {
536 Diff::finalized(
537 output.input_path.to_string_lossy().into_owned(),
538 Some(output.old_text.to_string()),
539 output.new_text,
540 self.language_registry.clone(),
541 cx,
542 )
543 }));
544 Ok(())
545 }
546}
547
548/// Validate that the file path is valid, meaning:
549///
550/// - For `edit` and `overwrite`, the path must point to an existing file.
551/// - For `create`, the file must not already exist, but it's parent dir must exist.
552fn resolve_path(
553 input: &EditFileToolInput,
554 project: Entity<Project>,
555 cx: &mut App,
556) -> Result<ProjectPath> {
557 let project = project.read(cx);
558
559 match input.mode {
560 EditFileMode::Edit | EditFileMode::Overwrite => {
561 let path = project
562 .find_project_path(&input.path, cx)
563 .context("Can't edit file: path not found")?;
564
565 let entry = project
566 .entry_for_path(&path, cx)
567 .context("Can't edit file: path not found")?;
568
569 anyhow::ensure!(entry.is_file(), "Can't edit file: path is a directory");
570 Ok(path)
571 }
572
573 EditFileMode::Create => {
574 if let Some(path) = project.find_project_path(&input.path, cx) {
575 anyhow::ensure!(
576 project.entry_for_path(&path, cx).is_none(),
577 "Can't create file: file already exists"
578 );
579 }
580
581 let parent_path = input
582 .path
583 .parent()
584 .context("Can't create file: incorrect path")?;
585
586 let parent_project_path = project.find_project_path(&parent_path, cx);
587
588 let parent_entry = parent_project_path
589 .as_ref()
590 .and_then(|path| project.entry_for_path(path, cx))
591 .context("Can't create file: parent directory doesn't exist")?;
592
593 anyhow::ensure!(
594 parent_entry.is_dir(),
595 "Can't create file: parent is not a directory"
596 );
597
598 let file_name = input
599 .path
600 .file_name()
601 .and_then(|file_name| file_name.to_str())
602 .and_then(|file_name| RelPath::unix(file_name).ok())
603 .context("Can't create file: invalid filename")?;
604
605 let new_file_path = parent_project_path.map(|parent| ProjectPath {
606 path: parent.path.join(file_name),
607 ..parent
608 });
609
610 new_file_path.context("Can't create file")
611 }
612 }
613}
614
615#[cfg(test)]
616mod tests {
617 use super::*;
618 use crate::{ContextServerRegistry, Templates};
619 use fs::Fs;
620 use gpui::{TestAppContext, UpdateGlobal};
621 use language_model::fake_provider::FakeLanguageModel;
622 use prompt_store::ProjectContext;
623 use serde_json::json;
624 use settings::SettingsStore;
625 use util::{path, rel_path::rel_path};
626
627 #[gpui::test]
628 async fn test_edit_nonexistent_file(cx: &mut TestAppContext) {
629 init_test(cx);
630
631 let fs = project::FakeFs::new(cx.executor());
632 fs.insert_tree("/root", json!({})).await;
633 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
634 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
635 let context_server_registry =
636 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
637 let model = Arc::new(FakeLanguageModel::default());
638 let thread = cx.new(|cx| {
639 Thread::new(
640 project.clone(),
641 cx.new(|_cx| ProjectContext::default()),
642 context_server_registry,
643 Templates::new(),
644 Some(model),
645 cx,
646 )
647 });
648 let result = cx
649 .update(|cx| {
650 let input = EditFileToolInput {
651 display_description: "Some edit".into(),
652 path: "root/nonexistent_file.txt".into(),
653 mode: EditFileMode::Edit,
654 };
655 Arc::new(EditFileTool::new(
656 project,
657 thread.downgrade(),
658 language_registry,
659 Templates::new(),
660 ))
661 .run(input, ToolCallEventStream::test().0, cx)
662 })
663 .await;
664 assert_eq!(
665 result.unwrap_err().to_string(),
666 "Can't edit file: path not found"
667 );
668 }
669
670 #[gpui::test]
671 async fn test_resolve_path_for_creating_file(cx: &mut TestAppContext) {
672 let mode = &EditFileMode::Create;
673
674 let result = test_resolve_path(mode, "root/new.txt", cx);
675 assert_resolved_path_eq(result.await, rel_path("new.txt"));
676
677 let result = test_resolve_path(mode, "new.txt", cx);
678 assert_resolved_path_eq(result.await, rel_path("new.txt"));
679
680 let result = test_resolve_path(mode, "dir/new.txt", cx);
681 assert_resolved_path_eq(result.await, rel_path("dir/new.txt"));
682
683 let result = test_resolve_path(mode, "root/dir/subdir/existing.txt", cx);
684 assert_eq!(
685 result.await.unwrap_err().to_string(),
686 "Can't create file: file already exists"
687 );
688
689 let result = test_resolve_path(mode, "root/dir/nonexistent_dir/new.txt", cx);
690 assert_eq!(
691 result.await.unwrap_err().to_string(),
692 "Can't create file: parent directory doesn't exist"
693 );
694 }
695
696 #[gpui::test]
697 async fn test_resolve_path_for_editing_file(cx: &mut TestAppContext) {
698 let mode = &EditFileMode::Edit;
699
700 let path_with_root = "root/dir/subdir/existing.txt";
701 let path_without_root = "dir/subdir/existing.txt";
702 let result = test_resolve_path(mode, path_with_root, cx);
703 assert_resolved_path_eq(result.await, rel_path(path_without_root));
704
705 let result = test_resolve_path(mode, path_without_root, cx);
706 assert_resolved_path_eq(result.await, rel_path(path_without_root));
707
708 let result = test_resolve_path(mode, "root/nonexistent.txt", cx);
709 assert_eq!(
710 result.await.unwrap_err().to_string(),
711 "Can't edit file: path not found"
712 );
713
714 let result = test_resolve_path(mode, "root/dir", cx);
715 assert_eq!(
716 result.await.unwrap_err().to_string(),
717 "Can't edit file: path is a directory"
718 );
719 }
720
721 async fn test_resolve_path(
722 mode: &EditFileMode,
723 path: &str,
724 cx: &mut TestAppContext,
725 ) -> anyhow::Result<ProjectPath> {
726 init_test(cx);
727
728 let fs = project::FakeFs::new(cx.executor());
729 fs.insert_tree(
730 "/root",
731 json!({
732 "dir": {
733 "subdir": {
734 "existing.txt": "hello"
735 }
736 }
737 }),
738 )
739 .await;
740 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
741
742 let input = EditFileToolInput {
743 display_description: "Some edit".into(),
744 path: path.into(),
745 mode: mode.clone(),
746 };
747
748 cx.update(|cx| resolve_path(&input, project, cx))
749 }
750
751 #[track_caller]
752 fn assert_resolved_path_eq(path: anyhow::Result<ProjectPath>, expected: &RelPath) {
753 let actual = path.expect("Should return valid path").path;
754 assert_eq!(actual.as_ref(), expected);
755 }
756
757 #[gpui::test]
758 async fn test_format_on_save(cx: &mut TestAppContext) {
759 init_test(cx);
760
761 let fs = project::FakeFs::new(cx.executor());
762 fs.insert_tree("/root", json!({"src": {}})).await;
763
764 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
765
766 // Set up a Rust language with LSP formatting support
767 let rust_language = Arc::new(language::Language::new(
768 language::LanguageConfig {
769 name: "Rust".into(),
770 matcher: language::LanguageMatcher {
771 path_suffixes: vec!["rs".to_string()],
772 ..Default::default()
773 },
774 ..Default::default()
775 },
776 None,
777 ));
778
779 // Register the language and fake LSP
780 let language_registry = project.read_with(cx, |project, _| project.languages().clone());
781 language_registry.add(rust_language);
782
783 let mut fake_language_servers = language_registry.register_fake_lsp(
784 "Rust",
785 language::FakeLspAdapter {
786 capabilities: lsp::ServerCapabilities {
787 document_formatting_provider: Some(lsp::OneOf::Left(true)),
788 ..Default::default()
789 },
790 ..Default::default()
791 },
792 );
793
794 // Create the file
795 fs.save(
796 path!("/root/src/main.rs").as_ref(),
797 &"initial content".into(),
798 language::LineEnding::Unix,
799 )
800 .await
801 .unwrap();
802
803 // Open the buffer to trigger LSP initialization
804 let buffer = project
805 .update(cx, |project, cx| {
806 project.open_local_buffer(path!("/root/src/main.rs"), cx)
807 })
808 .await
809 .unwrap();
810
811 // Register the buffer with language servers
812 let _handle = project.update(cx, |project, cx| {
813 project.register_buffer_with_language_servers(&buffer, cx)
814 });
815
816 const UNFORMATTED_CONTENT: &str = "fn main() {println!(\"Hello!\");}\n";
817 const FORMATTED_CONTENT: &str =
818 "This file was formatted by the fake formatter in the test.\n";
819
820 // Get the fake language server and set up formatting handler
821 let fake_language_server = fake_language_servers.next().await.unwrap();
822 fake_language_server.set_request_handler::<lsp::request::Formatting, _, _>({
823 |_, _| async move {
824 Ok(Some(vec![lsp::TextEdit {
825 range: lsp::Range::new(lsp::Position::new(0, 0), lsp::Position::new(1, 0)),
826 new_text: FORMATTED_CONTENT.to_string(),
827 }]))
828 }
829 });
830
831 let context_server_registry =
832 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
833 let model = Arc::new(FakeLanguageModel::default());
834 let thread = cx.new(|cx| {
835 Thread::new(
836 project.clone(),
837 cx.new(|_cx| ProjectContext::default()),
838 context_server_registry,
839 Templates::new(),
840 Some(model.clone()),
841 cx,
842 )
843 });
844
845 // First, test with format_on_save enabled
846 cx.update(|cx| {
847 SettingsStore::update_global(cx, |store, cx| {
848 store.update_user_settings(cx, |settings| {
849 settings.project.all_languages.defaults.format_on_save = Some(FormatOnSave::On);
850 settings.project.all_languages.defaults.formatter =
851 Some(language::language_settings::FormatterList::default());
852 });
853 });
854 });
855
856 // Have the model stream unformatted content
857 let edit_result = {
858 let edit_task = cx.update(|cx| {
859 let input = EditFileToolInput {
860 display_description: "Create main function".into(),
861 path: "root/src/main.rs".into(),
862 mode: EditFileMode::Overwrite,
863 };
864 Arc::new(EditFileTool::new(
865 project.clone(),
866 thread.downgrade(),
867 language_registry.clone(),
868 Templates::new(),
869 ))
870 .run(input, ToolCallEventStream::test().0, cx)
871 });
872
873 // Stream the unformatted content
874 cx.executor().run_until_parked();
875 model.send_last_completion_stream_text_chunk(UNFORMATTED_CONTENT.to_string());
876 model.end_last_completion_stream();
877
878 edit_task.await
879 };
880 assert!(edit_result.is_ok());
881
882 // Wait for any async operations (e.g. formatting) to complete
883 cx.executor().run_until_parked();
884
885 // Read the file to verify it was formatted automatically
886 let new_content = fs.load(path!("/root/src/main.rs").as_ref()).await.unwrap();
887 assert_eq!(
888 // Ignore carriage returns on Windows
889 new_content.replace("\r\n", "\n"),
890 FORMATTED_CONTENT,
891 "Code should be formatted when format_on_save is enabled"
892 );
893
894 let stale_buffer_count = thread
895 .read_with(cx, |thread, _cx| thread.action_log.clone())
896 .read_with(cx, |log, cx| log.stale_buffers(cx).count());
897
898 assert_eq!(
899 stale_buffer_count, 0,
900 "BUG: Buffer is incorrectly marked as stale after format-on-save. Found {} stale buffers. \
901 This causes the agent to think the file was modified externally when it was just formatted.",
902 stale_buffer_count
903 );
904
905 // Next, test with format_on_save disabled
906 cx.update(|cx| {
907 SettingsStore::update_global(cx, |store, cx| {
908 store.update_user_settings(cx, |settings| {
909 settings.project.all_languages.defaults.format_on_save =
910 Some(FormatOnSave::Off);
911 });
912 });
913 });
914
915 // Stream unformatted edits again
916 let edit_result = {
917 let edit_task = cx.update(|cx| {
918 let input = EditFileToolInput {
919 display_description: "Update main function".into(),
920 path: "root/src/main.rs".into(),
921 mode: EditFileMode::Overwrite,
922 };
923 Arc::new(EditFileTool::new(
924 project.clone(),
925 thread.downgrade(),
926 language_registry,
927 Templates::new(),
928 ))
929 .run(input, ToolCallEventStream::test().0, cx)
930 });
931
932 // Stream the unformatted content
933 cx.executor().run_until_parked();
934 model.send_last_completion_stream_text_chunk(UNFORMATTED_CONTENT.to_string());
935 model.end_last_completion_stream();
936
937 edit_task.await
938 };
939 assert!(edit_result.is_ok());
940
941 // Wait for any async operations (e.g. formatting) to complete
942 cx.executor().run_until_parked();
943
944 // Verify the file was not formatted
945 let new_content = fs.load(path!("/root/src/main.rs").as_ref()).await.unwrap();
946 assert_eq!(
947 // Ignore carriage returns on Windows
948 new_content.replace("\r\n", "\n"),
949 UNFORMATTED_CONTENT,
950 "Code should not be formatted when format_on_save is disabled"
951 );
952 }
953
954 #[gpui::test]
955 async fn test_remove_trailing_whitespace(cx: &mut TestAppContext) {
956 init_test(cx);
957
958 let fs = project::FakeFs::new(cx.executor());
959 fs.insert_tree("/root", json!({"src": {}})).await;
960
961 // Create a simple file with trailing whitespace
962 fs.save(
963 path!("/root/src/main.rs").as_ref(),
964 &"initial content".into(),
965 language::LineEnding::Unix,
966 )
967 .await
968 .unwrap();
969
970 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
971 let context_server_registry =
972 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
973 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
974 let model = Arc::new(FakeLanguageModel::default());
975 let thread = cx.new(|cx| {
976 Thread::new(
977 project.clone(),
978 cx.new(|_cx| ProjectContext::default()),
979 context_server_registry,
980 Templates::new(),
981 Some(model.clone()),
982 cx,
983 )
984 });
985
986 // First, test with remove_trailing_whitespace_on_save enabled
987 cx.update(|cx| {
988 SettingsStore::update_global(cx, |store, cx| {
989 store.update_user_settings(cx, |settings| {
990 settings
991 .project
992 .all_languages
993 .defaults
994 .remove_trailing_whitespace_on_save = Some(true);
995 });
996 });
997 });
998
999 const CONTENT_WITH_TRAILING_WHITESPACE: &str =
1000 "fn main() { \n println!(\"Hello!\"); \n}\n";
1001
1002 // Have the model stream content that contains trailing whitespace
1003 let edit_result = {
1004 let edit_task = cx.update(|cx| {
1005 let input = EditFileToolInput {
1006 display_description: "Create main function".into(),
1007 path: "root/src/main.rs".into(),
1008 mode: EditFileMode::Overwrite,
1009 };
1010 Arc::new(EditFileTool::new(
1011 project.clone(),
1012 thread.downgrade(),
1013 language_registry.clone(),
1014 Templates::new(),
1015 ))
1016 .run(input, ToolCallEventStream::test().0, cx)
1017 });
1018
1019 // Stream the content with trailing whitespace
1020 cx.executor().run_until_parked();
1021 model.send_last_completion_stream_text_chunk(
1022 CONTENT_WITH_TRAILING_WHITESPACE.to_string(),
1023 );
1024 model.end_last_completion_stream();
1025
1026 edit_task.await
1027 };
1028 assert!(edit_result.is_ok());
1029
1030 // Wait for any async operations (e.g. formatting) to complete
1031 cx.executor().run_until_parked();
1032
1033 // Read the file to verify trailing whitespace was removed automatically
1034 assert_eq!(
1035 // Ignore carriage returns on Windows
1036 fs.load(path!("/root/src/main.rs").as_ref())
1037 .await
1038 .unwrap()
1039 .replace("\r\n", "\n"),
1040 "fn main() {\n println!(\"Hello!\");\n}\n",
1041 "Trailing whitespace should be removed when remove_trailing_whitespace_on_save is enabled"
1042 );
1043
1044 // Next, test with remove_trailing_whitespace_on_save disabled
1045 cx.update(|cx| {
1046 SettingsStore::update_global(cx, |store, cx| {
1047 store.update_user_settings(cx, |settings| {
1048 settings
1049 .project
1050 .all_languages
1051 .defaults
1052 .remove_trailing_whitespace_on_save = Some(false);
1053 });
1054 });
1055 });
1056
1057 // Stream edits again with trailing whitespace
1058 let edit_result = {
1059 let edit_task = cx.update(|cx| {
1060 let input = EditFileToolInput {
1061 display_description: "Update main function".into(),
1062 path: "root/src/main.rs".into(),
1063 mode: EditFileMode::Overwrite,
1064 };
1065 Arc::new(EditFileTool::new(
1066 project.clone(),
1067 thread.downgrade(),
1068 language_registry,
1069 Templates::new(),
1070 ))
1071 .run(input, ToolCallEventStream::test().0, cx)
1072 });
1073
1074 // Stream the content with trailing whitespace
1075 cx.executor().run_until_parked();
1076 model.send_last_completion_stream_text_chunk(
1077 CONTENT_WITH_TRAILING_WHITESPACE.to_string(),
1078 );
1079 model.end_last_completion_stream();
1080
1081 edit_task.await
1082 };
1083 assert!(edit_result.is_ok());
1084
1085 // Wait for any async operations (e.g. formatting) to complete
1086 cx.executor().run_until_parked();
1087
1088 // Verify the file still has trailing whitespace
1089 // Read the file again - it should still have trailing whitespace
1090 let final_content = fs.load(path!("/root/src/main.rs").as_ref()).await.unwrap();
1091 assert_eq!(
1092 // Ignore carriage returns on Windows
1093 final_content.replace("\r\n", "\n"),
1094 CONTENT_WITH_TRAILING_WHITESPACE,
1095 "Trailing whitespace should remain when remove_trailing_whitespace_on_save is disabled"
1096 );
1097 }
1098
1099 #[gpui::test]
1100 async fn test_authorize(cx: &mut TestAppContext) {
1101 init_test(cx);
1102 let fs = project::FakeFs::new(cx.executor());
1103 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
1104 let context_server_registry =
1105 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1106 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1107 let model = Arc::new(FakeLanguageModel::default());
1108 let thread = cx.new(|cx| {
1109 Thread::new(
1110 project.clone(),
1111 cx.new(|_cx| ProjectContext::default()),
1112 context_server_registry,
1113 Templates::new(),
1114 Some(model.clone()),
1115 cx,
1116 )
1117 });
1118 let tool = Arc::new(EditFileTool::new(
1119 project.clone(),
1120 thread.downgrade(),
1121 language_registry,
1122 Templates::new(),
1123 ));
1124 fs.insert_tree("/root", json!({})).await;
1125
1126 // Test 1: Path with .zed component should require confirmation
1127 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1128 let _auth = cx.update(|cx| {
1129 tool.authorize(
1130 &EditFileToolInput {
1131 display_description: "test 1".into(),
1132 path: ".zed/settings.json".into(),
1133 mode: EditFileMode::Edit,
1134 },
1135 &stream_tx,
1136 cx,
1137 )
1138 });
1139
1140 let event = stream_rx.expect_authorization().await;
1141 assert_eq!(
1142 event.tool_call.fields.title,
1143 Some("test 1 (local settings)".into())
1144 );
1145
1146 // Test 2: Path outside project should require confirmation
1147 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1148 let _auth = cx.update(|cx| {
1149 tool.authorize(
1150 &EditFileToolInput {
1151 display_description: "test 2".into(),
1152 path: "/etc/hosts".into(),
1153 mode: EditFileMode::Edit,
1154 },
1155 &stream_tx,
1156 cx,
1157 )
1158 });
1159
1160 let event = stream_rx.expect_authorization().await;
1161 assert_eq!(event.tool_call.fields.title, Some("test 2".into()));
1162
1163 // Test 3: Relative path without .zed should not require confirmation
1164 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1165 cx.update(|cx| {
1166 tool.authorize(
1167 &EditFileToolInput {
1168 display_description: "test 3".into(),
1169 path: "root/src/main.rs".into(),
1170 mode: EditFileMode::Edit,
1171 },
1172 &stream_tx,
1173 cx,
1174 )
1175 })
1176 .await
1177 .unwrap();
1178 assert!(stream_rx.try_next().is_err());
1179
1180 // Test 4: Path with .zed in the middle should require confirmation
1181 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1182 let _auth = cx.update(|cx| {
1183 tool.authorize(
1184 &EditFileToolInput {
1185 display_description: "test 4".into(),
1186 path: "root/.zed/tasks.json".into(),
1187 mode: EditFileMode::Edit,
1188 },
1189 &stream_tx,
1190 cx,
1191 )
1192 });
1193 let event = stream_rx.expect_authorization().await;
1194 assert_eq!(
1195 event.tool_call.fields.title,
1196 Some("test 4 (local settings)".into())
1197 );
1198
1199 // Test 5: When always_allow_tool_actions is enabled, no confirmation needed
1200 cx.update(|cx| {
1201 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
1202 settings.always_allow_tool_actions = true;
1203 agent_settings::AgentSettings::override_global(settings, cx);
1204 });
1205
1206 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1207 cx.update(|cx| {
1208 tool.authorize(
1209 &EditFileToolInput {
1210 display_description: "test 5.1".into(),
1211 path: ".zed/settings.json".into(),
1212 mode: EditFileMode::Edit,
1213 },
1214 &stream_tx,
1215 cx,
1216 )
1217 })
1218 .await
1219 .unwrap();
1220 assert!(stream_rx.try_next().is_err());
1221
1222 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1223 cx.update(|cx| {
1224 tool.authorize(
1225 &EditFileToolInput {
1226 display_description: "test 5.2".into(),
1227 path: "/etc/hosts".into(),
1228 mode: EditFileMode::Edit,
1229 },
1230 &stream_tx,
1231 cx,
1232 )
1233 })
1234 .await
1235 .unwrap();
1236 assert!(stream_rx.try_next().is_err());
1237 }
1238
1239 #[gpui::test]
1240 async fn test_authorize_global_config(cx: &mut TestAppContext) {
1241 init_test(cx);
1242 let fs = project::FakeFs::new(cx.executor());
1243 fs.insert_tree("/project", json!({})).await;
1244 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1245 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1246 let context_server_registry =
1247 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1248 let model = Arc::new(FakeLanguageModel::default());
1249 let thread = cx.new(|cx| {
1250 Thread::new(
1251 project.clone(),
1252 cx.new(|_cx| ProjectContext::default()),
1253 context_server_registry,
1254 Templates::new(),
1255 Some(model.clone()),
1256 cx,
1257 )
1258 });
1259 let tool = Arc::new(EditFileTool::new(
1260 project.clone(),
1261 thread.downgrade(),
1262 language_registry,
1263 Templates::new(),
1264 ));
1265
1266 // Test global config paths - these should require confirmation if they exist and are outside the project
1267 let test_cases = vec![
1268 (
1269 "/etc/hosts",
1270 true,
1271 "System file should require confirmation",
1272 ),
1273 (
1274 "/usr/local/bin/script",
1275 true,
1276 "System bin file should require confirmation",
1277 ),
1278 (
1279 "project/normal_file.rs",
1280 false,
1281 "Normal project file should not require confirmation",
1282 ),
1283 ];
1284
1285 for (path, should_confirm, description) in test_cases {
1286 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1287 let auth = cx.update(|cx| {
1288 tool.authorize(
1289 &EditFileToolInput {
1290 display_description: "Edit file".into(),
1291 path: path.into(),
1292 mode: EditFileMode::Edit,
1293 },
1294 &stream_tx,
1295 cx,
1296 )
1297 });
1298
1299 if should_confirm {
1300 stream_rx.expect_authorization().await;
1301 } else {
1302 auth.await.unwrap();
1303 assert!(
1304 stream_rx.try_next().is_err(),
1305 "Failed for case: {} - path: {} - expected no confirmation but got one",
1306 description,
1307 path
1308 );
1309 }
1310 }
1311 }
1312
1313 #[gpui::test]
1314 async fn test_needs_confirmation_with_multiple_worktrees(cx: &mut TestAppContext) {
1315 init_test(cx);
1316 let fs = project::FakeFs::new(cx.executor());
1317
1318 // Create multiple worktree directories
1319 fs.insert_tree(
1320 "/workspace/frontend",
1321 json!({
1322 "src": {
1323 "main.js": "console.log('frontend');"
1324 }
1325 }),
1326 )
1327 .await;
1328 fs.insert_tree(
1329 "/workspace/backend",
1330 json!({
1331 "src": {
1332 "main.rs": "fn main() {}"
1333 }
1334 }),
1335 )
1336 .await;
1337 fs.insert_tree(
1338 "/workspace/shared",
1339 json!({
1340 ".zed": {
1341 "settings.json": "{}"
1342 }
1343 }),
1344 )
1345 .await;
1346
1347 // Create project with multiple worktrees
1348 let project = Project::test(
1349 fs.clone(),
1350 [
1351 path!("/workspace/frontend").as_ref(),
1352 path!("/workspace/backend").as_ref(),
1353 path!("/workspace/shared").as_ref(),
1354 ],
1355 cx,
1356 )
1357 .await;
1358 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1359 let context_server_registry =
1360 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1361 let model = Arc::new(FakeLanguageModel::default());
1362 let thread = cx.new(|cx| {
1363 Thread::new(
1364 project.clone(),
1365 cx.new(|_cx| ProjectContext::default()),
1366 context_server_registry.clone(),
1367 Templates::new(),
1368 Some(model.clone()),
1369 cx,
1370 )
1371 });
1372 let tool = Arc::new(EditFileTool::new(
1373 project.clone(),
1374 thread.downgrade(),
1375 language_registry,
1376 Templates::new(),
1377 ));
1378
1379 // Test files in different worktrees
1380 let test_cases = vec![
1381 ("frontend/src/main.js", false, "File in first worktree"),
1382 ("backend/src/main.rs", false, "File in second worktree"),
1383 (
1384 "shared/.zed/settings.json",
1385 true,
1386 ".zed file in third worktree",
1387 ),
1388 ("/etc/hosts", true, "Absolute path outside all worktrees"),
1389 (
1390 "../outside/file.txt",
1391 true,
1392 "Relative path outside worktrees",
1393 ),
1394 ];
1395
1396 for (path, should_confirm, description) in test_cases {
1397 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1398 let auth = cx.update(|cx| {
1399 tool.authorize(
1400 &EditFileToolInput {
1401 display_description: "Edit file".into(),
1402 path: path.into(),
1403 mode: EditFileMode::Edit,
1404 },
1405 &stream_tx,
1406 cx,
1407 )
1408 });
1409
1410 if should_confirm {
1411 stream_rx.expect_authorization().await;
1412 } else {
1413 auth.await.unwrap();
1414 assert!(
1415 stream_rx.try_next().is_err(),
1416 "Failed for case: {} - path: {} - expected no confirmation but got one",
1417 description,
1418 path
1419 );
1420 }
1421 }
1422 }
1423
1424 #[gpui::test]
1425 async fn test_needs_confirmation_edge_cases(cx: &mut TestAppContext) {
1426 init_test(cx);
1427 let fs = project::FakeFs::new(cx.executor());
1428 fs.insert_tree(
1429 "/project",
1430 json!({
1431 ".zed": {
1432 "settings.json": "{}"
1433 },
1434 "src": {
1435 ".zed": {
1436 "local.json": "{}"
1437 }
1438 }
1439 }),
1440 )
1441 .await;
1442 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1443 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1444 let context_server_registry =
1445 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1446 let model = Arc::new(FakeLanguageModel::default());
1447 let thread = cx.new(|cx| {
1448 Thread::new(
1449 project.clone(),
1450 cx.new(|_cx| ProjectContext::default()),
1451 context_server_registry.clone(),
1452 Templates::new(),
1453 Some(model.clone()),
1454 cx,
1455 )
1456 });
1457 let tool = Arc::new(EditFileTool::new(
1458 project.clone(),
1459 thread.downgrade(),
1460 language_registry,
1461 Templates::new(),
1462 ));
1463
1464 // Test edge cases
1465 let test_cases = vec![
1466 // Empty path - find_project_path returns Some for empty paths
1467 ("", false, "Empty path is treated as project root"),
1468 // Root directory
1469 ("/", true, "Root directory should be outside project"),
1470 // Parent directory references - find_project_path resolves these
1471 (
1472 "project/../other",
1473 true,
1474 "Path with .. that goes outside of root directory",
1475 ),
1476 (
1477 "project/./src/file.rs",
1478 false,
1479 "Path with . should work normally",
1480 ),
1481 // Windows-style paths (if on Windows)
1482 #[cfg(target_os = "windows")]
1483 ("C:\\Windows\\System32\\hosts", true, "Windows system path"),
1484 #[cfg(target_os = "windows")]
1485 ("project\\src\\main.rs", false, "Windows-style project path"),
1486 ];
1487
1488 for (path, should_confirm, description) in test_cases {
1489 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1490 let auth = cx.update(|cx| {
1491 tool.authorize(
1492 &EditFileToolInput {
1493 display_description: "Edit file".into(),
1494 path: path.into(),
1495 mode: EditFileMode::Edit,
1496 },
1497 &stream_tx,
1498 cx,
1499 )
1500 });
1501
1502 cx.run_until_parked();
1503
1504 if should_confirm {
1505 stream_rx.expect_authorization().await;
1506 } else {
1507 assert!(
1508 stream_rx.try_next().is_err(),
1509 "Failed for case: {} - path: {} - expected no confirmation but got one",
1510 description,
1511 path
1512 );
1513 auth.await.unwrap();
1514 }
1515 }
1516 }
1517
1518 #[gpui::test]
1519 async fn test_needs_confirmation_with_different_modes(cx: &mut TestAppContext) {
1520 init_test(cx);
1521 let fs = project::FakeFs::new(cx.executor());
1522 fs.insert_tree(
1523 "/project",
1524 json!({
1525 "existing.txt": "content",
1526 ".zed": {
1527 "settings.json": "{}"
1528 }
1529 }),
1530 )
1531 .await;
1532 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1533 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1534 let context_server_registry =
1535 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1536 let model = Arc::new(FakeLanguageModel::default());
1537 let thread = cx.new(|cx| {
1538 Thread::new(
1539 project.clone(),
1540 cx.new(|_cx| ProjectContext::default()),
1541 context_server_registry.clone(),
1542 Templates::new(),
1543 Some(model.clone()),
1544 cx,
1545 )
1546 });
1547 let tool = Arc::new(EditFileTool::new(
1548 project.clone(),
1549 thread.downgrade(),
1550 language_registry,
1551 Templates::new(),
1552 ));
1553
1554 // Test different EditFileMode values
1555 let modes = vec![
1556 EditFileMode::Edit,
1557 EditFileMode::Create,
1558 EditFileMode::Overwrite,
1559 ];
1560
1561 for mode in modes {
1562 // Test .zed path with different modes
1563 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1564 let _auth = cx.update(|cx| {
1565 tool.authorize(
1566 &EditFileToolInput {
1567 display_description: "Edit settings".into(),
1568 path: "project/.zed/settings.json".into(),
1569 mode: mode.clone(),
1570 },
1571 &stream_tx,
1572 cx,
1573 )
1574 });
1575
1576 stream_rx.expect_authorization().await;
1577
1578 // Test outside path with different modes
1579 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1580 let _auth = cx.update(|cx| {
1581 tool.authorize(
1582 &EditFileToolInput {
1583 display_description: "Edit file".into(),
1584 path: "/outside/file.txt".into(),
1585 mode: mode.clone(),
1586 },
1587 &stream_tx,
1588 cx,
1589 )
1590 });
1591
1592 stream_rx.expect_authorization().await;
1593
1594 // Test normal path with different modes
1595 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1596 cx.update(|cx| {
1597 tool.authorize(
1598 &EditFileToolInput {
1599 display_description: "Edit file".into(),
1600 path: "project/normal.txt".into(),
1601 mode: mode.clone(),
1602 },
1603 &stream_tx,
1604 cx,
1605 )
1606 })
1607 .await
1608 .unwrap();
1609 assert!(stream_rx.try_next().is_err());
1610 }
1611 }
1612
1613 #[gpui::test]
1614 async fn test_initial_title_with_partial_input(cx: &mut TestAppContext) {
1615 init_test(cx);
1616 let fs = project::FakeFs::new(cx.executor());
1617 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1618 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1619 let context_server_registry =
1620 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1621 let model = Arc::new(FakeLanguageModel::default());
1622 let thread = cx.new(|cx| {
1623 Thread::new(
1624 project.clone(),
1625 cx.new(|_cx| ProjectContext::default()),
1626 context_server_registry,
1627 Templates::new(),
1628 Some(model.clone()),
1629 cx,
1630 )
1631 });
1632 let tool = Arc::new(EditFileTool::new(
1633 project,
1634 thread.downgrade(),
1635 language_registry,
1636 Templates::new(),
1637 ));
1638
1639 cx.update(|cx| {
1640 // ...
1641 assert_eq!(
1642 tool.initial_title(
1643 Err(json!({
1644 "path": "src/main.rs",
1645 "display_description": "",
1646 "old_string": "old code",
1647 "new_string": "new code"
1648 })),
1649 cx
1650 ),
1651 "src/main.rs"
1652 );
1653 assert_eq!(
1654 tool.initial_title(
1655 Err(json!({
1656 "path": "",
1657 "display_description": "Fix error handling",
1658 "old_string": "old code",
1659 "new_string": "new code"
1660 })),
1661 cx
1662 ),
1663 "Fix error handling"
1664 );
1665 assert_eq!(
1666 tool.initial_title(
1667 Err(json!({
1668 "path": "src/main.rs",
1669 "display_description": "Fix error handling",
1670 "old_string": "old code",
1671 "new_string": "new code"
1672 })),
1673 cx
1674 ),
1675 "src/main.rs"
1676 );
1677 assert_eq!(
1678 tool.initial_title(
1679 Err(json!({
1680 "path": "",
1681 "display_description": "",
1682 "old_string": "old code",
1683 "new_string": "new code"
1684 })),
1685 cx
1686 ),
1687 DEFAULT_UI_TEXT
1688 );
1689 assert_eq!(
1690 tool.initial_title(Err(serde_json::Value::Null), cx),
1691 DEFAULT_UI_TEXT
1692 );
1693 });
1694 }
1695
1696 #[gpui::test]
1697 async fn test_diff_finalization(cx: &mut TestAppContext) {
1698 init_test(cx);
1699 let fs = project::FakeFs::new(cx.executor());
1700 fs.insert_tree("/", json!({"main.rs": ""})).await;
1701
1702 let project = Project::test(fs.clone(), [path!("/").as_ref()], cx).await;
1703 let languages = project.read_with(cx, |project, _cx| project.languages().clone());
1704 let context_server_registry =
1705 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1706 let model = Arc::new(FakeLanguageModel::default());
1707 let thread = cx.new(|cx| {
1708 Thread::new(
1709 project.clone(),
1710 cx.new(|_cx| ProjectContext::default()),
1711 context_server_registry.clone(),
1712 Templates::new(),
1713 Some(model.clone()),
1714 cx,
1715 )
1716 });
1717
1718 // Ensure the diff is finalized after the edit completes.
1719 {
1720 let tool = Arc::new(EditFileTool::new(
1721 project.clone(),
1722 thread.downgrade(),
1723 languages.clone(),
1724 Templates::new(),
1725 ));
1726 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1727 let edit = cx.update(|cx| {
1728 tool.run(
1729 EditFileToolInput {
1730 display_description: "Edit file".into(),
1731 path: path!("/main.rs").into(),
1732 mode: EditFileMode::Edit,
1733 },
1734 stream_tx,
1735 cx,
1736 )
1737 });
1738 stream_rx.expect_update_fields().await;
1739 let diff = stream_rx.expect_diff().await;
1740 diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Pending(_))));
1741 cx.run_until_parked();
1742 model.end_last_completion_stream();
1743 edit.await.unwrap();
1744 diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Finalized(_))));
1745 }
1746
1747 // Ensure the diff is finalized if an error occurs while editing.
1748 {
1749 model.forbid_requests();
1750 let tool = Arc::new(EditFileTool::new(
1751 project.clone(),
1752 thread.downgrade(),
1753 languages.clone(),
1754 Templates::new(),
1755 ));
1756 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1757 let edit = cx.update(|cx| {
1758 tool.run(
1759 EditFileToolInput {
1760 display_description: "Edit file".into(),
1761 path: path!("/main.rs").into(),
1762 mode: EditFileMode::Edit,
1763 },
1764 stream_tx,
1765 cx,
1766 )
1767 });
1768 stream_rx.expect_update_fields().await;
1769 let diff = stream_rx.expect_diff().await;
1770 diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Pending(_))));
1771 edit.await.unwrap_err();
1772 diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Finalized(_))));
1773 model.allow_requests();
1774 }
1775
1776 // Ensure the diff is finalized if the tool call gets dropped.
1777 {
1778 let tool = Arc::new(EditFileTool::new(
1779 project.clone(),
1780 thread.downgrade(),
1781 languages.clone(),
1782 Templates::new(),
1783 ));
1784 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1785 let edit = cx.update(|cx| {
1786 tool.run(
1787 EditFileToolInput {
1788 display_description: "Edit file".into(),
1789 path: path!("/main.rs").into(),
1790 mode: EditFileMode::Edit,
1791 },
1792 stream_tx,
1793 cx,
1794 )
1795 });
1796 stream_rx.expect_update_fields().await;
1797 let diff = stream_rx.expect_diff().await;
1798 diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Pending(_))));
1799 drop(edit);
1800 cx.run_until_parked();
1801 diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Finalized(_))));
1802 }
1803 }
1804
1805 #[gpui::test]
1806 async fn test_file_read_times_tracking(cx: &mut TestAppContext) {
1807 init_test(cx);
1808
1809 let fs = project::FakeFs::new(cx.executor());
1810 fs.insert_tree(
1811 "/root",
1812 json!({
1813 "test.txt": "original content"
1814 }),
1815 )
1816 .await;
1817 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
1818 let context_server_registry =
1819 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1820 let model = Arc::new(FakeLanguageModel::default());
1821 let thread = cx.new(|cx| {
1822 Thread::new(
1823 project.clone(),
1824 cx.new(|_cx| ProjectContext::default()),
1825 context_server_registry,
1826 Templates::new(),
1827 Some(model.clone()),
1828 cx,
1829 )
1830 });
1831 let action_log = thread.read_with(cx, |thread, _| thread.action_log().clone());
1832
1833 // Initially, file_read_times should be empty
1834 let is_empty = thread.read_with(cx, |thread, _| thread.file_read_times.is_empty());
1835 assert!(is_empty, "file_read_times should start empty");
1836
1837 // Create read tool
1838 let read_tool = Arc::new(crate::ReadFileTool::new(
1839 thread.downgrade(),
1840 project.clone(),
1841 action_log,
1842 ));
1843
1844 // Read the file to record the read time
1845 cx.update(|cx| {
1846 read_tool.clone().run(
1847 crate::ReadFileToolInput {
1848 path: "root/test.txt".to_string(),
1849 start_line: None,
1850 end_line: None,
1851 },
1852 ToolCallEventStream::test().0,
1853 cx,
1854 )
1855 })
1856 .await
1857 .unwrap();
1858
1859 // Verify that file_read_times now contains an entry for the file
1860 let has_entry = thread.read_with(cx, |thread, _| {
1861 thread.file_read_times.len() == 1
1862 && thread
1863 .file_read_times
1864 .keys()
1865 .any(|path| path.ends_with("test.txt"))
1866 });
1867 assert!(
1868 has_entry,
1869 "file_read_times should contain an entry after reading the file"
1870 );
1871
1872 // Read the file again - should update the entry
1873 cx.update(|cx| {
1874 read_tool.clone().run(
1875 crate::ReadFileToolInput {
1876 path: "root/test.txt".to_string(),
1877 start_line: None,
1878 end_line: None,
1879 },
1880 ToolCallEventStream::test().0,
1881 cx,
1882 )
1883 })
1884 .await
1885 .unwrap();
1886
1887 // Should still have exactly one entry
1888 let has_one_entry = thread.read_with(cx, |thread, _| thread.file_read_times.len() == 1);
1889 assert!(
1890 has_one_entry,
1891 "file_read_times should still have one entry after re-reading"
1892 );
1893 }
1894
1895 fn init_test(cx: &mut TestAppContext) {
1896 cx.update(|cx| {
1897 let settings_store = SettingsStore::test(cx);
1898 cx.set_global(settings_store);
1899 });
1900 }
1901
1902 #[gpui::test]
1903 async fn test_consecutive_edits_work(cx: &mut TestAppContext) {
1904 init_test(cx);
1905
1906 let fs = project::FakeFs::new(cx.executor());
1907 fs.insert_tree(
1908 "/root",
1909 json!({
1910 "test.txt": "original content"
1911 }),
1912 )
1913 .await;
1914 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
1915 let context_server_registry =
1916 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1917 let model = Arc::new(FakeLanguageModel::default());
1918 let thread = cx.new(|cx| {
1919 Thread::new(
1920 project.clone(),
1921 cx.new(|_cx| ProjectContext::default()),
1922 context_server_registry,
1923 Templates::new(),
1924 Some(model.clone()),
1925 cx,
1926 )
1927 });
1928 let languages = project.read_with(cx, |project, _| project.languages().clone());
1929 let action_log = thread.read_with(cx, |thread, _| thread.action_log().clone());
1930
1931 let read_tool = Arc::new(crate::ReadFileTool::new(
1932 thread.downgrade(),
1933 project.clone(),
1934 action_log,
1935 ));
1936 let edit_tool = Arc::new(EditFileTool::new(
1937 project.clone(),
1938 thread.downgrade(),
1939 languages,
1940 Templates::new(),
1941 ));
1942
1943 // Read the file first
1944 cx.update(|cx| {
1945 read_tool.clone().run(
1946 crate::ReadFileToolInput {
1947 path: "root/test.txt".to_string(),
1948 start_line: None,
1949 end_line: None,
1950 },
1951 ToolCallEventStream::test().0,
1952 cx,
1953 )
1954 })
1955 .await
1956 .unwrap();
1957
1958 // First edit should work
1959 let edit_result = {
1960 let edit_task = cx.update(|cx| {
1961 edit_tool.clone().run(
1962 EditFileToolInput {
1963 display_description: "First edit".into(),
1964 path: "root/test.txt".into(),
1965 mode: EditFileMode::Edit,
1966 },
1967 ToolCallEventStream::test().0,
1968 cx,
1969 )
1970 });
1971
1972 cx.executor().run_until_parked();
1973 model.send_last_completion_stream_text_chunk(
1974 "<old_text>original content</old_text><new_text>modified content</new_text>"
1975 .to_string(),
1976 );
1977 model.end_last_completion_stream();
1978
1979 edit_task.await
1980 };
1981 assert!(
1982 edit_result.is_ok(),
1983 "First edit should succeed, got error: {:?}",
1984 edit_result.as_ref().err()
1985 );
1986
1987 // Second edit should also work because the edit updated the recorded read time
1988 let edit_result = {
1989 let edit_task = cx.update(|cx| {
1990 edit_tool.clone().run(
1991 EditFileToolInput {
1992 display_description: "Second edit".into(),
1993 path: "root/test.txt".into(),
1994 mode: EditFileMode::Edit,
1995 },
1996 ToolCallEventStream::test().0,
1997 cx,
1998 )
1999 });
2000
2001 cx.executor().run_until_parked();
2002 model.send_last_completion_stream_text_chunk(
2003 "<old_text>modified content</old_text><new_text>further modified content</new_text>".to_string(),
2004 );
2005 model.end_last_completion_stream();
2006
2007 edit_task.await
2008 };
2009 assert!(
2010 edit_result.is_ok(),
2011 "Second consecutive edit should succeed, got error: {:?}",
2012 edit_result.as_ref().err()
2013 );
2014 }
2015
2016 #[gpui::test]
2017 async fn test_external_modification_detected(cx: &mut TestAppContext) {
2018 init_test(cx);
2019
2020 let fs = project::FakeFs::new(cx.executor());
2021 fs.insert_tree(
2022 "/root",
2023 json!({
2024 "test.txt": "original content"
2025 }),
2026 )
2027 .await;
2028 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
2029 let context_server_registry =
2030 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
2031 let model = Arc::new(FakeLanguageModel::default());
2032 let thread = cx.new(|cx| {
2033 Thread::new(
2034 project.clone(),
2035 cx.new(|_cx| ProjectContext::default()),
2036 context_server_registry,
2037 Templates::new(),
2038 Some(model.clone()),
2039 cx,
2040 )
2041 });
2042 let languages = project.read_with(cx, |project, _| project.languages().clone());
2043 let action_log = thread.read_with(cx, |thread, _| thread.action_log().clone());
2044
2045 let read_tool = Arc::new(crate::ReadFileTool::new(
2046 thread.downgrade(),
2047 project.clone(),
2048 action_log,
2049 ));
2050 let edit_tool = Arc::new(EditFileTool::new(
2051 project.clone(),
2052 thread.downgrade(),
2053 languages,
2054 Templates::new(),
2055 ));
2056
2057 // Read the file first
2058 cx.update(|cx| {
2059 read_tool.clone().run(
2060 crate::ReadFileToolInput {
2061 path: "root/test.txt".to_string(),
2062 start_line: None,
2063 end_line: None,
2064 },
2065 ToolCallEventStream::test().0,
2066 cx,
2067 )
2068 })
2069 .await
2070 .unwrap();
2071
2072 // Simulate external modification - advance time and save file
2073 cx.background_executor
2074 .advance_clock(std::time::Duration::from_secs(2));
2075 fs.save(
2076 path!("/root/test.txt").as_ref(),
2077 &"externally modified content".into(),
2078 language::LineEnding::Unix,
2079 )
2080 .await
2081 .unwrap();
2082
2083 // Reload the buffer to pick up the new mtime
2084 let project_path = project
2085 .read_with(cx, |project, cx| {
2086 project.find_project_path("root/test.txt", cx)
2087 })
2088 .expect("Should find project path");
2089 let buffer = project
2090 .update(cx, |project, cx| project.open_buffer(project_path, cx))
2091 .await
2092 .unwrap();
2093 buffer
2094 .update(cx, |buffer, cx| buffer.reload(cx))
2095 .await
2096 .unwrap();
2097
2098 cx.executor().run_until_parked();
2099
2100 // Try to edit - should fail because file was modified externally
2101 let result = cx
2102 .update(|cx| {
2103 edit_tool.clone().run(
2104 EditFileToolInput {
2105 display_description: "Edit after external change".into(),
2106 path: "root/test.txt".into(),
2107 mode: EditFileMode::Edit,
2108 },
2109 ToolCallEventStream::test().0,
2110 cx,
2111 )
2112 })
2113 .await;
2114
2115 assert!(
2116 result.is_err(),
2117 "Edit should fail after external modification"
2118 );
2119 let error_msg = result.unwrap_err().to_string();
2120 assert!(
2121 error_msg.contains("has been modified since you last read it"),
2122 "Error should mention file modification, got: {}",
2123 error_msg
2124 );
2125 }
2126
2127 #[gpui::test]
2128 async fn test_dirty_buffer_detected(cx: &mut TestAppContext) {
2129 init_test(cx);
2130
2131 let fs = project::FakeFs::new(cx.executor());
2132 fs.insert_tree(
2133 "/root",
2134 json!({
2135 "test.txt": "original content"
2136 }),
2137 )
2138 .await;
2139 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
2140 let context_server_registry =
2141 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
2142 let model = Arc::new(FakeLanguageModel::default());
2143 let thread = cx.new(|cx| {
2144 Thread::new(
2145 project.clone(),
2146 cx.new(|_cx| ProjectContext::default()),
2147 context_server_registry,
2148 Templates::new(),
2149 Some(model.clone()),
2150 cx,
2151 )
2152 });
2153 let languages = project.read_with(cx, |project, _| project.languages().clone());
2154 let action_log = thread.read_with(cx, |thread, _| thread.action_log().clone());
2155
2156 let read_tool = Arc::new(crate::ReadFileTool::new(
2157 thread.downgrade(),
2158 project.clone(),
2159 action_log,
2160 ));
2161 let edit_tool = Arc::new(EditFileTool::new(
2162 project.clone(),
2163 thread.downgrade(),
2164 languages,
2165 Templates::new(),
2166 ));
2167
2168 // Read the file first
2169 cx.update(|cx| {
2170 read_tool.clone().run(
2171 crate::ReadFileToolInput {
2172 path: "root/test.txt".to_string(),
2173 start_line: None,
2174 end_line: None,
2175 },
2176 ToolCallEventStream::test().0,
2177 cx,
2178 )
2179 })
2180 .await
2181 .unwrap();
2182
2183 // Open the buffer and make it dirty by editing without saving
2184 let project_path = project
2185 .read_with(cx, |project, cx| {
2186 project.find_project_path("root/test.txt", cx)
2187 })
2188 .expect("Should find project path");
2189 let buffer = project
2190 .update(cx, |project, cx| project.open_buffer(project_path, cx))
2191 .await
2192 .unwrap();
2193
2194 // Make an in-memory edit to the buffer (making it dirty)
2195 buffer.update(cx, |buffer, cx| {
2196 let end_point = buffer.max_point();
2197 buffer.edit([(end_point..end_point, " added text")], None, cx);
2198 });
2199
2200 // Verify buffer is dirty
2201 let is_dirty = buffer.read_with(cx, |buffer, _| buffer.is_dirty());
2202 assert!(is_dirty, "Buffer should be dirty after in-memory edit");
2203
2204 // Try to edit - should fail because buffer has unsaved changes
2205 let result = cx
2206 .update(|cx| {
2207 edit_tool.clone().run(
2208 EditFileToolInput {
2209 display_description: "Edit with dirty buffer".into(),
2210 path: "root/test.txt".into(),
2211 mode: EditFileMode::Edit,
2212 },
2213 ToolCallEventStream::test().0,
2214 cx,
2215 )
2216 })
2217 .await;
2218
2219 assert!(result.is_err(), "Edit should fail when buffer is dirty");
2220 let error_msg = result.unwrap_err().to_string();
2221 assert!(
2222 error_msg.contains("This file has unsaved changes."),
2223 "Error should mention unsaved changes, got: {}",
2224 error_msg
2225 );
2226 assert!(
2227 error_msg.contains("keep or discard"),
2228 "Error should ask whether to keep or discard changes, got: {}",
2229 error_msg
2230 );
2231 // Since save_file and restore_file_from_disk tools aren't added to the thread,
2232 // the error message should ask the user to manually save or revert
2233 assert!(
2234 error_msg.contains("save or revert the file manually"),
2235 "Error should ask user to manually save or revert when tools aren't available, got: {}",
2236 error_msg
2237 );
2238 }
2239}