1use crate::{
2 AgentTool, Templates, Thread, ToolCallEventStream,
3 edit_agent::{EditAgent, EditAgentOutput, EditAgentOutputEvent, EditFormat},
4};
5use acp_thread::Diff;
6use agent_client_protocol::{self as acp, ToolCallLocation, ToolCallUpdateFields};
7use anyhow::{Context as _, Result, anyhow};
8use cloud_llm_client::CompletionIntent;
9use collections::HashSet;
10use gpui::{App, AppContext, AsyncApp, Entity, Task, WeakEntity};
11use indoc::formatdoc;
12use language::language_settings::{self, FormatOnSave};
13use language::{LanguageRegistry, ToPoint};
14use language_model::LanguageModelToolResultContent;
15use paths;
16use project::lsp_store::{FormatTrigger, LspFormatTarget};
17use project::{Project, ProjectPath};
18use schemars::JsonSchema;
19use serde::{Deserialize, Serialize};
20use settings::Settings;
21use smol::stream::StreamExt as _;
22use std::ffi::OsStr;
23use std::path::{Path, PathBuf};
24use std::sync::Arc;
25use ui::SharedString;
26use util::ResultExt;
27use util::rel_path::RelPath;
28
29const DEFAULT_UI_TEXT: &str = "Editing file";
30
31/// This is a tool for creating a new file or editing an existing file. For moving or renaming files, you should generally use the `terminal` tool with the 'mv' command instead.
32///
33/// Before using this tool:
34///
35/// 1. Use the `read_file` tool to understand the file's contents and context
36///
37/// 2. Verify the directory path is correct (only applicable when creating new files):
38/// - Use the `list_directory` tool to verify the parent directory exists and is the correct location
39#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
40pub struct EditFileToolInput {
41 /// A one-line, user-friendly markdown description of the edit. This will be shown in the UI and also passed to another model to perform the edit.
42 ///
43 /// Be terse, but also descriptive in what you want to achieve with this edit. Avoid generic instructions.
44 ///
45 /// NEVER mention the file path in this description.
46 ///
47 /// <example>Fix API endpoint URLs</example>
48 /// <example>Update copyright year in `page_footer`</example>
49 ///
50 /// Make sure to include this field before all the others in the input object so that we can display it immediately.
51 pub display_description: String,
52
53 /// The full path of the file to create or modify in the project.
54 ///
55 /// WARNING: When specifying which file path need changing, you MUST start each path with one of the project's root directories.
56 ///
57 /// The following examples assume we have two root directories in the project:
58 /// - /a/b/backend
59 /// - /c/d/frontend
60 ///
61 /// <example>
62 /// `backend/src/main.rs`
63 ///
64 /// Notice how the file path starts with `backend`. Without that, the path would be ambiguous and the call would fail!
65 /// </example>
66 ///
67 /// <example>
68 /// `frontend/db.js`
69 /// </example>
70 pub path: PathBuf,
71 /// The mode of operation on the file. Possible values:
72 /// - 'edit': Make granular edits to an existing file.
73 /// - 'create': Create a new file if it doesn't exist.
74 /// - 'overwrite': Replace the entire contents of an existing file.
75 ///
76 /// When a file already exists or you just created it, prefer editing it as opposed to recreating it from scratch.
77 pub mode: EditFileMode,
78}
79
80#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
81struct EditFileToolPartialInput {
82 #[serde(default)]
83 path: String,
84 #[serde(default)]
85 display_description: String,
86}
87
88#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
89#[serde(rename_all = "lowercase")]
90#[schemars(inline)]
91pub enum EditFileMode {
92 Edit,
93 Create,
94 Overwrite,
95}
96
97#[derive(Debug, Serialize, Deserialize)]
98pub struct EditFileToolOutput {
99 #[serde(alias = "original_path")]
100 input_path: PathBuf,
101 new_text: String,
102 old_text: Arc<String>,
103 #[serde(default)]
104 diff: String,
105 #[serde(alias = "raw_output")]
106 edit_agent_output: EditAgentOutput,
107}
108
109impl From<EditFileToolOutput> for LanguageModelToolResultContent {
110 fn from(output: EditFileToolOutput) -> Self {
111 if output.diff.is_empty() {
112 "No edits were made.".into()
113 } else {
114 format!(
115 "Edited {}:\n\n```diff\n{}\n```",
116 output.input_path.display(),
117 output.diff
118 )
119 .into()
120 }
121 }
122}
123
124pub struct EditFileTool {
125 thread: WeakEntity<Thread>,
126 language_registry: Arc<LanguageRegistry>,
127 project: Entity<Project>,
128 templates: Arc<Templates>,
129}
130
131impl EditFileTool {
132 pub fn new(
133 project: Entity<Project>,
134 thread: WeakEntity<Thread>,
135 language_registry: Arc<LanguageRegistry>,
136 templates: Arc<Templates>,
137 ) -> Self {
138 Self {
139 project,
140 thread,
141 language_registry,
142 templates,
143 }
144 }
145
146 fn authorize(
147 &self,
148 input: &EditFileToolInput,
149 event_stream: &ToolCallEventStream,
150 cx: &mut App,
151 ) -> Task<Result<()>> {
152 if agent_settings::AgentSettings::get_global(cx).always_allow_tool_actions {
153 return Task::ready(Ok(()));
154 }
155
156 // If any path component matches the local settings folder, then this could affect
157 // the editor in ways beyond the project source, so prompt.
158 let local_settings_folder = paths::local_settings_folder_name();
159 let path = Path::new(&input.path);
160 if path.components().any(|component| {
161 component.as_os_str() == <_ as AsRef<OsStr>>::as_ref(&local_settings_folder)
162 }) {
163 return event_stream.authorize(
164 format!("{} (local settings)", input.display_description),
165 cx,
166 );
167 }
168
169 // It's also possible that the global config dir is configured to be inside the project,
170 // so check for that edge case too.
171 // TODO this is broken when remoting
172 if let Ok(canonical_path) = std::fs::canonicalize(&input.path)
173 && canonical_path.starts_with(paths::config_dir())
174 {
175 return event_stream.authorize(
176 format!("{} (global settings)", input.display_description),
177 cx,
178 );
179 }
180
181 // Check if path is inside the global config directory
182 // First check if it's already inside project - if not, try to canonicalize
183 let Ok(project_path) = self.thread.read_with(cx, |thread, cx| {
184 thread.project().read(cx).find_project_path(&input.path, cx)
185 }) else {
186 return Task::ready(Err(anyhow!("thread was dropped")));
187 };
188
189 // If the path is inside the project, and it's not one of the above edge cases,
190 // then no confirmation is necessary. Otherwise, confirmation is necessary.
191 if project_path.is_some() {
192 Task::ready(Ok(()))
193 } else {
194 event_stream.authorize(&input.display_description, cx)
195 }
196 }
197}
198
199impl AgentTool for EditFileTool {
200 type Input = EditFileToolInput;
201 type Output = EditFileToolOutput;
202
203 fn name() -> &'static str {
204 "edit_file"
205 }
206
207 fn kind() -> acp::ToolKind {
208 acp::ToolKind::Edit
209 }
210
211 fn initial_title(
212 &self,
213 input: Result<Self::Input, serde_json::Value>,
214 cx: &mut App,
215 ) -> SharedString {
216 match input {
217 Ok(input) => self
218 .project
219 .read(cx)
220 .find_project_path(&input.path, cx)
221 .and_then(|project_path| {
222 self.project
223 .read(cx)
224 .short_full_path_for_project_path(&project_path, cx)
225 })
226 .unwrap_or(input.path.to_string_lossy().into_owned())
227 .into(),
228 Err(raw_input) => {
229 if let Some(input) =
230 serde_json::from_value::<EditFileToolPartialInput>(raw_input).ok()
231 {
232 let path = input.path.trim();
233 if !path.is_empty() {
234 return self
235 .project
236 .read(cx)
237 .find_project_path(&input.path, cx)
238 .and_then(|project_path| {
239 self.project
240 .read(cx)
241 .short_full_path_for_project_path(&project_path, cx)
242 })
243 .unwrap_or(input.path)
244 .into();
245 }
246
247 let description = input.display_description.trim();
248 if !description.is_empty() {
249 return description.to_string().into();
250 }
251 }
252
253 DEFAULT_UI_TEXT.into()
254 }
255 }
256 }
257
258 fn run(
259 self: Arc<Self>,
260 input: Self::Input,
261 event_stream: ToolCallEventStream,
262 cx: &mut App,
263 ) -> Task<Result<Self::Output>> {
264 let Ok(project) = self
265 .thread
266 .read_with(cx, |thread, _cx| thread.project().clone())
267 else {
268 return Task::ready(Err(anyhow!("thread was dropped")));
269 };
270 let project_path = match resolve_path(&input, project.clone(), cx) {
271 Ok(path) => path,
272 Err(err) => return Task::ready(Err(anyhow!(err))),
273 };
274 let abs_path = project.read(cx).absolute_path(&project_path, cx);
275 if let Some(abs_path) = abs_path.clone() {
276 event_stream.update_fields(
277 ToolCallUpdateFields::new().locations(vec![acp::ToolCallLocation::new(abs_path)]),
278 );
279 }
280
281 let authorize = self.authorize(&input, &event_stream, cx);
282 cx.spawn(async move |cx: &mut AsyncApp| {
283 authorize.await?;
284
285 let (request, model, action_log) = self.thread.update(cx, |thread, cx| {
286 let request = thread.build_completion_request(CompletionIntent::ToolResults, cx);
287 (request, thread.model().cloned(), thread.action_log().clone())
288 })?;
289 let request = request?;
290 let model = model.context("No language model configured")?;
291
292 let edit_format = EditFormat::from_model(model.clone())?;
293 let edit_agent = EditAgent::new(
294 model,
295 project.clone(),
296 action_log.clone(),
297 self.templates.clone(),
298 edit_format,
299 );
300
301 let buffer = project
302 .update(cx, |project, cx| {
303 project.open_buffer(project_path.clone(), cx)
304 })?
305 .await?;
306
307 // Check if the file has been modified since the agent last read it
308 if let Some(abs_path) = abs_path.as_ref() {
309 let (last_read_mtime, current_mtime, is_dirty) = self.thread.update(cx, |thread, cx| {
310 let last_read = thread.file_read_times.get(abs_path).copied();
311 let current = buffer.read(cx).file().and_then(|file| file.disk_state().mtime());
312 let dirty = buffer.read(cx).is_dirty();
313 (last_read, current, dirty)
314 })?;
315
316 // Check for unsaved changes first - these indicate modifications we don't know about
317 if is_dirty {
318 anyhow::bail!(
319 "This file has unsaved changes. Ask the user whether they want to keep or discard those changes. \
320 If they want to keep them, ask for confirmation then use the save_file tool to save the file, then retry this edit. \
321 If they want to discard them, ask for confirmation then use the restore_file_from_disk tool to restore the on-disk contents, then retry this edit."
322 );
323 }
324
325 // Check if the file was modified on disk since we last read it
326 if let (Some(last_read), Some(current)) = (last_read_mtime, current_mtime) {
327 // MTime can be unreliable for comparisons, so our newtype intentionally
328 // doesn't support comparing them. If the mtime at all different
329 // (which could be because of a modification or because e.g. system clock changed),
330 // we pessimistically assume it was modified.
331 if current != last_read {
332 anyhow::bail!(
333 "The file {} has been modified since you last read it. \
334 Please read the file again to get the current state before editing it.",
335 input.path.display()
336 );
337 }
338 }
339 }
340
341 let diff = cx.new(|cx| Diff::new(buffer.clone(), cx))?;
342 event_stream.update_diff(diff.clone());
343 let _finalize_diff = util::defer({
344 let diff = diff.downgrade();
345 let mut cx = cx.clone();
346 move || {
347 diff.update(&mut cx, |diff, cx| diff.finalize(cx)).ok();
348 }
349 });
350
351 let old_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
352 let old_text = cx
353 .background_spawn({
354 let old_snapshot = old_snapshot.clone();
355 async move { Arc::new(old_snapshot.text()) }
356 })
357 .await;
358
359
360 let (output, mut events) = if matches!(input.mode, EditFileMode::Edit) {
361 edit_agent.edit(
362 buffer.clone(),
363 input.display_description.clone(),
364 &request,
365 cx,
366 )
367 } else {
368 edit_agent.overwrite(
369 buffer.clone(),
370 input.display_description.clone(),
371 &request,
372 cx,
373 )
374 };
375
376 let mut hallucinated_old_text = false;
377 let mut ambiguous_ranges = Vec::new();
378 let mut emitted_location = false;
379 while let Some(event) = events.next().await {
380 match event {
381 EditAgentOutputEvent::Edited(range) => {
382 if !emitted_location {
383 let line = buffer.update(cx, |buffer, _cx| {
384 range.start.to_point(&buffer.snapshot()).row
385 }).ok();
386 if let Some(abs_path) = abs_path.clone() {
387 event_stream.update_fields(ToolCallUpdateFields::new().locations(vec![ToolCallLocation::new(abs_path).line(line)]));
388 }
389 emitted_location = true;
390 }
391 },
392 EditAgentOutputEvent::UnresolvedEditRange => hallucinated_old_text = true,
393 EditAgentOutputEvent::AmbiguousEditRange(ranges) => ambiguous_ranges = ranges,
394 EditAgentOutputEvent::ResolvingEditRange(range) => {
395 diff.update(cx, |card, cx| card.reveal_range(range.clone(), cx))?;
396 // if !emitted_location {
397 // let line = buffer.update(cx, |buffer, _cx| {
398 // range.start.to_point(&buffer.snapshot()).row
399 // }).ok();
400 // if let Some(abs_path) = abs_path.clone() {
401 // event_stream.update_fields(ToolCallUpdateFields {
402 // locations: Some(vec![ToolCallLocation { path: abs_path, line }]),
403 // ..Default::default()
404 // });
405 // }
406 // }
407 }
408 }
409 }
410
411 // If format_on_save is enabled, format the buffer
412 let format_on_save_enabled = buffer
413 .read_with(cx, |buffer, cx| {
414 let settings = language_settings::language_settings(
415 buffer.language().map(|l| l.name()),
416 buffer.file(),
417 cx,
418 );
419 settings.format_on_save != FormatOnSave::Off
420 })
421 .unwrap_or(false);
422
423 let edit_agent_output = output.await?;
424
425 if format_on_save_enabled {
426 action_log.update(cx, |log, cx| {
427 log.buffer_edited(buffer.clone(), cx);
428 })?;
429
430 let format_task = project.update(cx, |project, cx| {
431 project.format(
432 HashSet::from_iter([buffer.clone()]),
433 LspFormatTarget::Buffers,
434 false, // Don't push to history since the tool did it.
435 FormatTrigger::Save,
436 cx,
437 )
438 })?;
439 format_task.await.log_err();
440 }
441
442 project
443 .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))?
444 .await?;
445
446 action_log.update(cx, |log, cx| {
447 log.buffer_edited(buffer.clone(), cx);
448 })?;
449
450 // Update the recorded read time after a successful edit so consecutive edits work
451 if let Some(abs_path) = abs_path.as_ref() {
452 if let Some(new_mtime) = buffer.read_with(cx, |buffer, _| {
453 buffer.file().and_then(|file| file.disk_state().mtime())
454 })? {
455 self.thread.update(cx, |thread, _| {
456 thread.file_read_times.insert(abs_path.to_path_buf(), new_mtime);
457 })?;
458 }
459 }
460
461 let new_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
462 let (new_text, unified_diff) = cx
463 .background_spawn({
464 let new_snapshot = new_snapshot.clone();
465 let old_text = old_text.clone();
466 async move {
467 let new_text = new_snapshot.text();
468 let diff = language::unified_diff(&old_text, &new_text);
469 (new_text, diff)
470 }
471 })
472 .await;
473
474 let input_path = input.path.display();
475 if unified_diff.is_empty() {
476 anyhow::ensure!(
477 !hallucinated_old_text,
478 formatdoc! {"
479 Some edits were produced but none of them could be applied.
480 Read the relevant sections of {input_path} again so that
481 I can perform the requested edits.
482 "}
483 );
484 anyhow::ensure!(
485 ambiguous_ranges.is_empty(),
486 {
487 let line_numbers = ambiguous_ranges
488 .iter()
489 .map(|range| range.start.to_string())
490 .collect::<Vec<_>>()
491 .join(", ");
492 formatdoc! {"
493 <old_text> matches more than one position in the file (lines: {line_numbers}). Read the
494 relevant sections of {input_path} again and extend <old_text> so
495 that I can perform the requested edits.
496 "}
497 }
498 );
499 }
500
501 Ok(EditFileToolOutput {
502 input_path: input.path,
503 new_text,
504 old_text,
505 diff: unified_diff,
506 edit_agent_output,
507 })
508 })
509 }
510
511 fn replay(
512 &self,
513 _input: Self::Input,
514 output: Self::Output,
515 event_stream: ToolCallEventStream,
516 cx: &mut App,
517 ) -> Result<()> {
518 event_stream.update_diff(cx.new(|cx| {
519 Diff::finalized(
520 output.input_path.to_string_lossy().into_owned(),
521 Some(output.old_text.to_string()),
522 output.new_text,
523 self.language_registry.clone(),
524 cx,
525 )
526 }));
527 Ok(())
528 }
529}
530
531/// Validate that the file path is valid, meaning:
532///
533/// - For `edit` and `overwrite`, the path must point to an existing file.
534/// - For `create`, the file must not already exist, but it's parent dir must exist.
535fn resolve_path(
536 input: &EditFileToolInput,
537 project: Entity<Project>,
538 cx: &mut App,
539) -> Result<ProjectPath> {
540 let project = project.read(cx);
541
542 match input.mode {
543 EditFileMode::Edit | EditFileMode::Overwrite => {
544 let path = project
545 .find_project_path(&input.path, cx)
546 .context("Can't edit file: path not found")?;
547
548 let entry = project
549 .entry_for_path(&path, cx)
550 .context("Can't edit file: path not found")?;
551
552 anyhow::ensure!(entry.is_file(), "Can't edit file: path is a directory");
553 Ok(path)
554 }
555
556 EditFileMode::Create => {
557 if let Some(path) = project.find_project_path(&input.path, cx) {
558 anyhow::ensure!(
559 project.entry_for_path(&path, cx).is_none(),
560 "Can't create file: file already exists"
561 );
562 }
563
564 let parent_path = input
565 .path
566 .parent()
567 .context("Can't create file: incorrect path")?;
568
569 let parent_project_path = project.find_project_path(&parent_path, cx);
570
571 let parent_entry = parent_project_path
572 .as_ref()
573 .and_then(|path| project.entry_for_path(path, cx))
574 .context("Can't create file: parent directory doesn't exist")?;
575
576 anyhow::ensure!(
577 parent_entry.is_dir(),
578 "Can't create file: parent is not a directory"
579 );
580
581 let file_name = input
582 .path
583 .file_name()
584 .and_then(|file_name| file_name.to_str())
585 .and_then(|file_name| RelPath::unix(file_name).ok())
586 .context("Can't create file: invalid filename")?;
587
588 let new_file_path = parent_project_path.map(|parent| ProjectPath {
589 path: parent.path.join(file_name),
590 ..parent
591 });
592
593 new_file_path.context("Can't create file")
594 }
595 }
596}
597
598#[cfg(test)]
599mod tests {
600 use super::*;
601 use crate::{ContextServerRegistry, Templates};
602 use fs::Fs;
603 use gpui::{TestAppContext, UpdateGlobal};
604 use language_model::fake_provider::FakeLanguageModel;
605 use prompt_store::ProjectContext;
606 use serde_json::json;
607 use settings::SettingsStore;
608 use util::{path, rel_path::rel_path};
609
610 #[gpui::test]
611 async fn test_edit_nonexistent_file(cx: &mut TestAppContext) {
612 init_test(cx);
613
614 let fs = project::FakeFs::new(cx.executor());
615 fs.insert_tree("/root", json!({})).await;
616 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
617 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
618 let context_server_registry =
619 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
620 let model = Arc::new(FakeLanguageModel::default());
621 let thread = cx.new(|cx| {
622 Thread::new(
623 project.clone(),
624 cx.new(|_cx| ProjectContext::default()),
625 context_server_registry,
626 Templates::new(),
627 Some(model),
628 cx,
629 )
630 });
631 let result = cx
632 .update(|cx| {
633 let input = EditFileToolInput {
634 display_description: "Some edit".into(),
635 path: "root/nonexistent_file.txt".into(),
636 mode: EditFileMode::Edit,
637 };
638 Arc::new(EditFileTool::new(
639 project,
640 thread.downgrade(),
641 language_registry,
642 Templates::new(),
643 ))
644 .run(input, ToolCallEventStream::test().0, cx)
645 })
646 .await;
647 assert_eq!(
648 result.unwrap_err().to_string(),
649 "Can't edit file: path not found"
650 );
651 }
652
653 #[gpui::test]
654 async fn test_resolve_path_for_creating_file(cx: &mut TestAppContext) {
655 let mode = &EditFileMode::Create;
656
657 let result = test_resolve_path(mode, "root/new.txt", cx);
658 assert_resolved_path_eq(result.await, rel_path("new.txt"));
659
660 let result = test_resolve_path(mode, "new.txt", cx);
661 assert_resolved_path_eq(result.await, rel_path("new.txt"));
662
663 let result = test_resolve_path(mode, "dir/new.txt", cx);
664 assert_resolved_path_eq(result.await, rel_path("dir/new.txt"));
665
666 let result = test_resolve_path(mode, "root/dir/subdir/existing.txt", cx);
667 assert_eq!(
668 result.await.unwrap_err().to_string(),
669 "Can't create file: file already exists"
670 );
671
672 let result = test_resolve_path(mode, "root/dir/nonexistent_dir/new.txt", cx);
673 assert_eq!(
674 result.await.unwrap_err().to_string(),
675 "Can't create file: parent directory doesn't exist"
676 );
677 }
678
679 #[gpui::test]
680 async fn test_resolve_path_for_editing_file(cx: &mut TestAppContext) {
681 let mode = &EditFileMode::Edit;
682
683 let path_with_root = "root/dir/subdir/existing.txt";
684 let path_without_root = "dir/subdir/existing.txt";
685 let result = test_resolve_path(mode, path_with_root, cx);
686 assert_resolved_path_eq(result.await, rel_path(path_without_root));
687
688 let result = test_resolve_path(mode, path_without_root, cx);
689 assert_resolved_path_eq(result.await, rel_path(path_without_root));
690
691 let result = test_resolve_path(mode, "root/nonexistent.txt", cx);
692 assert_eq!(
693 result.await.unwrap_err().to_string(),
694 "Can't edit file: path not found"
695 );
696
697 let result = test_resolve_path(mode, "root/dir", cx);
698 assert_eq!(
699 result.await.unwrap_err().to_string(),
700 "Can't edit file: path is a directory"
701 );
702 }
703
704 async fn test_resolve_path(
705 mode: &EditFileMode,
706 path: &str,
707 cx: &mut TestAppContext,
708 ) -> anyhow::Result<ProjectPath> {
709 init_test(cx);
710
711 let fs = project::FakeFs::new(cx.executor());
712 fs.insert_tree(
713 "/root",
714 json!({
715 "dir": {
716 "subdir": {
717 "existing.txt": "hello"
718 }
719 }
720 }),
721 )
722 .await;
723 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
724
725 let input = EditFileToolInput {
726 display_description: "Some edit".into(),
727 path: path.into(),
728 mode: mode.clone(),
729 };
730
731 cx.update(|cx| resolve_path(&input, project, cx))
732 }
733
734 #[track_caller]
735 fn assert_resolved_path_eq(path: anyhow::Result<ProjectPath>, expected: &RelPath) {
736 let actual = path.expect("Should return valid path").path;
737 assert_eq!(actual.as_ref(), expected);
738 }
739
740 #[gpui::test]
741 async fn test_format_on_save(cx: &mut TestAppContext) {
742 init_test(cx);
743
744 let fs = project::FakeFs::new(cx.executor());
745 fs.insert_tree("/root", json!({"src": {}})).await;
746
747 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
748
749 // Set up a Rust language with LSP formatting support
750 let rust_language = Arc::new(language::Language::new(
751 language::LanguageConfig {
752 name: "Rust".into(),
753 matcher: language::LanguageMatcher {
754 path_suffixes: vec!["rs".to_string()],
755 ..Default::default()
756 },
757 ..Default::default()
758 },
759 None,
760 ));
761
762 // Register the language and fake LSP
763 let language_registry = project.read_with(cx, |project, _| project.languages().clone());
764 language_registry.add(rust_language);
765
766 let mut fake_language_servers = language_registry.register_fake_lsp(
767 "Rust",
768 language::FakeLspAdapter {
769 capabilities: lsp::ServerCapabilities {
770 document_formatting_provider: Some(lsp::OneOf::Left(true)),
771 ..Default::default()
772 },
773 ..Default::default()
774 },
775 );
776
777 // Create the file
778 fs.save(
779 path!("/root/src/main.rs").as_ref(),
780 &"initial content".into(),
781 language::LineEnding::Unix,
782 )
783 .await
784 .unwrap();
785
786 // Open the buffer to trigger LSP initialization
787 let buffer = project
788 .update(cx, |project, cx| {
789 project.open_local_buffer(path!("/root/src/main.rs"), cx)
790 })
791 .await
792 .unwrap();
793
794 // Register the buffer with language servers
795 let _handle = project.update(cx, |project, cx| {
796 project.register_buffer_with_language_servers(&buffer, cx)
797 });
798
799 const UNFORMATTED_CONTENT: &str = "fn main() {println!(\"Hello!\");}\n";
800 const FORMATTED_CONTENT: &str =
801 "This file was formatted by the fake formatter in the test.\n";
802
803 // Get the fake language server and set up formatting handler
804 let fake_language_server = fake_language_servers.next().await.unwrap();
805 fake_language_server.set_request_handler::<lsp::request::Formatting, _, _>({
806 |_, _| async move {
807 Ok(Some(vec![lsp::TextEdit {
808 range: lsp::Range::new(lsp::Position::new(0, 0), lsp::Position::new(1, 0)),
809 new_text: FORMATTED_CONTENT.to_string(),
810 }]))
811 }
812 });
813
814 let context_server_registry =
815 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
816 let model = Arc::new(FakeLanguageModel::default());
817 let thread = cx.new(|cx| {
818 Thread::new(
819 project.clone(),
820 cx.new(|_cx| ProjectContext::default()),
821 context_server_registry,
822 Templates::new(),
823 Some(model.clone()),
824 cx,
825 )
826 });
827
828 // First, test with format_on_save enabled
829 cx.update(|cx| {
830 SettingsStore::update_global(cx, |store, cx| {
831 store.update_user_settings(cx, |settings| {
832 settings.project.all_languages.defaults.format_on_save = Some(FormatOnSave::On);
833 settings.project.all_languages.defaults.formatter =
834 Some(language::language_settings::FormatterList::default());
835 });
836 });
837 });
838
839 // Have the model stream unformatted content
840 let edit_result = {
841 let edit_task = cx.update(|cx| {
842 let input = EditFileToolInput {
843 display_description: "Create main function".into(),
844 path: "root/src/main.rs".into(),
845 mode: EditFileMode::Overwrite,
846 };
847 Arc::new(EditFileTool::new(
848 project.clone(),
849 thread.downgrade(),
850 language_registry.clone(),
851 Templates::new(),
852 ))
853 .run(input, ToolCallEventStream::test().0, cx)
854 });
855
856 // Stream the unformatted content
857 cx.executor().run_until_parked();
858 model.send_last_completion_stream_text_chunk(UNFORMATTED_CONTENT.to_string());
859 model.end_last_completion_stream();
860
861 edit_task.await
862 };
863 assert!(edit_result.is_ok());
864
865 // Wait for any async operations (e.g. formatting) to complete
866 cx.executor().run_until_parked();
867
868 // Read the file to verify it was formatted automatically
869 let new_content = fs.load(path!("/root/src/main.rs").as_ref()).await.unwrap();
870 assert_eq!(
871 // Ignore carriage returns on Windows
872 new_content.replace("\r\n", "\n"),
873 FORMATTED_CONTENT,
874 "Code should be formatted when format_on_save is enabled"
875 );
876
877 let stale_buffer_count = thread
878 .read_with(cx, |thread, _cx| thread.action_log.clone())
879 .read_with(cx, |log, cx| log.stale_buffers(cx).count());
880
881 assert_eq!(
882 stale_buffer_count, 0,
883 "BUG: Buffer is incorrectly marked as stale after format-on-save. Found {} stale buffers. \
884 This causes the agent to think the file was modified externally when it was just formatted.",
885 stale_buffer_count
886 );
887
888 // Next, test with format_on_save disabled
889 cx.update(|cx| {
890 SettingsStore::update_global(cx, |store, cx| {
891 store.update_user_settings(cx, |settings| {
892 settings.project.all_languages.defaults.format_on_save =
893 Some(FormatOnSave::Off);
894 });
895 });
896 });
897
898 // Stream unformatted edits again
899 let edit_result = {
900 let edit_task = cx.update(|cx| {
901 let input = EditFileToolInput {
902 display_description: "Update main function".into(),
903 path: "root/src/main.rs".into(),
904 mode: EditFileMode::Overwrite,
905 };
906 Arc::new(EditFileTool::new(
907 project.clone(),
908 thread.downgrade(),
909 language_registry,
910 Templates::new(),
911 ))
912 .run(input, ToolCallEventStream::test().0, cx)
913 });
914
915 // Stream the unformatted content
916 cx.executor().run_until_parked();
917 model.send_last_completion_stream_text_chunk(UNFORMATTED_CONTENT.to_string());
918 model.end_last_completion_stream();
919
920 edit_task.await
921 };
922 assert!(edit_result.is_ok());
923
924 // Wait for any async operations (e.g. formatting) to complete
925 cx.executor().run_until_parked();
926
927 // Verify the file was not formatted
928 let new_content = fs.load(path!("/root/src/main.rs").as_ref()).await.unwrap();
929 assert_eq!(
930 // Ignore carriage returns on Windows
931 new_content.replace("\r\n", "\n"),
932 UNFORMATTED_CONTENT,
933 "Code should not be formatted when format_on_save is disabled"
934 );
935 }
936
937 #[gpui::test]
938 async fn test_remove_trailing_whitespace(cx: &mut TestAppContext) {
939 init_test(cx);
940
941 let fs = project::FakeFs::new(cx.executor());
942 fs.insert_tree("/root", json!({"src": {}})).await;
943
944 // Create a simple file with trailing whitespace
945 fs.save(
946 path!("/root/src/main.rs").as_ref(),
947 &"initial content".into(),
948 language::LineEnding::Unix,
949 )
950 .await
951 .unwrap();
952
953 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
954 let context_server_registry =
955 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
956 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
957 let model = Arc::new(FakeLanguageModel::default());
958 let thread = cx.new(|cx| {
959 Thread::new(
960 project.clone(),
961 cx.new(|_cx| ProjectContext::default()),
962 context_server_registry,
963 Templates::new(),
964 Some(model.clone()),
965 cx,
966 )
967 });
968
969 // First, test with remove_trailing_whitespace_on_save enabled
970 cx.update(|cx| {
971 SettingsStore::update_global(cx, |store, cx| {
972 store.update_user_settings(cx, |settings| {
973 settings
974 .project
975 .all_languages
976 .defaults
977 .remove_trailing_whitespace_on_save = Some(true);
978 });
979 });
980 });
981
982 const CONTENT_WITH_TRAILING_WHITESPACE: &str =
983 "fn main() { \n println!(\"Hello!\"); \n}\n";
984
985 // Have the model stream content that contains trailing whitespace
986 let edit_result = {
987 let edit_task = cx.update(|cx| {
988 let input = EditFileToolInput {
989 display_description: "Create main function".into(),
990 path: "root/src/main.rs".into(),
991 mode: EditFileMode::Overwrite,
992 };
993 Arc::new(EditFileTool::new(
994 project.clone(),
995 thread.downgrade(),
996 language_registry.clone(),
997 Templates::new(),
998 ))
999 .run(input, ToolCallEventStream::test().0, cx)
1000 });
1001
1002 // Stream the content with trailing whitespace
1003 cx.executor().run_until_parked();
1004 model.send_last_completion_stream_text_chunk(
1005 CONTENT_WITH_TRAILING_WHITESPACE.to_string(),
1006 );
1007 model.end_last_completion_stream();
1008
1009 edit_task.await
1010 };
1011 assert!(edit_result.is_ok());
1012
1013 // Wait for any async operations (e.g. formatting) to complete
1014 cx.executor().run_until_parked();
1015
1016 // Read the file to verify trailing whitespace was removed automatically
1017 assert_eq!(
1018 // Ignore carriage returns on Windows
1019 fs.load(path!("/root/src/main.rs").as_ref())
1020 .await
1021 .unwrap()
1022 .replace("\r\n", "\n"),
1023 "fn main() {\n println!(\"Hello!\");\n}\n",
1024 "Trailing whitespace should be removed when remove_trailing_whitespace_on_save is enabled"
1025 );
1026
1027 // Next, test with remove_trailing_whitespace_on_save disabled
1028 cx.update(|cx| {
1029 SettingsStore::update_global(cx, |store, cx| {
1030 store.update_user_settings(cx, |settings| {
1031 settings
1032 .project
1033 .all_languages
1034 .defaults
1035 .remove_trailing_whitespace_on_save = Some(false);
1036 });
1037 });
1038 });
1039
1040 // Stream edits again with trailing whitespace
1041 let edit_result = {
1042 let edit_task = cx.update(|cx| {
1043 let input = EditFileToolInput {
1044 display_description: "Update main function".into(),
1045 path: "root/src/main.rs".into(),
1046 mode: EditFileMode::Overwrite,
1047 };
1048 Arc::new(EditFileTool::new(
1049 project.clone(),
1050 thread.downgrade(),
1051 language_registry,
1052 Templates::new(),
1053 ))
1054 .run(input, ToolCallEventStream::test().0, cx)
1055 });
1056
1057 // Stream the content with trailing whitespace
1058 cx.executor().run_until_parked();
1059 model.send_last_completion_stream_text_chunk(
1060 CONTENT_WITH_TRAILING_WHITESPACE.to_string(),
1061 );
1062 model.end_last_completion_stream();
1063
1064 edit_task.await
1065 };
1066 assert!(edit_result.is_ok());
1067
1068 // Wait for any async operations (e.g. formatting) to complete
1069 cx.executor().run_until_parked();
1070
1071 // Verify the file still has trailing whitespace
1072 // Read the file again - it should still have trailing whitespace
1073 let final_content = fs.load(path!("/root/src/main.rs").as_ref()).await.unwrap();
1074 assert_eq!(
1075 // Ignore carriage returns on Windows
1076 final_content.replace("\r\n", "\n"),
1077 CONTENT_WITH_TRAILING_WHITESPACE,
1078 "Trailing whitespace should remain when remove_trailing_whitespace_on_save is disabled"
1079 );
1080 }
1081
1082 #[gpui::test]
1083 async fn test_authorize(cx: &mut TestAppContext) {
1084 init_test(cx);
1085 let fs = project::FakeFs::new(cx.executor());
1086 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
1087 let context_server_registry =
1088 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1089 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1090 let model = Arc::new(FakeLanguageModel::default());
1091 let thread = cx.new(|cx| {
1092 Thread::new(
1093 project.clone(),
1094 cx.new(|_cx| ProjectContext::default()),
1095 context_server_registry,
1096 Templates::new(),
1097 Some(model.clone()),
1098 cx,
1099 )
1100 });
1101 let tool = Arc::new(EditFileTool::new(
1102 project.clone(),
1103 thread.downgrade(),
1104 language_registry,
1105 Templates::new(),
1106 ));
1107 fs.insert_tree("/root", json!({})).await;
1108
1109 // Test 1: Path with .zed component should require confirmation
1110 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1111 let _auth = cx.update(|cx| {
1112 tool.authorize(
1113 &EditFileToolInput {
1114 display_description: "test 1".into(),
1115 path: ".zed/settings.json".into(),
1116 mode: EditFileMode::Edit,
1117 },
1118 &stream_tx,
1119 cx,
1120 )
1121 });
1122
1123 let event = stream_rx.expect_authorization().await;
1124 assert_eq!(
1125 event.tool_call.fields.title,
1126 Some("test 1 (local settings)".into())
1127 );
1128
1129 // Test 2: Path outside project should require confirmation
1130 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1131 let _auth = cx.update(|cx| {
1132 tool.authorize(
1133 &EditFileToolInput {
1134 display_description: "test 2".into(),
1135 path: "/etc/hosts".into(),
1136 mode: EditFileMode::Edit,
1137 },
1138 &stream_tx,
1139 cx,
1140 )
1141 });
1142
1143 let event = stream_rx.expect_authorization().await;
1144 assert_eq!(event.tool_call.fields.title, Some("test 2".into()));
1145
1146 // Test 3: Relative path without .zed should not require confirmation
1147 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1148 cx.update(|cx| {
1149 tool.authorize(
1150 &EditFileToolInput {
1151 display_description: "test 3".into(),
1152 path: "root/src/main.rs".into(),
1153 mode: EditFileMode::Edit,
1154 },
1155 &stream_tx,
1156 cx,
1157 )
1158 })
1159 .await
1160 .unwrap();
1161 assert!(stream_rx.try_next().is_err());
1162
1163 // Test 4: Path with .zed in the middle should require confirmation
1164 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1165 let _auth = cx.update(|cx| {
1166 tool.authorize(
1167 &EditFileToolInput {
1168 display_description: "test 4".into(),
1169 path: "root/.zed/tasks.json".into(),
1170 mode: EditFileMode::Edit,
1171 },
1172 &stream_tx,
1173 cx,
1174 )
1175 });
1176 let event = stream_rx.expect_authorization().await;
1177 assert_eq!(
1178 event.tool_call.fields.title,
1179 Some("test 4 (local settings)".into())
1180 );
1181
1182 // Test 5: When always_allow_tool_actions is enabled, no confirmation needed
1183 cx.update(|cx| {
1184 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
1185 settings.always_allow_tool_actions = true;
1186 agent_settings::AgentSettings::override_global(settings, cx);
1187 });
1188
1189 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1190 cx.update(|cx| {
1191 tool.authorize(
1192 &EditFileToolInput {
1193 display_description: "test 5.1".into(),
1194 path: ".zed/settings.json".into(),
1195 mode: EditFileMode::Edit,
1196 },
1197 &stream_tx,
1198 cx,
1199 )
1200 })
1201 .await
1202 .unwrap();
1203 assert!(stream_rx.try_next().is_err());
1204
1205 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1206 cx.update(|cx| {
1207 tool.authorize(
1208 &EditFileToolInput {
1209 display_description: "test 5.2".into(),
1210 path: "/etc/hosts".into(),
1211 mode: EditFileMode::Edit,
1212 },
1213 &stream_tx,
1214 cx,
1215 )
1216 })
1217 .await
1218 .unwrap();
1219 assert!(stream_rx.try_next().is_err());
1220 }
1221
1222 #[gpui::test]
1223 async fn test_authorize_global_config(cx: &mut TestAppContext) {
1224 init_test(cx);
1225 let fs = project::FakeFs::new(cx.executor());
1226 fs.insert_tree("/project", json!({})).await;
1227 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1228 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1229 let context_server_registry =
1230 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1231 let model = Arc::new(FakeLanguageModel::default());
1232 let thread = cx.new(|cx| {
1233 Thread::new(
1234 project.clone(),
1235 cx.new(|_cx| ProjectContext::default()),
1236 context_server_registry,
1237 Templates::new(),
1238 Some(model.clone()),
1239 cx,
1240 )
1241 });
1242 let tool = Arc::new(EditFileTool::new(
1243 project.clone(),
1244 thread.downgrade(),
1245 language_registry,
1246 Templates::new(),
1247 ));
1248
1249 // Test global config paths - these should require confirmation if they exist and are outside the project
1250 let test_cases = vec![
1251 (
1252 "/etc/hosts",
1253 true,
1254 "System file should require confirmation",
1255 ),
1256 (
1257 "/usr/local/bin/script",
1258 true,
1259 "System bin file should require confirmation",
1260 ),
1261 (
1262 "project/normal_file.rs",
1263 false,
1264 "Normal project file should not require confirmation",
1265 ),
1266 ];
1267
1268 for (path, should_confirm, description) in test_cases {
1269 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1270 let auth = cx.update(|cx| {
1271 tool.authorize(
1272 &EditFileToolInput {
1273 display_description: "Edit file".into(),
1274 path: path.into(),
1275 mode: EditFileMode::Edit,
1276 },
1277 &stream_tx,
1278 cx,
1279 )
1280 });
1281
1282 if should_confirm {
1283 stream_rx.expect_authorization().await;
1284 } else {
1285 auth.await.unwrap();
1286 assert!(
1287 stream_rx.try_next().is_err(),
1288 "Failed for case: {} - path: {} - expected no confirmation but got one",
1289 description,
1290 path
1291 );
1292 }
1293 }
1294 }
1295
1296 #[gpui::test]
1297 async fn test_needs_confirmation_with_multiple_worktrees(cx: &mut TestAppContext) {
1298 init_test(cx);
1299 let fs = project::FakeFs::new(cx.executor());
1300
1301 // Create multiple worktree directories
1302 fs.insert_tree(
1303 "/workspace/frontend",
1304 json!({
1305 "src": {
1306 "main.js": "console.log('frontend');"
1307 }
1308 }),
1309 )
1310 .await;
1311 fs.insert_tree(
1312 "/workspace/backend",
1313 json!({
1314 "src": {
1315 "main.rs": "fn main() {}"
1316 }
1317 }),
1318 )
1319 .await;
1320 fs.insert_tree(
1321 "/workspace/shared",
1322 json!({
1323 ".zed": {
1324 "settings.json": "{}"
1325 }
1326 }),
1327 )
1328 .await;
1329
1330 // Create project with multiple worktrees
1331 let project = Project::test(
1332 fs.clone(),
1333 [
1334 path!("/workspace/frontend").as_ref(),
1335 path!("/workspace/backend").as_ref(),
1336 path!("/workspace/shared").as_ref(),
1337 ],
1338 cx,
1339 )
1340 .await;
1341 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1342 let context_server_registry =
1343 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1344 let model = Arc::new(FakeLanguageModel::default());
1345 let thread = cx.new(|cx| {
1346 Thread::new(
1347 project.clone(),
1348 cx.new(|_cx| ProjectContext::default()),
1349 context_server_registry.clone(),
1350 Templates::new(),
1351 Some(model.clone()),
1352 cx,
1353 )
1354 });
1355 let tool = Arc::new(EditFileTool::new(
1356 project.clone(),
1357 thread.downgrade(),
1358 language_registry,
1359 Templates::new(),
1360 ));
1361
1362 // Test files in different worktrees
1363 let test_cases = vec![
1364 ("frontend/src/main.js", false, "File in first worktree"),
1365 ("backend/src/main.rs", false, "File in second worktree"),
1366 (
1367 "shared/.zed/settings.json",
1368 true,
1369 ".zed file in third worktree",
1370 ),
1371 ("/etc/hosts", true, "Absolute path outside all worktrees"),
1372 (
1373 "../outside/file.txt",
1374 true,
1375 "Relative path outside worktrees",
1376 ),
1377 ];
1378
1379 for (path, should_confirm, description) in test_cases {
1380 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1381 let auth = cx.update(|cx| {
1382 tool.authorize(
1383 &EditFileToolInput {
1384 display_description: "Edit file".into(),
1385 path: path.into(),
1386 mode: EditFileMode::Edit,
1387 },
1388 &stream_tx,
1389 cx,
1390 )
1391 });
1392
1393 if should_confirm {
1394 stream_rx.expect_authorization().await;
1395 } else {
1396 auth.await.unwrap();
1397 assert!(
1398 stream_rx.try_next().is_err(),
1399 "Failed for case: {} - path: {} - expected no confirmation but got one",
1400 description,
1401 path
1402 );
1403 }
1404 }
1405 }
1406
1407 #[gpui::test]
1408 async fn test_needs_confirmation_edge_cases(cx: &mut TestAppContext) {
1409 init_test(cx);
1410 let fs = project::FakeFs::new(cx.executor());
1411 fs.insert_tree(
1412 "/project",
1413 json!({
1414 ".zed": {
1415 "settings.json": "{}"
1416 },
1417 "src": {
1418 ".zed": {
1419 "local.json": "{}"
1420 }
1421 }
1422 }),
1423 )
1424 .await;
1425 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1426 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1427 let context_server_registry =
1428 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1429 let model = Arc::new(FakeLanguageModel::default());
1430 let thread = cx.new(|cx| {
1431 Thread::new(
1432 project.clone(),
1433 cx.new(|_cx| ProjectContext::default()),
1434 context_server_registry.clone(),
1435 Templates::new(),
1436 Some(model.clone()),
1437 cx,
1438 )
1439 });
1440 let tool = Arc::new(EditFileTool::new(
1441 project.clone(),
1442 thread.downgrade(),
1443 language_registry,
1444 Templates::new(),
1445 ));
1446
1447 // Test edge cases
1448 let test_cases = vec![
1449 // Empty path - find_project_path returns Some for empty paths
1450 ("", false, "Empty path is treated as project root"),
1451 // Root directory
1452 ("/", true, "Root directory should be outside project"),
1453 // Parent directory references - find_project_path resolves these
1454 (
1455 "project/../other",
1456 true,
1457 "Path with .. that goes outside of root directory",
1458 ),
1459 (
1460 "project/./src/file.rs",
1461 false,
1462 "Path with . should work normally",
1463 ),
1464 // Windows-style paths (if on Windows)
1465 #[cfg(target_os = "windows")]
1466 ("C:\\Windows\\System32\\hosts", true, "Windows system path"),
1467 #[cfg(target_os = "windows")]
1468 ("project\\src\\main.rs", false, "Windows-style project path"),
1469 ];
1470
1471 for (path, should_confirm, description) in test_cases {
1472 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1473 let auth = cx.update(|cx| {
1474 tool.authorize(
1475 &EditFileToolInput {
1476 display_description: "Edit file".into(),
1477 path: path.into(),
1478 mode: EditFileMode::Edit,
1479 },
1480 &stream_tx,
1481 cx,
1482 )
1483 });
1484
1485 cx.run_until_parked();
1486
1487 if should_confirm {
1488 stream_rx.expect_authorization().await;
1489 } else {
1490 assert!(
1491 stream_rx.try_next().is_err(),
1492 "Failed for case: {} - path: {} - expected no confirmation but got one",
1493 description,
1494 path
1495 );
1496 auth.await.unwrap();
1497 }
1498 }
1499 }
1500
1501 #[gpui::test]
1502 async fn test_needs_confirmation_with_different_modes(cx: &mut TestAppContext) {
1503 init_test(cx);
1504 let fs = project::FakeFs::new(cx.executor());
1505 fs.insert_tree(
1506 "/project",
1507 json!({
1508 "existing.txt": "content",
1509 ".zed": {
1510 "settings.json": "{}"
1511 }
1512 }),
1513 )
1514 .await;
1515 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1516 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1517 let context_server_registry =
1518 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1519 let model = Arc::new(FakeLanguageModel::default());
1520 let thread = cx.new(|cx| {
1521 Thread::new(
1522 project.clone(),
1523 cx.new(|_cx| ProjectContext::default()),
1524 context_server_registry.clone(),
1525 Templates::new(),
1526 Some(model.clone()),
1527 cx,
1528 )
1529 });
1530 let tool = Arc::new(EditFileTool::new(
1531 project.clone(),
1532 thread.downgrade(),
1533 language_registry,
1534 Templates::new(),
1535 ));
1536
1537 // Test different EditFileMode values
1538 let modes = vec![
1539 EditFileMode::Edit,
1540 EditFileMode::Create,
1541 EditFileMode::Overwrite,
1542 ];
1543
1544 for mode in modes {
1545 // Test .zed path with different modes
1546 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1547 let _auth = cx.update(|cx| {
1548 tool.authorize(
1549 &EditFileToolInput {
1550 display_description: "Edit settings".into(),
1551 path: "project/.zed/settings.json".into(),
1552 mode: mode.clone(),
1553 },
1554 &stream_tx,
1555 cx,
1556 )
1557 });
1558
1559 stream_rx.expect_authorization().await;
1560
1561 // Test outside path with different modes
1562 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1563 let _auth = cx.update(|cx| {
1564 tool.authorize(
1565 &EditFileToolInput {
1566 display_description: "Edit file".into(),
1567 path: "/outside/file.txt".into(),
1568 mode: mode.clone(),
1569 },
1570 &stream_tx,
1571 cx,
1572 )
1573 });
1574
1575 stream_rx.expect_authorization().await;
1576
1577 // Test normal path with different modes
1578 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1579 cx.update(|cx| {
1580 tool.authorize(
1581 &EditFileToolInput {
1582 display_description: "Edit file".into(),
1583 path: "project/normal.txt".into(),
1584 mode: mode.clone(),
1585 },
1586 &stream_tx,
1587 cx,
1588 )
1589 })
1590 .await
1591 .unwrap();
1592 assert!(stream_rx.try_next().is_err());
1593 }
1594 }
1595
1596 #[gpui::test]
1597 async fn test_initial_title_with_partial_input(cx: &mut TestAppContext) {
1598 init_test(cx);
1599 let fs = project::FakeFs::new(cx.executor());
1600 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1601 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1602 let context_server_registry =
1603 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1604 let model = Arc::new(FakeLanguageModel::default());
1605 let thread = cx.new(|cx| {
1606 Thread::new(
1607 project.clone(),
1608 cx.new(|_cx| ProjectContext::default()),
1609 context_server_registry,
1610 Templates::new(),
1611 Some(model.clone()),
1612 cx,
1613 )
1614 });
1615 let tool = Arc::new(EditFileTool::new(
1616 project,
1617 thread.downgrade(),
1618 language_registry,
1619 Templates::new(),
1620 ));
1621
1622 cx.update(|cx| {
1623 // ...
1624 assert_eq!(
1625 tool.initial_title(
1626 Err(json!({
1627 "path": "src/main.rs",
1628 "display_description": "",
1629 "old_string": "old code",
1630 "new_string": "new code"
1631 })),
1632 cx
1633 ),
1634 "src/main.rs"
1635 );
1636 assert_eq!(
1637 tool.initial_title(
1638 Err(json!({
1639 "path": "",
1640 "display_description": "Fix error handling",
1641 "old_string": "old code",
1642 "new_string": "new code"
1643 })),
1644 cx
1645 ),
1646 "Fix error handling"
1647 );
1648 assert_eq!(
1649 tool.initial_title(
1650 Err(json!({
1651 "path": "src/main.rs",
1652 "display_description": "Fix error handling",
1653 "old_string": "old code",
1654 "new_string": "new code"
1655 })),
1656 cx
1657 ),
1658 "src/main.rs"
1659 );
1660 assert_eq!(
1661 tool.initial_title(
1662 Err(json!({
1663 "path": "",
1664 "display_description": "",
1665 "old_string": "old code",
1666 "new_string": "new code"
1667 })),
1668 cx
1669 ),
1670 DEFAULT_UI_TEXT
1671 );
1672 assert_eq!(
1673 tool.initial_title(Err(serde_json::Value::Null), cx),
1674 DEFAULT_UI_TEXT
1675 );
1676 });
1677 }
1678
1679 #[gpui::test]
1680 async fn test_diff_finalization(cx: &mut TestAppContext) {
1681 init_test(cx);
1682 let fs = project::FakeFs::new(cx.executor());
1683 fs.insert_tree("/", json!({"main.rs": ""})).await;
1684
1685 let project = Project::test(fs.clone(), [path!("/").as_ref()], cx).await;
1686 let languages = project.read_with(cx, |project, _cx| project.languages().clone());
1687 let context_server_registry =
1688 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1689 let model = Arc::new(FakeLanguageModel::default());
1690 let thread = cx.new(|cx| {
1691 Thread::new(
1692 project.clone(),
1693 cx.new(|_cx| ProjectContext::default()),
1694 context_server_registry.clone(),
1695 Templates::new(),
1696 Some(model.clone()),
1697 cx,
1698 )
1699 });
1700
1701 // Ensure the diff is finalized after the edit completes.
1702 {
1703 let tool = Arc::new(EditFileTool::new(
1704 project.clone(),
1705 thread.downgrade(),
1706 languages.clone(),
1707 Templates::new(),
1708 ));
1709 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1710 let edit = cx.update(|cx| {
1711 tool.run(
1712 EditFileToolInput {
1713 display_description: "Edit file".into(),
1714 path: path!("/main.rs").into(),
1715 mode: EditFileMode::Edit,
1716 },
1717 stream_tx,
1718 cx,
1719 )
1720 });
1721 stream_rx.expect_update_fields().await;
1722 let diff = stream_rx.expect_diff().await;
1723 diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Pending(_))));
1724 cx.run_until_parked();
1725 model.end_last_completion_stream();
1726 edit.await.unwrap();
1727 diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Finalized(_))));
1728 }
1729
1730 // Ensure the diff is finalized if an error occurs while editing.
1731 {
1732 model.forbid_requests();
1733 let tool = Arc::new(EditFileTool::new(
1734 project.clone(),
1735 thread.downgrade(),
1736 languages.clone(),
1737 Templates::new(),
1738 ));
1739 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1740 let edit = cx.update(|cx| {
1741 tool.run(
1742 EditFileToolInput {
1743 display_description: "Edit file".into(),
1744 path: path!("/main.rs").into(),
1745 mode: EditFileMode::Edit,
1746 },
1747 stream_tx,
1748 cx,
1749 )
1750 });
1751 stream_rx.expect_update_fields().await;
1752 let diff = stream_rx.expect_diff().await;
1753 diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Pending(_))));
1754 edit.await.unwrap_err();
1755 diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Finalized(_))));
1756 model.allow_requests();
1757 }
1758
1759 // Ensure the diff is finalized if the tool call gets dropped.
1760 {
1761 let tool = Arc::new(EditFileTool::new(
1762 project.clone(),
1763 thread.downgrade(),
1764 languages.clone(),
1765 Templates::new(),
1766 ));
1767 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1768 let edit = cx.update(|cx| {
1769 tool.run(
1770 EditFileToolInput {
1771 display_description: "Edit file".into(),
1772 path: path!("/main.rs").into(),
1773 mode: EditFileMode::Edit,
1774 },
1775 stream_tx,
1776 cx,
1777 )
1778 });
1779 stream_rx.expect_update_fields().await;
1780 let diff = stream_rx.expect_diff().await;
1781 diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Pending(_))));
1782 drop(edit);
1783 cx.run_until_parked();
1784 diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Finalized(_))));
1785 }
1786 }
1787
1788 #[gpui::test]
1789 async fn test_file_read_times_tracking(cx: &mut TestAppContext) {
1790 init_test(cx);
1791
1792 let fs = project::FakeFs::new(cx.executor());
1793 fs.insert_tree(
1794 "/root",
1795 json!({
1796 "test.txt": "original content"
1797 }),
1798 )
1799 .await;
1800 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
1801 let context_server_registry =
1802 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1803 let model = Arc::new(FakeLanguageModel::default());
1804 let thread = cx.new(|cx| {
1805 Thread::new(
1806 project.clone(),
1807 cx.new(|_cx| ProjectContext::default()),
1808 context_server_registry,
1809 Templates::new(),
1810 Some(model.clone()),
1811 cx,
1812 )
1813 });
1814 let action_log = thread.read_with(cx, |thread, _| thread.action_log().clone());
1815
1816 // Initially, file_read_times should be empty
1817 let is_empty = thread.read_with(cx, |thread, _| thread.file_read_times.is_empty());
1818 assert!(is_empty, "file_read_times should start empty");
1819
1820 // Create read tool
1821 let read_tool = Arc::new(crate::ReadFileTool::new(
1822 thread.downgrade(),
1823 project.clone(),
1824 action_log,
1825 ));
1826
1827 // Read the file to record the read time
1828 cx.update(|cx| {
1829 read_tool.clone().run(
1830 crate::ReadFileToolInput {
1831 path: "root/test.txt".to_string(),
1832 start_line: None,
1833 end_line: None,
1834 },
1835 ToolCallEventStream::test().0,
1836 cx,
1837 )
1838 })
1839 .await
1840 .unwrap();
1841
1842 // Verify that file_read_times now contains an entry for the file
1843 let has_entry = thread.read_with(cx, |thread, _| {
1844 thread.file_read_times.len() == 1
1845 && thread
1846 .file_read_times
1847 .keys()
1848 .any(|path| path.ends_with("test.txt"))
1849 });
1850 assert!(
1851 has_entry,
1852 "file_read_times should contain an entry after reading the file"
1853 );
1854
1855 // Read the file again - should update the entry
1856 cx.update(|cx| {
1857 read_tool.clone().run(
1858 crate::ReadFileToolInput {
1859 path: "root/test.txt".to_string(),
1860 start_line: None,
1861 end_line: None,
1862 },
1863 ToolCallEventStream::test().0,
1864 cx,
1865 )
1866 })
1867 .await
1868 .unwrap();
1869
1870 // Should still have exactly one entry
1871 let has_one_entry = thread.read_with(cx, |thread, _| thread.file_read_times.len() == 1);
1872 assert!(
1873 has_one_entry,
1874 "file_read_times should still have one entry after re-reading"
1875 );
1876 }
1877
1878 fn init_test(cx: &mut TestAppContext) {
1879 cx.update(|cx| {
1880 let settings_store = SettingsStore::test(cx);
1881 cx.set_global(settings_store);
1882 });
1883 }
1884
1885 #[gpui::test]
1886 async fn test_consecutive_edits_work(cx: &mut TestAppContext) {
1887 init_test(cx);
1888
1889 let fs = project::FakeFs::new(cx.executor());
1890 fs.insert_tree(
1891 "/root",
1892 json!({
1893 "test.txt": "original content"
1894 }),
1895 )
1896 .await;
1897 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
1898 let context_server_registry =
1899 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1900 let model = Arc::new(FakeLanguageModel::default());
1901 let thread = cx.new(|cx| {
1902 Thread::new(
1903 project.clone(),
1904 cx.new(|_cx| ProjectContext::default()),
1905 context_server_registry,
1906 Templates::new(),
1907 Some(model.clone()),
1908 cx,
1909 )
1910 });
1911 let languages = project.read_with(cx, |project, _| project.languages().clone());
1912 let action_log = thread.read_with(cx, |thread, _| thread.action_log().clone());
1913
1914 let read_tool = Arc::new(crate::ReadFileTool::new(
1915 thread.downgrade(),
1916 project.clone(),
1917 action_log,
1918 ));
1919 let edit_tool = Arc::new(EditFileTool::new(
1920 project.clone(),
1921 thread.downgrade(),
1922 languages,
1923 Templates::new(),
1924 ));
1925
1926 // Read the file first
1927 cx.update(|cx| {
1928 read_tool.clone().run(
1929 crate::ReadFileToolInput {
1930 path: "root/test.txt".to_string(),
1931 start_line: None,
1932 end_line: None,
1933 },
1934 ToolCallEventStream::test().0,
1935 cx,
1936 )
1937 })
1938 .await
1939 .unwrap();
1940
1941 // First edit should work
1942 let edit_result = {
1943 let edit_task = cx.update(|cx| {
1944 edit_tool.clone().run(
1945 EditFileToolInput {
1946 display_description: "First edit".into(),
1947 path: "root/test.txt".into(),
1948 mode: EditFileMode::Edit,
1949 },
1950 ToolCallEventStream::test().0,
1951 cx,
1952 )
1953 });
1954
1955 cx.executor().run_until_parked();
1956 model.send_last_completion_stream_text_chunk(
1957 "<old_text>original content</old_text><new_text>modified content</new_text>"
1958 .to_string(),
1959 );
1960 model.end_last_completion_stream();
1961
1962 edit_task.await
1963 };
1964 assert!(
1965 edit_result.is_ok(),
1966 "First edit should succeed, got error: {:?}",
1967 edit_result.as_ref().err()
1968 );
1969
1970 // Second edit should also work because the edit updated the recorded read time
1971 let edit_result = {
1972 let edit_task = cx.update(|cx| {
1973 edit_tool.clone().run(
1974 EditFileToolInput {
1975 display_description: "Second edit".into(),
1976 path: "root/test.txt".into(),
1977 mode: EditFileMode::Edit,
1978 },
1979 ToolCallEventStream::test().0,
1980 cx,
1981 )
1982 });
1983
1984 cx.executor().run_until_parked();
1985 model.send_last_completion_stream_text_chunk(
1986 "<old_text>modified content</old_text><new_text>further modified content</new_text>".to_string(),
1987 );
1988 model.end_last_completion_stream();
1989
1990 edit_task.await
1991 };
1992 assert!(
1993 edit_result.is_ok(),
1994 "Second consecutive edit should succeed, got error: {:?}",
1995 edit_result.as_ref().err()
1996 );
1997 }
1998
1999 #[gpui::test]
2000 async fn test_external_modification_detected(cx: &mut TestAppContext) {
2001 init_test(cx);
2002
2003 let fs = project::FakeFs::new(cx.executor());
2004 fs.insert_tree(
2005 "/root",
2006 json!({
2007 "test.txt": "original content"
2008 }),
2009 )
2010 .await;
2011 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
2012 let context_server_registry =
2013 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
2014 let model = Arc::new(FakeLanguageModel::default());
2015 let thread = cx.new(|cx| {
2016 Thread::new(
2017 project.clone(),
2018 cx.new(|_cx| ProjectContext::default()),
2019 context_server_registry,
2020 Templates::new(),
2021 Some(model.clone()),
2022 cx,
2023 )
2024 });
2025 let languages = project.read_with(cx, |project, _| project.languages().clone());
2026 let action_log = thread.read_with(cx, |thread, _| thread.action_log().clone());
2027
2028 let read_tool = Arc::new(crate::ReadFileTool::new(
2029 thread.downgrade(),
2030 project.clone(),
2031 action_log,
2032 ));
2033 let edit_tool = Arc::new(EditFileTool::new(
2034 project.clone(),
2035 thread.downgrade(),
2036 languages,
2037 Templates::new(),
2038 ));
2039
2040 // Read the file first
2041 cx.update(|cx| {
2042 read_tool.clone().run(
2043 crate::ReadFileToolInput {
2044 path: "root/test.txt".to_string(),
2045 start_line: None,
2046 end_line: None,
2047 },
2048 ToolCallEventStream::test().0,
2049 cx,
2050 )
2051 })
2052 .await
2053 .unwrap();
2054
2055 // Simulate external modification - advance time and save file
2056 cx.background_executor
2057 .advance_clock(std::time::Duration::from_secs(2));
2058 fs.save(
2059 path!("/root/test.txt").as_ref(),
2060 &"externally modified content".into(),
2061 language::LineEnding::Unix,
2062 )
2063 .await
2064 .unwrap();
2065
2066 // Reload the buffer to pick up the new mtime
2067 let project_path = project
2068 .read_with(cx, |project, cx| {
2069 project.find_project_path("root/test.txt", cx)
2070 })
2071 .expect("Should find project path");
2072 let buffer = project
2073 .update(cx, |project, cx| project.open_buffer(project_path, cx))
2074 .await
2075 .unwrap();
2076 buffer
2077 .update(cx, |buffer, cx| buffer.reload(cx))
2078 .await
2079 .unwrap();
2080
2081 cx.executor().run_until_parked();
2082
2083 // Try to edit - should fail because file was modified externally
2084 let result = cx
2085 .update(|cx| {
2086 edit_tool.clone().run(
2087 EditFileToolInput {
2088 display_description: "Edit after external change".into(),
2089 path: "root/test.txt".into(),
2090 mode: EditFileMode::Edit,
2091 },
2092 ToolCallEventStream::test().0,
2093 cx,
2094 )
2095 })
2096 .await;
2097
2098 assert!(
2099 result.is_err(),
2100 "Edit should fail after external modification"
2101 );
2102 let error_msg = result.unwrap_err().to_string();
2103 assert!(
2104 error_msg.contains("has been modified since you last read it"),
2105 "Error should mention file modification, got: {}",
2106 error_msg
2107 );
2108 }
2109
2110 #[gpui::test]
2111 async fn test_dirty_buffer_detected(cx: &mut TestAppContext) {
2112 init_test(cx);
2113
2114 let fs = project::FakeFs::new(cx.executor());
2115 fs.insert_tree(
2116 "/root",
2117 json!({
2118 "test.txt": "original content"
2119 }),
2120 )
2121 .await;
2122 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
2123 let context_server_registry =
2124 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
2125 let model = Arc::new(FakeLanguageModel::default());
2126 let thread = cx.new(|cx| {
2127 Thread::new(
2128 project.clone(),
2129 cx.new(|_cx| ProjectContext::default()),
2130 context_server_registry,
2131 Templates::new(),
2132 Some(model.clone()),
2133 cx,
2134 )
2135 });
2136 let languages = project.read_with(cx, |project, _| project.languages().clone());
2137 let action_log = thread.read_with(cx, |thread, _| thread.action_log().clone());
2138
2139 let read_tool = Arc::new(crate::ReadFileTool::new(
2140 thread.downgrade(),
2141 project.clone(),
2142 action_log,
2143 ));
2144 let edit_tool = Arc::new(EditFileTool::new(
2145 project.clone(),
2146 thread.downgrade(),
2147 languages,
2148 Templates::new(),
2149 ));
2150
2151 // Read the file first
2152 cx.update(|cx| {
2153 read_tool.clone().run(
2154 crate::ReadFileToolInput {
2155 path: "root/test.txt".to_string(),
2156 start_line: None,
2157 end_line: None,
2158 },
2159 ToolCallEventStream::test().0,
2160 cx,
2161 )
2162 })
2163 .await
2164 .unwrap();
2165
2166 // Open the buffer and make it dirty by editing without saving
2167 let project_path = project
2168 .read_with(cx, |project, cx| {
2169 project.find_project_path("root/test.txt", cx)
2170 })
2171 .expect("Should find project path");
2172 let buffer = project
2173 .update(cx, |project, cx| project.open_buffer(project_path, cx))
2174 .await
2175 .unwrap();
2176
2177 // Make an in-memory edit to the buffer (making it dirty)
2178 buffer.update(cx, |buffer, cx| {
2179 let end_point = buffer.max_point();
2180 buffer.edit([(end_point..end_point, " added text")], None, cx);
2181 });
2182
2183 // Verify buffer is dirty
2184 let is_dirty = buffer.read_with(cx, |buffer, _| buffer.is_dirty());
2185 assert!(is_dirty, "Buffer should be dirty after in-memory edit");
2186
2187 // Try to edit - should fail because buffer has unsaved changes
2188 let result = cx
2189 .update(|cx| {
2190 edit_tool.clone().run(
2191 EditFileToolInput {
2192 display_description: "Edit with dirty buffer".into(),
2193 path: "root/test.txt".into(),
2194 mode: EditFileMode::Edit,
2195 },
2196 ToolCallEventStream::test().0,
2197 cx,
2198 )
2199 })
2200 .await;
2201
2202 assert!(result.is_err(), "Edit should fail when buffer is dirty");
2203 let error_msg = result.unwrap_err().to_string();
2204 assert!(
2205 error_msg.contains("This file has unsaved changes."),
2206 "Error should mention unsaved changes, got: {}",
2207 error_msg
2208 );
2209 assert!(
2210 error_msg.contains("keep or discard"),
2211 "Error should ask whether to keep or discard changes, got: {}",
2212 error_msg
2213 );
2214 assert!(
2215 error_msg.contains("save_file"),
2216 "Error should reference save_file tool, got: {}",
2217 error_msg
2218 );
2219 assert!(
2220 error_msg.contains("restore_file_from_disk"),
2221 "Error should reference restore_file_from_disk tool, got: {}",
2222 error_msg
2223 );
2224 }
2225}