1use crate::{
2 AgentTool, Templates, Thread, ToolCallEventStream,
3 edit_agent::{EditAgent, EditAgentOutput, EditAgentOutputEvent, EditFormat},
4};
5use acp_thread::Diff;
6use agent_client_protocol::{self as acp, ToolCallLocation, ToolCallUpdateFields};
7use anyhow::{Context as _, Result, anyhow};
8use cloud_llm_client::CompletionIntent;
9use collections::HashSet;
10use gpui::{App, AppContext, AsyncApp, Entity, Task, WeakEntity};
11use indoc::formatdoc;
12use language::language_settings::{self, FormatOnSave};
13use language::{LanguageRegistry, ToPoint};
14use language_model::LanguageModelToolResultContent;
15use paths;
16use project::lsp_store::{FormatTrigger, LspFormatTarget};
17use project::{Project, ProjectPath};
18use schemars::JsonSchema;
19use serde::{Deserialize, Serialize};
20use settings::Settings;
21use smol::stream::StreamExt as _;
22use std::ffi::OsStr;
23use std::path::{Path, PathBuf};
24use std::sync::Arc;
25use ui::SharedString;
26use util::ResultExt;
27use util::rel_path::RelPath;
28
29const DEFAULT_UI_TEXT: &str = "Editing file";
30
31/// This is a tool for creating a new file or editing an existing file. For moving or renaming files, you should generally use the `terminal` tool with the 'mv' command instead.
32///
33/// Before using this tool:
34///
35/// 1. Use the `read_file` tool to understand the file's contents and context
36///
37/// 2. Verify the directory path is correct (only applicable when creating new files):
38/// - Use the `list_directory` tool to verify the parent directory exists and is the correct location
39#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
40pub struct EditFileToolInput {
41 /// A one-line, user-friendly markdown description of the edit. This will be shown in the UI and also passed to another model to perform the edit.
42 ///
43 /// Be terse, but also descriptive in what you want to achieve with this edit. Avoid generic instructions.
44 ///
45 /// NEVER mention the file path in this description.
46 ///
47 /// <example>Fix API endpoint URLs</example>
48 /// <example>Update copyright year in `page_footer`</example>
49 ///
50 /// Make sure to include this field before all the others in the input object so that we can display it immediately.
51 pub display_description: String,
52
53 /// The full path of the file to create or modify in the project.
54 ///
55 /// WARNING: When specifying which file path need changing, you MUST start each path with one of the project's root directories.
56 ///
57 /// The following examples assume we have two root directories in the project:
58 /// - /a/b/backend
59 /// - /c/d/frontend
60 ///
61 /// <example>
62 /// `backend/src/main.rs`
63 ///
64 /// Notice how the file path starts with `backend`. Without that, the path would be ambiguous and the call would fail!
65 /// </example>
66 ///
67 /// <example>
68 /// `frontend/db.js`
69 /// </example>
70 pub path: PathBuf,
71 /// The mode of operation on the file. Possible values:
72 /// - 'edit': Make granular edits to an existing file.
73 /// - 'create': Create a new file if it doesn't exist.
74 /// - 'overwrite': Replace the entire contents of an existing file.
75 ///
76 /// When a file already exists or you just created it, prefer editing it as opposed to recreating it from scratch.
77 pub mode: EditFileMode,
78}
79
80#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
81struct EditFileToolPartialInput {
82 #[serde(default)]
83 path: String,
84 #[serde(default)]
85 display_description: String,
86}
87
88#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
89#[serde(rename_all = "lowercase")]
90#[schemars(inline)]
91pub enum EditFileMode {
92 Edit,
93 Create,
94 Overwrite,
95}
96
97#[derive(Debug, Serialize, Deserialize)]
98pub struct EditFileToolOutput {
99 #[serde(alias = "original_path")]
100 input_path: PathBuf,
101 new_text: String,
102 old_text: Arc<String>,
103 #[serde(default)]
104 diff: String,
105 #[serde(alias = "raw_output")]
106 edit_agent_output: EditAgentOutput,
107}
108
109impl From<EditFileToolOutput> for LanguageModelToolResultContent {
110 fn from(output: EditFileToolOutput) -> Self {
111 if output.diff.is_empty() {
112 "No edits were made.".into()
113 } else {
114 format!(
115 "Edited {}:\n\n```diff\n{}\n```",
116 output.input_path.display(),
117 output.diff
118 )
119 .into()
120 }
121 }
122}
123
124pub struct EditFileTool {
125 thread: WeakEntity<Thread>,
126 language_registry: Arc<LanguageRegistry>,
127 project: Entity<Project>,
128 templates: Arc<Templates>,
129}
130
131impl EditFileTool {
132 pub fn new(
133 project: Entity<Project>,
134 thread: WeakEntity<Thread>,
135 language_registry: Arc<LanguageRegistry>,
136 templates: Arc<Templates>,
137 ) -> Self {
138 Self {
139 project,
140 thread,
141 language_registry,
142 templates,
143 }
144 }
145
146 fn authorize(
147 &self,
148 input: &EditFileToolInput,
149 event_stream: &ToolCallEventStream,
150 cx: &mut App,
151 ) -> Task<Result<()>> {
152 if agent_settings::AgentSettings::get_global(cx).always_allow_tool_actions {
153 return Task::ready(Ok(()));
154 }
155
156 // If any path component matches the local settings folder, then this could affect
157 // the editor in ways beyond the project source, so prompt.
158 let local_settings_folder = paths::local_settings_folder_name();
159 let path = Path::new(&input.path);
160 if path.components().any(|component| {
161 component.as_os_str() == <_ as AsRef<OsStr>>::as_ref(&local_settings_folder)
162 }) {
163 return event_stream.authorize(
164 format!("{} (local settings)", input.display_description),
165 cx,
166 );
167 }
168
169 // It's also possible that the global config dir is configured to be inside the project,
170 // so check for that edge case too.
171 // TODO this is broken when remoting
172 if let Ok(canonical_path) = std::fs::canonicalize(&input.path)
173 && canonical_path.starts_with(paths::config_dir())
174 {
175 return event_stream.authorize(
176 format!("{} (global settings)", input.display_description),
177 cx,
178 );
179 }
180
181 // Check if path is inside the global config directory
182 // First check if it's already inside project - if not, try to canonicalize
183 let Ok(project_path) = self.thread.read_with(cx, |thread, cx| {
184 thread.project().read(cx).find_project_path(&input.path, cx)
185 }) else {
186 return Task::ready(Err(anyhow!("thread was dropped")));
187 };
188
189 // If the path is inside the project, and it's not one of the above edge cases,
190 // then no confirmation is necessary. Otherwise, confirmation is necessary.
191 if project_path.is_some() {
192 Task::ready(Ok(()))
193 } else {
194 event_stream.authorize(&input.display_description, cx)
195 }
196 }
197}
198
199impl AgentTool for EditFileTool {
200 type Input = EditFileToolInput;
201 type Output = EditFileToolOutput;
202
203 fn name() -> &'static str {
204 "edit_file"
205 }
206
207 fn kind() -> acp::ToolKind {
208 acp::ToolKind::Edit
209 }
210
211 fn initial_title(
212 &self,
213 input: Result<Self::Input, serde_json::Value>,
214 cx: &mut App,
215 ) -> SharedString {
216 match input {
217 Ok(input) => self
218 .project
219 .read(cx)
220 .find_project_path(&input.path, cx)
221 .and_then(|project_path| {
222 self.project
223 .read(cx)
224 .short_full_path_for_project_path(&project_path, cx)
225 })
226 .unwrap_or(input.path.to_string_lossy().into_owned())
227 .into(),
228 Err(raw_input) => {
229 if let Some(input) =
230 serde_json::from_value::<EditFileToolPartialInput>(raw_input).ok()
231 {
232 let path = input.path.trim();
233 if !path.is_empty() {
234 return self
235 .project
236 .read(cx)
237 .find_project_path(&input.path, cx)
238 .and_then(|project_path| {
239 self.project
240 .read(cx)
241 .short_full_path_for_project_path(&project_path, cx)
242 })
243 .unwrap_or(input.path)
244 .into();
245 }
246
247 let description = input.display_description.trim();
248 if !description.is_empty() {
249 return description.to_string().into();
250 }
251 }
252
253 DEFAULT_UI_TEXT.into()
254 }
255 }
256 }
257
258 fn run(
259 self: Arc<Self>,
260 input: Self::Input,
261 event_stream: ToolCallEventStream,
262 cx: &mut App,
263 ) -> Task<Result<Self::Output>> {
264 let Ok(project) = self
265 .thread
266 .read_with(cx, |thread, _cx| thread.project().clone())
267 else {
268 return Task::ready(Err(anyhow!("thread was dropped")));
269 };
270 let project_path = match resolve_path(&input, project.clone(), cx) {
271 Ok(path) => path,
272 Err(err) => return Task::ready(Err(anyhow!(err))),
273 };
274 let abs_path = project.read(cx).absolute_path(&project_path, cx);
275 if let Some(abs_path) = abs_path.clone() {
276 event_stream.update_fields(
277 ToolCallUpdateFields::new().locations(vec![acp::ToolCallLocation::new(abs_path)]),
278 );
279 }
280
281 let authorize = self.authorize(&input, &event_stream, cx);
282 cx.spawn(async move |cx: &mut AsyncApp| {
283 authorize.await?;
284
285 let (request, model, action_log) = self.thread.update(cx, |thread, cx| {
286 let request = thread.build_completion_request(CompletionIntent::ToolResults, cx);
287 (request, thread.model().cloned(), thread.action_log().clone())
288 })?;
289 let request = request?;
290 let model = model.context("No language model configured")?;
291
292 let edit_format = EditFormat::from_model(model.clone())?;
293 let edit_agent = EditAgent::new(
294 model,
295 project.clone(),
296 action_log.clone(),
297 self.templates.clone(),
298 edit_format,
299 );
300
301 let buffer = project
302 .update(cx, |project, cx| {
303 project.open_buffer(project_path.clone(), cx)
304 })?
305 .await?;
306
307 // Check if the file has been modified since the agent last read it
308 if let Some(abs_path) = abs_path.as_ref() {
309 let (last_read_mtime, current_mtime, is_dirty, has_save_tool, has_restore_tool) = self.thread.update(cx, |thread, cx| {
310 let last_read = thread.file_read_times.get(abs_path).copied();
311 let current = buffer.read(cx).file().and_then(|file| file.disk_state().mtime());
312 let dirty = buffer.read(cx).is_dirty();
313 let has_save = thread.has_tool("save_file");
314 let has_restore = thread.has_tool("restore_file_from_disk");
315 (last_read, current, dirty, has_save, has_restore)
316 })?;
317
318 // Check for unsaved changes first - these indicate modifications we don't know about
319 if is_dirty {
320 let message = match (has_save_tool, has_restore_tool) {
321 (true, true) => {
322 "This file has unsaved changes. Ask the user whether they want to keep or discard those changes. \
323 If they want to keep them, ask for confirmation then use the save_file tool to save the file, then retry this edit. \
324 If they want to discard them, ask for confirmation then use the restore_file_from_disk tool to restore the on-disk contents, then retry this edit."
325 }
326 (true, false) => {
327 "This file has unsaved changes. Ask the user whether they want to keep or discard those changes. \
328 If they want to keep them, ask for confirmation then use the save_file tool to save the file, then retry this edit. \
329 If they want to discard them, ask the user to manually revert the file, then inform you when it's ok to proceed."
330 }
331 (false, true) => {
332 "This file has unsaved changes. Ask the user whether they want to keep or discard those changes. \
333 If they want to keep them, ask the user to manually save the file, then inform you when it's ok to proceed. \
334 If they want to discard them, ask for confirmation then use the restore_file_from_disk tool to restore the on-disk contents, then retry this edit."
335 }
336 (false, false) => {
337 "This file has unsaved changes. Ask the user whether they want to keep or discard those changes, \
338 then ask them to save or revert the file manually and inform you when it's ok to proceed."
339 }
340 };
341 anyhow::bail!("{}", message);
342 }
343
344 // Check if the file was modified on disk since we last read it
345 if let (Some(last_read), Some(current)) = (last_read_mtime, current_mtime) {
346 // MTime can be unreliable for comparisons, so our newtype intentionally
347 // doesn't support comparing them. If the mtime at all different
348 // (which could be because of a modification or because e.g. system clock changed),
349 // we pessimistically assume it was modified.
350 if current != last_read {
351 anyhow::bail!(
352 "The file {} has been modified since you last read it. \
353 Please read the file again to get the current state before editing it.",
354 input.path.display()
355 );
356 }
357 }
358 }
359
360 let diff = cx.new(|cx| Diff::new(buffer.clone(), cx))?;
361 event_stream.update_diff(diff.clone());
362 let _finalize_diff = util::defer({
363 let diff = diff.downgrade();
364 let mut cx = cx.clone();
365 move || {
366 diff.update(&mut cx, |diff, cx| diff.finalize(cx)).ok();
367 }
368 });
369
370 let old_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
371 let old_text = cx
372 .background_spawn({
373 let old_snapshot = old_snapshot.clone();
374 async move { Arc::new(old_snapshot.text()) }
375 })
376 .await;
377
378
379 let (output, mut events) = if matches!(input.mode, EditFileMode::Edit) {
380 edit_agent.edit(
381 buffer.clone(),
382 input.display_description.clone(),
383 &request,
384 cx,
385 )
386 } else {
387 edit_agent.overwrite(
388 buffer.clone(),
389 input.display_description.clone(),
390 &request,
391 cx,
392 )
393 };
394
395 let mut hallucinated_old_text = false;
396 let mut ambiguous_ranges = Vec::new();
397 let mut emitted_location = false;
398 while let Some(event) = events.next().await {
399 match event {
400 EditAgentOutputEvent::Edited(range) => {
401 if !emitted_location {
402 let line = buffer.update(cx, |buffer, _cx| {
403 range.start.to_point(&buffer.snapshot()).row
404 }).ok();
405 if let Some(abs_path) = abs_path.clone() {
406 event_stream.update_fields(ToolCallUpdateFields::new().locations(vec![ToolCallLocation::new(abs_path).line(line)]));
407 }
408 emitted_location = true;
409 }
410 },
411 EditAgentOutputEvent::UnresolvedEditRange => hallucinated_old_text = true,
412 EditAgentOutputEvent::AmbiguousEditRange(ranges) => ambiguous_ranges = ranges,
413 EditAgentOutputEvent::ResolvingEditRange(range) => {
414 diff.update(cx, |card, cx| card.reveal_range(range.clone(), cx))?;
415 // if !emitted_location {
416 // let line = buffer.update(cx, |buffer, _cx| {
417 // range.start.to_point(&buffer.snapshot()).row
418 // }).ok();
419 // if let Some(abs_path) = abs_path.clone() {
420 // event_stream.update_fields(ToolCallUpdateFields {
421 // locations: Some(vec![ToolCallLocation { path: abs_path, line }]),
422 // ..Default::default()
423 // });
424 // }
425 // }
426 }
427 }
428 }
429
430 // If format_on_save is enabled, format the buffer
431 let format_on_save_enabled = buffer
432 .read_with(cx, |buffer, cx| {
433 let settings = language_settings::language_settings(
434 buffer.language().map(|l| l.name()),
435 buffer.file(),
436 cx,
437 );
438 settings.format_on_save != FormatOnSave::Off
439 })
440 .unwrap_or(false);
441
442 let edit_agent_output = output.await?;
443
444 if format_on_save_enabled {
445 action_log.update(cx, |log, cx| {
446 log.buffer_edited(buffer.clone(), cx);
447 })?;
448
449 let format_task = project.update(cx, |project, cx| {
450 project.format(
451 HashSet::from_iter([buffer.clone()]),
452 LspFormatTarget::Buffers,
453 false, // Don't push to history since the tool did it.
454 FormatTrigger::Save,
455 cx,
456 )
457 })?;
458 format_task.await.log_err();
459 }
460
461 project
462 .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))?
463 .await?;
464
465 action_log.update(cx, |log, cx| {
466 log.buffer_edited(buffer.clone(), cx);
467 })?;
468
469 // Update the recorded read time after a successful edit so consecutive edits work
470 if let Some(abs_path) = abs_path.as_ref() {
471 if let Some(new_mtime) = buffer.read_with(cx, |buffer, _| {
472 buffer.file().and_then(|file| file.disk_state().mtime())
473 })? {
474 self.thread.update(cx, |thread, _| {
475 thread.file_read_times.insert(abs_path.to_path_buf(), new_mtime);
476 })?;
477 }
478 }
479
480 let new_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
481 let (new_text, unified_diff) = cx
482 .background_spawn({
483 let new_snapshot = new_snapshot.clone();
484 let old_text = old_text.clone();
485 async move {
486 let new_text = new_snapshot.text();
487 let diff = language::unified_diff(&old_text, &new_text);
488 (new_text, diff)
489 }
490 })
491 .await;
492
493 let input_path = input.path.display();
494 if unified_diff.is_empty() {
495 anyhow::ensure!(
496 !hallucinated_old_text,
497 formatdoc! {"
498 Some edits were produced but none of them could be applied.
499 Read the relevant sections of {input_path} again so that
500 I can perform the requested edits.
501 "}
502 );
503 anyhow::ensure!(
504 ambiguous_ranges.is_empty(),
505 {
506 let line_numbers = ambiguous_ranges
507 .iter()
508 .map(|range| range.start.to_string())
509 .collect::<Vec<_>>()
510 .join(", ");
511 formatdoc! {"
512 <old_text> matches more than one position in the file (lines: {line_numbers}). Read the
513 relevant sections of {input_path} again and extend <old_text> so
514 that I can perform the requested edits.
515 "}
516 }
517 );
518 }
519
520 Ok(EditFileToolOutput {
521 input_path: input.path,
522 new_text,
523 old_text,
524 diff: unified_diff,
525 edit_agent_output,
526 })
527 })
528 }
529
530 fn replay(
531 &self,
532 _input: Self::Input,
533 output: Self::Output,
534 event_stream: ToolCallEventStream,
535 cx: &mut App,
536 ) -> Result<()> {
537 event_stream.update_diff(cx.new(|cx| {
538 Diff::finalized(
539 output.input_path.to_string_lossy().into_owned(),
540 Some(output.old_text.to_string()),
541 output.new_text,
542 self.language_registry.clone(),
543 cx,
544 )
545 }));
546 Ok(())
547 }
548}
549
550/// Validate that the file path is valid, meaning:
551///
552/// - For `edit` and `overwrite`, the path must point to an existing file.
553/// - For `create`, the file must not already exist, but it's parent dir must exist.
554fn resolve_path(
555 input: &EditFileToolInput,
556 project: Entity<Project>,
557 cx: &mut App,
558) -> Result<ProjectPath> {
559 let project = project.read(cx);
560
561 match input.mode {
562 EditFileMode::Edit | EditFileMode::Overwrite => {
563 let path = project
564 .find_project_path(&input.path, cx)
565 .context("Can't edit file: path not found")?;
566
567 let entry = project
568 .entry_for_path(&path, cx)
569 .context("Can't edit file: path not found")?;
570
571 anyhow::ensure!(entry.is_file(), "Can't edit file: path is a directory");
572 Ok(path)
573 }
574
575 EditFileMode::Create => {
576 if let Some(path) = project.find_project_path(&input.path, cx) {
577 anyhow::ensure!(
578 project.entry_for_path(&path, cx).is_none(),
579 "Can't create file: file already exists"
580 );
581 }
582
583 let parent_path = input
584 .path
585 .parent()
586 .context("Can't create file: incorrect path")?;
587
588 let parent_project_path = project.find_project_path(&parent_path, cx);
589
590 let parent_entry = parent_project_path
591 .as_ref()
592 .and_then(|path| project.entry_for_path(path, cx))
593 .context("Can't create file: parent directory doesn't exist")?;
594
595 anyhow::ensure!(
596 parent_entry.is_dir(),
597 "Can't create file: parent is not a directory"
598 );
599
600 let file_name = input
601 .path
602 .file_name()
603 .and_then(|file_name| file_name.to_str())
604 .and_then(|file_name| RelPath::unix(file_name).ok())
605 .context("Can't create file: invalid filename")?;
606
607 let new_file_path = parent_project_path.map(|parent| ProjectPath {
608 path: parent.path.join(file_name),
609 ..parent
610 });
611
612 new_file_path.context("Can't create file")
613 }
614 }
615}
616
617#[cfg(test)]
618mod tests {
619 use super::*;
620 use crate::{ContextServerRegistry, Templates};
621 use fs::Fs;
622 use gpui::{TestAppContext, UpdateGlobal};
623 use language_model::fake_provider::FakeLanguageModel;
624 use prompt_store::ProjectContext;
625 use serde_json::json;
626 use settings::SettingsStore;
627 use util::{path, rel_path::rel_path};
628
629 #[gpui::test]
630 async fn test_edit_nonexistent_file(cx: &mut TestAppContext) {
631 init_test(cx);
632
633 let fs = project::FakeFs::new(cx.executor());
634 fs.insert_tree("/root", json!({})).await;
635 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
636 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
637 let context_server_registry =
638 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
639 let model = Arc::new(FakeLanguageModel::default());
640 let thread = cx.new(|cx| {
641 Thread::new(
642 project.clone(),
643 cx.new(|_cx| ProjectContext::default()),
644 context_server_registry,
645 Templates::new(),
646 Some(model),
647 cx,
648 )
649 });
650 let result = cx
651 .update(|cx| {
652 let input = EditFileToolInput {
653 display_description: "Some edit".into(),
654 path: "root/nonexistent_file.txt".into(),
655 mode: EditFileMode::Edit,
656 };
657 Arc::new(EditFileTool::new(
658 project,
659 thread.downgrade(),
660 language_registry,
661 Templates::new(),
662 ))
663 .run(input, ToolCallEventStream::test().0, cx)
664 })
665 .await;
666 assert_eq!(
667 result.unwrap_err().to_string(),
668 "Can't edit file: path not found"
669 );
670 }
671
672 #[gpui::test]
673 async fn test_resolve_path_for_creating_file(cx: &mut TestAppContext) {
674 let mode = &EditFileMode::Create;
675
676 let result = test_resolve_path(mode, "root/new.txt", cx);
677 assert_resolved_path_eq(result.await, rel_path("new.txt"));
678
679 let result = test_resolve_path(mode, "new.txt", cx);
680 assert_resolved_path_eq(result.await, rel_path("new.txt"));
681
682 let result = test_resolve_path(mode, "dir/new.txt", cx);
683 assert_resolved_path_eq(result.await, rel_path("dir/new.txt"));
684
685 let result = test_resolve_path(mode, "root/dir/subdir/existing.txt", cx);
686 assert_eq!(
687 result.await.unwrap_err().to_string(),
688 "Can't create file: file already exists"
689 );
690
691 let result = test_resolve_path(mode, "root/dir/nonexistent_dir/new.txt", cx);
692 assert_eq!(
693 result.await.unwrap_err().to_string(),
694 "Can't create file: parent directory doesn't exist"
695 );
696 }
697
698 #[gpui::test]
699 async fn test_resolve_path_for_editing_file(cx: &mut TestAppContext) {
700 let mode = &EditFileMode::Edit;
701
702 let path_with_root = "root/dir/subdir/existing.txt";
703 let path_without_root = "dir/subdir/existing.txt";
704 let result = test_resolve_path(mode, path_with_root, cx);
705 assert_resolved_path_eq(result.await, rel_path(path_without_root));
706
707 let result = test_resolve_path(mode, path_without_root, cx);
708 assert_resolved_path_eq(result.await, rel_path(path_without_root));
709
710 let result = test_resolve_path(mode, "root/nonexistent.txt", cx);
711 assert_eq!(
712 result.await.unwrap_err().to_string(),
713 "Can't edit file: path not found"
714 );
715
716 let result = test_resolve_path(mode, "root/dir", cx);
717 assert_eq!(
718 result.await.unwrap_err().to_string(),
719 "Can't edit file: path is a directory"
720 );
721 }
722
723 async fn test_resolve_path(
724 mode: &EditFileMode,
725 path: &str,
726 cx: &mut TestAppContext,
727 ) -> anyhow::Result<ProjectPath> {
728 init_test(cx);
729
730 let fs = project::FakeFs::new(cx.executor());
731 fs.insert_tree(
732 "/root",
733 json!({
734 "dir": {
735 "subdir": {
736 "existing.txt": "hello"
737 }
738 }
739 }),
740 )
741 .await;
742 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
743
744 let input = EditFileToolInput {
745 display_description: "Some edit".into(),
746 path: path.into(),
747 mode: mode.clone(),
748 };
749
750 cx.update(|cx| resolve_path(&input, project, cx))
751 }
752
753 #[track_caller]
754 fn assert_resolved_path_eq(path: anyhow::Result<ProjectPath>, expected: &RelPath) {
755 let actual = path.expect("Should return valid path").path;
756 assert_eq!(actual.as_ref(), expected);
757 }
758
759 #[gpui::test]
760 async fn test_format_on_save(cx: &mut TestAppContext) {
761 init_test(cx);
762
763 let fs = project::FakeFs::new(cx.executor());
764 fs.insert_tree("/root", json!({"src": {}})).await;
765
766 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
767
768 // Set up a Rust language with LSP formatting support
769 let rust_language = Arc::new(language::Language::new(
770 language::LanguageConfig {
771 name: "Rust".into(),
772 matcher: language::LanguageMatcher {
773 path_suffixes: vec!["rs".to_string()],
774 ..Default::default()
775 },
776 ..Default::default()
777 },
778 None,
779 ));
780
781 // Register the language and fake LSP
782 let language_registry = project.read_with(cx, |project, _| project.languages().clone());
783 language_registry.add(rust_language);
784
785 let mut fake_language_servers = language_registry.register_fake_lsp(
786 "Rust",
787 language::FakeLspAdapter {
788 capabilities: lsp::ServerCapabilities {
789 document_formatting_provider: Some(lsp::OneOf::Left(true)),
790 ..Default::default()
791 },
792 ..Default::default()
793 },
794 );
795
796 // Create the file
797 fs.save(
798 path!("/root/src/main.rs").as_ref(),
799 &"initial content".into(),
800 language::LineEnding::Unix,
801 )
802 .await
803 .unwrap();
804
805 // Open the buffer to trigger LSP initialization
806 let buffer = project
807 .update(cx, |project, cx| {
808 project.open_local_buffer(path!("/root/src/main.rs"), cx)
809 })
810 .await
811 .unwrap();
812
813 // Register the buffer with language servers
814 let _handle = project.update(cx, |project, cx| {
815 project.register_buffer_with_language_servers(&buffer, cx)
816 });
817
818 const UNFORMATTED_CONTENT: &str = "fn main() {println!(\"Hello!\");}\n";
819 const FORMATTED_CONTENT: &str =
820 "This file was formatted by the fake formatter in the test.\n";
821
822 // Get the fake language server and set up formatting handler
823 let fake_language_server = fake_language_servers.next().await.unwrap();
824 fake_language_server.set_request_handler::<lsp::request::Formatting, _, _>({
825 |_, _| async move {
826 Ok(Some(vec![lsp::TextEdit {
827 range: lsp::Range::new(lsp::Position::new(0, 0), lsp::Position::new(1, 0)),
828 new_text: FORMATTED_CONTENT.to_string(),
829 }]))
830 }
831 });
832
833 let context_server_registry =
834 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
835 let model = Arc::new(FakeLanguageModel::default());
836 let thread = cx.new(|cx| {
837 Thread::new(
838 project.clone(),
839 cx.new(|_cx| ProjectContext::default()),
840 context_server_registry,
841 Templates::new(),
842 Some(model.clone()),
843 cx,
844 )
845 });
846
847 // First, test with format_on_save enabled
848 cx.update(|cx| {
849 SettingsStore::update_global(cx, |store, cx| {
850 store.update_user_settings(cx, |settings| {
851 settings.project.all_languages.defaults.format_on_save = Some(FormatOnSave::On);
852 settings.project.all_languages.defaults.formatter =
853 Some(language::language_settings::FormatterList::default());
854 });
855 });
856 });
857
858 // Have the model stream unformatted content
859 let edit_result = {
860 let edit_task = cx.update(|cx| {
861 let input = EditFileToolInput {
862 display_description: "Create main function".into(),
863 path: "root/src/main.rs".into(),
864 mode: EditFileMode::Overwrite,
865 };
866 Arc::new(EditFileTool::new(
867 project.clone(),
868 thread.downgrade(),
869 language_registry.clone(),
870 Templates::new(),
871 ))
872 .run(input, ToolCallEventStream::test().0, cx)
873 });
874
875 // Stream the unformatted content
876 cx.executor().run_until_parked();
877 model.send_last_completion_stream_text_chunk(UNFORMATTED_CONTENT.to_string());
878 model.end_last_completion_stream();
879
880 edit_task.await
881 };
882 assert!(edit_result.is_ok());
883
884 // Wait for any async operations (e.g. formatting) to complete
885 cx.executor().run_until_parked();
886
887 // Read the file to verify it was formatted automatically
888 let new_content = fs.load(path!("/root/src/main.rs").as_ref()).await.unwrap();
889 assert_eq!(
890 // Ignore carriage returns on Windows
891 new_content.replace("\r\n", "\n"),
892 FORMATTED_CONTENT,
893 "Code should be formatted when format_on_save is enabled"
894 );
895
896 let stale_buffer_count = thread
897 .read_with(cx, |thread, _cx| thread.action_log.clone())
898 .read_with(cx, |log, cx| log.stale_buffers(cx).count());
899
900 assert_eq!(
901 stale_buffer_count, 0,
902 "BUG: Buffer is incorrectly marked as stale after format-on-save. Found {} stale buffers. \
903 This causes the agent to think the file was modified externally when it was just formatted.",
904 stale_buffer_count
905 );
906
907 // Next, test with format_on_save disabled
908 cx.update(|cx| {
909 SettingsStore::update_global(cx, |store, cx| {
910 store.update_user_settings(cx, |settings| {
911 settings.project.all_languages.defaults.format_on_save =
912 Some(FormatOnSave::Off);
913 });
914 });
915 });
916
917 // Stream unformatted edits again
918 let edit_result = {
919 let edit_task = cx.update(|cx| {
920 let input = EditFileToolInput {
921 display_description: "Update main function".into(),
922 path: "root/src/main.rs".into(),
923 mode: EditFileMode::Overwrite,
924 };
925 Arc::new(EditFileTool::new(
926 project.clone(),
927 thread.downgrade(),
928 language_registry,
929 Templates::new(),
930 ))
931 .run(input, ToolCallEventStream::test().0, cx)
932 });
933
934 // Stream the unformatted content
935 cx.executor().run_until_parked();
936 model.send_last_completion_stream_text_chunk(UNFORMATTED_CONTENT.to_string());
937 model.end_last_completion_stream();
938
939 edit_task.await
940 };
941 assert!(edit_result.is_ok());
942
943 // Wait for any async operations (e.g. formatting) to complete
944 cx.executor().run_until_parked();
945
946 // Verify the file was not formatted
947 let new_content = fs.load(path!("/root/src/main.rs").as_ref()).await.unwrap();
948 assert_eq!(
949 // Ignore carriage returns on Windows
950 new_content.replace("\r\n", "\n"),
951 UNFORMATTED_CONTENT,
952 "Code should not be formatted when format_on_save is disabled"
953 );
954 }
955
956 #[gpui::test]
957 async fn test_remove_trailing_whitespace(cx: &mut TestAppContext) {
958 init_test(cx);
959
960 let fs = project::FakeFs::new(cx.executor());
961 fs.insert_tree("/root", json!({"src": {}})).await;
962
963 // Create a simple file with trailing whitespace
964 fs.save(
965 path!("/root/src/main.rs").as_ref(),
966 &"initial content".into(),
967 language::LineEnding::Unix,
968 )
969 .await
970 .unwrap();
971
972 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
973 let context_server_registry =
974 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
975 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
976 let model = Arc::new(FakeLanguageModel::default());
977 let thread = cx.new(|cx| {
978 Thread::new(
979 project.clone(),
980 cx.new(|_cx| ProjectContext::default()),
981 context_server_registry,
982 Templates::new(),
983 Some(model.clone()),
984 cx,
985 )
986 });
987
988 // First, test with remove_trailing_whitespace_on_save enabled
989 cx.update(|cx| {
990 SettingsStore::update_global(cx, |store, cx| {
991 store.update_user_settings(cx, |settings| {
992 settings
993 .project
994 .all_languages
995 .defaults
996 .remove_trailing_whitespace_on_save = Some(true);
997 });
998 });
999 });
1000
1001 const CONTENT_WITH_TRAILING_WHITESPACE: &str =
1002 "fn main() { \n println!(\"Hello!\"); \n}\n";
1003
1004 // Have the model stream content that contains trailing whitespace
1005 let edit_result = {
1006 let edit_task = cx.update(|cx| {
1007 let input = EditFileToolInput {
1008 display_description: "Create main function".into(),
1009 path: "root/src/main.rs".into(),
1010 mode: EditFileMode::Overwrite,
1011 };
1012 Arc::new(EditFileTool::new(
1013 project.clone(),
1014 thread.downgrade(),
1015 language_registry.clone(),
1016 Templates::new(),
1017 ))
1018 .run(input, ToolCallEventStream::test().0, cx)
1019 });
1020
1021 // Stream the content with trailing whitespace
1022 cx.executor().run_until_parked();
1023 model.send_last_completion_stream_text_chunk(
1024 CONTENT_WITH_TRAILING_WHITESPACE.to_string(),
1025 );
1026 model.end_last_completion_stream();
1027
1028 edit_task.await
1029 };
1030 assert!(edit_result.is_ok());
1031
1032 // Wait for any async operations (e.g. formatting) to complete
1033 cx.executor().run_until_parked();
1034
1035 // Read the file to verify trailing whitespace was removed automatically
1036 assert_eq!(
1037 // Ignore carriage returns on Windows
1038 fs.load(path!("/root/src/main.rs").as_ref())
1039 .await
1040 .unwrap()
1041 .replace("\r\n", "\n"),
1042 "fn main() {\n println!(\"Hello!\");\n}\n",
1043 "Trailing whitespace should be removed when remove_trailing_whitespace_on_save is enabled"
1044 );
1045
1046 // Next, test with remove_trailing_whitespace_on_save disabled
1047 cx.update(|cx| {
1048 SettingsStore::update_global(cx, |store, cx| {
1049 store.update_user_settings(cx, |settings| {
1050 settings
1051 .project
1052 .all_languages
1053 .defaults
1054 .remove_trailing_whitespace_on_save = Some(false);
1055 });
1056 });
1057 });
1058
1059 // Stream edits again with trailing whitespace
1060 let edit_result = {
1061 let edit_task = cx.update(|cx| {
1062 let input = EditFileToolInput {
1063 display_description: "Update main function".into(),
1064 path: "root/src/main.rs".into(),
1065 mode: EditFileMode::Overwrite,
1066 };
1067 Arc::new(EditFileTool::new(
1068 project.clone(),
1069 thread.downgrade(),
1070 language_registry,
1071 Templates::new(),
1072 ))
1073 .run(input, ToolCallEventStream::test().0, cx)
1074 });
1075
1076 // Stream the content with trailing whitespace
1077 cx.executor().run_until_parked();
1078 model.send_last_completion_stream_text_chunk(
1079 CONTENT_WITH_TRAILING_WHITESPACE.to_string(),
1080 );
1081 model.end_last_completion_stream();
1082
1083 edit_task.await
1084 };
1085 assert!(edit_result.is_ok());
1086
1087 // Wait for any async operations (e.g. formatting) to complete
1088 cx.executor().run_until_parked();
1089
1090 // Verify the file still has trailing whitespace
1091 // Read the file again - it should still have trailing whitespace
1092 let final_content = fs.load(path!("/root/src/main.rs").as_ref()).await.unwrap();
1093 assert_eq!(
1094 // Ignore carriage returns on Windows
1095 final_content.replace("\r\n", "\n"),
1096 CONTENT_WITH_TRAILING_WHITESPACE,
1097 "Trailing whitespace should remain when remove_trailing_whitespace_on_save is disabled"
1098 );
1099 }
1100
1101 #[gpui::test]
1102 async fn test_authorize(cx: &mut TestAppContext) {
1103 init_test(cx);
1104 let fs = project::FakeFs::new(cx.executor());
1105 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
1106 let context_server_registry =
1107 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1108 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1109 let model = Arc::new(FakeLanguageModel::default());
1110 let thread = cx.new(|cx| {
1111 Thread::new(
1112 project.clone(),
1113 cx.new(|_cx| ProjectContext::default()),
1114 context_server_registry,
1115 Templates::new(),
1116 Some(model.clone()),
1117 cx,
1118 )
1119 });
1120 let tool = Arc::new(EditFileTool::new(
1121 project.clone(),
1122 thread.downgrade(),
1123 language_registry,
1124 Templates::new(),
1125 ));
1126 fs.insert_tree("/root", json!({})).await;
1127
1128 // Test 1: Path with .zed component should require confirmation
1129 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1130 let _auth = cx.update(|cx| {
1131 tool.authorize(
1132 &EditFileToolInput {
1133 display_description: "test 1".into(),
1134 path: ".zed/settings.json".into(),
1135 mode: EditFileMode::Edit,
1136 },
1137 &stream_tx,
1138 cx,
1139 )
1140 });
1141
1142 let event = stream_rx.expect_authorization().await;
1143 assert_eq!(
1144 event.tool_call.fields.title,
1145 Some("test 1 (local settings)".into())
1146 );
1147
1148 // Test 2: Path outside project should require confirmation
1149 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1150 let _auth = cx.update(|cx| {
1151 tool.authorize(
1152 &EditFileToolInput {
1153 display_description: "test 2".into(),
1154 path: "/etc/hosts".into(),
1155 mode: EditFileMode::Edit,
1156 },
1157 &stream_tx,
1158 cx,
1159 )
1160 });
1161
1162 let event = stream_rx.expect_authorization().await;
1163 assert_eq!(event.tool_call.fields.title, Some("test 2".into()));
1164
1165 // Test 3: Relative path without .zed should not require confirmation
1166 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1167 cx.update(|cx| {
1168 tool.authorize(
1169 &EditFileToolInput {
1170 display_description: "test 3".into(),
1171 path: "root/src/main.rs".into(),
1172 mode: EditFileMode::Edit,
1173 },
1174 &stream_tx,
1175 cx,
1176 )
1177 })
1178 .await
1179 .unwrap();
1180 assert!(stream_rx.try_next().is_err());
1181
1182 // Test 4: Path with .zed in the middle should require confirmation
1183 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1184 let _auth = cx.update(|cx| {
1185 tool.authorize(
1186 &EditFileToolInput {
1187 display_description: "test 4".into(),
1188 path: "root/.zed/tasks.json".into(),
1189 mode: EditFileMode::Edit,
1190 },
1191 &stream_tx,
1192 cx,
1193 )
1194 });
1195 let event = stream_rx.expect_authorization().await;
1196 assert_eq!(
1197 event.tool_call.fields.title,
1198 Some("test 4 (local settings)".into())
1199 );
1200
1201 // Test 5: When always_allow_tool_actions is enabled, no confirmation needed
1202 cx.update(|cx| {
1203 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
1204 settings.always_allow_tool_actions = true;
1205 agent_settings::AgentSettings::override_global(settings, cx);
1206 });
1207
1208 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1209 cx.update(|cx| {
1210 tool.authorize(
1211 &EditFileToolInput {
1212 display_description: "test 5.1".into(),
1213 path: ".zed/settings.json".into(),
1214 mode: EditFileMode::Edit,
1215 },
1216 &stream_tx,
1217 cx,
1218 )
1219 })
1220 .await
1221 .unwrap();
1222 assert!(stream_rx.try_next().is_err());
1223
1224 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1225 cx.update(|cx| {
1226 tool.authorize(
1227 &EditFileToolInput {
1228 display_description: "test 5.2".into(),
1229 path: "/etc/hosts".into(),
1230 mode: EditFileMode::Edit,
1231 },
1232 &stream_tx,
1233 cx,
1234 )
1235 })
1236 .await
1237 .unwrap();
1238 assert!(stream_rx.try_next().is_err());
1239 }
1240
1241 #[gpui::test]
1242 async fn test_authorize_global_config(cx: &mut TestAppContext) {
1243 init_test(cx);
1244 let fs = project::FakeFs::new(cx.executor());
1245 fs.insert_tree("/project", json!({})).await;
1246 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1247 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1248 let context_server_registry =
1249 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1250 let model = Arc::new(FakeLanguageModel::default());
1251 let thread = cx.new(|cx| {
1252 Thread::new(
1253 project.clone(),
1254 cx.new(|_cx| ProjectContext::default()),
1255 context_server_registry,
1256 Templates::new(),
1257 Some(model.clone()),
1258 cx,
1259 )
1260 });
1261 let tool = Arc::new(EditFileTool::new(
1262 project.clone(),
1263 thread.downgrade(),
1264 language_registry,
1265 Templates::new(),
1266 ));
1267
1268 // Test global config paths - these should require confirmation if they exist and are outside the project
1269 let test_cases = vec![
1270 (
1271 "/etc/hosts",
1272 true,
1273 "System file should require confirmation",
1274 ),
1275 (
1276 "/usr/local/bin/script",
1277 true,
1278 "System bin file should require confirmation",
1279 ),
1280 (
1281 "project/normal_file.rs",
1282 false,
1283 "Normal project file should not require confirmation",
1284 ),
1285 ];
1286
1287 for (path, should_confirm, description) in test_cases {
1288 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1289 let auth = cx.update(|cx| {
1290 tool.authorize(
1291 &EditFileToolInput {
1292 display_description: "Edit file".into(),
1293 path: path.into(),
1294 mode: EditFileMode::Edit,
1295 },
1296 &stream_tx,
1297 cx,
1298 )
1299 });
1300
1301 if should_confirm {
1302 stream_rx.expect_authorization().await;
1303 } else {
1304 auth.await.unwrap();
1305 assert!(
1306 stream_rx.try_next().is_err(),
1307 "Failed for case: {} - path: {} - expected no confirmation but got one",
1308 description,
1309 path
1310 );
1311 }
1312 }
1313 }
1314
1315 #[gpui::test]
1316 async fn test_needs_confirmation_with_multiple_worktrees(cx: &mut TestAppContext) {
1317 init_test(cx);
1318 let fs = project::FakeFs::new(cx.executor());
1319
1320 // Create multiple worktree directories
1321 fs.insert_tree(
1322 "/workspace/frontend",
1323 json!({
1324 "src": {
1325 "main.js": "console.log('frontend');"
1326 }
1327 }),
1328 )
1329 .await;
1330 fs.insert_tree(
1331 "/workspace/backend",
1332 json!({
1333 "src": {
1334 "main.rs": "fn main() {}"
1335 }
1336 }),
1337 )
1338 .await;
1339 fs.insert_tree(
1340 "/workspace/shared",
1341 json!({
1342 ".zed": {
1343 "settings.json": "{}"
1344 }
1345 }),
1346 )
1347 .await;
1348
1349 // Create project with multiple worktrees
1350 let project = Project::test(
1351 fs.clone(),
1352 [
1353 path!("/workspace/frontend").as_ref(),
1354 path!("/workspace/backend").as_ref(),
1355 path!("/workspace/shared").as_ref(),
1356 ],
1357 cx,
1358 )
1359 .await;
1360 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1361 let context_server_registry =
1362 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1363 let model = Arc::new(FakeLanguageModel::default());
1364 let thread = cx.new(|cx| {
1365 Thread::new(
1366 project.clone(),
1367 cx.new(|_cx| ProjectContext::default()),
1368 context_server_registry.clone(),
1369 Templates::new(),
1370 Some(model.clone()),
1371 cx,
1372 )
1373 });
1374 let tool = Arc::new(EditFileTool::new(
1375 project.clone(),
1376 thread.downgrade(),
1377 language_registry,
1378 Templates::new(),
1379 ));
1380
1381 // Test files in different worktrees
1382 let test_cases = vec![
1383 ("frontend/src/main.js", false, "File in first worktree"),
1384 ("backend/src/main.rs", false, "File in second worktree"),
1385 (
1386 "shared/.zed/settings.json",
1387 true,
1388 ".zed file in third worktree",
1389 ),
1390 ("/etc/hosts", true, "Absolute path outside all worktrees"),
1391 (
1392 "../outside/file.txt",
1393 true,
1394 "Relative path outside worktrees",
1395 ),
1396 ];
1397
1398 for (path, should_confirm, description) in test_cases {
1399 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1400 let auth = cx.update(|cx| {
1401 tool.authorize(
1402 &EditFileToolInput {
1403 display_description: "Edit file".into(),
1404 path: path.into(),
1405 mode: EditFileMode::Edit,
1406 },
1407 &stream_tx,
1408 cx,
1409 )
1410 });
1411
1412 if should_confirm {
1413 stream_rx.expect_authorization().await;
1414 } else {
1415 auth.await.unwrap();
1416 assert!(
1417 stream_rx.try_next().is_err(),
1418 "Failed for case: {} - path: {} - expected no confirmation but got one",
1419 description,
1420 path
1421 );
1422 }
1423 }
1424 }
1425
1426 #[gpui::test]
1427 async fn test_needs_confirmation_edge_cases(cx: &mut TestAppContext) {
1428 init_test(cx);
1429 let fs = project::FakeFs::new(cx.executor());
1430 fs.insert_tree(
1431 "/project",
1432 json!({
1433 ".zed": {
1434 "settings.json": "{}"
1435 },
1436 "src": {
1437 ".zed": {
1438 "local.json": "{}"
1439 }
1440 }
1441 }),
1442 )
1443 .await;
1444 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1445 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1446 let context_server_registry =
1447 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1448 let model = Arc::new(FakeLanguageModel::default());
1449 let thread = cx.new(|cx| {
1450 Thread::new(
1451 project.clone(),
1452 cx.new(|_cx| ProjectContext::default()),
1453 context_server_registry.clone(),
1454 Templates::new(),
1455 Some(model.clone()),
1456 cx,
1457 )
1458 });
1459 let tool = Arc::new(EditFileTool::new(
1460 project.clone(),
1461 thread.downgrade(),
1462 language_registry,
1463 Templates::new(),
1464 ));
1465
1466 // Test edge cases
1467 let test_cases = vec![
1468 // Empty path - find_project_path returns Some for empty paths
1469 ("", false, "Empty path is treated as project root"),
1470 // Root directory
1471 ("/", true, "Root directory should be outside project"),
1472 // Parent directory references - find_project_path resolves these
1473 (
1474 "project/../other",
1475 true,
1476 "Path with .. that goes outside of root directory",
1477 ),
1478 (
1479 "project/./src/file.rs",
1480 false,
1481 "Path with . should work normally",
1482 ),
1483 // Windows-style paths (if on Windows)
1484 #[cfg(target_os = "windows")]
1485 ("C:\\Windows\\System32\\hosts", true, "Windows system path"),
1486 #[cfg(target_os = "windows")]
1487 ("project\\src\\main.rs", false, "Windows-style project path"),
1488 ];
1489
1490 for (path, should_confirm, description) in test_cases {
1491 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1492 let auth = cx.update(|cx| {
1493 tool.authorize(
1494 &EditFileToolInput {
1495 display_description: "Edit file".into(),
1496 path: path.into(),
1497 mode: EditFileMode::Edit,
1498 },
1499 &stream_tx,
1500 cx,
1501 )
1502 });
1503
1504 cx.run_until_parked();
1505
1506 if should_confirm {
1507 stream_rx.expect_authorization().await;
1508 } else {
1509 assert!(
1510 stream_rx.try_next().is_err(),
1511 "Failed for case: {} - path: {} - expected no confirmation but got one",
1512 description,
1513 path
1514 );
1515 auth.await.unwrap();
1516 }
1517 }
1518 }
1519
1520 #[gpui::test]
1521 async fn test_needs_confirmation_with_different_modes(cx: &mut TestAppContext) {
1522 init_test(cx);
1523 let fs = project::FakeFs::new(cx.executor());
1524 fs.insert_tree(
1525 "/project",
1526 json!({
1527 "existing.txt": "content",
1528 ".zed": {
1529 "settings.json": "{}"
1530 }
1531 }),
1532 )
1533 .await;
1534 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1535 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1536 let context_server_registry =
1537 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1538 let model = Arc::new(FakeLanguageModel::default());
1539 let thread = cx.new(|cx| {
1540 Thread::new(
1541 project.clone(),
1542 cx.new(|_cx| ProjectContext::default()),
1543 context_server_registry.clone(),
1544 Templates::new(),
1545 Some(model.clone()),
1546 cx,
1547 )
1548 });
1549 let tool = Arc::new(EditFileTool::new(
1550 project.clone(),
1551 thread.downgrade(),
1552 language_registry,
1553 Templates::new(),
1554 ));
1555
1556 // Test different EditFileMode values
1557 let modes = vec![
1558 EditFileMode::Edit,
1559 EditFileMode::Create,
1560 EditFileMode::Overwrite,
1561 ];
1562
1563 for mode in modes {
1564 // Test .zed path with different modes
1565 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1566 let _auth = cx.update(|cx| {
1567 tool.authorize(
1568 &EditFileToolInput {
1569 display_description: "Edit settings".into(),
1570 path: "project/.zed/settings.json".into(),
1571 mode: mode.clone(),
1572 },
1573 &stream_tx,
1574 cx,
1575 )
1576 });
1577
1578 stream_rx.expect_authorization().await;
1579
1580 // Test outside path with different modes
1581 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1582 let _auth = cx.update(|cx| {
1583 tool.authorize(
1584 &EditFileToolInput {
1585 display_description: "Edit file".into(),
1586 path: "/outside/file.txt".into(),
1587 mode: mode.clone(),
1588 },
1589 &stream_tx,
1590 cx,
1591 )
1592 });
1593
1594 stream_rx.expect_authorization().await;
1595
1596 // Test normal path with different modes
1597 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1598 cx.update(|cx| {
1599 tool.authorize(
1600 &EditFileToolInput {
1601 display_description: "Edit file".into(),
1602 path: "project/normal.txt".into(),
1603 mode: mode.clone(),
1604 },
1605 &stream_tx,
1606 cx,
1607 )
1608 })
1609 .await
1610 .unwrap();
1611 assert!(stream_rx.try_next().is_err());
1612 }
1613 }
1614
1615 #[gpui::test]
1616 async fn test_initial_title_with_partial_input(cx: &mut TestAppContext) {
1617 init_test(cx);
1618 let fs = project::FakeFs::new(cx.executor());
1619 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1620 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1621 let context_server_registry =
1622 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1623 let model = Arc::new(FakeLanguageModel::default());
1624 let thread = cx.new(|cx| {
1625 Thread::new(
1626 project.clone(),
1627 cx.new(|_cx| ProjectContext::default()),
1628 context_server_registry,
1629 Templates::new(),
1630 Some(model.clone()),
1631 cx,
1632 )
1633 });
1634 let tool = Arc::new(EditFileTool::new(
1635 project,
1636 thread.downgrade(),
1637 language_registry,
1638 Templates::new(),
1639 ));
1640
1641 cx.update(|cx| {
1642 // ...
1643 assert_eq!(
1644 tool.initial_title(
1645 Err(json!({
1646 "path": "src/main.rs",
1647 "display_description": "",
1648 "old_string": "old code",
1649 "new_string": "new code"
1650 })),
1651 cx
1652 ),
1653 "src/main.rs"
1654 );
1655 assert_eq!(
1656 tool.initial_title(
1657 Err(json!({
1658 "path": "",
1659 "display_description": "Fix error handling",
1660 "old_string": "old code",
1661 "new_string": "new code"
1662 })),
1663 cx
1664 ),
1665 "Fix error handling"
1666 );
1667 assert_eq!(
1668 tool.initial_title(
1669 Err(json!({
1670 "path": "src/main.rs",
1671 "display_description": "Fix error handling",
1672 "old_string": "old code",
1673 "new_string": "new code"
1674 })),
1675 cx
1676 ),
1677 "src/main.rs"
1678 );
1679 assert_eq!(
1680 tool.initial_title(
1681 Err(json!({
1682 "path": "",
1683 "display_description": "",
1684 "old_string": "old code",
1685 "new_string": "new code"
1686 })),
1687 cx
1688 ),
1689 DEFAULT_UI_TEXT
1690 );
1691 assert_eq!(
1692 tool.initial_title(Err(serde_json::Value::Null), cx),
1693 DEFAULT_UI_TEXT
1694 );
1695 });
1696 }
1697
1698 #[gpui::test]
1699 async fn test_diff_finalization(cx: &mut TestAppContext) {
1700 init_test(cx);
1701 let fs = project::FakeFs::new(cx.executor());
1702 fs.insert_tree("/", json!({"main.rs": ""})).await;
1703
1704 let project = Project::test(fs.clone(), [path!("/").as_ref()], cx).await;
1705 let languages = project.read_with(cx, |project, _cx| project.languages().clone());
1706 let context_server_registry =
1707 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1708 let model = Arc::new(FakeLanguageModel::default());
1709 let thread = cx.new(|cx| {
1710 Thread::new(
1711 project.clone(),
1712 cx.new(|_cx| ProjectContext::default()),
1713 context_server_registry.clone(),
1714 Templates::new(),
1715 Some(model.clone()),
1716 cx,
1717 )
1718 });
1719
1720 // Ensure the diff is finalized after the edit completes.
1721 {
1722 let tool = Arc::new(EditFileTool::new(
1723 project.clone(),
1724 thread.downgrade(),
1725 languages.clone(),
1726 Templates::new(),
1727 ));
1728 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1729 let edit = cx.update(|cx| {
1730 tool.run(
1731 EditFileToolInput {
1732 display_description: "Edit file".into(),
1733 path: path!("/main.rs").into(),
1734 mode: EditFileMode::Edit,
1735 },
1736 stream_tx,
1737 cx,
1738 )
1739 });
1740 stream_rx.expect_update_fields().await;
1741 let diff = stream_rx.expect_diff().await;
1742 diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Pending(_))));
1743 cx.run_until_parked();
1744 model.end_last_completion_stream();
1745 edit.await.unwrap();
1746 diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Finalized(_))));
1747 }
1748
1749 // Ensure the diff is finalized if an error occurs while editing.
1750 {
1751 model.forbid_requests();
1752 let tool = Arc::new(EditFileTool::new(
1753 project.clone(),
1754 thread.downgrade(),
1755 languages.clone(),
1756 Templates::new(),
1757 ));
1758 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1759 let edit = cx.update(|cx| {
1760 tool.run(
1761 EditFileToolInput {
1762 display_description: "Edit file".into(),
1763 path: path!("/main.rs").into(),
1764 mode: EditFileMode::Edit,
1765 },
1766 stream_tx,
1767 cx,
1768 )
1769 });
1770 stream_rx.expect_update_fields().await;
1771 let diff = stream_rx.expect_diff().await;
1772 diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Pending(_))));
1773 edit.await.unwrap_err();
1774 diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Finalized(_))));
1775 model.allow_requests();
1776 }
1777
1778 // Ensure the diff is finalized if the tool call gets dropped.
1779 {
1780 let tool = Arc::new(EditFileTool::new(
1781 project.clone(),
1782 thread.downgrade(),
1783 languages.clone(),
1784 Templates::new(),
1785 ));
1786 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1787 let edit = cx.update(|cx| {
1788 tool.run(
1789 EditFileToolInput {
1790 display_description: "Edit file".into(),
1791 path: path!("/main.rs").into(),
1792 mode: EditFileMode::Edit,
1793 },
1794 stream_tx,
1795 cx,
1796 )
1797 });
1798 stream_rx.expect_update_fields().await;
1799 let diff = stream_rx.expect_diff().await;
1800 diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Pending(_))));
1801 drop(edit);
1802 cx.run_until_parked();
1803 diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Finalized(_))));
1804 }
1805 }
1806
1807 #[gpui::test]
1808 async fn test_file_read_times_tracking(cx: &mut TestAppContext) {
1809 init_test(cx);
1810
1811 let fs = project::FakeFs::new(cx.executor());
1812 fs.insert_tree(
1813 "/root",
1814 json!({
1815 "test.txt": "original content"
1816 }),
1817 )
1818 .await;
1819 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
1820 let context_server_registry =
1821 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1822 let model = Arc::new(FakeLanguageModel::default());
1823 let thread = cx.new(|cx| {
1824 Thread::new(
1825 project.clone(),
1826 cx.new(|_cx| ProjectContext::default()),
1827 context_server_registry,
1828 Templates::new(),
1829 Some(model.clone()),
1830 cx,
1831 )
1832 });
1833 let action_log = thread.read_with(cx, |thread, _| thread.action_log().clone());
1834
1835 // Initially, file_read_times should be empty
1836 let is_empty = thread.read_with(cx, |thread, _| thread.file_read_times.is_empty());
1837 assert!(is_empty, "file_read_times should start empty");
1838
1839 // Create read tool
1840 let read_tool = Arc::new(crate::ReadFileTool::new(
1841 thread.downgrade(),
1842 project.clone(),
1843 action_log,
1844 ));
1845
1846 // Read the file to record the read time
1847 cx.update(|cx| {
1848 read_tool.clone().run(
1849 crate::ReadFileToolInput {
1850 path: "root/test.txt".to_string(),
1851 start_line: None,
1852 end_line: None,
1853 },
1854 ToolCallEventStream::test().0,
1855 cx,
1856 )
1857 })
1858 .await
1859 .unwrap();
1860
1861 // Verify that file_read_times now contains an entry for the file
1862 let has_entry = thread.read_with(cx, |thread, _| {
1863 thread.file_read_times.len() == 1
1864 && thread
1865 .file_read_times
1866 .keys()
1867 .any(|path| path.ends_with("test.txt"))
1868 });
1869 assert!(
1870 has_entry,
1871 "file_read_times should contain an entry after reading the file"
1872 );
1873
1874 // Read the file again - should update the entry
1875 cx.update(|cx| {
1876 read_tool.clone().run(
1877 crate::ReadFileToolInput {
1878 path: "root/test.txt".to_string(),
1879 start_line: None,
1880 end_line: None,
1881 },
1882 ToolCallEventStream::test().0,
1883 cx,
1884 )
1885 })
1886 .await
1887 .unwrap();
1888
1889 // Should still have exactly one entry
1890 let has_one_entry = thread.read_with(cx, |thread, _| thread.file_read_times.len() == 1);
1891 assert!(
1892 has_one_entry,
1893 "file_read_times should still have one entry after re-reading"
1894 );
1895 }
1896
1897 fn init_test(cx: &mut TestAppContext) {
1898 cx.update(|cx| {
1899 let settings_store = SettingsStore::test(cx);
1900 cx.set_global(settings_store);
1901 });
1902 }
1903
1904 #[gpui::test]
1905 async fn test_consecutive_edits_work(cx: &mut TestAppContext) {
1906 init_test(cx);
1907
1908 let fs = project::FakeFs::new(cx.executor());
1909 fs.insert_tree(
1910 "/root",
1911 json!({
1912 "test.txt": "original content"
1913 }),
1914 )
1915 .await;
1916 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
1917 let context_server_registry =
1918 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1919 let model = Arc::new(FakeLanguageModel::default());
1920 let thread = cx.new(|cx| {
1921 Thread::new(
1922 project.clone(),
1923 cx.new(|_cx| ProjectContext::default()),
1924 context_server_registry,
1925 Templates::new(),
1926 Some(model.clone()),
1927 cx,
1928 )
1929 });
1930 let languages = project.read_with(cx, |project, _| project.languages().clone());
1931 let action_log = thread.read_with(cx, |thread, _| thread.action_log().clone());
1932
1933 let read_tool = Arc::new(crate::ReadFileTool::new(
1934 thread.downgrade(),
1935 project.clone(),
1936 action_log,
1937 ));
1938 let edit_tool = Arc::new(EditFileTool::new(
1939 project.clone(),
1940 thread.downgrade(),
1941 languages,
1942 Templates::new(),
1943 ));
1944
1945 // Read the file first
1946 cx.update(|cx| {
1947 read_tool.clone().run(
1948 crate::ReadFileToolInput {
1949 path: "root/test.txt".to_string(),
1950 start_line: None,
1951 end_line: None,
1952 },
1953 ToolCallEventStream::test().0,
1954 cx,
1955 )
1956 })
1957 .await
1958 .unwrap();
1959
1960 // First edit should work
1961 let edit_result = {
1962 let edit_task = cx.update(|cx| {
1963 edit_tool.clone().run(
1964 EditFileToolInput {
1965 display_description: "First edit".into(),
1966 path: "root/test.txt".into(),
1967 mode: EditFileMode::Edit,
1968 },
1969 ToolCallEventStream::test().0,
1970 cx,
1971 )
1972 });
1973
1974 cx.executor().run_until_parked();
1975 model.send_last_completion_stream_text_chunk(
1976 "<old_text>original content</old_text><new_text>modified content</new_text>"
1977 .to_string(),
1978 );
1979 model.end_last_completion_stream();
1980
1981 edit_task.await
1982 };
1983 assert!(
1984 edit_result.is_ok(),
1985 "First edit should succeed, got error: {:?}",
1986 edit_result.as_ref().err()
1987 );
1988
1989 // Second edit should also work because the edit updated the recorded read time
1990 let edit_result = {
1991 let edit_task = cx.update(|cx| {
1992 edit_tool.clone().run(
1993 EditFileToolInput {
1994 display_description: "Second edit".into(),
1995 path: "root/test.txt".into(),
1996 mode: EditFileMode::Edit,
1997 },
1998 ToolCallEventStream::test().0,
1999 cx,
2000 )
2001 });
2002
2003 cx.executor().run_until_parked();
2004 model.send_last_completion_stream_text_chunk(
2005 "<old_text>modified content</old_text><new_text>further modified content</new_text>".to_string(),
2006 );
2007 model.end_last_completion_stream();
2008
2009 edit_task.await
2010 };
2011 assert!(
2012 edit_result.is_ok(),
2013 "Second consecutive edit should succeed, got error: {:?}",
2014 edit_result.as_ref().err()
2015 );
2016 }
2017
2018 #[gpui::test]
2019 async fn test_external_modification_detected(cx: &mut TestAppContext) {
2020 init_test(cx);
2021
2022 let fs = project::FakeFs::new(cx.executor());
2023 fs.insert_tree(
2024 "/root",
2025 json!({
2026 "test.txt": "original content"
2027 }),
2028 )
2029 .await;
2030 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
2031 let context_server_registry =
2032 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
2033 let model = Arc::new(FakeLanguageModel::default());
2034 let thread = cx.new(|cx| {
2035 Thread::new(
2036 project.clone(),
2037 cx.new(|_cx| ProjectContext::default()),
2038 context_server_registry,
2039 Templates::new(),
2040 Some(model.clone()),
2041 cx,
2042 )
2043 });
2044 let languages = project.read_with(cx, |project, _| project.languages().clone());
2045 let action_log = thread.read_with(cx, |thread, _| thread.action_log().clone());
2046
2047 let read_tool = Arc::new(crate::ReadFileTool::new(
2048 thread.downgrade(),
2049 project.clone(),
2050 action_log,
2051 ));
2052 let edit_tool = Arc::new(EditFileTool::new(
2053 project.clone(),
2054 thread.downgrade(),
2055 languages,
2056 Templates::new(),
2057 ));
2058
2059 // Read the file first
2060 cx.update(|cx| {
2061 read_tool.clone().run(
2062 crate::ReadFileToolInput {
2063 path: "root/test.txt".to_string(),
2064 start_line: None,
2065 end_line: None,
2066 },
2067 ToolCallEventStream::test().0,
2068 cx,
2069 )
2070 })
2071 .await
2072 .unwrap();
2073
2074 // Simulate external modification - advance time and save file
2075 cx.background_executor
2076 .advance_clock(std::time::Duration::from_secs(2));
2077 fs.save(
2078 path!("/root/test.txt").as_ref(),
2079 &"externally modified content".into(),
2080 language::LineEnding::Unix,
2081 )
2082 .await
2083 .unwrap();
2084
2085 // Reload the buffer to pick up the new mtime
2086 let project_path = project
2087 .read_with(cx, |project, cx| {
2088 project.find_project_path("root/test.txt", cx)
2089 })
2090 .expect("Should find project path");
2091 let buffer = project
2092 .update(cx, |project, cx| project.open_buffer(project_path, cx))
2093 .await
2094 .unwrap();
2095 buffer
2096 .update(cx, |buffer, cx| buffer.reload(cx))
2097 .await
2098 .unwrap();
2099
2100 cx.executor().run_until_parked();
2101
2102 // Try to edit - should fail because file was modified externally
2103 let result = cx
2104 .update(|cx| {
2105 edit_tool.clone().run(
2106 EditFileToolInput {
2107 display_description: "Edit after external change".into(),
2108 path: "root/test.txt".into(),
2109 mode: EditFileMode::Edit,
2110 },
2111 ToolCallEventStream::test().0,
2112 cx,
2113 )
2114 })
2115 .await;
2116
2117 assert!(
2118 result.is_err(),
2119 "Edit should fail after external modification"
2120 );
2121 let error_msg = result.unwrap_err().to_string();
2122 assert!(
2123 error_msg.contains("has been modified since you last read it"),
2124 "Error should mention file modification, got: {}",
2125 error_msg
2126 );
2127 }
2128
2129 #[gpui::test]
2130 async fn test_dirty_buffer_detected(cx: &mut TestAppContext) {
2131 init_test(cx);
2132
2133 let fs = project::FakeFs::new(cx.executor());
2134 fs.insert_tree(
2135 "/root",
2136 json!({
2137 "test.txt": "original content"
2138 }),
2139 )
2140 .await;
2141 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
2142 let context_server_registry =
2143 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
2144 let model = Arc::new(FakeLanguageModel::default());
2145 let thread = cx.new(|cx| {
2146 Thread::new(
2147 project.clone(),
2148 cx.new(|_cx| ProjectContext::default()),
2149 context_server_registry,
2150 Templates::new(),
2151 Some(model.clone()),
2152 cx,
2153 )
2154 });
2155 let languages = project.read_with(cx, |project, _| project.languages().clone());
2156 let action_log = thread.read_with(cx, |thread, _| thread.action_log().clone());
2157
2158 let read_tool = Arc::new(crate::ReadFileTool::new(
2159 thread.downgrade(),
2160 project.clone(),
2161 action_log,
2162 ));
2163 let edit_tool = Arc::new(EditFileTool::new(
2164 project.clone(),
2165 thread.downgrade(),
2166 languages,
2167 Templates::new(),
2168 ));
2169
2170 // Read the file first
2171 cx.update(|cx| {
2172 read_tool.clone().run(
2173 crate::ReadFileToolInput {
2174 path: "root/test.txt".to_string(),
2175 start_line: None,
2176 end_line: None,
2177 },
2178 ToolCallEventStream::test().0,
2179 cx,
2180 )
2181 })
2182 .await
2183 .unwrap();
2184
2185 // Open the buffer and make it dirty by editing without saving
2186 let project_path = project
2187 .read_with(cx, |project, cx| {
2188 project.find_project_path("root/test.txt", cx)
2189 })
2190 .expect("Should find project path");
2191 let buffer = project
2192 .update(cx, |project, cx| project.open_buffer(project_path, cx))
2193 .await
2194 .unwrap();
2195
2196 // Make an in-memory edit to the buffer (making it dirty)
2197 buffer.update(cx, |buffer, cx| {
2198 let end_point = buffer.max_point();
2199 buffer.edit([(end_point..end_point, " added text")], None, cx);
2200 });
2201
2202 // Verify buffer is dirty
2203 let is_dirty = buffer.read_with(cx, |buffer, _| buffer.is_dirty());
2204 assert!(is_dirty, "Buffer should be dirty after in-memory edit");
2205
2206 // Try to edit - should fail because buffer has unsaved changes
2207 let result = cx
2208 .update(|cx| {
2209 edit_tool.clone().run(
2210 EditFileToolInput {
2211 display_description: "Edit with dirty buffer".into(),
2212 path: "root/test.txt".into(),
2213 mode: EditFileMode::Edit,
2214 },
2215 ToolCallEventStream::test().0,
2216 cx,
2217 )
2218 })
2219 .await;
2220
2221 assert!(result.is_err(), "Edit should fail when buffer is dirty");
2222 let error_msg = result.unwrap_err().to_string();
2223 assert!(
2224 error_msg.contains("This file has unsaved changes."),
2225 "Error should mention unsaved changes, got: {}",
2226 error_msg
2227 );
2228 assert!(
2229 error_msg.contains("keep or discard"),
2230 "Error should ask whether to keep or discard changes, got: {}",
2231 error_msg
2232 );
2233 // Since save_file and restore_file_from_disk tools aren't added to the thread,
2234 // the error message should ask the user to manually save or revert
2235 assert!(
2236 error_msg.contains("save or revert the file manually"),
2237 "Error should ask user to manually save or revert when tools aren't available, got: {}",
2238 error_msg
2239 );
2240 }
2241}