1use crate::{
2 AgentTool, Templates, Thread, ToolCallEventStream,
3 edit_agent::{EditAgent, EditAgentOutput, EditAgentOutputEvent, EditFormat},
4};
5use acp_thread::Diff;
6use agent_client_protocol::{self as acp, ToolCallLocation, ToolCallUpdateFields};
7use anyhow::{Context as _, Result, anyhow};
8use cloud_llm_client::CompletionIntent;
9use collections::HashSet;
10use gpui::{App, AppContext, AsyncApp, Entity, Task, WeakEntity};
11use indoc::formatdoc;
12use language::language_settings::{self, FormatOnSave};
13use language::{LanguageRegistry, ToPoint};
14use language_model::LanguageModelToolResultContent;
15use paths;
16use project::lsp_store::{FormatTrigger, LspFormatTarget};
17use project::{Project, ProjectPath};
18use schemars::JsonSchema;
19use serde::{Deserialize, Serialize};
20use settings::Settings;
21use smol::stream::StreamExt as _;
22use std::ffi::OsStr;
23use std::path::{Path, PathBuf};
24use std::sync::Arc;
25use ui::SharedString;
26use util::ResultExt;
27use util::rel_path::RelPath;
28
29const DEFAULT_UI_TEXT: &str = "Editing file";
30
31/// This is a tool for creating a new file or editing an existing file. For moving or renaming files, you should generally use the `terminal` tool with the 'mv' command instead.
32///
33/// Before using this tool:
34///
35/// 1. Use the `read_file` tool to understand the file's contents and context
36///
37/// 2. Verify the directory path is correct (only applicable when creating new files):
38/// - Use the `list_directory` tool to verify the parent directory exists and is the correct location
39#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
40pub struct EditFileToolInput {
41 /// A one-line, user-friendly markdown description of the edit. This will be shown in the UI and also passed to another model to perform the edit.
42 ///
43 /// Be terse, but also descriptive in what you want to achieve with this edit. Avoid generic instructions.
44 ///
45 /// NEVER mention the file path in this description.
46 ///
47 /// <example>Fix API endpoint URLs</example>
48 /// <example>Update copyright year in `page_footer`</example>
49 ///
50 /// Make sure to include this field before all the others in the input object so that we can display it immediately.
51 pub display_description: String,
52
53 /// The full path of the file to create or modify in the project.
54 ///
55 /// WARNING: When specifying which file path need changing, you MUST start each path with one of the project's root directories.
56 ///
57 /// The following examples assume we have two root directories in the project:
58 /// - /a/b/backend
59 /// - /c/d/frontend
60 ///
61 /// <example>
62 /// `backend/src/main.rs`
63 ///
64 /// Notice how the file path starts with `backend`. Without that, the path would be ambiguous and the call would fail!
65 /// </example>
66 ///
67 /// <example>
68 /// `frontend/db.js`
69 /// </example>
70 pub path: PathBuf,
71 /// The mode of operation on the file. Possible values:
72 /// - 'edit': Make granular edits to an existing file.
73 /// - 'create': Create a new file if it doesn't exist.
74 /// - 'overwrite': Replace the entire contents of an existing file.
75 ///
76 /// When a file already exists or you just created it, prefer editing it as opposed to recreating it from scratch.
77 pub mode: EditFileMode,
78}
79
80#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
81struct EditFileToolPartialInput {
82 #[serde(default)]
83 path: String,
84 #[serde(default)]
85 display_description: String,
86}
87
88#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
89#[serde(rename_all = "lowercase")]
90#[schemars(inline)]
91pub enum EditFileMode {
92 Edit,
93 Create,
94 Overwrite,
95}
96
97#[derive(Debug, Serialize, Deserialize)]
98pub struct EditFileToolOutput {
99 #[serde(alias = "original_path")]
100 input_path: PathBuf,
101 new_text: String,
102 old_text: Arc<String>,
103 #[serde(default)]
104 diff: String,
105 #[serde(alias = "raw_output")]
106 edit_agent_output: EditAgentOutput,
107}
108
109impl From<EditFileToolOutput> for LanguageModelToolResultContent {
110 fn from(output: EditFileToolOutput) -> Self {
111 if output.diff.is_empty() {
112 "No edits were made.".into()
113 } else {
114 format!(
115 "Edited {}:\n\n```diff\n{}\n```",
116 output.input_path.display(),
117 output.diff
118 )
119 .into()
120 }
121 }
122}
123
124pub struct EditFileTool {
125 thread: WeakEntity<Thread>,
126 language_registry: Arc<LanguageRegistry>,
127 project: Entity<Project>,
128 templates: Arc<Templates>,
129}
130
131impl EditFileTool {
132 pub fn new(
133 project: Entity<Project>,
134 thread: WeakEntity<Thread>,
135 language_registry: Arc<LanguageRegistry>,
136 templates: Arc<Templates>,
137 ) -> Self {
138 Self {
139 project,
140 thread,
141 language_registry,
142 templates,
143 }
144 }
145
146 fn authorize(
147 &self,
148 input: &EditFileToolInput,
149 event_stream: &ToolCallEventStream,
150 cx: &mut App,
151 ) -> Task<Result<()>> {
152 if agent_settings::AgentSettings::get_global(cx).always_allow_tool_actions {
153 return Task::ready(Ok(()));
154 }
155
156 // If any path component matches the local settings folder, then this could affect
157 // the editor in ways beyond the project source, so prompt.
158 let local_settings_folder = paths::local_settings_folder_name();
159 let path = Path::new(&input.path);
160 if path.components().any(|component| {
161 component.as_os_str() == <_ as AsRef<OsStr>>::as_ref(&local_settings_folder)
162 }) {
163 return event_stream.authorize(
164 format!("{} (local settings)", input.display_description),
165 cx,
166 );
167 }
168
169 // It's also possible that the global config dir is configured to be inside the project,
170 // so check for that edge case too.
171 // TODO this is broken when remoting
172 if let Ok(canonical_path) = std::fs::canonicalize(&input.path)
173 && canonical_path.starts_with(paths::config_dir())
174 {
175 return event_stream.authorize(
176 format!("{} (global settings)", input.display_description),
177 cx,
178 );
179 }
180
181 // Check if path is inside the global config directory
182 // First check if it's already inside project - if not, try to canonicalize
183 let Ok(project_path) = self.thread.read_with(cx, |thread, cx| {
184 thread.project().read(cx).find_project_path(&input.path, cx)
185 }) else {
186 return Task::ready(Err(anyhow!("thread was dropped")));
187 };
188
189 // If the path is inside the project, and it's not one of the above edge cases,
190 // then no confirmation is necessary. Otherwise, confirmation is necessary.
191 if project_path.is_some() {
192 Task::ready(Ok(()))
193 } else {
194 event_stream.authorize(&input.display_description, cx)
195 }
196 }
197}
198
199impl AgentTool for EditFileTool {
200 type Input = EditFileToolInput;
201 type Output = EditFileToolOutput;
202
203 fn name() -> &'static str {
204 "edit_file"
205 }
206
207 fn kind() -> acp::ToolKind {
208 acp::ToolKind::Edit
209 }
210
211 fn initial_title(
212 &self,
213 input: Result<Self::Input, serde_json::Value>,
214 cx: &mut App,
215 ) -> SharedString {
216 match input {
217 Ok(input) => self
218 .project
219 .read(cx)
220 .find_project_path(&input.path, cx)
221 .and_then(|project_path| {
222 self.project
223 .read(cx)
224 .short_full_path_for_project_path(&project_path, cx)
225 })
226 .unwrap_or(input.path.to_string_lossy().into_owned())
227 .into(),
228 Err(raw_input) => {
229 if let Some(input) =
230 serde_json::from_value::<EditFileToolPartialInput>(raw_input).ok()
231 {
232 let path = input.path.trim();
233 if !path.is_empty() {
234 return self
235 .project
236 .read(cx)
237 .find_project_path(&input.path, cx)
238 .and_then(|project_path| {
239 self.project
240 .read(cx)
241 .short_full_path_for_project_path(&project_path, cx)
242 })
243 .unwrap_or(input.path)
244 .into();
245 }
246
247 let description = input.display_description.trim();
248 if !description.is_empty() {
249 return description.to_string().into();
250 }
251 }
252
253 DEFAULT_UI_TEXT.into()
254 }
255 }
256 }
257
258 fn run(
259 self: Arc<Self>,
260 input: Self::Input,
261 event_stream: ToolCallEventStream,
262 cx: &mut App,
263 ) -> Task<Result<Self::Output>> {
264 let Ok(project) = self
265 .thread
266 .read_with(cx, |thread, _cx| thread.project().clone())
267 else {
268 return Task::ready(Err(anyhow!("thread was dropped")));
269 };
270 let project_path = match resolve_path(&input, project.clone(), cx) {
271 Ok(path) => path,
272 Err(err) => return Task::ready(Err(anyhow!(err))),
273 };
274 let abs_path = project.read(cx).absolute_path(&project_path, cx);
275 if let Some(abs_path) = abs_path.clone() {
276 event_stream.update_fields(ToolCallUpdateFields {
277 locations: Some(vec![acp::ToolCallLocation {
278 path: abs_path,
279 line: None,
280 meta: None,
281 }]),
282 ..Default::default()
283 });
284 }
285
286 let authorize = self.authorize(&input, &event_stream, cx);
287 cx.spawn(async move |cx: &mut AsyncApp| {
288 authorize.await?;
289
290 let (request, model, action_log) = self.thread.update(cx, |thread, cx| {
291 let request = thread.build_completion_request(CompletionIntent::ToolResults, cx);
292 (request, thread.model().cloned(), thread.action_log().clone())
293 })?;
294 let request = request?;
295 let model = model.context("No language model configured")?;
296
297 let edit_format = EditFormat::from_model(model.clone())?;
298 let edit_agent = EditAgent::new(
299 model,
300 project.clone(),
301 action_log.clone(),
302 self.templates.clone(),
303 edit_format,
304 );
305
306 let buffer = project
307 .update(cx, |project, cx| {
308 project.open_buffer(project_path.clone(), cx)
309 })?
310 .await?;
311
312 let diff = cx.new(|cx| Diff::new(buffer.clone(), cx))?;
313 event_stream.update_diff(diff.clone());
314 let _finalize_diff = util::defer({
315 let diff = diff.downgrade();
316 let mut cx = cx.clone();
317 move || {
318 diff.update(&mut cx, |diff, cx| diff.finalize(cx)).ok();
319 }
320 });
321
322 let old_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
323 let old_text = cx
324 .background_spawn({
325 let old_snapshot = old_snapshot.clone();
326 async move { Arc::new(old_snapshot.text()) }
327 })
328 .await;
329
330
331 let (output, mut events) = if matches!(input.mode, EditFileMode::Edit) {
332 edit_agent.edit(
333 buffer.clone(),
334 input.display_description.clone(),
335 &request,
336 cx,
337 )
338 } else {
339 edit_agent.overwrite(
340 buffer.clone(),
341 input.display_description.clone(),
342 &request,
343 cx,
344 )
345 };
346
347 let mut hallucinated_old_text = false;
348 let mut ambiguous_ranges = Vec::new();
349 let mut emitted_location = false;
350 while let Some(event) = events.next().await {
351 match event {
352 EditAgentOutputEvent::Edited(range) => {
353 if !emitted_location {
354 let line = buffer.update(cx, |buffer, _cx| {
355 range.start.to_point(&buffer.snapshot()).row
356 }).ok();
357 if let Some(abs_path) = abs_path.clone() {
358 event_stream.update_fields(ToolCallUpdateFields {
359 locations: Some(vec![ToolCallLocation { path: abs_path, line, meta: None }]),
360 ..Default::default()
361 });
362 }
363 emitted_location = true;
364 }
365 },
366 EditAgentOutputEvent::UnresolvedEditRange => hallucinated_old_text = true,
367 EditAgentOutputEvent::AmbiguousEditRange(ranges) => ambiguous_ranges = ranges,
368 EditAgentOutputEvent::ResolvingEditRange(range) => {
369 diff.update(cx, |card, cx| card.reveal_range(range.clone(), cx))?;
370 // if !emitted_location {
371 // let line = buffer.update(cx, |buffer, _cx| {
372 // range.start.to_point(&buffer.snapshot()).row
373 // }).ok();
374 // if let Some(abs_path) = abs_path.clone() {
375 // event_stream.update_fields(ToolCallUpdateFields {
376 // locations: Some(vec![ToolCallLocation { path: abs_path, line }]),
377 // ..Default::default()
378 // });
379 // }
380 // }
381 }
382 }
383 }
384
385 // If format_on_save is enabled, format the buffer
386 let format_on_save_enabled = buffer
387 .read_with(cx, |buffer, cx| {
388 let settings = language_settings::language_settings(
389 buffer.language().map(|l| l.name()),
390 buffer.file(),
391 cx,
392 );
393 settings.format_on_save != FormatOnSave::Off
394 })
395 .unwrap_or(false);
396
397 let edit_agent_output = output.await?;
398
399 if format_on_save_enabled {
400 action_log.update(cx, |log, cx| {
401 log.buffer_edited(buffer.clone(), cx);
402 })?;
403
404 let format_task = project.update(cx, |project, cx| {
405 project.format(
406 HashSet::from_iter([buffer.clone()]),
407 LspFormatTarget::Buffers,
408 false, // Don't push to history since the tool did it.
409 FormatTrigger::Save,
410 cx,
411 )
412 })?;
413 format_task.await.log_err();
414 }
415
416 project
417 .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))?
418 .await?;
419
420 action_log.update(cx, |log, cx| {
421 log.buffer_edited(buffer.clone(), cx);
422 })?;
423
424 let new_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
425 let (new_text, unified_diff) = cx
426 .background_spawn({
427 let new_snapshot = new_snapshot.clone();
428 let old_text = old_text.clone();
429 async move {
430 let new_text = new_snapshot.text();
431 let diff = language::unified_diff(&old_text, &new_text);
432 (new_text, diff)
433 }
434 })
435 .await;
436
437 let input_path = input.path.display();
438 if unified_diff.is_empty() {
439 anyhow::ensure!(
440 !hallucinated_old_text,
441 formatdoc! {"
442 Some edits were produced but none of them could be applied.
443 Read the relevant sections of {input_path} again so that
444 I can perform the requested edits.
445 "}
446 );
447 anyhow::ensure!(
448 ambiguous_ranges.is_empty(),
449 {
450 let line_numbers = ambiguous_ranges
451 .iter()
452 .map(|range| range.start.to_string())
453 .collect::<Vec<_>>()
454 .join(", ");
455 formatdoc! {"
456 <old_text> matches more than one position in the file (lines: {line_numbers}). Read the
457 relevant sections of {input_path} again and extend <old_text> so
458 that I can perform the requested edits.
459 "}
460 }
461 );
462 }
463
464 Ok(EditFileToolOutput {
465 input_path: input.path,
466 new_text,
467 old_text,
468 diff: unified_diff,
469 edit_agent_output,
470 })
471 })
472 }
473
474 fn replay(
475 &self,
476 _input: Self::Input,
477 output: Self::Output,
478 event_stream: ToolCallEventStream,
479 cx: &mut App,
480 ) -> Result<()> {
481 event_stream.update_diff(cx.new(|cx| {
482 Diff::finalized(
483 output.input_path.to_string_lossy().into_owned(),
484 Some(output.old_text.to_string()),
485 output.new_text,
486 self.language_registry.clone(),
487 cx,
488 )
489 }));
490 Ok(())
491 }
492}
493
494/// Validate that the file path is valid, meaning:
495///
496/// - For `edit` and `overwrite`, the path must point to an existing file.
497/// - For `create`, the file must not already exist, but it's parent dir must exist.
498fn resolve_path(
499 input: &EditFileToolInput,
500 project: Entity<Project>,
501 cx: &mut App,
502) -> Result<ProjectPath> {
503 let project = project.read(cx);
504
505 match input.mode {
506 EditFileMode::Edit | EditFileMode::Overwrite => {
507 let path = project
508 .find_project_path(&input.path, cx)
509 .context("Can't edit file: path not found")?;
510
511 let entry = project
512 .entry_for_path(&path, cx)
513 .context("Can't edit file: path not found")?;
514
515 anyhow::ensure!(entry.is_file(), "Can't edit file: path is a directory");
516 Ok(path)
517 }
518
519 EditFileMode::Create => {
520 if let Some(path) = project.find_project_path(&input.path, cx) {
521 anyhow::ensure!(
522 project.entry_for_path(&path, cx).is_none(),
523 "Can't create file: file already exists"
524 );
525 }
526
527 let parent_path = input
528 .path
529 .parent()
530 .context("Can't create file: incorrect path")?;
531
532 let parent_project_path = project.find_project_path(&parent_path, cx);
533
534 let parent_entry = parent_project_path
535 .as_ref()
536 .and_then(|path| project.entry_for_path(path, cx))
537 .context("Can't create file: parent directory doesn't exist")?;
538
539 anyhow::ensure!(
540 parent_entry.is_dir(),
541 "Can't create file: parent is not a directory"
542 );
543
544 let file_name = input
545 .path
546 .file_name()
547 .and_then(|file_name| file_name.to_str())
548 .and_then(|file_name| RelPath::unix(file_name).ok())
549 .context("Can't create file: invalid filename")?;
550
551 let new_file_path = parent_project_path.map(|parent| ProjectPath {
552 path: parent.path.join(file_name),
553 ..parent
554 });
555
556 new_file_path.context("Can't create file")
557 }
558 }
559}
560
561#[cfg(test)]
562mod tests {
563 use super::*;
564 use crate::{ContextServerRegistry, Templates};
565 use client::TelemetrySettings;
566 use fs::Fs;
567 use gpui::{TestAppContext, UpdateGlobal};
568 use language_model::fake_provider::FakeLanguageModel;
569 use prompt_store::ProjectContext;
570 use serde_json::json;
571 use settings::SettingsStore;
572 use text::Rope;
573 use util::{path, rel_path::rel_path};
574
575 #[gpui::test]
576 async fn test_edit_nonexistent_file(cx: &mut TestAppContext) {
577 init_test(cx);
578
579 let fs = project::FakeFs::new(cx.executor());
580 fs.insert_tree("/root", json!({})).await;
581 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
582 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
583 let context_server_registry =
584 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
585 let model = Arc::new(FakeLanguageModel::default());
586 let thread = cx.new(|cx| {
587 Thread::new(
588 project.clone(),
589 cx.new(|_cx| ProjectContext::default()),
590 context_server_registry,
591 Templates::new(),
592 Some(model),
593 cx,
594 )
595 });
596 let result = cx
597 .update(|cx| {
598 let input = EditFileToolInput {
599 display_description: "Some edit".into(),
600 path: "root/nonexistent_file.txt".into(),
601 mode: EditFileMode::Edit,
602 };
603 Arc::new(EditFileTool::new(
604 project,
605 thread.downgrade(),
606 language_registry,
607 Templates::new(),
608 ))
609 .run(input, ToolCallEventStream::test().0, cx)
610 })
611 .await;
612 assert_eq!(
613 result.unwrap_err().to_string(),
614 "Can't edit file: path not found"
615 );
616 }
617
618 #[gpui::test]
619 async fn test_resolve_path_for_creating_file(cx: &mut TestAppContext) {
620 let mode = &EditFileMode::Create;
621
622 let result = test_resolve_path(mode, "root/new.txt", cx);
623 assert_resolved_path_eq(result.await, rel_path("new.txt"));
624
625 let result = test_resolve_path(mode, "new.txt", cx);
626 assert_resolved_path_eq(result.await, rel_path("new.txt"));
627
628 let result = test_resolve_path(mode, "dir/new.txt", cx);
629 assert_resolved_path_eq(result.await, rel_path("dir/new.txt"));
630
631 let result = test_resolve_path(mode, "root/dir/subdir/existing.txt", cx);
632 assert_eq!(
633 result.await.unwrap_err().to_string(),
634 "Can't create file: file already exists"
635 );
636
637 let result = test_resolve_path(mode, "root/dir/nonexistent_dir/new.txt", cx);
638 assert_eq!(
639 result.await.unwrap_err().to_string(),
640 "Can't create file: parent directory doesn't exist"
641 );
642 }
643
644 #[gpui::test]
645 async fn test_resolve_path_for_editing_file(cx: &mut TestAppContext) {
646 let mode = &EditFileMode::Edit;
647
648 let path_with_root = "root/dir/subdir/existing.txt";
649 let path_without_root = "dir/subdir/existing.txt";
650 let result = test_resolve_path(mode, path_with_root, cx);
651 assert_resolved_path_eq(result.await, rel_path(path_without_root));
652
653 let result = test_resolve_path(mode, path_without_root, cx);
654 assert_resolved_path_eq(result.await, rel_path(path_without_root));
655
656 let result = test_resolve_path(mode, "root/nonexistent.txt", cx);
657 assert_eq!(
658 result.await.unwrap_err().to_string(),
659 "Can't edit file: path not found"
660 );
661
662 let result = test_resolve_path(mode, "root/dir", cx);
663 assert_eq!(
664 result.await.unwrap_err().to_string(),
665 "Can't edit file: path is a directory"
666 );
667 }
668
669 async fn test_resolve_path(
670 mode: &EditFileMode,
671 path: &str,
672 cx: &mut TestAppContext,
673 ) -> anyhow::Result<ProjectPath> {
674 init_test(cx);
675
676 let fs = project::FakeFs::new(cx.executor());
677 fs.insert_tree(
678 "/root",
679 json!({
680 "dir": {
681 "subdir": {
682 "existing.txt": "hello"
683 }
684 }
685 }),
686 )
687 .await;
688 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
689
690 let input = EditFileToolInput {
691 display_description: "Some edit".into(),
692 path: path.into(),
693 mode: mode.clone(),
694 };
695
696 cx.update(|cx| resolve_path(&input, project, cx))
697 }
698
699 #[track_caller]
700 fn assert_resolved_path_eq(path: anyhow::Result<ProjectPath>, expected: &RelPath) {
701 let actual = path.expect("Should return valid path").path;
702 assert_eq!(actual.as_ref(), expected);
703 }
704
705 #[gpui::test]
706 async fn test_format_on_save(cx: &mut TestAppContext) {
707 init_test(cx);
708
709 let fs = project::FakeFs::new(cx.executor());
710 fs.insert_tree("/root", json!({"src": {}})).await;
711
712 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
713
714 // Set up a Rust language with LSP formatting support
715 let rust_language = Arc::new(language::Language::new(
716 language::LanguageConfig {
717 name: "Rust".into(),
718 matcher: language::LanguageMatcher {
719 path_suffixes: vec!["rs".to_string()],
720 ..Default::default()
721 },
722 ..Default::default()
723 },
724 None,
725 ));
726
727 // Register the language and fake LSP
728 let language_registry = project.read_with(cx, |project, _| project.languages().clone());
729 language_registry.add(rust_language);
730
731 let mut fake_language_servers = language_registry.register_fake_lsp(
732 "Rust",
733 language::FakeLspAdapter {
734 capabilities: lsp::ServerCapabilities {
735 document_formatting_provider: Some(lsp::OneOf::Left(true)),
736 ..Default::default()
737 },
738 ..Default::default()
739 },
740 );
741
742 // Create the file
743 fs.save(
744 path!("/root/src/main.rs").as_ref(),
745 &Rope::from_str_small("initial content"),
746 language::LineEnding::Unix,
747 )
748 .await
749 .unwrap();
750
751 // Open the buffer to trigger LSP initialization
752 let buffer = project
753 .update(cx, |project, cx| {
754 project.open_local_buffer(path!("/root/src/main.rs"), cx)
755 })
756 .await
757 .unwrap();
758
759 // Register the buffer with language servers
760 let _handle = project.update(cx, |project, cx| {
761 project.register_buffer_with_language_servers(&buffer, cx)
762 });
763
764 const UNFORMATTED_CONTENT: &str = "fn main() {println!(\"Hello!\");}\n";
765 const FORMATTED_CONTENT: &str =
766 "This file was formatted by the fake formatter in the test.\n";
767
768 // Get the fake language server and set up formatting handler
769 let fake_language_server = fake_language_servers.next().await.unwrap();
770 fake_language_server.set_request_handler::<lsp::request::Formatting, _, _>({
771 |_, _| async move {
772 Ok(Some(vec![lsp::TextEdit {
773 range: lsp::Range::new(lsp::Position::new(0, 0), lsp::Position::new(1, 0)),
774 new_text: FORMATTED_CONTENT.to_string(),
775 }]))
776 }
777 });
778
779 let context_server_registry =
780 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
781 let model = Arc::new(FakeLanguageModel::default());
782 let thread = cx.new(|cx| {
783 Thread::new(
784 project.clone(),
785 cx.new(|_cx| ProjectContext::default()),
786 context_server_registry,
787 Templates::new(),
788 Some(model.clone()),
789 cx,
790 )
791 });
792
793 // First, test with format_on_save enabled
794 cx.update(|cx| {
795 SettingsStore::update_global(cx, |store, cx| {
796 store.update_user_settings(cx, |settings| {
797 settings.project.all_languages.defaults.format_on_save = Some(FormatOnSave::On);
798 settings.project.all_languages.defaults.formatter =
799 Some(language::language_settings::FormatterList::default());
800 });
801 });
802 });
803
804 // Have the model stream unformatted content
805 let edit_result = {
806 let edit_task = cx.update(|cx| {
807 let input = EditFileToolInput {
808 display_description: "Create main function".into(),
809 path: "root/src/main.rs".into(),
810 mode: EditFileMode::Overwrite,
811 };
812 Arc::new(EditFileTool::new(
813 project.clone(),
814 thread.downgrade(),
815 language_registry.clone(),
816 Templates::new(),
817 ))
818 .run(input, ToolCallEventStream::test().0, cx)
819 });
820
821 // Stream the unformatted content
822 cx.executor().run_until_parked();
823 model.send_last_completion_stream_text_chunk(UNFORMATTED_CONTENT.to_string());
824 model.end_last_completion_stream();
825
826 edit_task.await
827 };
828 assert!(edit_result.is_ok());
829
830 // Wait for any async operations (e.g. formatting) to complete
831 cx.executor().run_until_parked();
832
833 // Read the file to verify it was formatted automatically
834 let new_content = fs.load(path!("/root/src/main.rs").as_ref()).await.unwrap();
835 assert_eq!(
836 // Ignore carriage returns on Windows
837 new_content.replace("\r\n", "\n"),
838 FORMATTED_CONTENT,
839 "Code should be formatted when format_on_save is enabled"
840 );
841
842 let stale_buffer_count = thread
843 .read_with(cx, |thread, _cx| thread.action_log.clone())
844 .read_with(cx, |log, cx| log.stale_buffers(cx).count());
845
846 assert_eq!(
847 stale_buffer_count, 0,
848 "BUG: Buffer is incorrectly marked as stale after format-on-save. Found {} stale buffers. \
849 This causes the agent to think the file was modified externally when it was just formatted.",
850 stale_buffer_count
851 );
852
853 // Next, test with format_on_save disabled
854 cx.update(|cx| {
855 SettingsStore::update_global(cx, |store, cx| {
856 store.update_user_settings(cx, |settings| {
857 settings.project.all_languages.defaults.format_on_save =
858 Some(FormatOnSave::Off);
859 });
860 });
861 });
862
863 // Stream unformatted edits again
864 let edit_result = {
865 let edit_task = cx.update(|cx| {
866 let input = EditFileToolInput {
867 display_description: "Update main function".into(),
868 path: "root/src/main.rs".into(),
869 mode: EditFileMode::Overwrite,
870 };
871 Arc::new(EditFileTool::new(
872 project.clone(),
873 thread.downgrade(),
874 language_registry,
875 Templates::new(),
876 ))
877 .run(input, ToolCallEventStream::test().0, cx)
878 });
879
880 // Stream the unformatted content
881 cx.executor().run_until_parked();
882 model.send_last_completion_stream_text_chunk(UNFORMATTED_CONTENT.to_string());
883 model.end_last_completion_stream();
884
885 edit_task.await
886 };
887 assert!(edit_result.is_ok());
888
889 // Wait for any async operations (e.g. formatting) to complete
890 cx.executor().run_until_parked();
891
892 // Verify the file was not formatted
893 let new_content = fs.load(path!("/root/src/main.rs").as_ref()).await.unwrap();
894 assert_eq!(
895 // Ignore carriage returns on Windows
896 new_content.replace("\r\n", "\n"),
897 UNFORMATTED_CONTENT,
898 "Code should not be formatted when format_on_save is disabled"
899 );
900 }
901
902 #[gpui::test]
903 async fn test_remove_trailing_whitespace(cx: &mut TestAppContext) {
904 init_test(cx);
905
906 let fs = project::FakeFs::new(cx.executor());
907 fs.insert_tree("/root", json!({"src": {}})).await;
908
909 // Create a simple file with trailing whitespace
910 fs.save(
911 path!("/root/src/main.rs").as_ref(),
912 &Rope::from_str_small("initial content"),
913 language::LineEnding::Unix,
914 )
915 .await
916 .unwrap();
917
918 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
919 let context_server_registry =
920 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
921 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
922 let model = Arc::new(FakeLanguageModel::default());
923 let thread = cx.new(|cx| {
924 Thread::new(
925 project.clone(),
926 cx.new(|_cx| ProjectContext::default()),
927 context_server_registry,
928 Templates::new(),
929 Some(model.clone()),
930 cx,
931 )
932 });
933
934 // First, test with remove_trailing_whitespace_on_save enabled
935 cx.update(|cx| {
936 SettingsStore::update_global(cx, |store, cx| {
937 store.update_user_settings(cx, |settings| {
938 settings
939 .project
940 .all_languages
941 .defaults
942 .remove_trailing_whitespace_on_save = Some(true);
943 });
944 });
945 });
946
947 const CONTENT_WITH_TRAILING_WHITESPACE: &str =
948 "fn main() { \n println!(\"Hello!\"); \n}\n";
949
950 // Have the model stream content that contains trailing whitespace
951 let edit_result = {
952 let edit_task = cx.update(|cx| {
953 let input = EditFileToolInput {
954 display_description: "Create main function".into(),
955 path: "root/src/main.rs".into(),
956 mode: EditFileMode::Overwrite,
957 };
958 Arc::new(EditFileTool::new(
959 project.clone(),
960 thread.downgrade(),
961 language_registry.clone(),
962 Templates::new(),
963 ))
964 .run(input, ToolCallEventStream::test().0, cx)
965 });
966
967 // Stream the content with trailing whitespace
968 cx.executor().run_until_parked();
969 model.send_last_completion_stream_text_chunk(
970 CONTENT_WITH_TRAILING_WHITESPACE.to_string(),
971 );
972 model.end_last_completion_stream();
973
974 edit_task.await
975 };
976 assert!(edit_result.is_ok());
977
978 // Wait for any async operations (e.g. formatting) to complete
979 cx.executor().run_until_parked();
980
981 // Read the file to verify trailing whitespace was removed automatically
982 assert_eq!(
983 // Ignore carriage returns on Windows
984 fs.load(path!("/root/src/main.rs").as_ref())
985 .await
986 .unwrap()
987 .replace("\r\n", "\n"),
988 "fn main() {\n println!(\"Hello!\");\n}\n",
989 "Trailing whitespace should be removed when remove_trailing_whitespace_on_save is enabled"
990 );
991
992 // Next, test with remove_trailing_whitespace_on_save disabled
993 cx.update(|cx| {
994 SettingsStore::update_global(cx, |store, cx| {
995 store.update_user_settings(cx, |settings| {
996 settings
997 .project
998 .all_languages
999 .defaults
1000 .remove_trailing_whitespace_on_save = Some(false);
1001 });
1002 });
1003 });
1004
1005 // Stream edits again with trailing whitespace
1006 let edit_result = {
1007 let edit_task = cx.update(|cx| {
1008 let input = EditFileToolInput {
1009 display_description: "Update main function".into(),
1010 path: "root/src/main.rs".into(),
1011 mode: EditFileMode::Overwrite,
1012 };
1013 Arc::new(EditFileTool::new(
1014 project.clone(),
1015 thread.downgrade(),
1016 language_registry,
1017 Templates::new(),
1018 ))
1019 .run(input, ToolCallEventStream::test().0, cx)
1020 });
1021
1022 // Stream the content with trailing whitespace
1023 cx.executor().run_until_parked();
1024 model.send_last_completion_stream_text_chunk(
1025 CONTENT_WITH_TRAILING_WHITESPACE.to_string(),
1026 );
1027 model.end_last_completion_stream();
1028
1029 edit_task.await
1030 };
1031 assert!(edit_result.is_ok());
1032
1033 // Wait for any async operations (e.g. formatting) to complete
1034 cx.executor().run_until_parked();
1035
1036 // Verify the file still has trailing whitespace
1037 // Read the file again - it should still have trailing whitespace
1038 let final_content = fs.load(path!("/root/src/main.rs").as_ref()).await.unwrap();
1039 assert_eq!(
1040 // Ignore carriage returns on Windows
1041 final_content.replace("\r\n", "\n"),
1042 CONTENT_WITH_TRAILING_WHITESPACE,
1043 "Trailing whitespace should remain when remove_trailing_whitespace_on_save is disabled"
1044 );
1045 }
1046
1047 #[gpui::test]
1048 async fn test_authorize(cx: &mut TestAppContext) {
1049 init_test(cx);
1050 let fs = project::FakeFs::new(cx.executor());
1051 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
1052 let context_server_registry =
1053 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1054 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1055 let model = Arc::new(FakeLanguageModel::default());
1056 let thread = cx.new(|cx| {
1057 Thread::new(
1058 project.clone(),
1059 cx.new(|_cx| ProjectContext::default()),
1060 context_server_registry,
1061 Templates::new(),
1062 Some(model.clone()),
1063 cx,
1064 )
1065 });
1066 let tool = Arc::new(EditFileTool::new(
1067 project.clone(),
1068 thread.downgrade(),
1069 language_registry,
1070 Templates::new(),
1071 ));
1072 fs.insert_tree("/root", json!({})).await;
1073
1074 // Test 1: Path with .zed component should require confirmation
1075 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1076 let _auth = cx.update(|cx| {
1077 tool.authorize(
1078 &EditFileToolInput {
1079 display_description: "test 1".into(),
1080 path: ".zed/settings.json".into(),
1081 mode: EditFileMode::Edit,
1082 },
1083 &stream_tx,
1084 cx,
1085 )
1086 });
1087
1088 let event = stream_rx.expect_authorization().await;
1089 assert_eq!(
1090 event.tool_call.fields.title,
1091 Some("test 1 (local settings)".into())
1092 );
1093
1094 // Test 2: Path outside project should require confirmation
1095 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1096 let _auth = cx.update(|cx| {
1097 tool.authorize(
1098 &EditFileToolInput {
1099 display_description: "test 2".into(),
1100 path: "/etc/hosts".into(),
1101 mode: EditFileMode::Edit,
1102 },
1103 &stream_tx,
1104 cx,
1105 )
1106 });
1107
1108 let event = stream_rx.expect_authorization().await;
1109 assert_eq!(event.tool_call.fields.title, Some("test 2".into()));
1110
1111 // Test 3: Relative path without .zed should not require confirmation
1112 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1113 cx.update(|cx| {
1114 tool.authorize(
1115 &EditFileToolInput {
1116 display_description: "test 3".into(),
1117 path: "root/src/main.rs".into(),
1118 mode: EditFileMode::Edit,
1119 },
1120 &stream_tx,
1121 cx,
1122 )
1123 })
1124 .await
1125 .unwrap();
1126 assert!(stream_rx.try_next().is_err());
1127
1128 // Test 4: Path with .zed in the middle should require confirmation
1129 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1130 let _auth = cx.update(|cx| {
1131 tool.authorize(
1132 &EditFileToolInput {
1133 display_description: "test 4".into(),
1134 path: "root/.zed/tasks.json".into(),
1135 mode: EditFileMode::Edit,
1136 },
1137 &stream_tx,
1138 cx,
1139 )
1140 });
1141 let event = stream_rx.expect_authorization().await;
1142 assert_eq!(
1143 event.tool_call.fields.title,
1144 Some("test 4 (local settings)".into())
1145 );
1146
1147 // Test 5: When always_allow_tool_actions is enabled, no confirmation needed
1148 cx.update(|cx| {
1149 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
1150 settings.always_allow_tool_actions = true;
1151 agent_settings::AgentSettings::override_global(settings, cx);
1152 });
1153
1154 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1155 cx.update(|cx| {
1156 tool.authorize(
1157 &EditFileToolInput {
1158 display_description: "test 5.1".into(),
1159 path: ".zed/settings.json".into(),
1160 mode: EditFileMode::Edit,
1161 },
1162 &stream_tx,
1163 cx,
1164 )
1165 })
1166 .await
1167 .unwrap();
1168 assert!(stream_rx.try_next().is_err());
1169
1170 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1171 cx.update(|cx| {
1172 tool.authorize(
1173 &EditFileToolInput {
1174 display_description: "test 5.2".into(),
1175 path: "/etc/hosts".into(),
1176 mode: EditFileMode::Edit,
1177 },
1178 &stream_tx,
1179 cx,
1180 )
1181 })
1182 .await
1183 .unwrap();
1184 assert!(stream_rx.try_next().is_err());
1185 }
1186
1187 #[gpui::test]
1188 async fn test_authorize_global_config(cx: &mut TestAppContext) {
1189 init_test(cx);
1190 let fs = project::FakeFs::new(cx.executor());
1191 fs.insert_tree("/project", json!({})).await;
1192 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1193 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1194 let context_server_registry =
1195 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1196 let model = Arc::new(FakeLanguageModel::default());
1197 let thread = cx.new(|cx| {
1198 Thread::new(
1199 project.clone(),
1200 cx.new(|_cx| ProjectContext::default()),
1201 context_server_registry,
1202 Templates::new(),
1203 Some(model.clone()),
1204 cx,
1205 )
1206 });
1207 let tool = Arc::new(EditFileTool::new(
1208 project.clone(),
1209 thread.downgrade(),
1210 language_registry,
1211 Templates::new(),
1212 ));
1213
1214 // Test global config paths - these should require confirmation if they exist and are outside the project
1215 let test_cases = vec![
1216 (
1217 "/etc/hosts",
1218 true,
1219 "System file should require confirmation",
1220 ),
1221 (
1222 "/usr/local/bin/script",
1223 true,
1224 "System bin file should require confirmation",
1225 ),
1226 (
1227 "project/normal_file.rs",
1228 false,
1229 "Normal project file should not require confirmation",
1230 ),
1231 ];
1232
1233 for (path, should_confirm, description) in test_cases {
1234 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1235 let auth = cx.update(|cx| {
1236 tool.authorize(
1237 &EditFileToolInput {
1238 display_description: "Edit file".into(),
1239 path: path.into(),
1240 mode: EditFileMode::Edit,
1241 },
1242 &stream_tx,
1243 cx,
1244 )
1245 });
1246
1247 if should_confirm {
1248 stream_rx.expect_authorization().await;
1249 } else {
1250 auth.await.unwrap();
1251 assert!(
1252 stream_rx.try_next().is_err(),
1253 "Failed for case: {} - path: {} - expected no confirmation but got one",
1254 description,
1255 path
1256 );
1257 }
1258 }
1259 }
1260
1261 #[gpui::test]
1262 async fn test_needs_confirmation_with_multiple_worktrees(cx: &mut TestAppContext) {
1263 init_test(cx);
1264 let fs = project::FakeFs::new(cx.executor());
1265
1266 // Create multiple worktree directories
1267 fs.insert_tree(
1268 "/workspace/frontend",
1269 json!({
1270 "src": {
1271 "main.js": "console.log('frontend');"
1272 }
1273 }),
1274 )
1275 .await;
1276 fs.insert_tree(
1277 "/workspace/backend",
1278 json!({
1279 "src": {
1280 "main.rs": "fn main() {}"
1281 }
1282 }),
1283 )
1284 .await;
1285 fs.insert_tree(
1286 "/workspace/shared",
1287 json!({
1288 ".zed": {
1289 "settings.json": "{}"
1290 }
1291 }),
1292 )
1293 .await;
1294
1295 // Create project with multiple worktrees
1296 let project = Project::test(
1297 fs.clone(),
1298 [
1299 path!("/workspace/frontend").as_ref(),
1300 path!("/workspace/backend").as_ref(),
1301 path!("/workspace/shared").as_ref(),
1302 ],
1303 cx,
1304 )
1305 .await;
1306 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1307 let context_server_registry =
1308 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1309 let model = Arc::new(FakeLanguageModel::default());
1310 let thread = cx.new(|cx| {
1311 Thread::new(
1312 project.clone(),
1313 cx.new(|_cx| ProjectContext::default()),
1314 context_server_registry.clone(),
1315 Templates::new(),
1316 Some(model.clone()),
1317 cx,
1318 )
1319 });
1320 let tool = Arc::new(EditFileTool::new(
1321 project.clone(),
1322 thread.downgrade(),
1323 language_registry,
1324 Templates::new(),
1325 ));
1326
1327 // Test files in different worktrees
1328 let test_cases = vec![
1329 ("frontend/src/main.js", false, "File in first worktree"),
1330 ("backend/src/main.rs", false, "File in second worktree"),
1331 (
1332 "shared/.zed/settings.json",
1333 true,
1334 ".zed file in third worktree",
1335 ),
1336 ("/etc/hosts", true, "Absolute path outside all worktrees"),
1337 (
1338 "../outside/file.txt",
1339 true,
1340 "Relative path outside worktrees",
1341 ),
1342 ];
1343
1344 for (path, should_confirm, description) in test_cases {
1345 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1346 let auth = cx.update(|cx| {
1347 tool.authorize(
1348 &EditFileToolInput {
1349 display_description: "Edit file".into(),
1350 path: path.into(),
1351 mode: EditFileMode::Edit,
1352 },
1353 &stream_tx,
1354 cx,
1355 )
1356 });
1357
1358 if should_confirm {
1359 stream_rx.expect_authorization().await;
1360 } else {
1361 auth.await.unwrap();
1362 assert!(
1363 stream_rx.try_next().is_err(),
1364 "Failed for case: {} - path: {} - expected no confirmation but got one",
1365 description,
1366 path
1367 );
1368 }
1369 }
1370 }
1371
1372 #[gpui::test]
1373 async fn test_needs_confirmation_edge_cases(cx: &mut TestAppContext) {
1374 init_test(cx);
1375 let fs = project::FakeFs::new(cx.executor());
1376 fs.insert_tree(
1377 "/project",
1378 json!({
1379 ".zed": {
1380 "settings.json": "{}"
1381 },
1382 "src": {
1383 ".zed": {
1384 "local.json": "{}"
1385 }
1386 }
1387 }),
1388 )
1389 .await;
1390 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1391 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1392 let context_server_registry =
1393 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1394 let model = Arc::new(FakeLanguageModel::default());
1395 let thread = cx.new(|cx| {
1396 Thread::new(
1397 project.clone(),
1398 cx.new(|_cx| ProjectContext::default()),
1399 context_server_registry.clone(),
1400 Templates::new(),
1401 Some(model.clone()),
1402 cx,
1403 )
1404 });
1405 let tool = Arc::new(EditFileTool::new(
1406 project.clone(),
1407 thread.downgrade(),
1408 language_registry,
1409 Templates::new(),
1410 ));
1411
1412 // Test edge cases
1413 let test_cases = vec![
1414 // Empty path - find_project_path returns Some for empty paths
1415 ("", false, "Empty path is treated as project root"),
1416 // Root directory
1417 ("/", true, "Root directory should be outside project"),
1418 // Parent directory references - find_project_path resolves these
1419 (
1420 "project/../other",
1421 true,
1422 "Path with .. that goes outside of root directory",
1423 ),
1424 (
1425 "project/./src/file.rs",
1426 false,
1427 "Path with . should work normally",
1428 ),
1429 // Windows-style paths (if on Windows)
1430 #[cfg(target_os = "windows")]
1431 ("C:\\Windows\\System32\\hosts", true, "Windows system path"),
1432 #[cfg(target_os = "windows")]
1433 ("project\\src\\main.rs", false, "Windows-style project path"),
1434 ];
1435
1436 for (path, should_confirm, description) in test_cases {
1437 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1438 let auth = cx.update(|cx| {
1439 tool.authorize(
1440 &EditFileToolInput {
1441 display_description: "Edit file".into(),
1442 path: path.into(),
1443 mode: EditFileMode::Edit,
1444 },
1445 &stream_tx,
1446 cx,
1447 )
1448 });
1449
1450 cx.run_until_parked();
1451
1452 if should_confirm {
1453 stream_rx.expect_authorization().await;
1454 } else {
1455 assert!(
1456 stream_rx.try_next().is_err(),
1457 "Failed for case: {} - path: {} - expected no confirmation but got one",
1458 description,
1459 path
1460 );
1461 auth.await.unwrap();
1462 }
1463 }
1464 }
1465
1466 #[gpui::test]
1467 async fn test_needs_confirmation_with_different_modes(cx: &mut TestAppContext) {
1468 init_test(cx);
1469 let fs = project::FakeFs::new(cx.executor());
1470 fs.insert_tree(
1471 "/project",
1472 json!({
1473 "existing.txt": "content",
1474 ".zed": {
1475 "settings.json": "{}"
1476 }
1477 }),
1478 )
1479 .await;
1480 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1481 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1482 let context_server_registry =
1483 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1484 let model = Arc::new(FakeLanguageModel::default());
1485 let thread = cx.new(|cx| {
1486 Thread::new(
1487 project.clone(),
1488 cx.new(|_cx| ProjectContext::default()),
1489 context_server_registry.clone(),
1490 Templates::new(),
1491 Some(model.clone()),
1492 cx,
1493 )
1494 });
1495 let tool = Arc::new(EditFileTool::new(
1496 project.clone(),
1497 thread.downgrade(),
1498 language_registry,
1499 Templates::new(),
1500 ));
1501
1502 // Test different EditFileMode values
1503 let modes = vec![
1504 EditFileMode::Edit,
1505 EditFileMode::Create,
1506 EditFileMode::Overwrite,
1507 ];
1508
1509 for mode in modes {
1510 // Test .zed path with different modes
1511 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1512 let _auth = cx.update(|cx| {
1513 tool.authorize(
1514 &EditFileToolInput {
1515 display_description: "Edit settings".into(),
1516 path: "project/.zed/settings.json".into(),
1517 mode: mode.clone(),
1518 },
1519 &stream_tx,
1520 cx,
1521 )
1522 });
1523
1524 stream_rx.expect_authorization().await;
1525
1526 // Test outside path with different modes
1527 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1528 let _auth = cx.update(|cx| {
1529 tool.authorize(
1530 &EditFileToolInput {
1531 display_description: "Edit file".into(),
1532 path: "/outside/file.txt".into(),
1533 mode: mode.clone(),
1534 },
1535 &stream_tx,
1536 cx,
1537 )
1538 });
1539
1540 stream_rx.expect_authorization().await;
1541
1542 // Test normal path with different modes
1543 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1544 cx.update(|cx| {
1545 tool.authorize(
1546 &EditFileToolInput {
1547 display_description: "Edit file".into(),
1548 path: "project/normal.txt".into(),
1549 mode: mode.clone(),
1550 },
1551 &stream_tx,
1552 cx,
1553 )
1554 })
1555 .await
1556 .unwrap();
1557 assert!(stream_rx.try_next().is_err());
1558 }
1559 }
1560
1561 #[gpui::test]
1562 async fn test_initial_title_with_partial_input(cx: &mut TestAppContext) {
1563 init_test(cx);
1564 let fs = project::FakeFs::new(cx.executor());
1565 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1566 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1567 let context_server_registry =
1568 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1569 let model = Arc::new(FakeLanguageModel::default());
1570 let thread = cx.new(|cx| {
1571 Thread::new(
1572 project.clone(),
1573 cx.new(|_cx| ProjectContext::default()),
1574 context_server_registry,
1575 Templates::new(),
1576 Some(model.clone()),
1577 cx,
1578 )
1579 });
1580 let tool = Arc::new(EditFileTool::new(
1581 project,
1582 thread.downgrade(),
1583 language_registry,
1584 Templates::new(),
1585 ));
1586
1587 cx.update(|cx| {
1588 // ...
1589 assert_eq!(
1590 tool.initial_title(
1591 Err(json!({
1592 "path": "src/main.rs",
1593 "display_description": "",
1594 "old_string": "old code",
1595 "new_string": "new code"
1596 })),
1597 cx
1598 ),
1599 "src/main.rs"
1600 );
1601 assert_eq!(
1602 tool.initial_title(
1603 Err(json!({
1604 "path": "",
1605 "display_description": "Fix error handling",
1606 "old_string": "old code",
1607 "new_string": "new code"
1608 })),
1609 cx
1610 ),
1611 "Fix error handling"
1612 );
1613 assert_eq!(
1614 tool.initial_title(
1615 Err(json!({
1616 "path": "src/main.rs",
1617 "display_description": "Fix error handling",
1618 "old_string": "old code",
1619 "new_string": "new code"
1620 })),
1621 cx
1622 ),
1623 "src/main.rs"
1624 );
1625 assert_eq!(
1626 tool.initial_title(
1627 Err(json!({
1628 "path": "",
1629 "display_description": "",
1630 "old_string": "old code",
1631 "new_string": "new code"
1632 })),
1633 cx
1634 ),
1635 DEFAULT_UI_TEXT
1636 );
1637 assert_eq!(
1638 tool.initial_title(Err(serde_json::Value::Null), cx),
1639 DEFAULT_UI_TEXT
1640 );
1641 });
1642 }
1643
1644 #[gpui::test]
1645 async fn test_diff_finalization(cx: &mut TestAppContext) {
1646 init_test(cx);
1647 let fs = project::FakeFs::new(cx.executor());
1648 fs.insert_tree("/", json!({"main.rs": ""})).await;
1649
1650 let project = Project::test(fs.clone(), [path!("/").as_ref()], cx).await;
1651 let languages = project.read_with(cx, |project, _cx| project.languages().clone());
1652 let context_server_registry =
1653 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1654 let model = Arc::new(FakeLanguageModel::default());
1655 let thread = cx.new(|cx| {
1656 Thread::new(
1657 project.clone(),
1658 cx.new(|_cx| ProjectContext::default()),
1659 context_server_registry.clone(),
1660 Templates::new(),
1661 Some(model.clone()),
1662 cx,
1663 )
1664 });
1665
1666 // Ensure the diff is finalized after the edit completes.
1667 {
1668 let tool = Arc::new(EditFileTool::new(
1669 project.clone(),
1670 thread.downgrade(),
1671 languages.clone(),
1672 Templates::new(),
1673 ));
1674 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1675 let edit = cx.update(|cx| {
1676 tool.run(
1677 EditFileToolInput {
1678 display_description: "Edit file".into(),
1679 path: path!("/main.rs").into(),
1680 mode: EditFileMode::Edit,
1681 },
1682 stream_tx,
1683 cx,
1684 )
1685 });
1686 stream_rx.expect_update_fields().await;
1687 let diff = stream_rx.expect_diff().await;
1688 diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Pending(_))));
1689 cx.run_until_parked();
1690 model.end_last_completion_stream();
1691 edit.await.unwrap();
1692 diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Finalized(_))));
1693 }
1694
1695 // Ensure the diff is finalized if an error occurs while editing.
1696 {
1697 model.forbid_requests();
1698 let tool = Arc::new(EditFileTool::new(
1699 project.clone(),
1700 thread.downgrade(),
1701 languages.clone(),
1702 Templates::new(),
1703 ));
1704 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1705 let edit = cx.update(|cx| {
1706 tool.run(
1707 EditFileToolInput {
1708 display_description: "Edit file".into(),
1709 path: path!("/main.rs").into(),
1710 mode: EditFileMode::Edit,
1711 },
1712 stream_tx,
1713 cx,
1714 )
1715 });
1716 stream_rx.expect_update_fields().await;
1717 let diff = stream_rx.expect_diff().await;
1718 diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Pending(_))));
1719 edit.await.unwrap_err();
1720 diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Finalized(_))));
1721 model.allow_requests();
1722 }
1723
1724 // Ensure the diff is finalized if the tool call gets dropped.
1725 {
1726 let tool = Arc::new(EditFileTool::new(
1727 project.clone(),
1728 thread.downgrade(),
1729 languages.clone(),
1730 Templates::new(),
1731 ));
1732 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1733 let edit = cx.update(|cx| {
1734 tool.run(
1735 EditFileToolInput {
1736 display_description: "Edit file".into(),
1737 path: path!("/main.rs").into(),
1738 mode: EditFileMode::Edit,
1739 },
1740 stream_tx,
1741 cx,
1742 )
1743 });
1744 stream_rx.expect_update_fields().await;
1745 let diff = stream_rx.expect_diff().await;
1746 diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Pending(_))));
1747 drop(edit);
1748 cx.run_until_parked();
1749 diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Finalized(_))));
1750 }
1751 }
1752
1753 fn init_test(cx: &mut TestAppContext) {
1754 cx.update(|cx| {
1755 let settings_store = SettingsStore::test(cx);
1756 cx.set_global(settings_store);
1757 language::init(cx);
1758 TelemetrySettings::register(cx);
1759 agent_settings::AgentSettings::register(cx);
1760 Project::init_settings(cx);
1761 });
1762 }
1763}