1use crate::{
2 AgentTool, Templates, Thread, ToolCallEventStream,
3 edit_agent::{EditAgent, EditAgentOutput, EditAgentOutputEvent, EditFormat},
4};
5use acp_thread::Diff;
6use agent_client_protocol::{self as acp, ToolCallLocation, ToolCallUpdateFields};
7use anyhow::{Context as _, Result, anyhow};
8use cloud_llm_client::CompletionIntent;
9use collections::HashSet;
10use gpui::{App, AppContext, AsyncApp, Entity, Task, WeakEntity};
11use indoc::formatdoc;
12use language::language_settings::{self, FormatOnSave};
13use language::{LanguageRegistry, ToPoint};
14use language_model::LanguageModelToolResultContent;
15use paths;
16use project::lsp_store::{FormatTrigger, LspFormatTarget};
17use project::{Project, ProjectPath};
18use schemars::JsonSchema;
19use serde::{Deserialize, Serialize};
20use settings::Settings;
21use smol::stream::StreamExt as _;
22use std::ffi::OsStr;
23use std::path::{Path, PathBuf};
24use std::sync::Arc;
25use ui::SharedString;
26use util::ResultExt;
27use util::rel_path::RelPath;
28
29const DEFAULT_UI_TEXT: &str = "Editing file";
30
31/// This is a tool for creating a new file or editing an existing file. For moving or renaming files, you should generally use the `terminal` tool with the 'mv' command instead.
32///
33/// Before using this tool:
34///
35/// 1. Use the `read_file` tool to understand the file's contents and context
36///
37/// 2. Verify the directory path is correct (only applicable when creating new files):
38/// - Use the `list_directory` tool to verify the parent directory exists and is the correct location
39#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
40pub struct EditFileToolInput {
41 /// A one-line, user-friendly markdown description of the edit. This will be shown in the UI and also passed to another model to perform the edit.
42 ///
43 /// Be terse, but also descriptive in what you want to achieve with this edit. Avoid generic instructions.
44 ///
45 /// NEVER mention the file path in this description.
46 ///
47 /// <example>Fix API endpoint URLs</example>
48 /// <example>Update copyright year in `page_footer`</example>
49 ///
50 /// Make sure to include this field before all the others in the input object so that we can display it immediately.
51 pub display_description: String,
52
53 /// The full path of the file to create or modify in the project.
54 ///
55 /// WARNING: When specifying which file path need changing, you MUST start each path with one of the project's root directories.
56 ///
57 /// The following examples assume we have two root directories in the project:
58 /// - /a/b/backend
59 /// - /c/d/frontend
60 ///
61 /// <example>
62 /// `backend/src/main.rs`
63 ///
64 /// Notice how the file path starts with `backend`. Without that, the path would be ambiguous and the call would fail!
65 /// </example>
66 ///
67 /// <example>
68 /// `frontend/db.js`
69 /// </example>
70 pub path: PathBuf,
71 /// The mode of operation on the file. Possible values:
72 /// - 'edit': Make granular edits to an existing file.
73 /// - 'create': Create a new file if it doesn't exist.
74 /// - 'overwrite': Replace the entire contents of an existing file.
75 ///
76 /// When a file already exists or you just created it, prefer editing it as opposed to recreating it from scratch.
77 pub mode: EditFileMode,
78}
79
80#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
81struct EditFileToolPartialInput {
82 #[serde(default)]
83 path: String,
84 #[serde(default)]
85 display_description: String,
86}
87
88#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
89#[serde(rename_all = "lowercase")]
90#[schemars(inline)]
91pub enum EditFileMode {
92 Edit,
93 Create,
94 Overwrite,
95}
96
97#[derive(Debug, Serialize, Deserialize)]
98pub struct EditFileToolOutput {
99 #[serde(alias = "original_path")]
100 input_path: PathBuf,
101 new_text: String,
102 old_text: Arc<String>,
103 #[serde(default)]
104 diff: String,
105 #[serde(alias = "raw_output")]
106 edit_agent_output: EditAgentOutput,
107}
108
109impl From<EditFileToolOutput> for LanguageModelToolResultContent {
110 fn from(output: EditFileToolOutput) -> Self {
111 if output.diff.is_empty() {
112 "No edits were made.".into()
113 } else {
114 format!(
115 "Edited {}:\n\n```diff\n{}\n```",
116 output.input_path.display(),
117 output.diff
118 )
119 .into()
120 }
121 }
122}
123
124pub struct EditFileTool {
125 thread: WeakEntity<Thread>,
126 language_registry: Arc<LanguageRegistry>,
127 project: Entity<Project>,
128 templates: Arc<Templates>,
129}
130
131impl EditFileTool {
132 pub fn new(
133 project: Entity<Project>,
134 thread: WeakEntity<Thread>,
135 language_registry: Arc<LanguageRegistry>,
136 templates: Arc<Templates>,
137 ) -> Self {
138 Self {
139 project,
140 thread,
141 language_registry,
142 templates,
143 }
144 }
145
146 fn authorize(
147 &self,
148 input: &EditFileToolInput,
149 event_stream: &ToolCallEventStream,
150 cx: &mut App,
151 ) -> Task<Result<()>> {
152 if agent_settings::AgentSettings::get_global(cx).always_allow_tool_actions {
153 return Task::ready(Ok(()));
154 }
155
156 // If any path component matches the local settings folder, then this could affect
157 // the editor in ways beyond the project source, so prompt.
158 let local_settings_folder = paths::local_settings_folder_name();
159 let path = Path::new(&input.path);
160 if path.components().any(|component| {
161 component.as_os_str() == <_ as AsRef<OsStr>>::as_ref(&local_settings_folder)
162 }) {
163 return event_stream.authorize(
164 format!("{} (local settings)", input.display_description),
165 cx,
166 );
167 }
168
169 // It's also possible that the global config dir is configured to be inside the project,
170 // so check for that edge case too.
171 // TODO this is broken when remoting
172 if let Ok(canonical_path) = std::fs::canonicalize(&input.path)
173 && canonical_path.starts_with(paths::config_dir())
174 {
175 return event_stream.authorize(
176 format!("{} (global settings)", input.display_description),
177 cx,
178 );
179 }
180
181 // Check if path is inside the global config directory
182 // First check if it's already inside project - if not, try to canonicalize
183 let Ok(project_path) = self.thread.read_with(cx, |thread, cx| {
184 thread.project().read(cx).find_project_path(&input.path, cx)
185 }) else {
186 return Task::ready(Err(anyhow!("thread was dropped")));
187 };
188
189 // If the path is inside the project, and it's not one of the above edge cases,
190 // then no confirmation is necessary. Otherwise, confirmation is necessary.
191 if project_path.is_some() {
192 Task::ready(Ok(()))
193 } else {
194 event_stream.authorize(&input.display_description, cx)
195 }
196 }
197}
198
199impl AgentTool for EditFileTool {
200 type Input = EditFileToolInput;
201 type Output = EditFileToolOutput;
202
203 fn name() -> &'static str {
204 "edit_file"
205 }
206
207 fn kind() -> acp::ToolKind {
208 acp::ToolKind::Edit
209 }
210
211 fn initial_title(
212 &self,
213 input: Result<Self::Input, serde_json::Value>,
214 cx: &mut App,
215 ) -> SharedString {
216 match input {
217 Ok(input) => self
218 .project
219 .read(cx)
220 .find_project_path(&input.path, cx)
221 .and_then(|project_path| {
222 self.project
223 .read(cx)
224 .short_full_path_for_project_path(&project_path, cx)
225 })
226 .unwrap_or(input.path.to_string_lossy().into_owned())
227 .into(),
228 Err(raw_input) => {
229 if let Some(input) =
230 serde_json::from_value::<EditFileToolPartialInput>(raw_input).ok()
231 {
232 let path = input.path.trim();
233 if !path.is_empty() {
234 return self
235 .project
236 .read(cx)
237 .find_project_path(&input.path, cx)
238 .and_then(|project_path| {
239 self.project
240 .read(cx)
241 .short_full_path_for_project_path(&project_path, cx)
242 })
243 .unwrap_or(input.path)
244 .into();
245 }
246
247 let description = input.display_description.trim();
248 if !description.is_empty() {
249 return description.to_string().into();
250 }
251 }
252
253 DEFAULT_UI_TEXT.into()
254 }
255 }
256 }
257
258 fn run(
259 self: Arc<Self>,
260 input: Self::Input,
261 event_stream: ToolCallEventStream,
262 cx: &mut App,
263 ) -> Task<Result<Self::Output>> {
264 let Ok(project) = self
265 .thread
266 .read_with(cx, |thread, _cx| thread.project().clone())
267 else {
268 return Task::ready(Err(anyhow!("thread was dropped")));
269 };
270 let project_path = match resolve_path(&input, project.clone(), cx) {
271 Ok(path) => path,
272 Err(err) => return Task::ready(Err(anyhow!(err))),
273 };
274 let abs_path = project.read(cx).absolute_path(&project_path, cx);
275 if let Some(abs_path) = abs_path.clone() {
276 event_stream.update_fields(
277 ToolCallUpdateFields::new().locations(vec![acp::ToolCallLocation::new(abs_path)]),
278 );
279 }
280
281 let authorize = self.authorize(&input, &event_stream, cx);
282 cx.spawn(async move |cx: &mut AsyncApp| {
283 authorize.await?;
284
285 let (request, model, action_log) = self.thread.update(cx, |thread, cx| {
286 let request = thread.build_completion_request(CompletionIntent::ToolResults, cx);
287 (request, thread.model().cloned(), thread.action_log().clone())
288 })?;
289 let request = request?;
290 let model = model.context("No language model configured")?;
291
292 let edit_format = EditFormat::from_model(model.clone())?;
293 let edit_agent = EditAgent::new(
294 model,
295 project.clone(),
296 action_log.clone(),
297 self.templates.clone(),
298 edit_format,
299 );
300
301 let buffer = project
302 .update(cx, |project, cx| {
303 project.open_buffer(project_path.clone(), cx)
304 })?
305 .await?;
306
307 // Check if the file has been modified since the agent last read it
308 if let Some(abs_path) = abs_path.as_ref() {
309 let (last_read_mtime, current_mtime, is_dirty) = self.thread.update(cx, |thread, cx| {
310 let last_read = thread.file_read_times.get(abs_path).copied();
311 let current = buffer.read(cx).file().and_then(|file| file.disk_state().mtime());
312 let dirty = buffer.read(cx).is_dirty();
313 (last_read, current, dirty)
314 })?;
315
316 // Check for unsaved changes first - these indicate modifications we don't know about
317 if is_dirty {
318 anyhow::bail!(
319 "This file cannot be written to because it has unsaved changes. \
320 Please end the current conversation immediately by telling the user you want to write to this file (mention its path explicitly) but you can't write to it because it has unsaved changes. \
321 Ask the user to save that buffer's changes and to inform you when it's ok to proceed."
322 );
323 }
324
325 // Check if the file was modified on disk since we last read it
326 if let (Some(last_read), Some(current)) = (last_read_mtime, current_mtime) {
327 // MTime can be unreliable for comparisons, so our newtype intentionally
328 // doesn't support comparing them. If the mtime at all different
329 // (which could be because of a modification or because e.g. system clock changed),
330 // we pessimistically assume it was modified.
331 if current != last_read {
332 anyhow::bail!(
333 "The file {} has been modified since you last read it. \
334 Please read the file again to get the current state before editing it.",
335 input.path.display()
336 );
337 }
338 }
339 }
340
341 let diff = cx.new(|cx| Diff::new(buffer.clone(), cx))?;
342 event_stream.update_diff(diff.clone());
343 let _finalize_diff = util::defer({
344 let diff = diff.downgrade();
345 let mut cx = cx.clone();
346 move || {
347 diff.update(&mut cx, |diff, cx| diff.finalize(cx)).ok();
348 }
349 });
350
351 let old_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
352 let old_text = cx
353 .background_spawn({
354 let old_snapshot = old_snapshot.clone();
355 async move { Arc::new(old_snapshot.text()) }
356 })
357 .await;
358
359
360 let (output, mut events) = if matches!(input.mode, EditFileMode::Edit) {
361 edit_agent.edit(
362 buffer.clone(),
363 input.display_description.clone(),
364 &request,
365 cx,
366 )
367 } else {
368 edit_agent.overwrite(
369 buffer.clone(),
370 input.display_description.clone(),
371 &request,
372 cx,
373 )
374 };
375
376 let mut hallucinated_old_text = false;
377 let mut ambiguous_ranges = Vec::new();
378 let mut emitted_location = false;
379 while let Some(event) = events.next().await {
380 match event {
381 EditAgentOutputEvent::Edited(range) => {
382 if !emitted_location {
383 let line = buffer.update(cx, |buffer, _cx| {
384 range.start.to_point(&buffer.snapshot()).row
385 }).ok();
386 if let Some(abs_path) = abs_path.clone() {
387 event_stream.update_fields(ToolCallUpdateFields::new().locations(vec![ToolCallLocation::new(abs_path).line(line)]));
388 }
389 emitted_location = true;
390 }
391 },
392 EditAgentOutputEvent::UnresolvedEditRange => hallucinated_old_text = true,
393 EditAgentOutputEvent::AmbiguousEditRange(ranges) => ambiguous_ranges = ranges,
394 EditAgentOutputEvent::ResolvingEditRange(range) => {
395 diff.update(cx, |card, cx| card.reveal_range(range.clone(), cx))?;
396 // if !emitted_location {
397 // let line = buffer.update(cx, |buffer, _cx| {
398 // range.start.to_point(&buffer.snapshot()).row
399 // }).ok();
400 // if let Some(abs_path) = abs_path.clone() {
401 // event_stream.update_fields(ToolCallUpdateFields {
402 // locations: Some(vec![ToolCallLocation { path: abs_path, line }]),
403 // ..Default::default()
404 // });
405 // }
406 // }
407 }
408 }
409 }
410
411 // If format_on_save is enabled, format the buffer
412 let format_on_save_enabled = buffer
413 .read_with(cx, |buffer, cx| {
414 let settings = language_settings::language_settings(
415 buffer.language().map(|l| l.name()),
416 buffer.file(),
417 cx,
418 );
419 settings.format_on_save != FormatOnSave::Off
420 })
421 .unwrap_or(false);
422
423 let edit_agent_output = output.await?;
424
425 if format_on_save_enabled {
426 action_log.update(cx, |log, cx| {
427 log.buffer_edited(buffer.clone(), cx);
428 })?;
429
430 let format_task = project.update(cx, |project, cx| {
431 project.format(
432 HashSet::from_iter([buffer.clone()]),
433 LspFormatTarget::Buffers,
434 false, // Don't push to history since the tool did it.
435 FormatTrigger::Save,
436 cx,
437 )
438 })?;
439 format_task.await.log_err();
440 }
441
442 project
443 .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))?
444 .await?;
445
446 action_log.update(cx, |log, cx| {
447 log.buffer_edited(buffer.clone(), cx);
448 })?;
449
450 // Update the recorded read time after a successful edit so consecutive edits work
451 if let Some(abs_path) = abs_path.as_ref() {
452 if let Some(new_mtime) = buffer.read_with(cx, |buffer, _| {
453 buffer.file().and_then(|file| file.disk_state().mtime())
454 })? {
455 self.thread.update(cx, |thread, _| {
456 thread.file_read_times.insert(abs_path.to_path_buf(), new_mtime);
457 })?;
458 }
459 }
460
461 let new_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
462 let (new_text, unified_diff) = cx
463 .background_spawn({
464 let new_snapshot = new_snapshot.clone();
465 let old_text = old_text.clone();
466 async move {
467 let new_text = new_snapshot.text();
468 let diff = language::unified_diff(&old_text, &new_text);
469 (new_text, diff)
470 }
471 })
472 .await;
473
474 let input_path = input.path.display();
475 if unified_diff.is_empty() {
476 anyhow::ensure!(
477 !hallucinated_old_text,
478 formatdoc! {"
479 Some edits were produced but none of them could be applied.
480 Read the relevant sections of {input_path} again so that
481 I can perform the requested edits.
482 "}
483 );
484 anyhow::ensure!(
485 ambiguous_ranges.is_empty(),
486 {
487 let line_numbers = ambiguous_ranges
488 .iter()
489 .map(|range| range.start.to_string())
490 .collect::<Vec<_>>()
491 .join(", ");
492 formatdoc! {"
493 <old_text> matches more than one position in the file (lines: {line_numbers}). Read the
494 relevant sections of {input_path} again and extend <old_text> so
495 that I can perform the requested edits.
496 "}
497 }
498 );
499 }
500
501 Ok(EditFileToolOutput {
502 input_path: input.path,
503 new_text,
504 old_text,
505 diff: unified_diff,
506 edit_agent_output,
507 })
508 })
509 }
510
511 fn replay(
512 &self,
513 _input: Self::Input,
514 output: Self::Output,
515 event_stream: ToolCallEventStream,
516 cx: &mut App,
517 ) -> Result<()> {
518 event_stream.update_diff(cx.new(|cx| {
519 Diff::finalized(
520 output.input_path.to_string_lossy().into_owned(),
521 Some(output.old_text.to_string()),
522 output.new_text,
523 self.language_registry.clone(),
524 cx,
525 )
526 }));
527 Ok(())
528 }
529}
530
531/// Validate that the file path is valid, meaning:
532///
533/// - For `edit` and `overwrite`, the path must point to an existing file.
534/// - For `create`, the file must not already exist, but it's parent dir must exist.
535fn resolve_path(
536 input: &EditFileToolInput,
537 project: Entity<Project>,
538 cx: &mut App,
539) -> Result<ProjectPath> {
540 let project = project.read(cx);
541
542 match input.mode {
543 EditFileMode::Edit | EditFileMode::Overwrite => {
544 let path = project
545 .find_project_path(&input.path, cx)
546 .context("Can't edit file: path not found")?;
547
548 let entry = project
549 .entry_for_path(&path, cx)
550 .context("Can't edit file: path not found")?;
551
552 anyhow::ensure!(entry.is_file(), "Can't edit file: path is a directory");
553 Ok(path)
554 }
555
556 EditFileMode::Create => {
557 if let Some(path) = project.find_project_path(&input.path, cx) {
558 anyhow::ensure!(
559 project.entry_for_path(&path, cx).is_none(),
560 "Can't create file: file already exists"
561 );
562 }
563
564 let parent_path = input
565 .path
566 .parent()
567 .context("Can't create file: incorrect path")?;
568
569 let parent_project_path = project.find_project_path(&parent_path, cx);
570
571 let parent_entry = parent_project_path
572 .as_ref()
573 .and_then(|path| project.entry_for_path(path, cx))
574 .context("Can't create file: parent directory doesn't exist")?;
575
576 anyhow::ensure!(
577 parent_entry.is_dir(),
578 "Can't create file: parent is not a directory"
579 );
580
581 let file_name = input
582 .path
583 .file_name()
584 .and_then(|file_name| file_name.to_str())
585 .and_then(|file_name| RelPath::unix(file_name).ok())
586 .context("Can't create file: invalid filename")?;
587
588 let new_file_path = parent_project_path.map(|parent| ProjectPath {
589 path: parent.path.join(file_name),
590 ..parent
591 });
592
593 new_file_path.context("Can't create file")
594 }
595 }
596}
597
598#[cfg(test)]
599mod tests {
600 use super::*;
601 use crate::{ContextServerRegistry, Templates};
602 use fs::Fs;
603 use gpui::{TestAppContext, UpdateGlobal};
604 use language_model::fake_provider::FakeLanguageModel;
605 use prompt_store::ProjectContext;
606 use serde_json::json;
607 use settings::SettingsStore;
608 use util::{path, rel_path::rel_path};
609
610 #[gpui::test]
611 async fn test_edit_nonexistent_file(cx: &mut TestAppContext) {
612 init_test(cx);
613
614 let fs = project::FakeFs::new(cx.executor());
615 fs.insert_tree("/root", json!({})).await;
616 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
617 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
618 let context_server_registry =
619 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
620 let model = Arc::new(FakeLanguageModel::default());
621 let thread = cx.new(|cx| {
622 Thread::new(
623 project.clone(),
624 cx.new(|_cx| ProjectContext::default()),
625 context_server_registry,
626 Templates::new(),
627 Some(model),
628 cx,
629 )
630 });
631 let result = cx
632 .update(|cx| {
633 let input = EditFileToolInput {
634 display_description: "Some edit".into(),
635 path: "root/nonexistent_file.txt".into(),
636 mode: EditFileMode::Edit,
637 };
638 Arc::new(EditFileTool::new(
639 project,
640 thread.downgrade(),
641 language_registry,
642 Templates::new(),
643 ))
644 .run(input, ToolCallEventStream::test().0, cx)
645 })
646 .await;
647 assert_eq!(
648 result.unwrap_err().to_string(),
649 "Can't edit file: path not found"
650 );
651 }
652
653 #[gpui::test]
654 async fn test_resolve_path_for_creating_file(cx: &mut TestAppContext) {
655 let mode = &EditFileMode::Create;
656
657 let result = test_resolve_path(mode, "root/new.txt", cx);
658 assert_resolved_path_eq(result.await, rel_path("new.txt"));
659
660 let result = test_resolve_path(mode, "new.txt", cx);
661 assert_resolved_path_eq(result.await, rel_path("new.txt"));
662
663 let result = test_resolve_path(mode, "dir/new.txt", cx);
664 assert_resolved_path_eq(result.await, rel_path("dir/new.txt"));
665
666 let result = test_resolve_path(mode, "root/dir/subdir/existing.txt", cx);
667 assert_eq!(
668 result.await.unwrap_err().to_string(),
669 "Can't create file: file already exists"
670 );
671
672 let result = test_resolve_path(mode, "root/dir/nonexistent_dir/new.txt", cx);
673 assert_eq!(
674 result.await.unwrap_err().to_string(),
675 "Can't create file: parent directory doesn't exist"
676 );
677 }
678
679 #[gpui::test]
680 async fn test_resolve_path_for_editing_file(cx: &mut TestAppContext) {
681 let mode = &EditFileMode::Edit;
682
683 let path_with_root = "root/dir/subdir/existing.txt";
684 let path_without_root = "dir/subdir/existing.txt";
685 let result = test_resolve_path(mode, path_with_root, cx);
686 assert_resolved_path_eq(result.await, rel_path(path_without_root));
687
688 let result = test_resolve_path(mode, path_without_root, cx);
689 assert_resolved_path_eq(result.await, rel_path(path_without_root));
690
691 let result = test_resolve_path(mode, "root/nonexistent.txt", cx);
692 assert_eq!(
693 result.await.unwrap_err().to_string(),
694 "Can't edit file: path not found"
695 );
696
697 let result = test_resolve_path(mode, "root/dir", cx);
698 assert_eq!(
699 result.await.unwrap_err().to_string(),
700 "Can't edit file: path is a directory"
701 );
702 }
703
704 async fn test_resolve_path(
705 mode: &EditFileMode,
706 path: &str,
707 cx: &mut TestAppContext,
708 ) -> anyhow::Result<ProjectPath> {
709 init_test(cx);
710
711 let fs = project::FakeFs::new(cx.executor());
712 fs.insert_tree(
713 "/root",
714 json!({
715 "dir": {
716 "subdir": {
717 "existing.txt": "hello"
718 }
719 }
720 }),
721 )
722 .await;
723 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
724
725 let input = EditFileToolInput {
726 display_description: "Some edit".into(),
727 path: path.into(),
728 mode: mode.clone(),
729 };
730
731 cx.update(|cx| resolve_path(&input, project, cx))
732 }
733
734 #[track_caller]
735 fn assert_resolved_path_eq(path: anyhow::Result<ProjectPath>, expected: &RelPath) {
736 let actual = path.expect("Should return valid path").path;
737 assert_eq!(actual.as_ref(), expected);
738 }
739
740 #[gpui::test]
741 async fn test_format_on_save(cx: &mut TestAppContext) {
742 init_test(cx);
743
744 let fs = project::FakeFs::new(cx.executor());
745 fs.insert_tree("/root", json!({"src": {}})).await;
746
747 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
748
749 // Set up a Rust language with LSP formatting support
750 let rust_language = Arc::new(language::Language::new(
751 language::LanguageConfig {
752 name: "Rust".into(),
753 matcher: language::LanguageMatcher {
754 path_suffixes: vec!["rs".to_string()],
755 ..Default::default()
756 },
757 ..Default::default()
758 },
759 None,
760 ));
761
762 // Register the language and fake LSP
763 let language_registry = project.read_with(cx, |project, _| project.languages().clone());
764 language_registry.add(rust_language);
765
766 let mut fake_language_servers = language_registry.register_fake_lsp(
767 "Rust",
768 language::FakeLspAdapter {
769 capabilities: lsp::ServerCapabilities {
770 document_formatting_provider: Some(lsp::OneOf::Left(true)),
771 ..Default::default()
772 },
773 ..Default::default()
774 },
775 );
776
777 // Create the file
778 fs.save(
779 path!("/root/src/main.rs").as_ref(),
780 &"initial content".into(),
781 language::LineEnding::Unix,
782 )
783 .await
784 .unwrap();
785
786 // Open the buffer to trigger LSP initialization
787 let buffer = project
788 .update(cx, |project, cx| {
789 project.open_local_buffer(path!("/root/src/main.rs"), cx)
790 })
791 .await
792 .unwrap();
793
794 // Register the buffer with language servers
795 let _handle = project.update(cx, |project, cx| {
796 project.register_buffer_with_language_servers(&buffer, cx)
797 });
798
799 const UNFORMATTED_CONTENT: &str = "fn main() {println!(\"Hello!\");}\n";
800 const FORMATTED_CONTENT: &str =
801 "This file was formatted by the fake formatter in the test.\n";
802
803 // Get the fake language server and set up formatting handler
804 let fake_language_server = fake_language_servers.next().await.unwrap();
805 fake_language_server.set_request_handler::<lsp::request::Formatting, _, _>({
806 |_, _| async move {
807 Ok(Some(vec![lsp::TextEdit {
808 range: lsp::Range::new(lsp::Position::new(0, 0), lsp::Position::new(1, 0)),
809 new_text: FORMATTED_CONTENT.to_string(),
810 }]))
811 }
812 });
813
814 let context_server_registry =
815 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
816 let model = Arc::new(FakeLanguageModel::default());
817 let thread = cx.new(|cx| {
818 Thread::new(
819 project.clone(),
820 cx.new(|_cx| ProjectContext::default()),
821 context_server_registry,
822 Templates::new(),
823 Some(model.clone()),
824 cx,
825 )
826 });
827
828 // First, test with format_on_save enabled
829 cx.update(|cx| {
830 SettingsStore::update_global(cx, |store, cx| {
831 store.update_user_settings(cx, |settings| {
832 settings.project.all_languages.defaults.format_on_save = Some(FormatOnSave::On);
833 settings.project.all_languages.defaults.formatter =
834 Some(language::language_settings::FormatterList::default());
835 });
836 });
837 });
838
839 // Have the model stream unformatted content
840 let edit_result = {
841 let edit_task = cx.update(|cx| {
842 let input = EditFileToolInput {
843 display_description: "Create main function".into(),
844 path: "root/src/main.rs".into(),
845 mode: EditFileMode::Overwrite,
846 };
847 Arc::new(EditFileTool::new(
848 project.clone(),
849 thread.downgrade(),
850 language_registry.clone(),
851 Templates::new(),
852 ))
853 .run(input, ToolCallEventStream::test().0, cx)
854 });
855
856 // Stream the unformatted content
857 cx.executor().run_until_parked();
858 model.send_last_completion_stream_text_chunk(UNFORMATTED_CONTENT.to_string());
859 model.end_last_completion_stream();
860
861 edit_task.await
862 };
863 assert!(edit_result.is_ok());
864
865 // Wait for any async operations (e.g. formatting) to complete
866 cx.executor().run_until_parked();
867
868 // Read the file to verify it was formatted automatically
869 let new_content = fs.load(path!("/root/src/main.rs").as_ref()).await.unwrap();
870 assert_eq!(
871 // Ignore carriage returns on Windows
872 new_content.replace("\r\n", "\n"),
873 FORMATTED_CONTENT,
874 "Code should be formatted when format_on_save is enabled"
875 );
876
877 let stale_buffer_count = thread
878 .read_with(cx, |thread, _cx| thread.action_log.clone())
879 .read_with(cx, |log, cx| log.stale_buffers(cx).count());
880
881 assert_eq!(
882 stale_buffer_count, 0,
883 "BUG: Buffer is incorrectly marked as stale after format-on-save. Found {} stale buffers. \
884 This causes the agent to think the file was modified externally when it was just formatted.",
885 stale_buffer_count
886 );
887
888 // Next, test with format_on_save disabled
889 cx.update(|cx| {
890 SettingsStore::update_global(cx, |store, cx| {
891 store.update_user_settings(cx, |settings| {
892 settings.project.all_languages.defaults.format_on_save =
893 Some(FormatOnSave::Off);
894 });
895 });
896 });
897
898 // Stream unformatted edits again
899 let edit_result = {
900 let edit_task = cx.update(|cx| {
901 let input = EditFileToolInput {
902 display_description: "Update main function".into(),
903 path: "root/src/main.rs".into(),
904 mode: EditFileMode::Overwrite,
905 };
906 Arc::new(EditFileTool::new(
907 project.clone(),
908 thread.downgrade(),
909 language_registry,
910 Templates::new(),
911 ))
912 .run(input, ToolCallEventStream::test().0, cx)
913 });
914
915 // Stream the unformatted content
916 cx.executor().run_until_parked();
917 model.send_last_completion_stream_text_chunk(UNFORMATTED_CONTENT.to_string());
918 model.end_last_completion_stream();
919
920 edit_task.await
921 };
922 assert!(edit_result.is_ok());
923
924 // Wait for any async operations (e.g. formatting) to complete
925 cx.executor().run_until_parked();
926
927 // Verify the file was not formatted
928 let new_content = fs.load(path!("/root/src/main.rs").as_ref()).await.unwrap();
929 assert_eq!(
930 // Ignore carriage returns on Windows
931 new_content.replace("\r\n", "\n"),
932 UNFORMATTED_CONTENT,
933 "Code should not be formatted when format_on_save is disabled"
934 );
935 }
936
937 #[gpui::test]
938 async fn test_remove_trailing_whitespace(cx: &mut TestAppContext) {
939 init_test(cx);
940
941 let fs = project::FakeFs::new(cx.executor());
942 fs.insert_tree("/root", json!({"src": {}})).await;
943
944 // Create a simple file with trailing whitespace
945 fs.save(
946 path!("/root/src/main.rs").as_ref(),
947 &"initial content".into(),
948 language::LineEnding::Unix,
949 )
950 .await
951 .unwrap();
952
953 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
954 let context_server_registry =
955 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
956 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
957 let model = Arc::new(FakeLanguageModel::default());
958 let thread = cx.new(|cx| {
959 Thread::new(
960 project.clone(),
961 cx.new(|_cx| ProjectContext::default()),
962 context_server_registry,
963 Templates::new(),
964 Some(model.clone()),
965 cx,
966 )
967 });
968
969 // First, test with remove_trailing_whitespace_on_save enabled
970 cx.update(|cx| {
971 SettingsStore::update_global(cx, |store, cx| {
972 store.update_user_settings(cx, |settings| {
973 settings
974 .project
975 .all_languages
976 .defaults
977 .remove_trailing_whitespace_on_save = Some(true);
978 });
979 });
980 });
981
982 const CONTENT_WITH_TRAILING_WHITESPACE: &str =
983 "fn main() { \n println!(\"Hello!\"); \n}\n";
984
985 // Have the model stream content that contains trailing whitespace
986 let edit_result = {
987 let edit_task = cx.update(|cx| {
988 let input = EditFileToolInput {
989 display_description: "Create main function".into(),
990 path: "root/src/main.rs".into(),
991 mode: EditFileMode::Overwrite,
992 };
993 Arc::new(EditFileTool::new(
994 project.clone(),
995 thread.downgrade(),
996 language_registry.clone(),
997 Templates::new(),
998 ))
999 .run(input, ToolCallEventStream::test().0, cx)
1000 });
1001
1002 // Stream the content with trailing whitespace
1003 cx.executor().run_until_parked();
1004 model.send_last_completion_stream_text_chunk(
1005 CONTENT_WITH_TRAILING_WHITESPACE.to_string(),
1006 );
1007 model.end_last_completion_stream();
1008
1009 edit_task.await
1010 };
1011 assert!(edit_result.is_ok());
1012
1013 // Wait for any async operations (e.g. formatting) to complete
1014 cx.executor().run_until_parked();
1015
1016 // Read the file to verify trailing whitespace was removed automatically
1017 assert_eq!(
1018 // Ignore carriage returns on Windows
1019 fs.load(path!("/root/src/main.rs").as_ref())
1020 .await
1021 .unwrap()
1022 .replace("\r\n", "\n"),
1023 "fn main() {\n println!(\"Hello!\");\n}\n",
1024 "Trailing whitespace should be removed when remove_trailing_whitespace_on_save is enabled"
1025 );
1026
1027 // Next, test with remove_trailing_whitespace_on_save disabled
1028 cx.update(|cx| {
1029 SettingsStore::update_global(cx, |store, cx| {
1030 store.update_user_settings(cx, |settings| {
1031 settings
1032 .project
1033 .all_languages
1034 .defaults
1035 .remove_trailing_whitespace_on_save = Some(false);
1036 });
1037 });
1038 });
1039
1040 // Stream edits again with trailing whitespace
1041 let edit_result = {
1042 let edit_task = cx.update(|cx| {
1043 let input = EditFileToolInput {
1044 display_description: "Update main function".into(),
1045 path: "root/src/main.rs".into(),
1046 mode: EditFileMode::Overwrite,
1047 };
1048 Arc::new(EditFileTool::new(
1049 project.clone(),
1050 thread.downgrade(),
1051 language_registry,
1052 Templates::new(),
1053 ))
1054 .run(input, ToolCallEventStream::test().0, cx)
1055 });
1056
1057 // Stream the content with trailing whitespace
1058 cx.executor().run_until_parked();
1059 model.send_last_completion_stream_text_chunk(
1060 CONTENT_WITH_TRAILING_WHITESPACE.to_string(),
1061 );
1062 model.end_last_completion_stream();
1063
1064 edit_task.await
1065 };
1066 assert!(edit_result.is_ok());
1067
1068 // Wait for any async operations (e.g. formatting) to complete
1069 cx.executor().run_until_parked();
1070
1071 // Verify the file still has trailing whitespace
1072 // Read the file again - it should still have trailing whitespace
1073 let final_content = fs.load(path!("/root/src/main.rs").as_ref()).await.unwrap();
1074 assert_eq!(
1075 // Ignore carriage returns on Windows
1076 final_content.replace("\r\n", "\n"),
1077 CONTENT_WITH_TRAILING_WHITESPACE,
1078 "Trailing whitespace should remain when remove_trailing_whitespace_on_save is disabled"
1079 );
1080 }
1081
1082 #[gpui::test]
1083 async fn test_authorize(cx: &mut TestAppContext) {
1084 init_test(cx);
1085 let fs = project::FakeFs::new(cx.executor());
1086 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
1087 let context_server_registry =
1088 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1089 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1090 let model = Arc::new(FakeLanguageModel::default());
1091 let thread = cx.new(|cx| {
1092 Thread::new(
1093 project.clone(),
1094 cx.new(|_cx| ProjectContext::default()),
1095 context_server_registry,
1096 Templates::new(),
1097 Some(model.clone()),
1098 cx,
1099 )
1100 });
1101 let tool = Arc::new(EditFileTool::new(
1102 project.clone(),
1103 thread.downgrade(),
1104 language_registry,
1105 Templates::new(),
1106 ));
1107 fs.insert_tree("/root", json!({})).await;
1108
1109 // Test 1: Path with .zed component should require confirmation
1110 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1111 let _auth = cx.update(|cx| {
1112 tool.authorize(
1113 &EditFileToolInput {
1114 display_description: "test 1".into(),
1115 path: ".zed/settings.json".into(),
1116 mode: EditFileMode::Edit,
1117 },
1118 &stream_tx,
1119 cx,
1120 )
1121 });
1122
1123 let event = stream_rx.expect_authorization().await;
1124 assert_eq!(
1125 event.tool_call.fields.title,
1126 Some("test 1 (local settings)".into())
1127 );
1128
1129 // Test 2: Path outside project should require confirmation
1130 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1131 let _auth = cx.update(|cx| {
1132 tool.authorize(
1133 &EditFileToolInput {
1134 display_description: "test 2".into(),
1135 path: "/etc/hosts".into(),
1136 mode: EditFileMode::Edit,
1137 },
1138 &stream_tx,
1139 cx,
1140 )
1141 });
1142
1143 let event = stream_rx.expect_authorization().await;
1144 assert_eq!(event.tool_call.fields.title, Some("test 2".into()));
1145
1146 // Test 3: Relative path without .zed should not require confirmation
1147 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1148 cx.update(|cx| {
1149 tool.authorize(
1150 &EditFileToolInput {
1151 display_description: "test 3".into(),
1152 path: "root/src/main.rs".into(),
1153 mode: EditFileMode::Edit,
1154 },
1155 &stream_tx,
1156 cx,
1157 )
1158 })
1159 .await
1160 .unwrap();
1161 assert!(stream_rx.try_next().is_err());
1162
1163 // Test 4: Path with .zed in the middle should require confirmation
1164 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1165 let _auth = cx.update(|cx| {
1166 tool.authorize(
1167 &EditFileToolInput {
1168 display_description: "test 4".into(),
1169 path: "root/.zed/tasks.json".into(),
1170 mode: EditFileMode::Edit,
1171 },
1172 &stream_tx,
1173 cx,
1174 )
1175 });
1176 let event = stream_rx.expect_authorization().await;
1177 assert_eq!(
1178 event.tool_call.fields.title,
1179 Some("test 4 (local settings)".into())
1180 );
1181
1182 // Test 5: When always_allow_tool_actions is enabled, no confirmation needed
1183 cx.update(|cx| {
1184 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
1185 settings.always_allow_tool_actions = true;
1186 agent_settings::AgentSettings::override_global(settings, cx);
1187 });
1188
1189 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1190 cx.update(|cx| {
1191 tool.authorize(
1192 &EditFileToolInput {
1193 display_description: "test 5.1".into(),
1194 path: ".zed/settings.json".into(),
1195 mode: EditFileMode::Edit,
1196 },
1197 &stream_tx,
1198 cx,
1199 )
1200 })
1201 .await
1202 .unwrap();
1203 assert!(stream_rx.try_next().is_err());
1204
1205 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1206 cx.update(|cx| {
1207 tool.authorize(
1208 &EditFileToolInput {
1209 display_description: "test 5.2".into(),
1210 path: "/etc/hosts".into(),
1211 mode: EditFileMode::Edit,
1212 },
1213 &stream_tx,
1214 cx,
1215 )
1216 })
1217 .await
1218 .unwrap();
1219 assert!(stream_rx.try_next().is_err());
1220 }
1221
1222 #[gpui::test]
1223 async fn test_authorize_global_config(cx: &mut TestAppContext) {
1224 init_test(cx);
1225 let fs = project::FakeFs::new(cx.executor());
1226 fs.insert_tree("/project", json!({})).await;
1227 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1228 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1229 let context_server_registry =
1230 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1231 let model = Arc::new(FakeLanguageModel::default());
1232 let thread = cx.new(|cx| {
1233 Thread::new(
1234 project.clone(),
1235 cx.new(|_cx| ProjectContext::default()),
1236 context_server_registry,
1237 Templates::new(),
1238 Some(model.clone()),
1239 cx,
1240 )
1241 });
1242 let tool = Arc::new(EditFileTool::new(
1243 project.clone(),
1244 thread.downgrade(),
1245 language_registry,
1246 Templates::new(),
1247 ));
1248
1249 // Test global config paths - these should require confirmation if they exist and are outside the project
1250 let test_cases = vec![
1251 (
1252 "/etc/hosts",
1253 true,
1254 "System file should require confirmation",
1255 ),
1256 (
1257 "/usr/local/bin/script",
1258 true,
1259 "System bin file should require confirmation",
1260 ),
1261 (
1262 "project/normal_file.rs",
1263 false,
1264 "Normal project file should not require confirmation",
1265 ),
1266 ];
1267
1268 for (path, should_confirm, description) in test_cases {
1269 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1270 let auth = cx.update(|cx| {
1271 tool.authorize(
1272 &EditFileToolInput {
1273 display_description: "Edit file".into(),
1274 path: path.into(),
1275 mode: EditFileMode::Edit,
1276 },
1277 &stream_tx,
1278 cx,
1279 )
1280 });
1281
1282 if should_confirm {
1283 stream_rx.expect_authorization().await;
1284 } else {
1285 auth.await.unwrap();
1286 assert!(
1287 stream_rx.try_next().is_err(),
1288 "Failed for case: {} - path: {} - expected no confirmation but got one",
1289 description,
1290 path
1291 );
1292 }
1293 }
1294 }
1295
1296 #[gpui::test]
1297 async fn test_needs_confirmation_with_multiple_worktrees(cx: &mut TestAppContext) {
1298 init_test(cx);
1299 let fs = project::FakeFs::new(cx.executor());
1300
1301 // Create multiple worktree directories
1302 fs.insert_tree(
1303 "/workspace/frontend",
1304 json!({
1305 "src": {
1306 "main.js": "console.log('frontend');"
1307 }
1308 }),
1309 )
1310 .await;
1311 fs.insert_tree(
1312 "/workspace/backend",
1313 json!({
1314 "src": {
1315 "main.rs": "fn main() {}"
1316 }
1317 }),
1318 )
1319 .await;
1320 fs.insert_tree(
1321 "/workspace/shared",
1322 json!({
1323 ".zed": {
1324 "settings.json": "{}"
1325 }
1326 }),
1327 )
1328 .await;
1329
1330 // Create project with multiple worktrees
1331 let project = Project::test(
1332 fs.clone(),
1333 [
1334 path!("/workspace/frontend").as_ref(),
1335 path!("/workspace/backend").as_ref(),
1336 path!("/workspace/shared").as_ref(),
1337 ],
1338 cx,
1339 )
1340 .await;
1341 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1342 let context_server_registry =
1343 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1344 let model = Arc::new(FakeLanguageModel::default());
1345 let thread = cx.new(|cx| {
1346 Thread::new(
1347 project.clone(),
1348 cx.new(|_cx| ProjectContext::default()),
1349 context_server_registry.clone(),
1350 Templates::new(),
1351 Some(model.clone()),
1352 cx,
1353 )
1354 });
1355 let tool = Arc::new(EditFileTool::new(
1356 project.clone(),
1357 thread.downgrade(),
1358 language_registry,
1359 Templates::new(),
1360 ));
1361
1362 // Test files in different worktrees
1363 let test_cases = vec![
1364 ("frontend/src/main.js", false, "File in first worktree"),
1365 ("backend/src/main.rs", false, "File in second worktree"),
1366 (
1367 "shared/.zed/settings.json",
1368 true,
1369 ".zed file in third worktree",
1370 ),
1371 ("/etc/hosts", true, "Absolute path outside all worktrees"),
1372 (
1373 "../outside/file.txt",
1374 true,
1375 "Relative path outside worktrees",
1376 ),
1377 ];
1378
1379 for (path, should_confirm, description) in test_cases {
1380 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1381 let auth = cx.update(|cx| {
1382 tool.authorize(
1383 &EditFileToolInput {
1384 display_description: "Edit file".into(),
1385 path: path.into(),
1386 mode: EditFileMode::Edit,
1387 },
1388 &stream_tx,
1389 cx,
1390 )
1391 });
1392
1393 if should_confirm {
1394 stream_rx.expect_authorization().await;
1395 } else {
1396 auth.await.unwrap();
1397 assert!(
1398 stream_rx.try_next().is_err(),
1399 "Failed for case: {} - path: {} - expected no confirmation but got one",
1400 description,
1401 path
1402 );
1403 }
1404 }
1405 }
1406
1407 #[gpui::test]
1408 async fn test_needs_confirmation_edge_cases(cx: &mut TestAppContext) {
1409 init_test(cx);
1410 let fs = project::FakeFs::new(cx.executor());
1411 fs.insert_tree(
1412 "/project",
1413 json!({
1414 ".zed": {
1415 "settings.json": "{}"
1416 },
1417 "src": {
1418 ".zed": {
1419 "local.json": "{}"
1420 }
1421 }
1422 }),
1423 )
1424 .await;
1425 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1426 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1427 let context_server_registry =
1428 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1429 let model = Arc::new(FakeLanguageModel::default());
1430 let thread = cx.new(|cx| {
1431 Thread::new(
1432 project.clone(),
1433 cx.new(|_cx| ProjectContext::default()),
1434 context_server_registry.clone(),
1435 Templates::new(),
1436 Some(model.clone()),
1437 cx,
1438 )
1439 });
1440 let tool = Arc::new(EditFileTool::new(
1441 project.clone(),
1442 thread.downgrade(),
1443 language_registry,
1444 Templates::new(),
1445 ));
1446
1447 // Test edge cases
1448 let test_cases = vec![
1449 // Empty path - find_project_path returns Some for empty paths
1450 ("", false, "Empty path is treated as project root"),
1451 // Root directory
1452 ("/", true, "Root directory should be outside project"),
1453 // Parent directory references - find_project_path resolves these
1454 (
1455 "project/../other",
1456 true,
1457 "Path with .. that goes outside of root directory",
1458 ),
1459 (
1460 "project/./src/file.rs",
1461 false,
1462 "Path with . should work normally",
1463 ),
1464 // Windows-style paths (if on Windows)
1465 #[cfg(target_os = "windows")]
1466 ("C:\\Windows\\System32\\hosts", true, "Windows system path"),
1467 #[cfg(target_os = "windows")]
1468 ("project\\src\\main.rs", false, "Windows-style project path"),
1469 ];
1470
1471 for (path, should_confirm, description) in test_cases {
1472 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1473 let auth = cx.update(|cx| {
1474 tool.authorize(
1475 &EditFileToolInput {
1476 display_description: "Edit file".into(),
1477 path: path.into(),
1478 mode: EditFileMode::Edit,
1479 },
1480 &stream_tx,
1481 cx,
1482 )
1483 });
1484
1485 cx.run_until_parked();
1486
1487 if should_confirm {
1488 stream_rx.expect_authorization().await;
1489 } else {
1490 assert!(
1491 stream_rx.try_next().is_err(),
1492 "Failed for case: {} - path: {} - expected no confirmation but got one",
1493 description,
1494 path
1495 );
1496 auth.await.unwrap();
1497 }
1498 }
1499 }
1500
1501 #[gpui::test]
1502 async fn test_needs_confirmation_with_different_modes(cx: &mut TestAppContext) {
1503 init_test(cx);
1504 let fs = project::FakeFs::new(cx.executor());
1505 fs.insert_tree(
1506 "/project",
1507 json!({
1508 "existing.txt": "content",
1509 ".zed": {
1510 "settings.json": "{}"
1511 }
1512 }),
1513 )
1514 .await;
1515 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1516 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1517 let context_server_registry =
1518 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1519 let model = Arc::new(FakeLanguageModel::default());
1520 let thread = cx.new(|cx| {
1521 Thread::new(
1522 project.clone(),
1523 cx.new(|_cx| ProjectContext::default()),
1524 context_server_registry.clone(),
1525 Templates::new(),
1526 Some(model.clone()),
1527 cx,
1528 )
1529 });
1530 let tool = Arc::new(EditFileTool::new(
1531 project.clone(),
1532 thread.downgrade(),
1533 language_registry,
1534 Templates::new(),
1535 ));
1536
1537 // Test different EditFileMode values
1538 let modes = vec![
1539 EditFileMode::Edit,
1540 EditFileMode::Create,
1541 EditFileMode::Overwrite,
1542 ];
1543
1544 for mode in modes {
1545 // Test .zed path with different modes
1546 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1547 let _auth = cx.update(|cx| {
1548 tool.authorize(
1549 &EditFileToolInput {
1550 display_description: "Edit settings".into(),
1551 path: "project/.zed/settings.json".into(),
1552 mode: mode.clone(),
1553 },
1554 &stream_tx,
1555 cx,
1556 )
1557 });
1558
1559 stream_rx.expect_authorization().await;
1560
1561 // Test outside path with different modes
1562 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1563 let _auth = cx.update(|cx| {
1564 tool.authorize(
1565 &EditFileToolInput {
1566 display_description: "Edit file".into(),
1567 path: "/outside/file.txt".into(),
1568 mode: mode.clone(),
1569 },
1570 &stream_tx,
1571 cx,
1572 )
1573 });
1574
1575 stream_rx.expect_authorization().await;
1576
1577 // Test normal path with different modes
1578 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1579 cx.update(|cx| {
1580 tool.authorize(
1581 &EditFileToolInput {
1582 display_description: "Edit file".into(),
1583 path: "project/normal.txt".into(),
1584 mode: mode.clone(),
1585 },
1586 &stream_tx,
1587 cx,
1588 )
1589 })
1590 .await
1591 .unwrap();
1592 assert!(stream_rx.try_next().is_err());
1593 }
1594 }
1595
1596 #[gpui::test]
1597 async fn test_initial_title_with_partial_input(cx: &mut TestAppContext) {
1598 init_test(cx);
1599 let fs = project::FakeFs::new(cx.executor());
1600 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1601 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1602 let context_server_registry =
1603 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1604 let model = Arc::new(FakeLanguageModel::default());
1605 let thread = cx.new(|cx| {
1606 Thread::new(
1607 project.clone(),
1608 cx.new(|_cx| ProjectContext::default()),
1609 context_server_registry,
1610 Templates::new(),
1611 Some(model.clone()),
1612 cx,
1613 )
1614 });
1615 let tool = Arc::new(EditFileTool::new(
1616 project,
1617 thread.downgrade(),
1618 language_registry,
1619 Templates::new(),
1620 ));
1621
1622 cx.update(|cx| {
1623 // ...
1624 assert_eq!(
1625 tool.initial_title(
1626 Err(json!({
1627 "path": "src/main.rs",
1628 "display_description": "",
1629 "old_string": "old code",
1630 "new_string": "new code"
1631 })),
1632 cx
1633 ),
1634 "src/main.rs"
1635 );
1636 assert_eq!(
1637 tool.initial_title(
1638 Err(json!({
1639 "path": "",
1640 "display_description": "Fix error handling",
1641 "old_string": "old code",
1642 "new_string": "new code"
1643 })),
1644 cx
1645 ),
1646 "Fix error handling"
1647 );
1648 assert_eq!(
1649 tool.initial_title(
1650 Err(json!({
1651 "path": "src/main.rs",
1652 "display_description": "Fix error handling",
1653 "old_string": "old code",
1654 "new_string": "new code"
1655 })),
1656 cx
1657 ),
1658 "src/main.rs"
1659 );
1660 assert_eq!(
1661 tool.initial_title(
1662 Err(json!({
1663 "path": "",
1664 "display_description": "",
1665 "old_string": "old code",
1666 "new_string": "new code"
1667 })),
1668 cx
1669 ),
1670 DEFAULT_UI_TEXT
1671 );
1672 assert_eq!(
1673 tool.initial_title(Err(serde_json::Value::Null), cx),
1674 DEFAULT_UI_TEXT
1675 );
1676 });
1677 }
1678
1679 #[gpui::test]
1680 async fn test_diff_finalization(cx: &mut TestAppContext) {
1681 init_test(cx);
1682 let fs = project::FakeFs::new(cx.executor());
1683 fs.insert_tree("/", json!({"main.rs": ""})).await;
1684
1685 let project = Project::test(fs.clone(), [path!("/").as_ref()], cx).await;
1686 let languages = project.read_with(cx, |project, _cx| project.languages().clone());
1687 let context_server_registry =
1688 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1689 let model = Arc::new(FakeLanguageModel::default());
1690 let thread = cx.new(|cx| {
1691 Thread::new(
1692 project.clone(),
1693 cx.new(|_cx| ProjectContext::default()),
1694 context_server_registry.clone(),
1695 Templates::new(),
1696 Some(model.clone()),
1697 cx,
1698 )
1699 });
1700
1701 // Ensure the diff is finalized after the edit completes.
1702 {
1703 let tool = Arc::new(EditFileTool::new(
1704 project.clone(),
1705 thread.downgrade(),
1706 languages.clone(),
1707 Templates::new(),
1708 ));
1709 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1710 let edit = cx.update(|cx| {
1711 tool.run(
1712 EditFileToolInput {
1713 display_description: "Edit file".into(),
1714 path: path!("/main.rs").into(),
1715 mode: EditFileMode::Edit,
1716 },
1717 stream_tx,
1718 cx,
1719 )
1720 });
1721 stream_rx.expect_update_fields().await;
1722 let diff = stream_rx.expect_diff().await;
1723 diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Pending(_))));
1724 cx.run_until_parked();
1725 model.end_last_completion_stream();
1726 edit.await.unwrap();
1727 diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Finalized(_))));
1728 }
1729
1730 // Ensure the diff is finalized if an error occurs while editing.
1731 {
1732 model.forbid_requests();
1733 let tool = Arc::new(EditFileTool::new(
1734 project.clone(),
1735 thread.downgrade(),
1736 languages.clone(),
1737 Templates::new(),
1738 ));
1739 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1740 let edit = cx.update(|cx| {
1741 tool.run(
1742 EditFileToolInput {
1743 display_description: "Edit file".into(),
1744 path: path!("/main.rs").into(),
1745 mode: EditFileMode::Edit,
1746 },
1747 stream_tx,
1748 cx,
1749 )
1750 });
1751 stream_rx.expect_update_fields().await;
1752 let diff = stream_rx.expect_diff().await;
1753 diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Pending(_))));
1754 edit.await.unwrap_err();
1755 diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Finalized(_))));
1756 model.allow_requests();
1757 }
1758
1759 // Ensure the diff is finalized if the tool call gets dropped.
1760 {
1761 let tool = Arc::new(EditFileTool::new(
1762 project.clone(),
1763 thread.downgrade(),
1764 languages.clone(),
1765 Templates::new(),
1766 ));
1767 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1768 let edit = cx.update(|cx| {
1769 tool.run(
1770 EditFileToolInput {
1771 display_description: "Edit file".into(),
1772 path: path!("/main.rs").into(),
1773 mode: EditFileMode::Edit,
1774 },
1775 stream_tx,
1776 cx,
1777 )
1778 });
1779 stream_rx.expect_update_fields().await;
1780 let diff = stream_rx.expect_diff().await;
1781 diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Pending(_))));
1782 drop(edit);
1783 cx.run_until_parked();
1784 diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Finalized(_))));
1785 }
1786 }
1787
1788 #[gpui::test]
1789 async fn test_file_read_times_tracking(cx: &mut TestAppContext) {
1790 init_test(cx);
1791
1792 let fs = project::FakeFs::new(cx.executor());
1793 fs.insert_tree(
1794 "/root",
1795 json!({
1796 "test.txt": "original content"
1797 }),
1798 )
1799 .await;
1800 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
1801 let context_server_registry =
1802 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1803 let model = Arc::new(FakeLanguageModel::default());
1804 let thread = cx.new(|cx| {
1805 Thread::new(
1806 project.clone(),
1807 cx.new(|_cx| ProjectContext::default()),
1808 context_server_registry,
1809 Templates::new(),
1810 Some(model.clone()),
1811 cx,
1812 )
1813 });
1814 let action_log = thread.read_with(cx, |thread, _| thread.action_log().clone());
1815
1816 // Initially, file_read_times should be empty
1817 let is_empty = thread.read_with(cx, |thread, _| thread.file_read_times.is_empty());
1818 assert!(is_empty, "file_read_times should start empty");
1819
1820 // Create read tool
1821 let read_tool = Arc::new(crate::ReadFileTool::new(
1822 thread.downgrade(),
1823 project.clone(),
1824 action_log,
1825 ));
1826
1827 // Read the file to record the read time
1828 cx.update(|cx| {
1829 read_tool.clone().run(
1830 crate::ReadFileToolInput {
1831 path: "root/test.txt".to_string(),
1832 start_line: None,
1833 end_line: None,
1834 },
1835 ToolCallEventStream::test().0,
1836 cx,
1837 )
1838 })
1839 .await
1840 .unwrap();
1841
1842 // Verify that file_read_times now contains an entry for the file
1843 let has_entry = thread.read_with(cx, |thread, _| {
1844 thread.file_read_times.len() == 1
1845 && thread
1846 .file_read_times
1847 .keys()
1848 .any(|path| path.ends_with("test.txt"))
1849 });
1850 assert!(
1851 has_entry,
1852 "file_read_times should contain an entry after reading the file"
1853 );
1854
1855 // Read the file again - should update the entry
1856 cx.update(|cx| {
1857 read_tool.clone().run(
1858 crate::ReadFileToolInput {
1859 path: "root/test.txt".to_string(),
1860 start_line: None,
1861 end_line: None,
1862 },
1863 ToolCallEventStream::test().0,
1864 cx,
1865 )
1866 })
1867 .await
1868 .unwrap();
1869
1870 // Should still have exactly one entry
1871 let has_one_entry = thread.read_with(cx, |thread, _| thread.file_read_times.len() == 1);
1872 assert!(
1873 has_one_entry,
1874 "file_read_times should still have one entry after re-reading"
1875 );
1876 }
1877
1878 fn init_test(cx: &mut TestAppContext) {
1879 cx.update(|cx| {
1880 let settings_store = SettingsStore::test(cx);
1881 cx.set_global(settings_store);
1882 });
1883 }
1884
1885 #[gpui::test]
1886 async fn test_consecutive_edits_work(cx: &mut TestAppContext) {
1887 init_test(cx);
1888
1889 let fs = project::FakeFs::new(cx.executor());
1890 fs.insert_tree(
1891 "/root",
1892 json!({
1893 "test.txt": "original content"
1894 }),
1895 )
1896 .await;
1897 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
1898 let context_server_registry =
1899 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1900 let model = Arc::new(FakeLanguageModel::default());
1901 let thread = cx.new(|cx| {
1902 Thread::new(
1903 project.clone(),
1904 cx.new(|_cx| ProjectContext::default()),
1905 context_server_registry,
1906 Templates::new(),
1907 Some(model.clone()),
1908 cx,
1909 )
1910 });
1911 let languages = project.read_with(cx, |project, _| project.languages().clone());
1912 let action_log = thread.read_with(cx, |thread, _| thread.action_log().clone());
1913
1914 let read_tool = Arc::new(crate::ReadFileTool::new(
1915 thread.downgrade(),
1916 project.clone(),
1917 action_log,
1918 ));
1919 let edit_tool = Arc::new(EditFileTool::new(
1920 project.clone(),
1921 thread.downgrade(),
1922 languages,
1923 Templates::new(),
1924 ));
1925
1926 // Read the file first
1927 cx.update(|cx| {
1928 read_tool.clone().run(
1929 crate::ReadFileToolInput {
1930 path: "root/test.txt".to_string(),
1931 start_line: None,
1932 end_line: None,
1933 },
1934 ToolCallEventStream::test().0,
1935 cx,
1936 )
1937 })
1938 .await
1939 .unwrap();
1940
1941 // First edit should work
1942 let edit_result = {
1943 let edit_task = cx.update(|cx| {
1944 edit_tool.clone().run(
1945 EditFileToolInput {
1946 display_description: "First edit".into(),
1947 path: "root/test.txt".into(),
1948 mode: EditFileMode::Edit,
1949 },
1950 ToolCallEventStream::test().0,
1951 cx,
1952 )
1953 });
1954
1955 cx.executor().run_until_parked();
1956 model.send_last_completion_stream_text_chunk(
1957 "<old_text>original content</old_text><new_text>modified content</new_text>"
1958 .to_string(),
1959 );
1960 model.end_last_completion_stream();
1961
1962 edit_task.await
1963 };
1964 assert!(
1965 edit_result.is_ok(),
1966 "First edit should succeed, got error: {:?}",
1967 edit_result.as_ref().err()
1968 );
1969
1970 // Second edit should also work because the edit updated the recorded read time
1971 let edit_result = {
1972 let edit_task = cx.update(|cx| {
1973 edit_tool.clone().run(
1974 EditFileToolInput {
1975 display_description: "Second edit".into(),
1976 path: "root/test.txt".into(),
1977 mode: EditFileMode::Edit,
1978 },
1979 ToolCallEventStream::test().0,
1980 cx,
1981 )
1982 });
1983
1984 cx.executor().run_until_parked();
1985 model.send_last_completion_stream_text_chunk(
1986 "<old_text>modified content</old_text><new_text>further modified content</new_text>".to_string(),
1987 );
1988 model.end_last_completion_stream();
1989
1990 edit_task.await
1991 };
1992 assert!(
1993 edit_result.is_ok(),
1994 "Second consecutive edit should succeed, got error: {:?}",
1995 edit_result.as_ref().err()
1996 );
1997 }
1998
1999 #[gpui::test]
2000 async fn test_external_modification_detected(cx: &mut TestAppContext) {
2001 init_test(cx);
2002
2003 let fs = project::FakeFs::new(cx.executor());
2004 fs.insert_tree(
2005 "/root",
2006 json!({
2007 "test.txt": "original content"
2008 }),
2009 )
2010 .await;
2011 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
2012 let context_server_registry =
2013 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
2014 let model = Arc::new(FakeLanguageModel::default());
2015 let thread = cx.new(|cx| {
2016 Thread::new(
2017 project.clone(),
2018 cx.new(|_cx| ProjectContext::default()),
2019 context_server_registry,
2020 Templates::new(),
2021 Some(model.clone()),
2022 cx,
2023 )
2024 });
2025 let languages = project.read_with(cx, |project, _| project.languages().clone());
2026 let action_log = thread.read_with(cx, |thread, _| thread.action_log().clone());
2027
2028 let read_tool = Arc::new(crate::ReadFileTool::new(
2029 thread.downgrade(),
2030 project.clone(),
2031 action_log,
2032 ));
2033 let edit_tool = Arc::new(EditFileTool::new(
2034 project.clone(),
2035 thread.downgrade(),
2036 languages,
2037 Templates::new(),
2038 ));
2039
2040 // Read the file first
2041 cx.update(|cx| {
2042 read_tool.clone().run(
2043 crate::ReadFileToolInput {
2044 path: "root/test.txt".to_string(),
2045 start_line: None,
2046 end_line: None,
2047 },
2048 ToolCallEventStream::test().0,
2049 cx,
2050 )
2051 })
2052 .await
2053 .unwrap();
2054
2055 // Simulate external modification - advance time and save file
2056 cx.background_executor
2057 .advance_clock(std::time::Duration::from_secs(2));
2058 fs.save(
2059 path!("/root/test.txt").as_ref(),
2060 &"externally modified content".into(),
2061 language::LineEnding::Unix,
2062 )
2063 .await
2064 .unwrap();
2065
2066 // Reload the buffer to pick up the new mtime
2067 let project_path = project
2068 .read_with(cx, |project, cx| {
2069 project.find_project_path("root/test.txt", cx)
2070 })
2071 .expect("Should find project path");
2072 let buffer = project
2073 .update(cx, |project, cx| project.open_buffer(project_path, cx))
2074 .await
2075 .unwrap();
2076 buffer
2077 .update(cx, |buffer, cx| buffer.reload(cx))
2078 .await
2079 .unwrap();
2080
2081 cx.executor().run_until_parked();
2082
2083 // Try to edit - should fail because file was modified externally
2084 let result = cx
2085 .update(|cx| {
2086 edit_tool.clone().run(
2087 EditFileToolInput {
2088 display_description: "Edit after external change".into(),
2089 path: "root/test.txt".into(),
2090 mode: EditFileMode::Edit,
2091 },
2092 ToolCallEventStream::test().0,
2093 cx,
2094 )
2095 })
2096 .await;
2097
2098 assert!(
2099 result.is_err(),
2100 "Edit should fail after external modification"
2101 );
2102 let error_msg = result.unwrap_err().to_string();
2103 assert!(
2104 error_msg.contains("has been modified since you last read it"),
2105 "Error should mention file modification, got: {}",
2106 error_msg
2107 );
2108 }
2109
2110 #[gpui::test]
2111 async fn test_dirty_buffer_detected(cx: &mut TestAppContext) {
2112 init_test(cx);
2113
2114 let fs = project::FakeFs::new(cx.executor());
2115 fs.insert_tree(
2116 "/root",
2117 json!({
2118 "test.txt": "original content"
2119 }),
2120 )
2121 .await;
2122 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
2123 let context_server_registry =
2124 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
2125 let model = Arc::new(FakeLanguageModel::default());
2126 let thread = cx.new(|cx| {
2127 Thread::new(
2128 project.clone(),
2129 cx.new(|_cx| ProjectContext::default()),
2130 context_server_registry,
2131 Templates::new(),
2132 Some(model.clone()),
2133 cx,
2134 )
2135 });
2136 let languages = project.read_with(cx, |project, _| project.languages().clone());
2137 let action_log = thread.read_with(cx, |thread, _| thread.action_log().clone());
2138
2139 let read_tool = Arc::new(crate::ReadFileTool::new(
2140 thread.downgrade(),
2141 project.clone(),
2142 action_log,
2143 ));
2144 let edit_tool = Arc::new(EditFileTool::new(
2145 project.clone(),
2146 thread.downgrade(),
2147 languages,
2148 Templates::new(),
2149 ));
2150
2151 // Read the file first
2152 cx.update(|cx| {
2153 read_tool.clone().run(
2154 crate::ReadFileToolInput {
2155 path: "root/test.txt".to_string(),
2156 start_line: None,
2157 end_line: None,
2158 },
2159 ToolCallEventStream::test().0,
2160 cx,
2161 )
2162 })
2163 .await
2164 .unwrap();
2165
2166 // Open the buffer and make it dirty by editing without saving
2167 let project_path = project
2168 .read_with(cx, |project, cx| {
2169 project.find_project_path("root/test.txt", cx)
2170 })
2171 .expect("Should find project path");
2172 let buffer = project
2173 .update(cx, |project, cx| project.open_buffer(project_path, cx))
2174 .await
2175 .unwrap();
2176
2177 // Make an in-memory edit to the buffer (making it dirty)
2178 buffer.update(cx, |buffer, cx| {
2179 let end_point = buffer.max_point();
2180 buffer.edit([(end_point..end_point, " added text")], None, cx);
2181 });
2182
2183 // Verify buffer is dirty
2184 let is_dirty = buffer.read_with(cx, |buffer, _| buffer.is_dirty());
2185 assert!(is_dirty, "Buffer should be dirty after in-memory edit");
2186
2187 // Try to edit - should fail because buffer has unsaved changes
2188 let result = cx
2189 .update(|cx| {
2190 edit_tool.clone().run(
2191 EditFileToolInput {
2192 display_description: "Edit with dirty buffer".into(),
2193 path: "root/test.txt".into(),
2194 mode: EditFileMode::Edit,
2195 },
2196 ToolCallEventStream::test().0,
2197 cx,
2198 )
2199 })
2200 .await;
2201
2202 assert!(result.is_err(), "Edit should fail when buffer is dirty");
2203 let error_msg = result.unwrap_err().to_string();
2204 assert!(
2205 error_msg.contains("cannot be written to because it has unsaved changes"),
2206 "Error should mention unsaved changes, got: {}",
2207 error_msg
2208 );
2209 }
2210}