1use crate::{
2 AgentTool, Templates, Thread, ToolCallEventStream,
3 edit_agent::{EditAgent, EditAgentOutput, EditAgentOutputEvent, EditFormat},
4};
5use acp_thread::Diff;
6use agent_client_protocol::{self as acp, ToolCallLocation, ToolCallUpdateFields};
7use anyhow::{Context as _, Result, anyhow};
8use cloud_llm_client::CompletionIntent;
9use collections::HashSet;
10use gpui::{App, AppContext, AsyncApp, Entity, Task, WeakEntity};
11use indoc::formatdoc;
12use language::language_settings::{self, FormatOnSave};
13use language::{LanguageRegistry, ToPoint};
14use language_model::LanguageModelToolResultContent;
15use paths;
16use project::lsp_store::{FormatTrigger, LspFormatTarget};
17use project::{Project, ProjectPath};
18use schemars::JsonSchema;
19use serde::{Deserialize, Serialize};
20use settings::Settings;
21use smol::stream::StreamExt as _;
22use std::ffi::OsStr;
23use std::path::{Path, PathBuf};
24use std::sync::Arc;
25use ui::SharedString;
26use util::ResultExt;
27use util::rel_path::RelPath;
28
29const DEFAULT_UI_TEXT: &str = "Editing file";
30
31/// This is a tool for creating a new file or editing an existing file. For moving or renaming files, you should generally use the `terminal` tool with the 'mv' command instead.
32///
33/// Before using this tool:
34///
35/// 1. Use the `read_file` tool to understand the file's contents and context
36///
37/// 2. Verify the directory path is correct (only applicable when creating new files):
38/// - Use the `list_directory` tool to verify the parent directory exists and is the correct location
39#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
40pub struct EditFileToolInput {
41 /// A one-line, user-friendly markdown description of the edit. This will be shown in the UI and also passed to another model to perform the edit.
42 ///
43 /// Be terse, but also descriptive in what you want to achieve with this edit. Avoid generic instructions.
44 ///
45 /// NEVER mention the file path in this description.
46 ///
47 /// IMPORTANT: Do NOT include markdown code fences (```) or other markdown formatting in this description.
48 /// Just describe what should be done - another model will generate the actual content.
49 ///
50 /// <example>Fix API endpoint URLs</example>
51 /// <example>Update copyright year in `page_footer`</example>
52 /// <example>Create a Python script that prints hello world</example>
53 ///
54 /// INCORRECT examples (do not do this):
55 /// <example>Create a file with:\n```python\nprint('hello')\n```</example>
56 /// <example>Add this code:\n```\nif err:\n return\n```</example>
57 ///
58 /// Make sure to include this field before all the others in the input object so that we can display it immediately.
59 pub display_description: String,
60
61 /// The full path of the file to create or modify in the project.
62 ///
63 /// WARNING: When specifying which file path need changing, you MUST start each path with one of the project's root directories.
64 ///
65 /// The following examples assume we have two root directories in the project:
66 /// - /a/b/backend
67 /// - /c/d/frontend
68 ///
69 /// <example>
70 /// `backend/src/main.rs`
71 ///
72 /// Notice how the file path starts with `backend`. Without that, the path would be ambiguous and the call would fail!
73 /// </example>
74 ///
75 /// <example>
76 /// `frontend/db.js`
77 /// </example>
78 pub path: PathBuf,
79 /// The mode of operation on the file. Possible values:
80 /// - 'edit': Make granular edits to an existing file.
81 /// - 'create': Create a new file if it doesn't exist.
82 /// - 'overwrite': Replace the entire contents of an existing file.
83 ///
84 /// When a file already exists or you just created it, prefer editing it as opposed to recreating it from scratch.
85 pub mode: EditFileMode,
86}
87
88#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
89struct EditFileToolPartialInput {
90 #[serde(default)]
91 path: String,
92 #[serde(default)]
93 display_description: String,
94}
95
96#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
97#[serde(rename_all = "lowercase")]
98#[schemars(inline)]
99pub enum EditFileMode {
100 Edit,
101 Create,
102 Overwrite,
103}
104
105#[derive(Debug, Serialize, Deserialize)]
106pub struct EditFileToolOutput {
107 #[serde(alias = "original_path")]
108 input_path: PathBuf,
109 new_text: String,
110 old_text: Arc<String>,
111 #[serde(default)]
112 diff: String,
113 #[serde(alias = "raw_output")]
114 edit_agent_output: EditAgentOutput,
115}
116
117impl From<EditFileToolOutput> for LanguageModelToolResultContent {
118 fn from(output: EditFileToolOutput) -> Self {
119 if output.diff.is_empty() {
120 "No edits were made.".into()
121 } else {
122 format!(
123 "Edited {}:\n\n```diff\n{}\n```",
124 output.input_path.display(),
125 output.diff
126 )
127 .into()
128 }
129 }
130}
131
132pub struct EditFileTool {
133 thread: WeakEntity<Thread>,
134 language_registry: Arc<LanguageRegistry>,
135 project: Entity<Project>,
136 templates: Arc<Templates>,
137}
138
139impl EditFileTool {
140 pub fn new(
141 project: Entity<Project>,
142 thread: WeakEntity<Thread>,
143 language_registry: Arc<LanguageRegistry>,
144 templates: Arc<Templates>,
145 ) -> Self {
146 Self {
147 project,
148 thread,
149 language_registry,
150 templates,
151 }
152 }
153
154 fn authorize(
155 &self,
156 input: &EditFileToolInput,
157 event_stream: &ToolCallEventStream,
158 cx: &mut App,
159 ) -> Task<Result<()>> {
160 if agent_settings::AgentSettings::get_global(cx).always_allow_tool_actions {
161 return Task::ready(Ok(()));
162 }
163
164 // If any path component matches the local settings folder, then this could affect
165 // the editor in ways beyond the project source, so prompt.
166 let local_settings_folder = paths::local_settings_folder_name();
167 let path = Path::new(&input.path);
168 if path.components().any(|component| {
169 component.as_os_str() == <_ as AsRef<OsStr>>::as_ref(&local_settings_folder)
170 }) {
171 return event_stream.authorize(
172 format!("{} (local settings)", input.display_description),
173 cx,
174 );
175 }
176
177 // It's also possible that the global config dir is configured to be inside the project,
178 // so check for that edge case too.
179 // TODO this is broken when remoting
180 if let Ok(canonical_path) = std::fs::canonicalize(&input.path)
181 && canonical_path.starts_with(paths::config_dir())
182 {
183 return event_stream.authorize(
184 format!("{} (global settings)", input.display_description),
185 cx,
186 );
187 }
188
189 // Check if path is inside the global config directory
190 // First check if it's already inside project - if not, try to canonicalize
191 let Ok(project_path) = self.thread.read_with(cx, |thread, cx| {
192 thread.project().read(cx).find_project_path(&input.path, cx)
193 }) else {
194 return Task::ready(Err(anyhow!("thread was dropped")));
195 };
196
197 // If the path is inside the project, and it's not one of the above edge cases,
198 // then no confirmation is necessary. Otherwise, confirmation is necessary.
199 if project_path.is_some() {
200 Task::ready(Ok(()))
201 } else {
202 event_stream.authorize(&input.display_description, cx)
203 }
204 }
205}
206
207impl AgentTool for EditFileTool {
208 type Input = EditFileToolInput;
209 type Output = EditFileToolOutput;
210
211 fn name() -> &'static str {
212 "edit_file"
213 }
214
215 fn kind() -> acp::ToolKind {
216 acp::ToolKind::Edit
217 }
218
219 fn initial_title(
220 &self,
221 input: Result<Self::Input, serde_json::Value>,
222 cx: &mut App,
223 ) -> SharedString {
224 match input {
225 Ok(input) => self
226 .project
227 .read(cx)
228 .find_project_path(&input.path, cx)
229 .and_then(|project_path| {
230 self.project
231 .read(cx)
232 .short_full_path_for_project_path(&project_path, cx)
233 })
234 .unwrap_or(input.path.to_string_lossy().into_owned())
235 .into(),
236 Err(raw_input) => {
237 if let Some(input) =
238 serde_json::from_value::<EditFileToolPartialInput>(raw_input).ok()
239 {
240 let path = input.path.trim();
241 if !path.is_empty() {
242 return self
243 .project
244 .read(cx)
245 .find_project_path(&input.path, cx)
246 .and_then(|project_path| {
247 self.project
248 .read(cx)
249 .short_full_path_for_project_path(&project_path, cx)
250 })
251 .unwrap_or(input.path)
252 .into();
253 }
254
255 let description = input.display_description.trim();
256 if !description.is_empty() {
257 return description.to_string().into();
258 }
259 }
260
261 DEFAULT_UI_TEXT.into()
262 }
263 }
264 }
265
266 fn run(
267 self: Arc<Self>,
268 input: Self::Input,
269 event_stream: ToolCallEventStream,
270 cx: &mut App,
271 ) -> Task<Result<Self::Output>> {
272 let Ok(project) = self
273 .thread
274 .read_with(cx, |thread, _cx| thread.project().clone())
275 else {
276 return Task::ready(Err(anyhow!("thread was dropped")));
277 };
278 let project_path = match resolve_path(&input, project.clone(), cx) {
279 Ok(path) => path,
280 Err(err) => return Task::ready(Err(anyhow!(err))),
281 };
282 let abs_path = project.read(cx).absolute_path(&project_path, cx);
283 if let Some(abs_path) = abs_path.clone() {
284 event_stream.update_fields(ToolCallUpdateFields {
285 locations: Some(vec![acp::ToolCallLocation {
286 path: abs_path,
287 line: None,
288 meta: None,
289 }]),
290 ..Default::default()
291 });
292 }
293
294 let authorize = self.authorize(&input, &event_stream, cx);
295 cx.spawn(async move |cx: &mut AsyncApp| {
296 authorize.await?;
297
298 let (request, model, action_log) = self.thread.update(cx, |thread, cx| {
299 let request = thread.build_completion_request(CompletionIntent::ToolResults, cx);
300 (request, thread.model().cloned(), thread.action_log().clone())
301 })?;
302 let request = request?;
303 let model = model.context("No language model configured")?;
304
305 let edit_format = EditFormat::from_model(model.clone())?;
306 let edit_agent = EditAgent::new(
307 model,
308 project.clone(),
309 action_log.clone(),
310 self.templates.clone(),
311 edit_format,
312 );
313
314 let buffer = project
315 .update(cx, |project, cx| {
316 project.open_buffer(project_path.clone(), cx)
317 })?
318 .await?;
319
320 let diff = cx.new(|cx| Diff::new(buffer.clone(), cx))?;
321 event_stream.update_diff(diff.clone());
322 let _finalize_diff = util::defer({
323 let diff = diff.downgrade();
324 let mut cx = cx.clone();
325 move || {
326 diff.update(&mut cx, |diff, cx| diff.finalize(cx)).ok();
327 }
328 });
329
330 let old_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
331 let old_text = cx
332 .background_spawn({
333 let old_snapshot = old_snapshot.clone();
334 async move { Arc::new(old_snapshot.text()) }
335 })
336 .await;
337
338
339 let (output, mut events) = if matches!(input.mode, EditFileMode::Edit) {
340 edit_agent.edit(
341 buffer.clone(),
342 input.display_description.clone(),
343 &request,
344 cx,
345 )
346 } else {
347 edit_agent.overwrite(
348 buffer.clone(),
349 input.display_description.clone(),
350 &request,
351 cx,
352 )
353 };
354
355 let mut hallucinated_old_text = false;
356 let mut ambiguous_ranges = Vec::new();
357 let mut emitted_location = false;
358 while let Some(event) = events.next().await {
359 match event {
360 EditAgentOutputEvent::Edited(range) => {
361 if !emitted_location {
362 let line = buffer.update(cx, |buffer, _cx| {
363 range.start.to_point(&buffer.snapshot()).row
364 }).ok();
365 if let Some(abs_path) = abs_path.clone() {
366 event_stream.update_fields(ToolCallUpdateFields {
367 locations: Some(vec![ToolCallLocation { path: abs_path, line, meta: None }]),
368 ..Default::default()
369 });
370 }
371 emitted_location = true;
372 }
373 },
374 EditAgentOutputEvent::UnresolvedEditRange => hallucinated_old_text = true,
375 EditAgentOutputEvent::AmbiguousEditRange(ranges) => ambiguous_ranges = ranges,
376 EditAgentOutputEvent::ResolvingEditRange(range) => {
377 diff.update(cx, |card, cx| card.reveal_range(range.clone(), cx))?;
378 // if !emitted_location {
379 // let line = buffer.update(cx, |buffer, _cx| {
380 // range.start.to_point(&buffer.snapshot()).row
381 // }).ok();
382 // if let Some(abs_path) = abs_path.clone() {
383 // event_stream.update_fields(ToolCallUpdateFields {
384 // locations: Some(vec![ToolCallLocation { path: abs_path, line }]),
385 // ..Default::default()
386 // });
387 // }
388 // }
389 }
390 }
391 }
392
393 // If format_on_save is enabled, format the buffer
394 let format_on_save_enabled = buffer
395 .read_with(cx, |buffer, cx| {
396 let settings = language_settings::language_settings(
397 buffer.language().map(|l| l.name()),
398 buffer.file(),
399 cx,
400 );
401 settings.format_on_save != FormatOnSave::Off
402 })
403 .unwrap_or(false);
404
405 let edit_agent_output = output.await?;
406
407 if format_on_save_enabled {
408 action_log.update(cx, |log, cx| {
409 log.buffer_edited(buffer.clone(), cx);
410 })?;
411
412 let format_task = project.update(cx, |project, cx| {
413 project.format(
414 HashSet::from_iter([buffer.clone()]),
415 LspFormatTarget::Buffers,
416 false, // Don't push to history since the tool did it.
417 FormatTrigger::Save,
418 cx,
419 )
420 })?;
421 format_task.await.log_err();
422 }
423
424 project
425 .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))?
426 .await?;
427
428 action_log.update(cx, |log, cx| {
429 log.buffer_edited(buffer.clone(), cx);
430 })?;
431
432 let new_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
433 let (new_text, unified_diff) = cx
434 .background_spawn({
435 let new_snapshot = new_snapshot.clone();
436 let old_text = old_text.clone();
437 async move {
438 let new_text = new_snapshot.text();
439 let diff = language::unified_diff(&old_text, &new_text);
440 (new_text, diff)
441 }
442 })
443 .await;
444
445 let input_path = input.path.display();
446 if unified_diff.is_empty() {
447 anyhow::ensure!(
448 !hallucinated_old_text,
449 formatdoc! {"
450 Some edits were produced but none of them could be applied.
451 Read the relevant sections of {input_path} again so that
452 I can perform the requested edits.
453 "}
454 );
455 anyhow::ensure!(
456 ambiguous_ranges.is_empty(),
457 {
458 let line_numbers = ambiguous_ranges
459 .iter()
460 .map(|range| range.start.to_string())
461 .collect::<Vec<_>>()
462 .join(", ");
463 formatdoc! {"
464 <old_text> matches more than one position in the file (lines: {line_numbers}). Read the
465 relevant sections of {input_path} again and extend <old_text> so
466 that I can perform the requested edits.
467 "}
468 }
469 );
470 }
471
472 Ok(EditFileToolOutput {
473 input_path: input.path,
474 new_text,
475 old_text,
476 diff: unified_diff,
477 edit_agent_output,
478 })
479 })
480 }
481
482 fn replay(
483 &self,
484 _input: Self::Input,
485 output: Self::Output,
486 event_stream: ToolCallEventStream,
487 cx: &mut App,
488 ) -> Result<()> {
489 event_stream.update_diff(cx.new(|cx| {
490 Diff::finalized(
491 output.input_path.to_string_lossy().into_owned(),
492 Some(output.old_text.to_string()),
493 output.new_text,
494 self.language_registry.clone(),
495 cx,
496 )
497 }));
498 Ok(())
499 }
500}
501
502/// Validate that the file path is valid, meaning:
503///
504/// - For `edit` and `overwrite`, the path must point to an existing file.
505/// - For `create`, the file must not already exist, but it's parent dir must exist.
506fn resolve_path(
507 input: &EditFileToolInput,
508 project: Entity<Project>,
509 cx: &mut App,
510) -> Result<ProjectPath> {
511 let project = project.read(cx);
512
513 match input.mode {
514 EditFileMode::Edit | EditFileMode::Overwrite => {
515 let path = project
516 .find_project_path(&input.path, cx)
517 .context("Can't edit file: path not found")?;
518
519 let entry = project
520 .entry_for_path(&path, cx)
521 .context("Can't edit file: path not found")?;
522
523 anyhow::ensure!(entry.is_file(), "Can't edit file: path is a directory");
524 Ok(path)
525 }
526
527 EditFileMode::Create => {
528 if let Some(path) = project.find_project_path(&input.path, cx) {
529 anyhow::ensure!(
530 project.entry_for_path(&path, cx).is_none(),
531 "Can't create file: file already exists"
532 );
533 }
534
535 let parent_path = input
536 .path
537 .parent()
538 .context("Can't create file: incorrect path")?;
539
540 let parent_project_path = project.find_project_path(&parent_path, cx);
541
542 let parent_entry = parent_project_path
543 .as_ref()
544 .and_then(|path| project.entry_for_path(path, cx))
545 .context("Can't create file: parent directory doesn't exist")?;
546
547 anyhow::ensure!(
548 parent_entry.is_dir(),
549 "Can't create file: parent is not a directory"
550 );
551
552 let file_name = input
553 .path
554 .file_name()
555 .and_then(|file_name| file_name.to_str())
556 .and_then(|file_name| RelPath::unix(file_name).ok())
557 .context("Can't create file: invalid filename")?;
558
559 let new_file_path = parent_project_path.map(|parent| ProjectPath {
560 path: parent.path.join(file_name),
561 ..parent
562 });
563
564 new_file_path.context("Can't create file")
565 }
566 }
567}
568
569#[cfg(test)]
570mod tests {
571 use super::*;
572 use crate::{ContextServerRegistry, Templates};
573 use client::TelemetrySettings;
574 use fs::Fs;
575 use gpui::{TestAppContext, UpdateGlobal};
576 use language_model::fake_provider::FakeLanguageModel;
577 use prompt_store::ProjectContext;
578 use serde_json::json;
579 use settings::SettingsStore;
580 use util::{path, rel_path::rel_path};
581
582 #[gpui::test]
583 async fn test_edit_nonexistent_file(cx: &mut TestAppContext) {
584 init_test(cx);
585
586 let fs = project::FakeFs::new(cx.executor());
587 fs.insert_tree("/root", json!({})).await;
588 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
589 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
590 let context_server_registry =
591 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
592 let model = Arc::new(FakeLanguageModel::default());
593 let thread = cx.new(|cx| {
594 Thread::new(
595 project.clone(),
596 cx.new(|_cx| ProjectContext::default()),
597 context_server_registry,
598 Templates::new(),
599 Some(model),
600 cx,
601 )
602 });
603 let result = cx
604 .update(|cx| {
605 let input = EditFileToolInput {
606 display_description: "Some edit".into(),
607 path: "root/nonexistent_file.txt".into(),
608 mode: EditFileMode::Edit,
609 };
610 Arc::new(EditFileTool::new(
611 project,
612 thread.downgrade(),
613 language_registry,
614 Templates::new(),
615 ))
616 .run(input, ToolCallEventStream::test().0, cx)
617 })
618 .await;
619 assert_eq!(
620 result.unwrap_err().to_string(),
621 "Can't edit file: path not found"
622 );
623 }
624
625 #[gpui::test]
626 async fn test_resolve_path_for_creating_file(cx: &mut TestAppContext) {
627 let mode = &EditFileMode::Create;
628
629 let result = test_resolve_path(mode, "root/new.txt", cx);
630 assert_resolved_path_eq(result.await, rel_path("new.txt"));
631
632 let result = test_resolve_path(mode, "new.txt", cx);
633 assert_resolved_path_eq(result.await, rel_path("new.txt"));
634
635 let result = test_resolve_path(mode, "dir/new.txt", cx);
636 assert_resolved_path_eq(result.await, rel_path("dir/new.txt"));
637
638 let result = test_resolve_path(mode, "root/dir/subdir/existing.txt", cx);
639 assert_eq!(
640 result.await.unwrap_err().to_string(),
641 "Can't create file: file already exists"
642 );
643
644 let result = test_resolve_path(mode, "root/dir/nonexistent_dir/new.txt", cx);
645 assert_eq!(
646 result.await.unwrap_err().to_string(),
647 "Can't create file: parent directory doesn't exist"
648 );
649 }
650
651 #[gpui::test]
652 async fn test_resolve_path_for_editing_file(cx: &mut TestAppContext) {
653 let mode = &EditFileMode::Edit;
654
655 let path_with_root = "root/dir/subdir/existing.txt";
656 let path_without_root = "dir/subdir/existing.txt";
657 let result = test_resolve_path(mode, path_with_root, cx);
658 assert_resolved_path_eq(result.await, rel_path(path_without_root));
659
660 let result = test_resolve_path(mode, path_without_root, cx);
661 assert_resolved_path_eq(result.await, rel_path(path_without_root));
662
663 let result = test_resolve_path(mode, "root/nonexistent.txt", cx);
664 assert_eq!(
665 result.await.unwrap_err().to_string(),
666 "Can't edit file: path not found"
667 );
668
669 let result = test_resolve_path(mode, "root/dir", cx);
670 assert_eq!(
671 result.await.unwrap_err().to_string(),
672 "Can't edit file: path is a directory"
673 );
674 }
675
676 async fn test_resolve_path(
677 mode: &EditFileMode,
678 path: &str,
679 cx: &mut TestAppContext,
680 ) -> anyhow::Result<ProjectPath> {
681 init_test(cx);
682
683 let fs = project::FakeFs::new(cx.executor());
684 fs.insert_tree(
685 "/root",
686 json!({
687 "dir": {
688 "subdir": {
689 "existing.txt": "hello"
690 }
691 }
692 }),
693 )
694 .await;
695 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
696
697 let input = EditFileToolInput {
698 display_description: "Some edit".into(),
699 path: path.into(),
700 mode: mode.clone(),
701 };
702
703 cx.update(|cx| resolve_path(&input, project, cx))
704 }
705
706 #[track_caller]
707 fn assert_resolved_path_eq(path: anyhow::Result<ProjectPath>, expected: &RelPath) {
708 let actual = path.expect("Should return valid path").path;
709 assert_eq!(actual.as_ref(), expected);
710 }
711
712 #[gpui::test]
713 async fn test_format_on_save(cx: &mut TestAppContext) {
714 init_test(cx);
715
716 let fs = project::FakeFs::new(cx.executor());
717 fs.insert_tree("/root", json!({"src": {}})).await;
718
719 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
720
721 // Set up a Rust language with LSP formatting support
722 let rust_language = Arc::new(language::Language::new(
723 language::LanguageConfig {
724 name: "Rust".into(),
725 matcher: language::LanguageMatcher {
726 path_suffixes: vec!["rs".to_string()],
727 ..Default::default()
728 },
729 ..Default::default()
730 },
731 None,
732 ));
733
734 // Register the language and fake LSP
735 let language_registry = project.read_with(cx, |project, _| project.languages().clone());
736 language_registry.add(rust_language);
737
738 let mut fake_language_servers = language_registry.register_fake_lsp(
739 "Rust",
740 language::FakeLspAdapter {
741 capabilities: lsp::ServerCapabilities {
742 document_formatting_provider: Some(lsp::OneOf::Left(true)),
743 ..Default::default()
744 },
745 ..Default::default()
746 },
747 );
748
749 // Create the file
750 fs.save(
751 path!("/root/src/main.rs").as_ref(),
752 &"initial content".into(),
753 language::LineEnding::Unix,
754 )
755 .await
756 .unwrap();
757
758 // Open the buffer to trigger LSP initialization
759 let buffer = project
760 .update(cx, |project, cx| {
761 project.open_local_buffer(path!("/root/src/main.rs"), cx)
762 })
763 .await
764 .unwrap();
765
766 // Register the buffer with language servers
767 let _handle = project.update(cx, |project, cx| {
768 project.register_buffer_with_language_servers(&buffer, cx)
769 });
770
771 const UNFORMATTED_CONTENT: &str = "fn main() {println!(\"Hello!\");}\n";
772 const FORMATTED_CONTENT: &str =
773 "This file was formatted by the fake formatter in the test.\n";
774
775 // Get the fake language server and set up formatting handler
776 let fake_language_server = fake_language_servers.next().await.unwrap();
777 fake_language_server.set_request_handler::<lsp::request::Formatting, _, _>({
778 |_, _| async move {
779 Ok(Some(vec![lsp::TextEdit {
780 range: lsp::Range::new(lsp::Position::new(0, 0), lsp::Position::new(1, 0)),
781 new_text: FORMATTED_CONTENT.to_string(),
782 }]))
783 }
784 });
785
786 let context_server_registry =
787 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
788 let model = Arc::new(FakeLanguageModel::default());
789 let thread = cx.new(|cx| {
790 Thread::new(
791 project.clone(),
792 cx.new(|_cx| ProjectContext::default()),
793 context_server_registry,
794 Templates::new(),
795 Some(model.clone()),
796 cx,
797 )
798 });
799
800 // First, test with format_on_save enabled
801 cx.update(|cx| {
802 SettingsStore::update_global(cx, |store, cx| {
803 store.update_user_settings(cx, |settings| {
804 settings.project.all_languages.defaults.format_on_save = Some(FormatOnSave::On);
805 settings.project.all_languages.defaults.formatter =
806 Some(language::language_settings::FormatterList::default());
807 });
808 });
809 });
810
811 // Have the model stream unformatted content
812 let edit_result = {
813 let edit_task = cx.update(|cx| {
814 let input = EditFileToolInput {
815 display_description: "Create main function".into(),
816 path: "root/src/main.rs".into(),
817 mode: EditFileMode::Overwrite,
818 };
819 Arc::new(EditFileTool::new(
820 project.clone(),
821 thread.downgrade(),
822 language_registry.clone(),
823 Templates::new(),
824 ))
825 .run(input, ToolCallEventStream::test().0, cx)
826 });
827
828 // Stream the unformatted content
829 cx.executor().run_until_parked();
830 model.send_last_completion_stream_text_chunk(UNFORMATTED_CONTENT.to_string());
831 model.end_last_completion_stream();
832
833 edit_task.await
834 };
835 assert!(edit_result.is_ok());
836
837 // Wait for any async operations (e.g. formatting) to complete
838 cx.executor().run_until_parked();
839
840 // Read the file to verify it was formatted automatically
841 let new_content = fs.load(path!("/root/src/main.rs").as_ref()).await.unwrap();
842 assert_eq!(
843 // Ignore carriage returns on Windows
844 new_content.replace("\r\n", "\n"),
845 FORMATTED_CONTENT,
846 "Code should be formatted when format_on_save is enabled"
847 );
848
849 let stale_buffer_count = thread
850 .read_with(cx, |thread, _cx| thread.action_log.clone())
851 .read_with(cx, |log, cx| log.stale_buffers(cx).count());
852
853 assert_eq!(
854 stale_buffer_count, 0,
855 "BUG: Buffer is incorrectly marked as stale after format-on-save. Found {} stale buffers. \
856 This causes the agent to think the file was modified externally when it was just formatted.",
857 stale_buffer_count
858 );
859
860 // Next, test with format_on_save disabled
861 cx.update(|cx| {
862 SettingsStore::update_global(cx, |store, cx| {
863 store.update_user_settings(cx, |settings| {
864 settings.project.all_languages.defaults.format_on_save =
865 Some(FormatOnSave::Off);
866 });
867 });
868 });
869
870 // Stream unformatted edits again
871 let edit_result = {
872 let edit_task = cx.update(|cx| {
873 let input = EditFileToolInput {
874 display_description: "Update main function".into(),
875 path: "root/src/main.rs".into(),
876 mode: EditFileMode::Overwrite,
877 };
878 Arc::new(EditFileTool::new(
879 project.clone(),
880 thread.downgrade(),
881 language_registry,
882 Templates::new(),
883 ))
884 .run(input, ToolCallEventStream::test().0, cx)
885 });
886
887 // Stream the unformatted content
888 cx.executor().run_until_parked();
889 model.send_last_completion_stream_text_chunk(UNFORMATTED_CONTENT.to_string());
890 model.end_last_completion_stream();
891
892 edit_task.await
893 };
894 assert!(edit_result.is_ok());
895
896 // Wait for any async operations (e.g. formatting) to complete
897 cx.executor().run_until_parked();
898
899 // Verify the file was not formatted
900 let new_content = fs.load(path!("/root/src/main.rs").as_ref()).await.unwrap();
901 assert_eq!(
902 // Ignore carriage returns on Windows
903 new_content.replace("\r\n", "\n"),
904 UNFORMATTED_CONTENT,
905 "Code should not be formatted when format_on_save is disabled"
906 );
907 }
908
909 #[gpui::test]
910 async fn test_remove_trailing_whitespace(cx: &mut TestAppContext) {
911 init_test(cx);
912
913 let fs = project::FakeFs::new(cx.executor());
914 fs.insert_tree("/root", json!({"src": {}})).await;
915
916 // Create a simple file with trailing whitespace
917 fs.save(
918 path!("/root/src/main.rs").as_ref(),
919 &"initial content".into(),
920 language::LineEnding::Unix,
921 )
922 .await
923 .unwrap();
924
925 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
926 let context_server_registry =
927 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
928 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
929 let model = Arc::new(FakeLanguageModel::default());
930 let thread = cx.new(|cx| {
931 Thread::new(
932 project.clone(),
933 cx.new(|_cx| ProjectContext::default()),
934 context_server_registry,
935 Templates::new(),
936 Some(model.clone()),
937 cx,
938 )
939 });
940
941 // First, test with remove_trailing_whitespace_on_save enabled
942 cx.update(|cx| {
943 SettingsStore::update_global(cx, |store, cx| {
944 store.update_user_settings(cx, |settings| {
945 settings
946 .project
947 .all_languages
948 .defaults
949 .remove_trailing_whitespace_on_save = Some(true);
950 });
951 });
952 });
953
954 const CONTENT_WITH_TRAILING_WHITESPACE: &str =
955 "fn main() { \n println!(\"Hello!\"); \n}\n";
956
957 // Have the model stream content that contains trailing whitespace
958 let edit_result = {
959 let edit_task = cx.update(|cx| {
960 let input = EditFileToolInput {
961 display_description: "Create main function".into(),
962 path: "root/src/main.rs".into(),
963 mode: EditFileMode::Overwrite,
964 };
965 Arc::new(EditFileTool::new(
966 project.clone(),
967 thread.downgrade(),
968 language_registry.clone(),
969 Templates::new(),
970 ))
971 .run(input, ToolCallEventStream::test().0, cx)
972 });
973
974 // Stream the content with trailing whitespace
975 cx.executor().run_until_parked();
976 model.send_last_completion_stream_text_chunk(
977 CONTENT_WITH_TRAILING_WHITESPACE.to_string(),
978 );
979 model.end_last_completion_stream();
980
981 edit_task.await
982 };
983 assert!(edit_result.is_ok());
984
985 // Wait for any async operations (e.g. formatting) to complete
986 cx.executor().run_until_parked();
987
988 // Read the file to verify trailing whitespace was removed automatically
989 assert_eq!(
990 // Ignore carriage returns on Windows
991 fs.load(path!("/root/src/main.rs").as_ref())
992 .await
993 .unwrap()
994 .replace("\r\n", "\n"),
995 "fn main() {\n println!(\"Hello!\");\n}\n",
996 "Trailing whitespace should be removed when remove_trailing_whitespace_on_save is enabled"
997 );
998
999 // Next, test with remove_trailing_whitespace_on_save disabled
1000 cx.update(|cx| {
1001 SettingsStore::update_global(cx, |store, cx| {
1002 store.update_user_settings(cx, |settings| {
1003 settings
1004 .project
1005 .all_languages
1006 .defaults
1007 .remove_trailing_whitespace_on_save = Some(false);
1008 });
1009 });
1010 });
1011
1012 // Stream edits again with trailing whitespace
1013 let edit_result = {
1014 let edit_task = cx.update(|cx| {
1015 let input = EditFileToolInput {
1016 display_description: "Update main function".into(),
1017 path: "root/src/main.rs".into(),
1018 mode: EditFileMode::Overwrite,
1019 };
1020 Arc::new(EditFileTool::new(
1021 project.clone(),
1022 thread.downgrade(),
1023 language_registry,
1024 Templates::new(),
1025 ))
1026 .run(input, ToolCallEventStream::test().0, cx)
1027 });
1028
1029 // Stream the content with trailing whitespace
1030 cx.executor().run_until_parked();
1031 model.send_last_completion_stream_text_chunk(
1032 CONTENT_WITH_TRAILING_WHITESPACE.to_string(),
1033 );
1034 model.end_last_completion_stream();
1035
1036 edit_task.await
1037 };
1038 assert!(edit_result.is_ok());
1039
1040 // Wait for any async operations (e.g. formatting) to complete
1041 cx.executor().run_until_parked();
1042
1043 // Verify the file still has trailing whitespace
1044 // Read the file again - it should still have trailing whitespace
1045 let final_content = fs.load(path!("/root/src/main.rs").as_ref()).await.unwrap();
1046 assert_eq!(
1047 // Ignore carriage returns on Windows
1048 final_content.replace("\r\n", "\n"),
1049 CONTENT_WITH_TRAILING_WHITESPACE,
1050 "Trailing whitespace should remain when remove_trailing_whitespace_on_save is disabled"
1051 );
1052 }
1053
1054 #[gpui::test]
1055 async fn test_authorize(cx: &mut TestAppContext) {
1056 init_test(cx);
1057 let fs = project::FakeFs::new(cx.executor());
1058 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
1059 let context_server_registry =
1060 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1061 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1062 let model = Arc::new(FakeLanguageModel::default());
1063 let thread = cx.new(|cx| {
1064 Thread::new(
1065 project.clone(),
1066 cx.new(|_cx| ProjectContext::default()),
1067 context_server_registry,
1068 Templates::new(),
1069 Some(model.clone()),
1070 cx,
1071 )
1072 });
1073 let tool = Arc::new(EditFileTool::new(
1074 project.clone(),
1075 thread.downgrade(),
1076 language_registry,
1077 Templates::new(),
1078 ));
1079 fs.insert_tree("/root", json!({})).await;
1080
1081 // Test 1: Path with .zed component should require confirmation
1082 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1083 let _auth = cx.update(|cx| {
1084 tool.authorize(
1085 &EditFileToolInput {
1086 display_description: "test 1".into(),
1087 path: ".zed/settings.json".into(),
1088 mode: EditFileMode::Edit,
1089 },
1090 &stream_tx,
1091 cx,
1092 )
1093 });
1094
1095 let event = stream_rx.expect_authorization().await;
1096 assert_eq!(
1097 event.tool_call.fields.title,
1098 Some("test 1 (local settings)".into())
1099 );
1100
1101 // Test 2: Path outside project should require confirmation
1102 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1103 let _auth = cx.update(|cx| {
1104 tool.authorize(
1105 &EditFileToolInput {
1106 display_description: "test 2".into(),
1107 path: "/etc/hosts".into(),
1108 mode: EditFileMode::Edit,
1109 },
1110 &stream_tx,
1111 cx,
1112 )
1113 });
1114
1115 let event = stream_rx.expect_authorization().await;
1116 assert_eq!(event.tool_call.fields.title, Some("test 2".into()));
1117
1118 // Test 3: Relative path without .zed should not require confirmation
1119 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1120 cx.update(|cx| {
1121 tool.authorize(
1122 &EditFileToolInput {
1123 display_description: "test 3".into(),
1124 path: "root/src/main.rs".into(),
1125 mode: EditFileMode::Edit,
1126 },
1127 &stream_tx,
1128 cx,
1129 )
1130 })
1131 .await
1132 .unwrap();
1133 assert!(stream_rx.try_next().is_err());
1134
1135 // Test 4: Path with .zed in the middle should require confirmation
1136 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1137 let _auth = cx.update(|cx| {
1138 tool.authorize(
1139 &EditFileToolInput {
1140 display_description: "test 4".into(),
1141 path: "root/.zed/tasks.json".into(),
1142 mode: EditFileMode::Edit,
1143 },
1144 &stream_tx,
1145 cx,
1146 )
1147 });
1148 let event = stream_rx.expect_authorization().await;
1149 assert_eq!(
1150 event.tool_call.fields.title,
1151 Some("test 4 (local settings)".into())
1152 );
1153
1154 // Test 5: When always_allow_tool_actions is enabled, no confirmation needed
1155 cx.update(|cx| {
1156 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
1157 settings.always_allow_tool_actions = true;
1158 agent_settings::AgentSettings::override_global(settings, cx);
1159 });
1160
1161 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1162 cx.update(|cx| {
1163 tool.authorize(
1164 &EditFileToolInput {
1165 display_description: "test 5.1".into(),
1166 path: ".zed/settings.json".into(),
1167 mode: EditFileMode::Edit,
1168 },
1169 &stream_tx,
1170 cx,
1171 )
1172 })
1173 .await
1174 .unwrap();
1175 assert!(stream_rx.try_next().is_err());
1176
1177 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1178 cx.update(|cx| {
1179 tool.authorize(
1180 &EditFileToolInput {
1181 display_description: "test 5.2".into(),
1182 path: "/etc/hosts".into(),
1183 mode: EditFileMode::Edit,
1184 },
1185 &stream_tx,
1186 cx,
1187 )
1188 })
1189 .await
1190 .unwrap();
1191 assert!(stream_rx.try_next().is_err());
1192 }
1193
1194 #[gpui::test]
1195 async fn test_authorize_global_config(cx: &mut TestAppContext) {
1196 init_test(cx);
1197 let fs = project::FakeFs::new(cx.executor());
1198 fs.insert_tree("/project", json!({})).await;
1199 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1200 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1201 let context_server_registry =
1202 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1203 let model = Arc::new(FakeLanguageModel::default());
1204 let thread = cx.new(|cx| {
1205 Thread::new(
1206 project.clone(),
1207 cx.new(|_cx| ProjectContext::default()),
1208 context_server_registry,
1209 Templates::new(),
1210 Some(model.clone()),
1211 cx,
1212 )
1213 });
1214 let tool = Arc::new(EditFileTool::new(
1215 project.clone(),
1216 thread.downgrade(),
1217 language_registry,
1218 Templates::new(),
1219 ));
1220
1221 // Test global config paths - these should require confirmation if they exist and are outside the project
1222 let test_cases = vec![
1223 (
1224 "/etc/hosts",
1225 true,
1226 "System file should require confirmation",
1227 ),
1228 (
1229 "/usr/local/bin/script",
1230 true,
1231 "System bin file should require confirmation",
1232 ),
1233 (
1234 "project/normal_file.rs",
1235 false,
1236 "Normal project file should not require confirmation",
1237 ),
1238 ];
1239
1240 for (path, should_confirm, description) in test_cases {
1241 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1242 let auth = cx.update(|cx| {
1243 tool.authorize(
1244 &EditFileToolInput {
1245 display_description: "Edit file".into(),
1246 path: path.into(),
1247 mode: EditFileMode::Edit,
1248 },
1249 &stream_tx,
1250 cx,
1251 )
1252 });
1253
1254 if should_confirm {
1255 stream_rx.expect_authorization().await;
1256 } else {
1257 auth.await.unwrap();
1258 assert!(
1259 stream_rx.try_next().is_err(),
1260 "Failed for case: {} - path: {} - expected no confirmation but got one",
1261 description,
1262 path
1263 );
1264 }
1265 }
1266 }
1267
1268 #[gpui::test]
1269 async fn test_needs_confirmation_with_multiple_worktrees(cx: &mut TestAppContext) {
1270 init_test(cx);
1271 let fs = project::FakeFs::new(cx.executor());
1272
1273 // Create multiple worktree directories
1274 fs.insert_tree(
1275 "/workspace/frontend",
1276 json!({
1277 "src": {
1278 "main.js": "console.log('frontend');"
1279 }
1280 }),
1281 )
1282 .await;
1283 fs.insert_tree(
1284 "/workspace/backend",
1285 json!({
1286 "src": {
1287 "main.rs": "fn main() {}"
1288 }
1289 }),
1290 )
1291 .await;
1292 fs.insert_tree(
1293 "/workspace/shared",
1294 json!({
1295 ".zed": {
1296 "settings.json": "{}"
1297 }
1298 }),
1299 )
1300 .await;
1301
1302 // Create project with multiple worktrees
1303 let project = Project::test(
1304 fs.clone(),
1305 [
1306 path!("/workspace/frontend").as_ref(),
1307 path!("/workspace/backend").as_ref(),
1308 path!("/workspace/shared").as_ref(),
1309 ],
1310 cx,
1311 )
1312 .await;
1313 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1314 let context_server_registry =
1315 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1316 let model = Arc::new(FakeLanguageModel::default());
1317 let thread = cx.new(|cx| {
1318 Thread::new(
1319 project.clone(),
1320 cx.new(|_cx| ProjectContext::default()),
1321 context_server_registry.clone(),
1322 Templates::new(),
1323 Some(model.clone()),
1324 cx,
1325 )
1326 });
1327 let tool = Arc::new(EditFileTool::new(
1328 project.clone(),
1329 thread.downgrade(),
1330 language_registry,
1331 Templates::new(),
1332 ));
1333
1334 // Test files in different worktrees
1335 let test_cases = vec![
1336 ("frontend/src/main.js", false, "File in first worktree"),
1337 ("backend/src/main.rs", false, "File in second worktree"),
1338 (
1339 "shared/.zed/settings.json",
1340 true,
1341 ".zed file in third worktree",
1342 ),
1343 ("/etc/hosts", true, "Absolute path outside all worktrees"),
1344 (
1345 "../outside/file.txt",
1346 true,
1347 "Relative path outside worktrees",
1348 ),
1349 ];
1350
1351 for (path, should_confirm, description) in test_cases {
1352 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1353 let auth = cx.update(|cx| {
1354 tool.authorize(
1355 &EditFileToolInput {
1356 display_description: "Edit file".into(),
1357 path: path.into(),
1358 mode: EditFileMode::Edit,
1359 },
1360 &stream_tx,
1361 cx,
1362 )
1363 });
1364
1365 if should_confirm {
1366 stream_rx.expect_authorization().await;
1367 } else {
1368 auth.await.unwrap();
1369 assert!(
1370 stream_rx.try_next().is_err(),
1371 "Failed for case: {} - path: {} - expected no confirmation but got one",
1372 description,
1373 path
1374 );
1375 }
1376 }
1377 }
1378
1379 #[gpui::test]
1380 async fn test_needs_confirmation_edge_cases(cx: &mut TestAppContext) {
1381 init_test(cx);
1382 let fs = project::FakeFs::new(cx.executor());
1383 fs.insert_tree(
1384 "/project",
1385 json!({
1386 ".zed": {
1387 "settings.json": "{}"
1388 },
1389 "src": {
1390 ".zed": {
1391 "local.json": "{}"
1392 }
1393 }
1394 }),
1395 )
1396 .await;
1397 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1398 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1399 let context_server_registry =
1400 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1401 let model = Arc::new(FakeLanguageModel::default());
1402 let thread = cx.new(|cx| {
1403 Thread::new(
1404 project.clone(),
1405 cx.new(|_cx| ProjectContext::default()),
1406 context_server_registry.clone(),
1407 Templates::new(),
1408 Some(model.clone()),
1409 cx,
1410 )
1411 });
1412 let tool = Arc::new(EditFileTool::new(
1413 project.clone(),
1414 thread.downgrade(),
1415 language_registry,
1416 Templates::new(),
1417 ));
1418
1419 // Test edge cases
1420 let test_cases = vec![
1421 // Empty path - find_project_path returns Some for empty paths
1422 ("", false, "Empty path is treated as project root"),
1423 // Root directory
1424 ("/", true, "Root directory should be outside project"),
1425 // Parent directory references - find_project_path resolves these
1426 (
1427 "project/../other",
1428 true,
1429 "Path with .. that goes outside of root directory",
1430 ),
1431 (
1432 "project/./src/file.rs",
1433 false,
1434 "Path with . should work normally",
1435 ),
1436 // Windows-style paths (if on Windows)
1437 #[cfg(target_os = "windows")]
1438 ("C:\\Windows\\System32\\hosts", true, "Windows system path"),
1439 #[cfg(target_os = "windows")]
1440 ("project\\src\\main.rs", false, "Windows-style project path"),
1441 ];
1442
1443 for (path, should_confirm, description) in test_cases {
1444 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1445 let auth = cx.update(|cx| {
1446 tool.authorize(
1447 &EditFileToolInput {
1448 display_description: "Edit file".into(),
1449 path: path.into(),
1450 mode: EditFileMode::Edit,
1451 },
1452 &stream_tx,
1453 cx,
1454 )
1455 });
1456
1457 cx.run_until_parked();
1458
1459 if should_confirm {
1460 stream_rx.expect_authorization().await;
1461 } else {
1462 assert!(
1463 stream_rx.try_next().is_err(),
1464 "Failed for case: {} - path: {} - expected no confirmation but got one",
1465 description,
1466 path
1467 );
1468 auth.await.unwrap();
1469 }
1470 }
1471 }
1472
1473 #[gpui::test]
1474 async fn test_needs_confirmation_with_different_modes(cx: &mut TestAppContext) {
1475 init_test(cx);
1476 let fs = project::FakeFs::new(cx.executor());
1477 fs.insert_tree(
1478 "/project",
1479 json!({
1480 "existing.txt": "content",
1481 ".zed": {
1482 "settings.json": "{}"
1483 }
1484 }),
1485 )
1486 .await;
1487 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1488 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1489 let context_server_registry =
1490 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1491 let model = Arc::new(FakeLanguageModel::default());
1492 let thread = cx.new(|cx| {
1493 Thread::new(
1494 project.clone(),
1495 cx.new(|_cx| ProjectContext::default()),
1496 context_server_registry.clone(),
1497 Templates::new(),
1498 Some(model.clone()),
1499 cx,
1500 )
1501 });
1502 let tool = Arc::new(EditFileTool::new(
1503 project.clone(),
1504 thread.downgrade(),
1505 language_registry,
1506 Templates::new(),
1507 ));
1508
1509 // Test different EditFileMode values
1510 let modes = vec![
1511 EditFileMode::Edit,
1512 EditFileMode::Create,
1513 EditFileMode::Overwrite,
1514 ];
1515
1516 for mode in modes {
1517 // Test .zed path with different modes
1518 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1519 let _auth = cx.update(|cx| {
1520 tool.authorize(
1521 &EditFileToolInput {
1522 display_description: "Edit settings".into(),
1523 path: "project/.zed/settings.json".into(),
1524 mode: mode.clone(),
1525 },
1526 &stream_tx,
1527 cx,
1528 )
1529 });
1530
1531 stream_rx.expect_authorization().await;
1532
1533 // Test outside path with different modes
1534 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1535 let _auth = cx.update(|cx| {
1536 tool.authorize(
1537 &EditFileToolInput {
1538 display_description: "Edit file".into(),
1539 path: "/outside/file.txt".into(),
1540 mode: mode.clone(),
1541 },
1542 &stream_tx,
1543 cx,
1544 )
1545 });
1546
1547 stream_rx.expect_authorization().await;
1548
1549 // Test normal path with different modes
1550 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1551 cx.update(|cx| {
1552 tool.authorize(
1553 &EditFileToolInput {
1554 display_description: "Edit file".into(),
1555 path: "project/normal.txt".into(),
1556 mode: mode.clone(),
1557 },
1558 &stream_tx,
1559 cx,
1560 )
1561 })
1562 .await
1563 .unwrap();
1564 assert!(stream_rx.try_next().is_err());
1565 }
1566 }
1567
1568 #[gpui::test]
1569 async fn test_initial_title_with_partial_input(cx: &mut TestAppContext) {
1570 init_test(cx);
1571 let fs = project::FakeFs::new(cx.executor());
1572 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1573 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1574 let context_server_registry =
1575 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1576 let model = Arc::new(FakeLanguageModel::default());
1577 let thread = cx.new(|cx| {
1578 Thread::new(
1579 project.clone(),
1580 cx.new(|_cx| ProjectContext::default()),
1581 context_server_registry,
1582 Templates::new(),
1583 Some(model.clone()),
1584 cx,
1585 )
1586 });
1587 let tool = Arc::new(EditFileTool::new(
1588 project,
1589 thread.downgrade(),
1590 language_registry,
1591 Templates::new(),
1592 ));
1593
1594 cx.update(|cx| {
1595 // ...
1596 assert_eq!(
1597 tool.initial_title(
1598 Err(json!({
1599 "path": "src/main.rs",
1600 "display_description": "",
1601 "old_string": "old code",
1602 "new_string": "new code"
1603 })),
1604 cx
1605 ),
1606 "src/main.rs"
1607 );
1608 assert_eq!(
1609 tool.initial_title(
1610 Err(json!({
1611 "path": "",
1612 "display_description": "Fix error handling",
1613 "old_string": "old code",
1614 "new_string": "new code"
1615 })),
1616 cx
1617 ),
1618 "Fix error handling"
1619 );
1620 assert_eq!(
1621 tool.initial_title(
1622 Err(json!({
1623 "path": "src/main.rs",
1624 "display_description": "Fix error handling",
1625 "old_string": "old code",
1626 "new_string": "new code"
1627 })),
1628 cx
1629 ),
1630 "src/main.rs"
1631 );
1632 assert_eq!(
1633 tool.initial_title(
1634 Err(json!({
1635 "path": "",
1636 "display_description": "",
1637 "old_string": "old code",
1638 "new_string": "new code"
1639 })),
1640 cx
1641 ),
1642 DEFAULT_UI_TEXT
1643 );
1644 assert_eq!(
1645 tool.initial_title(Err(serde_json::Value::Null), cx),
1646 DEFAULT_UI_TEXT
1647 );
1648 });
1649 }
1650
1651 #[gpui::test]
1652 async fn test_diff_finalization(cx: &mut TestAppContext) {
1653 init_test(cx);
1654 let fs = project::FakeFs::new(cx.executor());
1655 fs.insert_tree("/", json!({"main.rs": ""})).await;
1656
1657 let project = Project::test(fs.clone(), [path!("/").as_ref()], cx).await;
1658 let languages = project.read_with(cx, |project, _cx| project.languages().clone());
1659 let context_server_registry =
1660 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1661 let model = Arc::new(FakeLanguageModel::default());
1662 let thread = cx.new(|cx| {
1663 Thread::new(
1664 project.clone(),
1665 cx.new(|_cx| ProjectContext::default()),
1666 context_server_registry.clone(),
1667 Templates::new(),
1668 Some(model.clone()),
1669 cx,
1670 )
1671 });
1672
1673 // Ensure the diff is finalized after the edit completes.
1674 {
1675 let tool = Arc::new(EditFileTool::new(
1676 project.clone(),
1677 thread.downgrade(),
1678 languages.clone(),
1679 Templates::new(),
1680 ));
1681 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1682 let edit = cx.update(|cx| {
1683 tool.run(
1684 EditFileToolInput {
1685 display_description: "Edit file".into(),
1686 path: path!("/main.rs").into(),
1687 mode: EditFileMode::Edit,
1688 },
1689 stream_tx,
1690 cx,
1691 )
1692 });
1693 stream_rx.expect_update_fields().await;
1694 let diff = stream_rx.expect_diff().await;
1695 diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Pending(_))));
1696 cx.run_until_parked();
1697 model.end_last_completion_stream();
1698 edit.await.unwrap();
1699 diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Finalized(_))));
1700 }
1701
1702 // Ensure the diff is finalized if an error occurs while editing.
1703 {
1704 model.forbid_requests();
1705 let tool = Arc::new(EditFileTool::new(
1706 project.clone(),
1707 thread.downgrade(),
1708 languages.clone(),
1709 Templates::new(),
1710 ));
1711 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1712 let edit = cx.update(|cx| {
1713 tool.run(
1714 EditFileToolInput {
1715 display_description: "Edit file".into(),
1716 path: path!("/main.rs").into(),
1717 mode: EditFileMode::Edit,
1718 },
1719 stream_tx,
1720 cx,
1721 )
1722 });
1723 stream_rx.expect_update_fields().await;
1724 let diff = stream_rx.expect_diff().await;
1725 diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Pending(_))));
1726 edit.await.unwrap_err();
1727 diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Finalized(_))));
1728 model.allow_requests();
1729 }
1730
1731 // Ensure the diff is finalized if the tool call gets dropped.
1732 {
1733 let tool = Arc::new(EditFileTool::new(
1734 project.clone(),
1735 thread.downgrade(),
1736 languages.clone(),
1737 Templates::new(),
1738 ));
1739 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1740 let edit = cx.update(|cx| {
1741 tool.run(
1742 EditFileToolInput {
1743 display_description: "Edit file".into(),
1744 path: path!("/main.rs").into(),
1745 mode: EditFileMode::Edit,
1746 },
1747 stream_tx,
1748 cx,
1749 )
1750 });
1751 stream_rx.expect_update_fields().await;
1752 let diff = stream_rx.expect_diff().await;
1753 diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Pending(_))));
1754 drop(edit);
1755 cx.run_until_parked();
1756 diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Finalized(_))));
1757 }
1758 }
1759
1760 fn init_test(cx: &mut TestAppContext) {
1761 cx.update(|cx| {
1762 let settings_store = SettingsStore::test(cx);
1763 cx.set_global(settings_store);
1764 language::init(cx);
1765 TelemetrySettings::register(cx);
1766 agent_settings::AgentSettings::register(cx);
1767 Project::init_settings(cx);
1768 });
1769 }
1770}