1use crate::{AgentTool, Thread, ToolCallEventStream};
2use acp_thread::Diff;
3use agent_client_protocol::{self as acp, ToolCallLocation, ToolCallUpdateFields};
4use anyhow::{Context as _, Result, anyhow};
5use assistant_tools::edit_agent::{EditAgent, EditAgentOutput, EditAgentOutputEvent, EditFormat};
6use cloud_llm_client::CompletionIntent;
7use collections::HashSet;
8use gpui::{App, AppContext, AsyncApp, Entity, Task, WeakEntity};
9use indoc::formatdoc;
10use language::language_settings::{self, FormatOnSave};
11use language::{LanguageRegistry, ToPoint};
12use language_model::LanguageModelToolResultContent;
13use paths;
14use project::lsp_store::{FormatTrigger, LspFormatTarget};
15use project::{Project, ProjectPath};
16use schemars::JsonSchema;
17use serde::{Deserialize, Serialize};
18use settings::Settings;
19use smol::stream::StreamExt as _;
20use std::path::{Path, PathBuf};
21use std::sync::Arc;
22use ui::SharedString;
23use util::ResultExt;
24
25const DEFAULT_UI_TEXT: &str = "Editing file";
26
27/// This is a tool for creating a new file or editing an existing file. For moving or renaming files, you should generally use the `terminal` tool with the 'mv' command instead.
28///
29/// Before using this tool:
30///
31/// 1. Use the `read_file` tool to understand the file's contents and context
32///
33/// 2. Verify the directory path is correct (only applicable when creating new files):
34/// - Use the `list_directory` tool to verify the parent directory exists and is the correct location
35#[derive(Debug, Serialize, Deserialize, JsonSchema)]
36pub struct EditFileToolInput {
37 /// A one-line, user-friendly markdown description of the edit. This will be shown in the UI and also passed to another model to perform the edit.
38 ///
39 /// Be terse, but also descriptive in what you want to achieve with this edit. Avoid generic instructions.
40 ///
41 /// NEVER mention the file path in this description.
42 ///
43 /// <example>Fix API endpoint URLs</example>
44 /// <example>Update copyright year in `page_footer`</example>
45 ///
46 /// Make sure to include this field before all the others in the input object so that we can display it immediately.
47 pub display_description: String,
48
49 /// The full path of the file to create or modify in the project.
50 ///
51 /// WARNING: When specifying which file path need changing, you MUST start each path with one of the project's root directories.
52 ///
53 /// The following examples assume we have two root directories in the project:
54 /// - /a/b/backend
55 /// - /c/d/frontend
56 ///
57 /// <example>
58 /// `backend/src/main.rs`
59 ///
60 /// Notice how the file path starts with `backend`. Without that, the path would be ambiguous and the call would fail!
61 /// </example>
62 ///
63 /// <example>
64 /// `frontend/db.js`
65 /// </example>
66 pub path: PathBuf,
67 /// The mode of operation on the file. Possible values:
68 /// - 'edit': Make granular edits to an existing file.
69 /// - 'create': Create a new file if it doesn't exist.
70 /// - 'overwrite': Replace the entire contents of an existing file.
71 ///
72 /// When a file already exists or you just created it, prefer editing it as opposed to recreating it from scratch.
73 pub mode: EditFileMode,
74}
75
76#[derive(Debug, Serialize, Deserialize, JsonSchema)]
77struct EditFileToolPartialInput {
78 #[serde(default)]
79 path: String,
80 #[serde(default)]
81 display_description: String,
82}
83
84#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
85#[serde(rename_all = "lowercase")]
86#[schemars(inline)]
87pub enum EditFileMode {
88 Edit,
89 Create,
90 Overwrite,
91}
92
93#[derive(Debug, Serialize, Deserialize)]
94pub struct EditFileToolOutput {
95 #[serde(alias = "original_path")]
96 input_path: PathBuf,
97 new_text: String,
98 old_text: Arc<String>,
99 #[serde(default)]
100 diff: String,
101 #[serde(alias = "raw_output")]
102 edit_agent_output: EditAgentOutput,
103}
104
105impl From<EditFileToolOutput> for LanguageModelToolResultContent {
106 fn from(output: EditFileToolOutput) -> Self {
107 if output.diff.is_empty() {
108 "No edits were made.".into()
109 } else {
110 format!(
111 "Edited {}:\n\n```diff\n{}\n```",
112 output.input_path.display(),
113 output.diff
114 )
115 .into()
116 }
117 }
118}
119
120pub struct EditFileTool {
121 thread: WeakEntity<Thread>,
122 language_registry: Arc<LanguageRegistry>,
123 project: Entity<Project>,
124}
125
126impl EditFileTool {
127 pub fn new(
128 project: Entity<Project>,
129 thread: WeakEntity<Thread>,
130 language_registry: Arc<LanguageRegistry>,
131 ) -> Self {
132 Self {
133 project,
134 thread,
135 language_registry,
136 }
137 }
138
139 fn authorize(
140 &self,
141 input: &EditFileToolInput,
142 event_stream: &ToolCallEventStream,
143 cx: &mut App,
144 ) -> Task<Result<()>> {
145 if agent_settings::AgentSettings::get_global(cx).always_allow_tool_actions {
146 return Task::ready(Ok(()));
147 }
148
149 // If any path component matches the local settings folder, then this could affect
150 // the editor in ways beyond the project source, so prompt.
151 let local_settings_folder = paths::local_settings_folder_relative_path();
152 let path = Path::new(&input.path);
153 if path
154 .components()
155 .any(|component| component.as_os_str() == local_settings_folder.as_os_str())
156 {
157 return event_stream.authorize(
158 format!("{} (local settings)", input.display_description),
159 cx,
160 );
161 }
162
163 // It's also possible that the global config dir is configured to be inside the project,
164 // so check for that edge case too.
165 if let Ok(canonical_path) = std::fs::canonicalize(&input.path)
166 && canonical_path.starts_with(paths::config_dir())
167 {
168 return event_stream.authorize(
169 format!("{} (global settings)", input.display_description),
170 cx,
171 );
172 }
173
174 // Check if path is inside the global config directory
175 // First check if it's already inside project - if not, try to canonicalize
176 let Ok(project_path) = self.thread.read_with(cx, |thread, cx| {
177 thread.project().read(cx).find_project_path(&input.path, cx)
178 }) else {
179 return Task::ready(Err(anyhow!("thread was dropped")));
180 };
181
182 // If the path is inside the project, and it's not one of the above edge cases,
183 // then no confirmation is necessary. Otherwise, confirmation is necessary.
184 if project_path.is_some() {
185 Task::ready(Ok(()))
186 } else {
187 event_stream.authorize(&input.display_description, cx)
188 }
189 }
190}
191
192impl AgentTool for EditFileTool {
193 type Input = EditFileToolInput;
194 type Output = EditFileToolOutput;
195
196 fn name() -> &'static str {
197 "edit_file"
198 }
199
200 fn kind() -> acp::ToolKind {
201 acp::ToolKind::Edit
202 }
203
204 fn initial_title(
205 &self,
206 input: Result<Self::Input, serde_json::Value>,
207 cx: &mut App,
208 ) -> SharedString {
209 match input {
210 Ok(input) => self
211 .project
212 .read(cx)
213 .find_project_path(&input.path, cx)
214 .and_then(|project_path| {
215 self.project
216 .read(cx)
217 .short_full_path_for_project_path(&project_path, cx)
218 })
219 .unwrap_or(Path::new(&input.path).into())
220 .to_string_lossy()
221 .to_string()
222 .into(),
223 Err(raw_input) => {
224 if let Some(input) =
225 serde_json::from_value::<EditFileToolPartialInput>(raw_input).ok()
226 {
227 let path = input.path.trim();
228 if !path.is_empty() {
229 return self
230 .project
231 .read(cx)
232 .find_project_path(&input.path, cx)
233 .and_then(|project_path| {
234 self.project
235 .read(cx)
236 .short_full_path_for_project_path(&project_path, cx)
237 })
238 .unwrap_or(Path::new(&input.path).into())
239 .to_string_lossy()
240 .to_string()
241 .into();
242 }
243
244 let description = input.display_description.trim();
245 if !description.is_empty() {
246 return description.to_string().into();
247 }
248 }
249
250 DEFAULT_UI_TEXT.into()
251 }
252 }
253 }
254
255 fn run(
256 self: Arc<Self>,
257 input: Self::Input,
258 event_stream: ToolCallEventStream,
259 cx: &mut App,
260 ) -> Task<Result<Self::Output>> {
261 let Ok(project) = self
262 .thread
263 .read_with(cx, |thread, _cx| thread.project().clone())
264 else {
265 return Task::ready(Err(anyhow!("thread was dropped")));
266 };
267 let project_path = match resolve_path(&input, project.clone(), cx) {
268 Ok(path) => path,
269 Err(err) => return Task::ready(Err(anyhow!(err))),
270 };
271 let abs_path = project.read(cx).absolute_path(&project_path, cx);
272 if let Some(abs_path) = abs_path.clone() {
273 event_stream.update_fields(ToolCallUpdateFields {
274 locations: Some(vec![acp::ToolCallLocation {
275 path: abs_path,
276 line: None,
277 meta: None,
278 }]),
279 ..Default::default()
280 });
281 }
282
283 let authorize = self.authorize(&input, &event_stream, cx);
284 cx.spawn(async move |cx: &mut AsyncApp| {
285 authorize.await?;
286
287 let (request, model, action_log) = self.thread.update(cx, |thread, cx| {
288 let request = thread.build_completion_request(CompletionIntent::ToolResults, cx);
289 (request, thread.model().cloned(), thread.action_log().clone())
290 })?;
291 let request = request?;
292 let model = model.context("No language model configured")?;
293
294 let edit_format = EditFormat::from_model(model.clone())?;
295 let edit_agent = EditAgent::new(
296 model,
297 project.clone(),
298 action_log.clone(),
299 // TODO: move edit agent to this crate so we can use our templates
300 assistant_tools::templates::Templates::new(),
301 edit_format,
302 );
303
304 let buffer = project
305 .update(cx, |project, cx| {
306 project.open_buffer(project_path.clone(), cx)
307 })?
308 .await?;
309
310 let diff = cx.new(|cx| Diff::new(buffer.clone(), cx))?;
311 event_stream.update_diff(diff.clone());
312 let _finalize_diff = util::defer({
313 let diff = diff.downgrade();
314 let mut cx = cx.clone();
315 move || {
316 diff.update(&mut cx, |diff, cx| diff.finalize(cx)).ok();
317 }
318 });
319
320 let old_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
321 let old_text = cx
322 .background_spawn({
323 let old_snapshot = old_snapshot.clone();
324 async move { Arc::new(old_snapshot.text()) }
325 })
326 .await;
327
328
329 let (output, mut events) = if matches!(input.mode, EditFileMode::Edit) {
330 edit_agent.edit(
331 buffer.clone(),
332 input.display_description.clone(),
333 &request,
334 cx,
335 )
336 } else {
337 edit_agent.overwrite(
338 buffer.clone(),
339 input.display_description.clone(),
340 &request,
341 cx,
342 )
343 };
344
345 let mut hallucinated_old_text = false;
346 let mut ambiguous_ranges = Vec::new();
347 let mut emitted_location = false;
348 while let Some(event) = events.next().await {
349 match event {
350 EditAgentOutputEvent::Edited(range) => {
351 if !emitted_location {
352 let line = buffer.update(cx, |buffer, _cx| {
353 range.start.to_point(&buffer.snapshot()).row
354 }).ok();
355 if let Some(abs_path) = abs_path.clone() {
356 event_stream.update_fields(ToolCallUpdateFields {
357 locations: Some(vec![ToolCallLocation { path: abs_path, line, meta: None }]),
358 ..Default::default()
359 });
360 }
361 emitted_location = true;
362 }
363 },
364 EditAgentOutputEvent::UnresolvedEditRange => hallucinated_old_text = true,
365 EditAgentOutputEvent::AmbiguousEditRange(ranges) => ambiguous_ranges = ranges,
366 EditAgentOutputEvent::ResolvingEditRange(range) => {
367 diff.update(cx, |card, cx| card.reveal_range(range.clone(), cx))?;
368 // if !emitted_location {
369 // let line = buffer.update(cx, |buffer, _cx| {
370 // range.start.to_point(&buffer.snapshot()).row
371 // }).ok();
372 // if let Some(abs_path) = abs_path.clone() {
373 // event_stream.update_fields(ToolCallUpdateFields {
374 // locations: Some(vec![ToolCallLocation { path: abs_path, line }]),
375 // ..Default::default()
376 // });
377 // }
378 // }
379 }
380 }
381 }
382
383 // If format_on_save is enabled, format the buffer
384 let format_on_save_enabled = buffer
385 .read_with(cx, |buffer, cx| {
386 let settings = language_settings::language_settings(
387 buffer.language().map(|l| l.name()),
388 buffer.file(),
389 cx,
390 );
391 settings.format_on_save != FormatOnSave::Off
392 })
393 .unwrap_or(false);
394
395 let edit_agent_output = output.await?;
396
397 if format_on_save_enabled {
398 action_log.update(cx, |log, cx| {
399 log.buffer_edited(buffer.clone(), cx);
400 })?;
401
402 let format_task = project.update(cx, |project, cx| {
403 project.format(
404 HashSet::from_iter([buffer.clone()]),
405 LspFormatTarget::Buffers,
406 false, // Don't push to history since the tool did it.
407 FormatTrigger::Save,
408 cx,
409 )
410 })?;
411 format_task.await.log_err();
412 }
413
414 project
415 .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))?
416 .await?;
417
418 action_log.update(cx, |log, cx| {
419 log.buffer_edited(buffer.clone(), cx);
420 })?;
421
422 let new_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
423 let (new_text, unified_diff) = cx
424 .background_spawn({
425 let new_snapshot = new_snapshot.clone();
426 let old_text = old_text.clone();
427 async move {
428 let new_text = new_snapshot.text();
429 let diff = language::unified_diff(&old_text, &new_text);
430 (new_text, diff)
431 }
432 })
433 .await;
434
435 let input_path = input.path.display();
436 if unified_diff.is_empty() {
437 anyhow::ensure!(
438 !hallucinated_old_text,
439 formatdoc! {"
440 Some edits were produced but none of them could be applied.
441 Read the relevant sections of {input_path} again so that
442 I can perform the requested edits.
443 "}
444 );
445 anyhow::ensure!(
446 ambiguous_ranges.is_empty(),
447 {
448 let line_numbers = ambiguous_ranges
449 .iter()
450 .map(|range| range.start.to_string())
451 .collect::<Vec<_>>()
452 .join(", ");
453 formatdoc! {"
454 <old_text> matches more than one position in the file (lines: {line_numbers}). Read the
455 relevant sections of {input_path} again and extend <old_text> so
456 that I can perform the requested edits.
457 "}
458 }
459 );
460 }
461
462 Ok(EditFileToolOutput {
463 input_path: input.path,
464 new_text,
465 old_text,
466 diff: unified_diff,
467 edit_agent_output,
468 })
469 })
470 }
471
472 fn replay(
473 &self,
474 _input: Self::Input,
475 output: Self::Output,
476 event_stream: ToolCallEventStream,
477 cx: &mut App,
478 ) -> Result<()> {
479 event_stream.update_diff(cx.new(|cx| {
480 Diff::finalized(
481 output.input_path,
482 Some(output.old_text.to_string()),
483 output.new_text,
484 self.language_registry.clone(),
485 cx,
486 )
487 }));
488 Ok(())
489 }
490}
491
492/// Validate that the file path is valid, meaning:
493///
494/// - For `edit` and `overwrite`, the path must point to an existing file.
495/// - For `create`, the file must not already exist, but it's parent dir must exist.
496fn resolve_path(
497 input: &EditFileToolInput,
498 project: Entity<Project>,
499 cx: &mut App,
500) -> Result<ProjectPath> {
501 let project = project.read(cx);
502
503 match input.mode {
504 EditFileMode::Edit | EditFileMode::Overwrite => {
505 let path = project
506 .find_project_path(&input.path, cx)
507 .context("Can't edit file: path not found")?;
508
509 let entry = project
510 .entry_for_path(&path, cx)
511 .context("Can't edit file: path not found")?;
512
513 anyhow::ensure!(entry.is_file(), "Can't edit file: path is a directory");
514 Ok(path)
515 }
516
517 EditFileMode::Create => {
518 if let Some(path) = project.find_project_path(&input.path, cx) {
519 anyhow::ensure!(
520 project.entry_for_path(&path, cx).is_none(),
521 "Can't create file: file already exists"
522 );
523 }
524
525 let parent_path = input
526 .path
527 .parent()
528 .context("Can't create file: incorrect path")?;
529
530 let parent_project_path = project.find_project_path(&parent_path, cx);
531
532 let parent_entry = parent_project_path
533 .as_ref()
534 .and_then(|path| project.entry_for_path(path, cx))
535 .context("Can't create file: parent directory doesn't exist")?;
536
537 anyhow::ensure!(
538 parent_entry.is_dir(),
539 "Can't create file: parent is not a directory"
540 );
541
542 let file_name = input
543 .path
544 .file_name()
545 .context("Can't create file: invalid filename")?;
546
547 let new_file_path = parent_project_path.map(|parent| ProjectPath {
548 path: Arc::from(parent.path.join(file_name)),
549 ..parent
550 });
551
552 new_file_path.context("Can't create file")
553 }
554 }
555}
556
557#[cfg(test)]
558mod tests {
559 use super::*;
560 use crate::{ContextServerRegistry, Templates};
561 use client::TelemetrySettings;
562 use fs::Fs;
563 use gpui::{TestAppContext, UpdateGlobal};
564 use language_model::fake_provider::FakeLanguageModel;
565 use prompt_store::ProjectContext;
566 use serde_json::json;
567 use settings::SettingsStore;
568 use util::path;
569
570 #[gpui::test]
571 async fn test_edit_nonexistent_file(cx: &mut TestAppContext) {
572 init_test(cx);
573
574 let fs = project::FakeFs::new(cx.executor());
575 fs.insert_tree("/root", json!({})).await;
576 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
577 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
578 let context_server_registry =
579 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
580 let model = Arc::new(FakeLanguageModel::default());
581 let thread = cx.new(|cx| {
582 Thread::new(
583 project.clone(),
584 cx.new(|_cx| ProjectContext::default()),
585 context_server_registry,
586 Templates::new(),
587 Some(model),
588 cx,
589 )
590 });
591 let result = cx
592 .update(|cx| {
593 let input = EditFileToolInput {
594 display_description: "Some edit".into(),
595 path: "root/nonexistent_file.txt".into(),
596 mode: EditFileMode::Edit,
597 };
598 Arc::new(EditFileTool::new(
599 project,
600 thread.downgrade(),
601 language_registry,
602 ))
603 .run(input, ToolCallEventStream::test().0, cx)
604 })
605 .await;
606 assert_eq!(
607 result.unwrap_err().to_string(),
608 "Can't edit file: path not found"
609 );
610 }
611
612 #[gpui::test]
613 async fn test_resolve_path_for_creating_file(cx: &mut TestAppContext) {
614 let mode = &EditFileMode::Create;
615
616 let result = test_resolve_path(mode, "root/new.txt", cx);
617 assert_resolved_path_eq(result.await, "new.txt");
618
619 let result = test_resolve_path(mode, "new.txt", cx);
620 assert_resolved_path_eq(result.await, "new.txt");
621
622 let result = test_resolve_path(mode, "dir/new.txt", cx);
623 assert_resolved_path_eq(result.await, "dir/new.txt");
624
625 let result = test_resolve_path(mode, "root/dir/subdir/existing.txt", cx);
626 assert_eq!(
627 result.await.unwrap_err().to_string(),
628 "Can't create file: file already exists"
629 );
630
631 let result = test_resolve_path(mode, "root/dir/nonexistent_dir/new.txt", cx);
632 assert_eq!(
633 result.await.unwrap_err().to_string(),
634 "Can't create file: parent directory doesn't exist"
635 );
636 }
637
638 #[gpui::test]
639 async fn test_resolve_path_for_editing_file(cx: &mut TestAppContext) {
640 let mode = &EditFileMode::Edit;
641
642 let path_with_root = "root/dir/subdir/existing.txt";
643 let path_without_root = "dir/subdir/existing.txt";
644 let result = test_resolve_path(mode, path_with_root, cx);
645 assert_resolved_path_eq(result.await, path_without_root);
646
647 let result = test_resolve_path(mode, path_without_root, cx);
648 assert_resolved_path_eq(result.await, path_without_root);
649
650 let result = test_resolve_path(mode, "root/nonexistent.txt", cx);
651 assert_eq!(
652 result.await.unwrap_err().to_string(),
653 "Can't edit file: path not found"
654 );
655
656 let result = test_resolve_path(mode, "root/dir", cx);
657 assert_eq!(
658 result.await.unwrap_err().to_string(),
659 "Can't edit file: path is a directory"
660 );
661 }
662
663 async fn test_resolve_path(
664 mode: &EditFileMode,
665 path: &str,
666 cx: &mut TestAppContext,
667 ) -> anyhow::Result<ProjectPath> {
668 init_test(cx);
669
670 let fs = project::FakeFs::new(cx.executor());
671 fs.insert_tree(
672 "/root",
673 json!({
674 "dir": {
675 "subdir": {
676 "existing.txt": "hello"
677 }
678 }
679 }),
680 )
681 .await;
682 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
683
684 let input = EditFileToolInput {
685 display_description: "Some edit".into(),
686 path: path.into(),
687 mode: mode.clone(),
688 };
689
690 cx.update(|cx| resolve_path(&input, project, cx))
691 }
692
693 fn assert_resolved_path_eq(path: anyhow::Result<ProjectPath>, expected: &str) {
694 let actual = path
695 .expect("Should return valid path")
696 .path
697 .to_str()
698 .unwrap()
699 .replace("\\", "/"); // Naive Windows paths normalization
700 assert_eq!(actual, expected);
701 }
702
703 #[gpui::test]
704 async fn test_format_on_save(cx: &mut TestAppContext) {
705 init_test(cx);
706
707 let fs = project::FakeFs::new(cx.executor());
708 fs.insert_tree("/root", json!({"src": {}})).await;
709
710 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
711
712 // Set up a Rust language with LSP formatting support
713 let rust_language = Arc::new(language::Language::new(
714 language::LanguageConfig {
715 name: "Rust".into(),
716 matcher: language::LanguageMatcher {
717 path_suffixes: vec!["rs".to_string()],
718 ..Default::default()
719 },
720 ..Default::default()
721 },
722 None,
723 ));
724
725 // Register the language and fake LSP
726 let language_registry = project.read_with(cx, |project, _| project.languages().clone());
727 language_registry.add(rust_language);
728
729 let mut fake_language_servers = language_registry.register_fake_lsp(
730 "Rust",
731 language::FakeLspAdapter {
732 capabilities: lsp::ServerCapabilities {
733 document_formatting_provider: Some(lsp::OneOf::Left(true)),
734 ..Default::default()
735 },
736 ..Default::default()
737 },
738 );
739
740 // Create the file
741 fs.save(
742 path!("/root/src/main.rs").as_ref(),
743 &"initial content".into(),
744 language::LineEnding::Unix,
745 )
746 .await
747 .unwrap();
748
749 // Open the buffer to trigger LSP initialization
750 let buffer = project
751 .update(cx, |project, cx| {
752 project.open_local_buffer(path!("/root/src/main.rs"), cx)
753 })
754 .await
755 .unwrap();
756
757 // Register the buffer with language servers
758 let _handle = project.update(cx, |project, cx| {
759 project.register_buffer_with_language_servers(&buffer, cx)
760 });
761
762 const UNFORMATTED_CONTENT: &str = "fn main() {println!(\"Hello!\");}\n";
763 const FORMATTED_CONTENT: &str =
764 "This file was formatted by the fake formatter in the test.\n";
765
766 // Get the fake language server and set up formatting handler
767 let fake_language_server = fake_language_servers.next().await.unwrap();
768 fake_language_server.set_request_handler::<lsp::request::Formatting, _, _>({
769 |_, _| async move {
770 Ok(Some(vec![lsp::TextEdit {
771 range: lsp::Range::new(lsp::Position::new(0, 0), lsp::Position::new(1, 0)),
772 new_text: FORMATTED_CONTENT.to_string(),
773 }]))
774 }
775 });
776
777 let context_server_registry =
778 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
779 let model = Arc::new(FakeLanguageModel::default());
780 let thread = cx.new(|cx| {
781 Thread::new(
782 project.clone(),
783 cx.new(|_cx| ProjectContext::default()),
784 context_server_registry,
785 Templates::new(),
786 Some(model.clone()),
787 cx,
788 )
789 });
790
791 // First, test with format_on_save enabled
792 cx.update(|cx| {
793 SettingsStore::update_global(cx, |store, cx| {
794 store.update_user_settings(cx, |settings| {
795 settings.project.all_languages.defaults.format_on_save = Some(FormatOnSave::On);
796 settings.project.all_languages.defaults.formatter =
797 Some(language::language_settings::SelectedFormatter::Auto);
798 });
799 });
800 });
801
802 // Have the model stream unformatted content
803 let edit_result = {
804 let edit_task = cx.update(|cx| {
805 let input = EditFileToolInput {
806 display_description: "Create main function".into(),
807 path: "root/src/main.rs".into(),
808 mode: EditFileMode::Overwrite,
809 };
810 Arc::new(EditFileTool::new(
811 project.clone(),
812 thread.downgrade(),
813 language_registry.clone(),
814 ))
815 .run(input, ToolCallEventStream::test().0, cx)
816 });
817
818 // Stream the unformatted content
819 cx.executor().run_until_parked();
820 model.send_last_completion_stream_text_chunk(UNFORMATTED_CONTENT.to_string());
821 model.end_last_completion_stream();
822
823 edit_task.await
824 };
825 assert!(edit_result.is_ok());
826
827 // Wait for any async operations (e.g. formatting) to complete
828 cx.executor().run_until_parked();
829
830 // Read the file to verify it was formatted automatically
831 let new_content = fs.load(path!("/root/src/main.rs").as_ref()).await.unwrap();
832 assert_eq!(
833 // Ignore carriage returns on Windows
834 new_content.replace("\r\n", "\n"),
835 FORMATTED_CONTENT,
836 "Code should be formatted when format_on_save is enabled"
837 );
838
839 let stale_buffer_count = thread
840 .read_with(cx, |thread, _cx| thread.action_log.clone())
841 .read_with(cx, |log, cx| log.stale_buffers(cx).count());
842
843 assert_eq!(
844 stale_buffer_count, 0,
845 "BUG: Buffer is incorrectly marked as stale after format-on-save. Found {} stale buffers. \
846 This causes the agent to think the file was modified externally when it was just formatted.",
847 stale_buffer_count
848 );
849
850 // Next, test with format_on_save disabled
851 cx.update(|cx| {
852 SettingsStore::update_global(cx, |store, cx| {
853 store.update_user_settings(cx, |settings| {
854 settings.project.all_languages.defaults.format_on_save =
855 Some(FormatOnSave::Off);
856 });
857 });
858 });
859
860 // Stream unformatted edits again
861 let edit_result = {
862 let edit_task = cx.update(|cx| {
863 let input = EditFileToolInput {
864 display_description: "Update main function".into(),
865 path: "root/src/main.rs".into(),
866 mode: EditFileMode::Overwrite,
867 };
868 Arc::new(EditFileTool::new(
869 project.clone(),
870 thread.downgrade(),
871 language_registry,
872 ))
873 .run(input, ToolCallEventStream::test().0, cx)
874 });
875
876 // Stream the unformatted content
877 cx.executor().run_until_parked();
878 model.send_last_completion_stream_text_chunk(UNFORMATTED_CONTENT.to_string());
879 model.end_last_completion_stream();
880
881 edit_task.await
882 };
883 assert!(edit_result.is_ok());
884
885 // Wait for any async operations (e.g. formatting) to complete
886 cx.executor().run_until_parked();
887
888 // Verify the file was not formatted
889 let new_content = fs.load(path!("/root/src/main.rs").as_ref()).await.unwrap();
890 assert_eq!(
891 // Ignore carriage returns on Windows
892 new_content.replace("\r\n", "\n"),
893 UNFORMATTED_CONTENT,
894 "Code should not be formatted when format_on_save is disabled"
895 );
896 }
897
898 #[gpui::test]
899 async fn test_remove_trailing_whitespace(cx: &mut TestAppContext) {
900 init_test(cx);
901
902 let fs = project::FakeFs::new(cx.executor());
903 fs.insert_tree("/root", json!({"src": {}})).await;
904
905 // Create a simple file with trailing whitespace
906 fs.save(
907 path!("/root/src/main.rs").as_ref(),
908 &"initial content".into(),
909 language::LineEnding::Unix,
910 )
911 .await
912 .unwrap();
913
914 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
915 let context_server_registry =
916 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
917 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
918 let model = Arc::new(FakeLanguageModel::default());
919 let thread = cx.new(|cx| {
920 Thread::new(
921 project.clone(),
922 cx.new(|_cx| ProjectContext::default()),
923 context_server_registry,
924 Templates::new(),
925 Some(model.clone()),
926 cx,
927 )
928 });
929
930 // First, test with remove_trailing_whitespace_on_save enabled
931 cx.update(|cx| {
932 SettingsStore::update_global(cx, |store, cx| {
933 store.update_user_settings(cx, |settings| {
934 settings
935 .project
936 .all_languages
937 .defaults
938 .remove_trailing_whitespace_on_save = Some(true);
939 });
940 });
941 });
942
943 const CONTENT_WITH_TRAILING_WHITESPACE: &str =
944 "fn main() { \n println!(\"Hello!\"); \n}\n";
945
946 // Have the model stream content that contains trailing whitespace
947 let edit_result = {
948 let edit_task = cx.update(|cx| {
949 let input = EditFileToolInput {
950 display_description: "Create main function".into(),
951 path: "root/src/main.rs".into(),
952 mode: EditFileMode::Overwrite,
953 };
954 Arc::new(EditFileTool::new(
955 project.clone(),
956 thread.downgrade(),
957 language_registry.clone(),
958 ))
959 .run(input, ToolCallEventStream::test().0, cx)
960 });
961
962 // Stream the content with trailing whitespace
963 cx.executor().run_until_parked();
964 model.send_last_completion_stream_text_chunk(
965 CONTENT_WITH_TRAILING_WHITESPACE.to_string(),
966 );
967 model.end_last_completion_stream();
968
969 edit_task.await
970 };
971 assert!(edit_result.is_ok());
972
973 // Wait for any async operations (e.g. formatting) to complete
974 cx.executor().run_until_parked();
975
976 // Read the file to verify trailing whitespace was removed automatically
977 assert_eq!(
978 // Ignore carriage returns on Windows
979 fs.load(path!("/root/src/main.rs").as_ref())
980 .await
981 .unwrap()
982 .replace("\r\n", "\n"),
983 "fn main() {\n println!(\"Hello!\");\n}\n",
984 "Trailing whitespace should be removed when remove_trailing_whitespace_on_save is enabled"
985 );
986
987 // Next, test with remove_trailing_whitespace_on_save disabled
988 cx.update(|cx| {
989 SettingsStore::update_global(cx, |store, cx| {
990 store.update_user_settings(cx, |settings| {
991 settings
992 .project
993 .all_languages
994 .defaults
995 .remove_trailing_whitespace_on_save = Some(false);
996 });
997 });
998 });
999
1000 // Stream edits again with trailing whitespace
1001 let edit_result = {
1002 let edit_task = cx.update(|cx| {
1003 let input = EditFileToolInput {
1004 display_description: "Update main function".into(),
1005 path: "root/src/main.rs".into(),
1006 mode: EditFileMode::Overwrite,
1007 };
1008 Arc::new(EditFileTool::new(
1009 project.clone(),
1010 thread.downgrade(),
1011 language_registry,
1012 ))
1013 .run(input, ToolCallEventStream::test().0, cx)
1014 });
1015
1016 // Stream the content with trailing whitespace
1017 cx.executor().run_until_parked();
1018 model.send_last_completion_stream_text_chunk(
1019 CONTENT_WITH_TRAILING_WHITESPACE.to_string(),
1020 );
1021 model.end_last_completion_stream();
1022
1023 edit_task.await
1024 };
1025 assert!(edit_result.is_ok());
1026
1027 // Wait for any async operations (e.g. formatting) to complete
1028 cx.executor().run_until_parked();
1029
1030 // Verify the file still has trailing whitespace
1031 // Read the file again - it should still have trailing whitespace
1032 let final_content = fs.load(path!("/root/src/main.rs").as_ref()).await.unwrap();
1033 assert_eq!(
1034 // Ignore carriage returns on Windows
1035 final_content.replace("\r\n", "\n"),
1036 CONTENT_WITH_TRAILING_WHITESPACE,
1037 "Trailing whitespace should remain when remove_trailing_whitespace_on_save is disabled"
1038 );
1039 }
1040
1041 #[gpui::test]
1042 async fn test_authorize(cx: &mut TestAppContext) {
1043 init_test(cx);
1044 let fs = project::FakeFs::new(cx.executor());
1045 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
1046 let context_server_registry =
1047 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1048 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1049 let model = Arc::new(FakeLanguageModel::default());
1050 let thread = cx.new(|cx| {
1051 Thread::new(
1052 project.clone(),
1053 cx.new(|_cx| ProjectContext::default()),
1054 context_server_registry,
1055 Templates::new(),
1056 Some(model.clone()),
1057 cx,
1058 )
1059 });
1060 let tool = Arc::new(EditFileTool::new(
1061 project.clone(),
1062 thread.downgrade(),
1063 language_registry,
1064 ));
1065 fs.insert_tree("/root", json!({})).await;
1066
1067 // Test 1: Path with .zed component should require confirmation
1068 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1069 let _auth = cx.update(|cx| {
1070 tool.authorize(
1071 &EditFileToolInput {
1072 display_description: "test 1".into(),
1073 path: ".zed/settings.json".into(),
1074 mode: EditFileMode::Edit,
1075 },
1076 &stream_tx,
1077 cx,
1078 )
1079 });
1080
1081 let event = stream_rx.expect_authorization().await;
1082 assert_eq!(
1083 event.tool_call.fields.title,
1084 Some("test 1 (local settings)".into())
1085 );
1086
1087 // Test 2: Path outside project should require confirmation
1088 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1089 let _auth = cx.update(|cx| {
1090 tool.authorize(
1091 &EditFileToolInput {
1092 display_description: "test 2".into(),
1093 path: "/etc/hosts".into(),
1094 mode: EditFileMode::Edit,
1095 },
1096 &stream_tx,
1097 cx,
1098 )
1099 });
1100
1101 let event = stream_rx.expect_authorization().await;
1102 assert_eq!(event.tool_call.fields.title, Some("test 2".into()));
1103
1104 // Test 3: Relative path without .zed should not require confirmation
1105 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1106 cx.update(|cx| {
1107 tool.authorize(
1108 &EditFileToolInput {
1109 display_description: "test 3".into(),
1110 path: "root/src/main.rs".into(),
1111 mode: EditFileMode::Edit,
1112 },
1113 &stream_tx,
1114 cx,
1115 )
1116 })
1117 .await
1118 .unwrap();
1119 assert!(stream_rx.try_next().is_err());
1120
1121 // Test 4: Path with .zed in the middle should require confirmation
1122 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1123 let _auth = cx.update(|cx| {
1124 tool.authorize(
1125 &EditFileToolInput {
1126 display_description: "test 4".into(),
1127 path: "root/.zed/tasks.json".into(),
1128 mode: EditFileMode::Edit,
1129 },
1130 &stream_tx,
1131 cx,
1132 )
1133 });
1134 let event = stream_rx.expect_authorization().await;
1135 assert_eq!(
1136 event.tool_call.fields.title,
1137 Some("test 4 (local settings)".into())
1138 );
1139
1140 // Test 5: When always_allow_tool_actions is enabled, no confirmation needed
1141 cx.update(|cx| {
1142 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
1143 settings.always_allow_tool_actions = true;
1144 agent_settings::AgentSettings::override_global(settings, cx);
1145 });
1146
1147 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1148 cx.update(|cx| {
1149 tool.authorize(
1150 &EditFileToolInput {
1151 display_description: "test 5.1".into(),
1152 path: ".zed/settings.json".into(),
1153 mode: EditFileMode::Edit,
1154 },
1155 &stream_tx,
1156 cx,
1157 )
1158 })
1159 .await
1160 .unwrap();
1161 assert!(stream_rx.try_next().is_err());
1162
1163 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1164 cx.update(|cx| {
1165 tool.authorize(
1166 &EditFileToolInput {
1167 display_description: "test 5.2".into(),
1168 path: "/etc/hosts".into(),
1169 mode: EditFileMode::Edit,
1170 },
1171 &stream_tx,
1172 cx,
1173 )
1174 })
1175 .await
1176 .unwrap();
1177 assert!(stream_rx.try_next().is_err());
1178 }
1179
1180 #[gpui::test]
1181 async fn test_authorize_global_config(cx: &mut TestAppContext) {
1182 init_test(cx);
1183 let fs = project::FakeFs::new(cx.executor());
1184 fs.insert_tree("/project", json!({})).await;
1185 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1186 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1187 let context_server_registry =
1188 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1189 let model = Arc::new(FakeLanguageModel::default());
1190 let thread = cx.new(|cx| {
1191 Thread::new(
1192 project.clone(),
1193 cx.new(|_cx| ProjectContext::default()),
1194 context_server_registry,
1195 Templates::new(),
1196 Some(model.clone()),
1197 cx,
1198 )
1199 });
1200 let tool = Arc::new(EditFileTool::new(
1201 project.clone(),
1202 thread.downgrade(),
1203 language_registry,
1204 ));
1205
1206 // Test global config paths - these should require confirmation if they exist and are outside the project
1207 let test_cases = vec![
1208 (
1209 "/etc/hosts",
1210 true,
1211 "System file should require confirmation",
1212 ),
1213 (
1214 "/usr/local/bin/script",
1215 true,
1216 "System bin file should require confirmation",
1217 ),
1218 (
1219 "project/normal_file.rs",
1220 false,
1221 "Normal project file should not require confirmation",
1222 ),
1223 ];
1224
1225 for (path, should_confirm, description) in test_cases {
1226 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1227 let auth = cx.update(|cx| {
1228 tool.authorize(
1229 &EditFileToolInput {
1230 display_description: "Edit file".into(),
1231 path: path.into(),
1232 mode: EditFileMode::Edit,
1233 },
1234 &stream_tx,
1235 cx,
1236 )
1237 });
1238
1239 if should_confirm {
1240 stream_rx.expect_authorization().await;
1241 } else {
1242 auth.await.unwrap();
1243 assert!(
1244 stream_rx.try_next().is_err(),
1245 "Failed for case: {} - path: {} - expected no confirmation but got one",
1246 description,
1247 path
1248 );
1249 }
1250 }
1251 }
1252
1253 #[gpui::test]
1254 async fn test_needs_confirmation_with_multiple_worktrees(cx: &mut TestAppContext) {
1255 init_test(cx);
1256 let fs = project::FakeFs::new(cx.executor());
1257
1258 // Create multiple worktree directories
1259 fs.insert_tree(
1260 "/workspace/frontend",
1261 json!({
1262 "src": {
1263 "main.js": "console.log('frontend');"
1264 }
1265 }),
1266 )
1267 .await;
1268 fs.insert_tree(
1269 "/workspace/backend",
1270 json!({
1271 "src": {
1272 "main.rs": "fn main() {}"
1273 }
1274 }),
1275 )
1276 .await;
1277 fs.insert_tree(
1278 "/workspace/shared",
1279 json!({
1280 ".zed": {
1281 "settings.json": "{}"
1282 }
1283 }),
1284 )
1285 .await;
1286
1287 // Create project with multiple worktrees
1288 let project = Project::test(
1289 fs.clone(),
1290 [
1291 path!("/workspace/frontend").as_ref(),
1292 path!("/workspace/backend").as_ref(),
1293 path!("/workspace/shared").as_ref(),
1294 ],
1295 cx,
1296 )
1297 .await;
1298 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1299 let context_server_registry =
1300 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1301 let model = Arc::new(FakeLanguageModel::default());
1302 let thread = cx.new(|cx| {
1303 Thread::new(
1304 project.clone(),
1305 cx.new(|_cx| ProjectContext::default()),
1306 context_server_registry.clone(),
1307 Templates::new(),
1308 Some(model.clone()),
1309 cx,
1310 )
1311 });
1312 let tool = Arc::new(EditFileTool::new(
1313 project.clone(),
1314 thread.downgrade(),
1315 language_registry,
1316 ));
1317
1318 // Test files in different worktrees
1319 let test_cases = vec![
1320 ("frontend/src/main.js", false, "File in first worktree"),
1321 ("backend/src/main.rs", false, "File in second worktree"),
1322 (
1323 "shared/.zed/settings.json",
1324 true,
1325 ".zed file in third worktree",
1326 ),
1327 ("/etc/hosts", true, "Absolute path outside all worktrees"),
1328 (
1329 "../outside/file.txt",
1330 true,
1331 "Relative path outside worktrees",
1332 ),
1333 ];
1334
1335 for (path, should_confirm, description) in test_cases {
1336 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1337 let auth = cx.update(|cx| {
1338 tool.authorize(
1339 &EditFileToolInput {
1340 display_description: "Edit file".into(),
1341 path: path.into(),
1342 mode: EditFileMode::Edit,
1343 },
1344 &stream_tx,
1345 cx,
1346 )
1347 });
1348
1349 if should_confirm {
1350 stream_rx.expect_authorization().await;
1351 } else {
1352 auth.await.unwrap();
1353 assert!(
1354 stream_rx.try_next().is_err(),
1355 "Failed for case: {} - path: {} - expected no confirmation but got one",
1356 description,
1357 path
1358 );
1359 }
1360 }
1361 }
1362
1363 #[gpui::test]
1364 async fn test_needs_confirmation_edge_cases(cx: &mut TestAppContext) {
1365 init_test(cx);
1366 let fs = project::FakeFs::new(cx.executor());
1367 fs.insert_tree(
1368 "/project",
1369 json!({
1370 ".zed": {
1371 "settings.json": "{}"
1372 },
1373 "src": {
1374 ".zed": {
1375 "local.json": "{}"
1376 }
1377 }
1378 }),
1379 )
1380 .await;
1381 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1382 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1383 let context_server_registry =
1384 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1385 let model = Arc::new(FakeLanguageModel::default());
1386 let thread = cx.new(|cx| {
1387 Thread::new(
1388 project.clone(),
1389 cx.new(|_cx| ProjectContext::default()),
1390 context_server_registry.clone(),
1391 Templates::new(),
1392 Some(model.clone()),
1393 cx,
1394 )
1395 });
1396 let tool = Arc::new(EditFileTool::new(
1397 project.clone(),
1398 thread.downgrade(),
1399 language_registry,
1400 ));
1401
1402 // Test edge cases
1403 let test_cases = vec![
1404 // Empty path - find_project_path returns Some for empty paths
1405 ("", false, "Empty path is treated as project root"),
1406 // Root directory
1407 ("/", true, "Root directory should be outside project"),
1408 // Parent directory references - find_project_path resolves these
1409 (
1410 "project/../other",
1411 false,
1412 "Path with .. is resolved by find_project_path",
1413 ),
1414 (
1415 "project/./src/file.rs",
1416 false,
1417 "Path with . should work normally",
1418 ),
1419 // Windows-style paths (if on Windows)
1420 #[cfg(target_os = "windows")]
1421 ("C:\\Windows\\System32\\hosts", true, "Windows system path"),
1422 #[cfg(target_os = "windows")]
1423 ("project\\src\\main.rs", false, "Windows-style project path"),
1424 ];
1425
1426 for (path, should_confirm, description) in test_cases {
1427 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1428 let auth = cx.update(|cx| {
1429 tool.authorize(
1430 &EditFileToolInput {
1431 display_description: "Edit file".into(),
1432 path: path.into(),
1433 mode: EditFileMode::Edit,
1434 },
1435 &stream_tx,
1436 cx,
1437 )
1438 });
1439
1440 if should_confirm {
1441 stream_rx.expect_authorization().await;
1442 } else {
1443 auth.await.unwrap();
1444 assert!(
1445 stream_rx.try_next().is_err(),
1446 "Failed for case: {} - path: {} - expected no confirmation but got one",
1447 description,
1448 path
1449 );
1450 }
1451 }
1452 }
1453
1454 #[gpui::test]
1455 async fn test_needs_confirmation_with_different_modes(cx: &mut TestAppContext) {
1456 init_test(cx);
1457 let fs = project::FakeFs::new(cx.executor());
1458 fs.insert_tree(
1459 "/project",
1460 json!({
1461 "existing.txt": "content",
1462 ".zed": {
1463 "settings.json": "{}"
1464 }
1465 }),
1466 )
1467 .await;
1468 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1469 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1470 let context_server_registry =
1471 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1472 let model = Arc::new(FakeLanguageModel::default());
1473 let thread = cx.new(|cx| {
1474 Thread::new(
1475 project.clone(),
1476 cx.new(|_cx| ProjectContext::default()),
1477 context_server_registry.clone(),
1478 Templates::new(),
1479 Some(model.clone()),
1480 cx,
1481 )
1482 });
1483 let tool = Arc::new(EditFileTool::new(
1484 project.clone(),
1485 thread.downgrade(),
1486 language_registry,
1487 ));
1488
1489 // Test different EditFileMode values
1490 let modes = vec![
1491 EditFileMode::Edit,
1492 EditFileMode::Create,
1493 EditFileMode::Overwrite,
1494 ];
1495
1496 for mode in modes {
1497 // Test .zed path with different modes
1498 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1499 let _auth = cx.update(|cx| {
1500 tool.authorize(
1501 &EditFileToolInput {
1502 display_description: "Edit settings".into(),
1503 path: "project/.zed/settings.json".into(),
1504 mode: mode.clone(),
1505 },
1506 &stream_tx,
1507 cx,
1508 )
1509 });
1510
1511 stream_rx.expect_authorization().await;
1512
1513 // Test outside path with different modes
1514 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1515 let _auth = cx.update(|cx| {
1516 tool.authorize(
1517 &EditFileToolInput {
1518 display_description: "Edit file".into(),
1519 path: "/outside/file.txt".into(),
1520 mode: mode.clone(),
1521 },
1522 &stream_tx,
1523 cx,
1524 )
1525 });
1526
1527 stream_rx.expect_authorization().await;
1528
1529 // Test normal path with different modes
1530 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1531 cx.update(|cx| {
1532 tool.authorize(
1533 &EditFileToolInput {
1534 display_description: "Edit file".into(),
1535 path: "project/normal.txt".into(),
1536 mode: mode.clone(),
1537 },
1538 &stream_tx,
1539 cx,
1540 )
1541 })
1542 .await
1543 .unwrap();
1544 assert!(stream_rx.try_next().is_err());
1545 }
1546 }
1547
1548 #[gpui::test]
1549 async fn test_initial_title_with_partial_input(cx: &mut TestAppContext) {
1550 init_test(cx);
1551 let fs = project::FakeFs::new(cx.executor());
1552 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1553 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1554 let context_server_registry =
1555 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1556 let model = Arc::new(FakeLanguageModel::default());
1557 let thread = cx.new(|cx| {
1558 Thread::new(
1559 project.clone(),
1560 cx.new(|_cx| ProjectContext::default()),
1561 context_server_registry,
1562 Templates::new(),
1563 Some(model.clone()),
1564 cx,
1565 )
1566 });
1567 let tool = Arc::new(EditFileTool::new(
1568 project,
1569 thread.downgrade(),
1570 language_registry,
1571 ));
1572
1573 cx.update(|cx| {
1574 // ...
1575 assert_eq!(
1576 tool.initial_title(
1577 Err(json!({
1578 "path": "src/main.rs",
1579 "display_description": "",
1580 "old_string": "old code",
1581 "new_string": "new code"
1582 })),
1583 cx
1584 ),
1585 "src/main.rs"
1586 );
1587 assert_eq!(
1588 tool.initial_title(
1589 Err(json!({
1590 "path": "",
1591 "display_description": "Fix error handling",
1592 "old_string": "old code",
1593 "new_string": "new code"
1594 })),
1595 cx
1596 ),
1597 "Fix error handling"
1598 );
1599 assert_eq!(
1600 tool.initial_title(
1601 Err(json!({
1602 "path": "src/main.rs",
1603 "display_description": "Fix error handling",
1604 "old_string": "old code",
1605 "new_string": "new code"
1606 })),
1607 cx
1608 ),
1609 "src/main.rs"
1610 );
1611 assert_eq!(
1612 tool.initial_title(
1613 Err(json!({
1614 "path": "",
1615 "display_description": "",
1616 "old_string": "old code",
1617 "new_string": "new code"
1618 })),
1619 cx
1620 ),
1621 DEFAULT_UI_TEXT
1622 );
1623 assert_eq!(
1624 tool.initial_title(Err(serde_json::Value::Null), cx),
1625 DEFAULT_UI_TEXT
1626 );
1627 });
1628 }
1629
1630 #[gpui::test]
1631 async fn test_diff_finalization(cx: &mut TestAppContext) {
1632 init_test(cx);
1633 let fs = project::FakeFs::new(cx.executor());
1634 fs.insert_tree("/", json!({"main.rs": ""})).await;
1635
1636 let project = Project::test(fs.clone(), [path!("/").as_ref()], cx).await;
1637 let languages = project.read_with(cx, |project, _cx| project.languages().clone());
1638 let context_server_registry =
1639 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1640 let model = Arc::new(FakeLanguageModel::default());
1641 let thread = cx.new(|cx| {
1642 Thread::new(
1643 project.clone(),
1644 cx.new(|_cx| ProjectContext::default()),
1645 context_server_registry.clone(),
1646 Templates::new(),
1647 Some(model.clone()),
1648 cx,
1649 )
1650 });
1651
1652 // Ensure the diff is finalized after the edit completes.
1653 {
1654 let tool = Arc::new(EditFileTool::new(
1655 project.clone(),
1656 thread.downgrade(),
1657 languages.clone(),
1658 ));
1659 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1660 let edit = cx.update(|cx| {
1661 tool.run(
1662 EditFileToolInput {
1663 display_description: "Edit file".into(),
1664 path: path!("/main.rs").into(),
1665 mode: EditFileMode::Edit,
1666 },
1667 stream_tx,
1668 cx,
1669 )
1670 });
1671 stream_rx.expect_update_fields().await;
1672 let diff = stream_rx.expect_diff().await;
1673 diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Pending(_))));
1674 cx.run_until_parked();
1675 model.end_last_completion_stream();
1676 edit.await.unwrap();
1677 diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Finalized(_))));
1678 }
1679
1680 // Ensure the diff is finalized if an error occurs while editing.
1681 {
1682 model.forbid_requests();
1683 let tool = Arc::new(EditFileTool::new(
1684 project.clone(),
1685 thread.downgrade(),
1686 languages.clone(),
1687 ));
1688 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1689 let edit = cx.update(|cx| {
1690 tool.run(
1691 EditFileToolInput {
1692 display_description: "Edit file".into(),
1693 path: path!("/main.rs").into(),
1694 mode: EditFileMode::Edit,
1695 },
1696 stream_tx,
1697 cx,
1698 )
1699 });
1700 stream_rx.expect_update_fields().await;
1701 let diff = stream_rx.expect_diff().await;
1702 diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Pending(_))));
1703 edit.await.unwrap_err();
1704 diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Finalized(_))));
1705 model.allow_requests();
1706 }
1707
1708 // Ensure the diff is finalized if the tool call gets dropped.
1709 {
1710 let tool = Arc::new(EditFileTool::new(
1711 project.clone(),
1712 thread.downgrade(),
1713 languages.clone(),
1714 ));
1715 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1716 let edit = cx.update(|cx| {
1717 tool.run(
1718 EditFileToolInput {
1719 display_description: "Edit file".into(),
1720 path: path!("/main.rs").into(),
1721 mode: EditFileMode::Edit,
1722 },
1723 stream_tx,
1724 cx,
1725 )
1726 });
1727 stream_rx.expect_update_fields().await;
1728 let diff = stream_rx.expect_diff().await;
1729 diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Pending(_))));
1730 drop(edit);
1731 cx.run_until_parked();
1732 diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Finalized(_))));
1733 }
1734 }
1735
1736 fn init_test(cx: &mut TestAppContext) {
1737 cx.update(|cx| {
1738 let settings_store = SettingsStore::test(cx);
1739 cx.set_global(settings_store);
1740 language::init(cx);
1741 TelemetrySettings::register(cx);
1742 agent_settings::AgentSettings::register(cx);
1743 Project::init_settings(cx);
1744 });
1745 }
1746}