1use crate::{AgentTool, Thread, ToolCallEventStream};
2use acp_thread::Diff;
3use agent_client_protocol::{self as acp, ToolCallLocation, ToolCallUpdateFields};
4use anyhow::{Context as _, Result, anyhow};
5use assistant_tools::edit_agent::{EditAgent, EditAgentOutput, EditAgentOutputEvent, EditFormat};
6use cloud_llm_client::CompletionIntent;
7use collections::HashSet;
8use gpui::{App, AppContext, AsyncApp, Entity, Task, WeakEntity};
9use indoc::formatdoc;
10use language::language_settings::{self, FormatOnSave};
11use language::{LanguageRegistry, ToPoint};
12use language_model::LanguageModelToolResultContent;
13use paths;
14use project::lsp_store::{FormatTrigger, LspFormatTarget};
15use project::{Project, ProjectPath};
16use schemars::JsonSchema;
17use serde::{Deserialize, Serialize};
18use settings::Settings;
19use smol::stream::StreamExt as _;
20use std::path::{Path, PathBuf};
21use std::sync::Arc;
22use ui::SharedString;
23use util::ResultExt;
24
25const DEFAULT_UI_TEXT: &str = "Editing file";
26
27/// This is a tool for creating a new file or editing an existing file. For moving or renaming files, you should generally use the `terminal` tool with the 'mv' command instead.
28///
29/// Before using this tool:
30///
31/// 1. Use the `read_file` tool to understand the file's contents and context
32///
33/// 2. Verify the directory path is correct (only applicable when creating new files):
34/// - Use the `list_directory` tool to verify the parent directory exists and is the correct location
35#[derive(Debug, Serialize, Deserialize, JsonSchema)]
36pub struct EditFileToolInput {
37 /// A one-line, user-friendly markdown description of the edit. This will be shown in the UI and also passed to another model to perform the edit.
38 ///
39 /// Be terse, but also descriptive in what you want to achieve with this edit. Avoid generic instructions.
40 ///
41 /// NEVER mention the file path in this description.
42 ///
43 /// <example>Fix API endpoint URLs</example>
44 /// <example>Update copyright year in `page_footer`</example>
45 ///
46 /// Make sure to include this field before all the others in the input object so that we can display it immediately.
47 pub display_description: String,
48
49 /// The full path of the file to create or modify in the project.
50 ///
51 /// WARNING: When specifying which file path need changing, you MUST start each path with one of the project's root directories.
52 ///
53 /// The following examples assume we have two root directories in the project:
54 /// - /a/b/backend
55 /// - /c/d/frontend
56 ///
57 /// <example>
58 /// `backend/src/main.rs`
59 ///
60 /// Notice how the file path starts with `backend`. Without that, the path would be ambiguous and the call would fail!
61 /// </example>
62 ///
63 /// <example>
64 /// `frontend/db.js`
65 /// </example>
66 pub path: PathBuf,
67 /// The mode of operation on the file. Possible values:
68 /// - 'edit': Make granular edits to an existing file.
69 /// - 'create': Create a new file if it doesn't exist.
70 /// - 'overwrite': Replace the entire contents of an existing file.
71 ///
72 /// When a file already exists or you just created it, prefer editing it as opposed to recreating it from scratch.
73 pub mode: EditFileMode,
74}
75
76#[derive(Debug, Serialize, Deserialize, JsonSchema)]
77struct EditFileToolPartialInput {
78 #[serde(default)]
79 path: String,
80 #[serde(default)]
81 display_description: String,
82}
83
84#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
85#[serde(rename_all = "lowercase")]
86#[schemars(inline)]
87pub enum EditFileMode {
88 Edit,
89 Create,
90 Overwrite,
91}
92
93#[derive(Debug, Serialize, Deserialize)]
94pub struct EditFileToolOutput {
95 #[serde(alias = "original_path")]
96 input_path: PathBuf,
97 new_text: String,
98 old_text: Arc<String>,
99 #[serde(default)]
100 diff: String,
101 #[serde(alias = "raw_output")]
102 edit_agent_output: EditAgentOutput,
103}
104
105impl From<EditFileToolOutput> for LanguageModelToolResultContent {
106 fn from(output: EditFileToolOutput) -> Self {
107 if output.diff.is_empty() {
108 "No edits were made.".into()
109 } else {
110 format!(
111 "Edited {}:\n\n```diff\n{}\n```",
112 output.input_path.display(),
113 output.diff
114 )
115 .into()
116 }
117 }
118}
119
120pub struct EditFileTool {
121 thread: WeakEntity<Thread>,
122 language_registry: Arc<LanguageRegistry>,
123 project: Entity<Project>,
124}
125
126impl EditFileTool {
127 pub fn new(
128 project: Entity<Project>,
129 thread: WeakEntity<Thread>,
130 language_registry: Arc<LanguageRegistry>,
131 ) -> Self {
132 Self {
133 project,
134 thread,
135 language_registry,
136 }
137 }
138
139 fn authorize(
140 &self,
141 input: &EditFileToolInput,
142 event_stream: &ToolCallEventStream,
143 cx: &mut App,
144 ) -> Task<Result<()>> {
145 if agent_settings::AgentSettings::get_global(cx).always_allow_tool_actions {
146 return Task::ready(Ok(()));
147 }
148
149 // If any path component matches the local settings folder, then this could affect
150 // the editor in ways beyond the project source, so prompt.
151 let local_settings_folder = paths::local_settings_folder_relative_path();
152 let path = Path::new(&input.path);
153 if path
154 .components()
155 .any(|component| component.as_os_str() == local_settings_folder.as_os_str())
156 {
157 return event_stream.authorize(
158 format!("{} (local settings)", input.display_description),
159 cx,
160 );
161 }
162
163 // It's also possible that the global config dir is configured to be inside the project,
164 // so check for that edge case too.
165 if let Ok(canonical_path) = std::fs::canonicalize(&input.path)
166 && canonical_path.starts_with(paths::config_dir())
167 {
168 return event_stream.authorize(
169 format!("{} (global settings)", input.display_description),
170 cx,
171 );
172 }
173
174 // Check if path is inside the global config directory
175 // First check if it's already inside project - if not, try to canonicalize
176 let Ok(project_path) = self.thread.read_with(cx, |thread, cx| {
177 thread.project().read(cx).find_project_path(&input.path, cx)
178 }) else {
179 return Task::ready(Err(anyhow!("thread was dropped")));
180 };
181
182 // If the path is inside the project, and it's not one of the above edge cases,
183 // then no confirmation is necessary. Otherwise, confirmation is necessary.
184 if project_path.is_some() {
185 Task::ready(Ok(()))
186 } else {
187 event_stream.authorize(&input.display_description, cx)
188 }
189 }
190}
191
192impl AgentTool for EditFileTool {
193 type Input = EditFileToolInput;
194 type Output = EditFileToolOutput;
195
196 fn name() -> &'static str {
197 "edit_file"
198 }
199
200 fn kind() -> acp::ToolKind {
201 acp::ToolKind::Edit
202 }
203
204 fn initial_title(
205 &self,
206 input: Result<Self::Input, serde_json::Value>,
207 cx: &mut App,
208 ) -> SharedString {
209 match input {
210 Ok(input) => self
211 .project
212 .read(cx)
213 .find_project_path(&input.path, cx)
214 .and_then(|project_path| {
215 self.project
216 .read(cx)
217 .short_full_path_for_project_path(&project_path, cx)
218 })
219 .unwrap_or(Path::new(&input.path).into())
220 .to_string_lossy()
221 .to_string()
222 .into(),
223 Err(raw_input) => {
224 if let Some(input) =
225 serde_json::from_value::<EditFileToolPartialInput>(raw_input).ok()
226 {
227 let path = input.path.trim();
228 if !path.is_empty() {
229 return self
230 .project
231 .read(cx)
232 .find_project_path(&input.path, cx)
233 .and_then(|project_path| {
234 self.project
235 .read(cx)
236 .short_full_path_for_project_path(&project_path, cx)
237 })
238 .unwrap_or(Path::new(&input.path).into())
239 .to_string_lossy()
240 .to_string()
241 .into();
242 }
243
244 let description = input.display_description.trim();
245 if !description.is_empty() {
246 return description.to_string().into();
247 }
248 }
249
250 DEFAULT_UI_TEXT.into()
251 }
252 }
253 }
254
255 fn run(
256 self: Arc<Self>,
257 input: Self::Input,
258 event_stream: ToolCallEventStream,
259 cx: &mut App,
260 ) -> Task<Result<Self::Output>> {
261 let Ok(project) = self
262 .thread
263 .read_with(cx, |thread, _cx| thread.project().clone())
264 else {
265 return Task::ready(Err(anyhow!("thread was dropped")));
266 };
267 let project_path = match resolve_path(&input, project.clone(), cx) {
268 Ok(path) => path,
269 Err(err) => return Task::ready(Err(anyhow!(err))),
270 };
271 let abs_path = project.read(cx).absolute_path(&project_path, cx);
272 if let Some(abs_path) = abs_path.clone() {
273 event_stream.update_fields(ToolCallUpdateFields {
274 locations: Some(vec![acp::ToolCallLocation {
275 path: abs_path,
276 line: None,
277 }]),
278 ..Default::default()
279 });
280 }
281
282 let authorize = self.authorize(&input, &event_stream, cx);
283 cx.spawn(async move |cx: &mut AsyncApp| {
284 authorize.await?;
285
286 let (request, model, action_log) = self.thread.update(cx, |thread, cx| {
287 let request = thread.build_completion_request(CompletionIntent::ToolResults, cx);
288 (request, thread.model().cloned(), thread.action_log().clone())
289 })?;
290 let request = request?;
291 let model = model.context("No language model configured")?;
292
293 let edit_format = EditFormat::from_model(model.clone())?;
294 let edit_agent = EditAgent::new(
295 model,
296 project.clone(),
297 action_log.clone(),
298 // TODO: move edit agent to this crate so we can use our templates
299 assistant_tools::templates::Templates::new(),
300 edit_format,
301 );
302
303 let buffer = project
304 .update(cx, |project, cx| {
305 project.open_buffer(project_path.clone(), cx)
306 })?
307 .await?;
308
309 let diff = cx.new(|cx| Diff::new(buffer.clone(), cx))?;
310 event_stream.update_diff(diff.clone());
311 let _finalize_diff = util::defer({
312 let diff = diff.downgrade();
313 let mut cx = cx.clone();
314 move || {
315 diff.update(&mut cx, |diff, cx| diff.finalize(cx)).ok();
316 }
317 });
318
319 let old_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
320 let old_text = cx
321 .background_spawn({
322 let old_snapshot = old_snapshot.clone();
323 async move { Arc::new(old_snapshot.text()) }
324 })
325 .await;
326
327
328 let (output, mut events) = if matches!(input.mode, EditFileMode::Edit) {
329 edit_agent.edit(
330 buffer.clone(),
331 input.display_description.clone(),
332 &request,
333 cx,
334 )
335 } else {
336 edit_agent.overwrite(
337 buffer.clone(),
338 input.display_description.clone(),
339 &request,
340 cx,
341 )
342 };
343
344 let mut hallucinated_old_text = false;
345 let mut ambiguous_ranges = Vec::new();
346 let mut emitted_location = false;
347 while let Some(event) = events.next().await {
348 match event {
349 EditAgentOutputEvent::Edited(range) => {
350 if !emitted_location {
351 let line = buffer.update(cx, |buffer, _cx| {
352 range.start.to_point(&buffer.snapshot()).row
353 }).ok();
354 if let Some(abs_path) = abs_path.clone() {
355 event_stream.update_fields(ToolCallUpdateFields {
356 locations: Some(vec![ToolCallLocation { path: abs_path, line }]),
357 ..Default::default()
358 });
359 }
360 emitted_location = true;
361 }
362 },
363 EditAgentOutputEvent::UnresolvedEditRange => hallucinated_old_text = true,
364 EditAgentOutputEvent::AmbiguousEditRange(ranges) => ambiguous_ranges = ranges,
365 EditAgentOutputEvent::ResolvingEditRange(range) => {
366 diff.update(cx, |card, cx| card.reveal_range(range.clone(), cx))?;
367 // if !emitted_location {
368 // let line = buffer.update(cx, |buffer, _cx| {
369 // range.start.to_point(&buffer.snapshot()).row
370 // }).ok();
371 // if let Some(abs_path) = abs_path.clone() {
372 // event_stream.update_fields(ToolCallUpdateFields {
373 // locations: Some(vec![ToolCallLocation { path: abs_path, line }]),
374 // ..Default::default()
375 // });
376 // }
377 // }
378 }
379 }
380 }
381
382 // If format_on_save is enabled, format the buffer
383 let format_on_save_enabled = buffer
384 .read_with(cx, |buffer, cx| {
385 let settings = language_settings::language_settings(
386 buffer.language().map(|l| l.name()),
387 buffer.file(),
388 cx,
389 );
390 settings.format_on_save != FormatOnSave::Off
391 })
392 .unwrap_or(false);
393
394 let edit_agent_output = output.await?;
395
396 if format_on_save_enabled {
397 action_log.update(cx, |log, cx| {
398 log.buffer_edited(buffer.clone(), cx);
399 })?;
400
401 let format_task = project.update(cx, |project, cx| {
402 project.format(
403 HashSet::from_iter([buffer.clone()]),
404 LspFormatTarget::Buffers,
405 false, // Don't push to history since the tool did it.
406 FormatTrigger::Save,
407 cx,
408 )
409 })?;
410 format_task.await.log_err();
411 }
412
413 project
414 .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))?
415 .await?;
416
417 action_log.update(cx, |log, cx| {
418 log.buffer_edited(buffer.clone(), cx);
419 })?;
420
421 let new_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
422 let (new_text, unified_diff) = cx
423 .background_spawn({
424 let new_snapshot = new_snapshot.clone();
425 let old_text = old_text.clone();
426 async move {
427 let new_text = new_snapshot.text();
428 let diff = language::unified_diff(&old_text, &new_text);
429 (new_text, diff)
430 }
431 })
432 .await;
433
434 let input_path = input.path.display();
435 if unified_diff.is_empty() {
436 anyhow::ensure!(
437 !hallucinated_old_text,
438 formatdoc! {"
439 Some edits were produced but none of them could be applied.
440 Read the relevant sections of {input_path} again so that
441 I can perform the requested edits.
442 "}
443 );
444 anyhow::ensure!(
445 ambiguous_ranges.is_empty(),
446 {
447 let line_numbers = ambiguous_ranges
448 .iter()
449 .map(|range| range.start.to_string())
450 .collect::<Vec<_>>()
451 .join(", ");
452 formatdoc! {"
453 <old_text> matches more than one position in the file (lines: {line_numbers}). Read the
454 relevant sections of {input_path} again and extend <old_text> so
455 that I can perform the requested edits.
456 "}
457 }
458 );
459 }
460
461 Ok(EditFileToolOutput {
462 input_path: input.path,
463 new_text,
464 old_text,
465 diff: unified_diff,
466 edit_agent_output,
467 })
468 })
469 }
470
471 fn replay(
472 &self,
473 _input: Self::Input,
474 output: Self::Output,
475 event_stream: ToolCallEventStream,
476 cx: &mut App,
477 ) -> Result<()> {
478 event_stream.update_diff(cx.new(|cx| {
479 Diff::finalized(
480 output.input_path,
481 Some(output.old_text.to_string()),
482 output.new_text,
483 self.language_registry.clone(),
484 cx,
485 )
486 }));
487 Ok(())
488 }
489}
490
491/// Validate that the file path is valid, meaning:
492///
493/// - For `edit` and `overwrite`, the path must point to an existing file.
494/// - For `create`, the file must not already exist, but it's parent dir must exist.
495fn resolve_path(
496 input: &EditFileToolInput,
497 project: Entity<Project>,
498 cx: &mut App,
499) -> Result<ProjectPath> {
500 let project = project.read(cx);
501
502 match input.mode {
503 EditFileMode::Edit | EditFileMode::Overwrite => {
504 let path = project
505 .find_project_path(&input.path, cx)
506 .context("Can't edit file: path not found")?;
507
508 let entry = project
509 .entry_for_path(&path, cx)
510 .context("Can't edit file: path not found")?;
511
512 anyhow::ensure!(entry.is_file(), "Can't edit file: path is a directory");
513 Ok(path)
514 }
515
516 EditFileMode::Create => {
517 if let Some(path) = project.find_project_path(&input.path, cx) {
518 anyhow::ensure!(
519 project.entry_for_path(&path, cx).is_none(),
520 "Can't create file: file already exists"
521 );
522 }
523
524 let parent_path = input
525 .path
526 .parent()
527 .context("Can't create file: incorrect path")?;
528
529 let parent_project_path = project.find_project_path(&parent_path, cx);
530
531 let parent_entry = parent_project_path
532 .as_ref()
533 .and_then(|path| project.entry_for_path(path, cx))
534 .context("Can't create file: parent directory doesn't exist")?;
535
536 anyhow::ensure!(
537 parent_entry.is_dir(),
538 "Can't create file: parent is not a directory"
539 );
540
541 let file_name = input
542 .path
543 .file_name()
544 .context("Can't create file: invalid filename")?;
545
546 let new_file_path = parent_project_path.map(|parent| ProjectPath {
547 path: Arc::from(parent.path.join(file_name)),
548 ..parent
549 });
550
551 new_file_path.context("Can't create file")
552 }
553 }
554}
555
556#[cfg(test)]
557mod tests {
558 use super::*;
559 use crate::{ContextServerRegistry, Templates};
560 use client::TelemetrySettings;
561 use fs::Fs;
562 use gpui::{TestAppContext, UpdateGlobal};
563 use language_model::fake_provider::FakeLanguageModel;
564 use prompt_store::ProjectContext;
565 use serde_json::json;
566 use settings::SettingsStore;
567 use util::path;
568
569 #[gpui::test]
570 async fn test_edit_nonexistent_file(cx: &mut TestAppContext) {
571 init_test(cx);
572
573 let fs = project::FakeFs::new(cx.executor());
574 fs.insert_tree("/root", json!({})).await;
575 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
576 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
577 let context_server_registry =
578 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
579 let model = Arc::new(FakeLanguageModel::default());
580 let thread = cx.new(|cx| {
581 Thread::new(
582 project.clone(),
583 cx.new(|_cx| ProjectContext::default()),
584 context_server_registry,
585 Templates::new(),
586 Some(model),
587 cx,
588 )
589 });
590 let result = cx
591 .update(|cx| {
592 let input = EditFileToolInput {
593 display_description: "Some edit".into(),
594 path: "root/nonexistent_file.txt".into(),
595 mode: EditFileMode::Edit,
596 };
597 Arc::new(EditFileTool::new(
598 project,
599 thread.downgrade(),
600 language_registry,
601 ))
602 .run(input, ToolCallEventStream::test().0, cx)
603 })
604 .await;
605 assert_eq!(
606 result.unwrap_err().to_string(),
607 "Can't edit file: path not found"
608 );
609 }
610
611 #[gpui::test]
612 async fn test_resolve_path_for_creating_file(cx: &mut TestAppContext) {
613 let mode = &EditFileMode::Create;
614
615 let result = test_resolve_path(mode, "root/new.txt", cx);
616 assert_resolved_path_eq(result.await, "new.txt");
617
618 let result = test_resolve_path(mode, "new.txt", cx);
619 assert_resolved_path_eq(result.await, "new.txt");
620
621 let result = test_resolve_path(mode, "dir/new.txt", cx);
622 assert_resolved_path_eq(result.await, "dir/new.txt");
623
624 let result = test_resolve_path(mode, "root/dir/subdir/existing.txt", cx);
625 assert_eq!(
626 result.await.unwrap_err().to_string(),
627 "Can't create file: file already exists"
628 );
629
630 let result = test_resolve_path(mode, "root/dir/nonexistent_dir/new.txt", cx);
631 assert_eq!(
632 result.await.unwrap_err().to_string(),
633 "Can't create file: parent directory doesn't exist"
634 );
635 }
636
637 #[gpui::test]
638 async fn test_resolve_path_for_editing_file(cx: &mut TestAppContext) {
639 let mode = &EditFileMode::Edit;
640
641 let path_with_root = "root/dir/subdir/existing.txt";
642 let path_without_root = "dir/subdir/existing.txt";
643 let result = test_resolve_path(mode, path_with_root, cx);
644 assert_resolved_path_eq(result.await, path_without_root);
645
646 let result = test_resolve_path(mode, path_without_root, cx);
647 assert_resolved_path_eq(result.await, path_without_root);
648
649 let result = test_resolve_path(mode, "root/nonexistent.txt", cx);
650 assert_eq!(
651 result.await.unwrap_err().to_string(),
652 "Can't edit file: path not found"
653 );
654
655 let result = test_resolve_path(mode, "root/dir", cx);
656 assert_eq!(
657 result.await.unwrap_err().to_string(),
658 "Can't edit file: path is a directory"
659 );
660 }
661
662 async fn test_resolve_path(
663 mode: &EditFileMode,
664 path: &str,
665 cx: &mut TestAppContext,
666 ) -> anyhow::Result<ProjectPath> {
667 init_test(cx);
668
669 let fs = project::FakeFs::new(cx.executor());
670 fs.insert_tree(
671 "/root",
672 json!({
673 "dir": {
674 "subdir": {
675 "existing.txt": "hello"
676 }
677 }
678 }),
679 )
680 .await;
681 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
682
683 let input = EditFileToolInput {
684 display_description: "Some edit".into(),
685 path: path.into(),
686 mode: mode.clone(),
687 };
688
689 cx.update(|cx| resolve_path(&input, project, cx))
690 }
691
692 fn assert_resolved_path_eq(path: anyhow::Result<ProjectPath>, expected: &str) {
693 let actual = path
694 .expect("Should return valid path")
695 .path
696 .to_str()
697 .unwrap()
698 .replace("\\", "/"); // Naive Windows paths normalization
699 assert_eq!(actual, expected);
700 }
701
702 #[gpui::test]
703 async fn test_format_on_save(cx: &mut TestAppContext) {
704 init_test(cx);
705
706 let fs = project::FakeFs::new(cx.executor());
707 fs.insert_tree("/root", json!({"src": {}})).await;
708
709 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
710
711 // Set up a Rust language with LSP formatting support
712 let rust_language = Arc::new(language::Language::new(
713 language::LanguageConfig {
714 name: "Rust".into(),
715 matcher: language::LanguageMatcher {
716 path_suffixes: vec!["rs".to_string()],
717 ..Default::default()
718 },
719 ..Default::default()
720 },
721 None,
722 ));
723
724 // Register the language and fake LSP
725 let language_registry = project.read_with(cx, |project, _| project.languages().clone());
726 language_registry.add(rust_language);
727
728 let mut fake_language_servers = language_registry.register_fake_lsp(
729 "Rust",
730 language::FakeLspAdapter {
731 capabilities: lsp::ServerCapabilities {
732 document_formatting_provider: Some(lsp::OneOf::Left(true)),
733 ..Default::default()
734 },
735 ..Default::default()
736 },
737 );
738
739 // Create the file
740 fs.save(
741 path!("/root/src/main.rs").as_ref(),
742 &"initial content".into(),
743 language::LineEnding::Unix,
744 )
745 .await
746 .unwrap();
747
748 // Open the buffer to trigger LSP initialization
749 let buffer = project
750 .update(cx, |project, cx| {
751 project.open_local_buffer(path!("/root/src/main.rs"), cx)
752 })
753 .await
754 .unwrap();
755
756 // Register the buffer with language servers
757 let _handle = project.update(cx, |project, cx| {
758 project.register_buffer_with_language_servers(&buffer, cx)
759 });
760
761 const UNFORMATTED_CONTENT: &str = "fn main() {println!(\"Hello!\");}\n";
762 const FORMATTED_CONTENT: &str =
763 "This file was formatted by the fake formatter in the test.\n";
764
765 // Get the fake language server and set up formatting handler
766 let fake_language_server = fake_language_servers.next().await.unwrap();
767 fake_language_server.set_request_handler::<lsp::request::Formatting, _, _>({
768 |_, _| async move {
769 Ok(Some(vec![lsp::TextEdit {
770 range: lsp::Range::new(lsp::Position::new(0, 0), lsp::Position::new(1, 0)),
771 new_text: FORMATTED_CONTENT.to_string(),
772 }]))
773 }
774 });
775
776 let context_server_registry =
777 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
778 let model = Arc::new(FakeLanguageModel::default());
779 let thread = cx.new(|cx| {
780 Thread::new(
781 project.clone(),
782 cx.new(|_cx| ProjectContext::default()),
783 context_server_registry,
784 Templates::new(),
785 Some(model.clone()),
786 cx,
787 )
788 });
789
790 // First, test with format_on_save enabled
791 cx.update(|cx| {
792 SettingsStore::update_global(cx, |store, cx| {
793 store.update_user_settings::<language::language_settings::AllLanguageSettings>(
794 cx,
795 |settings| {
796 settings.defaults.format_on_save = Some(FormatOnSave::On);
797 settings.defaults.formatter =
798 Some(language::language_settings::SelectedFormatter::Auto);
799 },
800 );
801 });
802 });
803
804 // Have the model stream unformatted content
805 let edit_result = {
806 let edit_task = cx.update(|cx| {
807 let input = EditFileToolInput {
808 display_description: "Create main function".into(),
809 path: "root/src/main.rs".into(),
810 mode: EditFileMode::Overwrite,
811 };
812 Arc::new(EditFileTool::new(
813 project.clone(),
814 thread.downgrade(),
815 language_registry.clone(),
816 ))
817 .run(input, ToolCallEventStream::test().0, cx)
818 });
819
820 // Stream the unformatted content
821 cx.executor().run_until_parked();
822 model.send_last_completion_stream_text_chunk(UNFORMATTED_CONTENT.to_string());
823 model.end_last_completion_stream();
824
825 edit_task.await
826 };
827 assert!(edit_result.is_ok());
828
829 // Wait for any async operations (e.g. formatting) to complete
830 cx.executor().run_until_parked();
831
832 // Read the file to verify it was formatted automatically
833 let new_content = fs.load(path!("/root/src/main.rs").as_ref()).await.unwrap();
834 assert_eq!(
835 // Ignore carriage returns on Windows
836 new_content.replace("\r\n", "\n"),
837 FORMATTED_CONTENT,
838 "Code should be formatted when format_on_save is enabled"
839 );
840
841 let stale_buffer_count = thread
842 .read_with(cx, |thread, _cx| thread.action_log.clone())
843 .read_with(cx, |log, cx| log.stale_buffers(cx).count());
844
845 assert_eq!(
846 stale_buffer_count, 0,
847 "BUG: Buffer is incorrectly marked as stale after format-on-save. Found {} stale buffers. \
848 This causes the agent to think the file was modified externally when it was just formatted.",
849 stale_buffer_count
850 );
851
852 // Next, test with format_on_save disabled
853 cx.update(|cx| {
854 SettingsStore::update_global(cx, |store, cx| {
855 store.update_user_settings::<language::language_settings::AllLanguageSettings>(
856 cx,
857 |settings| {
858 settings.defaults.format_on_save = Some(FormatOnSave::Off);
859 },
860 );
861 });
862 });
863
864 // Stream unformatted edits again
865 let edit_result = {
866 let edit_task = cx.update(|cx| {
867 let input = EditFileToolInput {
868 display_description: "Update main function".into(),
869 path: "root/src/main.rs".into(),
870 mode: EditFileMode::Overwrite,
871 };
872 Arc::new(EditFileTool::new(
873 project.clone(),
874 thread.downgrade(),
875 language_registry,
876 ))
877 .run(input, ToolCallEventStream::test().0, cx)
878 });
879
880 // Stream the unformatted content
881 cx.executor().run_until_parked();
882 model.send_last_completion_stream_text_chunk(UNFORMATTED_CONTENT.to_string());
883 model.end_last_completion_stream();
884
885 edit_task.await
886 };
887 assert!(edit_result.is_ok());
888
889 // Wait for any async operations (e.g. formatting) to complete
890 cx.executor().run_until_parked();
891
892 // Verify the file was not formatted
893 let new_content = fs.load(path!("/root/src/main.rs").as_ref()).await.unwrap();
894 assert_eq!(
895 // Ignore carriage returns on Windows
896 new_content.replace("\r\n", "\n"),
897 UNFORMATTED_CONTENT,
898 "Code should not be formatted when format_on_save is disabled"
899 );
900 }
901
902 #[gpui::test]
903 async fn test_remove_trailing_whitespace(cx: &mut TestAppContext) {
904 init_test(cx);
905
906 let fs = project::FakeFs::new(cx.executor());
907 fs.insert_tree("/root", json!({"src": {}})).await;
908
909 // Create a simple file with trailing whitespace
910 fs.save(
911 path!("/root/src/main.rs").as_ref(),
912 &"initial content".into(),
913 language::LineEnding::Unix,
914 )
915 .await
916 .unwrap();
917
918 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
919 let context_server_registry =
920 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
921 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
922 let model = Arc::new(FakeLanguageModel::default());
923 let thread = cx.new(|cx| {
924 Thread::new(
925 project.clone(),
926 cx.new(|_cx| ProjectContext::default()),
927 context_server_registry,
928 Templates::new(),
929 Some(model.clone()),
930 cx,
931 )
932 });
933
934 // First, test with remove_trailing_whitespace_on_save enabled
935 cx.update(|cx| {
936 SettingsStore::update_global(cx, |store, cx| {
937 store.update_user_settings::<language::language_settings::AllLanguageSettings>(
938 cx,
939 |settings| {
940 settings.defaults.remove_trailing_whitespace_on_save = Some(true);
941 },
942 );
943 });
944 });
945
946 const CONTENT_WITH_TRAILING_WHITESPACE: &str =
947 "fn main() { \n println!(\"Hello!\"); \n}\n";
948
949 // Have the model stream content that contains trailing whitespace
950 let edit_result = {
951 let edit_task = cx.update(|cx| {
952 let input = EditFileToolInput {
953 display_description: "Create main function".into(),
954 path: "root/src/main.rs".into(),
955 mode: EditFileMode::Overwrite,
956 };
957 Arc::new(EditFileTool::new(
958 project.clone(),
959 thread.downgrade(),
960 language_registry.clone(),
961 ))
962 .run(input, ToolCallEventStream::test().0, cx)
963 });
964
965 // Stream the content with trailing whitespace
966 cx.executor().run_until_parked();
967 model.send_last_completion_stream_text_chunk(
968 CONTENT_WITH_TRAILING_WHITESPACE.to_string(),
969 );
970 model.end_last_completion_stream();
971
972 edit_task.await
973 };
974 assert!(edit_result.is_ok());
975
976 // Wait for any async operations (e.g. formatting) to complete
977 cx.executor().run_until_parked();
978
979 // Read the file to verify trailing whitespace was removed automatically
980 assert_eq!(
981 // Ignore carriage returns on Windows
982 fs.load(path!("/root/src/main.rs").as_ref())
983 .await
984 .unwrap()
985 .replace("\r\n", "\n"),
986 "fn main() {\n println!(\"Hello!\");\n}\n",
987 "Trailing whitespace should be removed when remove_trailing_whitespace_on_save is enabled"
988 );
989
990 // Next, test with remove_trailing_whitespace_on_save disabled
991 cx.update(|cx| {
992 SettingsStore::update_global(cx, |store, cx| {
993 store.update_user_settings::<language::language_settings::AllLanguageSettings>(
994 cx,
995 |settings| {
996 settings.defaults.remove_trailing_whitespace_on_save = Some(false);
997 },
998 );
999 });
1000 });
1001
1002 // Stream edits again with trailing whitespace
1003 let edit_result = {
1004 let edit_task = cx.update(|cx| {
1005 let input = EditFileToolInput {
1006 display_description: "Update main function".into(),
1007 path: "root/src/main.rs".into(),
1008 mode: EditFileMode::Overwrite,
1009 };
1010 Arc::new(EditFileTool::new(
1011 project.clone(),
1012 thread.downgrade(),
1013 language_registry,
1014 ))
1015 .run(input, ToolCallEventStream::test().0, cx)
1016 });
1017
1018 // Stream the content with trailing whitespace
1019 cx.executor().run_until_parked();
1020 model.send_last_completion_stream_text_chunk(
1021 CONTENT_WITH_TRAILING_WHITESPACE.to_string(),
1022 );
1023 model.end_last_completion_stream();
1024
1025 edit_task.await
1026 };
1027 assert!(edit_result.is_ok());
1028
1029 // Wait for any async operations (e.g. formatting) to complete
1030 cx.executor().run_until_parked();
1031
1032 // Verify the file still has trailing whitespace
1033 // Read the file again - it should still have trailing whitespace
1034 let final_content = fs.load(path!("/root/src/main.rs").as_ref()).await.unwrap();
1035 assert_eq!(
1036 // Ignore carriage returns on Windows
1037 final_content.replace("\r\n", "\n"),
1038 CONTENT_WITH_TRAILING_WHITESPACE,
1039 "Trailing whitespace should remain when remove_trailing_whitespace_on_save is disabled"
1040 );
1041 }
1042
1043 #[gpui::test]
1044 async fn test_authorize(cx: &mut TestAppContext) {
1045 init_test(cx);
1046 let fs = project::FakeFs::new(cx.executor());
1047 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
1048 let context_server_registry =
1049 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1050 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1051 let model = Arc::new(FakeLanguageModel::default());
1052 let thread = cx.new(|cx| {
1053 Thread::new(
1054 project.clone(),
1055 cx.new(|_cx| ProjectContext::default()),
1056 context_server_registry,
1057 Templates::new(),
1058 Some(model.clone()),
1059 cx,
1060 )
1061 });
1062 let tool = Arc::new(EditFileTool::new(
1063 project.clone(),
1064 thread.downgrade(),
1065 language_registry,
1066 ));
1067 fs.insert_tree("/root", json!({})).await;
1068
1069 // Test 1: Path with .zed component should require confirmation
1070 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1071 let _auth = cx.update(|cx| {
1072 tool.authorize(
1073 &EditFileToolInput {
1074 display_description: "test 1".into(),
1075 path: ".zed/settings.json".into(),
1076 mode: EditFileMode::Edit,
1077 },
1078 &stream_tx,
1079 cx,
1080 )
1081 });
1082
1083 let event = stream_rx.expect_authorization().await;
1084 assert_eq!(
1085 event.tool_call.fields.title,
1086 Some("test 1 (local settings)".into())
1087 );
1088
1089 // Test 2: Path outside project should require confirmation
1090 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1091 let _auth = cx.update(|cx| {
1092 tool.authorize(
1093 &EditFileToolInput {
1094 display_description: "test 2".into(),
1095 path: "/etc/hosts".into(),
1096 mode: EditFileMode::Edit,
1097 },
1098 &stream_tx,
1099 cx,
1100 )
1101 });
1102
1103 let event = stream_rx.expect_authorization().await;
1104 assert_eq!(event.tool_call.fields.title, Some("test 2".into()));
1105
1106 // Test 3: Relative path without .zed should not require confirmation
1107 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1108 cx.update(|cx| {
1109 tool.authorize(
1110 &EditFileToolInput {
1111 display_description: "test 3".into(),
1112 path: "root/src/main.rs".into(),
1113 mode: EditFileMode::Edit,
1114 },
1115 &stream_tx,
1116 cx,
1117 )
1118 })
1119 .await
1120 .unwrap();
1121 assert!(stream_rx.try_next().is_err());
1122
1123 // Test 4: Path with .zed in the middle should require confirmation
1124 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1125 let _auth = cx.update(|cx| {
1126 tool.authorize(
1127 &EditFileToolInput {
1128 display_description: "test 4".into(),
1129 path: "root/.zed/tasks.json".into(),
1130 mode: EditFileMode::Edit,
1131 },
1132 &stream_tx,
1133 cx,
1134 )
1135 });
1136 let event = stream_rx.expect_authorization().await;
1137 assert_eq!(
1138 event.tool_call.fields.title,
1139 Some("test 4 (local settings)".into())
1140 );
1141
1142 // Test 5: When always_allow_tool_actions is enabled, no confirmation needed
1143 cx.update(|cx| {
1144 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
1145 settings.always_allow_tool_actions = true;
1146 agent_settings::AgentSettings::override_global(settings, cx);
1147 });
1148
1149 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1150 cx.update(|cx| {
1151 tool.authorize(
1152 &EditFileToolInput {
1153 display_description: "test 5.1".into(),
1154 path: ".zed/settings.json".into(),
1155 mode: EditFileMode::Edit,
1156 },
1157 &stream_tx,
1158 cx,
1159 )
1160 })
1161 .await
1162 .unwrap();
1163 assert!(stream_rx.try_next().is_err());
1164
1165 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1166 cx.update(|cx| {
1167 tool.authorize(
1168 &EditFileToolInput {
1169 display_description: "test 5.2".into(),
1170 path: "/etc/hosts".into(),
1171 mode: EditFileMode::Edit,
1172 },
1173 &stream_tx,
1174 cx,
1175 )
1176 })
1177 .await
1178 .unwrap();
1179 assert!(stream_rx.try_next().is_err());
1180 }
1181
1182 #[gpui::test]
1183 async fn test_authorize_global_config(cx: &mut TestAppContext) {
1184 init_test(cx);
1185 let fs = project::FakeFs::new(cx.executor());
1186 fs.insert_tree("/project", json!({})).await;
1187 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1188 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1189 let context_server_registry =
1190 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1191 let model = Arc::new(FakeLanguageModel::default());
1192 let thread = cx.new(|cx| {
1193 Thread::new(
1194 project.clone(),
1195 cx.new(|_cx| ProjectContext::default()),
1196 context_server_registry,
1197 Templates::new(),
1198 Some(model.clone()),
1199 cx,
1200 )
1201 });
1202 let tool = Arc::new(EditFileTool::new(
1203 project.clone(),
1204 thread.downgrade(),
1205 language_registry,
1206 ));
1207
1208 // Test global config paths - these should require confirmation if they exist and are outside the project
1209 let test_cases = vec![
1210 (
1211 "/etc/hosts",
1212 true,
1213 "System file should require confirmation",
1214 ),
1215 (
1216 "/usr/local/bin/script",
1217 true,
1218 "System bin file should require confirmation",
1219 ),
1220 (
1221 "project/normal_file.rs",
1222 false,
1223 "Normal project file should not require confirmation",
1224 ),
1225 ];
1226
1227 for (path, should_confirm, description) in test_cases {
1228 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1229 let auth = cx.update(|cx| {
1230 tool.authorize(
1231 &EditFileToolInput {
1232 display_description: "Edit file".into(),
1233 path: path.into(),
1234 mode: EditFileMode::Edit,
1235 },
1236 &stream_tx,
1237 cx,
1238 )
1239 });
1240
1241 if should_confirm {
1242 stream_rx.expect_authorization().await;
1243 } else {
1244 auth.await.unwrap();
1245 assert!(
1246 stream_rx.try_next().is_err(),
1247 "Failed for case: {} - path: {} - expected no confirmation but got one",
1248 description,
1249 path
1250 );
1251 }
1252 }
1253 }
1254
1255 #[gpui::test]
1256 async fn test_needs_confirmation_with_multiple_worktrees(cx: &mut TestAppContext) {
1257 init_test(cx);
1258 let fs = project::FakeFs::new(cx.executor());
1259
1260 // Create multiple worktree directories
1261 fs.insert_tree(
1262 "/workspace/frontend",
1263 json!({
1264 "src": {
1265 "main.js": "console.log('frontend');"
1266 }
1267 }),
1268 )
1269 .await;
1270 fs.insert_tree(
1271 "/workspace/backend",
1272 json!({
1273 "src": {
1274 "main.rs": "fn main() {}"
1275 }
1276 }),
1277 )
1278 .await;
1279 fs.insert_tree(
1280 "/workspace/shared",
1281 json!({
1282 ".zed": {
1283 "settings.json": "{}"
1284 }
1285 }),
1286 )
1287 .await;
1288
1289 // Create project with multiple worktrees
1290 let project = Project::test(
1291 fs.clone(),
1292 [
1293 path!("/workspace/frontend").as_ref(),
1294 path!("/workspace/backend").as_ref(),
1295 path!("/workspace/shared").as_ref(),
1296 ],
1297 cx,
1298 )
1299 .await;
1300 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1301 let context_server_registry =
1302 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1303 let model = Arc::new(FakeLanguageModel::default());
1304 let thread = cx.new(|cx| {
1305 Thread::new(
1306 project.clone(),
1307 cx.new(|_cx| ProjectContext::default()),
1308 context_server_registry.clone(),
1309 Templates::new(),
1310 Some(model.clone()),
1311 cx,
1312 )
1313 });
1314 let tool = Arc::new(EditFileTool::new(
1315 project.clone(),
1316 thread.downgrade(),
1317 language_registry,
1318 ));
1319
1320 // Test files in different worktrees
1321 let test_cases = vec![
1322 ("frontend/src/main.js", false, "File in first worktree"),
1323 ("backend/src/main.rs", false, "File in second worktree"),
1324 (
1325 "shared/.zed/settings.json",
1326 true,
1327 ".zed file in third worktree",
1328 ),
1329 ("/etc/hosts", true, "Absolute path outside all worktrees"),
1330 (
1331 "../outside/file.txt",
1332 true,
1333 "Relative path outside worktrees",
1334 ),
1335 ];
1336
1337 for (path, should_confirm, description) in test_cases {
1338 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1339 let auth = cx.update(|cx| {
1340 tool.authorize(
1341 &EditFileToolInput {
1342 display_description: "Edit file".into(),
1343 path: path.into(),
1344 mode: EditFileMode::Edit,
1345 },
1346 &stream_tx,
1347 cx,
1348 )
1349 });
1350
1351 if should_confirm {
1352 stream_rx.expect_authorization().await;
1353 } else {
1354 auth.await.unwrap();
1355 assert!(
1356 stream_rx.try_next().is_err(),
1357 "Failed for case: {} - path: {} - expected no confirmation but got one",
1358 description,
1359 path
1360 );
1361 }
1362 }
1363 }
1364
1365 #[gpui::test]
1366 async fn test_needs_confirmation_edge_cases(cx: &mut TestAppContext) {
1367 init_test(cx);
1368 let fs = project::FakeFs::new(cx.executor());
1369 fs.insert_tree(
1370 "/project",
1371 json!({
1372 ".zed": {
1373 "settings.json": "{}"
1374 },
1375 "src": {
1376 ".zed": {
1377 "local.json": "{}"
1378 }
1379 }
1380 }),
1381 )
1382 .await;
1383 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1384 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1385 let context_server_registry =
1386 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1387 let model = Arc::new(FakeLanguageModel::default());
1388 let thread = cx.new(|cx| {
1389 Thread::new(
1390 project.clone(),
1391 cx.new(|_cx| ProjectContext::default()),
1392 context_server_registry.clone(),
1393 Templates::new(),
1394 Some(model.clone()),
1395 cx,
1396 )
1397 });
1398 let tool = Arc::new(EditFileTool::new(
1399 project.clone(),
1400 thread.downgrade(),
1401 language_registry,
1402 ));
1403
1404 // Test edge cases
1405 let test_cases = vec![
1406 // Empty path - find_project_path returns Some for empty paths
1407 ("", false, "Empty path is treated as project root"),
1408 // Root directory
1409 ("/", true, "Root directory should be outside project"),
1410 // Parent directory references - find_project_path resolves these
1411 (
1412 "project/../other",
1413 false,
1414 "Path with .. is resolved by find_project_path",
1415 ),
1416 (
1417 "project/./src/file.rs",
1418 false,
1419 "Path with . should work normally",
1420 ),
1421 // Windows-style paths (if on Windows)
1422 #[cfg(target_os = "windows")]
1423 ("C:\\Windows\\System32\\hosts", true, "Windows system path"),
1424 #[cfg(target_os = "windows")]
1425 ("project\\src\\main.rs", false, "Windows-style project path"),
1426 ];
1427
1428 for (path, should_confirm, description) in test_cases {
1429 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1430 let auth = cx.update(|cx| {
1431 tool.authorize(
1432 &EditFileToolInput {
1433 display_description: "Edit file".into(),
1434 path: path.into(),
1435 mode: EditFileMode::Edit,
1436 },
1437 &stream_tx,
1438 cx,
1439 )
1440 });
1441
1442 if should_confirm {
1443 stream_rx.expect_authorization().await;
1444 } else {
1445 auth.await.unwrap();
1446 assert!(
1447 stream_rx.try_next().is_err(),
1448 "Failed for case: {} - path: {} - expected no confirmation but got one",
1449 description,
1450 path
1451 );
1452 }
1453 }
1454 }
1455
1456 #[gpui::test]
1457 async fn test_needs_confirmation_with_different_modes(cx: &mut TestAppContext) {
1458 init_test(cx);
1459 let fs = project::FakeFs::new(cx.executor());
1460 fs.insert_tree(
1461 "/project",
1462 json!({
1463 "existing.txt": "content",
1464 ".zed": {
1465 "settings.json": "{}"
1466 }
1467 }),
1468 )
1469 .await;
1470 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1471 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1472 let context_server_registry =
1473 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1474 let model = Arc::new(FakeLanguageModel::default());
1475 let thread = cx.new(|cx| {
1476 Thread::new(
1477 project.clone(),
1478 cx.new(|_cx| ProjectContext::default()),
1479 context_server_registry.clone(),
1480 Templates::new(),
1481 Some(model.clone()),
1482 cx,
1483 )
1484 });
1485 let tool = Arc::new(EditFileTool::new(
1486 project.clone(),
1487 thread.downgrade(),
1488 language_registry,
1489 ));
1490
1491 // Test different EditFileMode values
1492 let modes = vec![
1493 EditFileMode::Edit,
1494 EditFileMode::Create,
1495 EditFileMode::Overwrite,
1496 ];
1497
1498 for mode in modes {
1499 // Test .zed path with different modes
1500 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1501 let _auth = cx.update(|cx| {
1502 tool.authorize(
1503 &EditFileToolInput {
1504 display_description: "Edit settings".into(),
1505 path: "project/.zed/settings.json".into(),
1506 mode: mode.clone(),
1507 },
1508 &stream_tx,
1509 cx,
1510 )
1511 });
1512
1513 stream_rx.expect_authorization().await;
1514
1515 // Test outside path with different modes
1516 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1517 let _auth = cx.update(|cx| {
1518 tool.authorize(
1519 &EditFileToolInput {
1520 display_description: "Edit file".into(),
1521 path: "/outside/file.txt".into(),
1522 mode: mode.clone(),
1523 },
1524 &stream_tx,
1525 cx,
1526 )
1527 });
1528
1529 stream_rx.expect_authorization().await;
1530
1531 // Test normal path with different modes
1532 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1533 cx.update(|cx| {
1534 tool.authorize(
1535 &EditFileToolInput {
1536 display_description: "Edit file".into(),
1537 path: "project/normal.txt".into(),
1538 mode: mode.clone(),
1539 },
1540 &stream_tx,
1541 cx,
1542 )
1543 })
1544 .await
1545 .unwrap();
1546 assert!(stream_rx.try_next().is_err());
1547 }
1548 }
1549
1550 #[gpui::test]
1551 async fn test_initial_title_with_partial_input(cx: &mut TestAppContext) {
1552 init_test(cx);
1553 let fs = project::FakeFs::new(cx.executor());
1554 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1555 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1556 let context_server_registry =
1557 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1558 let model = Arc::new(FakeLanguageModel::default());
1559 let thread = cx.new(|cx| {
1560 Thread::new(
1561 project.clone(),
1562 cx.new(|_cx| ProjectContext::default()),
1563 context_server_registry,
1564 Templates::new(),
1565 Some(model.clone()),
1566 cx,
1567 )
1568 });
1569 let tool = Arc::new(EditFileTool::new(
1570 project,
1571 thread.downgrade(),
1572 language_registry,
1573 ));
1574
1575 cx.update(|cx| {
1576 // ...
1577 assert_eq!(
1578 tool.initial_title(
1579 Err(json!({
1580 "path": "src/main.rs",
1581 "display_description": "",
1582 "old_string": "old code",
1583 "new_string": "new code"
1584 })),
1585 cx
1586 ),
1587 "src/main.rs"
1588 );
1589 assert_eq!(
1590 tool.initial_title(
1591 Err(json!({
1592 "path": "",
1593 "display_description": "Fix error handling",
1594 "old_string": "old code",
1595 "new_string": "new code"
1596 })),
1597 cx
1598 ),
1599 "Fix error handling"
1600 );
1601 assert_eq!(
1602 tool.initial_title(
1603 Err(json!({
1604 "path": "src/main.rs",
1605 "display_description": "Fix error handling",
1606 "old_string": "old code",
1607 "new_string": "new code"
1608 })),
1609 cx
1610 ),
1611 "src/main.rs"
1612 );
1613 assert_eq!(
1614 tool.initial_title(
1615 Err(json!({
1616 "path": "",
1617 "display_description": "",
1618 "old_string": "old code",
1619 "new_string": "new code"
1620 })),
1621 cx
1622 ),
1623 DEFAULT_UI_TEXT
1624 );
1625 assert_eq!(
1626 tool.initial_title(Err(serde_json::Value::Null), cx),
1627 DEFAULT_UI_TEXT
1628 );
1629 });
1630 }
1631
1632 #[gpui::test]
1633 async fn test_diff_finalization(cx: &mut TestAppContext) {
1634 init_test(cx);
1635 let fs = project::FakeFs::new(cx.executor());
1636 fs.insert_tree("/", json!({"main.rs": ""})).await;
1637
1638 let project = Project::test(fs.clone(), [path!("/").as_ref()], cx).await;
1639 let languages = project.read_with(cx, |project, _cx| project.languages().clone());
1640 let context_server_registry =
1641 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1642 let model = Arc::new(FakeLanguageModel::default());
1643 let thread = cx.new(|cx| {
1644 Thread::new(
1645 project.clone(),
1646 cx.new(|_cx| ProjectContext::default()),
1647 context_server_registry.clone(),
1648 Templates::new(),
1649 Some(model.clone()),
1650 cx,
1651 )
1652 });
1653
1654 // Ensure the diff is finalized after the edit completes.
1655 {
1656 let tool = Arc::new(EditFileTool::new(
1657 project.clone(),
1658 thread.downgrade(),
1659 languages.clone(),
1660 ));
1661 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1662 let edit = cx.update(|cx| {
1663 tool.run(
1664 EditFileToolInput {
1665 display_description: "Edit file".into(),
1666 path: path!("/main.rs").into(),
1667 mode: EditFileMode::Edit,
1668 },
1669 stream_tx,
1670 cx,
1671 )
1672 });
1673 stream_rx.expect_update_fields().await;
1674 let diff = stream_rx.expect_diff().await;
1675 diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Pending(_))));
1676 cx.run_until_parked();
1677 model.end_last_completion_stream();
1678 edit.await.unwrap();
1679 diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Finalized(_))));
1680 }
1681
1682 // Ensure the diff is finalized if an error occurs while editing.
1683 {
1684 model.forbid_requests();
1685 let tool = Arc::new(EditFileTool::new(
1686 project.clone(),
1687 thread.downgrade(),
1688 languages.clone(),
1689 ));
1690 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1691 let edit = cx.update(|cx| {
1692 tool.run(
1693 EditFileToolInput {
1694 display_description: "Edit file".into(),
1695 path: path!("/main.rs").into(),
1696 mode: EditFileMode::Edit,
1697 },
1698 stream_tx,
1699 cx,
1700 )
1701 });
1702 stream_rx.expect_update_fields().await;
1703 let diff = stream_rx.expect_diff().await;
1704 diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Pending(_))));
1705 edit.await.unwrap_err();
1706 diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Finalized(_))));
1707 model.allow_requests();
1708 }
1709
1710 // Ensure the diff is finalized if the tool call gets dropped.
1711 {
1712 let tool = Arc::new(EditFileTool::new(
1713 project.clone(),
1714 thread.downgrade(),
1715 languages.clone(),
1716 ));
1717 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1718 let edit = cx.update(|cx| {
1719 tool.run(
1720 EditFileToolInput {
1721 display_description: "Edit file".into(),
1722 path: path!("/main.rs").into(),
1723 mode: EditFileMode::Edit,
1724 },
1725 stream_tx,
1726 cx,
1727 )
1728 });
1729 stream_rx.expect_update_fields().await;
1730 let diff = stream_rx.expect_diff().await;
1731 diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Pending(_))));
1732 drop(edit);
1733 cx.run_until_parked();
1734 diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Finalized(_))));
1735 }
1736 }
1737
1738 fn init_test(cx: &mut TestAppContext) {
1739 cx.update(|cx| {
1740 let settings_store = SettingsStore::test(cx);
1741 cx.set_global(settings_store);
1742 language::init(cx);
1743 TelemetrySettings::register(cx);
1744 agent_settings::AgentSettings::register(cx);
1745 Project::init_settings(cx);
1746 });
1747 }
1748}