1use crate::{
2 AgentTool, Templates, Thread, ToolCallEventStream,
3 edit_agent::{EditAgent, EditAgentOutput, EditAgentOutputEvent, EditFormat},
4};
5use acp_thread::Diff;
6use agent_client_protocol::{self as acp, ToolCallLocation, ToolCallUpdateFields};
7use anyhow::{Context as _, Result, anyhow};
8use cloud_llm_client::CompletionIntent;
9use collections::HashSet;
10use gpui::{App, AppContext, AsyncApp, Entity, Task, WeakEntity};
11use indoc::formatdoc;
12use language::language_settings::{self, FormatOnSave};
13use language::{LanguageRegistry, ToPoint};
14use language_model::LanguageModelToolResultContent;
15use paths;
16use project::lsp_store::{FormatTrigger, LspFormatTarget};
17use project::{Project, ProjectPath};
18use schemars::JsonSchema;
19use serde::{Deserialize, Serialize};
20use settings::Settings;
21use smol::stream::StreamExt as _;
22use std::ffi::OsStr;
23use std::path::{Path, PathBuf};
24use std::sync::Arc;
25use ui::SharedString;
26use util::ResultExt;
27use util::rel_path::RelPath;
28
29const DEFAULT_UI_TEXT: &str = "Editing file";
30
31/// This is a tool for creating a new file or editing an existing file. For moving or renaming files, you should generally use the `terminal` tool with the 'mv' command instead.
32///
33/// Before using this tool:
34///
35/// 1. Use the `read_file` tool to understand the file's contents and context
36///
37/// 2. Verify the directory path is correct (only applicable when creating new files):
38/// - Use the `list_directory` tool to verify the parent directory exists and is the correct location
39#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
40pub struct EditFileToolInput {
41 /// A one-line, user-friendly markdown description of the edit. This will be shown in the UI and also passed to another model to perform the edit.
42 ///
43 /// Be terse, but also descriptive in what you want to achieve with this edit. Avoid generic instructions.
44 ///
45 /// NEVER mention the file path in this description.
46 ///
47 /// <example>Fix API endpoint URLs</example>
48 /// <example>Update copyright year in `page_footer`</example>
49 ///
50 /// Make sure to include this field before all the others in the input object so that we can display it immediately.
51 pub display_description: String,
52
53 /// The full path of the file to create or modify in the project.
54 ///
55 /// WARNING: When specifying which file path need changing, you MUST start each path with one of the project's root directories.
56 ///
57 /// The following examples assume we have two root directories in the project:
58 /// - /a/b/backend
59 /// - /c/d/frontend
60 ///
61 /// <example>
62 /// `backend/src/main.rs`
63 ///
64 /// Notice how the file path starts with `backend`. Without that, the path would be ambiguous and the call would fail!
65 /// </example>
66 ///
67 /// <example>
68 /// `frontend/db.js`
69 /// </example>
70 pub path: PathBuf,
71 /// The mode of operation on the file. Possible values:
72 /// - 'edit': Make granular edits to an existing file.
73 /// - 'create': Create a new file if it doesn't exist.
74 /// - 'overwrite': Replace the entire contents of an existing file.
75 ///
76 /// When a file already exists or you just created it, prefer editing it as opposed to recreating it from scratch.
77 pub mode: EditFileMode,
78}
79
80#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
81struct EditFileToolPartialInput {
82 #[serde(default)]
83 path: String,
84 #[serde(default)]
85 display_description: String,
86}
87
88#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
89#[serde(rename_all = "lowercase")]
90#[schemars(inline)]
91pub enum EditFileMode {
92 Edit,
93 Create,
94 Overwrite,
95}
96
97#[derive(Debug, Serialize, Deserialize)]
98pub struct EditFileToolOutput {
99 #[serde(alias = "original_path")]
100 input_path: PathBuf,
101 new_text: String,
102 old_text: Arc<String>,
103 #[serde(default)]
104 diff: String,
105 #[serde(alias = "raw_output")]
106 edit_agent_output: EditAgentOutput,
107}
108
109impl From<EditFileToolOutput> for LanguageModelToolResultContent {
110 fn from(output: EditFileToolOutput) -> Self {
111 if output.diff.is_empty() {
112 "No edits were made.".into()
113 } else {
114 format!(
115 "Edited {}:\n\n```diff\n{}\n```",
116 output.input_path.display(),
117 output.diff
118 )
119 .into()
120 }
121 }
122}
123
124pub struct EditFileTool {
125 thread: WeakEntity<Thread>,
126 language_registry: Arc<LanguageRegistry>,
127 project: Entity<Project>,
128 templates: Arc<Templates>,
129}
130
131impl EditFileTool {
132 pub fn new(
133 project: Entity<Project>,
134 thread: WeakEntity<Thread>,
135 language_registry: Arc<LanguageRegistry>,
136 templates: Arc<Templates>,
137 ) -> Self {
138 Self {
139 project,
140 thread,
141 language_registry,
142 templates,
143 }
144 }
145
146 fn authorize(
147 &self,
148 input: &EditFileToolInput,
149 event_stream: &ToolCallEventStream,
150 cx: &mut App,
151 ) -> Task<Result<()>> {
152 if agent_settings::AgentSettings::get_global(cx).always_allow_tool_actions {
153 return Task::ready(Ok(()));
154 }
155
156 // If any path component matches the local settings folder, then this could affect
157 // the editor in ways beyond the project source, so prompt.
158 let local_settings_folder = paths::local_settings_folder_name();
159 let path = Path::new(&input.path);
160 if path.components().any(|component| {
161 component.as_os_str() == <_ as AsRef<OsStr>>::as_ref(&local_settings_folder)
162 }) {
163 return event_stream.authorize(
164 format!("{} (local settings)", input.display_description),
165 cx,
166 );
167 }
168
169 // It's also possible that the global config dir is configured to be inside the project,
170 // so check for that edge case too.
171 // TODO this is broken when remoting
172 if let Ok(canonical_path) = std::fs::canonicalize(&input.path)
173 && canonical_path.starts_with(paths::config_dir())
174 {
175 return event_stream.authorize(
176 format!("{} (global settings)", input.display_description),
177 cx,
178 );
179 }
180
181 // Check if path is inside the global config directory
182 // First check if it's already inside project - if not, try to canonicalize
183 let Ok(project_path) = self.thread.read_with(cx, |thread, cx| {
184 thread.project().read(cx).find_project_path(&input.path, cx)
185 }) else {
186 return Task::ready(Err(anyhow!("thread was dropped")));
187 };
188
189 // If the path is inside the project, and it's not one of the above edge cases,
190 // then no confirmation is necessary. Otherwise, confirmation is necessary.
191 if project_path.is_some() {
192 Task::ready(Ok(()))
193 } else {
194 event_stream.authorize(&input.display_description, cx)
195 }
196 }
197}
198
199impl AgentTool for EditFileTool {
200 type Input = EditFileToolInput;
201 type Output = EditFileToolOutput;
202
203 fn name() -> &'static str {
204 "edit_file"
205 }
206
207 fn kind() -> acp::ToolKind {
208 acp::ToolKind::Edit
209 }
210
211 fn initial_title(
212 &self,
213 input: Result<Self::Input, serde_json::Value>,
214 cx: &mut App,
215 ) -> SharedString {
216 match input {
217 Ok(input) => self
218 .project
219 .read(cx)
220 .find_project_path(&input.path, cx)
221 .and_then(|project_path| {
222 self.project
223 .read(cx)
224 .short_full_path_for_project_path(&project_path, cx)
225 })
226 .unwrap_or(input.path.to_string_lossy().into_owned())
227 .into(),
228 Err(raw_input) => {
229 if let Some(input) =
230 serde_json::from_value::<EditFileToolPartialInput>(raw_input).ok()
231 {
232 let path = input.path.trim();
233 if !path.is_empty() {
234 return self
235 .project
236 .read(cx)
237 .find_project_path(&input.path, cx)
238 .and_then(|project_path| {
239 self.project
240 .read(cx)
241 .short_full_path_for_project_path(&project_path, cx)
242 })
243 .unwrap_or(input.path)
244 .into();
245 }
246
247 let description = input.display_description.trim();
248 if !description.is_empty() {
249 return description.to_string().into();
250 }
251 }
252
253 DEFAULT_UI_TEXT.into()
254 }
255 }
256 }
257
258 fn run(
259 self: Arc<Self>,
260 input: Self::Input,
261 event_stream: ToolCallEventStream,
262 cx: &mut App,
263 ) -> Task<Result<Self::Output>> {
264 let Ok(project) = self
265 .thread
266 .read_with(cx, |thread, _cx| thread.project().clone())
267 else {
268 return Task::ready(Err(anyhow!("thread was dropped")));
269 };
270 let project_path = match resolve_path(&input, project.clone(), cx) {
271 Ok(path) => path,
272 Err(err) => return Task::ready(Err(anyhow!(err))),
273 };
274 let abs_path = project.read(cx).absolute_path(&project_path, cx);
275 if let Some(abs_path) = abs_path.clone() {
276 event_stream.update_fields(ToolCallUpdateFields {
277 locations: Some(vec![acp::ToolCallLocation {
278 path: abs_path,
279 line: None,
280 meta: None,
281 }]),
282 ..Default::default()
283 });
284 }
285
286 let authorize = self.authorize(&input, &event_stream, cx);
287 cx.spawn(async move |cx: &mut AsyncApp| {
288 authorize.await?;
289
290 let (request, model, action_log) = self.thread.update(cx, |thread, cx| {
291 let request = thread.build_completion_request(CompletionIntent::ToolResults, cx);
292 (request, thread.model().cloned(), thread.action_log().clone())
293 })?;
294 let request = request?;
295 let model = model.context("No language model configured")?;
296
297 let edit_format = EditFormat::from_model(model.clone())?;
298 let edit_agent = EditAgent::new(
299 model,
300 project.clone(),
301 action_log.clone(),
302 self.templates.clone(),
303 edit_format,
304 );
305
306 let buffer = project
307 .update(cx, |project, cx| {
308 project.open_buffer(project_path.clone(), cx)
309 })?
310 .await?;
311
312 // Check if the file has been modified since the agent last read it
313 if let Some(abs_path) = abs_path.as_ref() {
314 let (last_read_mtime, current_mtime, is_dirty) = self.thread.update(cx, |thread, cx| {
315 let last_read = thread.file_read_times.get(abs_path).copied();
316 let current = buffer.read(cx).file().and_then(|file| file.disk_state().mtime());
317 let dirty = buffer.read(cx).is_dirty();
318 (last_read, current, dirty)
319 })?;
320
321 // Check for unsaved changes first - these indicate modifications we don't know about
322 if is_dirty {
323 anyhow::bail!(
324 "This file cannot be written to because it has unsaved changes. \
325 Please end the current conversation immediately by telling the user you want to write to this file (mention its path explicitly) but you can't write to it because it has unsaved changes. \
326 Ask the user to save that buffer's changes and to inform you when it's ok to proceed."
327 );
328 }
329
330 // Check if the file was modified on disk since we last read it
331 if let (Some(last_read), Some(current)) = (last_read_mtime, current_mtime) {
332 // MTime can be unreliable for comparisons, so our newtype intentionally
333 // doesn't support comparing them. If the mtime at all different
334 // (which could be because of a modification or because e.g. system clock changed),
335 // we pessimistically assume it was modified.
336 if current != last_read {
337 anyhow::bail!(
338 "The file {} has been modified since you last read it. \
339 Please read the file again to get the current state before editing it.",
340 input.path.display()
341 );
342 }
343 }
344 }
345
346 let diff = cx.new(|cx| Diff::new(buffer.clone(), cx))?;
347 event_stream.update_diff(diff.clone());
348 let _finalize_diff = util::defer({
349 let diff = diff.downgrade();
350 let mut cx = cx.clone();
351 move || {
352 diff.update(&mut cx, |diff, cx| diff.finalize(cx)).ok();
353 }
354 });
355
356 let old_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
357 let old_text = cx
358 .background_spawn({
359 let old_snapshot = old_snapshot.clone();
360 async move { Arc::new(old_snapshot.text()) }
361 })
362 .await;
363
364
365 let (output, mut events) = if matches!(input.mode, EditFileMode::Edit) {
366 edit_agent.edit(
367 buffer.clone(),
368 input.display_description.clone(),
369 &request,
370 cx,
371 )
372 } else {
373 edit_agent.overwrite(
374 buffer.clone(),
375 input.display_description.clone(),
376 &request,
377 cx,
378 )
379 };
380
381 let mut hallucinated_old_text = false;
382 let mut ambiguous_ranges = Vec::new();
383 let mut emitted_location = false;
384 while let Some(event) = events.next().await {
385 match event {
386 EditAgentOutputEvent::Edited(range) => {
387 if !emitted_location {
388 let line = buffer.update(cx, |buffer, _cx| {
389 range.start.to_point(&buffer.snapshot()).row
390 }).ok();
391 if let Some(abs_path) = abs_path.clone() {
392 event_stream.update_fields(ToolCallUpdateFields {
393 locations: Some(vec![ToolCallLocation { path: abs_path, line, meta: None }]),
394 ..Default::default()
395 });
396 }
397 emitted_location = true;
398 }
399 },
400 EditAgentOutputEvent::UnresolvedEditRange => hallucinated_old_text = true,
401 EditAgentOutputEvent::AmbiguousEditRange(ranges) => ambiguous_ranges = ranges,
402 EditAgentOutputEvent::ResolvingEditRange(range) => {
403 diff.update(cx, |card, cx| card.reveal_range(range.clone(), cx))?;
404 // if !emitted_location {
405 // let line = buffer.update(cx, |buffer, _cx| {
406 // range.start.to_point(&buffer.snapshot()).row
407 // }).ok();
408 // if let Some(abs_path) = abs_path.clone() {
409 // event_stream.update_fields(ToolCallUpdateFields {
410 // locations: Some(vec![ToolCallLocation { path: abs_path, line }]),
411 // ..Default::default()
412 // });
413 // }
414 // }
415 }
416 }
417 }
418
419 // If format_on_save is enabled, format the buffer
420 let format_on_save_enabled = buffer
421 .read_with(cx, |buffer, cx| {
422 let settings = language_settings::language_settings(
423 buffer.language().map(|l| l.name()),
424 buffer.file(),
425 cx,
426 );
427 settings.format_on_save != FormatOnSave::Off
428 })
429 .unwrap_or(false);
430
431 let edit_agent_output = output.await?;
432
433 if format_on_save_enabled {
434 action_log.update(cx, |log, cx| {
435 log.buffer_edited(buffer.clone(), cx);
436 })?;
437
438 let format_task = project.update(cx, |project, cx| {
439 project.format(
440 HashSet::from_iter([buffer.clone()]),
441 LspFormatTarget::Buffers,
442 false, // Don't push to history since the tool did it.
443 FormatTrigger::Save,
444 cx,
445 )
446 })?;
447 format_task.await.log_err();
448 }
449
450 project
451 .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))?
452 .await?;
453
454 action_log.update(cx, |log, cx| {
455 log.buffer_edited(buffer.clone(), cx);
456 })?;
457
458 // Update the recorded read time after a successful edit so consecutive edits work
459 if let Some(abs_path) = abs_path.as_ref() {
460 if let Some(new_mtime) = buffer.read_with(cx, |buffer, _| {
461 buffer.file().and_then(|file| file.disk_state().mtime())
462 })? {
463 self.thread.update(cx, |thread, _| {
464 thread.file_read_times.insert(abs_path.to_path_buf(), new_mtime);
465 })?;
466 }
467 }
468
469 let new_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
470 let (new_text, unified_diff) = cx
471 .background_spawn({
472 let new_snapshot = new_snapshot.clone();
473 let old_text = old_text.clone();
474 async move {
475 let new_text = new_snapshot.text();
476 let diff = language::unified_diff(&old_text, &new_text);
477 (new_text, diff)
478 }
479 })
480 .await;
481
482 let input_path = input.path.display();
483 if unified_diff.is_empty() {
484 anyhow::ensure!(
485 !hallucinated_old_text,
486 formatdoc! {"
487 Some edits were produced but none of them could be applied.
488 Read the relevant sections of {input_path} again so that
489 I can perform the requested edits.
490 "}
491 );
492 anyhow::ensure!(
493 ambiguous_ranges.is_empty(),
494 {
495 let line_numbers = ambiguous_ranges
496 .iter()
497 .map(|range| range.start.to_string())
498 .collect::<Vec<_>>()
499 .join(", ");
500 formatdoc! {"
501 <old_text> matches more than one position in the file (lines: {line_numbers}). Read the
502 relevant sections of {input_path} again and extend <old_text> so
503 that I can perform the requested edits.
504 "}
505 }
506 );
507 }
508
509 Ok(EditFileToolOutput {
510 input_path: input.path,
511 new_text,
512 old_text,
513 diff: unified_diff,
514 edit_agent_output,
515 })
516 })
517 }
518
519 fn replay(
520 &self,
521 _input: Self::Input,
522 output: Self::Output,
523 event_stream: ToolCallEventStream,
524 cx: &mut App,
525 ) -> Result<()> {
526 event_stream.update_diff(cx.new(|cx| {
527 Diff::finalized(
528 output.input_path.to_string_lossy().into_owned(),
529 Some(output.old_text.to_string()),
530 output.new_text,
531 self.language_registry.clone(),
532 cx,
533 )
534 }));
535 Ok(())
536 }
537}
538
539/// Validate that the file path is valid, meaning:
540///
541/// - For `edit` and `overwrite`, the path must point to an existing file.
542/// - For `create`, the file must not already exist, but it's parent dir must exist.
543fn resolve_path(
544 input: &EditFileToolInput,
545 project: Entity<Project>,
546 cx: &mut App,
547) -> Result<ProjectPath> {
548 let project = project.read(cx);
549
550 match input.mode {
551 EditFileMode::Edit | EditFileMode::Overwrite => {
552 let path = project
553 .find_project_path(&input.path, cx)
554 .context("Can't edit file: path not found")?;
555
556 let entry = project
557 .entry_for_path(&path, cx)
558 .context("Can't edit file: path not found")?;
559
560 anyhow::ensure!(entry.is_file(), "Can't edit file: path is a directory");
561 Ok(path)
562 }
563
564 EditFileMode::Create => {
565 if let Some(path) = project.find_project_path(&input.path, cx) {
566 anyhow::ensure!(
567 project.entry_for_path(&path, cx).is_none(),
568 "Can't create file: file already exists"
569 );
570 }
571
572 let parent_path = input
573 .path
574 .parent()
575 .context("Can't create file: incorrect path")?;
576
577 let parent_project_path = project.find_project_path(&parent_path, cx);
578
579 let parent_entry = parent_project_path
580 .as_ref()
581 .and_then(|path| project.entry_for_path(path, cx))
582 .context("Can't create file: parent directory doesn't exist")?;
583
584 anyhow::ensure!(
585 parent_entry.is_dir(),
586 "Can't create file: parent is not a directory"
587 );
588
589 let file_name = input
590 .path
591 .file_name()
592 .and_then(|file_name| file_name.to_str())
593 .and_then(|file_name| RelPath::unix(file_name).ok())
594 .context("Can't create file: invalid filename")?;
595
596 let new_file_path = parent_project_path.map(|parent| ProjectPath {
597 path: parent.path.join(file_name),
598 ..parent
599 });
600
601 new_file_path.context("Can't create file")
602 }
603 }
604}
605
606#[cfg(test)]
607mod tests {
608 use super::*;
609 use crate::{ContextServerRegistry, Templates};
610 use fs::Fs;
611 use gpui::{TestAppContext, UpdateGlobal};
612 use language_model::fake_provider::FakeLanguageModel;
613 use prompt_store::ProjectContext;
614 use serde_json::json;
615 use settings::SettingsStore;
616 use util::{path, rel_path::rel_path};
617
618 #[gpui::test]
619 async fn test_edit_nonexistent_file(cx: &mut TestAppContext) {
620 init_test(cx);
621
622 let fs = project::FakeFs::new(cx.executor());
623 fs.insert_tree("/root", json!({})).await;
624 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
625 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
626 let context_server_registry =
627 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
628 let model = Arc::new(FakeLanguageModel::default());
629 let thread = cx.new(|cx| {
630 Thread::new(
631 project.clone(),
632 cx.new(|_cx| ProjectContext::default()),
633 context_server_registry,
634 Templates::new(),
635 Some(model),
636 cx,
637 )
638 });
639 let result = cx
640 .update(|cx| {
641 let input = EditFileToolInput {
642 display_description: "Some edit".into(),
643 path: "root/nonexistent_file.txt".into(),
644 mode: EditFileMode::Edit,
645 };
646 Arc::new(EditFileTool::new(
647 project,
648 thread.downgrade(),
649 language_registry,
650 Templates::new(),
651 ))
652 .run(input, ToolCallEventStream::test().0, cx)
653 })
654 .await;
655 assert_eq!(
656 result.unwrap_err().to_string(),
657 "Can't edit file: path not found"
658 );
659 }
660
661 #[gpui::test]
662 async fn test_resolve_path_for_creating_file(cx: &mut TestAppContext) {
663 let mode = &EditFileMode::Create;
664
665 let result = test_resolve_path(mode, "root/new.txt", cx);
666 assert_resolved_path_eq(result.await, rel_path("new.txt"));
667
668 let result = test_resolve_path(mode, "new.txt", cx);
669 assert_resolved_path_eq(result.await, rel_path("new.txt"));
670
671 let result = test_resolve_path(mode, "dir/new.txt", cx);
672 assert_resolved_path_eq(result.await, rel_path("dir/new.txt"));
673
674 let result = test_resolve_path(mode, "root/dir/subdir/existing.txt", cx);
675 assert_eq!(
676 result.await.unwrap_err().to_string(),
677 "Can't create file: file already exists"
678 );
679
680 let result = test_resolve_path(mode, "root/dir/nonexistent_dir/new.txt", cx);
681 assert_eq!(
682 result.await.unwrap_err().to_string(),
683 "Can't create file: parent directory doesn't exist"
684 );
685 }
686
687 #[gpui::test]
688 async fn test_resolve_path_for_editing_file(cx: &mut TestAppContext) {
689 let mode = &EditFileMode::Edit;
690
691 let path_with_root = "root/dir/subdir/existing.txt";
692 let path_without_root = "dir/subdir/existing.txt";
693 let result = test_resolve_path(mode, path_with_root, cx);
694 assert_resolved_path_eq(result.await, rel_path(path_without_root));
695
696 let result = test_resolve_path(mode, path_without_root, cx);
697 assert_resolved_path_eq(result.await, rel_path(path_without_root));
698
699 let result = test_resolve_path(mode, "root/nonexistent.txt", cx);
700 assert_eq!(
701 result.await.unwrap_err().to_string(),
702 "Can't edit file: path not found"
703 );
704
705 let result = test_resolve_path(mode, "root/dir", cx);
706 assert_eq!(
707 result.await.unwrap_err().to_string(),
708 "Can't edit file: path is a directory"
709 );
710 }
711
712 async fn test_resolve_path(
713 mode: &EditFileMode,
714 path: &str,
715 cx: &mut TestAppContext,
716 ) -> anyhow::Result<ProjectPath> {
717 init_test(cx);
718
719 let fs = project::FakeFs::new(cx.executor());
720 fs.insert_tree(
721 "/root",
722 json!({
723 "dir": {
724 "subdir": {
725 "existing.txt": "hello"
726 }
727 }
728 }),
729 )
730 .await;
731 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
732
733 let input = EditFileToolInput {
734 display_description: "Some edit".into(),
735 path: path.into(),
736 mode: mode.clone(),
737 };
738
739 cx.update(|cx| resolve_path(&input, project, cx))
740 }
741
742 #[track_caller]
743 fn assert_resolved_path_eq(path: anyhow::Result<ProjectPath>, expected: &RelPath) {
744 let actual = path.expect("Should return valid path").path;
745 assert_eq!(actual.as_ref(), expected);
746 }
747
748 #[gpui::test]
749 async fn test_format_on_save(cx: &mut TestAppContext) {
750 init_test(cx);
751
752 let fs = project::FakeFs::new(cx.executor());
753 fs.insert_tree("/root", json!({"src": {}})).await;
754
755 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
756
757 // Set up a Rust language with LSP formatting support
758 let rust_language = Arc::new(language::Language::new(
759 language::LanguageConfig {
760 name: "Rust".into(),
761 matcher: language::LanguageMatcher {
762 path_suffixes: vec!["rs".to_string()],
763 ..Default::default()
764 },
765 ..Default::default()
766 },
767 None,
768 ));
769
770 // Register the language and fake LSP
771 let language_registry = project.read_with(cx, |project, _| project.languages().clone());
772 language_registry.add(rust_language);
773
774 let mut fake_language_servers = language_registry.register_fake_lsp(
775 "Rust",
776 language::FakeLspAdapter {
777 capabilities: lsp::ServerCapabilities {
778 document_formatting_provider: Some(lsp::OneOf::Left(true)),
779 ..Default::default()
780 },
781 ..Default::default()
782 },
783 );
784
785 // Create the file
786 fs.save(
787 path!("/root/src/main.rs").as_ref(),
788 &"initial content".into(),
789 language::LineEnding::Unix,
790 )
791 .await
792 .unwrap();
793
794 // Open the buffer to trigger LSP initialization
795 let buffer = project
796 .update(cx, |project, cx| {
797 project.open_local_buffer(path!("/root/src/main.rs"), cx)
798 })
799 .await
800 .unwrap();
801
802 // Register the buffer with language servers
803 let _handle = project.update(cx, |project, cx| {
804 project.register_buffer_with_language_servers(&buffer, cx)
805 });
806
807 const UNFORMATTED_CONTENT: &str = "fn main() {println!(\"Hello!\");}\n";
808 const FORMATTED_CONTENT: &str =
809 "This file was formatted by the fake formatter in the test.\n";
810
811 // Get the fake language server and set up formatting handler
812 let fake_language_server = fake_language_servers.next().await.unwrap();
813 fake_language_server.set_request_handler::<lsp::request::Formatting, _, _>({
814 |_, _| async move {
815 Ok(Some(vec![lsp::TextEdit {
816 range: lsp::Range::new(lsp::Position::new(0, 0), lsp::Position::new(1, 0)),
817 new_text: FORMATTED_CONTENT.to_string(),
818 }]))
819 }
820 });
821
822 let context_server_registry =
823 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
824 let model = Arc::new(FakeLanguageModel::default());
825 let thread = cx.new(|cx| {
826 Thread::new(
827 project.clone(),
828 cx.new(|_cx| ProjectContext::default()),
829 context_server_registry,
830 Templates::new(),
831 Some(model.clone()),
832 cx,
833 )
834 });
835
836 // First, test with format_on_save enabled
837 cx.update(|cx| {
838 SettingsStore::update_global(cx, |store, cx| {
839 store.update_user_settings(cx, |settings| {
840 settings.project.all_languages.defaults.format_on_save = Some(FormatOnSave::On);
841 settings.project.all_languages.defaults.formatter =
842 Some(language::language_settings::FormatterList::default());
843 });
844 });
845 });
846
847 // Have the model stream unformatted content
848 let edit_result = {
849 let edit_task = cx.update(|cx| {
850 let input = EditFileToolInput {
851 display_description: "Create main function".into(),
852 path: "root/src/main.rs".into(),
853 mode: EditFileMode::Overwrite,
854 };
855 Arc::new(EditFileTool::new(
856 project.clone(),
857 thread.downgrade(),
858 language_registry.clone(),
859 Templates::new(),
860 ))
861 .run(input, ToolCallEventStream::test().0, cx)
862 });
863
864 // Stream the unformatted content
865 cx.executor().run_until_parked();
866 model.send_last_completion_stream_text_chunk(UNFORMATTED_CONTENT.to_string());
867 model.end_last_completion_stream();
868
869 edit_task.await
870 };
871 assert!(edit_result.is_ok());
872
873 // Wait for any async operations (e.g. formatting) to complete
874 cx.executor().run_until_parked();
875
876 // Read the file to verify it was formatted automatically
877 let new_content = fs.load(path!("/root/src/main.rs").as_ref()).await.unwrap();
878 assert_eq!(
879 // Ignore carriage returns on Windows
880 new_content.replace("\r\n", "\n"),
881 FORMATTED_CONTENT,
882 "Code should be formatted when format_on_save is enabled"
883 );
884
885 let stale_buffer_count = thread
886 .read_with(cx, |thread, _cx| thread.action_log.clone())
887 .read_with(cx, |log, cx| log.stale_buffers(cx).count());
888
889 assert_eq!(
890 stale_buffer_count, 0,
891 "BUG: Buffer is incorrectly marked as stale after format-on-save. Found {} stale buffers. \
892 This causes the agent to think the file was modified externally when it was just formatted.",
893 stale_buffer_count
894 );
895
896 // Next, test with format_on_save disabled
897 cx.update(|cx| {
898 SettingsStore::update_global(cx, |store, cx| {
899 store.update_user_settings(cx, |settings| {
900 settings.project.all_languages.defaults.format_on_save =
901 Some(FormatOnSave::Off);
902 });
903 });
904 });
905
906 // Stream unformatted edits again
907 let edit_result = {
908 let edit_task = cx.update(|cx| {
909 let input = EditFileToolInput {
910 display_description: "Update main function".into(),
911 path: "root/src/main.rs".into(),
912 mode: EditFileMode::Overwrite,
913 };
914 Arc::new(EditFileTool::new(
915 project.clone(),
916 thread.downgrade(),
917 language_registry,
918 Templates::new(),
919 ))
920 .run(input, ToolCallEventStream::test().0, cx)
921 });
922
923 // Stream the unformatted content
924 cx.executor().run_until_parked();
925 model.send_last_completion_stream_text_chunk(UNFORMATTED_CONTENT.to_string());
926 model.end_last_completion_stream();
927
928 edit_task.await
929 };
930 assert!(edit_result.is_ok());
931
932 // Wait for any async operations (e.g. formatting) to complete
933 cx.executor().run_until_parked();
934
935 // Verify the file was not formatted
936 let new_content = fs.load(path!("/root/src/main.rs").as_ref()).await.unwrap();
937 assert_eq!(
938 // Ignore carriage returns on Windows
939 new_content.replace("\r\n", "\n"),
940 UNFORMATTED_CONTENT,
941 "Code should not be formatted when format_on_save is disabled"
942 );
943 }
944
945 #[gpui::test]
946 async fn test_remove_trailing_whitespace(cx: &mut TestAppContext) {
947 init_test(cx);
948
949 let fs = project::FakeFs::new(cx.executor());
950 fs.insert_tree("/root", json!({"src": {}})).await;
951
952 // Create a simple file with trailing whitespace
953 fs.save(
954 path!("/root/src/main.rs").as_ref(),
955 &"initial content".into(),
956 language::LineEnding::Unix,
957 )
958 .await
959 .unwrap();
960
961 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
962 let context_server_registry =
963 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
964 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
965 let model = Arc::new(FakeLanguageModel::default());
966 let thread = cx.new(|cx| {
967 Thread::new(
968 project.clone(),
969 cx.new(|_cx| ProjectContext::default()),
970 context_server_registry,
971 Templates::new(),
972 Some(model.clone()),
973 cx,
974 )
975 });
976
977 // First, test with remove_trailing_whitespace_on_save enabled
978 cx.update(|cx| {
979 SettingsStore::update_global(cx, |store, cx| {
980 store.update_user_settings(cx, |settings| {
981 settings
982 .project
983 .all_languages
984 .defaults
985 .remove_trailing_whitespace_on_save = Some(true);
986 });
987 });
988 });
989
990 const CONTENT_WITH_TRAILING_WHITESPACE: &str =
991 "fn main() { \n println!(\"Hello!\"); \n}\n";
992
993 // Have the model stream content that contains trailing whitespace
994 let edit_result = {
995 let edit_task = cx.update(|cx| {
996 let input = EditFileToolInput {
997 display_description: "Create main function".into(),
998 path: "root/src/main.rs".into(),
999 mode: EditFileMode::Overwrite,
1000 };
1001 Arc::new(EditFileTool::new(
1002 project.clone(),
1003 thread.downgrade(),
1004 language_registry.clone(),
1005 Templates::new(),
1006 ))
1007 .run(input, ToolCallEventStream::test().0, cx)
1008 });
1009
1010 // Stream the content with trailing whitespace
1011 cx.executor().run_until_parked();
1012 model.send_last_completion_stream_text_chunk(
1013 CONTENT_WITH_TRAILING_WHITESPACE.to_string(),
1014 );
1015 model.end_last_completion_stream();
1016
1017 edit_task.await
1018 };
1019 assert!(edit_result.is_ok());
1020
1021 // Wait for any async operations (e.g. formatting) to complete
1022 cx.executor().run_until_parked();
1023
1024 // Read the file to verify trailing whitespace was removed automatically
1025 assert_eq!(
1026 // Ignore carriage returns on Windows
1027 fs.load(path!("/root/src/main.rs").as_ref())
1028 .await
1029 .unwrap()
1030 .replace("\r\n", "\n"),
1031 "fn main() {\n println!(\"Hello!\");\n}\n",
1032 "Trailing whitespace should be removed when remove_trailing_whitespace_on_save is enabled"
1033 );
1034
1035 // Next, test with remove_trailing_whitespace_on_save disabled
1036 cx.update(|cx| {
1037 SettingsStore::update_global(cx, |store, cx| {
1038 store.update_user_settings(cx, |settings| {
1039 settings
1040 .project
1041 .all_languages
1042 .defaults
1043 .remove_trailing_whitespace_on_save = Some(false);
1044 });
1045 });
1046 });
1047
1048 // Stream edits again with trailing whitespace
1049 let edit_result = {
1050 let edit_task = cx.update(|cx| {
1051 let input = EditFileToolInput {
1052 display_description: "Update main function".into(),
1053 path: "root/src/main.rs".into(),
1054 mode: EditFileMode::Overwrite,
1055 };
1056 Arc::new(EditFileTool::new(
1057 project.clone(),
1058 thread.downgrade(),
1059 language_registry,
1060 Templates::new(),
1061 ))
1062 .run(input, ToolCallEventStream::test().0, cx)
1063 });
1064
1065 // Stream the content with trailing whitespace
1066 cx.executor().run_until_parked();
1067 model.send_last_completion_stream_text_chunk(
1068 CONTENT_WITH_TRAILING_WHITESPACE.to_string(),
1069 );
1070 model.end_last_completion_stream();
1071
1072 edit_task.await
1073 };
1074 assert!(edit_result.is_ok());
1075
1076 // Wait for any async operations (e.g. formatting) to complete
1077 cx.executor().run_until_parked();
1078
1079 // Verify the file still has trailing whitespace
1080 // Read the file again - it should still have trailing whitespace
1081 let final_content = fs.load(path!("/root/src/main.rs").as_ref()).await.unwrap();
1082 assert_eq!(
1083 // Ignore carriage returns on Windows
1084 final_content.replace("\r\n", "\n"),
1085 CONTENT_WITH_TRAILING_WHITESPACE,
1086 "Trailing whitespace should remain when remove_trailing_whitespace_on_save is disabled"
1087 );
1088 }
1089
1090 #[gpui::test]
1091 async fn test_authorize(cx: &mut TestAppContext) {
1092 init_test(cx);
1093 let fs = project::FakeFs::new(cx.executor());
1094 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
1095 let context_server_registry =
1096 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1097 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1098 let model = Arc::new(FakeLanguageModel::default());
1099 let thread = cx.new(|cx| {
1100 Thread::new(
1101 project.clone(),
1102 cx.new(|_cx| ProjectContext::default()),
1103 context_server_registry,
1104 Templates::new(),
1105 Some(model.clone()),
1106 cx,
1107 )
1108 });
1109 let tool = Arc::new(EditFileTool::new(
1110 project.clone(),
1111 thread.downgrade(),
1112 language_registry,
1113 Templates::new(),
1114 ));
1115 fs.insert_tree("/root", json!({})).await;
1116
1117 // Test 1: Path with .zed component should require confirmation
1118 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1119 let _auth = cx.update(|cx| {
1120 tool.authorize(
1121 &EditFileToolInput {
1122 display_description: "test 1".into(),
1123 path: ".zed/settings.json".into(),
1124 mode: EditFileMode::Edit,
1125 },
1126 &stream_tx,
1127 cx,
1128 )
1129 });
1130
1131 let event = stream_rx.expect_authorization().await;
1132 assert_eq!(
1133 event.tool_call.fields.title,
1134 Some("test 1 (local settings)".into())
1135 );
1136
1137 // Test 2: Path outside project should require confirmation
1138 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1139 let _auth = cx.update(|cx| {
1140 tool.authorize(
1141 &EditFileToolInput {
1142 display_description: "test 2".into(),
1143 path: "/etc/hosts".into(),
1144 mode: EditFileMode::Edit,
1145 },
1146 &stream_tx,
1147 cx,
1148 )
1149 });
1150
1151 let event = stream_rx.expect_authorization().await;
1152 assert_eq!(event.tool_call.fields.title, Some("test 2".into()));
1153
1154 // Test 3: Relative path without .zed should not require confirmation
1155 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1156 cx.update(|cx| {
1157 tool.authorize(
1158 &EditFileToolInput {
1159 display_description: "test 3".into(),
1160 path: "root/src/main.rs".into(),
1161 mode: EditFileMode::Edit,
1162 },
1163 &stream_tx,
1164 cx,
1165 )
1166 })
1167 .await
1168 .unwrap();
1169 assert!(stream_rx.try_next().is_err());
1170
1171 // Test 4: Path with .zed in the middle should require confirmation
1172 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1173 let _auth = cx.update(|cx| {
1174 tool.authorize(
1175 &EditFileToolInput {
1176 display_description: "test 4".into(),
1177 path: "root/.zed/tasks.json".into(),
1178 mode: EditFileMode::Edit,
1179 },
1180 &stream_tx,
1181 cx,
1182 )
1183 });
1184 let event = stream_rx.expect_authorization().await;
1185 assert_eq!(
1186 event.tool_call.fields.title,
1187 Some("test 4 (local settings)".into())
1188 );
1189
1190 // Test 5: When always_allow_tool_actions is enabled, no confirmation needed
1191 cx.update(|cx| {
1192 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
1193 settings.always_allow_tool_actions = true;
1194 agent_settings::AgentSettings::override_global(settings, cx);
1195 });
1196
1197 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1198 cx.update(|cx| {
1199 tool.authorize(
1200 &EditFileToolInput {
1201 display_description: "test 5.1".into(),
1202 path: ".zed/settings.json".into(),
1203 mode: EditFileMode::Edit,
1204 },
1205 &stream_tx,
1206 cx,
1207 )
1208 })
1209 .await
1210 .unwrap();
1211 assert!(stream_rx.try_next().is_err());
1212
1213 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1214 cx.update(|cx| {
1215 tool.authorize(
1216 &EditFileToolInput {
1217 display_description: "test 5.2".into(),
1218 path: "/etc/hosts".into(),
1219 mode: EditFileMode::Edit,
1220 },
1221 &stream_tx,
1222 cx,
1223 )
1224 })
1225 .await
1226 .unwrap();
1227 assert!(stream_rx.try_next().is_err());
1228 }
1229
1230 #[gpui::test]
1231 async fn test_authorize_global_config(cx: &mut TestAppContext) {
1232 init_test(cx);
1233 let fs = project::FakeFs::new(cx.executor());
1234 fs.insert_tree("/project", json!({})).await;
1235 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1236 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1237 let context_server_registry =
1238 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1239 let model = Arc::new(FakeLanguageModel::default());
1240 let thread = cx.new(|cx| {
1241 Thread::new(
1242 project.clone(),
1243 cx.new(|_cx| ProjectContext::default()),
1244 context_server_registry,
1245 Templates::new(),
1246 Some(model.clone()),
1247 cx,
1248 )
1249 });
1250 let tool = Arc::new(EditFileTool::new(
1251 project.clone(),
1252 thread.downgrade(),
1253 language_registry,
1254 Templates::new(),
1255 ));
1256
1257 // Test global config paths - these should require confirmation if they exist and are outside the project
1258 let test_cases = vec![
1259 (
1260 "/etc/hosts",
1261 true,
1262 "System file should require confirmation",
1263 ),
1264 (
1265 "/usr/local/bin/script",
1266 true,
1267 "System bin file should require confirmation",
1268 ),
1269 (
1270 "project/normal_file.rs",
1271 false,
1272 "Normal project file should not require confirmation",
1273 ),
1274 ];
1275
1276 for (path, should_confirm, description) in test_cases {
1277 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1278 let auth = cx.update(|cx| {
1279 tool.authorize(
1280 &EditFileToolInput {
1281 display_description: "Edit file".into(),
1282 path: path.into(),
1283 mode: EditFileMode::Edit,
1284 },
1285 &stream_tx,
1286 cx,
1287 )
1288 });
1289
1290 if should_confirm {
1291 stream_rx.expect_authorization().await;
1292 } else {
1293 auth.await.unwrap();
1294 assert!(
1295 stream_rx.try_next().is_err(),
1296 "Failed for case: {} - path: {} - expected no confirmation but got one",
1297 description,
1298 path
1299 );
1300 }
1301 }
1302 }
1303
1304 #[gpui::test]
1305 async fn test_needs_confirmation_with_multiple_worktrees(cx: &mut TestAppContext) {
1306 init_test(cx);
1307 let fs = project::FakeFs::new(cx.executor());
1308
1309 // Create multiple worktree directories
1310 fs.insert_tree(
1311 "/workspace/frontend",
1312 json!({
1313 "src": {
1314 "main.js": "console.log('frontend');"
1315 }
1316 }),
1317 )
1318 .await;
1319 fs.insert_tree(
1320 "/workspace/backend",
1321 json!({
1322 "src": {
1323 "main.rs": "fn main() {}"
1324 }
1325 }),
1326 )
1327 .await;
1328 fs.insert_tree(
1329 "/workspace/shared",
1330 json!({
1331 ".zed": {
1332 "settings.json": "{}"
1333 }
1334 }),
1335 )
1336 .await;
1337
1338 // Create project with multiple worktrees
1339 let project = Project::test(
1340 fs.clone(),
1341 [
1342 path!("/workspace/frontend").as_ref(),
1343 path!("/workspace/backend").as_ref(),
1344 path!("/workspace/shared").as_ref(),
1345 ],
1346 cx,
1347 )
1348 .await;
1349 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1350 let context_server_registry =
1351 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1352 let model = Arc::new(FakeLanguageModel::default());
1353 let thread = cx.new(|cx| {
1354 Thread::new(
1355 project.clone(),
1356 cx.new(|_cx| ProjectContext::default()),
1357 context_server_registry.clone(),
1358 Templates::new(),
1359 Some(model.clone()),
1360 cx,
1361 )
1362 });
1363 let tool = Arc::new(EditFileTool::new(
1364 project.clone(),
1365 thread.downgrade(),
1366 language_registry,
1367 Templates::new(),
1368 ));
1369
1370 // Test files in different worktrees
1371 let test_cases = vec![
1372 ("frontend/src/main.js", false, "File in first worktree"),
1373 ("backend/src/main.rs", false, "File in second worktree"),
1374 (
1375 "shared/.zed/settings.json",
1376 true,
1377 ".zed file in third worktree",
1378 ),
1379 ("/etc/hosts", true, "Absolute path outside all worktrees"),
1380 (
1381 "../outside/file.txt",
1382 true,
1383 "Relative path outside worktrees",
1384 ),
1385 ];
1386
1387 for (path, should_confirm, description) in test_cases {
1388 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1389 let auth = cx.update(|cx| {
1390 tool.authorize(
1391 &EditFileToolInput {
1392 display_description: "Edit file".into(),
1393 path: path.into(),
1394 mode: EditFileMode::Edit,
1395 },
1396 &stream_tx,
1397 cx,
1398 )
1399 });
1400
1401 if should_confirm {
1402 stream_rx.expect_authorization().await;
1403 } else {
1404 auth.await.unwrap();
1405 assert!(
1406 stream_rx.try_next().is_err(),
1407 "Failed for case: {} - path: {} - expected no confirmation but got one",
1408 description,
1409 path
1410 );
1411 }
1412 }
1413 }
1414
1415 #[gpui::test]
1416 async fn test_needs_confirmation_edge_cases(cx: &mut TestAppContext) {
1417 init_test(cx);
1418 let fs = project::FakeFs::new(cx.executor());
1419 fs.insert_tree(
1420 "/project",
1421 json!({
1422 ".zed": {
1423 "settings.json": "{}"
1424 },
1425 "src": {
1426 ".zed": {
1427 "local.json": "{}"
1428 }
1429 }
1430 }),
1431 )
1432 .await;
1433 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1434 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1435 let context_server_registry =
1436 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1437 let model = Arc::new(FakeLanguageModel::default());
1438 let thread = cx.new(|cx| {
1439 Thread::new(
1440 project.clone(),
1441 cx.new(|_cx| ProjectContext::default()),
1442 context_server_registry.clone(),
1443 Templates::new(),
1444 Some(model.clone()),
1445 cx,
1446 )
1447 });
1448 let tool = Arc::new(EditFileTool::new(
1449 project.clone(),
1450 thread.downgrade(),
1451 language_registry,
1452 Templates::new(),
1453 ));
1454
1455 // Test edge cases
1456 let test_cases = vec![
1457 // Empty path - find_project_path returns Some for empty paths
1458 ("", false, "Empty path is treated as project root"),
1459 // Root directory
1460 ("/", true, "Root directory should be outside project"),
1461 // Parent directory references - find_project_path resolves these
1462 (
1463 "project/../other",
1464 true,
1465 "Path with .. that goes outside of root directory",
1466 ),
1467 (
1468 "project/./src/file.rs",
1469 false,
1470 "Path with . should work normally",
1471 ),
1472 // Windows-style paths (if on Windows)
1473 #[cfg(target_os = "windows")]
1474 ("C:\\Windows\\System32\\hosts", true, "Windows system path"),
1475 #[cfg(target_os = "windows")]
1476 ("project\\src\\main.rs", false, "Windows-style project path"),
1477 ];
1478
1479 for (path, should_confirm, description) in test_cases {
1480 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1481 let auth = cx.update(|cx| {
1482 tool.authorize(
1483 &EditFileToolInput {
1484 display_description: "Edit file".into(),
1485 path: path.into(),
1486 mode: EditFileMode::Edit,
1487 },
1488 &stream_tx,
1489 cx,
1490 )
1491 });
1492
1493 cx.run_until_parked();
1494
1495 if should_confirm {
1496 stream_rx.expect_authorization().await;
1497 } else {
1498 assert!(
1499 stream_rx.try_next().is_err(),
1500 "Failed for case: {} - path: {} - expected no confirmation but got one",
1501 description,
1502 path
1503 );
1504 auth.await.unwrap();
1505 }
1506 }
1507 }
1508
1509 #[gpui::test]
1510 async fn test_needs_confirmation_with_different_modes(cx: &mut TestAppContext) {
1511 init_test(cx);
1512 let fs = project::FakeFs::new(cx.executor());
1513 fs.insert_tree(
1514 "/project",
1515 json!({
1516 "existing.txt": "content",
1517 ".zed": {
1518 "settings.json": "{}"
1519 }
1520 }),
1521 )
1522 .await;
1523 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1524 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1525 let context_server_registry =
1526 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1527 let model = Arc::new(FakeLanguageModel::default());
1528 let thread = cx.new(|cx| {
1529 Thread::new(
1530 project.clone(),
1531 cx.new(|_cx| ProjectContext::default()),
1532 context_server_registry.clone(),
1533 Templates::new(),
1534 Some(model.clone()),
1535 cx,
1536 )
1537 });
1538 let tool = Arc::new(EditFileTool::new(
1539 project.clone(),
1540 thread.downgrade(),
1541 language_registry,
1542 Templates::new(),
1543 ));
1544
1545 // Test different EditFileMode values
1546 let modes = vec![
1547 EditFileMode::Edit,
1548 EditFileMode::Create,
1549 EditFileMode::Overwrite,
1550 ];
1551
1552 for mode in modes {
1553 // Test .zed path with different modes
1554 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1555 let _auth = cx.update(|cx| {
1556 tool.authorize(
1557 &EditFileToolInput {
1558 display_description: "Edit settings".into(),
1559 path: "project/.zed/settings.json".into(),
1560 mode: mode.clone(),
1561 },
1562 &stream_tx,
1563 cx,
1564 )
1565 });
1566
1567 stream_rx.expect_authorization().await;
1568
1569 // Test outside path with different modes
1570 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1571 let _auth = cx.update(|cx| {
1572 tool.authorize(
1573 &EditFileToolInput {
1574 display_description: "Edit file".into(),
1575 path: "/outside/file.txt".into(),
1576 mode: mode.clone(),
1577 },
1578 &stream_tx,
1579 cx,
1580 )
1581 });
1582
1583 stream_rx.expect_authorization().await;
1584
1585 // Test normal path with different modes
1586 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1587 cx.update(|cx| {
1588 tool.authorize(
1589 &EditFileToolInput {
1590 display_description: "Edit file".into(),
1591 path: "project/normal.txt".into(),
1592 mode: mode.clone(),
1593 },
1594 &stream_tx,
1595 cx,
1596 )
1597 })
1598 .await
1599 .unwrap();
1600 assert!(stream_rx.try_next().is_err());
1601 }
1602 }
1603
1604 #[gpui::test]
1605 async fn test_initial_title_with_partial_input(cx: &mut TestAppContext) {
1606 init_test(cx);
1607 let fs = project::FakeFs::new(cx.executor());
1608 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1609 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1610 let context_server_registry =
1611 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1612 let model = Arc::new(FakeLanguageModel::default());
1613 let thread = cx.new(|cx| {
1614 Thread::new(
1615 project.clone(),
1616 cx.new(|_cx| ProjectContext::default()),
1617 context_server_registry,
1618 Templates::new(),
1619 Some(model.clone()),
1620 cx,
1621 )
1622 });
1623 let tool = Arc::new(EditFileTool::new(
1624 project,
1625 thread.downgrade(),
1626 language_registry,
1627 Templates::new(),
1628 ));
1629
1630 cx.update(|cx| {
1631 // ...
1632 assert_eq!(
1633 tool.initial_title(
1634 Err(json!({
1635 "path": "src/main.rs",
1636 "display_description": "",
1637 "old_string": "old code",
1638 "new_string": "new code"
1639 })),
1640 cx
1641 ),
1642 "src/main.rs"
1643 );
1644 assert_eq!(
1645 tool.initial_title(
1646 Err(json!({
1647 "path": "",
1648 "display_description": "Fix error handling",
1649 "old_string": "old code",
1650 "new_string": "new code"
1651 })),
1652 cx
1653 ),
1654 "Fix error handling"
1655 );
1656 assert_eq!(
1657 tool.initial_title(
1658 Err(json!({
1659 "path": "src/main.rs",
1660 "display_description": "Fix error handling",
1661 "old_string": "old code",
1662 "new_string": "new code"
1663 })),
1664 cx
1665 ),
1666 "src/main.rs"
1667 );
1668 assert_eq!(
1669 tool.initial_title(
1670 Err(json!({
1671 "path": "",
1672 "display_description": "",
1673 "old_string": "old code",
1674 "new_string": "new code"
1675 })),
1676 cx
1677 ),
1678 DEFAULT_UI_TEXT
1679 );
1680 assert_eq!(
1681 tool.initial_title(Err(serde_json::Value::Null), cx),
1682 DEFAULT_UI_TEXT
1683 );
1684 });
1685 }
1686
1687 #[gpui::test]
1688 async fn test_diff_finalization(cx: &mut TestAppContext) {
1689 init_test(cx);
1690 let fs = project::FakeFs::new(cx.executor());
1691 fs.insert_tree("/", json!({"main.rs": ""})).await;
1692
1693 let project = Project::test(fs.clone(), [path!("/").as_ref()], cx).await;
1694 let languages = project.read_with(cx, |project, _cx| project.languages().clone());
1695 let context_server_registry =
1696 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1697 let model = Arc::new(FakeLanguageModel::default());
1698 let thread = cx.new(|cx| {
1699 Thread::new(
1700 project.clone(),
1701 cx.new(|_cx| ProjectContext::default()),
1702 context_server_registry.clone(),
1703 Templates::new(),
1704 Some(model.clone()),
1705 cx,
1706 )
1707 });
1708
1709 // Ensure the diff is finalized after the edit completes.
1710 {
1711 let tool = Arc::new(EditFileTool::new(
1712 project.clone(),
1713 thread.downgrade(),
1714 languages.clone(),
1715 Templates::new(),
1716 ));
1717 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1718 let edit = cx.update(|cx| {
1719 tool.run(
1720 EditFileToolInput {
1721 display_description: "Edit file".into(),
1722 path: path!("/main.rs").into(),
1723 mode: EditFileMode::Edit,
1724 },
1725 stream_tx,
1726 cx,
1727 )
1728 });
1729 stream_rx.expect_update_fields().await;
1730 let diff = stream_rx.expect_diff().await;
1731 diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Pending(_))));
1732 cx.run_until_parked();
1733 model.end_last_completion_stream();
1734 edit.await.unwrap();
1735 diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Finalized(_))));
1736 }
1737
1738 // Ensure the diff is finalized if an error occurs while editing.
1739 {
1740 model.forbid_requests();
1741 let tool = Arc::new(EditFileTool::new(
1742 project.clone(),
1743 thread.downgrade(),
1744 languages.clone(),
1745 Templates::new(),
1746 ));
1747 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1748 let edit = cx.update(|cx| {
1749 tool.run(
1750 EditFileToolInput {
1751 display_description: "Edit file".into(),
1752 path: path!("/main.rs").into(),
1753 mode: EditFileMode::Edit,
1754 },
1755 stream_tx,
1756 cx,
1757 )
1758 });
1759 stream_rx.expect_update_fields().await;
1760 let diff = stream_rx.expect_diff().await;
1761 diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Pending(_))));
1762 edit.await.unwrap_err();
1763 diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Finalized(_))));
1764 model.allow_requests();
1765 }
1766
1767 // Ensure the diff is finalized if the tool call gets dropped.
1768 {
1769 let tool = Arc::new(EditFileTool::new(
1770 project.clone(),
1771 thread.downgrade(),
1772 languages.clone(),
1773 Templates::new(),
1774 ));
1775 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1776 let edit = cx.update(|cx| {
1777 tool.run(
1778 EditFileToolInput {
1779 display_description: "Edit file".into(),
1780 path: path!("/main.rs").into(),
1781 mode: EditFileMode::Edit,
1782 },
1783 stream_tx,
1784 cx,
1785 )
1786 });
1787 stream_rx.expect_update_fields().await;
1788 let diff = stream_rx.expect_diff().await;
1789 diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Pending(_))));
1790 drop(edit);
1791 cx.run_until_parked();
1792 diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Finalized(_))));
1793 }
1794 }
1795
1796 #[gpui::test]
1797 async fn test_file_read_times_tracking(cx: &mut TestAppContext) {
1798 init_test(cx);
1799
1800 let fs = project::FakeFs::new(cx.executor());
1801 fs.insert_tree(
1802 "/root",
1803 json!({
1804 "test.txt": "original content"
1805 }),
1806 )
1807 .await;
1808 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
1809 let context_server_registry =
1810 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1811 let model = Arc::new(FakeLanguageModel::default());
1812 let thread = cx.new(|cx| {
1813 Thread::new(
1814 project.clone(),
1815 cx.new(|_cx| ProjectContext::default()),
1816 context_server_registry,
1817 Templates::new(),
1818 Some(model.clone()),
1819 cx,
1820 )
1821 });
1822 let action_log = thread.read_with(cx, |thread, _| thread.action_log().clone());
1823
1824 // Initially, file_read_times should be empty
1825 let is_empty = thread.read_with(cx, |thread, _| thread.file_read_times.is_empty());
1826 assert!(is_empty, "file_read_times should start empty");
1827
1828 // Create read tool
1829 let read_tool = Arc::new(crate::ReadFileTool::new(
1830 thread.downgrade(),
1831 project.clone(),
1832 action_log,
1833 ));
1834
1835 // Read the file to record the read time
1836 cx.update(|cx| {
1837 read_tool.clone().run(
1838 crate::ReadFileToolInput {
1839 path: "root/test.txt".to_string(),
1840 start_line: None,
1841 end_line: None,
1842 },
1843 ToolCallEventStream::test().0,
1844 cx,
1845 )
1846 })
1847 .await
1848 .unwrap();
1849
1850 // Verify that file_read_times now contains an entry for the file
1851 let has_entry = thread.read_with(cx, |thread, _| {
1852 thread.file_read_times.len() == 1
1853 && thread
1854 .file_read_times
1855 .keys()
1856 .any(|path| path.ends_with("test.txt"))
1857 });
1858 assert!(
1859 has_entry,
1860 "file_read_times should contain an entry after reading the file"
1861 );
1862
1863 // Read the file again - should update the entry
1864 cx.update(|cx| {
1865 read_tool.clone().run(
1866 crate::ReadFileToolInput {
1867 path: "root/test.txt".to_string(),
1868 start_line: None,
1869 end_line: None,
1870 },
1871 ToolCallEventStream::test().0,
1872 cx,
1873 )
1874 })
1875 .await
1876 .unwrap();
1877
1878 // Should still have exactly one entry
1879 let has_one_entry = thread.read_with(cx, |thread, _| thread.file_read_times.len() == 1);
1880 assert!(
1881 has_one_entry,
1882 "file_read_times should still have one entry after re-reading"
1883 );
1884 }
1885
1886 fn init_test(cx: &mut TestAppContext) {
1887 cx.update(|cx| {
1888 let settings_store = SettingsStore::test(cx);
1889 cx.set_global(settings_store);
1890 });
1891 }
1892
1893 #[gpui::test]
1894 async fn test_consecutive_edits_work(cx: &mut TestAppContext) {
1895 init_test(cx);
1896
1897 let fs = project::FakeFs::new(cx.executor());
1898 fs.insert_tree(
1899 "/root",
1900 json!({
1901 "test.txt": "original content"
1902 }),
1903 )
1904 .await;
1905 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
1906 let context_server_registry =
1907 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1908 let model = Arc::new(FakeLanguageModel::default());
1909 let thread = cx.new(|cx| {
1910 Thread::new(
1911 project.clone(),
1912 cx.new(|_cx| ProjectContext::default()),
1913 context_server_registry,
1914 Templates::new(),
1915 Some(model.clone()),
1916 cx,
1917 )
1918 });
1919 let languages = project.read_with(cx, |project, _| project.languages().clone());
1920 let action_log = thread.read_with(cx, |thread, _| thread.action_log().clone());
1921
1922 let read_tool = Arc::new(crate::ReadFileTool::new(
1923 thread.downgrade(),
1924 project.clone(),
1925 action_log,
1926 ));
1927 let edit_tool = Arc::new(EditFileTool::new(
1928 project.clone(),
1929 thread.downgrade(),
1930 languages,
1931 Templates::new(),
1932 ));
1933
1934 // Read the file first
1935 cx.update(|cx| {
1936 read_tool.clone().run(
1937 crate::ReadFileToolInput {
1938 path: "root/test.txt".to_string(),
1939 start_line: None,
1940 end_line: None,
1941 },
1942 ToolCallEventStream::test().0,
1943 cx,
1944 )
1945 })
1946 .await
1947 .unwrap();
1948
1949 // First edit should work
1950 let edit_result = {
1951 let edit_task = cx.update(|cx| {
1952 edit_tool.clone().run(
1953 EditFileToolInput {
1954 display_description: "First edit".into(),
1955 path: "root/test.txt".into(),
1956 mode: EditFileMode::Edit,
1957 },
1958 ToolCallEventStream::test().0,
1959 cx,
1960 )
1961 });
1962
1963 cx.executor().run_until_parked();
1964 model.send_last_completion_stream_text_chunk(
1965 "<old_text>original content</old_text><new_text>modified content</new_text>"
1966 .to_string(),
1967 );
1968 model.end_last_completion_stream();
1969
1970 edit_task.await
1971 };
1972 assert!(
1973 edit_result.is_ok(),
1974 "First edit should succeed, got error: {:?}",
1975 edit_result.as_ref().err()
1976 );
1977
1978 // Second edit should also work because the edit updated the recorded read time
1979 let edit_result = {
1980 let edit_task = cx.update(|cx| {
1981 edit_tool.clone().run(
1982 EditFileToolInput {
1983 display_description: "Second edit".into(),
1984 path: "root/test.txt".into(),
1985 mode: EditFileMode::Edit,
1986 },
1987 ToolCallEventStream::test().0,
1988 cx,
1989 )
1990 });
1991
1992 cx.executor().run_until_parked();
1993 model.send_last_completion_stream_text_chunk(
1994 "<old_text>modified content</old_text><new_text>further modified content</new_text>".to_string(),
1995 );
1996 model.end_last_completion_stream();
1997
1998 edit_task.await
1999 };
2000 assert!(
2001 edit_result.is_ok(),
2002 "Second consecutive edit should succeed, got error: {:?}",
2003 edit_result.as_ref().err()
2004 );
2005 }
2006
2007 #[gpui::test]
2008 async fn test_external_modification_detected(cx: &mut TestAppContext) {
2009 init_test(cx);
2010
2011 let fs = project::FakeFs::new(cx.executor());
2012 fs.insert_tree(
2013 "/root",
2014 json!({
2015 "test.txt": "original content"
2016 }),
2017 )
2018 .await;
2019 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
2020 let context_server_registry =
2021 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
2022 let model = Arc::new(FakeLanguageModel::default());
2023 let thread = cx.new(|cx| {
2024 Thread::new(
2025 project.clone(),
2026 cx.new(|_cx| ProjectContext::default()),
2027 context_server_registry,
2028 Templates::new(),
2029 Some(model.clone()),
2030 cx,
2031 )
2032 });
2033 let languages = project.read_with(cx, |project, _| project.languages().clone());
2034 let action_log = thread.read_with(cx, |thread, _| thread.action_log().clone());
2035
2036 let read_tool = Arc::new(crate::ReadFileTool::new(
2037 thread.downgrade(),
2038 project.clone(),
2039 action_log,
2040 ));
2041 let edit_tool = Arc::new(EditFileTool::new(
2042 project.clone(),
2043 thread.downgrade(),
2044 languages,
2045 Templates::new(),
2046 ));
2047
2048 // Read the file first
2049 cx.update(|cx| {
2050 read_tool.clone().run(
2051 crate::ReadFileToolInput {
2052 path: "root/test.txt".to_string(),
2053 start_line: None,
2054 end_line: None,
2055 },
2056 ToolCallEventStream::test().0,
2057 cx,
2058 )
2059 })
2060 .await
2061 .unwrap();
2062
2063 // Simulate external modification - advance time and save file
2064 cx.background_executor
2065 .advance_clock(std::time::Duration::from_secs(2));
2066 fs.save(
2067 path!("/root/test.txt").as_ref(),
2068 &"externally modified content".into(),
2069 language::LineEnding::Unix,
2070 )
2071 .await
2072 .unwrap();
2073
2074 // Reload the buffer to pick up the new mtime
2075 let project_path = project
2076 .read_with(cx, |project, cx| {
2077 project.find_project_path("root/test.txt", cx)
2078 })
2079 .expect("Should find project path");
2080 let buffer = project
2081 .update(cx, |project, cx| project.open_buffer(project_path, cx))
2082 .await
2083 .unwrap();
2084 buffer
2085 .update(cx, |buffer, cx| buffer.reload(cx))
2086 .await
2087 .unwrap();
2088
2089 cx.executor().run_until_parked();
2090
2091 // Try to edit - should fail because file was modified externally
2092 let result = cx
2093 .update(|cx| {
2094 edit_tool.clone().run(
2095 EditFileToolInput {
2096 display_description: "Edit after external change".into(),
2097 path: "root/test.txt".into(),
2098 mode: EditFileMode::Edit,
2099 },
2100 ToolCallEventStream::test().0,
2101 cx,
2102 )
2103 })
2104 .await;
2105
2106 assert!(
2107 result.is_err(),
2108 "Edit should fail after external modification"
2109 );
2110 let error_msg = result.unwrap_err().to_string();
2111 assert!(
2112 error_msg.contains("has been modified since you last read it"),
2113 "Error should mention file modification, got: {}",
2114 error_msg
2115 );
2116 }
2117
2118 #[gpui::test]
2119 async fn test_dirty_buffer_detected(cx: &mut TestAppContext) {
2120 init_test(cx);
2121
2122 let fs = project::FakeFs::new(cx.executor());
2123 fs.insert_tree(
2124 "/root",
2125 json!({
2126 "test.txt": "original content"
2127 }),
2128 )
2129 .await;
2130 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
2131 let context_server_registry =
2132 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
2133 let model = Arc::new(FakeLanguageModel::default());
2134 let thread = cx.new(|cx| {
2135 Thread::new(
2136 project.clone(),
2137 cx.new(|_cx| ProjectContext::default()),
2138 context_server_registry,
2139 Templates::new(),
2140 Some(model.clone()),
2141 cx,
2142 )
2143 });
2144 let languages = project.read_with(cx, |project, _| project.languages().clone());
2145 let action_log = thread.read_with(cx, |thread, _| thread.action_log().clone());
2146
2147 let read_tool = Arc::new(crate::ReadFileTool::new(
2148 thread.downgrade(),
2149 project.clone(),
2150 action_log,
2151 ));
2152 let edit_tool = Arc::new(EditFileTool::new(
2153 project.clone(),
2154 thread.downgrade(),
2155 languages,
2156 Templates::new(),
2157 ));
2158
2159 // Read the file first
2160 cx.update(|cx| {
2161 read_tool.clone().run(
2162 crate::ReadFileToolInput {
2163 path: "root/test.txt".to_string(),
2164 start_line: None,
2165 end_line: None,
2166 },
2167 ToolCallEventStream::test().0,
2168 cx,
2169 )
2170 })
2171 .await
2172 .unwrap();
2173
2174 // Open the buffer and make it dirty by editing without saving
2175 let project_path = project
2176 .read_with(cx, |project, cx| {
2177 project.find_project_path("root/test.txt", cx)
2178 })
2179 .expect("Should find project path");
2180 let buffer = project
2181 .update(cx, |project, cx| project.open_buffer(project_path, cx))
2182 .await
2183 .unwrap();
2184
2185 // Make an in-memory edit to the buffer (making it dirty)
2186 buffer.update(cx, |buffer, cx| {
2187 let end_point = buffer.max_point();
2188 buffer.edit([(end_point..end_point, " added text")], None, cx);
2189 });
2190
2191 // Verify buffer is dirty
2192 let is_dirty = buffer.read_with(cx, |buffer, _| buffer.is_dirty());
2193 assert!(is_dirty, "Buffer should be dirty after in-memory edit");
2194
2195 // Try to edit - should fail because buffer has unsaved changes
2196 let result = cx
2197 .update(|cx| {
2198 edit_tool.clone().run(
2199 EditFileToolInput {
2200 display_description: "Edit with dirty buffer".into(),
2201 path: "root/test.txt".into(),
2202 mode: EditFileMode::Edit,
2203 },
2204 ToolCallEventStream::test().0,
2205 cx,
2206 )
2207 })
2208 .await;
2209
2210 assert!(result.is_err(), "Edit should fail when buffer is dirty");
2211 let error_msg = result.unwrap_err().to_string();
2212 assert!(
2213 error_msg.contains("cannot be written to because it has unsaved changes"),
2214 "Error should mention unsaved changes, got: {}",
2215 error_msg
2216 );
2217 }
2218}