1use crate::{AgentTool, Thread, ToolCallEventStream};
2use acp_thread::Diff;
3use agent_client_protocol::{self as acp, ToolCallLocation, ToolCallUpdateFields};
4use anyhow::{Context as _, Result, anyhow};
5use assistant_tools::edit_agent::{EditAgent, EditAgentOutput, EditAgentOutputEvent, EditFormat};
6use cloud_llm_client::CompletionIntent;
7use collections::HashSet;
8use gpui::{App, AppContext, AsyncApp, Entity, Task, WeakEntity};
9use indoc::formatdoc;
10use language::language_settings::{self, FormatOnSave};
11use language::{LanguageRegistry, ToPoint};
12use language_model::LanguageModelToolResultContent;
13use paths;
14use project::lsp_store::{FormatTrigger, LspFormatTarget};
15use project::{Project, ProjectPath};
16use schemars::JsonSchema;
17use serde::{Deserialize, Serialize};
18use settings::Settings;
19use smol::stream::StreamExt as _;
20use std::ffi::OsStr;
21use std::path::{Path, PathBuf};
22use std::sync::Arc;
23use ui::SharedString;
24use util::ResultExt;
25use util::rel_path::RelPath;
26
27const DEFAULT_UI_TEXT: &str = "Editing file";
28
29/// This is a tool for creating a new file or editing an existing file. For moving or renaming files, you should generally use the `terminal` tool with the 'mv' command instead.
30///
31/// Before using this tool:
32///
33/// 1. Use the `read_file` tool to understand the file's contents and context
34///
35/// 2. Verify the directory path is correct (only applicable when creating new files):
36/// - Use the `list_directory` tool to verify the parent directory exists and is the correct location
37#[derive(Debug, Serialize, Deserialize, JsonSchema)]
38pub struct EditFileToolInput {
39 /// A one-line, user-friendly markdown description of the edit. This will be shown in the UI and also passed to another model to perform the edit.
40 ///
41 /// Be terse, but also descriptive in what you want to achieve with this edit. Avoid generic instructions.
42 ///
43 /// NEVER mention the file path in this description.
44 ///
45 /// <example>Fix API endpoint URLs</example>
46 /// <example>Update copyright year in `page_footer`</example>
47 ///
48 /// Make sure to include this field before all the others in the input object so that we can display it immediately.
49 pub display_description: String,
50
51 /// The full path of the file to create or modify in the project.
52 ///
53 /// WARNING: When specifying which file path need changing, you MUST start each path with one of the project's root directories.
54 ///
55 /// The following examples assume we have two root directories in the project:
56 /// - /a/b/backend
57 /// - /c/d/frontend
58 ///
59 /// <example>
60 /// `backend/src/main.rs`
61 ///
62 /// Notice how the file path starts with `backend`. Without that, the path would be ambiguous and the call would fail!
63 /// </example>
64 ///
65 /// <example>
66 /// `frontend/db.js`
67 /// </example>
68 pub path: PathBuf,
69 /// The mode of operation on the file. Possible values:
70 /// - 'edit': Make granular edits to an existing file.
71 /// - 'create': Create a new file if it doesn't exist.
72 /// - 'overwrite': Replace the entire contents of an existing file.
73 ///
74 /// When a file already exists or you just created it, prefer editing it as opposed to recreating it from scratch.
75 pub mode: EditFileMode,
76}
77
78#[derive(Debug, Serialize, Deserialize, JsonSchema)]
79struct EditFileToolPartialInput {
80 #[serde(default)]
81 path: String,
82 #[serde(default)]
83 display_description: String,
84}
85
86#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
87#[serde(rename_all = "lowercase")]
88#[schemars(inline)]
89pub enum EditFileMode {
90 Edit,
91 Create,
92 Overwrite,
93}
94
95#[derive(Debug, Serialize, Deserialize)]
96pub struct EditFileToolOutput {
97 #[serde(alias = "original_path")]
98 input_path: PathBuf,
99 new_text: String,
100 old_text: Arc<String>,
101 #[serde(default)]
102 diff: String,
103 #[serde(alias = "raw_output")]
104 edit_agent_output: EditAgentOutput,
105}
106
107impl From<EditFileToolOutput> for LanguageModelToolResultContent {
108 fn from(output: EditFileToolOutput) -> Self {
109 if output.diff.is_empty() {
110 "No edits were made.".into()
111 } else {
112 format!(
113 "Edited {}:\n\n```diff\n{}\n```",
114 output.input_path.display(),
115 output.diff
116 )
117 .into()
118 }
119 }
120}
121
122pub struct EditFileTool {
123 thread: WeakEntity<Thread>,
124 language_registry: Arc<LanguageRegistry>,
125 project: Entity<Project>,
126}
127
128impl EditFileTool {
129 pub fn new(
130 project: Entity<Project>,
131 thread: WeakEntity<Thread>,
132 language_registry: Arc<LanguageRegistry>,
133 ) -> Self {
134 Self {
135 project,
136 thread,
137 language_registry,
138 }
139 }
140
141 fn authorize(
142 &self,
143 input: &EditFileToolInput,
144 event_stream: &ToolCallEventStream,
145 cx: &mut App,
146 ) -> Task<Result<()>> {
147 if agent_settings::AgentSettings::get_global(cx).always_allow_tool_actions {
148 return Task::ready(Ok(()));
149 }
150
151 // If any path component matches the local settings folder, then this could affect
152 // the editor in ways beyond the project source, so prompt.
153 let local_settings_folder = paths::local_settings_folder_name();
154 let path = Path::new(&input.path);
155 if path.components().any(|component| {
156 component.as_os_str() == <_ as AsRef<OsStr>>::as_ref(&local_settings_folder)
157 }) {
158 return event_stream.authorize(
159 format!("{} (local settings)", input.display_description),
160 cx,
161 );
162 }
163
164 // It's also possible that the global config dir is configured to be inside the project,
165 // so check for that edge case too.
166 // TODO this is broken when remoting
167 if let Ok(canonical_path) = std::fs::canonicalize(&input.path)
168 && canonical_path.starts_with(paths::config_dir())
169 {
170 return event_stream.authorize(
171 format!("{} (global settings)", input.display_description),
172 cx,
173 );
174 }
175
176 // Check if path is inside the global config directory
177 // First check if it's already inside project - if not, try to canonicalize
178 let Ok(project_path) = self.thread.read_with(cx, |thread, cx| {
179 thread.project().read(cx).find_project_path(&input.path, cx)
180 }) else {
181 return Task::ready(Err(anyhow!("thread was dropped")));
182 };
183
184 // If the path is inside the project, and it's not one of the above edge cases,
185 // then no confirmation is necessary. Otherwise, confirmation is necessary.
186 if project_path.is_some() {
187 Task::ready(Ok(()))
188 } else {
189 event_stream.authorize(&input.display_description, cx)
190 }
191 }
192}
193
194impl AgentTool for EditFileTool {
195 type Input = EditFileToolInput;
196 type Output = EditFileToolOutput;
197
198 fn name() -> &'static str {
199 "edit_file"
200 }
201
202 fn kind() -> acp::ToolKind {
203 acp::ToolKind::Edit
204 }
205
206 fn initial_title(
207 &self,
208 input: Result<Self::Input, serde_json::Value>,
209 cx: &mut App,
210 ) -> SharedString {
211 match input {
212 Ok(input) => self
213 .project
214 .read(cx)
215 .find_project_path(&input.path, cx)
216 .and_then(|project_path| {
217 self.project
218 .read(cx)
219 .short_full_path_for_project_path(&project_path, cx)
220 })
221 .unwrap_or(input.path.to_string_lossy().into_owned())
222 .into(),
223 Err(raw_input) => {
224 if let Some(input) =
225 serde_json::from_value::<EditFileToolPartialInput>(raw_input).ok()
226 {
227 let path = input.path.trim();
228 if !path.is_empty() {
229 return self
230 .project
231 .read(cx)
232 .find_project_path(&input.path, cx)
233 .and_then(|project_path| {
234 self.project
235 .read(cx)
236 .short_full_path_for_project_path(&project_path, cx)
237 })
238 .unwrap_or(input.path)
239 .into();
240 }
241
242 let description = input.display_description.trim();
243 if !description.is_empty() {
244 return description.to_string().into();
245 }
246 }
247
248 DEFAULT_UI_TEXT.into()
249 }
250 }
251 }
252
253 fn run(
254 self: Arc<Self>,
255 input: Self::Input,
256 event_stream: ToolCallEventStream,
257 cx: &mut App,
258 ) -> Task<Result<Self::Output>> {
259 let Ok(project) = self
260 .thread
261 .read_with(cx, |thread, _cx| thread.project().clone())
262 else {
263 return Task::ready(Err(anyhow!("thread was dropped")));
264 };
265 let project_path = match resolve_path(&input, project.clone(), cx) {
266 Ok(path) => path,
267 Err(err) => return Task::ready(Err(anyhow!(err))),
268 };
269 let abs_path = project.read(cx).absolute_path(&project_path, cx);
270 if let Some(abs_path) = abs_path.clone() {
271 event_stream.update_fields(ToolCallUpdateFields {
272 locations: Some(vec![acp::ToolCallLocation {
273 path: abs_path,
274 line: None,
275 meta: None,
276 }]),
277 ..Default::default()
278 });
279 }
280
281 let authorize = self.authorize(&input, &event_stream, cx);
282 cx.spawn(async move |cx: &mut AsyncApp| {
283 authorize.await?;
284
285 let (request, model, action_log) = self.thread.update(cx, |thread, cx| {
286 let request = thread.build_completion_request(CompletionIntent::ToolResults, cx);
287 (request, thread.model().cloned(), thread.action_log().clone())
288 })?;
289 let request = request?;
290 let model = model.context("No language model configured")?;
291
292 let edit_format = EditFormat::from_model(model.clone())?;
293 let edit_agent = EditAgent::new(
294 model,
295 project.clone(),
296 action_log.clone(),
297 // TODO: move edit agent to this crate so we can use our templates
298 assistant_tools::templates::Templates::new(),
299 edit_format,
300 );
301
302 let buffer = project
303 .update(cx, |project, cx| {
304 project.open_buffer(project_path.clone(), cx)
305 })?
306 .await?;
307
308 let diff = cx.new(|cx| Diff::new(buffer.clone(), cx))?;
309 event_stream.update_diff(diff.clone());
310 let _finalize_diff = util::defer({
311 let diff = diff.downgrade();
312 let mut cx = cx.clone();
313 move || {
314 diff.update(&mut cx, |diff, cx| diff.finalize(cx)).ok();
315 }
316 });
317
318 let old_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
319 let old_text = cx
320 .background_spawn({
321 let old_snapshot = old_snapshot.clone();
322 async move { Arc::new(old_snapshot.text()) }
323 })
324 .await;
325
326
327 let (output, mut events) = if matches!(input.mode, EditFileMode::Edit) {
328 edit_agent.edit(
329 buffer.clone(),
330 input.display_description.clone(),
331 &request,
332 cx,
333 )
334 } else {
335 edit_agent.overwrite(
336 buffer.clone(),
337 input.display_description.clone(),
338 &request,
339 cx,
340 )
341 };
342
343 let mut hallucinated_old_text = false;
344 let mut ambiguous_ranges = Vec::new();
345 let mut emitted_location = false;
346 while let Some(event) = events.next().await {
347 match event {
348 EditAgentOutputEvent::Edited(range) => {
349 if !emitted_location {
350 let line = buffer.update(cx, |buffer, _cx| {
351 range.start.to_point(&buffer.snapshot()).row
352 }).ok();
353 if let Some(abs_path) = abs_path.clone() {
354 event_stream.update_fields(ToolCallUpdateFields {
355 locations: Some(vec![ToolCallLocation { path: abs_path, line, meta: None }]),
356 ..Default::default()
357 });
358 }
359 emitted_location = true;
360 }
361 },
362 EditAgentOutputEvent::UnresolvedEditRange => hallucinated_old_text = true,
363 EditAgentOutputEvent::AmbiguousEditRange(ranges) => ambiguous_ranges = ranges,
364 EditAgentOutputEvent::ResolvingEditRange(range) => {
365 diff.update(cx, |card, cx| card.reveal_range(range.clone(), cx))?;
366 // if !emitted_location {
367 // let line = buffer.update(cx, |buffer, _cx| {
368 // range.start.to_point(&buffer.snapshot()).row
369 // }).ok();
370 // if let Some(abs_path) = abs_path.clone() {
371 // event_stream.update_fields(ToolCallUpdateFields {
372 // locations: Some(vec![ToolCallLocation { path: abs_path, line }]),
373 // ..Default::default()
374 // });
375 // }
376 // }
377 }
378 }
379 }
380
381 // If format_on_save is enabled, format the buffer
382 let format_on_save_enabled = buffer
383 .read_with(cx, |buffer, cx| {
384 let settings = language_settings::language_settings(
385 buffer.language().map(|l| l.name()),
386 buffer.file(),
387 cx,
388 );
389 settings.format_on_save != FormatOnSave::Off
390 })
391 .unwrap_or(false);
392
393 let edit_agent_output = output.await?;
394
395 if format_on_save_enabled {
396 action_log.update(cx, |log, cx| {
397 log.buffer_edited(buffer.clone(), cx);
398 })?;
399
400 let format_task = project.update(cx, |project, cx| {
401 project.format(
402 HashSet::from_iter([buffer.clone()]),
403 LspFormatTarget::Buffers,
404 false, // Don't push to history since the tool did it.
405 FormatTrigger::Save,
406 cx,
407 )
408 })?;
409 format_task.await.log_err();
410 }
411
412 project
413 .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))?
414 .await?;
415
416 action_log.update(cx, |log, cx| {
417 log.buffer_edited(buffer.clone(), cx);
418 })?;
419
420 let new_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
421 let (new_text, unified_diff) = cx
422 .background_spawn({
423 let new_snapshot = new_snapshot.clone();
424 let old_text = old_text.clone();
425 async move {
426 let new_text = new_snapshot.text();
427 let diff = language::unified_diff(&old_text, &new_text);
428 (new_text, diff)
429 }
430 })
431 .await;
432
433 let input_path = input.path.display();
434 if unified_diff.is_empty() {
435 anyhow::ensure!(
436 !hallucinated_old_text,
437 formatdoc! {"
438 Some edits were produced but none of them could be applied.
439 Read the relevant sections of {input_path} again so that
440 I can perform the requested edits.
441 "}
442 );
443 anyhow::ensure!(
444 ambiguous_ranges.is_empty(),
445 {
446 let line_numbers = ambiguous_ranges
447 .iter()
448 .map(|range| range.start.to_string())
449 .collect::<Vec<_>>()
450 .join(", ");
451 formatdoc! {"
452 <old_text> matches more than one position in the file (lines: {line_numbers}). Read the
453 relevant sections of {input_path} again and extend <old_text> so
454 that I can perform the requested edits.
455 "}
456 }
457 );
458 }
459
460 Ok(EditFileToolOutput {
461 input_path: input.path,
462 new_text,
463 old_text,
464 diff: unified_diff,
465 edit_agent_output,
466 })
467 })
468 }
469
470 fn replay(
471 &self,
472 _input: Self::Input,
473 output: Self::Output,
474 event_stream: ToolCallEventStream,
475 cx: &mut App,
476 ) -> Result<()> {
477 event_stream.update_diff(cx.new(|cx| {
478 Diff::finalized(
479 output.input_path.to_string_lossy().into_owned(),
480 Some(output.old_text.to_string()),
481 output.new_text,
482 self.language_registry.clone(),
483 cx,
484 )
485 }));
486 Ok(())
487 }
488}
489
490/// Validate that the file path is valid, meaning:
491///
492/// - For `edit` and `overwrite`, the path must point to an existing file.
493/// - For `create`, the file must not already exist, but it's parent dir must exist.
494fn resolve_path(
495 input: &EditFileToolInput,
496 project: Entity<Project>,
497 cx: &mut App,
498) -> Result<ProjectPath> {
499 let project = project.read(cx);
500
501 match input.mode {
502 EditFileMode::Edit | EditFileMode::Overwrite => {
503 let path = project
504 .find_project_path(&input.path, cx)
505 .context("Can't edit file: path not found")?;
506
507 let entry = project
508 .entry_for_path(&path, cx)
509 .context("Can't edit file: path not found")?;
510
511 anyhow::ensure!(entry.is_file(), "Can't edit file: path is a directory");
512 Ok(path)
513 }
514
515 EditFileMode::Create => {
516 if let Some(path) = project.find_project_path(&input.path, cx) {
517 anyhow::ensure!(
518 project.entry_for_path(&path, cx).is_none(),
519 "Can't create file: file already exists"
520 );
521 }
522
523 let parent_path = input
524 .path
525 .parent()
526 .context("Can't create file: incorrect path")?;
527
528 let parent_project_path = project.find_project_path(&parent_path, cx);
529
530 let parent_entry = parent_project_path
531 .as_ref()
532 .and_then(|path| project.entry_for_path(path, cx))
533 .context("Can't create file: parent directory doesn't exist")?;
534
535 anyhow::ensure!(
536 parent_entry.is_dir(),
537 "Can't create file: parent is not a directory"
538 );
539
540 let file_name = input
541 .path
542 .file_name()
543 .and_then(|file_name| file_name.to_str())
544 .and_then(|file_name| RelPath::unix(file_name).ok())
545 .context("Can't create file: invalid filename")?;
546
547 let new_file_path = parent_project_path.map(|parent| ProjectPath {
548 path: parent.path.join(file_name),
549 ..parent
550 });
551
552 new_file_path.context("Can't create file")
553 }
554 }
555}
556
557#[cfg(test)]
558mod tests {
559 use super::*;
560 use crate::{ContextServerRegistry, Templates};
561 use client::TelemetrySettings;
562 use fs::Fs;
563 use gpui::{TestAppContext, UpdateGlobal};
564 use language_model::fake_provider::FakeLanguageModel;
565 use prompt_store::ProjectContext;
566 use serde_json::json;
567 use settings::SettingsStore;
568 use util::{path, rel_path::rel_path};
569
570 #[gpui::test]
571 async fn test_edit_nonexistent_file(cx: &mut TestAppContext) {
572 init_test(cx);
573
574 let fs = project::FakeFs::new(cx.executor());
575 fs.insert_tree("/root", json!({})).await;
576 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
577 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
578 let context_server_registry =
579 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
580 let model = Arc::new(FakeLanguageModel::default());
581 let thread = cx.new(|cx| {
582 Thread::new(
583 project.clone(),
584 cx.new(|_cx| ProjectContext::default()),
585 context_server_registry,
586 Templates::new(),
587 Some(model),
588 cx,
589 )
590 });
591 let result = cx
592 .update(|cx| {
593 let input = EditFileToolInput {
594 display_description: "Some edit".into(),
595 path: "root/nonexistent_file.txt".into(),
596 mode: EditFileMode::Edit,
597 };
598 Arc::new(EditFileTool::new(
599 project,
600 thread.downgrade(),
601 language_registry,
602 ))
603 .run(input, ToolCallEventStream::test().0, cx)
604 })
605 .await;
606 assert_eq!(
607 result.unwrap_err().to_string(),
608 "Can't edit file: path not found"
609 );
610 }
611
612 #[gpui::test]
613 async fn test_resolve_path_for_creating_file(cx: &mut TestAppContext) {
614 let mode = &EditFileMode::Create;
615
616 let result = test_resolve_path(mode, "root/new.txt", cx);
617 assert_resolved_path_eq(result.await, rel_path("new.txt"));
618
619 let result = test_resolve_path(mode, "new.txt", cx);
620 assert_resolved_path_eq(result.await, rel_path("new.txt"));
621
622 let result = test_resolve_path(mode, "dir/new.txt", cx);
623 assert_resolved_path_eq(result.await, rel_path("dir/new.txt"));
624
625 let result = test_resolve_path(mode, "root/dir/subdir/existing.txt", cx);
626 assert_eq!(
627 result.await.unwrap_err().to_string(),
628 "Can't create file: file already exists"
629 );
630
631 let result = test_resolve_path(mode, "root/dir/nonexistent_dir/new.txt", cx);
632 assert_eq!(
633 result.await.unwrap_err().to_string(),
634 "Can't create file: parent directory doesn't exist"
635 );
636 }
637
638 #[gpui::test]
639 async fn test_resolve_path_for_editing_file(cx: &mut TestAppContext) {
640 let mode = &EditFileMode::Edit;
641
642 let path_with_root = "root/dir/subdir/existing.txt";
643 let path_without_root = "dir/subdir/existing.txt";
644 let result = test_resolve_path(mode, path_with_root, cx);
645 assert_resolved_path_eq(result.await, rel_path(path_without_root));
646
647 let result = test_resolve_path(mode, path_without_root, cx);
648 assert_resolved_path_eq(result.await, rel_path(path_without_root));
649
650 let result = test_resolve_path(mode, "root/nonexistent.txt", cx);
651 assert_eq!(
652 result.await.unwrap_err().to_string(),
653 "Can't edit file: path not found"
654 );
655
656 let result = test_resolve_path(mode, "root/dir", cx);
657 assert_eq!(
658 result.await.unwrap_err().to_string(),
659 "Can't edit file: path is a directory"
660 );
661 }
662
663 async fn test_resolve_path(
664 mode: &EditFileMode,
665 path: &str,
666 cx: &mut TestAppContext,
667 ) -> anyhow::Result<ProjectPath> {
668 init_test(cx);
669
670 let fs = project::FakeFs::new(cx.executor());
671 fs.insert_tree(
672 "/root",
673 json!({
674 "dir": {
675 "subdir": {
676 "existing.txt": "hello"
677 }
678 }
679 }),
680 )
681 .await;
682 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
683
684 let input = EditFileToolInput {
685 display_description: "Some edit".into(),
686 path: path.into(),
687 mode: mode.clone(),
688 };
689
690 cx.update(|cx| resolve_path(&input, project, cx))
691 }
692
693 #[track_caller]
694 fn assert_resolved_path_eq(path: anyhow::Result<ProjectPath>, expected: &RelPath) {
695 let actual = path.expect("Should return valid path").path;
696 assert_eq!(actual.as_ref(), expected);
697 }
698
699 #[gpui::test]
700 async fn test_format_on_save(cx: &mut TestAppContext) {
701 init_test(cx);
702
703 let fs = project::FakeFs::new(cx.executor());
704 fs.insert_tree("/root", json!({"src": {}})).await;
705
706 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
707
708 // Set up a Rust language with LSP formatting support
709 let rust_language = Arc::new(language::Language::new(
710 language::LanguageConfig {
711 name: "Rust".into(),
712 matcher: language::LanguageMatcher {
713 path_suffixes: vec!["rs".to_string()],
714 ..Default::default()
715 },
716 ..Default::default()
717 },
718 None,
719 ));
720
721 // Register the language and fake LSP
722 let language_registry = project.read_with(cx, |project, _| project.languages().clone());
723 language_registry.add(rust_language);
724
725 let mut fake_language_servers = language_registry.register_fake_lsp(
726 "Rust",
727 language::FakeLspAdapter {
728 capabilities: lsp::ServerCapabilities {
729 document_formatting_provider: Some(lsp::OneOf::Left(true)),
730 ..Default::default()
731 },
732 ..Default::default()
733 },
734 );
735
736 // Create the file
737 fs.save(
738 path!("/root/src/main.rs").as_ref(),
739 &"initial content".into(),
740 language::LineEnding::Unix,
741 )
742 .await
743 .unwrap();
744
745 // Open the buffer to trigger LSP initialization
746 let buffer = project
747 .update(cx, |project, cx| {
748 project.open_local_buffer(path!("/root/src/main.rs"), cx)
749 })
750 .await
751 .unwrap();
752
753 // Register the buffer with language servers
754 let _handle = project.update(cx, |project, cx| {
755 project.register_buffer_with_language_servers(&buffer, cx)
756 });
757
758 const UNFORMATTED_CONTENT: &str = "fn main() {println!(\"Hello!\");}\n";
759 const FORMATTED_CONTENT: &str =
760 "This file was formatted by the fake formatter in the test.\n";
761
762 // Get the fake language server and set up formatting handler
763 let fake_language_server = fake_language_servers.next().await.unwrap();
764 fake_language_server.set_request_handler::<lsp::request::Formatting, _, _>({
765 |_, _| async move {
766 Ok(Some(vec![lsp::TextEdit {
767 range: lsp::Range::new(lsp::Position::new(0, 0), lsp::Position::new(1, 0)),
768 new_text: FORMATTED_CONTENT.to_string(),
769 }]))
770 }
771 });
772
773 let context_server_registry =
774 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
775 let model = Arc::new(FakeLanguageModel::default());
776 let thread = cx.new(|cx| {
777 Thread::new(
778 project.clone(),
779 cx.new(|_cx| ProjectContext::default()),
780 context_server_registry,
781 Templates::new(),
782 Some(model.clone()),
783 cx,
784 )
785 });
786
787 // First, test with format_on_save enabled
788 cx.update(|cx| {
789 SettingsStore::update_global(cx, |store, cx| {
790 store.update_user_settings(cx, |settings| {
791 settings.project.all_languages.defaults.format_on_save = Some(FormatOnSave::On);
792 settings.project.all_languages.defaults.formatter =
793 Some(language::language_settings::FormatterList::default());
794 });
795 });
796 });
797
798 // Have the model stream unformatted content
799 let edit_result = {
800 let edit_task = cx.update(|cx| {
801 let input = EditFileToolInput {
802 display_description: "Create main function".into(),
803 path: "root/src/main.rs".into(),
804 mode: EditFileMode::Overwrite,
805 };
806 Arc::new(EditFileTool::new(
807 project.clone(),
808 thread.downgrade(),
809 language_registry.clone(),
810 ))
811 .run(input, ToolCallEventStream::test().0, cx)
812 });
813
814 // Stream the unformatted content
815 cx.executor().run_until_parked();
816 model.send_last_completion_stream_text_chunk(UNFORMATTED_CONTENT.to_string());
817 model.end_last_completion_stream();
818
819 edit_task.await
820 };
821 assert!(edit_result.is_ok());
822
823 // Wait for any async operations (e.g. formatting) to complete
824 cx.executor().run_until_parked();
825
826 // Read the file to verify it was formatted automatically
827 let new_content = fs.load(path!("/root/src/main.rs").as_ref()).await.unwrap();
828 assert_eq!(
829 // Ignore carriage returns on Windows
830 new_content.replace("\r\n", "\n"),
831 FORMATTED_CONTENT,
832 "Code should be formatted when format_on_save is enabled"
833 );
834
835 let stale_buffer_count = thread
836 .read_with(cx, |thread, _cx| thread.action_log.clone())
837 .read_with(cx, |log, cx| log.stale_buffers(cx).count());
838
839 assert_eq!(
840 stale_buffer_count, 0,
841 "BUG: Buffer is incorrectly marked as stale after format-on-save. Found {} stale buffers. \
842 This causes the agent to think the file was modified externally when it was just formatted.",
843 stale_buffer_count
844 );
845
846 // Next, test with format_on_save disabled
847 cx.update(|cx| {
848 SettingsStore::update_global(cx, |store, cx| {
849 store.update_user_settings(cx, |settings| {
850 settings.project.all_languages.defaults.format_on_save =
851 Some(FormatOnSave::Off);
852 });
853 });
854 });
855
856 // Stream unformatted edits again
857 let edit_result = {
858 let edit_task = cx.update(|cx| {
859 let input = EditFileToolInput {
860 display_description: "Update main function".into(),
861 path: "root/src/main.rs".into(),
862 mode: EditFileMode::Overwrite,
863 };
864 Arc::new(EditFileTool::new(
865 project.clone(),
866 thread.downgrade(),
867 language_registry,
868 ))
869 .run(input, ToolCallEventStream::test().0, cx)
870 });
871
872 // Stream the unformatted content
873 cx.executor().run_until_parked();
874 model.send_last_completion_stream_text_chunk(UNFORMATTED_CONTENT.to_string());
875 model.end_last_completion_stream();
876
877 edit_task.await
878 };
879 assert!(edit_result.is_ok());
880
881 // Wait for any async operations (e.g. formatting) to complete
882 cx.executor().run_until_parked();
883
884 // Verify the file was not formatted
885 let new_content = fs.load(path!("/root/src/main.rs").as_ref()).await.unwrap();
886 assert_eq!(
887 // Ignore carriage returns on Windows
888 new_content.replace("\r\n", "\n"),
889 UNFORMATTED_CONTENT,
890 "Code should not be formatted when format_on_save is disabled"
891 );
892 }
893
894 #[gpui::test]
895 async fn test_remove_trailing_whitespace(cx: &mut TestAppContext) {
896 init_test(cx);
897
898 let fs = project::FakeFs::new(cx.executor());
899 fs.insert_tree("/root", json!({"src": {}})).await;
900
901 // Create a simple file with trailing whitespace
902 fs.save(
903 path!("/root/src/main.rs").as_ref(),
904 &"initial content".into(),
905 language::LineEnding::Unix,
906 )
907 .await
908 .unwrap();
909
910 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
911 let context_server_registry =
912 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
913 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
914 let model = Arc::new(FakeLanguageModel::default());
915 let thread = cx.new(|cx| {
916 Thread::new(
917 project.clone(),
918 cx.new(|_cx| ProjectContext::default()),
919 context_server_registry,
920 Templates::new(),
921 Some(model.clone()),
922 cx,
923 )
924 });
925
926 // First, test with remove_trailing_whitespace_on_save enabled
927 cx.update(|cx| {
928 SettingsStore::update_global(cx, |store, cx| {
929 store.update_user_settings(cx, |settings| {
930 settings
931 .project
932 .all_languages
933 .defaults
934 .remove_trailing_whitespace_on_save = Some(true);
935 });
936 });
937 });
938
939 const CONTENT_WITH_TRAILING_WHITESPACE: &str =
940 "fn main() { \n println!(\"Hello!\"); \n}\n";
941
942 // Have the model stream content that contains trailing whitespace
943 let edit_result = {
944 let edit_task = cx.update(|cx| {
945 let input = EditFileToolInput {
946 display_description: "Create main function".into(),
947 path: "root/src/main.rs".into(),
948 mode: EditFileMode::Overwrite,
949 };
950 Arc::new(EditFileTool::new(
951 project.clone(),
952 thread.downgrade(),
953 language_registry.clone(),
954 ))
955 .run(input, ToolCallEventStream::test().0, cx)
956 });
957
958 // Stream the content with trailing whitespace
959 cx.executor().run_until_parked();
960 model.send_last_completion_stream_text_chunk(
961 CONTENT_WITH_TRAILING_WHITESPACE.to_string(),
962 );
963 model.end_last_completion_stream();
964
965 edit_task.await
966 };
967 assert!(edit_result.is_ok());
968
969 // Wait for any async operations (e.g. formatting) to complete
970 cx.executor().run_until_parked();
971
972 // Read the file to verify trailing whitespace was removed automatically
973 assert_eq!(
974 // Ignore carriage returns on Windows
975 fs.load(path!("/root/src/main.rs").as_ref())
976 .await
977 .unwrap()
978 .replace("\r\n", "\n"),
979 "fn main() {\n println!(\"Hello!\");\n}\n",
980 "Trailing whitespace should be removed when remove_trailing_whitespace_on_save is enabled"
981 );
982
983 // Next, test with remove_trailing_whitespace_on_save disabled
984 cx.update(|cx| {
985 SettingsStore::update_global(cx, |store, cx| {
986 store.update_user_settings(cx, |settings| {
987 settings
988 .project
989 .all_languages
990 .defaults
991 .remove_trailing_whitespace_on_save = Some(false);
992 });
993 });
994 });
995
996 // Stream edits again with trailing whitespace
997 let edit_result = {
998 let edit_task = cx.update(|cx| {
999 let input = EditFileToolInput {
1000 display_description: "Update main function".into(),
1001 path: "root/src/main.rs".into(),
1002 mode: EditFileMode::Overwrite,
1003 };
1004 Arc::new(EditFileTool::new(
1005 project.clone(),
1006 thread.downgrade(),
1007 language_registry,
1008 ))
1009 .run(input, ToolCallEventStream::test().0, cx)
1010 });
1011
1012 // Stream the content with trailing whitespace
1013 cx.executor().run_until_parked();
1014 model.send_last_completion_stream_text_chunk(
1015 CONTENT_WITH_TRAILING_WHITESPACE.to_string(),
1016 );
1017 model.end_last_completion_stream();
1018
1019 edit_task.await
1020 };
1021 assert!(edit_result.is_ok());
1022
1023 // Wait for any async operations (e.g. formatting) to complete
1024 cx.executor().run_until_parked();
1025
1026 // Verify the file still has trailing whitespace
1027 // Read the file again - it should still have trailing whitespace
1028 let final_content = fs.load(path!("/root/src/main.rs").as_ref()).await.unwrap();
1029 assert_eq!(
1030 // Ignore carriage returns on Windows
1031 final_content.replace("\r\n", "\n"),
1032 CONTENT_WITH_TRAILING_WHITESPACE,
1033 "Trailing whitespace should remain when remove_trailing_whitespace_on_save is disabled"
1034 );
1035 }
1036
1037 #[gpui::test]
1038 async fn test_authorize(cx: &mut TestAppContext) {
1039 init_test(cx);
1040 let fs = project::FakeFs::new(cx.executor());
1041 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
1042 let context_server_registry =
1043 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1044 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1045 let model = Arc::new(FakeLanguageModel::default());
1046 let thread = cx.new(|cx| {
1047 Thread::new(
1048 project.clone(),
1049 cx.new(|_cx| ProjectContext::default()),
1050 context_server_registry,
1051 Templates::new(),
1052 Some(model.clone()),
1053 cx,
1054 )
1055 });
1056 let tool = Arc::new(EditFileTool::new(
1057 project.clone(),
1058 thread.downgrade(),
1059 language_registry,
1060 ));
1061 fs.insert_tree("/root", json!({})).await;
1062
1063 // Test 1: Path with .zed component should require confirmation
1064 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1065 let _auth = cx.update(|cx| {
1066 tool.authorize(
1067 &EditFileToolInput {
1068 display_description: "test 1".into(),
1069 path: ".zed/settings.json".into(),
1070 mode: EditFileMode::Edit,
1071 },
1072 &stream_tx,
1073 cx,
1074 )
1075 });
1076
1077 let event = stream_rx.expect_authorization().await;
1078 assert_eq!(
1079 event.tool_call.fields.title,
1080 Some("test 1 (local settings)".into())
1081 );
1082
1083 // Test 2: Path outside project should require confirmation
1084 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1085 let _auth = cx.update(|cx| {
1086 tool.authorize(
1087 &EditFileToolInput {
1088 display_description: "test 2".into(),
1089 path: "/etc/hosts".into(),
1090 mode: EditFileMode::Edit,
1091 },
1092 &stream_tx,
1093 cx,
1094 )
1095 });
1096
1097 let event = stream_rx.expect_authorization().await;
1098 assert_eq!(event.tool_call.fields.title, Some("test 2".into()));
1099
1100 // Test 3: Relative path without .zed should not require confirmation
1101 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1102 cx.update(|cx| {
1103 tool.authorize(
1104 &EditFileToolInput {
1105 display_description: "test 3".into(),
1106 path: "root/src/main.rs".into(),
1107 mode: EditFileMode::Edit,
1108 },
1109 &stream_tx,
1110 cx,
1111 )
1112 })
1113 .await
1114 .unwrap();
1115 assert!(stream_rx.try_next().is_err());
1116
1117 // Test 4: Path with .zed in the middle should require confirmation
1118 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1119 let _auth = cx.update(|cx| {
1120 tool.authorize(
1121 &EditFileToolInput {
1122 display_description: "test 4".into(),
1123 path: "root/.zed/tasks.json".into(),
1124 mode: EditFileMode::Edit,
1125 },
1126 &stream_tx,
1127 cx,
1128 )
1129 });
1130 let event = stream_rx.expect_authorization().await;
1131 assert_eq!(
1132 event.tool_call.fields.title,
1133 Some("test 4 (local settings)".into())
1134 );
1135
1136 // Test 5: When always_allow_tool_actions is enabled, no confirmation needed
1137 cx.update(|cx| {
1138 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
1139 settings.always_allow_tool_actions = true;
1140 agent_settings::AgentSettings::override_global(settings, cx);
1141 });
1142
1143 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1144 cx.update(|cx| {
1145 tool.authorize(
1146 &EditFileToolInput {
1147 display_description: "test 5.1".into(),
1148 path: ".zed/settings.json".into(),
1149 mode: EditFileMode::Edit,
1150 },
1151 &stream_tx,
1152 cx,
1153 )
1154 })
1155 .await
1156 .unwrap();
1157 assert!(stream_rx.try_next().is_err());
1158
1159 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1160 cx.update(|cx| {
1161 tool.authorize(
1162 &EditFileToolInput {
1163 display_description: "test 5.2".into(),
1164 path: "/etc/hosts".into(),
1165 mode: EditFileMode::Edit,
1166 },
1167 &stream_tx,
1168 cx,
1169 )
1170 })
1171 .await
1172 .unwrap();
1173 assert!(stream_rx.try_next().is_err());
1174 }
1175
1176 #[gpui::test]
1177 async fn test_authorize_global_config(cx: &mut TestAppContext) {
1178 init_test(cx);
1179 let fs = project::FakeFs::new(cx.executor());
1180 fs.insert_tree("/project", json!({})).await;
1181 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1182 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1183 let context_server_registry =
1184 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1185 let model = Arc::new(FakeLanguageModel::default());
1186 let thread = cx.new(|cx| {
1187 Thread::new(
1188 project.clone(),
1189 cx.new(|_cx| ProjectContext::default()),
1190 context_server_registry,
1191 Templates::new(),
1192 Some(model.clone()),
1193 cx,
1194 )
1195 });
1196 let tool = Arc::new(EditFileTool::new(
1197 project.clone(),
1198 thread.downgrade(),
1199 language_registry,
1200 ));
1201
1202 // Test global config paths - these should require confirmation if they exist and are outside the project
1203 let test_cases = vec![
1204 (
1205 "/etc/hosts",
1206 true,
1207 "System file should require confirmation",
1208 ),
1209 (
1210 "/usr/local/bin/script",
1211 true,
1212 "System bin file should require confirmation",
1213 ),
1214 (
1215 "project/normal_file.rs",
1216 false,
1217 "Normal project file should not require confirmation",
1218 ),
1219 ];
1220
1221 for (path, should_confirm, description) in test_cases {
1222 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1223 let auth = cx.update(|cx| {
1224 tool.authorize(
1225 &EditFileToolInput {
1226 display_description: "Edit file".into(),
1227 path: path.into(),
1228 mode: EditFileMode::Edit,
1229 },
1230 &stream_tx,
1231 cx,
1232 )
1233 });
1234
1235 if should_confirm {
1236 stream_rx.expect_authorization().await;
1237 } else {
1238 auth.await.unwrap();
1239 assert!(
1240 stream_rx.try_next().is_err(),
1241 "Failed for case: {} - path: {} - expected no confirmation but got one",
1242 description,
1243 path
1244 );
1245 }
1246 }
1247 }
1248
1249 #[gpui::test]
1250 async fn test_needs_confirmation_with_multiple_worktrees(cx: &mut TestAppContext) {
1251 init_test(cx);
1252 let fs = project::FakeFs::new(cx.executor());
1253
1254 // Create multiple worktree directories
1255 fs.insert_tree(
1256 "/workspace/frontend",
1257 json!({
1258 "src": {
1259 "main.js": "console.log('frontend');"
1260 }
1261 }),
1262 )
1263 .await;
1264 fs.insert_tree(
1265 "/workspace/backend",
1266 json!({
1267 "src": {
1268 "main.rs": "fn main() {}"
1269 }
1270 }),
1271 )
1272 .await;
1273 fs.insert_tree(
1274 "/workspace/shared",
1275 json!({
1276 ".zed": {
1277 "settings.json": "{}"
1278 }
1279 }),
1280 )
1281 .await;
1282
1283 // Create project with multiple worktrees
1284 let project = Project::test(
1285 fs.clone(),
1286 [
1287 path!("/workspace/frontend").as_ref(),
1288 path!("/workspace/backend").as_ref(),
1289 path!("/workspace/shared").as_ref(),
1290 ],
1291 cx,
1292 )
1293 .await;
1294 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1295 let context_server_registry =
1296 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1297 let model = Arc::new(FakeLanguageModel::default());
1298 let thread = cx.new(|cx| {
1299 Thread::new(
1300 project.clone(),
1301 cx.new(|_cx| ProjectContext::default()),
1302 context_server_registry.clone(),
1303 Templates::new(),
1304 Some(model.clone()),
1305 cx,
1306 )
1307 });
1308 let tool = Arc::new(EditFileTool::new(
1309 project.clone(),
1310 thread.downgrade(),
1311 language_registry,
1312 ));
1313
1314 // Test files in different worktrees
1315 let test_cases = vec![
1316 ("frontend/src/main.js", false, "File in first worktree"),
1317 ("backend/src/main.rs", false, "File in second worktree"),
1318 (
1319 "shared/.zed/settings.json",
1320 true,
1321 ".zed file in third worktree",
1322 ),
1323 ("/etc/hosts", true, "Absolute path outside all worktrees"),
1324 (
1325 "../outside/file.txt",
1326 true,
1327 "Relative path outside worktrees",
1328 ),
1329 ];
1330
1331 for (path, should_confirm, description) in test_cases {
1332 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1333 let auth = cx.update(|cx| {
1334 tool.authorize(
1335 &EditFileToolInput {
1336 display_description: "Edit file".into(),
1337 path: path.into(),
1338 mode: EditFileMode::Edit,
1339 },
1340 &stream_tx,
1341 cx,
1342 )
1343 });
1344
1345 if should_confirm {
1346 stream_rx.expect_authorization().await;
1347 } else {
1348 auth.await.unwrap();
1349 assert!(
1350 stream_rx.try_next().is_err(),
1351 "Failed for case: {} - path: {} - expected no confirmation but got one",
1352 description,
1353 path
1354 );
1355 }
1356 }
1357 }
1358
1359 #[gpui::test]
1360 async fn test_needs_confirmation_edge_cases(cx: &mut TestAppContext) {
1361 init_test(cx);
1362 let fs = project::FakeFs::new(cx.executor());
1363 fs.insert_tree(
1364 "/project",
1365 json!({
1366 ".zed": {
1367 "settings.json": "{}"
1368 },
1369 "src": {
1370 ".zed": {
1371 "local.json": "{}"
1372 }
1373 }
1374 }),
1375 )
1376 .await;
1377 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1378 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1379 let context_server_registry =
1380 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1381 let model = Arc::new(FakeLanguageModel::default());
1382 let thread = cx.new(|cx| {
1383 Thread::new(
1384 project.clone(),
1385 cx.new(|_cx| ProjectContext::default()),
1386 context_server_registry.clone(),
1387 Templates::new(),
1388 Some(model.clone()),
1389 cx,
1390 )
1391 });
1392 let tool = Arc::new(EditFileTool::new(
1393 project.clone(),
1394 thread.downgrade(),
1395 language_registry,
1396 ));
1397
1398 // Test edge cases
1399 let test_cases = vec![
1400 // Empty path - find_project_path returns Some for empty paths
1401 ("", false, "Empty path is treated as project root"),
1402 // Root directory
1403 ("/", true, "Root directory should be outside project"),
1404 // Parent directory references - find_project_path resolves these
1405 (
1406 "project/../other",
1407 true,
1408 "Path with .. that goes outside of root directory",
1409 ),
1410 (
1411 "project/./src/file.rs",
1412 false,
1413 "Path with . should work normally",
1414 ),
1415 // Windows-style paths (if on Windows)
1416 #[cfg(target_os = "windows")]
1417 ("C:\\Windows\\System32\\hosts", true, "Windows system path"),
1418 #[cfg(target_os = "windows")]
1419 ("project\\src\\main.rs", false, "Windows-style project path"),
1420 ];
1421
1422 for (path, should_confirm, description) in test_cases {
1423 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1424 let auth = cx.update(|cx| {
1425 tool.authorize(
1426 &EditFileToolInput {
1427 display_description: "Edit file".into(),
1428 path: path.into(),
1429 mode: EditFileMode::Edit,
1430 },
1431 &stream_tx,
1432 cx,
1433 )
1434 });
1435
1436 cx.run_until_parked();
1437
1438 if should_confirm {
1439 stream_rx.expect_authorization().await;
1440 } else {
1441 assert!(
1442 stream_rx.try_next().is_err(),
1443 "Failed for case: {} - path: {} - expected no confirmation but got one",
1444 description,
1445 path
1446 );
1447 auth.await.unwrap();
1448 }
1449 }
1450 }
1451
1452 #[gpui::test]
1453 async fn test_needs_confirmation_with_different_modes(cx: &mut TestAppContext) {
1454 init_test(cx);
1455 let fs = project::FakeFs::new(cx.executor());
1456 fs.insert_tree(
1457 "/project",
1458 json!({
1459 "existing.txt": "content",
1460 ".zed": {
1461 "settings.json": "{}"
1462 }
1463 }),
1464 )
1465 .await;
1466 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1467 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1468 let context_server_registry =
1469 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1470 let model = Arc::new(FakeLanguageModel::default());
1471 let thread = cx.new(|cx| {
1472 Thread::new(
1473 project.clone(),
1474 cx.new(|_cx| ProjectContext::default()),
1475 context_server_registry.clone(),
1476 Templates::new(),
1477 Some(model.clone()),
1478 cx,
1479 )
1480 });
1481 let tool = Arc::new(EditFileTool::new(
1482 project.clone(),
1483 thread.downgrade(),
1484 language_registry,
1485 ));
1486
1487 // Test different EditFileMode values
1488 let modes = vec![
1489 EditFileMode::Edit,
1490 EditFileMode::Create,
1491 EditFileMode::Overwrite,
1492 ];
1493
1494 for mode in modes {
1495 // Test .zed path with different modes
1496 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1497 let _auth = cx.update(|cx| {
1498 tool.authorize(
1499 &EditFileToolInput {
1500 display_description: "Edit settings".into(),
1501 path: "project/.zed/settings.json".into(),
1502 mode: mode.clone(),
1503 },
1504 &stream_tx,
1505 cx,
1506 )
1507 });
1508
1509 stream_rx.expect_authorization().await;
1510
1511 // Test outside path with different modes
1512 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1513 let _auth = cx.update(|cx| {
1514 tool.authorize(
1515 &EditFileToolInput {
1516 display_description: "Edit file".into(),
1517 path: "/outside/file.txt".into(),
1518 mode: mode.clone(),
1519 },
1520 &stream_tx,
1521 cx,
1522 )
1523 });
1524
1525 stream_rx.expect_authorization().await;
1526
1527 // Test normal path with different modes
1528 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1529 cx.update(|cx| {
1530 tool.authorize(
1531 &EditFileToolInput {
1532 display_description: "Edit file".into(),
1533 path: "project/normal.txt".into(),
1534 mode: mode.clone(),
1535 },
1536 &stream_tx,
1537 cx,
1538 )
1539 })
1540 .await
1541 .unwrap();
1542 assert!(stream_rx.try_next().is_err());
1543 }
1544 }
1545
1546 #[gpui::test]
1547 async fn test_initial_title_with_partial_input(cx: &mut TestAppContext) {
1548 init_test(cx);
1549 let fs = project::FakeFs::new(cx.executor());
1550 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1551 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1552 let context_server_registry =
1553 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1554 let model = Arc::new(FakeLanguageModel::default());
1555 let thread = cx.new(|cx| {
1556 Thread::new(
1557 project.clone(),
1558 cx.new(|_cx| ProjectContext::default()),
1559 context_server_registry,
1560 Templates::new(),
1561 Some(model.clone()),
1562 cx,
1563 )
1564 });
1565 let tool = Arc::new(EditFileTool::new(
1566 project,
1567 thread.downgrade(),
1568 language_registry,
1569 ));
1570
1571 cx.update(|cx| {
1572 // ...
1573 assert_eq!(
1574 tool.initial_title(
1575 Err(json!({
1576 "path": "src/main.rs",
1577 "display_description": "",
1578 "old_string": "old code",
1579 "new_string": "new code"
1580 })),
1581 cx
1582 ),
1583 "src/main.rs"
1584 );
1585 assert_eq!(
1586 tool.initial_title(
1587 Err(json!({
1588 "path": "",
1589 "display_description": "Fix error handling",
1590 "old_string": "old code",
1591 "new_string": "new code"
1592 })),
1593 cx
1594 ),
1595 "Fix error handling"
1596 );
1597 assert_eq!(
1598 tool.initial_title(
1599 Err(json!({
1600 "path": "src/main.rs",
1601 "display_description": "Fix error handling",
1602 "old_string": "old code",
1603 "new_string": "new code"
1604 })),
1605 cx
1606 ),
1607 "src/main.rs"
1608 );
1609 assert_eq!(
1610 tool.initial_title(
1611 Err(json!({
1612 "path": "",
1613 "display_description": "",
1614 "old_string": "old code",
1615 "new_string": "new code"
1616 })),
1617 cx
1618 ),
1619 DEFAULT_UI_TEXT
1620 );
1621 assert_eq!(
1622 tool.initial_title(Err(serde_json::Value::Null), cx),
1623 DEFAULT_UI_TEXT
1624 );
1625 });
1626 }
1627
1628 #[gpui::test]
1629 async fn test_diff_finalization(cx: &mut TestAppContext) {
1630 init_test(cx);
1631 let fs = project::FakeFs::new(cx.executor());
1632 fs.insert_tree("/", json!({"main.rs": ""})).await;
1633
1634 let project = Project::test(fs.clone(), [path!("/").as_ref()], cx).await;
1635 let languages = project.read_with(cx, |project, _cx| project.languages().clone());
1636 let context_server_registry =
1637 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1638 let model = Arc::new(FakeLanguageModel::default());
1639 let thread = cx.new(|cx| {
1640 Thread::new(
1641 project.clone(),
1642 cx.new(|_cx| ProjectContext::default()),
1643 context_server_registry.clone(),
1644 Templates::new(),
1645 Some(model.clone()),
1646 cx,
1647 )
1648 });
1649
1650 // Ensure the diff is finalized after the edit completes.
1651 {
1652 let tool = Arc::new(EditFileTool::new(
1653 project.clone(),
1654 thread.downgrade(),
1655 languages.clone(),
1656 ));
1657 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1658 let edit = cx.update(|cx| {
1659 tool.run(
1660 EditFileToolInput {
1661 display_description: "Edit file".into(),
1662 path: path!("/main.rs").into(),
1663 mode: EditFileMode::Edit,
1664 },
1665 stream_tx,
1666 cx,
1667 )
1668 });
1669 stream_rx.expect_update_fields().await;
1670 let diff = stream_rx.expect_diff().await;
1671 diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Pending(_))));
1672 cx.run_until_parked();
1673 model.end_last_completion_stream();
1674 edit.await.unwrap();
1675 diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Finalized(_))));
1676 }
1677
1678 // Ensure the diff is finalized if an error occurs while editing.
1679 {
1680 model.forbid_requests();
1681 let tool = Arc::new(EditFileTool::new(
1682 project.clone(),
1683 thread.downgrade(),
1684 languages.clone(),
1685 ));
1686 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1687 let edit = cx.update(|cx| {
1688 tool.run(
1689 EditFileToolInput {
1690 display_description: "Edit file".into(),
1691 path: path!("/main.rs").into(),
1692 mode: EditFileMode::Edit,
1693 },
1694 stream_tx,
1695 cx,
1696 )
1697 });
1698 stream_rx.expect_update_fields().await;
1699 let diff = stream_rx.expect_diff().await;
1700 diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Pending(_))));
1701 edit.await.unwrap_err();
1702 diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Finalized(_))));
1703 model.allow_requests();
1704 }
1705
1706 // Ensure the diff is finalized if the tool call gets dropped.
1707 {
1708 let tool = Arc::new(EditFileTool::new(
1709 project.clone(),
1710 thread.downgrade(),
1711 languages.clone(),
1712 ));
1713 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1714 let edit = cx.update(|cx| {
1715 tool.run(
1716 EditFileToolInput {
1717 display_description: "Edit file".into(),
1718 path: path!("/main.rs").into(),
1719 mode: EditFileMode::Edit,
1720 },
1721 stream_tx,
1722 cx,
1723 )
1724 });
1725 stream_rx.expect_update_fields().await;
1726 let diff = stream_rx.expect_diff().await;
1727 diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Pending(_))));
1728 drop(edit);
1729 cx.run_until_parked();
1730 diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Finalized(_))));
1731 }
1732 }
1733
1734 fn init_test(cx: &mut TestAppContext) {
1735 cx.update(|cx| {
1736 let settings_store = SettingsStore::test(cx);
1737 cx.set_global(settings_store);
1738 language::init(cx);
1739 TelemetrySettings::register(cx);
1740 agent_settings::AgentSettings::register(cx);
1741 Project::init_settings(cx);
1742 });
1743 }
1744}