1use crate::{
2 AgentTool, Templates, Thread, ToolCallEventStream,
3 edit_agent::{EditAgent, EditAgentOutput, EditAgentOutputEvent, EditFormat},
4};
5use acp_thread::Diff;
6use agent_client_protocol::{self as acp, ToolCallLocation, ToolCallUpdateFields};
7use anyhow::{Context as _, Result, anyhow};
8use cloud_llm_client::CompletionIntent;
9use collections::HashSet;
10use gpui::{App, AppContext, AsyncApp, Entity, Task, WeakEntity};
11use indoc::formatdoc;
12use language::language_settings::{self, FormatOnSave};
13use language::{LanguageRegistry, ToPoint};
14use language_model::LanguageModelToolResultContent;
15use paths;
16use project::lsp_store::{FormatTrigger, LspFormatTarget};
17use project::{Project, ProjectPath};
18use schemars::JsonSchema;
19use serde::{Deserialize, Serialize};
20use settings::Settings;
21use smol::stream::StreamExt as _;
22use std::ffi::OsStr;
23use std::path::{Path, PathBuf};
24use std::sync::Arc;
25use ui::SharedString;
26use util::ResultExt;
27use util::rel_path::RelPath;
28
29const DEFAULT_UI_TEXT: &str = "Editing file";
30
31/// This is a tool for creating a new file or editing an existing file. For moving or renaming files, you should generally use the `terminal` tool with the 'mv' command instead.
32///
33/// Before using this tool:
34///
35/// 1. Use the `read_file` tool to understand the file's contents and context
36///
37/// 2. Verify the directory path is correct (only applicable when creating new files):
38/// - Use the `list_directory` tool to verify the parent directory exists and is the correct location
39#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
40pub struct EditFileToolInput {
41 /// A one-line, user-friendly markdown description of the edit. This will be shown in the UI and also passed to another model to perform the edit.
42 ///
43 /// Be terse, but also descriptive in what you want to achieve with this edit. Avoid generic instructions.
44 ///
45 /// NEVER mention the file path in this description.
46 ///
47 /// <example>Fix API endpoint URLs</example>
48 /// <example>Update copyright year in `page_footer`</example>
49 ///
50 /// Make sure to include this field before all the others in the input object so that we can display it immediately.
51 pub display_description: String,
52
53 /// The full path of the file to create or modify in the project.
54 ///
55 /// WARNING: When specifying which file path need changing, you MUST start each path with one of the project's root directories.
56 ///
57 /// The following examples assume we have two root directories in the project:
58 /// - /a/b/backend
59 /// - /c/d/frontend
60 ///
61 /// <example>
62 /// `backend/src/main.rs`
63 ///
64 /// Notice how the file path starts with `backend`. Without that, the path would be ambiguous and the call would fail!
65 /// </example>
66 ///
67 /// <example>
68 /// `frontend/db.js`
69 /// </example>
70 pub path: PathBuf,
71 /// The mode of operation on the file. Possible values:
72 /// - 'edit': Make granular edits to an existing file.
73 /// - 'create': Create a new file if it doesn't exist.
74 /// - 'overwrite': Replace the entire contents of an existing file.
75 ///
76 /// When a file already exists or you just created it, prefer editing it as opposed to recreating it from scratch.
77 pub mode: EditFileMode,
78}
79
80#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
81struct EditFileToolPartialInput {
82 #[serde(default)]
83 path: String,
84 #[serde(default)]
85 display_description: String,
86}
87
88#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
89#[serde(rename_all = "lowercase")]
90#[schemars(inline)]
91pub enum EditFileMode {
92 Edit,
93 Create,
94 Overwrite,
95}
96
97#[derive(Debug, Serialize, Deserialize)]
98pub struct EditFileToolOutput {
99 #[serde(alias = "original_path")]
100 input_path: PathBuf,
101 new_text: String,
102 old_text: Arc<String>,
103 #[serde(default)]
104 diff: String,
105 #[serde(alias = "raw_output")]
106 edit_agent_output: EditAgentOutput,
107}
108
109impl From<EditFileToolOutput> for LanguageModelToolResultContent {
110 fn from(output: EditFileToolOutput) -> Self {
111 if output.diff.is_empty() {
112 "No edits were made.".into()
113 } else {
114 format!(
115 "Edited {}:\n\n```diff\n{}\n```",
116 output.input_path.display(),
117 output.diff
118 )
119 .into()
120 }
121 }
122}
123
124pub struct EditFileTool {
125 thread: WeakEntity<Thread>,
126 language_registry: Arc<LanguageRegistry>,
127 project: Entity<Project>,
128 templates: Arc<Templates>,
129}
130
131impl EditFileTool {
132 pub fn new(
133 project: Entity<Project>,
134 thread: WeakEntity<Thread>,
135 language_registry: Arc<LanguageRegistry>,
136 templates: Arc<Templates>,
137 ) -> Self {
138 Self {
139 project,
140 thread,
141 language_registry,
142 templates,
143 }
144 }
145
146 fn authorize(
147 &self,
148 input: &EditFileToolInput,
149 event_stream: &ToolCallEventStream,
150 cx: &mut App,
151 ) -> Task<Result<()>> {
152 if agent_settings::AgentSettings::get_global(cx).always_allow_tool_actions {
153 return Task::ready(Ok(()));
154 }
155
156 // If any path component matches the local settings folder, then this could affect
157 // the editor in ways beyond the project source, so prompt.
158 let local_settings_folder = paths::local_settings_folder_name();
159 let path = Path::new(&input.path);
160 if path.components().any(|component| {
161 component.as_os_str() == <_ as AsRef<OsStr>>::as_ref(&local_settings_folder)
162 }) {
163 return event_stream.authorize(
164 format!("{} (local settings)", input.display_description),
165 cx,
166 );
167 }
168
169 // It's also possible that the global config dir is configured to be inside the project,
170 // so check for that edge case too.
171 // TODO this is broken when remoting
172 if let Ok(canonical_path) = std::fs::canonicalize(&input.path)
173 && canonical_path.starts_with(paths::config_dir())
174 {
175 return event_stream.authorize(
176 format!("{} (global settings)", input.display_description),
177 cx,
178 );
179 }
180
181 // Check if path is inside the global config directory
182 // First check if it's already inside project - if not, try to canonicalize
183 let Ok(project_path) = self.thread.read_with(cx, |thread, cx| {
184 thread.project().read(cx).find_project_path(&input.path, cx)
185 }) else {
186 return Task::ready(Err(anyhow!("thread was dropped")));
187 };
188
189 // If the path is inside the project, and it's not one of the above edge cases,
190 // then no confirmation is necessary. Otherwise, confirmation is necessary.
191 if project_path.is_some() {
192 Task::ready(Ok(()))
193 } else {
194 event_stream.authorize(&input.display_description, cx)
195 }
196 }
197}
198
199impl AgentTool for EditFileTool {
200 type Input = EditFileToolInput;
201 type Output = EditFileToolOutput;
202
203 fn name() -> &'static str {
204 "edit_file"
205 }
206
207 fn kind() -> acp::ToolKind {
208 acp::ToolKind::Edit
209 }
210
211 fn initial_title(
212 &self,
213 input: Result<Self::Input, serde_json::Value>,
214 cx: &mut App,
215 ) -> SharedString {
216 match input {
217 Ok(input) => self
218 .project
219 .read(cx)
220 .find_project_path(&input.path, cx)
221 .and_then(|project_path| {
222 self.project
223 .read(cx)
224 .short_full_path_for_project_path(&project_path, cx)
225 })
226 .unwrap_or(input.path.to_string_lossy().into_owned())
227 .into(),
228 Err(raw_input) => {
229 if let Some(input) =
230 serde_json::from_value::<EditFileToolPartialInput>(raw_input).ok()
231 {
232 let path = input.path.trim();
233 if !path.is_empty() {
234 return self
235 .project
236 .read(cx)
237 .find_project_path(&input.path, cx)
238 .and_then(|project_path| {
239 self.project
240 .read(cx)
241 .short_full_path_for_project_path(&project_path, cx)
242 })
243 .unwrap_or(input.path)
244 .into();
245 }
246
247 let description = input.display_description.trim();
248 if !description.is_empty() {
249 return description.to_string().into();
250 }
251 }
252
253 DEFAULT_UI_TEXT.into()
254 }
255 }
256 }
257
258 fn run(
259 self: Arc<Self>,
260 input: Self::Input,
261 event_stream: ToolCallEventStream,
262 cx: &mut App,
263 ) -> Task<Result<Self::Output>> {
264 let Ok(project) = self
265 .thread
266 .read_with(cx, |thread, _cx| thread.project().clone())
267 else {
268 return Task::ready(Err(anyhow!("thread was dropped")));
269 };
270 let project_path = match resolve_path(&input, project.clone(), cx) {
271 Ok(path) => path,
272 Err(err) => return Task::ready(Err(anyhow!(err))),
273 };
274 let abs_path = project.read(cx).absolute_path(&project_path, cx);
275 if let Some(abs_path) = abs_path.clone() {
276 event_stream.update_fields(
277 ToolCallUpdateFields::new().locations(vec![acp::ToolCallLocation::new(abs_path)]),
278 );
279 }
280
281 let authorize = self.authorize(&input, &event_stream, cx);
282 cx.spawn(async move |cx: &mut AsyncApp| {
283 authorize.await?;
284
285 let (request, model, action_log) = self.thread.update(cx, |thread, cx| {
286 let request = thread.build_completion_request(CompletionIntent::ToolResults, cx);
287 (request, thread.model().cloned(), thread.action_log().clone())
288 })?;
289 let request = request?;
290 let model = model.context("No language model configured")?;
291
292 let edit_format = EditFormat::from_model(model.clone())?;
293 let edit_agent = EditAgent::new(
294 model,
295 project.clone(),
296 action_log.clone(),
297 self.templates.clone(),
298 edit_format,
299 );
300
301 let buffer = project
302 .update(cx, |project, cx| {
303 project.open_buffer(project_path.clone(), cx)
304 })?
305 .await?;
306
307 // Check if the file has been modified since the agent last read it
308 if let Some(abs_path) = abs_path.as_ref() {
309 let (last_read_mtime, current_mtime, is_dirty) = self.thread.update(cx, |thread, cx| {
310 let last_read = thread.file_read_times.get(abs_path).copied();
311 let current = buffer.read(cx).file().and_then(|file| file.disk_state().mtime());
312 let dirty = buffer.read(cx).is_dirty();
313 (last_read, current, dirty)
314 })?;
315
316 // Check for unsaved changes first - these indicate modifications we don't know about
317 if is_dirty {
318 anyhow::bail!(
319 "This file cannot be written to because it has unsaved changes. \
320 Please end the current conversation immediately by telling the user you want to write to this file (mention its path explicitly) but you can't write to it because it has unsaved changes. \
321 Ask the user to save that buffer's changes and to inform you when it's ok to proceed."
322 );
323 }
324
325 // Check if the file was modified on disk since we last read it
326 if let (Some(last_read), Some(current)) = (last_read_mtime, current_mtime) {
327 // MTime can be unreliable for comparisons, so our newtype intentionally
328 // doesn't support comparing them. If the mtime at all different
329 // (which could be because of a modification or because e.g. system clock changed),
330 // we pessimistically assume it was modified.
331 if current != last_read {
332 anyhow::bail!(
333 "The file {} has been modified since you last read it. \
334 Please read the file again to get the current state before editing it.",
335 input.path.display()
336 );
337 }
338 }
339 }
340
341 let diff = cx.new(|cx| Diff::new(buffer.clone(), cx))?;
342 event_stream.update_diff(diff.clone());
343 let _finalize_diff = util::defer({
344 let diff = diff.downgrade();
345 let mut cx = cx.clone();
346 move || {
347 diff.update(&mut cx, |diff, cx| diff.finalize(cx)).ok();
348 }
349 });
350
351 let old_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
352 let old_text = cx
353 .background_spawn({
354 let old_snapshot = old_snapshot.clone();
355 async move { Arc::new(old_snapshot.text()) }
356 })
357 .await;
358
359
360 let (output, mut events) = if matches!(input.mode, EditFileMode::Edit) {
361 edit_agent.edit(
362 buffer.clone(),
363 input.display_description.clone(),
364 &request,
365 cx,
366 )
367 } else {
368 edit_agent.overwrite(
369 buffer.clone(),
370 input.display_description.clone(),
371 &request,
372 cx,
373 )
374 };
375
376 let mut hallucinated_old_text = false;
377 let mut ambiguous_ranges = Vec::new();
378 let mut emitted_location = false;
379 while let Some(event) = events.next().await {
380 match event {
381 EditAgentOutputEvent::Edited(range) => {
382 if !emitted_location {
383 let line = buffer.update(cx, |buffer, _cx| {
384 range.start.to_point(&buffer.snapshot()).row
385 }).ok();
386 if let Some(abs_path) = abs_path.clone() {
387 let mut location = ToolCallLocation::new(abs_path);
388 if let Some(line) = line {
389 location = location.line(line);
390 }
391 event_stream.update_fields(ToolCallUpdateFields::new().locations(vec![location]));
392 }
393 emitted_location = true;
394 }
395 },
396 EditAgentOutputEvent::UnresolvedEditRange => hallucinated_old_text = true,
397 EditAgentOutputEvent::AmbiguousEditRange(ranges) => ambiguous_ranges = ranges,
398 EditAgentOutputEvent::ResolvingEditRange(range) => {
399 diff.update(cx, |card, cx| card.reveal_range(range.clone(), cx))?;
400 // if !emitted_location {
401 // let line = buffer.update(cx, |buffer, _cx| {
402 // range.start.to_point(&buffer.snapshot()).row
403 // }).ok();
404 // if let Some(abs_path) = abs_path.clone() {
405 // event_stream.update_fields(ToolCallUpdateFields {
406 // locations: Some(vec![ToolCallLocation { path: abs_path, line }]),
407 // ..Default::default()
408 // });
409 // }
410 // }
411 }
412 }
413 }
414
415 // If format_on_save is enabled, format the buffer
416 let format_on_save_enabled = buffer
417 .read_with(cx, |buffer, cx| {
418 let settings = language_settings::language_settings(
419 buffer.language().map(|l| l.name()),
420 buffer.file(),
421 cx,
422 );
423 settings.format_on_save != FormatOnSave::Off
424 })
425 .unwrap_or(false);
426
427 let edit_agent_output = output.await?;
428
429 if format_on_save_enabled {
430 action_log.update(cx, |log, cx| {
431 log.buffer_edited(buffer.clone(), cx);
432 })?;
433
434 let format_task = project.update(cx, |project, cx| {
435 project.format(
436 HashSet::from_iter([buffer.clone()]),
437 LspFormatTarget::Buffers,
438 false, // Don't push to history since the tool did it.
439 FormatTrigger::Save,
440 cx,
441 )
442 })?;
443 format_task.await.log_err();
444 }
445
446 project
447 .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))?
448 .await?;
449
450 action_log.update(cx, |log, cx| {
451 log.buffer_edited(buffer.clone(), cx);
452 })?;
453
454 // Update the recorded read time after a successful edit so consecutive edits work
455 if let Some(abs_path) = abs_path.as_ref() {
456 if let Some(new_mtime) = buffer.read_with(cx, |buffer, _| {
457 buffer.file().and_then(|file| file.disk_state().mtime())
458 })? {
459 self.thread.update(cx, |thread, _| {
460 thread.file_read_times.insert(abs_path.to_path_buf(), new_mtime);
461 })?;
462 }
463 }
464
465 let new_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
466 let (new_text, unified_diff) = cx
467 .background_spawn({
468 let new_snapshot = new_snapshot.clone();
469 let old_text = old_text.clone();
470 async move {
471 let new_text = new_snapshot.text();
472 let diff = language::unified_diff(&old_text, &new_text);
473 (new_text, diff)
474 }
475 })
476 .await;
477
478 let input_path = input.path.display();
479 if unified_diff.is_empty() {
480 anyhow::ensure!(
481 !hallucinated_old_text,
482 formatdoc! {"
483 Some edits were produced but none of them could be applied.
484 Read the relevant sections of {input_path} again so that
485 I can perform the requested edits.
486 "}
487 );
488 anyhow::ensure!(
489 ambiguous_ranges.is_empty(),
490 {
491 let line_numbers = ambiguous_ranges
492 .iter()
493 .map(|range| range.start.to_string())
494 .collect::<Vec<_>>()
495 .join(", ");
496 formatdoc! {"
497 <old_text> matches more than one position in the file (lines: {line_numbers}). Read the
498 relevant sections of {input_path} again and extend <old_text> so
499 that I can perform the requested edits.
500 "}
501 }
502 );
503 }
504
505 Ok(EditFileToolOutput {
506 input_path: input.path,
507 new_text,
508 old_text,
509 diff: unified_diff,
510 edit_agent_output,
511 })
512 })
513 }
514
515 fn replay(
516 &self,
517 _input: Self::Input,
518 output: Self::Output,
519 event_stream: ToolCallEventStream,
520 cx: &mut App,
521 ) -> Result<()> {
522 event_stream.update_diff(cx.new(|cx| {
523 Diff::finalized(
524 output.input_path.to_string_lossy().into_owned(),
525 Some(output.old_text.to_string()),
526 output.new_text,
527 self.language_registry.clone(),
528 cx,
529 )
530 }));
531 Ok(())
532 }
533}
534
535/// Validate that the file path is valid, meaning:
536///
537/// - For `edit` and `overwrite`, the path must point to an existing file.
538/// - For `create`, the file must not already exist, but it's parent dir must exist.
539fn resolve_path(
540 input: &EditFileToolInput,
541 project: Entity<Project>,
542 cx: &mut App,
543) -> Result<ProjectPath> {
544 let project = project.read(cx);
545
546 match input.mode {
547 EditFileMode::Edit | EditFileMode::Overwrite => {
548 let path = project
549 .find_project_path(&input.path, cx)
550 .context("Can't edit file: path not found")?;
551
552 let entry = project
553 .entry_for_path(&path, cx)
554 .context("Can't edit file: path not found")?;
555
556 anyhow::ensure!(entry.is_file(), "Can't edit file: path is a directory");
557 Ok(path)
558 }
559
560 EditFileMode::Create => {
561 if let Some(path) = project.find_project_path(&input.path, cx) {
562 anyhow::ensure!(
563 project.entry_for_path(&path, cx).is_none(),
564 "Can't create file: file already exists"
565 );
566 }
567
568 let parent_path = input
569 .path
570 .parent()
571 .context("Can't create file: incorrect path")?;
572
573 let parent_project_path = project.find_project_path(&parent_path, cx);
574
575 let parent_entry = parent_project_path
576 .as_ref()
577 .and_then(|path| project.entry_for_path(path, cx))
578 .context("Can't create file: parent directory doesn't exist")?;
579
580 anyhow::ensure!(
581 parent_entry.is_dir(),
582 "Can't create file: parent is not a directory"
583 );
584
585 let file_name = input
586 .path
587 .file_name()
588 .and_then(|file_name| file_name.to_str())
589 .and_then(|file_name| RelPath::unix(file_name).ok())
590 .context("Can't create file: invalid filename")?;
591
592 let new_file_path = parent_project_path.map(|parent| ProjectPath {
593 path: parent.path.join(file_name),
594 ..parent
595 });
596
597 new_file_path.context("Can't create file")
598 }
599 }
600}
601
602#[cfg(test)]
603mod tests {
604 use super::*;
605 use crate::{ContextServerRegistry, Templates};
606 use fs::Fs;
607 use gpui::{TestAppContext, UpdateGlobal};
608 use language_model::fake_provider::FakeLanguageModel;
609 use prompt_store::ProjectContext;
610 use serde_json::json;
611 use settings::SettingsStore;
612 use util::{path, rel_path::rel_path};
613
614 #[gpui::test]
615 async fn test_edit_nonexistent_file(cx: &mut TestAppContext) {
616 init_test(cx);
617
618 let fs = project::FakeFs::new(cx.executor());
619 fs.insert_tree("/root", json!({})).await;
620 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
621 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
622 let context_server_registry =
623 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
624 let model = Arc::new(FakeLanguageModel::default());
625 let thread = cx.new(|cx| {
626 Thread::new(
627 project.clone(),
628 cx.new(|_cx| ProjectContext::default()),
629 context_server_registry,
630 Templates::new(),
631 Some(model),
632 cx,
633 )
634 });
635 let result = cx
636 .update(|cx| {
637 let input = EditFileToolInput {
638 display_description: "Some edit".into(),
639 path: "root/nonexistent_file.txt".into(),
640 mode: EditFileMode::Edit,
641 };
642 Arc::new(EditFileTool::new(
643 project,
644 thread.downgrade(),
645 language_registry,
646 Templates::new(),
647 ))
648 .run(input, ToolCallEventStream::test().0, cx)
649 })
650 .await;
651 assert_eq!(
652 result.unwrap_err().to_string(),
653 "Can't edit file: path not found"
654 );
655 }
656
657 #[gpui::test]
658 async fn test_resolve_path_for_creating_file(cx: &mut TestAppContext) {
659 let mode = &EditFileMode::Create;
660
661 let result = test_resolve_path(mode, "root/new.txt", cx);
662 assert_resolved_path_eq(result.await, rel_path("new.txt"));
663
664 let result = test_resolve_path(mode, "new.txt", cx);
665 assert_resolved_path_eq(result.await, rel_path("new.txt"));
666
667 let result = test_resolve_path(mode, "dir/new.txt", cx);
668 assert_resolved_path_eq(result.await, rel_path("dir/new.txt"));
669
670 let result = test_resolve_path(mode, "root/dir/subdir/existing.txt", cx);
671 assert_eq!(
672 result.await.unwrap_err().to_string(),
673 "Can't create file: file already exists"
674 );
675
676 let result = test_resolve_path(mode, "root/dir/nonexistent_dir/new.txt", cx);
677 assert_eq!(
678 result.await.unwrap_err().to_string(),
679 "Can't create file: parent directory doesn't exist"
680 );
681 }
682
683 #[gpui::test]
684 async fn test_resolve_path_for_editing_file(cx: &mut TestAppContext) {
685 let mode = &EditFileMode::Edit;
686
687 let path_with_root = "root/dir/subdir/existing.txt";
688 let path_without_root = "dir/subdir/existing.txt";
689 let result = test_resolve_path(mode, path_with_root, cx);
690 assert_resolved_path_eq(result.await, rel_path(path_without_root));
691
692 let result = test_resolve_path(mode, path_without_root, cx);
693 assert_resolved_path_eq(result.await, rel_path(path_without_root));
694
695 let result = test_resolve_path(mode, "root/nonexistent.txt", cx);
696 assert_eq!(
697 result.await.unwrap_err().to_string(),
698 "Can't edit file: path not found"
699 );
700
701 let result = test_resolve_path(mode, "root/dir", cx);
702 assert_eq!(
703 result.await.unwrap_err().to_string(),
704 "Can't edit file: path is a directory"
705 );
706 }
707
708 async fn test_resolve_path(
709 mode: &EditFileMode,
710 path: &str,
711 cx: &mut TestAppContext,
712 ) -> anyhow::Result<ProjectPath> {
713 init_test(cx);
714
715 let fs = project::FakeFs::new(cx.executor());
716 fs.insert_tree(
717 "/root",
718 json!({
719 "dir": {
720 "subdir": {
721 "existing.txt": "hello"
722 }
723 }
724 }),
725 )
726 .await;
727 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
728
729 let input = EditFileToolInput {
730 display_description: "Some edit".into(),
731 path: path.into(),
732 mode: mode.clone(),
733 };
734
735 cx.update(|cx| resolve_path(&input, project, cx))
736 }
737
738 #[track_caller]
739 fn assert_resolved_path_eq(path: anyhow::Result<ProjectPath>, expected: &RelPath) {
740 let actual = path.expect("Should return valid path").path;
741 assert_eq!(actual.as_ref(), expected);
742 }
743
744 #[gpui::test]
745 async fn test_format_on_save(cx: &mut TestAppContext) {
746 init_test(cx);
747
748 let fs = project::FakeFs::new(cx.executor());
749 fs.insert_tree("/root", json!({"src": {}})).await;
750
751 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
752
753 // Set up a Rust language with LSP formatting support
754 let rust_language = Arc::new(language::Language::new(
755 language::LanguageConfig {
756 name: "Rust".into(),
757 matcher: language::LanguageMatcher {
758 path_suffixes: vec!["rs".to_string()],
759 ..Default::default()
760 },
761 ..Default::default()
762 },
763 None,
764 ));
765
766 // Register the language and fake LSP
767 let language_registry = project.read_with(cx, |project, _| project.languages().clone());
768 language_registry.add(rust_language);
769
770 let mut fake_language_servers = language_registry.register_fake_lsp(
771 "Rust",
772 language::FakeLspAdapter {
773 capabilities: lsp::ServerCapabilities {
774 document_formatting_provider: Some(lsp::OneOf::Left(true)),
775 ..Default::default()
776 },
777 ..Default::default()
778 },
779 );
780
781 // Create the file
782 fs.save(
783 path!("/root/src/main.rs").as_ref(),
784 &"initial content".into(),
785 language::LineEnding::Unix,
786 )
787 .await
788 .unwrap();
789
790 // Open the buffer to trigger LSP initialization
791 let buffer = project
792 .update(cx, |project, cx| {
793 project.open_local_buffer(path!("/root/src/main.rs"), cx)
794 })
795 .await
796 .unwrap();
797
798 // Register the buffer with language servers
799 let _handle = project.update(cx, |project, cx| {
800 project.register_buffer_with_language_servers(&buffer, cx)
801 });
802
803 const UNFORMATTED_CONTENT: &str = "fn main() {println!(\"Hello!\");}\n";
804 const FORMATTED_CONTENT: &str =
805 "This file was formatted by the fake formatter in the test.\n";
806
807 // Get the fake language server and set up formatting handler
808 let fake_language_server = fake_language_servers.next().await.unwrap();
809 fake_language_server.set_request_handler::<lsp::request::Formatting, _, _>({
810 |_, _| async move {
811 Ok(Some(vec![lsp::TextEdit {
812 range: lsp::Range::new(lsp::Position::new(0, 0), lsp::Position::new(1, 0)),
813 new_text: FORMATTED_CONTENT.to_string(),
814 }]))
815 }
816 });
817
818 let context_server_registry =
819 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
820 let model = Arc::new(FakeLanguageModel::default());
821 let thread = cx.new(|cx| {
822 Thread::new(
823 project.clone(),
824 cx.new(|_cx| ProjectContext::default()),
825 context_server_registry,
826 Templates::new(),
827 Some(model.clone()),
828 cx,
829 )
830 });
831
832 // First, test with format_on_save enabled
833 cx.update(|cx| {
834 SettingsStore::update_global(cx, |store, cx| {
835 store.update_user_settings(cx, |settings| {
836 settings.project.all_languages.defaults.format_on_save = Some(FormatOnSave::On);
837 settings.project.all_languages.defaults.formatter =
838 Some(language::language_settings::FormatterList::default());
839 });
840 });
841 });
842
843 // Have the model stream unformatted content
844 let edit_result = {
845 let edit_task = cx.update(|cx| {
846 let input = EditFileToolInput {
847 display_description: "Create main function".into(),
848 path: "root/src/main.rs".into(),
849 mode: EditFileMode::Overwrite,
850 };
851 Arc::new(EditFileTool::new(
852 project.clone(),
853 thread.downgrade(),
854 language_registry.clone(),
855 Templates::new(),
856 ))
857 .run(input, ToolCallEventStream::test().0, cx)
858 });
859
860 // Stream the unformatted content
861 cx.executor().run_until_parked();
862 model.send_last_completion_stream_text_chunk(UNFORMATTED_CONTENT.to_string());
863 model.end_last_completion_stream();
864
865 edit_task.await
866 };
867 assert!(edit_result.is_ok());
868
869 // Wait for any async operations (e.g. formatting) to complete
870 cx.executor().run_until_parked();
871
872 // Read the file to verify it was formatted automatically
873 let new_content = fs.load(path!("/root/src/main.rs").as_ref()).await.unwrap();
874 assert_eq!(
875 // Ignore carriage returns on Windows
876 new_content.replace("\r\n", "\n"),
877 FORMATTED_CONTENT,
878 "Code should be formatted when format_on_save is enabled"
879 );
880
881 let stale_buffer_count = thread
882 .read_with(cx, |thread, _cx| thread.action_log.clone())
883 .read_with(cx, |log, cx| log.stale_buffers(cx).count());
884
885 assert_eq!(
886 stale_buffer_count, 0,
887 "BUG: Buffer is incorrectly marked as stale after format-on-save. Found {} stale buffers. \
888 This causes the agent to think the file was modified externally when it was just formatted.",
889 stale_buffer_count
890 );
891
892 // Next, test with format_on_save disabled
893 cx.update(|cx| {
894 SettingsStore::update_global(cx, |store, cx| {
895 store.update_user_settings(cx, |settings| {
896 settings.project.all_languages.defaults.format_on_save =
897 Some(FormatOnSave::Off);
898 });
899 });
900 });
901
902 // Stream unformatted edits again
903 let edit_result = {
904 let edit_task = cx.update(|cx| {
905 let input = EditFileToolInput {
906 display_description: "Update main function".into(),
907 path: "root/src/main.rs".into(),
908 mode: EditFileMode::Overwrite,
909 };
910 Arc::new(EditFileTool::new(
911 project.clone(),
912 thread.downgrade(),
913 language_registry,
914 Templates::new(),
915 ))
916 .run(input, ToolCallEventStream::test().0, cx)
917 });
918
919 // Stream the unformatted content
920 cx.executor().run_until_parked();
921 model.send_last_completion_stream_text_chunk(UNFORMATTED_CONTENT.to_string());
922 model.end_last_completion_stream();
923
924 edit_task.await
925 };
926 assert!(edit_result.is_ok());
927
928 // Wait for any async operations (e.g. formatting) to complete
929 cx.executor().run_until_parked();
930
931 // Verify the file was not formatted
932 let new_content = fs.load(path!("/root/src/main.rs").as_ref()).await.unwrap();
933 assert_eq!(
934 // Ignore carriage returns on Windows
935 new_content.replace("\r\n", "\n"),
936 UNFORMATTED_CONTENT,
937 "Code should not be formatted when format_on_save is disabled"
938 );
939 }
940
941 #[gpui::test]
942 async fn test_remove_trailing_whitespace(cx: &mut TestAppContext) {
943 init_test(cx);
944
945 let fs = project::FakeFs::new(cx.executor());
946 fs.insert_tree("/root", json!({"src": {}})).await;
947
948 // Create a simple file with trailing whitespace
949 fs.save(
950 path!("/root/src/main.rs").as_ref(),
951 &"initial content".into(),
952 language::LineEnding::Unix,
953 )
954 .await
955 .unwrap();
956
957 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
958 let context_server_registry =
959 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
960 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
961 let model = Arc::new(FakeLanguageModel::default());
962 let thread = cx.new(|cx| {
963 Thread::new(
964 project.clone(),
965 cx.new(|_cx| ProjectContext::default()),
966 context_server_registry,
967 Templates::new(),
968 Some(model.clone()),
969 cx,
970 )
971 });
972
973 // First, test with remove_trailing_whitespace_on_save enabled
974 cx.update(|cx| {
975 SettingsStore::update_global(cx, |store, cx| {
976 store.update_user_settings(cx, |settings| {
977 settings
978 .project
979 .all_languages
980 .defaults
981 .remove_trailing_whitespace_on_save = Some(true);
982 });
983 });
984 });
985
986 const CONTENT_WITH_TRAILING_WHITESPACE: &str =
987 "fn main() { \n println!(\"Hello!\"); \n}\n";
988
989 // Have the model stream content that contains trailing whitespace
990 let edit_result = {
991 let edit_task = cx.update(|cx| {
992 let input = EditFileToolInput {
993 display_description: "Create main function".into(),
994 path: "root/src/main.rs".into(),
995 mode: EditFileMode::Overwrite,
996 };
997 Arc::new(EditFileTool::new(
998 project.clone(),
999 thread.downgrade(),
1000 language_registry.clone(),
1001 Templates::new(),
1002 ))
1003 .run(input, ToolCallEventStream::test().0, cx)
1004 });
1005
1006 // Stream the content with trailing whitespace
1007 cx.executor().run_until_parked();
1008 model.send_last_completion_stream_text_chunk(
1009 CONTENT_WITH_TRAILING_WHITESPACE.to_string(),
1010 );
1011 model.end_last_completion_stream();
1012
1013 edit_task.await
1014 };
1015 assert!(edit_result.is_ok());
1016
1017 // Wait for any async operations (e.g. formatting) to complete
1018 cx.executor().run_until_parked();
1019
1020 // Read the file to verify trailing whitespace was removed automatically
1021 assert_eq!(
1022 // Ignore carriage returns on Windows
1023 fs.load(path!("/root/src/main.rs").as_ref())
1024 .await
1025 .unwrap()
1026 .replace("\r\n", "\n"),
1027 "fn main() {\n println!(\"Hello!\");\n}\n",
1028 "Trailing whitespace should be removed when remove_trailing_whitespace_on_save is enabled"
1029 );
1030
1031 // Next, test with remove_trailing_whitespace_on_save disabled
1032 cx.update(|cx| {
1033 SettingsStore::update_global(cx, |store, cx| {
1034 store.update_user_settings(cx, |settings| {
1035 settings
1036 .project
1037 .all_languages
1038 .defaults
1039 .remove_trailing_whitespace_on_save = Some(false);
1040 });
1041 });
1042 });
1043
1044 // Stream edits again with trailing whitespace
1045 let edit_result = {
1046 let edit_task = cx.update(|cx| {
1047 let input = EditFileToolInput {
1048 display_description: "Update main function".into(),
1049 path: "root/src/main.rs".into(),
1050 mode: EditFileMode::Overwrite,
1051 };
1052 Arc::new(EditFileTool::new(
1053 project.clone(),
1054 thread.downgrade(),
1055 language_registry,
1056 Templates::new(),
1057 ))
1058 .run(input, ToolCallEventStream::test().0, cx)
1059 });
1060
1061 // Stream the content with trailing whitespace
1062 cx.executor().run_until_parked();
1063 model.send_last_completion_stream_text_chunk(
1064 CONTENT_WITH_TRAILING_WHITESPACE.to_string(),
1065 );
1066 model.end_last_completion_stream();
1067
1068 edit_task.await
1069 };
1070 assert!(edit_result.is_ok());
1071
1072 // Wait for any async operations (e.g. formatting) to complete
1073 cx.executor().run_until_parked();
1074
1075 // Verify the file still has trailing whitespace
1076 // Read the file again - it should still have trailing whitespace
1077 let final_content = fs.load(path!("/root/src/main.rs").as_ref()).await.unwrap();
1078 assert_eq!(
1079 // Ignore carriage returns on Windows
1080 final_content.replace("\r\n", "\n"),
1081 CONTENT_WITH_TRAILING_WHITESPACE,
1082 "Trailing whitespace should remain when remove_trailing_whitespace_on_save is disabled"
1083 );
1084 }
1085
1086 #[gpui::test]
1087 async fn test_authorize(cx: &mut TestAppContext) {
1088 init_test(cx);
1089 let fs = project::FakeFs::new(cx.executor());
1090 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
1091 let context_server_registry =
1092 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1093 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1094 let model = Arc::new(FakeLanguageModel::default());
1095 let thread = cx.new(|cx| {
1096 Thread::new(
1097 project.clone(),
1098 cx.new(|_cx| ProjectContext::default()),
1099 context_server_registry,
1100 Templates::new(),
1101 Some(model.clone()),
1102 cx,
1103 )
1104 });
1105 let tool = Arc::new(EditFileTool::new(
1106 project.clone(),
1107 thread.downgrade(),
1108 language_registry,
1109 Templates::new(),
1110 ));
1111 fs.insert_tree("/root", json!({})).await;
1112
1113 // Test 1: Path with .zed component should require confirmation
1114 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1115 let _auth = cx.update(|cx| {
1116 tool.authorize(
1117 &EditFileToolInput {
1118 display_description: "test 1".into(),
1119 path: ".zed/settings.json".into(),
1120 mode: EditFileMode::Edit,
1121 },
1122 &stream_tx,
1123 cx,
1124 )
1125 });
1126
1127 let event = stream_rx.expect_authorization().await;
1128 assert_eq!(
1129 event.tool_call.fields.title,
1130 Some("test 1 (local settings)".into())
1131 );
1132
1133 // Test 2: Path outside project should require confirmation
1134 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1135 let _auth = cx.update(|cx| {
1136 tool.authorize(
1137 &EditFileToolInput {
1138 display_description: "test 2".into(),
1139 path: "/etc/hosts".into(),
1140 mode: EditFileMode::Edit,
1141 },
1142 &stream_tx,
1143 cx,
1144 )
1145 });
1146
1147 let event = stream_rx.expect_authorization().await;
1148 assert_eq!(event.tool_call.fields.title, Some("test 2".into()));
1149
1150 // Test 3: Relative path without .zed should not require confirmation
1151 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1152 cx.update(|cx| {
1153 tool.authorize(
1154 &EditFileToolInput {
1155 display_description: "test 3".into(),
1156 path: "root/src/main.rs".into(),
1157 mode: EditFileMode::Edit,
1158 },
1159 &stream_tx,
1160 cx,
1161 )
1162 })
1163 .await
1164 .unwrap();
1165 assert!(stream_rx.try_next().is_err());
1166
1167 // Test 4: Path with .zed in the middle should require confirmation
1168 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1169 let _auth = cx.update(|cx| {
1170 tool.authorize(
1171 &EditFileToolInput {
1172 display_description: "test 4".into(),
1173 path: "root/.zed/tasks.json".into(),
1174 mode: EditFileMode::Edit,
1175 },
1176 &stream_tx,
1177 cx,
1178 )
1179 });
1180 let event = stream_rx.expect_authorization().await;
1181 assert_eq!(
1182 event.tool_call.fields.title,
1183 Some("test 4 (local settings)".into())
1184 );
1185
1186 // Test 5: When always_allow_tool_actions is enabled, no confirmation needed
1187 cx.update(|cx| {
1188 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
1189 settings.always_allow_tool_actions = true;
1190 agent_settings::AgentSettings::override_global(settings, cx);
1191 });
1192
1193 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1194 cx.update(|cx| {
1195 tool.authorize(
1196 &EditFileToolInput {
1197 display_description: "test 5.1".into(),
1198 path: ".zed/settings.json".into(),
1199 mode: EditFileMode::Edit,
1200 },
1201 &stream_tx,
1202 cx,
1203 )
1204 })
1205 .await
1206 .unwrap();
1207 assert!(stream_rx.try_next().is_err());
1208
1209 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1210 cx.update(|cx| {
1211 tool.authorize(
1212 &EditFileToolInput {
1213 display_description: "test 5.2".into(),
1214 path: "/etc/hosts".into(),
1215 mode: EditFileMode::Edit,
1216 },
1217 &stream_tx,
1218 cx,
1219 )
1220 })
1221 .await
1222 .unwrap();
1223 assert!(stream_rx.try_next().is_err());
1224 }
1225
1226 #[gpui::test]
1227 async fn test_authorize_global_config(cx: &mut TestAppContext) {
1228 init_test(cx);
1229 let fs = project::FakeFs::new(cx.executor());
1230 fs.insert_tree("/project", json!({})).await;
1231 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1232 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1233 let context_server_registry =
1234 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1235 let model = Arc::new(FakeLanguageModel::default());
1236 let thread = cx.new(|cx| {
1237 Thread::new(
1238 project.clone(),
1239 cx.new(|_cx| ProjectContext::default()),
1240 context_server_registry,
1241 Templates::new(),
1242 Some(model.clone()),
1243 cx,
1244 )
1245 });
1246 let tool = Arc::new(EditFileTool::new(
1247 project.clone(),
1248 thread.downgrade(),
1249 language_registry,
1250 Templates::new(),
1251 ));
1252
1253 // Test global config paths - these should require confirmation if they exist and are outside the project
1254 let test_cases = vec![
1255 (
1256 "/etc/hosts",
1257 true,
1258 "System file should require confirmation",
1259 ),
1260 (
1261 "/usr/local/bin/script",
1262 true,
1263 "System bin file should require confirmation",
1264 ),
1265 (
1266 "project/normal_file.rs",
1267 false,
1268 "Normal project file should not require confirmation",
1269 ),
1270 ];
1271
1272 for (path, should_confirm, description) in test_cases {
1273 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1274 let auth = cx.update(|cx| {
1275 tool.authorize(
1276 &EditFileToolInput {
1277 display_description: "Edit file".into(),
1278 path: path.into(),
1279 mode: EditFileMode::Edit,
1280 },
1281 &stream_tx,
1282 cx,
1283 )
1284 });
1285
1286 if should_confirm {
1287 stream_rx.expect_authorization().await;
1288 } else {
1289 auth.await.unwrap();
1290 assert!(
1291 stream_rx.try_next().is_err(),
1292 "Failed for case: {} - path: {} - expected no confirmation but got one",
1293 description,
1294 path
1295 );
1296 }
1297 }
1298 }
1299
1300 #[gpui::test]
1301 async fn test_needs_confirmation_with_multiple_worktrees(cx: &mut TestAppContext) {
1302 init_test(cx);
1303 let fs = project::FakeFs::new(cx.executor());
1304
1305 // Create multiple worktree directories
1306 fs.insert_tree(
1307 "/workspace/frontend",
1308 json!({
1309 "src": {
1310 "main.js": "console.log('frontend');"
1311 }
1312 }),
1313 )
1314 .await;
1315 fs.insert_tree(
1316 "/workspace/backend",
1317 json!({
1318 "src": {
1319 "main.rs": "fn main() {}"
1320 }
1321 }),
1322 )
1323 .await;
1324 fs.insert_tree(
1325 "/workspace/shared",
1326 json!({
1327 ".zed": {
1328 "settings.json": "{}"
1329 }
1330 }),
1331 )
1332 .await;
1333
1334 // Create project with multiple worktrees
1335 let project = Project::test(
1336 fs.clone(),
1337 [
1338 path!("/workspace/frontend").as_ref(),
1339 path!("/workspace/backend").as_ref(),
1340 path!("/workspace/shared").as_ref(),
1341 ],
1342 cx,
1343 )
1344 .await;
1345 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1346 let context_server_registry =
1347 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1348 let model = Arc::new(FakeLanguageModel::default());
1349 let thread = cx.new(|cx| {
1350 Thread::new(
1351 project.clone(),
1352 cx.new(|_cx| ProjectContext::default()),
1353 context_server_registry.clone(),
1354 Templates::new(),
1355 Some(model.clone()),
1356 cx,
1357 )
1358 });
1359 let tool = Arc::new(EditFileTool::new(
1360 project.clone(),
1361 thread.downgrade(),
1362 language_registry,
1363 Templates::new(),
1364 ));
1365
1366 // Test files in different worktrees
1367 let test_cases = vec![
1368 ("frontend/src/main.js", false, "File in first worktree"),
1369 ("backend/src/main.rs", false, "File in second worktree"),
1370 (
1371 "shared/.zed/settings.json",
1372 true,
1373 ".zed file in third worktree",
1374 ),
1375 ("/etc/hosts", true, "Absolute path outside all worktrees"),
1376 (
1377 "../outside/file.txt",
1378 true,
1379 "Relative path outside worktrees",
1380 ),
1381 ];
1382
1383 for (path, should_confirm, description) in test_cases {
1384 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1385 let auth = cx.update(|cx| {
1386 tool.authorize(
1387 &EditFileToolInput {
1388 display_description: "Edit file".into(),
1389 path: path.into(),
1390 mode: EditFileMode::Edit,
1391 },
1392 &stream_tx,
1393 cx,
1394 )
1395 });
1396
1397 if should_confirm {
1398 stream_rx.expect_authorization().await;
1399 } else {
1400 auth.await.unwrap();
1401 assert!(
1402 stream_rx.try_next().is_err(),
1403 "Failed for case: {} - path: {} - expected no confirmation but got one",
1404 description,
1405 path
1406 );
1407 }
1408 }
1409 }
1410
1411 #[gpui::test]
1412 async fn test_needs_confirmation_edge_cases(cx: &mut TestAppContext) {
1413 init_test(cx);
1414 let fs = project::FakeFs::new(cx.executor());
1415 fs.insert_tree(
1416 "/project",
1417 json!({
1418 ".zed": {
1419 "settings.json": "{}"
1420 },
1421 "src": {
1422 ".zed": {
1423 "local.json": "{}"
1424 }
1425 }
1426 }),
1427 )
1428 .await;
1429 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1430 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1431 let context_server_registry =
1432 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1433 let model = Arc::new(FakeLanguageModel::default());
1434 let thread = cx.new(|cx| {
1435 Thread::new(
1436 project.clone(),
1437 cx.new(|_cx| ProjectContext::default()),
1438 context_server_registry.clone(),
1439 Templates::new(),
1440 Some(model.clone()),
1441 cx,
1442 )
1443 });
1444 let tool = Arc::new(EditFileTool::new(
1445 project.clone(),
1446 thread.downgrade(),
1447 language_registry,
1448 Templates::new(),
1449 ));
1450
1451 // Test edge cases
1452 let test_cases = vec![
1453 // Empty path - find_project_path returns Some for empty paths
1454 ("", false, "Empty path is treated as project root"),
1455 // Root directory
1456 ("/", true, "Root directory should be outside project"),
1457 // Parent directory references - find_project_path resolves these
1458 (
1459 "project/../other",
1460 true,
1461 "Path with .. that goes outside of root directory",
1462 ),
1463 (
1464 "project/./src/file.rs",
1465 false,
1466 "Path with . should work normally",
1467 ),
1468 // Windows-style paths (if on Windows)
1469 #[cfg(target_os = "windows")]
1470 ("C:\\Windows\\System32\\hosts", true, "Windows system path"),
1471 #[cfg(target_os = "windows")]
1472 ("project\\src\\main.rs", false, "Windows-style project path"),
1473 ];
1474
1475 for (path, should_confirm, description) in test_cases {
1476 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1477 let auth = cx.update(|cx| {
1478 tool.authorize(
1479 &EditFileToolInput {
1480 display_description: "Edit file".into(),
1481 path: path.into(),
1482 mode: EditFileMode::Edit,
1483 },
1484 &stream_tx,
1485 cx,
1486 )
1487 });
1488
1489 cx.run_until_parked();
1490
1491 if should_confirm {
1492 stream_rx.expect_authorization().await;
1493 } else {
1494 assert!(
1495 stream_rx.try_next().is_err(),
1496 "Failed for case: {} - path: {} - expected no confirmation but got one",
1497 description,
1498 path
1499 );
1500 auth.await.unwrap();
1501 }
1502 }
1503 }
1504
1505 #[gpui::test]
1506 async fn test_needs_confirmation_with_different_modes(cx: &mut TestAppContext) {
1507 init_test(cx);
1508 let fs = project::FakeFs::new(cx.executor());
1509 fs.insert_tree(
1510 "/project",
1511 json!({
1512 "existing.txt": "content",
1513 ".zed": {
1514 "settings.json": "{}"
1515 }
1516 }),
1517 )
1518 .await;
1519 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1520 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1521 let context_server_registry =
1522 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1523 let model = Arc::new(FakeLanguageModel::default());
1524 let thread = cx.new(|cx| {
1525 Thread::new(
1526 project.clone(),
1527 cx.new(|_cx| ProjectContext::default()),
1528 context_server_registry.clone(),
1529 Templates::new(),
1530 Some(model.clone()),
1531 cx,
1532 )
1533 });
1534 let tool = Arc::new(EditFileTool::new(
1535 project.clone(),
1536 thread.downgrade(),
1537 language_registry,
1538 Templates::new(),
1539 ));
1540
1541 // Test different EditFileMode values
1542 let modes = vec![
1543 EditFileMode::Edit,
1544 EditFileMode::Create,
1545 EditFileMode::Overwrite,
1546 ];
1547
1548 for mode in modes {
1549 // Test .zed path with different modes
1550 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1551 let _auth = cx.update(|cx| {
1552 tool.authorize(
1553 &EditFileToolInput {
1554 display_description: "Edit settings".into(),
1555 path: "project/.zed/settings.json".into(),
1556 mode: mode.clone(),
1557 },
1558 &stream_tx,
1559 cx,
1560 )
1561 });
1562
1563 stream_rx.expect_authorization().await;
1564
1565 // Test outside path with different modes
1566 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1567 let _auth = cx.update(|cx| {
1568 tool.authorize(
1569 &EditFileToolInput {
1570 display_description: "Edit file".into(),
1571 path: "/outside/file.txt".into(),
1572 mode: mode.clone(),
1573 },
1574 &stream_tx,
1575 cx,
1576 )
1577 });
1578
1579 stream_rx.expect_authorization().await;
1580
1581 // Test normal path with different modes
1582 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1583 cx.update(|cx| {
1584 tool.authorize(
1585 &EditFileToolInput {
1586 display_description: "Edit file".into(),
1587 path: "project/normal.txt".into(),
1588 mode: mode.clone(),
1589 },
1590 &stream_tx,
1591 cx,
1592 )
1593 })
1594 .await
1595 .unwrap();
1596 assert!(stream_rx.try_next().is_err());
1597 }
1598 }
1599
1600 #[gpui::test]
1601 async fn test_initial_title_with_partial_input(cx: &mut TestAppContext) {
1602 init_test(cx);
1603 let fs = project::FakeFs::new(cx.executor());
1604 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1605 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1606 let context_server_registry =
1607 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1608 let model = Arc::new(FakeLanguageModel::default());
1609 let thread = cx.new(|cx| {
1610 Thread::new(
1611 project.clone(),
1612 cx.new(|_cx| ProjectContext::default()),
1613 context_server_registry,
1614 Templates::new(),
1615 Some(model.clone()),
1616 cx,
1617 )
1618 });
1619 let tool = Arc::new(EditFileTool::new(
1620 project,
1621 thread.downgrade(),
1622 language_registry,
1623 Templates::new(),
1624 ));
1625
1626 cx.update(|cx| {
1627 // ...
1628 assert_eq!(
1629 tool.initial_title(
1630 Err(json!({
1631 "path": "src/main.rs",
1632 "display_description": "",
1633 "old_string": "old code",
1634 "new_string": "new code"
1635 })),
1636 cx
1637 ),
1638 "src/main.rs"
1639 );
1640 assert_eq!(
1641 tool.initial_title(
1642 Err(json!({
1643 "path": "",
1644 "display_description": "Fix error handling",
1645 "old_string": "old code",
1646 "new_string": "new code"
1647 })),
1648 cx
1649 ),
1650 "Fix error handling"
1651 );
1652 assert_eq!(
1653 tool.initial_title(
1654 Err(json!({
1655 "path": "src/main.rs",
1656 "display_description": "Fix error handling",
1657 "old_string": "old code",
1658 "new_string": "new code"
1659 })),
1660 cx
1661 ),
1662 "src/main.rs"
1663 );
1664 assert_eq!(
1665 tool.initial_title(
1666 Err(json!({
1667 "path": "",
1668 "display_description": "",
1669 "old_string": "old code",
1670 "new_string": "new code"
1671 })),
1672 cx
1673 ),
1674 DEFAULT_UI_TEXT
1675 );
1676 assert_eq!(
1677 tool.initial_title(Err(serde_json::Value::Null), cx),
1678 DEFAULT_UI_TEXT
1679 );
1680 });
1681 }
1682
1683 #[gpui::test]
1684 async fn test_diff_finalization(cx: &mut TestAppContext) {
1685 init_test(cx);
1686 let fs = project::FakeFs::new(cx.executor());
1687 fs.insert_tree("/", json!({"main.rs": ""})).await;
1688
1689 let project = Project::test(fs.clone(), [path!("/").as_ref()], cx).await;
1690 let languages = project.read_with(cx, |project, _cx| project.languages().clone());
1691 let context_server_registry =
1692 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1693 let model = Arc::new(FakeLanguageModel::default());
1694 let thread = cx.new(|cx| {
1695 Thread::new(
1696 project.clone(),
1697 cx.new(|_cx| ProjectContext::default()),
1698 context_server_registry.clone(),
1699 Templates::new(),
1700 Some(model.clone()),
1701 cx,
1702 )
1703 });
1704
1705 // Ensure the diff is finalized after the edit completes.
1706 {
1707 let tool = Arc::new(EditFileTool::new(
1708 project.clone(),
1709 thread.downgrade(),
1710 languages.clone(),
1711 Templates::new(),
1712 ));
1713 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1714 let edit = cx.update(|cx| {
1715 tool.run(
1716 EditFileToolInput {
1717 display_description: "Edit file".into(),
1718 path: path!("/main.rs").into(),
1719 mode: EditFileMode::Edit,
1720 },
1721 stream_tx,
1722 cx,
1723 )
1724 });
1725 stream_rx.expect_update_fields().await;
1726 let diff = stream_rx.expect_diff().await;
1727 diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Pending(_))));
1728 cx.run_until_parked();
1729 model.end_last_completion_stream();
1730 edit.await.unwrap();
1731 diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Finalized(_))));
1732 }
1733
1734 // Ensure the diff is finalized if an error occurs while editing.
1735 {
1736 model.forbid_requests();
1737 let tool = Arc::new(EditFileTool::new(
1738 project.clone(),
1739 thread.downgrade(),
1740 languages.clone(),
1741 Templates::new(),
1742 ));
1743 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1744 let edit = cx.update(|cx| {
1745 tool.run(
1746 EditFileToolInput {
1747 display_description: "Edit file".into(),
1748 path: path!("/main.rs").into(),
1749 mode: EditFileMode::Edit,
1750 },
1751 stream_tx,
1752 cx,
1753 )
1754 });
1755 stream_rx.expect_update_fields().await;
1756 let diff = stream_rx.expect_diff().await;
1757 diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Pending(_))));
1758 edit.await.unwrap_err();
1759 diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Finalized(_))));
1760 model.allow_requests();
1761 }
1762
1763 // Ensure the diff is finalized if the tool call gets dropped.
1764 {
1765 let tool = Arc::new(EditFileTool::new(
1766 project.clone(),
1767 thread.downgrade(),
1768 languages.clone(),
1769 Templates::new(),
1770 ));
1771 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1772 let edit = cx.update(|cx| {
1773 tool.run(
1774 EditFileToolInput {
1775 display_description: "Edit file".into(),
1776 path: path!("/main.rs").into(),
1777 mode: EditFileMode::Edit,
1778 },
1779 stream_tx,
1780 cx,
1781 )
1782 });
1783 stream_rx.expect_update_fields().await;
1784 let diff = stream_rx.expect_diff().await;
1785 diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Pending(_))));
1786 drop(edit);
1787 cx.run_until_parked();
1788 diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Finalized(_))));
1789 }
1790 }
1791
1792 #[gpui::test]
1793 async fn test_file_read_times_tracking(cx: &mut TestAppContext) {
1794 init_test(cx);
1795
1796 let fs = project::FakeFs::new(cx.executor());
1797 fs.insert_tree(
1798 "/root",
1799 json!({
1800 "test.txt": "original content"
1801 }),
1802 )
1803 .await;
1804 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
1805 let context_server_registry =
1806 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1807 let model = Arc::new(FakeLanguageModel::default());
1808 let thread = cx.new(|cx| {
1809 Thread::new(
1810 project.clone(),
1811 cx.new(|_cx| ProjectContext::default()),
1812 context_server_registry,
1813 Templates::new(),
1814 Some(model.clone()),
1815 cx,
1816 )
1817 });
1818 let action_log = thread.read_with(cx, |thread, _| thread.action_log().clone());
1819
1820 // Initially, file_read_times should be empty
1821 let is_empty = thread.read_with(cx, |thread, _| thread.file_read_times.is_empty());
1822 assert!(is_empty, "file_read_times should start empty");
1823
1824 // Create read tool
1825 let read_tool = Arc::new(crate::ReadFileTool::new(
1826 thread.downgrade(),
1827 project.clone(),
1828 action_log,
1829 ));
1830
1831 // Read the file to record the read time
1832 cx.update(|cx| {
1833 read_tool.clone().run(
1834 crate::ReadFileToolInput {
1835 path: "root/test.txt".to_string(),
1836 start_line: None,
1837 end_line: None,
1838 },
1839 ToolCallEventStream::test().0,
1840 cx,
1841 )
1842 })
1843 .await
1844 .unwrap();
1845
1846 // Verify that file_read_times now contains an entry for the file
1847 let has_entry = thread.read_with(cx, |thread, _| {
1848 thread.file_read_times.len() == 1
1849 && thread
1850 .file_read_times
1851 .keys()
1852 .any(|path| path.ends_with("test.txt"))
1853 });
1854 assert!(
1855 has_entry,
1856 "file_read_times should contain an entry after reading the file"
1857 );
1858
1859 // Read the file again - should update the entry
1860 cx.update(|cx| {
1861 read_tool.clone().run(
1862 crate::ReadFileToolInput {
1863 path: "root/test.txt".to_string(),
1864 start_line: None,
1865 end_line: None,
1866 },
1867 ToolCallEventStream::test().0,
1868 cx,
1869 )
1870 })
1871 .await
1872 .unwrap();
1873
1874 // Should still have exactly one entry
1875 let has_one_entry = thread.read_with(cx, |thread, _| thread.file_read_times.len() == 1);
1876 assert!(
1877 has_one_entry,
1878 "file_read_times should still have one entry after re-reading"
1879 );
1880 }
1881
1882 fn init_test(cx: &mut TestAppContext) {
1883 cx.update(|cx| {
1884 let settings_store = SettingsStore::test(cx);
1885 cx.set_global(settings_store);
1886 });
1887 }
1888
1889 #[gpui::test]
1890 async fn test_consecutive_edits_work(cx: &mut TestAppContext) {
1891 init_test(cx);
1892
1893 let fs = project::FakeFs::new(cx.executor());
1894 fs.insert_tree(
1895 "/root",
1896 json!({
1897 "test.txt": "original content"
1898 }),
1899 )
1900 .await;
1901 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
1902 let context_server_registry =
1903 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1904 let model = Arc::new(FakeLanguageModel::default());
1905 let thread = cx.new(|cx| {
1906 Thread::new(
1907 project.clone(),
1908 cx.new(|_cx| ProjectContext::default()),
1909 context_server_registry,
1910 Templates::new(),
1911 Some(model.clone()),
1912 cx,
1913 )
1914 });
1915 let languages = project.read_with(cx, |project, _| project.languages().clone());
1916 let action_log = thread.read_with(cx, |thread, _| thread.action_log().clone());
1917
1918 let read_tool = Arc::new(crate::ReadFileTool::new(
1919 thread.downgrade(),
1920 project.clone(),
1921 action_log,
1922 ));
1923 let edit_tool = Arc::new(EditFileTool::new(
1924 project.clone(),
1925 thread.downgrade(),
1926 languages,
1927 Templates::new(),
1928 ));
1929
1930 // Read the file first
1931 cx.update(|cx| {
1932 read_tool.clone().run(
1933 crate::ReadFileToolInput {
1934 path: "root/test.txt".to_string(),
1935 start_line: None,
1936 end_line: None,
1937 },
1938 ToolCallEventStream::test().0,
1939 cx,
1940 )
1941 })
1942 .await
1943 .unwrap();
1944
1945 // First edit should work
1946 let edit_result = {
1947 let edit_task = cx.update(|cx| {
1948 edit_tool.clone().run(
1949 EditFileToolInput {
1950 display_description: "First edit".into(),
1951 path: "root/test.txt".into(),
1952 mode: EditFileMode::Edit,
1953 },
1954 ToolCallEventStream::test().0,
1955 cx,
1956 )
1957 });
1958
1959 cx.executor().run_until_parked();
1960 model.send_last_completion_stream_text_chunk(
1961 "<old_text>original content</old_text><new_text>modified content</new_text>"
1962 .to_string(),
1963 );
1964 model.end_last_completion_stream();
1965
1966 edit_task.await
1967 };
1968 assert!(
1969 edit_result.is_ok(),
1970 "First edit should succeed, got error: {:?}",
1971 edit_result.as_ref().err()
1972 );
1973
1974 // Second edit should also work because the edit updated the recorded read time
1975 let edit_result = {
1976 let edit_task = cx.update(|cx| {
1977 edit_tool.clone().run(
1978 EditFileToolInput {
1979 display_description: "Second edit".into(),
1980 path: "root/test.txt".into(),
1981 mode: EditFileMode::Edit,
1982 },
1983 ToolCallEventStream::test().0,
1984 cx,
1985 )
1986 });
1987
1988 cx.executor().run_until_parked();
1989 model.send_last_completion_stream_text_chunk(
1990 "<old_text>modified content</old_text><new_text>further modified content</new_text>".to_string(),
1991 );
1992 model.end_last_completion_stream();
1993
1994 edit_task.await
1995 };
1996 assert!(
1997 edit_result.is_ok(),
1998 "Second consecutive edit should succeed, got error: {:?}",
1999 edit_result.as_ref().err()
2000 );
2001 }
2002
2003 #[gpui::test]
2004 async fn test_external_modification_detected(cx: &mut TestAppContext) {
2005 init_test(cx);
2006
2007 let fs = project::FakeFs::new(cx.executor());
2008 fs.insert_tree(
2009 "/root",
2010 json!({
2011 "test.txt": "original content"
2012 }),
2013 )
2014 .await;
2015 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
2016 let context_server_registry =
2017 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
2018 let model = Arc::new(FakeLanguageModel::default());
2019 let thread = cx.new(|cx| {
2020 Thread::new(
2021 project.clone(),
2022 cx.new(|_cx| ProjectContext::default()),
2023 context_server_registry,
2024 Templates::new(),
2025 Some(model.clone()),
2026 cx,
2027 )
2028 });
2029 let languages = project.read_with(cx, |project, _| project.languages().clone());
2030 let action_log = thread.read_with(cx, |thread, _| thread.action_log().clone());
2031
2032 let read_tool = Arc::new(crate::ReadFileTool::new(
2033 thread.downgrade(),
2034 project.clone(),
2035 action_log,
2036 ));
2037 let edit_tool = Arc::new(EditFileTool::new(
2038 project.clone(),
2039 thread.downgrade(),
2040 languages,
2041 Templates::new(),
2042 ));
2043
2044 // Read the file first
2045 cx.update(|cx| {
2046 read_tool.clone().run(
2047 crate::ReadFileToolInput {
2048 path: "root/test.txt".to_string(),
2049 start_line: None,
2050 end_line: None,
2051 },
2052 ToolCallEventStream::test().0,
2053 cx,
2054 )
2055 })
2056 .await
2057 .unwrap();
2058
2059 // Simulate external modification - advance time and save file
2060 cx.background_executor
2061 .advance_clock(std::time::Duration::from_secs(2));
2062 fs.save(
2063 path!("/root/test.txt").as_ref(),
2064 &"externally modified content".into(),
2065 language::LineEnding::Unix,
2066 )
2067 .await
2068 .unwrap();
2069
2070 // Reload the buffer to pick up the new mtime
2071 let project_path = project
2072 .read_with(cx, |project, cx| {
2073 project.find_project_path("root/test.txt", cx)
2074 })
2075 .expect("Should find project path");
2076 let buffer = project
2077 .update(cx, |project, cx| project.open_buffer(project_path, cx))
2078 .await
2079 .unwrap();
2080 buffer
2081 .update(cx, |buffer, cx| buffer.reload(cx))
2082 .await
2083 .unwrap();
2084
2085 cx.executor().run_until_parked();
2086
2087 // Try to edit - should fail because file was modified externally
2088 let result = cx
2089 .update(|cx| {
2090 edit_tool.clone().run(
2091 EditFileToolInput {
2092 display_description: "Edit after external change".into(),
2093 path: "root/test.txt".into(),
2094 mode: EditFileMode::Edit,
2095 },
2096 ToolCallEventStream::test().0,
2097 cx,
2098 )
2099 })
2100 .await;
2101
2102 assert!(
2103 result.is_err(),
2104 "Edit should fail after external modification"
2105 );
2106 let error_msg = result.unwrap_err().to_string();
2107 assert!(
2108 error_msg.contains("has been modified since you last read it"),
2109 "Error should mention file modification, got: {}",
2110 error_msg
2111 );
2112 }
2113
2114 #[gpui::test]
2115 async fn test_dirty_buffer_detected(cx: &mut TestAppContext) {
2116 init_test(cx);
2117
2118 let fs = project::FakeFs::new(cx.executor());
2119 fs.insert_tree(
2120 "/root",
2121 json!({
2122 "test.txt": "original content"
2123 }),
2124 )
2125 .await;
2126 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
2127 let context_server_registry =
2128 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
2129 let model = Arc::new(FakeLanguageModel::default());
2130 let thread = cx.new(|cx| {
2131 Thread::new(
2132 project.clone(),
2133 cx.new(|_cx| ProjectContext::default()),
2134 context_server_registry,
2135 Templates::new(),
2136 Some(model.clone()),
2137 cx,
2138 )
2139 });
2140 let languages = project.read_with(cx, |project, _| project.languages().clone());
2141 let action_log = thread.read_with(cx, |thread, _| thread.action_log().clone());
2142
2143 let read_tool = Arc::new(crate::ReadFileTool::new(
2144 thread.downgrade(),
2145 project.clone(),
2146 action_log,
2147 ));
2148 let edit_tool = Arc::new(EditFileTool::new(
2149 project.clone(),
2150 thread.downgrade(),
2151 languages,
2152 Templates::new(),
2153 ));
2154
2155 // Read the file first
2156 cx.update(|cx| {
2157 read_tool.clone().run(
2158 crate::ReadFileToolInput {
2159 path: "root/test.txt".to_string(),
2160 start_line: None,
2161 end_line: None,
2162 },
2163 ToolCallEventStream::test().0,
2164 cx,
2165 )
2166 })
2167 .await
2168 .unwrap();
2169
2170 // Open the buffer and make it dirty by editing without saving
2171 let project_path = project
2172 .read_with(cx, |project, cx| {
2173 project.find_project_path("root/test.txt", cx)
2174 })
2175 .expect("Should find project path");
2176 let buffer = project
2177 .update(cx, |project, cx| project.open_buffer(project_path, cx))
2178 .await
2179 .unwrap();
2180
2181 // Make an in-memory edit to the buffer (making it dirty)
2182 buffer.update(cx, |buffer, cx| {
2183 let end_point = buffer.max_point();
2184 buffer.edit([(end_point..end_point, " added text")], None, cx);
2185 });
2186
2187 // Verify buffer is dirty
2188 let is_dirty = buffer.read_with(cx, |buffer, _| buffer.is_dirty());
2189 assert!(is_dirty, "Buffer should be dirty after in-memory edit");
2190
2191 // Try to edit - should fail because buffer has unsaved changes
2192 let result = cx
2193 .update(|cx| {
2194 edit_tool.clone().run(
2195 EditFileToolInput {
2196 display_description: "Edit with dirty buffer".into(),
2197 path: "root/test.txt".into(),
2198 mode: EditFileMode::Edit,
2199 },
2200 ToolCallEventStream::test().0,
2201 cx,
2202 )
2203 })
2204 .await;
2205
2206 assert!(result.is_err(), "Edit should fail when buffer is dirty");
2207 let error_msg = result.unwrap_err().to_string();
2208 assert!(
2209 error_msg.contains("cannot be written to because it has unsaved changes"),
2210 "Error should mention unsaved changes, got: {}",
2211 error_msg
2212 );
2213 }
2214}