1use crate::{
2 AgentTool, Templates, Thread, ToolCallEventStream,
3 edit_agent::{EditAgent, EditAgentOutput, EditAgentOutputEvent, EditFormat},
4};
5use acp_thread::Diff;
6use agent_client_protocol::{self as acp, ToolCallLocation, ToolCallUpdateFields};
7use anyhow::{Context as _, Result, anyhow};
8use cloud_llm_client::CompletionIntent;
9use collections::HashSet;
10use futures::{FutureExt as _, StreamExt as _};
11use gpui::{App, AppContext, AsyncApp, Entity, Task, WeakEntity};
12use indoc::formatdoc;
13use language::language_settings::{self, FormatOnSave};
14use language::{LanguageRegistry, ToPoint};
15use language_model::LanguageModelToolResultContent;
16use paths;
17use project::lsp_store::{FormatTrigger, LspFormatTarget};
18use project::{Project, ProjectPath};
19use schemars::JsonSchema;
20use serde::{Deserialize, Serialize};
21use settings::Settings;
22use std::ffi::OsStr;
23use std::path::{Path, PathBuf};
24use std::sync::Arc;
25use ui::SharedString;
26use util::ResultExt;
27use util::rel_path::RelPath;
28
29const DEFAULT_UI_TEXT: &str = "Editing file";
30
31/// This is a tool for creating a new file or editing an existing file. For moving or renaming files, you should generally use the `terminal` tool with the 'mv' command instead.
32///
33/// Before using this tool:
34///
35/// 1. Use the `read_file` tool to understand the file's contents and context
36///
37/// 2. Verify the directory path is correct (only applicable when creating new files):
38/// - Use the `list_directory` tool to verify the parent directory exists and is the correct location
39#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
40pub struct EditFileToolInput {
41 /// A one-line, user-friendly markdown description of the edit. This will be shown in the UI and also passed to another model to perform the edit.
42 ///
43 /// Be terse, but also descriptive in what you want to achieve with this edit. Avoid generic instructions.
44 ///
45 /// NEVER mention the file path in this description.
46 ///
47 /// <example>Fix API endpoint URLs</example>
48 /// <example>Update copyright year in `page_footer`</example>
49 ///
50 /// Make sure to include this field before all the others in the input object so that we can display it immediately.
51 pub display_description: String,
52
53 /// The full path of the file to create or modify in the project.
54 ///
55 /// WARNING: When specifying which file path need changing, you MUST start each path with one of the project's root directories.
56 ///
57 /// The following examples assume we have two root directories in the project:
58 /// - /a/b/backend
59 /// - /c/d/frontend
60 ///
61 /// <example>
62 /// `backend/src/main.rs`
63 ///
64 /// Notice how the file path starts with `backend`. Without that, the path would be ambiguous and the call would fail!
65 /// </example>
66 ///
67 /// <example>
68 /// `frontend/db.js`
69 /// </example>
70 pub path: PathBuf,
71 /// The mode of operation on the file. Possible values:
72 /// - 'edit': Make granular edits to an existing file.
73 /// - 'create': Create a new file if it doesn't exist.
74 /// - 'overwrite': Replace the entire contents of an existing file.
75 ///
76 /// When a file already exists or you just created it, prefer editing it as opposed to recreating it from scratch.
77 pub mode: EditFileMode,
78}
79
80#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
81struct EditFileToolPartialInput {
82 #[serde(default)]
83 path: String,
84 #[serde(default)]
85 display_description: String,
86}
87
88#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
89#[serde(rename_all = "lowercase")]
90#[schemars(inline)]
91pub enum EditFileMode {
92 Edit,
93 Create,
94 Overwrite,
95}
96
97#[derive(Debug, Serialize, Deserialize)]
98pub struct EditFileToolOutput {
99 #[serde(alias = "original_path")]
100 input_path: PathBuf,
101 new_text: String,
102 old_text: Arc<String>,
103 #[serde(default)]
104 diff: String,
105 #[serde(alias = "raw_output")]
106 edit_agent_output: EditAgentOutput,
107}
108
109impl From<EditFileToolOutput> for LanguageModelToolResultContent {
110 fn from(output: EditFileToolOutput) -> Self {
111 if output.diff.is_empty() {
112 "No edits were made.".into()
113 } else {
114 format!(
115 "Edited {}:\n\n```diff\n{}\n```",
116 output.input_path.display(),
117 output.diff
118 )
119 .into()
120 }
121 }
122}
123
124pub struct EditFileTool {
125 thread: WeakEntity<Thread>,
126 language_registry: Arc<LanguageRegistry>,
127 project: Entity<Project>,
128 templates: Arc<Templates>,
129}
130
131impl EditFileTool {
132 pub fn new(
133 project: Entity<Project>,
134 thread: WeakEntity<Thread>,
135 language_registry: Arc<LanguageRegistry>,
136 templates: Arc<Templates>,
137 ) -> Self {
138 Self {
139 project,
140 thread,
141 language_registry,
142 templates,
143 }
144 }
145
146 fn authorize(
147 &self,
148 input: &EditFileToolInput,
149 event_stream: &ToolCallEventStream,
150 cx: &mut App,
151 ) -> Task<Result<()>> {
152 if agent_settings::AgentSettings::get_global(cx).always_allow_tool_actions {
153 return Task::ready(Ok(()));
154 }
155
156 // If any path component matches the local settings folder, then this could affect
157 // the editor in ways beyond the project source, so prompt.
158 let local_settings_folder = paths::local_settings_folder_name();
159 let path = Path::new(&input.path);
160 if path.components().any(|component| {
161 component.as_os_str() == <_ as AsRef<OsStr>>::as_ref(&local_settings_folder)
162 }) {
163 return event_stream.authorize(
164 format!("{} (local settings)", input.display_description),
165 cx,
166 );
167 }
168
169 // It's also possible that the global config dir is configured to be inside the project,
170 // so check for that edge case too.
171 // TODO this is broken when remoting
172 if let Ok(canonical_path) = std::fs::canonicalize(&input.path)
173 && canonical_path.starts_with(paths::config_dir())
174 {
175 return event_stream.authorize(
176 format!("{} (global settings)", input.display_description),
177 cx,
178 );
179 }
180
181 // Check if path is inside the global config directory
182 // First check if it's already inside project - if not, try to canonicalize
183 let Ok(project_path) = self.thread.read_with(cx, |thread, cx| {
184 thread.project().read(cx).find_project_path(&input.path, cx)
185 }) else {
186 return Task::ready(Err(anyhow!("thread was dropped")));
187 };
188
189 // If the path is inside the project, and it's not one of the above edge cases,
190 // then no confirmation is necessary. Otherwise, confirmation is necessary.
191 if project_path.is_some() {
192 Task::ready(Ok(()))
193 } else {
194 event_stream.authorize(&input.display_description, cx)
195 }
196 }
197}
198
199impl AgentTool for EditFileTool {
200 type Input = EditFileToolInput;
201 type Output = EditFileToolOutput;
202
203 fn name() -> &'static str {
204 "edit_file"
205 }
206
207 fn kind() -> acp::ToolKind {
208 acp::ToolKind::Edit
209 }
210
211 fn initial_title(
212 &self,
213 input: Result<Self::Input, serde_json::Value>,
214 cx: &mut App,
215 ) -> SharedString {
216 match input {
217 Ok(input) => self
218 .project
219 .read(cx)
220 .find_project_path(&input.path, cx)
221 .and_then(|project_path| {
222 self.project
223 .read(cx)
224 .short_full_path_for_project_path(&project_path, cx)
225 })
226 .unwrap_or(input.path.to_string_lossy().into_owned())
227 .into(),
228 Err(raw_input) => {
229 if let Some(input) =
230 serde_json::from_value::<EditFileToolPartialInput>(raw_input).ok()
231 {
232 let path = input.path.trim();
233 if !path.is_empty() {
234 return self
235 .project
236 .read(cx)
237 .find_project_path(&input.path, cx)
238 .and_then(|project_path| {
239 self.project
240 .read(cx)
241 .short_full_path_for_project_path(&project_path, cx)
242 })
243 .unwrap_or(input.path)
244 .into();
245 }
246
247 let description = input.display_description.trim();
248 if !description.is_empty() {
249 return description.to_string().into();
250 }
251 }
252
253 DEFAULT_UI_TEXT.into()
254 }
255 }
256 }
257
258 fn run(
259 self: Arc<Self>,
260 input: Self::Input,
261 event_stream: ToolCallEventStream,
262 cx: &mut App,
263 ) -> Task<Result<Self::Output>> {
264 let Ok(project) = self
265 .thread
266 .read_with(cx, |thread, _cx| thread.project().clone())
267 else {
268 return Task::ready(Err(anyhow!("thread was dropped")));
269 };
270 let project_path = match resolve_path(&input, project.clone(), cx) {
271 Ok(path) => path,
272 Err(err) => return Task::ready(Err(anyhow!(err))),
273 };
274 let abs_path = project.read(cx).absolute_path(&project_path, cx);
275 if let Some(abs_path) = abs_path.clone() {
276 event_stream.update_fields(
277 ToolCallUpdateFields::new().locations(vec![acp::ToolCallLocation::new(abs_path)]),
278 );
279 }
280
281 let authorize = self.authorize(&input, &event_stream, cx);
282 cx.spawn(async move |cx: &mut AsyncApp| {
283 authorize.await?;
284
285 let (request, model, action_log) = self.thread.update(cx, |thread, cx| {
286 let request = thread.build_completion_request(CompletionIntent::ToolResults, cx);
287 (request, thread.model().cloned(), thread.action_log().clone())
288 })?;
289 let request = request?;
290 let model = model.context("No language model configured")?;
291
292 let edit_format = EditFormat::from_model(model.clone())?;
293 let edit_agent = EditAgent::new(
294 model,
295 project.clone(),
296 action_log.clone(),
297 self.templates.clone(),
298 edit_format,
299 );
300
301 let buffer = project
302 .update(cx, |project, cx| {
303 project.open_buffer(project_path.clone(), cx)
304 })
305 .await?;
306
307 // Check if the file has been modified since the agent last read it
308 if let Some(abs_path) = abs_path.as_ref() {
309 let (last_read_mtime, current_mtime, is_dirty, has_save_tool, has_restore_tool) = self.thread.update(cx, |thread, cx| {
310 let last_read = thread.file_read_times.get(abs_path).copied();
311 let current = buffer.read(cx).file().and_then(|file| file.disk_state().mtime());
312 let dirty = buffer.read(cx).is_dirty();
313 let has_save = thread.has_tool("save_file");
314 let has_restore = thread.has_tool("restore_file_from_disk");
315 (last_read, current, dirty, has_save, has_restore)
316 })?;
317
318 // Check for unsaved changes first - these indicate modifications we don't know about
319 if is_dirty {
320 let message = match (has_save_tool, has_restore_tool) {
321 (true, true) => {
322 "This file has unsaved changes. Ask the user whether they want to keep or discard those changes. \
323 If they want to keep them, ask for confirmation then use the save_file tool to save the file, then retry this edit. \
324 If they want to discard them, ask for confirmation then use the restore_file_from_disk tool to restore the on-disk contents, then retry this edit."
325 }
326 (true, false) => {
327 "This file has unsaved changes. Ask the user whether they want to keep or discard those changes. \
328 If they want to keep them, ask for confirmation then use the save_file tool to save the file, then retry this edit. \
329 If they want to discard them, ask the user to manually revert the file, then inform you when it's ok to proceed."
330 }
331 (false, true) => {
332 "This file has unsaved changes. Ask the user whether they want to keep or discard those changes. \
333 If they want to keep them, ask the user to manually save the file, then inform you when it's ok to proceed. \
334 If they want to discard them, ask for confirmation then use the restore_file_from_disk tool to restore the on-disk contents, then retry this edit."
335 }
336 (false, false) => {
337 "This file has unsaved changes. Ask the user whether they want to keep or discard those changes, \
338 then ask them to save or revert the file manually and inform you when it's ok to proceed."
339 }
340 };
341 anyhow::bail!("{}", message);
342 }
343
344 // Check if the file was modified on disk since we last read it
345 if let (Some(last_read), Some(current)) = (last_read_mtime, current_mtime) {
346 // MTime can be unreliable for comparisons, so our newtype intentionally
347 // doesn't support comparing them. If the mtime at all different
348 // (which could be because of a modification or because e.g. system clock changed),
349 // we pessimistically assume it was modified.
350 if current != last_read {
351 anyhow::bail!(
352 "The file {} has been modified since you last read it. \
353 Please read the file again to get the current state before editing it.",
354 input.path.display()
355 );
356 }
357 }
358 }
359
360 let diff = cx.new(|cx| Diff::new(buffer.clone(), cx));
361 event_stream.update_diff(diff.clone());
362 let _finalize_diff = util::defer({
363 let diff = diff.downgrade();
364 let mut cx = cx.clone();
365 move || {
366 diff.update(&mut cx, |diff, cx| diff.finalize(cx)).ok();
367 }
368 });
369
370 let old_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
371 let old_text = cx
372 .background_spawn({
373 let old_snapshot = old_snapshot.clone();
374 async move { Arc::new(old_snapshot.text()) }
375 })
376 .await;
377
378
379 let (output, mut events) = if matches!(input.mode, EditFileMode::Edit) {
380 edit_agent.edit(
381 buffer.clone(),
382 input.display_description.clone(),
383 &request,
384 cx,
385 )
386 } else {
387 edit_agent.overwrite(
388 buffer.clone(),
389 input.display_description.clone(),
390 &request,
391 cx,
392 )
393 };
394
395 let mut hallucinated_old_text = false;
396 let mut ambiguous_ranges = Vec::new();
397 let mut emitted_location = false;
398 loop {
399 let event = futures::select! {
400 event = events.next().fuse() => match event {
401 Some(event) => event,
402 None => break,
403 },
404 _ = event_stream.cancelled_by_user().fuse() => {
405 anyhow::bail!("Edit cancelled by user");
406 }
407 };
408 match event {
409 EditAgentOutputEvent::Edited(range) => {
410 if !emitted_location {
411 let line = Some(buffer.update(cx, |buffer, _cx| {
412 range.start.to_point(&buffer.snapshot()).row
413 }));
414 if let Some(abs_path) = abs_path.clone() {
415 event_stream.update_fields(ToolCallUpdateFields::new().locations(vec![ToolCallLocation::new(abs_path).line(line)]));
416 }
417 emitted_location = true;
418 }
419 },
420 EditAgentOutputEvent::UnresolvedEditRange => hallucinated_old_text = true,
421 EditAgentOutputEvent::AmbiguousEditRange(ranges) => ambiguous_ranges = ranges,
422 EditAgentOutputEvent::ResolvingEditRange(range) => {
423 diff.update(cx, |card, cx| card.reveal_range(range.clone(), cx));
424 // if !emitted_location {
425 // let line = buffer.update(cx, |buffer, _cx| {
426 // range.start.to_point(&buffer.snapshot()).row
427 // }).ok();
428 // if let Some(abs_path) = abs_path.clone() {
429 // event_stream.update_fields(ToolCallUpdateFields {
430 // locations: Some(vec![ToolCallLocation { path: abs_path, line }]),
431 // ..Default::default()
432 // });
433 // }
434 // }
435 }
436 }
437 }
438
439 // If format_on_save is enabled, format the buffer
440 let format_on_save_enabled = buffer.read_with(cx, |buffer, cx| {
441 let settings = language_settings::language_settings(
442 buffer.language().map(|l| l.name()),
443 buffer.file(),
444 cx,
445 );
446 settings.format_on_save != FormatOnSave::Off
447 });
448
449 let edit_agent_output = output.await?;
450
451 if format_on_save_enabled {
452 action_log.update(cx, |log, cx| {
453 log.buffer_edited(buffer.clone(), cx);
454 });
455
456 let format_task = project.update(cx, |project, cx| {
457 project.format(
458 HashSet::from_iter([buffer.clone()]),
459 LspFormatTarget::Buffers,
460 false, // Don't push to history since the tool did it.
461 FormatTrigger::Save,
462 cx,
463 )
464 });
465 format_task.await.log_err();
466 }
467
468 project
469 .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))
470 .await?;
471
472 action_log.update(cx, |log, cx| {
473 log.buffer_edited(buffer.clone(), cx);
474 });
475
476 // Update the recorded read time after a successful edit so consecutive edits work
477 if let Some(abs_path) = abs_path.as_ref() {
478 if let Some(new_mtime) = buffer.read_with(cx, |buffer, _| {
479 buffer.file().and_then(|file| file.disk_state().mtime())
480 }) {
481 self.thread.update(cx, |thread, _| {
482 thread.file_read_times.insert(abs_path.to_path_buf(), new_mtime);
483 })?;
484 }
485 }
486
487 let new_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
488 let (new_text, unified_diff) = cx
489 .background_spawn({
490 let new_snapshot = new_snapshot.clone();
491 let old_text = old_text.clone();
492 async move {
493 let new_text = new_snapshot.text();
494 let diff = language::unified_diff(&old_text, &new_text);
495 (new_text, diff)
496 }
497 })
498 .await;
499
500 let input_path = input.path.display();
501 if unified_diff.is_empty() {
502 anyhow::ensure!(
503 !hallucinated_old_text,
504 formatdoc! {"
505 Some edits were produced but none of them could be applied.
506 Read the relevant sections of {input_path} again so that
507 I can perform the requested edits.
508 "}
509 );
510 anyhow::ensure!(
511 ambiguous_ranges.is_empty(),
512 {
513 let line_numbers = ambiguous_ranges
514 .iter()
515 .map(|range| range.start.to_string())
516 .collect::<Vec<_>>()
517 .join(", ");
518 formatdoc! {"
519 <old_text> matches more than one position in the file (lines: {line_numbers}). Read the
520 relevant sections of {input_path} again and extend <old_text> so
521 that I can perform the requested edits.
522 "}
523 }
524 );
525 }
526
527 Ok(EditFileToolOutput {
528 input_path: input.path,
529 new_text,
530 old_text,
531 diff: unified_diff,
532 edit_agent_output,
533 })
534 })
535 }
536
537 fn replay(
538 &self,
539 _input: Self::Input,
540 output: Self::Output,
541 event_stream: ToolCallEventStream,
542 cx: &mut App,
543 ) -> Result<()> {
544 event_stream.update_diff(cx.new(|cx| {
545 Diff::finalized(
546 output.input_path.to_string_lossy().into_owned(),
547 Some(output.old_text.to_string()),
548 output.new_text,
549 self.language_registry.clone(),
550 cx,
551 )
552 }));
553 Ok(())
554 }
555}
556
557/// Validate that the file path is valid, meaning:
558///
559/// - For `edit` and `overwrite`, the path must point to an existing file.
560/// - For `create`, the file must not already exist, but it's parent dir must exist.
561fn resolve_path(
562 input: &EditFileToolInput,
563 project: Entity<Project>,
564 cx: &mut App,
565) -> Result<ProjectPath> {
566 let project = project.read(cx);
567
568 match input.mode {
569 EditFileMode::Edit | EditFileMode::Overwrite => {
570 let path = project
571 .find_project_path(&input.path, cx)
572 .context("Can't edit file: path not found")?;
573
574 let entry = project
575 .entry_for_path(&path, cx)
576 .context("Can't edit file: path not found")?;
577
578 anyhow::ensure!(entry.is_file(), "Can't edit file: path is a directory");
579 Ok(path)
580 }
581
582 EditFileMode::Create => {
583 if let Some(path) = project.find_project_path(&input.path, cx) {
584 anyhow::ensure!(
585 project.entry_for_path(&path, cx).is_none(),
586 "Can't create file: file already exists"
587 );
588 }
589
590 let parent_path = input
591 .path
592 .parent()
593 .context("Can't create file: incorrect path")?;
594
595 let parent_project_path = project.find_project_path(&parent_path, cx);
596
597 let parent_entry = parent_project_path
598 .as_ref()
599 .and_then(|path| project.entry_for_path(path, cx))
600 .context("Can't create file: parent directory doesn't exist")?;
601
602 anyhow::ensure!(
603 parent_entry.is_dir(),
604 "Can't create file: parent is not a directory"
605 );
606
607 let file_name = input
608 .path
609 .file_name()
610 .and_then(|file_name| file_name.to_str())
611 .and_then(|file_name| RelPath::unix(file_name).ok())
612 .context("Can't create file: invalid filename")?;
613
614 let new_file_path = parent_project_path.map(|parent| ProjectPath {
615 path: parent.path.join(file_name),
616 ..parent
617 });
618
619 new_file_path.context("Can't create file")
620 }
621 }
622}
623
624#[cfg(test)]
625mod tests {
626 use super::*;
627 use crate::{ContextServerRegistry, Templates};
628 use fs::Fs;
629 use gpui::{TestAppContext, UpdateGlobal};
630 use language_model::fake_provider::FakeLanguageModel;
631 use prompt_store::ProjectContext;
632 use serde_json::json;
633 use settings::SettingsStore;
634 use util::{path, rel_path::rel_path};
635
636 #[gpui::test]
637 async fn test_edit_nonexistent_file(cx: &mut TestAppContext) {
638 init_test(cx);
639
640 let fs = project::FakeFs::new(cx.executor());
641 fs.insert_tree("/root", json!({})).await;
642 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
643 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
644 let context_server_registry =
645 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
646 let model = Arc::new(FakeLanguageModel::default());
647 let thread = cx.new(|cx| {
648 Thread::new(
649 project.clone(),
650 cx.new(|_cx| ProjectContext::default()),
651 context_server_registry,
652 Templates::new(),
653 Some(model),
654 cx,
655 )
656 });
657 let result = cx
658 .update(|cx| {
659 let input = EditFileToolInput {
660 display_description: "Some edit".into(),
661 path: "root/nonexistent_file.txt".into(),
662 mode: EditFileMode::Edit,
663 };
664 Arc::new(EditFileTool::new(
665 project,
666 thread.downgrade(),
667 language_registry,
668 Templates::new(),
669 ))
670 .run(input, ToolCallEventStream::test().0, cx)
671 })
672 .await;
673 assert_eq!(
674 result.unwrap_err().to_string(),
675 "Can't edit file: path not found"
676 );
677 }
678
679 #[gpui::test]
680 async fn test_resolve_path_for_creating_file(cx: &mut TestAppContext) {
681 let mode = &EditFileMode::Create;
682
683 let result = test_resolve_path(mode, "root/new.txt", cx);
684 assert_resolved_path_eq(result.await, rel_path("new.txt"));
685
686 let result = test_resolve_path(mode, "new.txt", cx);
687 assert_resolved_path_eq(result.await, rel_path("new.txt"));
688
689 let result = test_resolve_path(mode, "dir/new.txt", cx);
690 assert_resolved_path_eq(result.await, rel_path("dir/new.txt"));
691
692 let result = test_resolve_path(mode, "root/dir/subdir/existing.txt", cx);
693 assert_eq!(
694 result.await.unwrap_err().to_string(),
695 "Can't create file: file already exists"
696 );
697
698 let result = test_resolve_path(mode, "root/dir/nonexistent_dir/new.txt", cx);
699 assert_eq!(
700 result.await.unwrap_err().to_string(),
701 "Can't create file: parent directory doesn't exist"
702 );
703 }
704
705 #[gpui::test]
706 async fn test_resolve_path_for_editing_file(cx: &mut TestAppContext) {
707 let mode = &EditFileMode::Edit;
708
709 let path_with_root = "root/dir/subdir/existing.txt";
710 let path_without_root = "dir/subdir/existing.txt";
711 let result = test_resolve_path(mode, path_with_root, cx);
712 assert_resolved_path_eq(result.await, rel_path(path_without_root));
713
714 let result = test_resolve_path(mode, path_without_root, cx);
715 assert_resolved_path_eq(result.await, rel_path(path_without_root));
716
717 let result = test_resolve_path(mode, "root/nonexistent.txt", cx);
718 assert_eq!(
719 result.await.unwrap_err().to_string(),
720 "Can't edit file: path not found"
721 );
722
723 let result = test_resolve_path(mode, "root/dir", cx);
724 assert_eq!(
725 result.await.unwrap_err().to_string(),
726 "Can't edit file: path is a directory"
727 );
728 }
729
730 async fn test_resolve_path(
731 mode: &EditFileMode,
732 path: &str,
733 cx: &mut TestAppContext,
734 ) -> anyhow::Result<ProjectPath> {
735 init_test(cx);
736
737 let fs = project::FakeFs::new(cx.executor());
738 fs.insert_tree(
739 "/root",
740 json!({
741 "dir": {
742 "subdir": {
743 "existing.txt": "hello"
744 }
745 }
746 }),
747 )
748 .await;
749 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
750
751 let input = EditFileToolInput {
752 display_description: "Some edit".into(),
753 path: path.into(),
754 mode: mode.clone(),
755 };
756
757 cx.update(|cx| resolve_path(&input, project, cx))
758 }
759
760 #[track_caller]
761 fn assert_resolved_path_eq(path: anyhow::Result<ProjectPath>, expected: &RelPath) {
762 let actual = path.expect("Should return valid path").path;
763 assert_eq!(actual.as_ref(), expected);
764 }
765
766 #[gpui::test]
767 async fn test_format_on_save(cx: &mut TestAppContext) {
768 init_test(cx);
769
770 let fs = project::FakeFs::new(cx.executor());
771 fs.insert_tree("/root", json!({"src": {}})).await;
772
773 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
774
775 // Set up a Rust language with LSP formatting support
776 let rust_language = Arc::new(language::Language::new(
777 language::LanguageConfig {
778 name: "Rust".into(),
779 matcher: language::LanguageMatcher {
780 path_suffixes: vec!["rs".to_string()],
781 ..Default::default()
782 },
783 ..Default::default()
784 },
785 None,
786 ));
787
788 // Register the language and fake LSP
789 let language_registry = project.read_with(cx, |project, _| project.languages().clone());
790 language_registry.add(rust_language);
791
792 let mut fake_language_servers = language_registry.register_fake_lsp(
793 "Rust",
794 language::FakeLspAdapter {
795 capabilities: lsp::ServerCapabilities {
796 document_formatting_provider: Some(lsp::OneOf::Left(true)),
797 ..Default::default()
798 },
799 ..Default::default()
800 },
801 );
802
803 // Create the file
804 fs.save(
805 path!("/root/src/main.rs").as_ref(),
806 &"initial content".into(),
807 language::LineEnding::Unix,
808 )
809 .await
810 .unwrap();
811
812 // Open the buffer to trigger LSP initialization
813 let buffer = project
814 .update(cx, |project, cx| {
815 project.open_local_buffer(path!("/root/src/main.rs"), cx)
816 })
817 .await
818 .unwrap();
819
820 // Register the buffer with language servers
821 let _handle = project.update(cx, |project, cx| {
822 project.register_buffer_with_language_servers(&buffer, cx)
823 });
824
825 const UNFORMATTED_CONTENT: &str = "fn main() {println!(\"Hello!\");}\n";
826 const FORMATTED_CONTENT: &str =
827 "This file was formatted by the fake formatter in the test.\n";
828
829 // Get the fake language server and set up formatting handler
830 let fake_language_server = fake_language_servers.next().await.unwrap();
831 fake_language_server.set_request_handler::<lsp::request::Formatting, _, _>({
832 |_, _| async move {
833 Ok(Some(vec![lsp::TextEdit {
834 range: lsp::Range::new(lsp::Position::new(0, 0), lsp::Position::new(1, 0)),
835 new_text: FORMATTED_CONTENT.to_string(),
836 }]))
837 }
838 });
839
840 let context_server_registry =
841 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
842 let model = Arc::new(FakeLanguageModel::default());
843 let thread = cx.new(|cx| {
844 Thread::new(
845 project.clone(),
846 cx.new(|_cx| ProjectContext::default()),
847 context_server_registry,
848 Templates::new(),
849 Some(model.clone()),
850 cx,
851 )
852 });
853
854 // First, test with format_on_save enabled
855 cx.update(|cx| {
856 SettingsStore::update_global(cx, |store, cx| {
857 store.update_user_settings(cx, |settings| {
858 settings.project.all_languages.defaults.format_on_save = Some(FormatOnSave::On);
859 settings.project.all_languages.defaults.formatter =
860 Some(language::language_settings::FormatterList::default());
861 });
862 });
863 });
864
865 // Have the model stream unformatted content
866 let edit_result = {
867 let edit_task = cx.update(|cx| {
868 let input = EditFileToolInput {
869 display_description: "Create main function".into(),
870 path: "root/src/main.rs".into(),
871 mode: EditFileMode::Overwrite,
872 };
873 Arc::new(EditFileTool::new(
874 project.clone(),
875 thread.downgrade(),
876 language_registry.clone(),
877 Templates::new(),
878 ))
879 .run(input, ToolCallEventStream::test().0, cx)
880 });
881
882 // Stream the unformatted content
883 cx.executor().run_until_parked();
884 model.send_last_completion_stream_text_chunk(UNFORMATTED_CONTENT.to_string());
885 model.end_last_completion_stream();
886
887 edit_task.await
888 };
889 assert!(edit_result.is_ok());
890
891 // Wait for any async operations (e.g. formatting) to complete
892 cx.executor().run_until_parked();
893
894 // Read the file to verify it was formatted automatically
895 let new_content = fs.load(path!("/root/src/main.rs").as_ref()).await.unwrap();
896 assert_eq!(
897 // Ignore carriage returns on Windows
898 new_content.replace("\r\n", "\n"),
899 FORMATTED_CONTENT,
900 "Code should be formatted when format_on_save is enabled"
901 );
902
903 let stale_buffer_count = thread
904 .read_with(cx, |thread, _cx| thread.action_log.clone())
905 .read_with(cx, |log, cx| log.stale_buffers(cx).count());
906
907 assert_eq!(
908 stale_buffer_count, 0,
909 "BUG: Buffer is incorrectly marked as stale after format-on-save. Found {} stale buffers. \
910 This causes the agent to think the file was modified externally when it was just formatted.",
911 stale_buffer_count
912 );
913
914 // Next, test with format_on_save disabled
915 cx.update(|cx| {
916 SettingsStore::update_global(cx, |store, cx| {
917 store.update_user_settings(cx, |settings| {
918 settings.project.all_languages.defaults.format_on_save =
919 Some(FormatOnSave::Off);
920 });
921 });
922 });
923
924 // Stream unformatted edits again
925 let edit_result = {
926 let edit_task = cx.update(|cx| {
927 let input = EditFileToolInput {
928 display_description: "Update main function".into(),
929 path: "root/src/main.rs".into(),
930 mode: EditFileMode::Overwrite,
931 };
932 Arc::new(EditFileTool::new(
933 project.clone(),
934 thread.downgrade(),
935 language_registry,
936 Templates::new(),
937 ))
938 .run(input, ToolCallEventStream::test().0, cx)
939 });
940
941 // Stream the unformatted content
942 cx.executor().run_until_parked();
943 model.send_last_completion_stream_text_chunk(UNFORMATTED_CONTENT.to_string());
944 model.end_last_completion_stream();
945
946 edit_task.await
947 };
948 assert!(edit_result.is_ok());
949
950 // Wait for any async operations (e.g. formatting) to complete
951 cx.executor().run_until_parked();
952
953 // Verify the file was not formatted
954 let new_content = fs.load(path!("/root/src/main.rs").as_ref()).await.unwrap();
955 assert_eq!(
956 // Ignore carriage returns on Windows
957 new_content.replace("\r\n", "\n"),
958 UNFORMATTED_CONTENT,
959 "Code should not be formatted when format_on_save is disabled"
960 );
961 }
962
963 #[gpui::test]
964 async fn test_remove_trailing_whitespace(cx: &mut TestAppContext) {
965 init_test(cx);
966
967 let fs = project::FakeFs::new(cx.executor());
968 fs.insert_tree("/root", json!({"src": {}})).await;
969
970 // Create a simple file with trailing whitespace
971 fs.save(
972 path!("/root/src/main.rs").as_ref(),
973 &"initial content".into(),
974 language::LineEnding::Unix,
975 )
976 .await
977 .unwrap();
978
979 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
980 let context_server_registry =
981 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
982 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
983 let model = Arc::new(FakeLanguageModel::default());
984 let thread = cx.new(|cx| {
985 Thread::new(
986 project.clone(),
987 cx.new(|_cx| ProjectContext::default()),
988 context_server_registry,
989 Templates::new(),
990 Some(model.clone()),
991 cx,
992 )
993 });
994
995 // First, test with remove_trailing_whitespace_on_save enabled
996 cx.update(|cx| {
997 SettingsStore::update_global(cx, |store, cx| {
998 store.update_user_settings(cx, |settings| {
999 settings
1000 .project
1001 .all_languages
1002 .defaults
1003 .remove_trailing_whitespace_on_save = Some(true);
1004 });
1005 });
1006 });
1007
1008 const CONTENT_WITH_TRAILING_WHITESPACE: &str =
1009 "fn main() { \n println!(\"Hello!\"); \n}\n";
1010
1011 // Have the model stream content that contains trailing whitespace
1012 let edit_result = {
1013 let edit_task = cx.update(|cx| {
1014 let input = EditFileToolInput {
1015 display_description: "Create main function".into(),
1016 path: "root/src/main.rs".into(),
1017 mode: EditFileMode::Overwrite,
1018 };
1019 Arc::new(EditFileTool::new(
1020 project.clone(),
1021 thread.downgrade(),
1022 language_registry.clone(),
1023 Templates::new(),
1024 ))
1025 .run(input, ToolCallEventStream::test().0, cx)
1026 });
1027
1028 // Stream the content with trailing whitespace
1029 cx.executor().run_until_parked();
1030 model.send_last_completion_stream_text_chunk(
1031 CONTENT_WITH_TRAILING_WHITESPACE.to_string(),
1032 );
1033 model.end_last_completion_stream();
1034
1035 edit_task.await
1036 };
1037 assert!(edit_result.is_ok());
1038
1039 // Wait for any async operations (e.g. formatting) to complete
1040 cx.executor().run_until_parked();
1041
1042 // Read the file to verify trailing whitespace was removed automatically
1043 assert_eq!(
1044 // Ignore carriage returns on Windows
1045 fs.load(path!("/root/src/main.rs").as_ref())
1046 .await
1047 .unwrap()
1048 .replace("\r\n", "\n"),
1049 "fn main() {\n println!(\"Hello!\");\n}\n",
1050 "Trailing whitespace should be removed when remove_trailing_whitespace_on_save is enabled"
1051 );
1052
1053 // Next, test with remove_trailing_whitespace_on_save disabled
1054 cx.update(|cx| {
1055 SettingsStore::update_global(cx, |store, cx| {
1056 store.update_user_settings(cx, |settings| {
1057 settings
1058 .project
1059 .all_languages
1060 .defaults
1061 .remove_trailing_whitespace_on_save = Some(false);
1062 });
1063 });
1064 });
1065
1066 // Stream edits again with trailing whitespace
1067 let edit_result = {
1068 let edit_task = cx.update(|cx| {
1069 let input = EditFileToolInput {
1070 display_description: "Update main function".into(),
1071 path: "root/src/main.rs".into(),
1072 mode: EditFileMode::Overwrite,
1073 };
1074 Arc::new(EditFileTool::new(
1075 project.clone(),
1076 thread.downgrade(),
1077 language_registry,
1078 Templates::new(),
1079 ))
1080 .run(input, ToolCallEventStream::test().0, cx)
1081 });
1082
1083 // Stream the content with trailing whitespace
1084 cx.executor().run_until_parked();
1085 model.send_last_completion_stream_text_chunk(
1086 CONTENT_WITH_TRAILING_WHITESPACE.to_string(),
1087 );
1088 model.end_last_completion_stream();
1089
1090 edit_task.await
1091 };
1092 assert!(edit_result.is_ok());
1093
1094 // Wait for any async operations (e.g. formatting) to complete
1095 cx.executor().run_until_parked();
1096
1097 // Verify the file still has trailing whitespace
1098 // Read the file again - it should still have trailing whitespace
1099 let final_content = fs.load(path!("/root/src/main.rs").as_ref()).await.unwrap();
1100 assert_eq!(
1101 // Ignore carriage returns on Windows
1102 final_content.replace("\r\n", "\n"),
1103 CONTENT_WITH_TRAILING_WHITESPACE,
1104 "Trailing whitespace should remain when remove_trailing_whitespace_on_save is disabled"
1105 );
1106 }
1107
1108 #[gpui::test]
1109 async fn test_authorize(cx: &mut TestAppContext) {
1110 init_test(cx);
1111 let fs = project::FakeFs::new(cx.executor());
1112 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
1113 let context_server_registry =
1114 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1115 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1116 let model = Arc::new(FakeLanguageModel::default());
1117 let thread = cx.new(|cx| {
1118 Thread::new(
1119 project.clone(),
1120 cx.new(|_cx| ProjectContext::default()),
1121 context_server_registry,
1122 Templates::new(),
1123 Some(model.clone()),
1124 cx,
1125 )
1126 });
1127 let tool = Arc::new(EditFileTool::new(
1128 project.clone(),
1129 thread.downgrade(),
1130 language_registry,
1131 Templates::new(),
1132 ));
1133 fs.insert_tree("/root", json!({})).await;
1134
1135 // Test 1: Path with .zed component should require confirmation
1136 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1137 let _auth = cx.update(|cx| {
1138 tool.authorize(
1139 &EditFileToolInput {
1140 display_description: "test 1".into(),
1141 path: ".zed/settings.json".into(),
1142 mode: EditFileMode::Edit,
1143 },
1144 &stream_tx,
1145 cx,
1146 )
1147 });
1148
1149 let event = stream_rx.expect_authorization().await;
1150 assert_eq!(
1151 event.tool_call.fields.title,
1152 Some("test 1 (local settings)".into())
1153 );
1154
1155 // Test 2: Path outside project should require confirmation
1156 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1157 let _auth = cx.update(|cx| {
1158 tool.authorize(
1159 &EditFileToolInput {
1160 display_description: "test 2".into(),
1161 path: "/etc/hosts".into(),
1162 mode: EditFileMode::Edit,
1163 },
1164 &stream_tx,
1165 cx,
1166 )
1167 });
1168
1169 let event = stream_rx.expect_authorization().await;
1170 assert_eq!(event.tool_call.fields.title, Some("test 2".into()));
1171
1172 // Test 3: Relative path without .zed should not require confirmation
1173 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1174 cx.update(|cx| {
1175 tool.authorize(
1176 &EditFileToolInput {
1177 display_description: "test 3".into(),
1178 path: "root/src/main.rs".into(),
1179 mode: EditFileMode::Edit,
1180 },
1181 &stream_tx,
1182 cx,
1183 )
1184 })
1185 .await
1186 .unwrap();
1187 assert!(stream_rx.try_next().is_err());
1188
1189 // Test 4: Path with .zed in the middle should require confirmation
1190 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1191 let _auth = cx.update(|cx| {
1192 tool.authorize(
1193 &EditFileToolInput {
1194 display_description: "test 4".into(),
1195 path: "root/.zed/tasks.json".into(),
1196 mode: EditFileMode::Edit,
1197 },
1198 &stream_tx,
1199 cx,
1200 )
1201 });
1202 let event = stream_rx.expect_authorization().await;
1203 assert_eq!(
1204 event.tool_call.fields.title,
1205 Some("test 4 (local settings)".into())
1206 );
1207
1208 // Test 5: When always_allow_tool_actions is enabled, no confirmation needed
1209 cx.update(|cx| {
1210 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
1211 settings.always_allow_tool_actions = true;
1212 agent_settings::AgentSettings::override_global(settings, cx);
1213 });
1214
1215 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1216 cx.update(|cx| {
1217 tool.authorize(
1218 &EditFileToolInput {
1219 display_description: "test 5.1".into(),
1220 path: ".zed/settings.json".into(),
1221 mode: EditFileMode::Edit,
1222 },
1223 &stream_tx,
1224 cx,
1225 )
1226 })
1227 .await
1228 .unwrap();
1229 assert!(stream_rx.try_next().is_err());
1230
1231 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1232 cx.update(|cx| {
1233 tool.authorize(
1234 &EditFileToolInput {
1235 display_description: "test 5.2".into(),
1236 path: "/etc/hosts".into(),
1237 mode: EditFileMode::Edit,
1238 },
1239 &stream_tx,
1240 cx,
1241 )
1242 })
1243 .await
1244 .unwrap();
1245 assert!(stream_rx.try_next().is_err());
1246 }
1247
1248 #[gpui::test]
1249 async fn test_authorize_global_config(cx: &mut TestAppContext) {
1250 init_test(cx);
1251 let fs = project::FakeFs::new(cx.executor());
1252 fs.insert_tree("/project", json!({})).await;
1253 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1254 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1255 let context_server_registry =
1256 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1257 let model = Arc::new(FakeLanguageModel::default());
1258 let thread = cx.new(|cx| {
1259 Thread::new(
1260 project.clone(),
1261 cx.new(|_cx| ProjectContext::default()),
1262 context_server_registry,
1263 Templates::new(),
1264 Some(model.clone()),
1265 cx,
1266 )
1267 });
1268 let tool = Arc::new(EditFileTool::new(
1269 project.clone(),
1270 thread.downgrade(),
1271 language_registry,
1272 Templates::new(),
1273 ));
1274
1275 // Test global config paths - these should require confirmation if they exist and are outside the project
1276 let test_cases = vec![
1277 (
1278 "/etc/hosts",
1279 true,
1280 "System file should require confirmation",
1281 ),
1282 (
1283 "/usr/local/bin/script",
1284 true,
1285 "System bin file should require confirmation",
1286 ),
1287 (
1288 "project/normal_file.rs",
1289 false,
1290 "Normal project file should not require confirmation",
1291 ),
1292 ];
1293
1294 for (path, should_confirm, description) in test_cases {
1295 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1296 let auth = cx.update(|cx| {
1297 tool.authorize(
1298 &EditFileToolInput {
1299 display_description: "Edit file".into(),
1300 path: path.into(),
1301 mode: EditFileMode::Edit,
1302 },
1303 &stream_tx,
1304 cx,
1305 )
1306 });
1307
1308 if should_confirm {
1309 stream_rx.expect_authorization().await;
1310 } else {
1311 auth.await.unwrap();
1312 assert!(
1313 stream_rx.try_next().is_err(),
1314 "Failed for case: {} - path: {} - expected no confirmation but got one",
1315 description,
1316 path
1317 );
1318 }
1319 }
1320 }
1321
1322 #[gpui::test]
1323 async fn test_needs_confirmation_with_multiple_worktrees(cx: &mut TestAppContext) {
1324 init_test(cx);
1325 let fs = project::FakeFs::new(cx.executor());
1326
1327 // Create multiple worktree directories
1328 fs.insert_tree(
1329 "/workspace/frontend",
1330 json!({
1331 "src": {
1332 "main.js": "console.log('frontend');"
1333 }
1334 }),
1335 )
1336 .await;
1337 fs.insert_tree(
1338 "/workspace/backend",
1339 json!({
1340 "src": {
1341 "main.rs": "fn main() {}"
1342 }
1343 }),
1344 )
1345 .await;
1346 fs.insert_tree(
1347 "/workspace/shared",
1348 json!({
1349 ".zed": {
1350 "settings.json": "{}"
1351 }
1352 }),
1353 )
1354 .await;
1355
1356 // Create project with multiple worktrees
1357 let project = Project::test(
1358 fs.clone(),
1359 [
1360 path!("/workspace/frontend").as_ref(),
1361 path!("/workspace/backend").as_ref(),
1362 path!("/workspace/shared").as_ref(),
1363 ],
1364 cx,
1365 )
1366 .await;
1367 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1368 let context_server_registry =
1369 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1370 let model = Arc::new(FakeLanguageModel::default());
1371 let thread = cx.new(|cx| {
1372 Thread::new(
1373 project.clone(),
1374 cx.new(|_cx| ProjectContext::default()),
1375 context_server_registry.clone(),
1376 Templates::new(),
1377 Some(model.clone()),
1378 cx,
1379 )
1380 });
1381 let tool = Arc::new(EditFileTool::new(
1382 project.clone(),
1383 thread.downgrade(),
1384 language_registry,
1385 Templates::new(),
1386 ));
1387
1388 // Test files in different worktrees
1389 let test_cases = vec![
1390 ("frontend/src/main.js", false, "File in first worktree"),
1391 ("backend/src/main.rs", false, "File in second worktree"),
1392 (
1393 "shared/.zed/settings.json",
1394 true,
1395 ".zed file in third worktree",
1396 ),
1397 ("/etc/hosts", true, "Absolute path outside all worktrees"),
1398 (
1399 "../outside/file.txt",
1400 true,
1401 "Relative path outside worktrees",
1402 ),
1403 ];
1404
1405 for (path, should_confirm, description) in test_cases {
1406 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1407 let auth = cx.update(|cx| {
1408 tool.authorize(
1409 &EditFileToolInput {
1410 display_description: "Edit file".into(),
1411 path: path.into(),
1412 mode: EditFileMode::Edit,
1413 },
1414 &stream_tx,
1415 cx,
1416 )
1417 });
1418
1419 if should_confirm {
1420 stream_rx.expect_authorization().await;
1421 } else {
1422 auth.await.unwrap();
1423 assert!(
1424 stream_rx.try_next().is_err(),
1425 "Failed for case: {} - path: {} - expected no confirmation but got one",
1426 description,
1427 path
1428 );
1429 }
1430 }
1431 }
1432
1433 #[gpui::test]
1434 async fn test_needs_confirmation_edge_cases(cx: &mut TestAppContext) {
1435 init_test(cx);
1436 let fs = project::FakeFs::new(cx.executor());
1437 fs.insert_tree(
1438 "/project",
1439 json!({
1440 ".zed": {
1441 "settings.json": "{}"
1442 },
1443 "src": {
1444 ".zed": {
1445 "local.json": "{}"
1446 }
1447 }
1448 }),
1449 )
1450 .await;
1451 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1452 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1453 let context_server_registry =
1454 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1455 let model = Arc::new(FakeLanguageModel::default());
1456 let thread = cx.new(|cx| {
1457 Thread::new(
1458 project.clone(),
1459 cx.new(|_cx| ProjectContext::default()),
1460 context_server_registry.clone(),
1461 Templates::new(),
1462 Some(model.clone()),
1463 cx,
1464 )
1465 });
1466 let tool = Arc::new(EditFileTool::new(
1467 project.clone(),
1468 thread.downgrade(),
1469 language_registry,
1470 Templates::new(),
1471 ));
1472
1473 // Test edge cases
1474 let test_cases = vec![
1475 // Empty path - find_project_path returns Some for empty paths
1476 ("", false, "Empty path is treated as project root"),
1477 // Root directory
1478 ("/", true, "Root directory should be outside project"),
1479 // Parent directory references - find_project_path resolves these
1480 (
1481 "project/../other",
1482 true,
1483 "Path with .. that goes outside of root directory",
1484 ),
1485 (
1486 "project/./src/file.rs",
1487 false,
1488 "Path with . should work normally",
1489 ),
1490 // Windows-style paths (if on Windows)
1491 #[cfg(target_os = "windows")]
1492 ("C:\\Windows\\System32\\hosts", true, "Windows system path"),
1493 #[cfg(target_os = "windows")]
1494 ("project\\src\\main.rs", false, "Windows-style project path"),
1495 ];
1496
1497 for (path, should_confirm, description) in test_cases {
1498 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1499 let auth = cx.update(|cx| {
1500 tool.authorize(
1501 &EditFileToolInput {
1502 display_description: "Edit file".into(),
1503 path: path.into(),
1504 mode: EditFileMode::Edit,
1505 },
1506 &stream_tx,
1507 cx,
1508 )
1509 });
1510
1511 cx.run_until_parked();
1512
1513 if should_confirm {
1514 stream_rx.expect_authorization().await;
1515 } else {
1516 assert!(
1517 stream_rx.try_next().is_err(),
1518 "Failed for case: {} - path: {} - expected no confirmation but got one",
1519 description,
1520 path
1521 );
1522 auth.await.unwrap();
1523 }
1524 }
1525 }
1526
1527 #[gpui::test]
1528 async fn test_needs_confirmation_with_different_modes(cx: &mut TestAppContext) {
1529 init_test(cx);
1530 let fs = project::FakeFs::new(cx.executor());
1531 fs.insert_tree(
1532 "/project",
1533 json!({
1534 "existing.txt": "content",
1535 ".zed": {
1536 "settings.json": "{}"
1537 }
1538 }),
1539 )
1540 .await;
1541 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1542 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1543 let context_server_registry =
1544 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1545 let model = Arc::new(FakeLanguageModel::default());
1546 let thread = cx.new(|cx| {
1547 Thread::new(
1548 project.clone(),
1549 cx.new(|_cx| ProjectContext::default()),
1550 context_server_registry.clone(),
1551 Templates::new(),
1552 Some(model.clone()),
1553 cx,
1554 )
1555 });
1556 let tool = Arc::new(EditFileTool::new(
1557 project.clone(),
1558 thread.downgrade(),
1559 language_registry,
1560 Templates::new(),
1561 ));
1562
1563 // Test different EditFileMode values
1564 let modes = vec![
1565 EditFileMode::Edit,
1566 EditFileMode::Create,
1567 EditFileMode::Overwrite,
1568 ];
1569
1570 for mode in modes {
1571 // Test .zed path with different modes
1572 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1573 let _auth = cx.update(|cx| {
1574 tool.authorize(
1575 &EditFileToolInput {
1576 display_description: "Edit settings".into(),
1577 path: "project/.zed/settings.json".into(),
1578 mode: mode.clone(),
1579 },
1580 &stream_tx,
1581 cx,
1582 )
1583 });
1584
1585 stream_rx.expect_authorization().await;
1586
1587 // Test outside path with different modes
1588 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1589 let _auth = cx.update(|cx| {
1590 tool.authorize(
1591 &EditFileToolInput {
1592 display_description: "Edit file".into(),
1593 path: "/outside/file.txt".into(),
1594 mode: mode.clone(),
1595 },
1596 &stream_tx,
1597 cx,
1598 )
1599 });
1600
1601 stream_rx.expect_authorization().await;
1602
1603 // Test normal path with different modes
1604 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1605 cx.update(|cx| {
1606 tool.authorize(
1607 &EditFileToolInput {
1608 display_description: "Edit file".into(),
1609 path: "project/normal.txt".into(),
1610 mode: mode.clone(),
1611 },
1612 &stream_tx,
1613 cx,
1614 )
1615 })
1616 .await
1617 .unwrap();
1618 assert!(stream_rx.try_next().is_err());
1619 }
1620 }
1621
1622 #[gpui::test]
1623 async fn test_initial_title_with_partial_input(cx: &mut TestAppContext) {
1624 init_test(cx);
1625 let fs = project::FakeFs::new(cx.executor());
1626 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1627 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1628 let context_server_registry =
1629 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1630 let model = Arc::new(FakeLanguageModel::default());
1631 let thread = cx.new(|cx| {
1632 Thread::new(
1633 project.clone(),
1634 cx.new(|_cx| ProjectContext::default()),
1635 context_server_registry,
1636 Templates::new(),
1637 Some(model.clone()),
1638 cx,
1639 )
1640 });
1641 let tool = Arc::new(EditFileTool::new(
1642 project,
1643 thread.downgrade(),
1644 language_registry,
1645 Templates::new(),
1646 ));
1647
1648 cx.update(|cx| {
1649 // ...
1650 assert_eq!(
1651 tool.initial_title(
1652 Err(json!({
1653 "path": "src/main.rs",
1654 "display_description": "",
1655 "old_string": "old code",
1656 "new_string": "new code"
1657 })),
1658 cx
1659 ),
1660 "src/main.rs"
1661 );
1662 assert_eq!(
1663 tool.initial_title(
1664 Err(json!({
1665 "path": "",
1666 "display_description": "Fix error handling",
1667 "old_string": "old code",
1668 "new_string": "new code"
1669 })),
1670 cx
1671 ),
1672 "Fix error handling"
1673 );
1674 assert_eq!(
1675 tool.initial_title(
1676 Err(json!({
1677 "path": "src/main.rs",
1678 "display_description": "Fix error handling",
1679 "old_string": "old code",
1680 "new_string": "new code"
1681 })),
1682 cx
1683 ),
1684 "src/main.rs"
1685 );
1686 assert_eq!(
1687 tool.initial_title(
1688 Err(json!({
1689 "path": "",
1690 "display_description": "",
1691 "old_string": "old code",
1692 "new_string": "new code"
1693 })),
1694 cx
1695 ),
1696 DEFAULT_UI_TEXT
1697 );
1698 assert_eq!(
1699 tool.initial_title(Err(serde_json::Value::Null), cx),
1700 DEFAULT_UI_TEXT
1701 );
1702 });
1703 }
1704
1705 #[gpui::test]
1706 async fn test_diff_finalization(cx: &mut TestAppContext) {
1707 init_test(cx);
1708 let fs = project::FakeFs::new(cx.executor());
1709 fs.insert_tree("/", json!({"main.rs": ""})).await;
1710
1711 let project = Project::test(fs.clone(), [path!("/").as_ref()], cx).await;
1712 let languages = project.read_with(cx, |project, _cx| project.languages().clone());
1713 let context_server_registry =
1714 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1715 let model = Arc::new(FakeLanguageModel::default());
1716 let thread = cx.new(|cx| {
1717 Thread::new(
1718 project.clone(),
1719 cx.new(|_cx| ProjectContext::default()),
1720 context_server_registry.clone(),
1721 Templates::new(),
1722 Some(model.clone()),
1723 cx,
1724 )
1725 });
1726
1727 // Ensure the diff is finalized after the edit completes.
1728 {
1729 let tool = Arc::new(EditFileTool::new(
1730 project.clone(),
1731 thread.downgrade(),
1732 languages.clone(),
1733 Templates::new(),
1734 ));
1735 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1736 let edit = cx.update(|cx| {
1737 tool.run(
1738 EditFileToolInput {
1739 display_description: "Edit file".into(),
1740 path: path!("/main.rs").into(),
1741 mode: EditFileMode::Edit,
1742 },
1743 stream_tx,
1744 cx,
1745 )
1746 });
1747 stream_rx.expect_update_fields().await;
1748 let diff = stream_rx.expect_diff().await;
1749 diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Pending(_))));
1750 cx.run_until_parked();
1751 model.end_last_completion_stream();
1752 edit.await.unwrap();
1753 diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Finalized(_))));
1754 }
1755
1756 // Ensure the diff is finalized if an error occurs while editing.
1757 {
1758 model.forbid_requests();
1759 let tool = Arc::new(EditFileTool::new(
1760 project.clone(),
1761 thread.downgrade(),
1762 languages.clone(),
1763 Templates::new(),
1764 ));
1765 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1766 let edit = cx.update(|cx| {
1767 tool.run(
1768 EditFileToolInput {
1769 display_description: "Edit file".into(),
1770 path: path!("/main.rs").into(),
1771 mode: EditFileMode::Edit,
1772 },
1773 stream_tx,
1774 cx,
1775 )
1776 });
1777 stream_rx.expect_update_fields().await;
1778 let diff = stream_rx.expect_diff().await;
1779 diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Pending(_))));
1780 edit.await.unwrap_err();
1781 diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Finalized(_))));
1782 model.allow_requests();
1783 }
1784
1785 // Ensure the diff is finalized if the tool call gets dropped.
1786 {
1787 let tool = Arc::new(EditFileTool::new(
1788 project.clone(),
1789 thread.downgrade(),
1790 languages.clone(),
1791 Templates::new(),
1792 ));
1793 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1794 let edit = cx.update(|cx| {
1795 tool.run(
1796 EditFileToolInput {
1797 display_description: "Edit file".into(),
1798 path: path!("/main.rs").into(),
1799 mode: EditFileMode::Edit,
1800 },
1801 stream_tx,
1802 cx,
1803 )
1804 });
1805 stream_rx.expect_update_fields().await;
1806 let diff = stream_rx.expect_diff().await;
1807 diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Pending(_))));
1808 drop(edit);
1809 cx.run_until_parked();
1810 diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Finalized(_))));
1811 }
1812 }
1813
1814 #[gpui::test]
1815 async fn test_file_read_times_tracking(cx: &mut TestAppContext) {
1816 init_test(cx);
1817
1818 let fs = project::FakeFs::new(cx.executor());
1819 fs.insert_tree(
1820 "/root",
1821 json!({
1822 "test.txt": "original content"
1823 }),
1824 )
1825 .await;
1826 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
1827 let context_server_registry =
1828 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1829 let model = Arc::new(FakeLanguageModel::default());
1830 let thread = cx.new(|cx| {
1831 Thread::new(
1832 project.clone(),
1833 cx.new(|_cx| ProjectContext::default()),
1834 context_server_registry,
1835 Templates::new(),
1836 Some(model.clone()),
1837 cx,
1838 )
1839 });
1840 let action_log = thread.read_with(cx, |thread, _| thread.action_log().clone());
1841
1842 // Initially, file_read_times should be empty
1843 let is_empty = thread.read_with(cx, |thread, _| thread.file_read_times.is_empty());
1844 assert!(is_empty, "file_read_times should start empty");
1845
1846 // Create read tool
1847 let read_tool = Arc::new(crate::ReadFileTool::new(
1848 thread.downgrade(),
1849 project.clone(),
1850 action_log,
1851 ));
1852
1853 // Read the file to record the read time
1854 cx.update(|cx| {
1855 read_tool.clone().run(
1856 crate::ReadFileToolInput {
1857 path: "root/test.txt".to_string(),
1858 start_line: None,
1859 end_line: None,
1860 },
1861 ToolCallEventStream::test().0,
1862 cx,
1863 )
1864 })
1865 .await
1866 .unwrap();
1867
1868 // Verify that file_read_times now contains an entry for the file
1869 let has_entry = thread.read_with(cx, |thread, _| {
1870 thread.file_read_times.len() == 1
1871 && thread
1872 .file_read_times
1873 .keys()
1874 .any(|path| path.ends_with("test.txt"))
1875 });
1876 assert!(
1877 has_entry,
1878 "file_read_times should contain an entry after reading the file"
1879 );
1880
1881 // Read the file again - should update the entry
1882 cx.update(|cx| {
1883 read_tool.clone().run(
1884 crate::ReadFileToolInput {
1885 path: "root/test.txt".to_string(),
1886 start_line: None,
1887 end_line: None,
1888 },
1889 ToolCallEventStream::test().0,
1890 cx,
1891 )
1892 })
1893 .await
1894 .unwrap();
1895
1896 // Should still have exactly one entry
1897 let has_one_entry = thread.read_with(cx, |thread, _| thread.file_read_times.len() == 1);
1898 assert!(
1899 has_one_entry,
1900 "file_read_times should still have one entry after re-reading"
1901 );
1902 }
1903
1904 fn init_test(cx: &mut TestAppContext) {
1905 cx.update(|cx| {
1906 let settings_store = SettingsStore::test(cx);
1907 cx.set_global(settings_store);
1908 });
1909 }
1910
1911 #[gpui::test]
1912 async fn test_consecutive_edits_work(cx: &mut TestAppContext) {
1913 init_test(cx);
1914
1915 let fs = project::FakeFs::new(cx.executor());
1916 fs.insert_tree(
1917 "/root",
1918 json!({
1919 "test.txt": "original content"
1920 }),
1921 )
1922 .await;
1923 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
1924 let context_server_registry =
1925 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1926 let model = Arc::new(FakeLanguageModel::default());
1927 let thread = cx.new(|cx| {
1928 Thread::new(
1929 project.clone(),
1930 cx.new(|_cx| ProjectContext::default()),
1931 context_server_registry,
1932 Templates::new(),
1933 Some(model.clone()),
1934 cx,
1935 )
1936 });
1937 let languages = project.read_with(cx, |project, _| project.languages().clone());
1938 let action_log = thread.read_with(cx, |thread, _| thread.action_log().clone());
1939
1940 let read_tool = Arc::new(crate::ReadFileTool::new(
1941 thread.downgrade(),
1942 project.clone(),
1943 action_log,
1944 ));
1945 let edit_tool = Arc::new(EditFileTool::new(
1946 project.clone(),
1947 thread.downgrade(),
1948 languages,
1949 Templates::new(),
1950 ));
1951
1952 // Read the file first
1953 cx.update(|cx| {
1954 read_tool.clone().run(
1955 crate::ReadFileToolInput {
1956 path: "root/test.txt".to_string(),
1957 start_line: None,
1958 end_line: None,
1959 },
1960 ToolCallEventStream::test().0,
1961 cx,
1962 )
1963 })
1964 .await
1965 .unwrap();
1966
1967 // First edit should work
1968 let edit_result = {
1969 let edit_task = cx.update(|cx| {
1970 edit_tool.clone().run(
1971 EditFileToolInput {
1972 display_description: "First edit".into(),
1973 path: "root/test.txt".into(),
1974 mode: EditFileMode::Edit,
1975 },
1976 ToolCallEventStream::test().0,
1977 cx,
1978 )
1979 });
1980
1981 cx.executor().run_until_parked();
1982 model.send_last_completion_stream_text_chunk(
1983 "<old_text>original content</old_text><new_text>modified content</new_text>"
1984 .to_string(),
1985 );
1986 model.end_last_completion_stream();
1987
1988 edit_task.await
1989 };
1990 assert!(
1991 edit_result.is_ok(),
1992 "First edit should succeed, got error: {:?}",
1993 edit_result.as_ref().err()
1994 );
1995
1996 // Second edit should also work because the edit updated the recorded read time
1997 let edit_result = {
1998 let edit_task = cx.update(|cx| {
1999 edit_tool.clone().run(
2000 EditFileToolInput {
2001 display_description: "Second edit".into(),
2002 path: "root/test.txt".into(),
2003 mode: EditFileMode::Edit,
2004 },
2005 ToolCallEventStream::test().0,
2006 cx,
2007 )
2008 });
2009
2010 cx.executor().run_until_parked();
2011 model.send_last_completion_stream_text_chunk(
2012 "<old_text>modified content</old_text><new_text>further modified content</new_text>".to_string(),
2013 );
2014 model.end_last_completion_stream();
2015
2016 edit_task.await
2017 };
2018 assert!(
2019 edit_result.is_ok(),
2020 "Second consecutive edit should succeed, got error: {:?}",
2021 edit_result.as_ref().err()
2022 );
2023 }
2024
2025 #[gpui::test]
2026 async fn test_external_modification_detected(cx: &mut TestAppContext) {
2027 init_test(cx);
2028
2029 let fs = project::FakeFs::new(cx.executor());
2030 fs.insert_tree(
2031 "/root",
2032 json!({
2033 "test.txt": "original content"
2034 }),
2035 )
2036 .await;
2037 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
2038 let context_server_registry =
2039 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
2040 let model = Arc::new(FakeLanguageModel::default());
2041 let thread = cx.new(|cx| {
2042 Thread::new(
2043 project.clone(),
2044 cx.new(|_cx| ProjectContext::default()),
2045 context_server_registry,
2046 Templates::new(),
2047 Some(model.clone()),
2048 cx,
2049 )
2050 });
2051 let languages = project.read_with(cx, |project, _| project.languages().clone());
2052 let action_log = thread.read_with(cx, |thread, _| thread.action_log().clone());
2053
2054 let read_tool = Arc::new(crate::ReadFileTool::new(
2055 thread.downgrade(),
2056 project.clone(),
2057 action_log,
2058 ));
2059 let edit_tool = Arc::new(EditFileTool::new(
2060 project.clone(),
2061 thread.downgrade(),
2062 languages,
2063 Templates::new(),
2064 ));
2065
2066 // Read the file first
2067 cx.update(|cx| {
2068 read_tool.clone().run(
2069 crate::ReadFileToolInput {
2070 path: "root/test.txt".to_string(),
2071 start_line: None,
2072 end_line: None,
2073 },
2074 ToolCallEventStream::test().0,
2075 cx,
2076 )
2077 })
2078 .await
2079 .unwrap();
2080
2081 // Simulate external modification - advance time and save file
2082 cx.background_executor
2083 .advance_clock(std::time::Duration::from_secs(2));
2084 fs.save(
2085 path!("/root/test.txt").as_ref(),
2086 &"externally modified content".into(),
2087 language::LineEnding::Unix,
2088 )
2089 .await
2090 .unwrap();
2091
2092 // Reload the buffer to pick up the new mtime
2093 let project_path = project
2094 .read_with(cx, |project, cx| {
2095 project.find_project_path("root/test.txt", cx)
2096 })
2097 .expect("Should find project path");
2098 let buffer = project
2099 .update(cx, |project, cx| project.open_buffer(project_path, cx))
2100 .await
2101 .unwrap();
2102 buffer
2103 .update(cx, |buffer, cx| buffer.reload(cx))
2104 .await
2105 .unwrap();
2106
2107 cx.executor().run_until_parked();
2108
2109 // Try to edit - should fail because file was modified externally
2110 let result = cx
2111 .update(|cx| {
2112 edit_tool.clone().run(
2113 EditFileToolInput {
2114 display_description: "Edit after external change".into(),
2115 path: "root/test.txt".into(),
2116 mode: EditFileMode::Edit,
2117 },
2118 ToolCallEventStream::test().0,
2119 cx,
2120 )
2121 })
2122 .await;
2123
2124 assert!(
2125 result.is_err(),
2126 "Edit should fail after external modification"
2127 );
2128 let error_msg = result.unwrap_err().to_string();
2129 assert!(
2130 error_msg.contains("has been modified since you last read it"),
2131 "Error should mention file modification, got: {}",
2132 error_msg
2133 );
2134 }
2135
2136 #[gpui::test]
2137 async fn test_dirty_buffer_detected(cx: &mut TestAppContext) {
2138 init_test(cx);
2139
2140 let fs = project::FakeFs::new(cx.executor());
2141 fs.insert_tree(
2142 "/root",
2143 json!({
2144 "test.txt": "original content"
2145 }),
2146 )
2147 .await;
2148 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
2149 let context_server_registry =
2150 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
2151 let model = Arc::new(FakeLanguageModel::default());
2152 let thread = cx.new(|cx| {
2153 Thread::new(
2154 project.clone(),
2155 cx.new(|_cx| ProjectContext::default()),
2156 context_server_registry,
2157 Templates::new(),
2158 Some(model.clone()),
2159 cx,
2160 )
2161 });
2162 let languages = project.read_with(cx, |project, _| project.languages().clone());
2163 let action_log = thread.read_with(cx, |thread, _| thread.action_log().clone());
2164
2165 let read_tool = Arc::new(crate::ReadFileTool::new(
2166 thread.downgrade(),
2167 project.clone(),
2168 action_log,
2169 ));
2170 let edit_tool = Arc::new(EditFileTool::new(
2171 project.clone(),
2172 thread.downgrade(),
2173 languages,
2174 Templates::new(),
2175 ));
2176
2177 // Read the file first
2178 cx.update(|cx| {
2179 read_tool.clone().run(
2180 crate::ReadFileToolInput {
2181 path: "root/test.txt".to_string(),
2182 start_line: None,
2183 end_line: None,
2184 },
2185 ToolCallEventStream::test().0,
2186 cx,
2187 )
2188 })
2189 .await
2190 .unwrap();
2191
2192 // Open the buffer and make it dirty by editing without saving
2193 let project_path = project
2194 .read_with(cx, |project, cx| {
2195 project.find_project_path("root/test.txt", cx)
2196 })
2197 .expect("Should find project path");
2198 let buffer = project
2199 .update(cx, |project, cx| project.open_buffer(project_path, cx))
2200 .await
2201 .unwrap();
2202
2203 // Make an in-memory edit to the buffer (making it dirty)
2204 buffer.update(cx, |buffer, cx| {
2205 let end_point = buffer.max_point();
2206 buffer.edit([(end_point..end_point, " added text")], None, cx);
2207 });
2208
2209 // Verify buffer is dirty
2210 let is_dirty = buffer.read_with(cx, |buffer, _| buffer.is_dirty());
2211 assert!(is_dirty, "Buffer should be dirty after in-memory edit");
2212
2213 // Try to edit - should fail because buffer has unsaved changes
2214 let result = cx
2215 .update(|cx| {
2216 edit_tool.clone().run(
2217 EditFileToolInput {
2218 display_description: "Edit with dirty buffer".into(),
2219 path: "root/test.txt".into(),
2220 mode: EditFileMode::Edit,
2221 },
2222 ToolCallEventStream::test().0,
2223 cx,
2224 )
2225 })
2226 .await;
2227
2228 assert!(result.is_err(), "Edit should fail when buffer is dirty");
2229 let error_msg = result.unwrap_err().to_string();
2230 assert!(
2231 error_msg.contains("This file has unsaved changes."),
2232 "Error should mention unsaved changes, got: {}",
2233 error_msg
2234 );
2235 assert!(
2236 error_msg.contains("keep or discard"),
2237 "Error should ask whether to keep or discard changes, got: {}",
2238 error_msg
2239 );
2240 // Since save_file and restore_file_from_disk tools aren't added to the thread,
2241 // the error message should ask the user to manually save or revert
2242 assert!(
2243 error_msg.contains("save or revert the file manually"),
2244 "Error should ask user to manually save or revert when tools aren't available, got: {}",
2245 error_msg
2246 );
2247 }
2248}