1use crate::{
2 AgentTool, Templates, Thread, ToolCallEventStream,
3 edit_agent::{EditAgent, EditAgentOutput, EditAgentOutputEvent, EditFormat},
4};
5use acp_thread::Diff;
6use agent_client_protocol::{self as acp, ToolCallLocation, ToolCallUpdateFields};
7use anyhow::{Context as _, Result, anyhow};
8use cloud_llm_client::CompletionIntent;
9use collections::HashSet;
10use gpui::{App, AppContext, AsyncApp, Entity, Task, WeakEntity};
11use indoc::formatdoc;
12use language::language_settings::{self, FormatOnSave};
13use language::{LanguageRegistry, ToPoint};
14use language_model::LanguageModelToolResultContent;
15use paths;
16use project::lsp_store::{FormatTrigger, LspFormatTarget};
17use project::{Project, ProjectPath};
18use schemars::JsonSchema;
19use serde::{Deserialize, Serialize};
20use settings::Settings;
21use smol::stream::StreamExt as _;
22use std::ffi::OsStr;
23use std::path::{Path, PathBuf};
24use std::sync::Arc;
25use ui::SharedString;
26use util::ResultExt;
27use util::rel_path::RelPath;
28
29const DEFAULT_UI_TEXT: &str = "Editing file";
30
31/// This is a tool for creating a new file or editing an existing file. For moving or renaming files, you should generally use the `terminal` tool with the 'mv' command instead.
32///
33/// Before using this tool:
34///
35/// 1. Use the `read_file` tool to understand the file's contents and context
36///
37/// 2. Verify the directory path is correct (only applicable when creating new files):
38/// - Use the `list_directory` tool to verify the parent directory exists and is the correct location
39#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
40pub struct EditFileToolInput {
41 /// A one-line, user-friendly markdown description of the edit. This will be shown in the UI and also passed to another model to perform the edit.
42 ///
43 /// Be terse, but also descriptive in what you want to achieve with this edit. Avoid generic instructions.
44 ///
45 /// NEVER mention the file path in this description.
46 ///
47 /// IMPORTANT: Do NOT include markdown code fences (```) or other markdown formatting in this description.
48 /// Just describe what should be done - another model will generate the actual content.
49 ///
50 /// <example>Fix API endpoint URLs</example>
51 /// <example>Update copyright year in `page_footer`</example>
52 /// <example>Create a Python script that prints hello world</example>
53 ///
54 /// INCORRECT examples (do not do this):
55 /// <example>Create a file with:\n```python\nprint('hello')\n```</example>
56 /// <example>Add this code:\n```\nif err:\n return\n```</example>
57 ///
58 /// Make sure to include this field before all the others in the input object so that we can display it immediately.
59 pub display_description: String,
60
61 /// The full path of the file to create or modify in the project.
62 ///
63 /// WARNING: When specifying which file path need changing, you MUST start each path with one of the project's root directories.
64 ///
65 /// The following examples assume we have two root directories in the project:
66 /// - /a/b/backend
67 /// - /c/d/frontend
68 ///
69 /// <example>
70 /// `backend/src/main.rs`
71 ///
72 /// Notice how the file path starts with `backend`. Without that, the path would be ambiguous and the call would fail!
73 /// </example>
74 ///
75 /// <example>
76 /// `frontend/db.js`
77 /// </example>
78 pub path: PathBuf,
79 /// The mode of operation on the file. Possible values:
80 /// - 'edit': Make granular edits to an existing file.
81 /// - 'create': Create a new file if it doesn't exist.
82 /// - 'overwrite': Replace the entire contents of an existing file.
83 ///
84 /// When a file already exists or you just created it, prefer editing it as opposed to recreating it from scratch.
85 pub mode: EditFileMode,
86}
87
88#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
89struct EditFileToolPartialInput {
90 #[serde(default)]
91 path: String,
92 #[serde(default)]
93 display_description: String,
94}
95
96#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
97#[serde(rename_all = "lowercase")]
98#[schemars(inline)]
99pub enum EditFileMode {
100 Edit,
101 Create,
102 Overwrite,
103}
104
105#[derive(Debug, Serialize, Deserialize)]
106pub struct EditFileToolOutput {
107 #[serde(alias = "original_path")]
108 input_path: PathBuf,
109 new_text: String,
110 old_text: Arc<String>,
111 #[serde(default)]
112 diff: String,
113 #[serde(alias = "raw_output")]
114 edit_agent_output: EditAgentOutput,
115}
116
117impl From<EditFileToolOutput> for LanguageModelToolResultContent {
118 fn from(output: EditFileToolOutput) -> Self {
119 if output.diff.is_empty() {
120 "No edits were made.".into()
121 } else {
122 format!(
123 "Edited {}:\n\n```diff\n{}\n```",
124 output.input_path.display(),
125 output.diff
126 )
127 .into()
128 }
129 }
130}
131
132pub struct EditFileTool {
133 thread: WeakEntity<Thread>,
134 language_registry: Arc<LanguageRegistry>,
135 project: Entity<Project>,
136 templates: Arc<Templates>,
137}
138
139impl EditFileTool {
140 pub fn new(
141 project: Entity<Project>,
142 thread: WeakEntity<Thread>,
143 language_registry: Arc<LanguageRegistry>,
144 templates: Arc<Templates>,
145 ) -> Self {
146 Self {
147 project,
148 thread,
149 language_registry,
150 templates,
151 }
152 }
153
154 fn authorize(
155 &self,
156 input: &EditFileToolInput,
157 event_stream: &ToolCallEventStream,
158 cx: &mut App,
159 ) -> Task<Result<()>> {
160 if agent_settings::AgentSettings::get_global(cx).always_allow_tool_actions {
161 return Task::ready(Ok(()));
162 }
163
164 // If any path component matches the local settings folder, then this could affect
165 // the editor in ways beyond the project source, so prompt.
166 let local_settings_folder = paths::local_settings_folder_name();
167 let path = Path::new(&input.path);
168 if path.components().any(|component| {
169 component.as_os_str() == <_ as AsRef<OsStr>>::as_ref(&local_settings_folder)
170 }) {
171 return event_stream.authorize(
172 format!("{} (local settings)", input.display_description),
173 cx,
174 );
175 }
176
177 // It's also possible that the global config dir is configured to be inside the project,
178 // so check for that edge case too.
179 // TODO this is broken when remoting
180 if let Ok(canonical_path) = std::fs::canonicalize(&input.path)
181 && canonical_path.starts_with(paths::config_dir())
182 {
183 return event_stream.authorize(
184 format!("{} (global settings)", input.display_description),
185 cx,
186 );
187 }
188
189 // Check if path is inside the global config directory
190 // First check if it's already inside project - if not, try to canonicalize
191 let Ok(project_path) = self.thread.read_with(cx, |thread, cx| {
192 thread.project().read(cx).find_project_path(&input.path, cx)
193 }) else {
194 return Task::ready(Err(anyhow!("thread was dropped")));
195 };
196
197 // If the path is inside the project, and it's not one of the above edge cases,
198 // then no confirmation is necessary. Otherwise, confirmation is necessary.
199 if project_path.is_some() {
200 Task::ready(Ok(()))
201 } else {
202 event_stream.authorize(&input.display_description, cx)
203 }
204 }
205}
206
207impl AgentTool for EditFileTool {
208 type Input = EditFileToolInput;
209 type Output = EditFileToolOutput;
210
211 fn name() -> &'static str {
212 "edit_file"
213 }
214
215 fn kind() -> acp::ToolKind {
216 acp::ToolKind::Edit
217 }
218
219 fn initial_title(
220 &self,
221 input: Result<Self::Input, serde_json::Value>,
222 cx: &mut App,
223 ) -> SharedString {
224 match input {
225 Ok(input) => self
226 .project
227 .read(cx)
228 .find_project_path(&input.path, cx)
229 .and_then(|project_path| {
230 self.project
231 .read(cx)
232 .short_full_path_for_project_path(&project_path, cx)
233 })
234 .unwrap_or(input.path.to_string_lossy().into_owned())
235 .into(),
236 Err(raw_input) => {
237 if let Some(input) =
238 serde_json::from_value::<EditFileToolPartialInput>(raw_input).ok()
239 {
240 let path = input.path.trim();
241 if !path.is_empty() {
242 return self
243 .project
244 .read(cx)
245 .find_project_path(&input.path, cx)
246 .and_then(|project_path| {
247 self.project
248 .read(cx)
249 .short_full_path_for_project_path(&project_path, cx)
250 })
251 .unwrap_or(input.path)
252 .into();
253 }
254
255 let description = input.display_description.trim();
256 if !description.is_empty() {
257 return description.to_string().into();
258 }
259 }
260
261 DEFAULT_UI_TEXT.into()
262 }
263 }
264 }
265
266 fn run(
267 self: Arc<Self>,
268 input: Self::Input,
269 event_stream: ToolCallEventStream,
270 cx: &mut App,
271 ) -> Task<Result<Self::Output>> {
272 let Ok(project) = self
273 .thread
274 .read_with(cx, |thread, _cx| thread.project().clone())
275 else {
276 return Task::ready(Err(anyhow!("thread was dropped")));
277 };
278 let project_path = match resolve_path(&input, project.clone(), cx) {
279 Ok(path) => path,
280 Err(err) => return Task::ready(Err(anyhow!(err))),
281 };
282 let abs_path = project.read(cx).absolute_path(&project_path, cx);
283 if let Some(abs_path) = abs_path.clone() {
284 event_stream.update_fields(ToolCallUpdateFields {
285 locations: Some(vec![acp::ToolCallLocation {
286 path: abs_path,
287 line: None,
288 meta: None,
289 }]),
290 ..Default::default()
291 });
292 }
293
294 let authorize = self.authorize(&input, &event_stream, cx);
295 cx.spawn(async move |cx: &mut AsyncApp| {
296 authorize.await?;
297
298 let (request, model, action_log) = self.thread.update(cx, |thread, cx| {
299 let request = thread.build_completion_request(CompletionIntent::ToolResults, cx);
300 (request, thread.model().cloned(), thread.action_log().clone())
301 })?;
302 let request = request?;
303 let model = model.context("No language model configured")?;
304
305 let edit_format = EditFormat::from_model(model.clone())?;
306 let edit_agent = EditAgent::new(
307 model,
308 project.clone(),
309 action_log.clone(),
310 self.templates.clone(),
311 edit_format,
312 );
313
314 let buffer = project
315 .update(cx, |project, cx| {
316 project.open_buffer(project_path.clone(), cx)
317 })?
318 .await?;
319
320 // Check if the file has been modified since the agent last read it
321 if let Some(abs_path) = abs_path.as_ref() {
322 let (last_read_mtime, current_mtime, is_dirty) = self.thread.update(cx, |thread, cx| {
323 let last_read = thread.file_read_times.get(abs_path).copied();
324 let current = buffer.read(cx).file().and_then(|file| file.disk_state().mtime());
325 let dirty = buffer.read(cx).is_dirty();
326 (last_read, current, dirty)
327 })?;
328
329 // Check for unsaved changes first - these indicate modifications we don't know about
330 if is_dirty {
331 anyhow::bail!(
332 "This file cannot be written to because it has unsaved changes. \
333 Please end the current conversation immediately by telling the user you want to write to this file (mention its path explicitly) but you can't write to it because it has unsaved changes. \
334 Ask the user to save that buffer's changes and to inform you when it's ok to proceed."
335 );
336 }
337
338 // Check if the file was modified on disk since we last read it
339 if let (Some(last_read), Some(current)) = (last_read_mtime, current_mtime) {
340 // MTime can be unreliable for comparisons, so our newtype intentionally
341 // doesn't support comparing them. If the mtime at all different
342 // (which could be because of a modification or because e.g. system clock changed),
343 // we pessimistically assume it was modified.
344 if current != last_read {
345 anyhow::bail!(
346 "The file {} has been modified since you last read it. \
347 Please read the file again to get the current state before editing it.",
348 input.path.display()
349 );
350 }
351 }
352 }
353
354 let diff = cx.new(|cx| Diff::new(buffer.clone(), cx))?;
355 event_stream.update_diff(diff.clone());
356 let _finalize_diff = util::defer({
357 let diff = diff.downgrade();
358 let mut cx = cx.clone();
359 move || {
360 diff.update(&mut cx, |diff, cx| diff.finalize(cx)).ok();
361 }
362 });
363
364 let old_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
365 let old_text = cx
366 .background_spawn({
367 let old_snapshot = old_snapshot.clone();
368 async move { Arc::new(old_snapshot.text()) }
369 })
370 .await;
371
372
373 let (output, mut events) = if matches!(input.mode, EditFileMode::Edit) {
374 edit_agent.edit(
375 buffer.clone(),
376 input.display_description.clone(),
377 &request,
378 cx,
379 )
380 } else {
381 edit_agent.overwrite(
382 buffer.clone(),
383 input.display_description.clone(),
384 &request,
385 cx,
386 )
387 };
388
389 let mut hallucinated_old_text = false;
390 let mut ambiguous_ranges = Vec::new();
391 let mut emitted_location = false;
392 while let Some(event) = events.next().await {
393 match event {
394 EditAgentOutputEvent::Edited(range) => {
395 if !emitted_location {
396 let line = buffer.update(cx, |buffer, _cx| {
397 range.start.to_point(&buffer.snapshot()).row
398 }).ok();
399 if let Some(abs_path) = abs_path.clone() {
400 event_stream.update_fields(ToolCallUpdateFields {
401 locations: Some(vec![ToolCallLocation { path: abs_path, line, meta: None }]),
402 ..Default::default()
403 });
404 }
405 emitted_location = true;
406 }
407 },
408 EditAgentOutputEvent::UnresolvedEditRange => hallucinated_old_text = true,
409 EditAgentOutputEvent::AmbiguousEditRange(ranges) => ambiguous_ranges = ranges,
410 EditAgentOutputEvent::ResolvingEditRange(range) => {
411 diff.update(cx, |card, cx| card.reveal_range(range.clone(), cx))?;
412 // if !emitted_location {
413 // let line = buffer.update(cx, |buffer, _cx| {
414 // range.start.to_point(&buffer.snapshot()).row
415 // }).ok();
416 // if let Some(abs_path) = abs_path.clone() {
417 // event_stream.update_fields(ToolCallUpdateFields {
418 // locations: Some(vec![ToolCallLocation { path: abs_path, line }]),
419 // ..Default::default()
420 // });
421 // }
422 // }
423 }
424 }
425 }
426
427 // If format_on_save is enabled, format the buffer
428 let format_on_save_enabled = buffer
429 .read_with(cx, |buffer, cx| {
430 let settings = language_settings::language_settings(
431 buffer.language().map(|l| l.name()),
432 buffer.file(),
433 cx,
434 );
435 settings.format_on_save != FormatOnSave::Off
436 })
437 .unwrap_or(false);
438
439 let edit_agent_output = output.await?;
440
441 if format_on_save_enabled {
442 action_log.update(cx, |log, cx| {
443 log.buffer_edited(buffer.clone(), cx);
444 })?;
445
446 let format_task = project.update(cx, |project, cx| {
447 project.format(
448 HashSet::from_iter([buffer.clone()]),
449 LspFormatTarget::Buffers,
450 false, // Don't push to history since the tool did it.
451 FormatTrigger::Save,
452 cx,
453 )
454 })?;
455 format_task.await.log_err();
456 }
457
458 project
459 .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))?
460 .await?;
461
462 action_log.update(cx, |log, cx| {
463 log.buffer_edited(buffer.clone(), cx);
464 })?;
465
466 // Update the recorded read time after a successful edit so consecutive edits work
467 if let Some(abs_path) = abs_path.as_ref() {
468 if let Some(new_mtime) = buffer.read_with(cx, |buffer, _| {
469 buffer.file().and_then(|file| file.disk_state().mtime())
470 })? {
471 self.thread.update(cx, |thread, _| {
472 thread.file_read_times.insert(abs_path.to_path_buf(), new_mtime);
473 })?;
474 }
475 }
476
477 let new_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
478 let (new_text, unified_diff) = cx
479 .background_spawn({
480 let new_snapshot = new_snapshot.clone();
481 let old_text = old_text.clone();
482 async move {
483 let new_text = new_snapshot.text();
484 let diff = language::unified_diff(&old_text, &new_text);
485 (new_text, diff)
486 }
487 })
488 .await;
489
490 let input_path = input.path.display();
491 if unified_diff.is_empty() {
492 anyhow::ensure!(
493 !hallucinated_old_text,
494 formatdoc! {"
495 Some edits were produced but none of them could be applied.
496 Read the relevant sections of {input_path} again so that
497 I can perform the requested edits.
498 "}
499 );
500 anyhow::ensure!(
501 ambiguous_ranges.is_empty(),
502 {
503 let line_numbers = ambiguous_ranges
504 .iter()
505 .map(|range| range.start.to_string())
506 .collect::<Vec<_>>()
507 .join(", ");
508 formatdoc! {"
509 <old_text> matches more than one position in the file (lines: {line_numbers}). Read the
510 relevant sections of {input_path} again and extend <old_text> so
511 that I can perform the requested edits.
512 "}
513 }
514 );
515 }
516
517 Ok(EditFileToolOutput {
518 input_path: input.path,
519 new_text,
520 old_text,
521 diff: unified_diff,
522 edit_agent_output,
523 })
524 })
525 }
526
527 fn replay(
528 &self,
529 _input: Self::Input,
530 output: Self::Output,
531 event_stream: ToolCallEventStream,
532 cx: &mut App,
533 ) -> Result<()> {
534 event_stream.update_diff(cx.new(|cx| {
535 Diff::finalized(
536 output.input_path.to_string_lossy().into_owned(),
537 Some(output.old_text.to_string()),
538 output.new_text,
539 self.language_registry.clone(),
540 cx,
541 )
542 }));
543 Ok(())
544 }
545}
546
547/// Validate that the file path is valid, meaning:
548///
549/// - For `edit` and `overwrite`, the path must point to an existing file.
550/// - For `create`, the file must not already exist, but it's parent dir must exist.
551fn resolve_path(
552 input: &EditFileToolInput,
553 project: Entity<Project>,
554 cx: &mut App,
555) -> Result<ProjectPath> {
556 let project = project.read(cx);
557
558 match input.mode {
559 EditFileMode::Edit | EditFileMode::Overwrite => {
560 let path = project
561 .find_project_path(&input.path, cx)
562 .context("Can't edit file: path not found")?;
563
564 let entry = project
565 .entry_for_path(&path, cx)
566 .context("Can't edit file: path not found")?;
567
568 anyhow::ensure!(entry.is_file(), "Can't edit file: path is a directory");
569 Ok(path)
570 }
571
572 EditFileMode::Create => {
573 if let Some(path) = project.find_project_path(&input.path, cx) {
574 anyhow::ensure!(
575 project.entry_for_path(&path, cx).is_none(),
576 "Can't create file: file already exists"
577 );
578 }
579
580 let parent_path = input
581 .path
582 .parent()
583 .context("Can't create file: incorrect path")?;
584
585 let parent_project_path = project.find_project_path(&parent_path, cx);
586
587 let parent_entry = parent_project_path
588 .as_ref()
589 .and_then(|path| project.entry_for_path(path, cx))
590 .context("Can't create file: parent directory doesn't exist")?;
591
592 anyhow::ensure!(
593 parent_entry.is_dir(),
594 "Can't create file: parent is not a directory"
595 );
596
597 let file_name = input
598 .path
599 .file_name()
600 .and_then(|file_name| file_name.to_str())
601 .and_then(|file_name| RelPath::unix(file_name).ok())
602 .context("Can't create file: invalid filename")?;
603
604 let new_file_path = parent_project_path.map(|parent| ProjectPath {
605 path: parent.path.join(file_name),
606 ..parent
607 });
608
609 new_file_path.context("Can't create file")
610 }
611 }
612}
613
614#[cfg(test)]
615mod tests {
616 use super::*;
617 use crate::{ContextServerRegistry, Templates};
618 use fs::Fs;
619 use gpui::{TestAppContext, UpdateGlobal};
620 use language_model::fake_provider::FakeLanguageModel;
621 use prompt_store::ProjectContext;
622 use serde_json::json;
623 use settings::SettingsStore;
624 use util::{path, rel_path::rel_path};
625
626 #[gpui::test]
627 async fn test_edit_nonexistent_file(cx: &mut TestAppContext) {
628 init_test(cx);
629
630 let fs = project::FakeFs::new(cx.executor());
631 fs.insert_tree("/root", json!({})).await;
632 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
633 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
634 let context_server_registry =
635 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
636 let model = Arc::new(FakeLanguageModel::default());
637 let thread = cx.new(|cx| {
638 Thread::new(
639 project.clone(),
640 cx.new(|_cx| ProjectContext::default()),
641 context_server_registry,
642 Templates::new(),
643 Some(model),
644 cx,
645 )
646 });
647 let result = cx
648 .update(|cx| {
649 let input = EditFileToolInput {
650 display_description: "Some edit".into(),
651 path: "root/nonexistent_file.txt".into(),
652 mode: EditFileMode::Edit,
653 };
654 Arc::new(EditFileTool::new(
655 project,
656 thread.downgrade(),
657 language_registry,
658 Templates::new(),
659 ))
660 .run(input, ToolCallEventStream::test().0, cx)
661 })
662 .await;
663 assert_eq!(
664 result.unwrap_err().to_string(),
665 "Can't edit file: path not found"
666 );
667 }
668
669 #[gpui::test]
670 async fn test_resolve_path_for_creating_file(cx: &mut TestAppContext) {
671 let mode = &EditFileMode::Create;
672
673 let result = test_resolve_path(mode, "root/new.txt", cx);
674 assert_resolved_path_eq(result.await, rel_path("new.txt"));
675
676 let result = test_resolve_path(mode, "new.txt", cx);
677 assert_resolved_path_eq(result.await, rel_path("new.txt"));
678
679 let result = test_resolve_path(mode, "dir/new.txt", cx);
680 assert_resolved_path_eq(result.await, rel_path("dir/new.txt"));
681
682 let result = test_resolve_path(mode, "root/dir/subdir/existing.txt", cx);
683 assert_eq!(
684 result.await.unwrap_err().to_string(),
685 "Can't create file: file already exists"
686 );
687
688 let result = test_resolve_path(mode, "root/dir/nonexistent_dir/new.txt", cx);
689 assert_eq!(
690 result.await.unwrap_err().to_string(),
691 "Can't create file: parent directory doesn't exist"
692 );
693 }
694
695 #[gpui::test]
696 async fn test_resolve_path_for_editing_file(cx: &mut TestAppContext) {
697 let mode = &EditFileMode::Edit;
698
699 let path_with_root = "root/dir/subdir/existing.txt";
700 let path_without_root = "dir/subdir/existing.txt";
701 let result = test_resolve_path(mode, path_with_root, cx);
702 assert_resolved_path_eq(result.await, rel_path(path_without_root));
703
704 let result = test_resolve_path(mode, path_without_root, cx);
705 assert_resolved_path_eq(result.await, rel_path(path_without_root));
706
707 let result = test_resolve_path(mode, "root/nonexistent.txt", cx);
708 assert_eq!(
709 result.await.unwrap_err().to_string(),
710 "Can't edit file: path not found"
711 );
712
713 let result = test_resolve_path(mode, "root/dir", cx);
714 assert_eq!(
715 result.await.unwrap_err().to_string(),
716 "Can't edit file: path is a directory"
717 );
718 }
719
720 async fn test_resolve_path(
721 mode: &EditFileMode,
722 path: &str,
723 cx: &mut TestAppContext,
724 ) -> anyhow::Result<ProjectPath> {
725 init_test(cx);
726
727 let fs = project::FakeFs::new(cx.executor());
728 fs.insert_tree(
729 "/root",
730 json!({
731 "dir": {
732 "subdir": {
733 "existing.txt": "hello"
734 }
735 }
736 }),
737 )
738 .await;
739 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
740
741 let input = EditFileToolInput {
742 display_description: "Some edit".into(),
743 path: path.into(),
744 mode: mode.clone(),
745 };
746
747 cx.update(|cx| resolve_path(&input, project, cx))
748 }
749
750 #[track_caller]
751 fn assert_resolved_path_eq(path: anyhow::Result<ProjectPath>, expected: &RelPath) {
752 let actual = path.expect("Should return valid path").path;
753 assert_eq!(actual.as_ref(), expected);
754 }
755
756 #[gpui::test]
757 async fn test_format_on_save(cx: &mut TestAppContext) {
758 init_test(cx);
759
760 let fs = project::FakeFs::new(cx.executor());
761 fs.insert_tree("/root", json!({"src": {}})).await;
762
763 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
764
765 // Set up a Rust language with LSP formatting support
766 let rust_language = Arc::new(language::Language::new(
767 language::LanguageConfig {
768 name: "Rust".into(),
769 matcher: language::LanguageMatcher {
770 path_suffixes: vec!["rs".to_string()],
771 ..Default::default()
772 },
773 ..Default::default()
774 },
775 None,
776 ));
777
778 // Register the language and fake LSP
779 let language_registry = project.read_with(cx, |project, _| project.languages().clone());
780 language_registry.add(rust_language);
781
782 let mut fake_language_servers = language_registry.register_fake_lsp(
783 "Rust",
784 language::FakeLspAdapter {
785 capabilities: lsp::ServerCapabilities {
786 document_formatting_provider: Some(lsp::OneOf::Left(true)),
787 ..Default::default()
788 },
789 ..Default::default()
790 },
791 );
792
793 // Create the file
794 fs.save(
795 path!("/root/src/main.rs").as_ref(),
796 &"initial content".into(),
797 language::LineEnding::Unix,
798 )
799 .await
800 .unwrap();
801
802 // Open the buffer to trigger LSP initialization
803 let buffer = project
804 .update(cx, |project, cx| {
805 project.open_local_buffer(path!("/root/src/main.rs"), cx)
806 })
807 .await
808 .unwrap();
809
810 // Register the buffer with language servers
811 let _handle = project.update(cx, |project, cx| {
812 project.register_buffer_with_language_servers(&buffer, cx)
813 });
814
815 const UNFORMATTED_CONTENT: &str = "fn main() {println!(\"Hello!\");}\n";
816 const FORMATTED_CONTENT: &str =
817 "This file was formatted by the fake formatter in the test.\n";
818
819 // Get the fake language server and set up formatting handler
820 let fake_language_server = fake_language_servers.next().await.unwrap();
821 fake_language_server.set_request_handler::<lsp::request::Formatting, _, _>({
822 |_, _| async move {
823 Ok(Some(vec![lsp::TextEdit {
824 range: lsp::Range::new(lsp::Position::new(0, 0), lsp::Position::new(1, 0)),
825 new_text: FORMATTED_CONTENT.to_string(),
826 }]))
827 }
828 });
829
830 let context_server_registry =
831 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
832 let model = Arc::new(FakeLanguageModel::default());
833 let thread = cx.new(|cx| {
834 Thread::new(
835 project.clone(),
836 cx.new(|_cx| ProjectContext::default()),
837 context_server_registry,
838 Templates::new(),
839 Some(model.clone()),
840 cx,
841 )
842 });
843
844 // First, test with format_on_save enabled
845 cx.update(|cx| {
846 SettingsStore::update_global(cx, |store, cx| {
847 store.update_user_settings(cx, |settings| {
848 settings.project.all_languages.defaults.format_on_save = Some(FormatOnSave::On);
849 settings.project.all_languages.defaults.formatter =
850 Some(language::language_settings::FormatterList::default());
851 });
852 });
853 });
854
855 // Have the model stream unformatted content
856 let edit_result = {
857 let edit_task = cx.update(|cx| {
858 let input = EditFileToolInput {
859 display_description: "Create main function".into(),
860 path: "root/src/main.rs".into(),
861 mode: EditFileMode::Overwrite,
862 };
863 Arc::new(EditFileTool::new(
864 project.clone(),
865 thread.downgrade(),
866 language_registry.clone(),
867 Templates::new(),
868 ))
869 .run(input, ToolCallEventStream::test().0, cx)
870 });
871
872 // Stream the unformatted content
873 cx.executor().run_until_parked();
874 model.send_last_completion_stream_text_chunk(UNFORMATTED_CONTENT.to_string());
875 model.end_last_completion_stream();
876
877 edit_task.await
878 };
879 assert!(edit_result.is_ok());
880
881 // Wait for any async operations (e.g. formatting) to complete
882 cx.executor().run_until_parked();
883
884 // Read the file to verify it was formatted automatically
885 let new_content = fs.load(path!("/root/src/main.rs").as_ref()).await.unwrap();
886 assert_eq!(
887 // Ignore carriage returns on Windows
888 new_content.replace("\r\n", "\n"),
889 FORMATTED_CONTENT,
890 "Code should be formatted when format_on_save is enabled"
891 );
892
893 let stale_buffer_count = thread
894 .read_with(cx, |thread, _cx| thread.action_log.clone())
895 .read_with(cx, |log, cx| log.stale_buffers(cx).count());
896
897 assert_eq!(
898 stale_buffer_count, 0,
899 "BUG: Buffer is incorrectly marked as stale after format-on-save. Found {} stale buffers. \
900 This causes the agent to think the file was modified externally when it was just formatted.",
901 stale_buffer_count
902 );
903
904 // Next, test with format_on_save disabled
905 cx.update(|cx| {
906 SettingsStore::update_global(cx, |store, cx| {
907 store.update_user_settings(cx, |settings| {
908 settings.project.all_languages.defaults.format_on_save =
909 Some(FormatOnSave::Off);
910 });
911 });
912 });
913
914 // Stream unformatted edits again
915 let edit_result = {
916 let edit_task = cx.update(|cx| {
917 let input = EditFileToolInput {
918 display_description: "Update main function".into(),
919 path: "root/src/main.rs".into(),
920 mode: EditFileMode::Overwrite,
921 };
922 Arc::new(EditFileTool::new(
923 project.clone(),
924 thread.downgrade(),
925 language_registry,
926 Templates::new(),
927 ))
928 .run(input, ToolCallEventStream::test().0, cx)
929 });
930
931 // Stream the unformatted content
932 cx.executor().run_until_parked();
933 model.send_last_completion_stream_text_chunk(UNFORMATTED_CONTENT.to_string());
934 model.end_last_completion_stream();
935
936 edit_task.await
937 };
938 assert!(edit_result.is_ok());
939
940 // Wait for any async operations (e.g. formatting) to complete
941 cx.executor().run_until_parked();
942
943 // Verify the file was not formatted
944 let new_content = fs.load(path!("/root/src/main.rs").as_ref()).await.unwrap();
945 assert_eq!(
946 // Ignore carriage returns on Windows
947 new_content.replace("\r\n", "\n"),
948 UNFORMATTED_CONTENT,
949 "Code should not be formatted when format_on_save is disabled"
950 );
951 }
952
953 #[gpui::test]
954 async fn test_remove_trailing_whitespace(cx: &mut TestAppContext) {
955 init_test(cx);
956
957 let fs = project::FakeFs::new(cx.executor());
958 fs.insert_tree("/root", json!({"src": {}})).await;
959
960 // Create a simple file with trailing whitespace
961 fs.save(
962 path!("/root/src/main.rs").as_ref(),
963 &"initial content".into(),
964 language::LineEnding::Unix,
965 )
966 .await
967 .unwrap();
968
969 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
970 let context_server_registry =
971 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
972 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
973 let model = Arc::new(FakeLanguageModel::default());
974 let thread = cx.new(|cx| {
975 Thread::new(
976 project.clone(),
977 cx.new(|_cx| ProjectContext::default()),
978 context_server_registry,
979 Templates::new(),
980 Some(model.clone()),
981 cx,
982 )
983 });
984
985 // First, test with remove_trailing_whitespace_on_save enabled
986 cx.update(|cx| {
987 SettingsStore::update_global(cx, |store, cx| {
988 store.update_user_settings(cx, |settings| {
989 settings
990 .project
991 .all_languages
992 .defaults
993 .remove_trailing_whitespace_on_save = Some(true);
994 });
995 });
996 });
997
998 const CONTENT_WITH_TRAILING_WHITESPACE: &str =
999 "fn main() { \n println!(\"Hello!\"); \n}\n";
1000
1001 // Have the model stream content that contains trailing whitespace
1002 let edit_result = {
1003 let edit_task = cx.update(|cx| {
1004 let input = EditFileToolInput {
1005 display_description: "Create main function".into(),
1006 path: "root/src/main.rs".into(),
1007 mode: EditFileMode::Overwrite,
1008 };
1009 Arc::new(EditFileTool::new(
1010 project.clone(),
1011 thread.downgrade(),
1012 language_registry.clone(),
1013 Templates::new(),
1014 ))
1015 .run(input, ToolCallEventStream::test().0, cx)
1016 });
1017
1018 // Stream the content with trailing whitespace
1019 cx.executor().run_until_parked();
1020 model.send_last_completion_stream_text_chunk(
1021 CONTENT_WITH_TRAILING_WHITESPACE.to_string(),
1022 );
1023 model.end_last_completion_stream();
1024
1025 edit_task.await
1026 };
1027 assert!(edit_result.is_ok());
1028
1029 // Wait for any async operations (e.g. formatting) to complete
1030 cx.executor().run_until_parked();
1031
1032 // Read the file to verify trailing whitespace was removed automatically
1033 assert_eq!(
1034 // Ignore carriage returns on Windows
1035 fs.load(path!("/root/src/main.rs").as_ref())
1036 .await
1037 .unwrap()
1038 .replace("\r\n", "\n"),
1039 "fn main() {\n println!(\"Hello!\");\n}\n",
1040 "Trailing whitespace should be removed when remove_trailing_whitespace_on_save is enabled"
1041 );
1042
1043 // Next, test with remove_trailing_whitespace_on_save disabled
1044 cx.update(|cx| {
1045 SettingsStore::update_global(cx, |store, cx| {
1046 store.update_user_settings(cx, |settings| {
1047 settings
1048 .project
1049 .all_languages
1050 .defaults
1051 .remove_trailing_whitespace_on_save = Some(false);
1052 });
1053 });
1054 });
1055
1056 // Stream edits again with trailing whitespace
1057 let edit_result = {
1058 let edit_task = cx.update(|cx| {
1059 let input = EditFileToolInput {
1060 display_description: "Update main function".into(),
1061 path: "root/src/main.rs".into(),
1062 mode: EditFileMode::Overwrite,
1063 };
1064 Arc::new(EditFileTool::new(
1065 project.clone(),
1066 thread.downgrade(),
1067 language_registry,
1068 Templates::new(),
1069 ))
1070 .run(input, ToolCallEventStream::test().0, cx)
1071 });
1072
1073 // Stream the content with trailing whitespace
1074 cx.executor().run_until_parked();
1075 model.send_last_completion_stream_text_chunk(
1076 CONTENT_WITH_TRAILING_WHITESPACE.to_string(),
1077 );
1078 model.end_last_completion_stream();
1079
1080 edit_task.await
1081 };
1082 assert!(edit_result.is_ok());
1083
1084 // Wait for any async operations (e.g. formatting) to complete
1085 cx.executor().run_until_parked();
1086
1087 // Verify the file still has trailing whitespace
1088 // Read the file again - it should still have trailing whitespace
1089 let final_content = fs.load(path!("/root/src/main.rs").as_ref()).await.unwrap();
1090 assert_eq!(
1091 // Ignore carriage returns on Windows
1092 final_content.replace("\r\n", "\n"),
1093 CONTENT_WITH_TRAILING_WHITESPACE,
1094 "Trailing whitespace should remain when remove_trailing_whitespace_on_save is disabled"
1095 );
1096 }
1097
1098 #[gpui::test]
1099 async fn test_authorize(cx: &mut TestAppContext) {
1100 init_test(cx);
1101 let fs = project::FakeFs::new(cx.executor());
1102 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
1103 let context_server_registry =
1104 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1105 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1106 let model = Arc::new(FakeLanguageModel::default());
1107 let thread = cx.new(|cx| {
1108 Thread::new(
1109 project.clone(),
1110 cx.new(|_cx| ProjectContext::default()),
1111 context_server_registry,
1112 Templates::new(),
1113 Some(model.clone()),
1114 cx,
1115 )
1116 });
1117 let tool = Arc::new(EditFileTool::new(
1118 project.clone(),
1119 thread.downgrade(),
1120 language_registry,
1121 Templates::new(),
1122 ));
1123 fs.insert_tree("/root", json!({})).await;
1124
1125 // Test 1: Path with .zed component should require confirmation
1126 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1127 let _auth = cx.update(|cx| {
1128 tool.authorize(
1129 &EditFileToolInput {
1130 display_description: "test 1".into(),
1131 path: ".zed/settings.json".into(),
1132 mode: EditFileMode::Edit,
1133 },
1134 &stream_tx,
1135 cx,
1136 )
1137 });
1138
1139 let event = stream_rx.expect_authorization().await;
1140 assert_eq!(
1141 event.tool_call.fields.title,
1142 Some("test 1 (local settings)".into())
1143 );
1144
1145 // Test 2: Path outside project should require confirmation
1146 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1147 let _auth = cx.update(|cx| {
1148 tool.authorize(
1149 &EditFileToolInput {
1150 display_description: "test 2".into(),
1151 path: "/etc/hosts".into(),
1152 mode: EditFileMode::Edit,
1153 },
1154 &stream_tx,
1155 cx,
1156 )
1157 });
1158
1159 let event = stream_rx.expect_authorization().await;
1160 assert_eq!(event.tool_call.fields.title, Some("test 2".into()));
1161
1162 // Test 3: Relative path without .zed should not require confirmation
1163 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1164 cx.update(|cx| {
1165 tool.authorize(
1166 &EditFileToolInput {
1167 display_description: "test 3".into(),
1168 path: "root/src/main.rs".into(),
1169 mode: EditFileMode::Edit,
1170 },
1171 &stream_tx,
1172 cx,
1173 )
1174 })
1175 .await
1176 .unwrap();
1177 assert!(stream_rx.try_next().is_err());
1178
1179 // Test 4: Path with .zed in the middle should require confirmation
1180 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1181 let _auth = cx.update(|cx| {
1182 tool.authorize(
1183 &EditFileToolInput {
1184 display_description: "test 4".into(),
1185 path: "root/.zed/tasks.json".into(),
1186 mode: EditFileMode::Edit,
1187 },
1188 &stream_tx,
1189 cx,
1190 )
1191 });
1192 let event = stream_rx.expect_authorization().await;
1193 assert_eq!(
1194 event.tool_call.fields.title,
1195 Some("test 4 (local settings)".into())
1196 );
1197
1198 // Test 5: When always_allow_tool_actions is enabled, no confirmation needed
1199 cx.update(|cx| {
1200 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
1201 settings.always_allow_tool_actions = true;
1202 agent_settings::AgentSettings::override_global(settings, cx);
1203 });
1204
1205 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1206 cx.update(|cx| {
1207 tool.authorize(
1208 &EditFileToolInput {
1209 display_description: "test 5.1".into(),
1210 path: ".zed/settings.json".into(),
1211 mode: EditFileMode::Edit,
1212 },
1213 &stream_tx,
1214 cx,
1215 )
1216 })
1217 .await
1218 .unwrap();
1219 assert!(stream_rx.try_next().is_err());
1220
1221 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1222 cx.update(|cx| {
1223 tool.authorize(
1224 &EditFileToolInput {
1225 display_description: "test 5.2".into(),
1226 path: "/etc/hosts".into(),
1227 mode: EditFileMode::Edit,
1228 },
1229 &stream_tx,
1230 cx,
1231 )
1232 })
1233 .await
1234 .unwrap();
1235 assert!(stream_rx.try_next().is_err());
1236 }
1237
1238 #[gpui::test]
1239 async fn test_authorize_global_config(cx: &mut TestAppContext) {
1240 init_test(cx);
1241 let fs = project::FakeFs::new(cx.executor());
1242 fs.insert_tree("/project", json!({})).await;
1243 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1244 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1245 let context_server_registry =
1246 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1247 let model = Arc::new(FakeLanguageModel::default());
1248 let thread = cx.new(|cx| {
1249 Thread::new(
1250 project.clone(),
1251 cx.new(|_cx| ProjectContext::default()),
1252 context_server_registry,
1253 Templates::new(),
1254 Some(model.clone()),
1255 cx,
1256 )
1257 });
1258 let tool = Arc::new(EditFileTool::new(
1259 project.clone(),
1260 thread.downgrade(),
1261 language_registry,
1262 Templates::new(),
1263 ));
1264
1265 // Test global config paths - these should require confirmation if they exist and are outside the project
1266 let test_cases = vec![
1267 (
1268 "/etc/hosts",
1269 true,
1270 "System file should require confirmation",
1271 ),
1272 (
1273 "/usr/local/bin/script",
1274 true,
1275 "System bin file should require confirmation",
1276 ),
1277 (
1278 "project/normal_file.rs",
1279 false,
1280 "Normal project file should not require confirmation",
1281 ),
1282 ];
1283
1284 for (path, should_confirm, description) in test_cases {
1285 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1286 let auth = cx.update(|cx| {
1287 tool.authorize(
1288 &EditFileToolInput {
1289 display_description: "Edit file".into(),
1290 path: path.into(),
1291 mode: EditFileMode::Edit,
1292 },
1293 &stream_tx,
1294 cx,
1295 )
1296 });
1297
1298 if should_confirm {
1299 stream_rx.expect_authorization().await;
1300 } else {
1301 auth.await.unwrap();
1302 assert!(
1303 stream_rx.try_next().is_err(),
1304 "Failed for case: {} - path: {} - expected no confirmation but got one",
1305 description,
1306 path
1307 );
1308 }
1309 }
1310 }
1311
1312 #[gpui::test]
1313 async fn test_needs_confirmation_with_multiple_worktrees(cx: &mut TestAppContext) {
1314 init_test(cx);
1315 let fs = project::FakeFs::new(cx.executor());
1316
1317 // Create multiple worktree directories
1318 fs.insert_tree(
1319 "/workspace/frontend",
1320 json!({
1321 "src": {
1322 "main.js": "console.log('frontend');"
1323 }
1324 }),
1325 )
1326 .await;
1327 fs.insert_tree(
1328 "/workspace/backend",
1329 json!({
1330 "src": {
1331 "main.rs": "fn main() {}"
1332 }
1333 }),
1334 )
1335 .await;
1336 fs.insert_tree(
1337 "/workspace/shared",
1338 json!({
1339 ".zed": {
1340 "settings.json": "{}"
1341 }
1342 }),
1343 )
1344 .await;
1345
1346 // Create project with multiple worktrees
1347 let project = Project::test(
1348 fs.clone(),
1349 [
1350 path!("/workspace/frontend").as_ref(),
1351 path!("/workspace/backend").as_ref(),
1352 path!("/workspace/shared").as_ref(),
1353 ],
1354 cx,
1355 )
1356 .await;
1357 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1358 let context_server_registry =
1359 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1360 let model = Arc::new(FakeLanguageModel::default());
1361 let thread = cx.new(|cx| {
1362 Thread::new(
1363 project.clone(),
1364 cx.new(|_cx| ProjectContext::default()),
1365 context_server_registry.clone(),
1366 Templates::new(),
1367 Some(model.clone()),
1368 cx,
1369 )
1370 });
1371 let tool = Arc::new(EditFileTool::new(
1372 project.clone(),
1373 thread.downgrade(),
1374 language_registry,
1375 Templates::new(),
1376 ));
1377
1378 // Test files in different worktrees
1379 let test_cases = vec![
1380 ("frontend/src/main.js", false, "File in first worktree"),
1381 ("backend/src/main.rs", false, "File in second worktree"),
1382 (
1383 "shared/.zed/settings.json",
1384 true,
1385 ".zed file in third worktree",
1386 ),
1387 ("/etc/hosts", true, "Absolute path outside all worktrees"),
1388 (
1389 "../outside/file.txt",
1390 true,
1391 "Relative path outside worktrees",
1392 ),
1393 ];
1394
1395 for (path, should_confirm, description) in test_cases {
1396 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1397 let auth = cx.update(|cx| {
1398 tool.authorize(
1399 &EditFileToolInput {
1400 display_description: "Edit file".into(),
1401 path: path.into(),
1402 mode: EditFileMode::Edit,
1403 },
1404 &stream_tx,
1405 cx,
1406 )
1407 });
1408
1409 if should_confirm {
1410 stream_rx.expect_authorization().await;
1411 } else {
1412 auth.await.unwrap();
1413 assert!(
1414 stream_rx.try_next().is_err(),
1415 "Failed for case: {} - path: {} - expected no confirmation but got one",
1416 description,
1417 path
1418 );
1419 }
1420 }
1421 }
1422
1423 #[gpui::test]
1424 async fn test_needs_confirmation_edge_cases(cx: &mut TestAppContext) {
1425 init_test(cx);
1426 let fs = project::FakeFs::new(cx.executor());
1427 fs.insert_tree(
1428 "/project",
1429 json!({
1430 ".zed": {
1431 "settings.json": "{}"
1432 },
1433 "src": {
1434 ".zed": {
1435 "local.json": "{}"
1436 }
1437 }
1438 }),
1439 )
1440 .await;
1441 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1442 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1443 let context_server_registry =
1444 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1445 let model = Arc::new(FakeLanguageModel::default());
1446 let thread = cx.new(|cx| {
1447 Thread::new(
1448 project.clone(),
1449 cx.new(|_cx| ProjectContext::default()),
1450 context_server_registry.clone(),
1451 Templates::new(),
1452 Some(model.clone()),
1453 cx,
1454 )
1455 });
1456 let tool = Arc::new(EditFileTool::new(
1457 project.clone(),
1458 thread.downgrade(),
1459 language_registry,
1460 Templates::new(),
1461 ));
1462
1463 // Test edge cases
1464 let test_cases = vec![
1465 // Empty path - find_project_path returns Some for empty paths
1466 ("", false, "Empty path is treated as project root"),
1467 // Root directory
1468 ("/", true, "Root directory should be outside project"),
1469 // Parent directory references - find_project_path resolves these
1470 (
1471 "project/../other",
1472 true,
1473 "Path with .. that goes outside of root directory",
1474 ),
1475 (
1476 "project/./src/file.rs",
1477 false,
1478 "Path with . should work normally",
1479 ),
1480 // Windows-style paths (if on Windows)
1481 #[cfg(target_os = "windows")]
1482 ("C:\\Windows\\System32\\hosts", true, "Windows system path"),
1483 #[cfg(target_os = "windows")]
1484 ("project\\src\\main.rs", false, "Windows-style project path"),
1485 ];
1486
1487 for (path, should_confirm, description) in test_cases {
1488 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1489 let auth = cx.update(|cx| {
1490 tool.authorize(
1491 &EditFileToolInput {
1492 display_description: "Edit file".into(),
1493 path: path.into(),
1494 mode: EditFileMode::Edit,
1495 },
1496 &stream_tx,
1497 cx,
1498 )
1499 });
1500
1501 cx.run_until_parked();
1502
1503 if should_confirm {
1504 stream_rx.expect_authorization().await;
1505 } else {
1506 assert!(
1507 stream_rx.try_next().is_err(),
1508 "Failed for case: {} - path: {} - expected no confirmation but got one",
1509 description,
1510 path
1511 );
1512 auth.await.unwrap();
1513 }
1514 }
1515 }
1516
1517 #[gpui::test]
1518 async fn test_needs_confirmation_with_different_modes(cx: &mut TestAppContext) {
1519 init_test(cx);
1520 let fs = project::FakeFs::new(cx.executor());
1521 fs.insert_tree(
1522 "/project",
1523 json!({
1524 "existing.txt": "content",
1525 ".zed": {
1526 "settings.json": "{}"
1527 }
1528 }),
1529 )
1530 .await;
1531 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1532 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1533 let context_server_registry =
1534 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1535 let model = Arc::new(FakeLanguageModel::default());
1536 let thread = cx.new(|cx| {
1537 Thread::new(
1538 project.clone(),
1539 cx.new(|_cx| ProjectContext::default()),
1540 context_server_registry.clone(),
1541 Templates::new(),
1542 Some(model.clone()),
1543 cx,
1544 )
1545 });
1546 let tool = Arc::new(EditFileTool::new(
1547 project.clone(),
1548 thread.downgrade(),
1549 language_registry,
1550 Templates::new(),
1551 ));
1552
1553 // Test different EditFileMode values
1554 let modes = vec![
1555 EditFileMode::Edit,
1556 EditFileMode::Create,
1557 EditFileMode::Overwrite,
1558 ];
1559
1560 for mode in modes {
1561 // Test .zed path with different modes
1562 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1563 let _auth = cx.update(|cx| {
1564 tool.authorize(
1565 &EditFileToolInput {
1566 display_description: "Edit settings".into(),
1567 path: "project/.zed/settings.json".into(),
1568 mode: mode.clone(),
1569 },
1570 &stream_tx,
1571 cx,
1572 )
1573 });
1574
1575 stream_rx.expect_authorization().await;
1576
1577 // Test outside path with different modes
1578 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1579 let _auth = cx.update(|cx| {
1580 tool.authorize(
1581 &EditFileToolInput {
1582 display_description: "Edit file".into(),
1583 path: "/outside/file.txt".into(),
1584 mode: mode.clone(),
1585 },
1586 &stream_tx,
1587 cx,
1588 )
1589 });
1590
1591 stream_rx.expect_authorization().await;
1592
1593 // Test normal path with different modes
1594 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1595 cx.update(|cx| {
1596 tool.authorize(
1597 &EditFileToolInput {
1598 display_description: "Edit file".into(),
1599 path: "project/normal.txt".into(),
1600 mode: mode.clone(),
1601 },
1602 &stream_tx,
1603 cx,
1604 )
1605 })
1606 .await
1607 .unwrap();
1608 assert!(stream_rx.try_next().is_err());
1609 }
1610 }
1611
1612 #[gpui::test]
1613 async fn test_initial_title_with_partial_input(cx: &mut TestAppContext) {
1614 init_test(cx);
1615 let fs = project::FakeFs::new(cx.executor());
1616 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1617 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1618 let context_server_registry =
1619 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1620 let model = Arc::new(FakeLanguageModel::default());
1621 let thread = cx.new(|cx| {
1622 Thread::new(
1623 project.clone(),
1624 cx.new(|_cx| ProjectContext::default()),
1625 context_server_registry,
1626 Templates::new(),
1627 Some(model.clone()),
1628 cx,
1629 )
1630 });
1631 let tool = Arc::new(EditFileTool::new(
1632 project,
1633 thread.downgrade(),
1634 language_registry,
1635 Templates::new(),
1636 ));
1637
1638 cx.update(|cx| {
1639 // ...
1640 assert_eq!(
1641 tool.initial_title(
1642 Err(json!({
1643 "path": "src/main.rs",
1644 "display_description": "",
1645 "old_string": "old code",
1646 "new_string": "new code"
1647 })),
1648 cx
1649 ),
1650 "src/main.rs"
1651 );
1652 assert_eq!(
1653 tool.initial_title(
1654 Err(json!({
1655 "path": "",
1656 "display_description": "Fix error handling",
1657 "old_string": "old code",
1658 "new_string": "new code"
1659 })),
1660 cx
1661 ),
1662 "Fix error handling"
1663 );
1664 assert_eq!(
1665 tool.initial_title(
1666 Err(json!({
1667 "path": "src/main.rs",
1668 "display_description": "Fix error handling",
1669 "old_string": "old code",
1670 "new_string": "new code"
1671 })),
1672 cx
1673 ),
1674 "src/main.rs"
1675 );
1676 assert_eq!(
1677 tool.initial_title(
1678 Err(json!({
1679 "path": "",
1680 "display_description": "",
1681 "old_string": "old code",
1682 "new_string": "new code"
1683 })),
1684 cx
1685 ),
1686 DEFAULT_UI_TEXT
1687 );
1688 assert_eq!(
1689 tool.initial_title(Err(serde_json::Value::Null), cx),
1690 DEFAULT_UI_TEXT
1691 );
1692 });
1693 }
1694
1695 #[gpui::test]
1696 async fn test_diff_finalization(cx: &mut TestAppContext) {
1697 init_test(cx);
1698 let fs = project::FakeFs::new(cx.executor());
1699 fs.insert_tree("/", json!({"main.rs": ""})).await;
1700
1701 let project = Project::test(fs.clone(), [path!("/").as_ref()], cx).await;
1702 let languages = project.read_with(cx, |project, _cx| project.languages().clone());
1703 let context_server_registry =
1704 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1705 let model = Arc::new(FakeLanguageModel::default());
1706 let thread = cx.new(|cx| {
1707 Thread::new(
1708 project.clone(),
1709 cx.new(|_cx| ProjectContext::default()),
1710 context_server_registry.clone(),
1711 Templates::new(),
1712 Some(model.clone()),
1713 cx,
1714 )
1715 });
1716
1717 // Ensure the diff is finalized after the edit completes.
1718 {
1719 let tool = Arc::new(EditFileTool::new(
1720 project.clone(),
1721 thread.downgrade(),
1722 languages.clone(),
1723 Templates::new(),
1724 ));
1725 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1726 let edit = cx.update(|cx| {
1727 tool.run(
1728 EditFileToolInput {
1729 display_description: "Edit file".into(),
1730 path: path!("/main.rs").into(),
1731 mode: EditFileMode::Edit,
1732 },
1733 stream_tx,
1734 cx,
1735 )
1736 });
1737 stream_rx.expect_update_fields().await;
1738 let diff = stream_rx.expect_diff().await;
1739 diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Pending(_))));
1740 cx.run_until_parked();
1741 model.end_last_completion_stream();
1742 edit.await.unwrap();
1743 diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Finalized(_))));
1744 }
1745
1746 // Ensure the diff is finalized if an error occurs while editing.
1747 {
1748 model.forbid_requests();
1749 let tool = Arc::new(EditFileTool::new(
1750 project.clone(),
1751 thread.downgrade(),
1752 languages.clone(),
1753 Templates::new(),
1754 ));
1755 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1756 let edit = cx.update(|cx| {
1757 tool.run(
1758 EditFileToolInput {
1759 display_description: "Edit file".into(),
1760 path: path!("/main.rs").into(),
1761 mode: EditFileMode::Edit,
1762 },
1763 stream_tx,
1764 cx,
1765 )
1766 });
1767 stream_rx.expect_update_fields().await;
1768 let diff = stream_rx.expect_diff().await;
1769 diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Pending(_))));
1770 edit.await.unwrap_err();
1771 diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Finalized(_))));
1772 model.allow_requests();
1773 }
1774
1775 // Ensure the diff is finalized if the tool call gets dropped.
1776 {
1777 let tool = Arc::new(EditFileTool::new(
1778 project.clone(),
1779 thread.downgrade(),
1780 languages.clone(),
1781 Templates::new(),
1782 ));
1783 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1784 let edit = cx.update(|cx| {
1785 tool.run(
1786 EditFileToolInput {
1787 display_description: "Edit file".into(),
1788 path: path!("/main.rs").into(),
1789 mode: EditFileMode::Edit,
1790 },
1791 stream_tx,
1792 cx,
1793 )
1794 });
1795 stream_rx.expect_update_fields().await;
1796 let diff = stream_rx.expect_diff().await;
1797 diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Pending(_))));
1798 drop(edit);
1799 cx.run_until_parked();
1800 diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Finalized(_))));
1801 }
1802 }
1803
1804 #[gpui::test]
1805 async fn test_file_read_times_tracking(cx: &mut TestAppContext) {
1806 init_test(cx);
1807
1808 let fs = project::FakeFs::new(cx.executor());
1809 fs.insert_tree(
1810 "/root",
1811 json!({
1812 "test.txt": "original content"
1813 }),
1814 )
1815 .await;
1816 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
1817 let context_server_registry =
1818 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1819 let model = Arc::new(FakeLanguageModel::default());
1820 let thread = cx.new(|cx| {
1821 Thread::new(
1822 project.clone(),
1823 cx.new(|_cx| ProjectContext::default()),
1824 context_server_registry,
1825 Templates::new(),
1826 Some(model.clone()),
1827 cx,
1828 )
1829 });
1830 let action_log = thread.read_with(cx, |thread, _| thread.action_log().clone());
1831
1832 // Initially, file_read_times should be empty
1833 let is_empty = thread.read_with(cx, |thread, _| thread.file_read_times.is_empty());
1834 assert!(is_empty, "file_read_times should start empty");
1835
1836 // Create read tool
1837 let read_tool = Arc::new(crate::ReadFileTool::new(
1838 thread.downgrade(),
1839 project.clone(),
1840 action_log,
1841 ));
1842
1843 // Read the file to record the read time
1844 cx.update(|cx| {
1845 read_tool.clone().run(
1846 crate::ReadFileToolInput {
1847 path: "root/test.txt".to_string(),
1848 start_line: None,
1849 end_line: None,
1850 },
1851 ToolCallEventStream::test().0,
1852 cx,
1853 )
1854 })
1855 .await
1856 .unwrap();
1857
1858 // Verify that file_read_times now contains an entry for the file
1859 let has_entry = thread.read_with(cx, |thread, _| {
1860 thread.file_read_times.len() == 1
1861 && thread
1862 .file_read_times
1863 .keys()
1864 .any(|path| path.ends_with("test.txt"))
1865 });
1866 assert!(
1867 has_entry,
1868 "file_read_times should contain an entry after reading the file"
1869 );
1870
1871 // Read the file again - should update the entry
1872 cx.update(|cx| {
1873 read_tool.clone().run(
1874 crate::ReadFileToolInput {
1875 path: "root/test.txt".to_string(),
1876 start_line: None,
1877 end_line: None,
1878 },
1879 ToolCallEventStream::test().0,
1880 cx,
1881 )
1882 })
1883 .await
1884 .unwrap();
1885
1886 // Should still have exactly one entry
1887 let has_one_entry = thread.read_with(cx, |thread, _| thread.file_read_times.len() == 1);
1888 assert!(
1889 has_one_entry,
1890 "file_read_times should still have one entry after re-reading"
1891 );
1892 }
1893
1894 fn init_test(cx: &mut TestAppContext) {
1895 cx.update(|cx| {
1896 let settings_store = SettingsStore::test(cx);
1897 cx.set_global(settings_store);
1898 });
1899 }
1900
1901 #[gpui::test]
1902 async fn test_consecutive_edits_work(cx: &mut TestAppContext) {
1903 init_test(cx);
1904
1905 let fs = project::FakeFs::new(cx.executor());
1906 fs.insert_tree(
1907 "/root",
1908 json!({
1909 "test.txt": "original content"
1910 }),
1911 )
1912 .await;
1913 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
1914 let context_server_registry =
1915 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1916 let model = Arc::new(FakeLanguageModel::default());
1917 let thread = cx.new(|cx| {
1918 Thread::new(
1919 project.clone(),
1920 cx.new(|_cx| ProjectContext::default()),
1921 context_server_registry,
1922 Templates::new(),
1923 Some(model.clone()),
1924 cx,
1925 )
1926 });
1927 let languages = project.read_with(cx, |project, _| project.languages().clone());
1928 let action_log = thread.read_with(cx, |thread, _| thread.action_log().clone());
1929
1930 let read_tool = Arc::new(crate::ReadFileTool::new(
1931 thread.downgrade(),
1932 project.clone(),
1933 action_log,
1934 ));
1935 let edit_tool = Arc::new(EditFileTool::new(
1936 project.clone(),
1937 thread.downgrade(),
1938 languages,
1939 Templates::new(),
1940 ));
1941
1942 // Read the file first
1943 cx.update(|cx| {
1944 read_tool.clone().run(
1945 crate::ReadFileToolInput {
1946 path: "root/test.txt".to_string(),
1947 start_line: None,
1948 end_line: None,
1949 },
1950 ToolCallEventStream::test().0,
1951 cx,
1952 )
1953 })
1954 .await
1955 .unwrap();
1956
1957 // First edit should work
1958 let edit_result = {
1959 let edit_task = cx.update(|cx| {
1960 edit_tool.clone().run(
1961 EditFileToolInput {
1962 display_description: "First edit".into(),
1963 path: "root/test.txt".into(),
1964 mode: EditFileMode::Edit,
1965 },
1966 ToolCallEventStream::test().0,
1967 cx,
1968 )
1969 });
1970
1971 cx.executor().run_until_parked();
1972 model.send_last_completion_stream_text_chunk(
1973 "<old_text>original content</old_text><new_text>modified content</new_text>"
1974 .to_string(),
1975 );
1976 model.end_last_completion_stream();
1977
1978 edit_task.await
1979 };
1980 assert!(
1981 edit_result.is_ok(),
1982 "First edit should succeed, got error: {:?}",
1983 edit_result.as_ref().err()
1984 );
1985
1986 // Second edit should also work because the edit updated the recorded read time
1987 let edit_result = {
1988 let edit_task = cx.update(|cx| {
1989 edit_tool.clone().run(
1990 EditFileToolInput {
1991 display_description: "Second edit".into(),
1992 path: "root/test.txt".into(),
1993 mode: EditFileMode::Edit,
1994 },
1995 ToolCallEventStream::test().0,
1996 cx,
1997 )
1998 });
1999
2000 cx.executor().run_until_parked();
2001 model.send_last_completion_stream_text_chunk(
2002 "<old_text>modified content</old_text><new_text>further modified content</new_text>".to_string(),
2003 );
2004 model.end_last_completion_stream();
2005
2006 edit_task.await
2007 };
2008 assert!(
2009 edit_result.is_ok(),
2010 "Second consecutive edit should succeed, got error: {:?}",
2011 edit_result.as_ref().err()
2012 );
2013 }
2014
2015 #[gpui::test]
2016 async fn test_external_modification_detected(cx: &mut TestAppContext) {
2017 init_test(cx);
2018
2019 let fs = project::FakeFs::new(cx.executor());
2020 fs.insert_tree(
2021 "/root",
2022 json!({
2023 "test.txt": "original content"
2024 }),
2025 )
2026 .await;
2027 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
2028 let context_server_registry =
2029 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
2030 let model = Arc::new(FakeLanguageModel::default());
2031 let thread = cx.new(|cx| {
2032 Thread::new(
2033 project.clone(),
2034 cx.new(|_cx| ProjectContext::default()),
2035 context_server_registry,
2036 Templates::new(),
2037 Some(model.clone()),
2038 cx,
2039 )
2040 });
2041 let languages = project.read_with(cx, |project, _| project.languages().clone());
2042 let action_log = thread.read_with(cx, |thread, _| thread.action_log().clone());
2043
2044 let read_tool = Arc::new(crate::ReadFileTool::new(
2045 thread.downgrade(),
2046 project.clone(),
2047 action_log,
2048 ));
2049 let edit_tool = Arc::new(EditFileTool::new(
2050 project.clone(),
2051 thread.downgrade(),
2052 languages,
2053 Templates::new(),
2054 ));
2055
2056 // Read the file first
2057 cx.update(|cx| {
2058 read_tool.clone().run(
2059 crate::ReadFileToolInput {
2060 path: "root/test.txt".to_string(),
2061 start_line: None,
2062 end_line: None,
2063 },
2064 ToolCallEventStream::test().0,
2065 cx,
2066 )
2067 })
2068 .await
2069 .unwrap();
2070
2071 // Simulate external modification - advance time and save file
2072 cx.background_executor
2073 .advance_clock(std::time::Duration::from_secs(2));
2074 fs.save(
2075 path!("/root/test.txt").as_ref(),
2076 &"externally modified content".into(),
2077 language::LineEnding::Unix,
2078 )
2079 .await
2080 .unwrap();
2081
2082 // Reload the buffer to pick up the new mtime
2083 let project_path = project
2084 .read_with(cx, |project, cx| {
2085 project.find_project_path("root/test.txt", cx)
2086 })
2087 .expect("Should find project path");
2088 let buffer = project
2089 .update(cx, |project, cx| project.open_buffer(project_path, cx))
2090 .await
2091 .unwrap();
2092 buffer
2093 .update(cx, |buffer, cx| buffer.reload(cx))
2094 .await
2095 .unwrap();
2096
2097 cx.executor().run_until_parked();
2098
2099 // Try to edit - should fail because file was modified externally
2100 let result = cx
2101 .update(|cx| {
2102 edit_tool.clone().run(
2103 EditFileToolInput {
2104 display_description: "Edit after external change".into(),
2105 path: "root/test.txt".into(),
2106 mode: EditFileMode::Edit,
2107 },
2108 ToolCallEventStream::test().0,
2109 cx,
2110 )
2111 })
2112 .await;
2113
2114 assert!(
2115 result.is_err(),
2116 "Edit should fail after external modification"
2117 );
2118 let error_msg = result.unwrap_err().to_string();
2119 assert!(
2120 error_msg.contains("has been modified since you last read it"),
2121 "Error should mention file modification, got: {}",
2122 error_msg
2123 );
2124 }
2125
2126 #[gpui::test]
2127 async fn test_dirty_buffer_detected(cx: &mut TestAppContext) {
2128 init_test(cx);
2129
2130 let fs = project::FakeFs::new(cx.executor());
2131 fs.insert_tree(
2132 "/root",
2133 json!({
2134 "test.txt": "original content"
2135 }),
2136 )
2137 .await;
2138 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
2139 let context_server_registry =
2140 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
2141 let model = Arc::new(FakeLanguageModel::default());
2142 let thread = cx.new(|cx| {
2143 Thread::new(
2144 project.clone(),
2145 cx.new(|_cx| ProjectContext::default()),
2146 context_server_registry,
2147 Templates::new(),
2148 Some(model.clone()),
2149 cx,
2150 )
2151 });
2152 let languages = project.read_with(cx, |project, _| project.languages().clone());
2153 let action_log = thread.read_with(cx, |thread, _| thread.action_log().clone());
2154
2155 let read_tool = Arc::new(crate::ReadFileTool::new(
2156 thread.downgrade(),
2157 project.clone(),
2158 action_log,
2159 ));
2160 let edit_tool = Arc::new(EditFileTool::new(
2161 project.clone(),
2162 thread.downgrade(),
2163 languages,
2164 Templates::new(),
2165 ));
2166
2167 // Read the file first
2168 cx.update(|cx| {
2169 read_tool.clone().run(
2170 crate::ReadFileToolInput {
2171 path: "root/test.txt".to_string(),
2172 start_line: None,
2173 end_line: None,
2174 },
2175 ToolCallEventStream::test().0,
2176 cx,
2177 )
2178 })
2179 .await
2180 .unwrap();
2181
2182 // Open the buffer and make it dirty by editing without saving
2183 let project_path = project
2184 .read_with(cx, |project, cx| {
2185 project.find_project_path("root/test.txt", cx)
2186 })
2187 .expect("Should find project path");
2188 let buffer = project
2189 .update(cx, |project, cx| project.open_buffer(project_path, cx))
2190 .await
2191 .unwrap();
2192
2193 // Make an in-memory edit to the buffer (making it dirty)
2194 buffer.update(cx, |buffer, cx| {
2195 let end_point = buffer.max_point();
2196 buffer.edit([(end_point..end_point, " added text")], None, cx);
2197 });
2198
2199 // Verify buffer is dirty
2200 let is_dirty = buffer.read_with(cx, |buffer, _| buffer.is_dirty());
2201 assert!(is_dirty, "Buffer should be dirty after in-memory edit");
2202
2203 // Try to edit - should fail because buffer has unsaved changes
2204 let result = cx
2205 .update(|cx| {
2206 edit_tool.clone().run(
2207 EditFileToolInput {
2208 display_description: "Edit with dirty buffer".into(),
2209 path: "root/test.txt".into(),
2210 mode: EditFileMode::Edit,
2211 },
2212 ToolCallEventStream::test().0,
2213 cx,
2214 )
2215 })
2216 .await;
2217
2218 assert!(result.is_err(), "Edit should fail when buffer is dirty");
2219 let error_msg = result.unwrap_err().to_string();
2220 assert!(
2221 error_msg.contains("cannot be written to because it has unsaved changes"),
2222 "Error should mention unsaved changes, got: {}",
2223 error_msg
2224 );
2225 }
2226}