1use crate::{AgentTool, Thread, ToolCallEventStream};
2use acp_thread::Diff;
3use agent_client_protocol::{self as acp, ToolCallLocation, ToolCallUpdateFields};
4use anyhow::{Context as _, Result, anyhow};
5use assistant_tools::edit_agent::{EditAgent, EditAgentOutput, EditAgentOutputEvent, EditFormat};
6use cloud_llm_client::CompletionIntent;
7use collections::HashSet;
8use gpui::{App, AppContext, AsyncApp, Entity, Task, WeakEntity};
9use indoc::formatdoc;
10use language::language_settings::{self, FormatOnSave};
11use language::{LanguageRegistry, ToPoint};
12use language_model::LanguageModelToolResultContent;
13use paths;
14use project::lsp_store::{FormatTrigger, LspFormatTarget};
15use project::{Project, ProjectPath};
16use schemars::JsonSchema;
17use serde::{Deserialize, Serialize};
18use settings::Settings;
19use smol::stream::StreamExt as _;
20use std::path::{Path, PathBuf};
21use std::sync::Arc;
22use ui::SharedString;
23use util::ResultExt;
24
25const DEFAULT_UI_TEXT: &str = "Editing file";
26
27/// This is a tool for creating a new file or editing an existing file. For moving or renaming files, you should generally use the `terminal` tool with the 'mv' command instead.
28///
29/// Before using this tool:
30///
31/// 1. Use the `read_file` tool to understand the file's contents and context
32///
33/// 2. Verify the directory path is correct (only applicable when creating new files):
34/// - Use the `list_directory` tool to verify the parent directory exists and is the correct location
35#[derive(Debug, Serialize, Deserialize, JsonSchema)]
36pub struct EditFileToolInput {
37 /// A one-line, user-friendly markdown description of the edit. This will be shown in the UI and also passed to another model to perform the edit.
38 ///
39 /// Be terse, but also descriptive in what you want to achieve with this edit. Avoid generic instructions.
40 ///
41 /// NEVER mention the file path in this description.
42 ///
43 /// <example>Fix API endpoint URLs</example>
44 /// <example>Update copyright year in `page_footer`</example>
45 ///
46 /// Make sure to include this field before all the others in the input object so that we can display it immediately.
47 pub display_description: String,
48
49 /// The full path of the file to create or modify in the project.
50 ///
51 /// WARNING: When specifying which file path need changing, you MUST start each path with one of the project's root directories.
52 ///
53 /// The following examples assume we have two root directories in the project:
54 /// - /a/b/backend
55 /// - /c/d/frontend
56 ///
57 /// <example>
58 /// `backend/src/main.rs`
59 ///
60 /// Notice how the file path starts with `backend`. Without that, the path would be ambiguous and the call would fail!
61 /// </example>
62 ///
63 /// <example>
64 /// `frontend/db.js`
65 /// </example>
66 pub path: PathBuf,
67 /// The mode of operation on the file. Possible values:
68 /// - 'edit': Make granular edits to an existing file.
69 /// - 'create': Create a new file if it doesn't exist.
70 /// - 'overwrite': Replace the entire contents of an existing file.
71 ///
72 /// When a file already exists or you just created it, prefer editing it as opposed to recreating it from scratch.
73 pub mode: EditFileMode,
74}
75
76#[derive(Debug, Serialize, Deserialize, JsonSchema)]
77struct EditFileToolPartialInput {
78 #[serde(default)]
79 path: String,
80 #[serde(default)]
81 display_description: String,
82}
83
84#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
85#[serde(rename_all = "lowercase")]
86#[schemars(inline)]
87pub enum EditFileMode {
88 Edit,
89 Create,
90 Overwrite,
91}
92
93#[derive(Debug, Serialize, Deserialize)]
94pub struct EditFileToolOutput {
95 #[serde(alias = "original_path")]
96 input_path: PathBuf,
97 new_text: String,
98 old_text: Arc<String>,
99 #[serde(default)]
100 diff: String,
101 #[serde(alias = "raw_output")]
102 edit_agent_output: EditAgentOutput,
103}
104
105impl From<EditFileToolOutput> for LanguageModelToolResultContent {
106 fn from(output: EditFileToolOutput) -> Self {
107 if output.diff.is_empty() {
108 "No edits were made.".into()
109 } else {
110 format!(
111 "Edited {}:\n\n```diff\n{}\n```",
112 output.input_path.display(),
113 output.diff
114 )
115 .into()
116 }
117 }
118}
119
120pub struct EditFileTool {
121 thread: WeakEntity<Thread>,
122 language_registry: Arc<LanguageRegistry>,
123}
124
125impl EditFileTool {
126 pub fn new(thread: WeakEntity<Thread>, language_registry: Arc<LanguageRegistry>) -> Self {
127 Self {
128 thread,
129 language_registry,
130 }
131 }
132
133 fn authorize(
134 &self,
135 input: &EditFileToolInput,
136 event_stream: &ToolCallEventStream,
137 cx: &mut App,
138 ) -> Task<Result<()>> {
139 if agent_settings::AgentSettings::get_global(cx).always_allow_tool_actions {
140 return Task::ready(Ok(()));
141 }
142
143 // If any path component matches the local settings folder, then this could affect
144 // the editor in ways beyond the project source, so prompt.
145 let local_settings_folder = paths::local_settings_folder_relative_path();
146 let path = Path::new(&input.path);
147 if path
148 .components()
149 .any(|component| component.as_os_str() == local_settings_folder.as_os_str())
150 {
151 return event_stream.authorize(
152 format!("{} (local settings)", input.display_description),
153 cx,
154 );
155 }
156
157 // It's also possible that the global config dir is configured to be inside the project,
158 // so check for that edge case too.
159 if let Ok(canonical_path) = std::fs::canonicalize(&input.path)
160 && canonical_path.starts_with(paths::config_dir())
161 {
162 return event_stream.authorize(
163 format!("{} (global settings)", input.display_description),
164 cx,
165 );
166 }
167
168 // Check if path is inside the global config directory
169 // First check if it's already inside project - if not, try to canonicalize
170 let Ok(project_path) = self.thread.read_with(cx, |thread, cx| {
171 thread.project().read(cx).find_project_path(&input.path, cx)
172 }) else {
173 return Task::ready(Err(anyhow!("thread was dropped")));
174 };
175
176 // If the path is inside the project, and it's not one of the above edge cases,
177 // then no confirmation is necessary. Otherwise, confirmation is necessary.
178 if project_path.is_some() {
179 Task::ready(Ok(()))
180 } else {
181 event_stream.authorize(&input.display_description, cx)
182 }
183 }
184}
185
186impl AgentTool for EditFileTool {
187 type Input = EditFileToolInput;
188 type Output = EditFileToolOutput;
189
190 fn name() -> &'static str {
191 "edit_file"
192 }
193
194 fn kind() -> acp::ToolKind {
195 acp::ToolKind::Edit
196 }
197
198 fn initial_title(&self, input: Result<Self::Input, serde_json::Value>) -> SharedString {
199 match input {
200 Ok(input) => input.display_description.into(),
201 Err(raw_input) => {
202 if let Some(input) =
203 serde_json::from_value::<EditFileToolPartialInput>(raw_input).ok()
204 {
205 let description = input.display_description.trim();
206 if !description.is_empty() {
207 return description.to_string().into();
208 }
209
210 let path = input.path.trim().to_string();
211 if !path.is_empty() {
212 return path.into();
213 }
214 }
215
216 DEFAULT_UI_TEXT.into()
217 }
218 }
219 }
220
221 fn run(
222 self: Arc<Self>,
223 input: Self::Input,
224 event_stream: ToolCallEventStream,
225 cx: &mut App,
226 ) -> Task<Result<Self::Output>> {
227 let Ok(project) = self
228 .thread
229 .read_with(cx, |thread, _cx| thread.project().clone())
230 else {
231 return Task::ready(Err(anyhow!("thread was dropped")));
232 };
233 let project_path = match resolve_path(&input, project.clone(), cx) {
234 Ok(path) => path,
235 Err(err) => return Task::ready(Err(anyhow!(err))),
236 };
237 let abs_path = project.read(cx).absolute_path(&project_path, cx);
238 if let Some(abs_path) = abs_path.clone() {
239 event_stream.update_fields(ToolCallUpdateFields {
240 locations: Some(vec![acp::ToolCallLocation {
241 path: abs_path,
242 line: None,
243 }]),
244 ..Default::default()
245 });
246 }
247
248 let authorize = self.authorize(&input, &event_stream, cx);
249 cx.spawn(async move |cx: &mut AsyncApp| {
250 authorize.await?;
251
252 let (request, model, action_log) = self.thread.update(cx, |thread, cx| {
253 let request = thread.build_completion_request(CompletionIntent::ToolResults, cx);
254 (request, thread.model().cloned(), thread.action_log().clone())
255 })?;
256 let request = request?;
257 let model = model.context("No language model configured")?;
258
259 let edit_format = EditFormat::from_model(model.clone())?;
260 let edit_agent = EditAgent::new(
261 model,
262 project.clone(),
263 action_log.clone(),
264 // TODO: move edit agent to this crate so we can use our templates
265 assistant_tools::templates::Templates::new(),
266 edit_format,
267 );
268
269 let buffer = project
270 .update(cx, |project, cx| {
271 project.open_buffer(project_path.clone(), cx)
272 })?
273 .await?;
274
275 let diff = cx.new(|cx| Diff::new(buffer.clone(), cx))?;
276 event_stream.update_diff(diff.clone());
277 let _finalize_diff = util::defer({
278 let diff = diff.downgrade();
279 let mut cx = cx.clone();
280 move || {
281 diff.update(&mut cx, |diff, cx| diff.finalize(cx)).ok();
282 }
283 });
284
285 let old_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
286 let old_text = cx
287 .background_spawn({
288 let old_snapshot = old_snapshot.clone();
289 async move { Arc::new(old_snapshot.text()) }
290 })
291 .await;
292
293
294 let (output, mut events) = if matches!(input.mode, EditFileMode::Edit) {
295 edit_agent.edit(
296 buffer.clone(),
297 input.display_description.clone(),
298 &request,
299 cx,
300 )
301 } else {
302 edit_agent.overwrite(
303 buffer.clone(),
304 input.display_description.clone(),
305 &request,
306 cx,
307 )
308 };
309
310 let mut hallucinated_old_text = false;
311 let mut ambiguous_ranges = Vec::new();
312 let mut emitted_location = false;
313 while let Some(event) = events.next().await {
314 match event {
315 EditAgentOutputEvent::Edited(range) => {
316 if !emitted_location {
317 let line = buffer.update(cx, |buffer, _cx| {
318 range.start.to_point(&buffer.snapshot()).row
319 }).ok();
320 if let Some(abs_path) = abs_path.clone() {
321 event_stream.update_fields(ToolCallUpdateFields {
322 locations: Some(vec![ToolCallLocation { path: abs_path, line }]),
323 ..Default::default()
324 });
325 }
326 emitted_location = true;
327 }
328 },
329 EditAgentOutputEvent::UnresolvedEditRange => hallucinated_old_text = true,
330 EditAgentOutputEvent::AmbiguousEditRange(ranges) => ambiguous_ranges = ranges,
331 EditAgentOutputEvent::ResolvingEditRange(range) => {
332 diff.update(cx, |card, cx| card.reveal_range(range.clone(), cx))?;
333 // if !emitted_location {
334 // let line = buffer.update(cx, |buffer, _cx| {
335 // range.start.to_point(&buffer.snapshot()).row
336 // }).ok();
337 // if let Some(abs_path) = abs_path.clone() {
338 // event_stream.update_fields(ToolCallUpdateFields {
339 // locations: Some(vec![ToolCallLocation { path: abs_path, line }]),
340 // ..Default::default()
341 // });
342 // }
343 // }
344 }
345 }
346 }
347
348 // If format_on_save is enabled, format the buffer
349 let format_on_save_enabled = buffer
350 .read_with(cx, |buffer, cx| {
351 let settings = language_settings::language_settings(
352 buffer.language().map(|l| l.name()),
353 buffer.file(),
354 cx,
355 );
356 settings.format_on_save != FormatOnSave::Off
357 })
358 .unwrap_or(false);
359
360 let edit_agent_output = output.await?;
361
362 if format_on_save_enabled {
363 action_log.update(cx, |log, cx| {
364 log.buffer_edited(buffer.clone(), cx);
365 })?;
366
367 let format_task = project.update(cx, |project, cx| {
368 project.format(
369 HashSet::from_iter([buffer.clone()]),
370 LspFormatTarget::Buffers,
371 false, // Don't push to history since the tool did it.
372 FormatTrigger::Save,
373 cx,
374 )
375 })?;
376 format_task.await.log_err();
377 }
378
379 project
380 .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))?
381 .await?;
382
383 action_log.update(cx, |log, cx| {
384 log.buffer_edited(buffer.clone(), cx);
385 })?;
386
387 let new_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
388 let (new_text, unified_diff) = cx
389 .background_spawn({
390 let new_snapshot = new_snapshot.clone();
391 let old_text = old_text.clone();
392 async move {
393 let new_text = new_snapshot.text();
394 let diff = language::unified_diff(&old_text, &new_text);
395 (new_text, diff)
396 }
397 })
398 .await;
399
400 let input_path = input.path.display();
401 if unified_diff.is_empty() {
402 anyhow::ensure!(
403 !hallucinated_old_text,
404 formatdoc! {"
405 Some edits were produced but none of them could be applied.
406 Read the relevant sections of {input_path} again so that
407 I can perform the requested edits.
408 "}
409 );
410 anyhow::ensure!(
411 ambiguous_ranges.is_empty(),
412 {
413 let line_numbers = ambiguous_ranges
414 .iter()
415 .map(|range| range.start.to_string())
416 .collect::<Vec<_>>()
417 .join(", ");
418 formatdoc! {"
419 <old_text> matches more than one position in the file (lines: {line_numbers}). Read the
420 relevant sections of {input_path} again and extend <old_text> so
421 that I can perform the requested edits.
422 "}
423 }
424 );
425 }
426
427 Ok(EditFileToolOutput {
428 input_path: input.path,
429 new_text,
430 old_text,
431 diff: unified_diff,
432 edit_agent_output,
433 })
434 })
435 }
436
437 fn replay(
438 &self,
439 _input: Self::Input,
440 output: Self::Output,
441 event_stream: ToolCallEventStream,
442 cx: &mut App,
443 ) -> Result<()> {
444 event_stream.update_diff(cx.new(|cx| {
445 Diff::finalized(
446 output.input_path,
447 Some(output.old_text.to_string()),
448 output.new_text,
449 self.language_registry.clone(),
450 cx,
451 )
452 }));
453 Ok(())
454 }
455}
456
457/// Validate that the file path is valid, meaning:
458///
459/// - For `edit` and `overwrite`, the path must point to an existing file.
460/// - For `create`, the file must not already exist, but it's parent dir must exist.
461fn resolve_path(
462 input: &EditFileToolInput,
463 project: Entity<Project>,
464 cx: &mut App,
465) -> Result<ProjectPath> {
466 let project = project.read(cx);
467
468 match input.mode {
469 EditFileMode::Edit | EditFileMode::Overwrite => {
470 let path = project
471 .find_project_path(&input.path, cx)
472 .context("Can't edit file: path not found")?;
473
474 let entry = project
475 .entry_for_path(&path, cx)
476 .context("Can't edit file: path not found")?;
477
478 anyhow::ensure!(entry.is_file(), "Can't edit file: path is a directory");
479 Ok(path)
480 }
481
482 EditFileMode::Create => {
483 if let Some(path) = project.find_project_path(&input.path, cx) {
484 anyhow::ensure!(
485 project.entry_for_path(&path, cx).is_none(),
486 "Can't create file: file already exists"
487 );
488 }
489
490 let parent_path = input
491 .path
492 .parent()
493 .context("Can't create file: incorrect path")?;
494
495 let parent_project_path = project.find_project_path(&parent_path, cx);
496
497 let parent_entry = parent_project_path
498 .as_ref()
499 .and_then(|path| project.entry_for_path(path, cx))
500 .context("Can't create file: parent directory doesn't exist")?;
501
502 anyhow::ensure!(
503 parent_entry.is_dir(),
504 "Can't create file: parent is not a directory"
505 );
506
507 let file_name = input
508 .path
509 .file_name()
510 .context("Can't create file: invalid filename")?;
511
512 let new_file_path = parent_project_path.map(|parent| ProjectPath {
513 path: Arc::from(parent.path.join(file_name)),
514 ..parent
515 });
516
517 new_file_path.context("Can't create file")
518 }
519 }
520}
521
522#[cfg(test)]
523mod tests {
524 use super::*;
525 use crate::{ContextServerRegistry, Templates};
526 use client::TelemetrySettings;
527 use fs::Fs;
528 use gpui::{TestAppContext, UpdateGlobal};
529 use language_model::fake_provider::FakeLanguageModel;
530 use prompt_store::ProjectContext;
531 use serde_json::json;
532 use settings::SettingsStore;
533 use util::path;
534
535 #[gpui::test]
536 async fn test_edit_nonexistent_file(cx: &mut TestAppContext) {
537 init_test(cx);
538
539 let fs = project::FakeFs::new(cx.executor());
540 fs.insert_tree("/root", json!({})).await;
541 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
542 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
543 let context_server_registry =
544 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
545 let model = Arc::new(FakeLanguageModel::default());
546 let thread = cx.new(|cx| {
547 Thread::new(
548 project,
549 cx.new(|_cx| ProjectContext::default()),
550 context_server_registry,
551 Templates::new(),
552 Some(model),
553 cx,
554 )
555 });
556 let result = cx
557 .update(|cx| {
558 let input = EditFileToolInput {
559 display_description: "Some edit".into(),
560 path: "root/nonexistent_file.txt".into(),
561 mode: EditFileMode::Edit,
562 };
563 Arc::new(EditFileTool::new(thread.downgrade(), language_registry)).run(
564 input,
565 ToolCallEventStream::test().0,
566 cx,
567 )
568 })
569 .await;
570 assert_eq!(
571 result.unwrap_err().to_string(),
572 "Can't edit file: path not found"
573 );
574 }
575
576 #[gpui::test]
577 async fn test_resolve_path_for_creating_file(cx: &mut TestAppContext) {
578 let mode = &EditFileMode::Create;
579
580 let result = test_resolve_path(mode, "root/new.txt", cx);
581 assert_resolved_path_eq(result.await, "new.txt");
582
583 let result = test_resolve_path(mode, "new.txt", cx);
584 assert_resolved_path_eq(result.await, "new.txt");
585
586 let result = test_resolve_path(mode, "dir/new.txt", cx);
587 assert_resolved_path_eq(result.await, "dir/new.txt");
588
589 let result = test_resolve_path(mode, "root/dir/subdir/existing.txt", cx);
590 assert_eq!(
591 result.await.unwrap_err().to_string(),
592 "Can't create file: file already exists"
593 );
594
595 let result = test_resolve_path(mode, "root/dir/nonexistent_dir/new.txt", cx);
596 assert_eq!(
597 result.await.unwrap_err().to_string(),
598 "Can't create file: parent directory doesn't exist"
599 );
600 }
601
602 #[gpui::test]
603 async fn test_resolve_path_for_editing_file(cx: &mut TestAppContext) {
604 let mode = &EditFileMode::Edit;
605
606 let path_with_root = "root/dir/subdir/existing.txt";
607 let path_without_root = "dir/subdir/existing.txt";
608 let result = test_resolve_path(mode, path_with_root, cx);
609 assert_resolved_path_eq(result.await, path_without_root);
610
611 let result = test_resolve_path(mode, path_without_root, cx);
612 assert_resolved_path_eq(result.await, path_without_root);
613
614 let result = test_resolve_path(mode, "root/nonexistent.txt", cx);
615 assert_eq!(
616 result.await.unwrap_err().to_string(),
617 "Can't edit file: path not found"
618 );
619
620 let result = test_resolve_path(mode, "root/dir", cx);
621 assert_eq!(
622 result.await.unwrap_err().to_string(),
623 "Can't edit file: path is a directory"
624 );
625 }
626
627 async fn test_resolve_path(
628 mode: &EditFileMode,
629 path: &str,
630 cx: &mut TestAppContext,
631 ) -> anyhow::Result<ProjectPath> {
632 init_test(cx);
633
634 let fs = project::FakeFs::new(cx.executor());
635 fs.insert_tree(
636 "/root",
637 json!({
638 "dir": {
639 "subdir": {
640 "existing.txt": "hello"
641 }
642 }
643 }),
644 )
645 .await;
646 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
647
648 let input = EditFileToolInput {
649 display_description: "Some edit".into(),
650 path: path.into(),
651 mode: mode.clone(),
652 };
653
654 cx.update(|cx| resolve_path(&input, project, cx))
655 }
656
657 fn assert_resolved_path_eq(path: anyhow::Result<ProjectPath>, expected: &str) {
658 let actual = path
659 .expect("Should return valid path")
660 .path
661 .to_str()
662 .unwrap()
663 .replace("\\", "/"); // Naive Windows paths normalization
664 assert_eq!(actual, expected);
665 }
666
667 #[gpui::test]
668 async fn test_format_on_save(cx: &mut TestAppContext) {
669 init_test(cx);
670
671 let fs = project::FakeFs::new(cx.executor());
672 fs.insert_tree("/root", json!({"src": {}})).await;
673
674 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
675
676 // Set up a Rust language with LSP formatting support
677 let rust_language = Arc::new(language::Language::new(
678 language::LanguageConfig {
679 name: "Rust".into(),
680 matcher: language::LanguageMatcher {
681 path_suffixes: vec!["rs".to_string()],
682 ..Default::default()
683 },
684 ..Default::default()
685 },
686 None,
687 ));
688
689 // Register the language and fake LSP
690 let language_registry = project.read_with(cx, |project, _| project.languages().clone());
691 language_registry.add(rust_language);
692
693 let mut fake_language_servers = language_registry.register_fake_lsp(
694 "Rust",
695 language::FakeLspAdapter {
696 capabilities: lsp::ServerCapabilities {
697 document_formatting_provider: Some(lsp::OneOf::Left(true)),
698 ..Default::default()
699 },
700 ..Default::default()
701 },
702 );
703
704 // Create the file
705 fs.save(
706 path!("/root/src/main.rs").as_ref(),
707 &"initial content".into(),
708 language::LineEnding::Unix,
709 )
710 .await
711 .unwrap();
712
713 // Open the buffer to trigger LSP initialization
714 let buffer = project
715 .update(cx, |project, cx| {
716 project.open_local_buffer(path!("/root/src/main.rs"), cx)
717 })
718 .await
719 .unwrap();
720
721 // Register the buffer with language servers
722 let _handle = project.update(cx, |project, cx| {
723 project.register_buffer_with_language_servers(&buffer, cx)
724 });
725
726 const UNFORMATTED_CONTENT: &str = "fn main() {println!(\"Hello!\");}\n";
727 const FORMATTED_CONTENT: &str =
728 "This file was formatted by the fake formatter in the test.\n";
729
730 // Get the fake language server and set up formatting handler
731 let fake_language_server = fake_language_servers.next().await.unwrap();
732 fake_language_server.set_request_handler::<lsp::request::Formatting, _, _>({
733 |_, _| async move {
734 Ok(Some(vec![lsp::TextEdit {
735 range: lsp::Range::new(lsp::Position::new(0, 0), lsp::Position::new(1, 0)),
736 new_text: FORMATTED_CONTENT.to_string(),
737 }]))
738 }
739 });
740
741 let context_server_registry =
742 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
743 let model = Arc::new(FakeLanguageModel::default());
744 let thread = cx.new(|cx| {
745 Thread::new(
746 project,
747 cx.new(|_cx| ProjectContext::default()),
748 context_server_registry,
749 Templates::new(),
750 Some(model.clone()),
751 cx,
752 )
753 });
754
755 // First, test with format_on_save enabled
756 cx.update(|cx| {
757 SettingsStore::update_global(cx, |store, cx| {
758 store.update_user_settings::<language::language_settings::AllLanguageSettings>(
759 cx,
760 |settings| {
761 settings.defaults.format_on_save = Some(FormatOnSave::On);
762 settings.defaults.formatter =
763 Some(language::language_settings::SelectedFormatter::Auto);
764 },
765 );
766 });
767 });
768
769 // Have the model stream unformatted content
770 let edit_result = {
771 let edit_task = cx.update(|cx| {
772 let input = EditFileToolInput {
773 display_description: "Create main function".into(),
774 path: "root/src/main.rs".into(),
775 mode: EditFileMode::Overwrite,
776 };
777 Arc::new(EditFileTool::new(
778 thread.downgrade(),
779 language_registry.clone(),
780 ))
781 .run(input, ToolCallEventStream::test().0, cx)
782 });
783
784 // Stream the unformatted content
785 cx.executor().run_until_parked();
786 model.send_last_completion_stream_text_chunk(UNFORMATTED_CONTENT.to_string());
787 model.end_last_completion_stream();
788
789 edit_task.await
790 };
791 assert!(edit_result.is_ok());
792
793 // Wait for any async operations (e.g. formatting) to complete
794 cx.executor().run_until_parked();
795
796 // Read the file to verify it was formatted automatically
797 let new_content = fs.load(path!("/root/src/main.rs").as_ref()).await.unwrap();
798 assert_eq!(
799 // Ignore carriage returns on Windows
800 new_content.replace("\r\n", "\n"),
801 FORMATTED_CONTENT,
802 "Code should be formatted when format_on_save is enabled"
803 );
804
805 let stale_buffer_count = thread
806 .read_with(cx, |thread, _cx| thread.action_log.clone())
807 .read_with(cx, |log, cx| log.stale_buffers(cx).count());
808
809 assert_eq!(
810 stale_buffer_count, 0,
811 "BUG: Buffer is incorrectly marked as stale after format-on-save. Found {} stale buffers. \
812 This causes the agent to think the file was modified externally when it was just formatted.",
813 stale_buffer_count
814 );
815
816 // Next, test with format_on_save disabled
817 cx.update(|cx| {
818 SettingsStore::update_global(cx, |store, cx| {
819 store.update_user_settings::<language::language_settings::AllLanguageSettings>(
820 cx,
821 |settings| {
822 settings.defaults.format_on_save = Some(FormatOnSave::Off);
823 },
824 );
825 });
826 });
827
828 // Stream unformatted edits again
829 let edit_result = {
830 let edit_task = cx.update(|cx| {
831 let input = EditFileToolInput {
832 display_description: "Update main function".into(),
833 path: "root/src/main.rs".into(),
834 mode: EditFileMode::Overwrite,
835 };
836 Arc::new(EditFileTool::new(thread.downgrade(), language_registry)).run(
837 input,
838 ToolCallEventStream::test().0,
839 cx,
840 )
841 });
842
843 // Stream the unformatted content
844 cx.executor().run_until_parked();
845 model.send_last_completion_stream_text_chunk(UNFORMATTED_CONTENT.to_string());
846 model.end_last_completion_stream();
847
848 edit_task.await
849 };
850 assert!(edit_result.is_ok());
851
852 // Wait for any async operations (e.g. formatting) to complete
853 cx.executor().run_until_parked();
854
855 // Verify the file was not formatted
856 let new_content = fs.load(path!("/root/src/main.rs").as_ref()).await.unwrap();
857 assert_eq!(
858 // Ignore carriage returns on Windows
859 new_content.replace("\r\n", "\n"),
860 UNFORMATTED_CONTENT,
861 "Code should not be formatted when format_on_save is disabled"
862 );
863 }
864
865 #[gpui::test]
866 async fn test_remove_trailing_whitespace(cx: &mut TestAppContext) {
867 init_test(cx);
868
869 let fs = project::FakeFs::new(cx.executor());
870 fs.insert_tree("/root", json!({"src": {}})).await;
871
872 // Create a simple file with trailing whitespace
873 fs.save(
874 path!("/root/src/main.rs").as_ref(),
875 &"initial content".into(),
876 language::LineEnding::Unix,
877 )
878 .await
879 .unwrap();
880
881 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
882 let context_server_registry =
883 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
884 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
885 let model = Arc::new(FakeLanguageModel::default());
886 let thread = cx.new(|cx| {
887 Thread::new(
888 project,
889 cx.new(|_cx| ProjectContext::default()),
890 context_server_registry,
891 Templates::new(),
892 Some(model.clone()),
893 cx,
894 )
895 });
896
897 // First, test with remove_trailing_whitespace_on_save enabled
898 cx.update(|cx| {
899 SettingsStore::update_global(cx, |store, cx| {
900 store.update_user_settings::<language::language_settings::AllLanguageSettings>(
901 cx,
902 |settings| {
903 settings.defaults.remove_trailing_whitespace_on_save = Some(true);
904 },
905 );
906 });
907 });
908
909 const CONTENT_WITH_TRAILING_WHITESPACE: &str =
910 "fn main() { \n println!(\"Hello!\"); \n}\n";
911
912 // Have the model stream content that contains trailing whitespace
913 let edit_result = {
914 let edit_task = cx.update(|cx| {
915 let input = EditFileToolInput {
916 display_description: "Create main function".into(),
917 path: "root/src/main.rs".into(),
918 mode: EditFileMode::Overwrite,
919 };
920 Arc::new(EditFileTool::new(
921 thread.downgrade(),
922 language_registry.clone(),
923 ))
924 .run(input, ToolCallEventStream::test().0, cx)
925 });
926
927 // Stream the content with trailing whitespace
928 cx.executor().run_until_parked();
929 model.send_last_completion_stream_text_chunk(
930 CONTENT_WITH_TRAILING_WHITESPACE.to_string(),
931 );
932 model.end_last_completion_stream();
933
934 edit_task.await
935 };
936 assert!(edit_result.is_ok());
937
938 // Wait for any async operations (e.g. formatting) to complete
939 cx.executor().run_until_parked();
940
941 // Read the file to verify trailing whitespace was removed automatically
942 assert_eq!(
943 // Ignore carriage returns on Windows
944 fs.load(path!("/root/src/main.rs").as_ref())
945 .await
946 .unwrap()
947 .replace("\r\n", "\n"),
948 "fn main() {\n println!(\"Hello!\");\n}\n",
949 "Trailing whitespace should be removed when remove_trailing_whitespace_on_save is enabled"
950 );
951
952 // Next, test with remove_trailing_whitespace_on_save disabled
953 cx.update(|cx| {
954 SettingsStore::update_global(cx, |store, cx| {
955 store.update_user_settings::<language::language_settings::AllLanguageSettings>(
956 cx,
957 |settings| {
958 settings.defaults.remove_trailing_whitespace_on_save = Some(false);
959 },
960 );
961 });
962 });
963
964 // Stream edits again with trailing whitespace
965 let edit_result = {
966 let edit_task = cx.update(|cx| {
967 let input = EditFileToolInput {
968 display_description: "Update main function".into(),
969 path: "root/src/main.rs".into(),
970 mode: EditFileMode::Overwrite,
971 };
972 Arc::new(EditFileTool::new(thread.downgrade(), language_registry)).run(
973 input,
974 ToolCallEventStream::test().0,
975 cx,
976 )
977 });
978
979 // Stream the content with trailing whitespace
980 cx.executor().run_until_parked();
981 model.send_last_completion_stream_text_chunk(
982 CONTENT_WITH_TRAILING_WHITESPACE.to_string(),
983 );
984 model.end_last_completion_stream();
985
986 edit_task.await
987 };
988 assert!(edit_result.is_ok());
989
990 // Wait for any async operations (e.g. formatting) to complete
991 cx.executor().run_until_parked();
992
993 // Verify the file still has trailing whitespace
994 // Read the file again - it should still have trailing whitespace
995 let final_content = fs.load(path!("/root/src/main.rs").as_ref()).await.unwrap();
996 assert_eq!(
997 // Ignore carriage returns on Windows
998 final_content.replace("\r\n", "\n"),
999 CONTENT_WITH_TRAILING_WHITESPACE,
1000 "Trailing whitespace should remain when remove_trailing_whitespace_on_save is disabled"
1001 );
1002 }
1003
1004 #[gpui::test]
1005 async fn test_authorize(cx: &mut TestAppContext) {
1006 init_test(cx);
1007 let fs = project::FakeFs::new(cx.executor());
1008 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
1009 let context_server_registry =
1010 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1011 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1012 let model = Arc::new(FakeLanguageModel::default());
1013 let thread = cx.new(|cx| {
1014 Thread::new(
1015 project,
1016 cx.new(|_cx| ProjectContext::default()),
1017 context_server_registry,
1018 Templates::new(),
1019 Some(model.clone()),
1020 cx,
1021 )
1022 });
1023 let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry));
1024 fs.insert_tree("/root", json!({})).await;
1025
1026 // Test 1: Path with .zed component should require confirmation
1027 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1028 let _auth = cx.update(|cx| {
1029 tool.authorize(
1030 &EditFileToolInput {
1031 display_description: "test 1".into(),
1032 path: ".zed/settings.json".into(),
1033 mode: EditFileMode::Edit,
1034 },
1035 &stream_tx,
1036 cx,
1037 )
1038 });
1039
1040 let event = stream_rx.expect_authorization().await;
1041 assert_eq!(
1042 event.tool_call.fields.title,
1043 Some("test 1 (local settings)".into())
1044 );
1045
1046 // Test 2: Path outside project should require confirmation
1047 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1048 let _auth = cx.update(|cx| {
1049 tool.authorize(
1050 &EditFileToolInput {
1051 display_description: "test 2".into(),
1052 path: "/etc/hosts".into(),
1053 mode: EditFileMode::Edit,
1054 },
1055 &stream_tx,
1056 cx,
1057 )
1058 });
1059
1060 let event = stream_rx.expect_authorization().await;
1061 assert_eq!(event.tool_call.fields.title, Some("test 2".into()));
1062
1063 // Test 3: Relative path without .zed should not require confirmation
1064 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1065 cx.update(|cx| {
1066 tool.authorize(
1067 &EditFileToolInput {
1068 display_description: "test 3".into(),
1069 path: "root/src/main.rs".into(),
1070 mode: EditFileMode::Edit,
1071 },
1072 &stream_tx,
1073 cx,
1074 )
1075 })
1076 .await
1077 .unwrap();
1078 assert!(stream_rx.try_next().is_err());
1079
1080 // Test 4: Path with .zed in the middle should require confirmation
1081 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1082 let _auth = cx.update(|cx| {
1083 tool.authorize(
1084 &EditFileToolInput {
1085 display_description: "test 4".into(),
1086 path: "root/.zed/tasks.json".into(),
1087 mode: EditFileMode::Edit,
1088 },
1089 &stream_tx,
1090 cx,
1091 )
1092 });
1093 let event = stream_rx.expect_authorization().await;
1094 assert_eq!(
1095 event.tool_call.fields.title,
1096 Some("test 4 (local settings)".into())
1097 );
1098
1099 // Test 5: When always_allow_tool_actions is enabled, no confirmation needed
1100 cx.update(|cx| {
1101 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
1102 settings.always_allow_tool_actions = true;
1103 agent_settings::AgentSettings::override_global(settings, cx);
1104 });
1105
1106 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1107 cx.update(|cx| {
1108 tool.authorize(
1109 &EditFileToolInput {
1110 display_description: "test 5.1".into(),
1111 path: ".zed/settings.json".into(),
1112 mode: EditFileMode::Edit,
1113 },
1114 &stream_tx,
1115 cx,
1116 )
1117 })
1118 .await
1119 .unwrap();
1120 assert!(stream_rx.try_next().is_err());
1121
1122 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1123 cx.update(|cx| {
1124 tool.authorize(
1125 &EditFileToolInput {
1126 display_description: "test 5.2".into(),
1127 path: "/etc/hosts".into(),
1128 mode: EditFileMode::Edit,
1129 },
1130 &stream_tx,
1131 cx,
1132 )
1133 })
1134 .await
1135 .unwrap();
1136 assert!(stream_rx.try_next().is_err());
1137 }
1138
1139 #[gpui::test]
1140 async fn test_authorize_global_config(cx: &mut TestAppContext) {
1141 init_test(cx);
1142 let fs = project::FakeFs::new(cx.executor());
1143 fs.insert_tree("/project", json!({})).await;
1144 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1145 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1146 let context_server_registry =
1147 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1148 let model = Arc::new(FakeLanguageModel::default());
1149 let thread = cx.new(|cx| {
1150 Thread::new(
1151 project,
1152 cx.new(|_cx| ProjectContext::default()),
1153 context_server_registry,
1154 Templates::new(),
1155 Some(model.clone()),
1156 cx,
1157 )
1158 });
1159 let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry));
1160
1161 // Test global config paths - these should require confirmation if they exist and are outside the project
1162 let test_cases = vec![
1163 (
1164 "/etc/hosts",
1165 true,
1166 "System file should require confirmation",
1167 ),
1168 (
1169 "/usr/local/bin/script",
1170 true,
1171 "System bin file should require confirmation",
1172 ),
1173 (
1174 "project/normal_file.rs",
1175 false,
1176 "Normal project file should not require confirmation",
1177 ),
1178 ];
1179
1180 for (path, should_confirm, description) in test_cases {
1181 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1182 let auth = cx.update(|cx| {
1183 tool.authorize(
1184 &EditFileToolInput {
1185 display_description: "Edit file".into(),
1186 path: path.into(),
1187 mode: EditFileMode::Edit,
1188 },
1189 &stream_tx,
1190 cx,
1191 )
1192 });
1193
1194 if should_confirm {
1195 stream_rx.expect_authorization().await;
1196 } else {
1197 auth.await.unwrap();
1198 assert!(
1199 stream_rx.try_next().is_err(),
1200 "Failed for case: {} - path: {} - expected no confirmation but got one",
1201 description,
1202 path
1203 );
1204 }
1205 }
1206 }
1207
1208 #[gpui::test]
1209 async fn test_needs_confirmation_with_multiple_worktrees(cx: &mut TestAppContext) {
1210 init_test(cx);
1211 let fs = project::FakeFs::new(cx.executor());
1212
1213 // Create multiple worktree directories
1214 fs.insert_tree(
1215 "/workspace/frontend",
1216 json!({
1217 "src": {
1218 "main.js": "console.log('frontend');"
1219 }
1220 }),
1221 )
1222 .await;
1223 fs.insert_tree(
1224 "/workspace/backend",
1225 json!({
1226 "src": {
1227 "main.rs": "fn main() {}"
1228 }
1229 }),
1230 )
1231 .await;
1232 fs.insert_tree(
1233 "/workspace/shared",
1234 json!({
1235 ".zed": {
1236 "settings.json": "{}"
1237 }
1238 }),
1239 )
1240 .await;
1241
1242 // Create project with multiple worktrees
1243 let project = Project::test(
1244 fs.clone(),
1245 [
1246 path!("/workspace/frontend").as_ref(),
1247 path!("/workspace/backend").as_ref(),
1248 path!("/workspace/shared").as_ref(),
1249 ],
1250 cx,
1251 )
1252 .await;
1253 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1254 let context_server_registry =
1255 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1256 let model = Arc::new(FakeLanguageModel::default());
1257 let thread = cx.new(|cx| {
1258 Thread::new(
1259 project.clone(),
1260 cx.new(|_cx| ProjectContext::default()),
1261 context_server_registry.clone(),
1262 Templates::new(),
1263 Some(model.clone()),
1264 cx,
1265 )
1266 });
1267 let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry));
1268
1269 // Test files in different worktrees
1270 let test_cases = vec![
1271 ("frontend/src/main.js", false, "File in first worktree"),
1272 ("backend/src/main.rs", false, "File in second worktree"),
1273 (
1274 "shared/.zed/settings.json",
1275 true,
1276 ".zed file in third worktree",
1277 ),
1278 ("/etc/hosts", true, "Absolute path outside all worktrees"),
1279 (
1280 "../outside/file.txt",
1281 true,
1282 "Relative path outside worktrees",
1283 ),
1284 ];
1285
1286 for (path, should_confirm, description) in test_cases {
1287 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1288 let auth = cx.update(|cx| {
1289 tool.authorize(
1290 &EditFileToolInput {
1291 display_description: "Edit file".into(),
1292 path: path.into(),
1293 mode: EditFileMode::Edit,
1294 },
1295 &stream_tx,
1296 cx,
1297 )
1298 });
1299
1300 if should_confirm {
1301 stream_rx.expect_authorization().await;
1302 } else {
1303 auth.await.unwrap();
1304 assert!(
1305 stream_rx.try_next().is_err(),
1306 "Failed for case: {} - path: {} - expected no confirmation but got one",
1307 description,
1308 path
1309 );
1310 }
1311 }
1312 }
1313
1314 #[gpui::test]
1315 async fn test_needs_confirmation_edge_cases(cx: &mut TestAppContext) {
1316 init_test(cx);
1317 let fs = project::FakeFs::new(cx.executor());
1318 fs.insert_tree(
1319 "/project",
1320 json!({
1321 ".zed": {
1322 "settings.json": "{}"
1323 },
1324 "src": {
1325 ".zed": {
1326 "local.json": "{}"
1327 }
1328 }
1329 }),
1330 )
1331 .await;
1332 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1333 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1334 let context_server_registry =
1335 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1336 let model = Arc::new(FakeLanguageModel::default());
1337 let thread = cx.new(|cx| {
1338 Thread::new(
1339 project.clone(),
1340 cx.new(|_cx| ProjectContext::default()),
1341 context_server_registry.clone(),
1342 Templates::new(),
1343 Some(model.clone()),
1344 cx,
1345 )
1346 });
1347 let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry));
1348
1349 // Test edge cases
1350 let test_cases = vec![
1351 // Empty path - find_project_path returns Some for empty paths
1352 ("", false, "Empty path is treated as project root"),
1353 // Root directory
1354 ("/", true, "Root directory should be outside project"),
1355 // Parent directory references - find_project_path resolves these
1356 (
1357 "project/../other",
1358 false,
1359 "Path with .. is resolved by find_project_path",
1360 ),
1361 (
1362 "project/./src/file.rs",
1363 false,
1364 "Path with . should work normally",
1365 ),
1366 // Windows-style paths (if on Windows)
1367 #[cfg(target_os = "windows")]
1368 ("C:\\Windows\\System32\\hosts", true, "Windows system path"),
1369 #[cfg(target_os = "windows")]
1370 ("project\\src\\main.rs", false, "Windows-style project path"),
1371 ];
1372
1373 for (path, should_confirm, description) in test_cases {
1374 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1375 let auth = cx.update(|cx| {
1376 tool.authorize(
1377 &EditFileToolInput {
1378 display_description: "Edit file".into(),
1379 path: path.into(),
1380 mode: EditFileMode::Edit,
1381 },
1382 &stream_tx,
1383 cx,
1384 )
1385 });
1386
1387 if should_confirm {
1388 stream_rx.expect_authorization().await;
1389 } else {
1390 auth.await.unwrap();
1391 assert!(
1392 stream_rx.try_next().is_err(),
1393 "Failed for case: {} - path: {} - expected no confirmation but got one",
1394 description,
1395 path
1396 );
1397 }
1398 }
1399 }
1400
1401 #[gpui::test]
1402 async fn test_needs_confirmation_with_different_modes(cx: &mut TestAppContext) {
1403 init_test(cx);
1404 let fs = project::FakeFs::new(cx.executor());
1405 fs.insert_tree(
1406 "/project",
1407 json!({
1408 "existing.txt": "content",
1409 ".zed": {
1410 "settings.json": "{}"
1411 }
1412 }),
1413 )
1414 .await;
1415 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1416 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1417 let context_server_registry =
1418 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1419 let model = Arc::new(FakeLanguageModel::default());
1420 let thread = cx.new(|cx| {
1421 Thread::new(
1422 project.clone(),
1423 cx.new(|_cx| ProjectContext::default()),
1424 context_server_registry.clone(),
1425 Templates::new(),
1426 Some(model.clone()),
1427 cx,
1428 )
1429 });
1430 let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry));
1431
1432 // Test different EditFileMode values
1433 let modes = vec![
1434 EditFileMode::Edit,
1435 EditFileMode::Create,
1436 EditFileMode::Overwrite,
1437 ];
1438
1439 for mode in modes {
1440 // Test .zed path with different modes
1441 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1442 let _auth = cx.update(|cx| {
1443 tool.authorize(
1444 &EditFileToolInput {
1445 display_description: "Edit settings".into(),
1446 path: "project/.zed/settings.json".into(),
1447 mode: mode.clone(),
1448 },
1449 &stream_tx,
1450 cx,
1451 )
1452 });
1453
1454 stream_rx.expect_authorization().await;
1455
1456 // Test outside path with different modes
1457 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1458 let _auth = cx.update(|cx| {
1459 tool.authorize(
1460 &EditFileToolInput {
1461 display_description: "Edit file".into(),
1462 path: "/outside/file.txt".into(),
1463 mode: mode.clone(),
1464 },
1465 &stream_tx,
1466 cx,
1467 )
1468 });
1469
1470 stream_rx.expect_authorization().await;
1471
1472 // Test normal path with different modes
1473 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1474 cx.update(|cx| {
1475 tool.authorize(
1476 &EditFileToolInput {
1477 display_description: "Edit file".into(),
1478 path: "project/normal.txt".into(),
1479 mode: mode.clone(),
1480 },
1481 &stream_tx,
1482 cx,
1483 )
1484 })
1485 .await
1486 .unwrap();
1487 assert!(stream_rx.try_next().is_err());
1488 }
1489 }
1490
1491 #[gpui::test]
1492 async fn test_initial_title_with_partial_input(cx: &mut TestAppContext) {
1493 init_test(cx);
1494 let fs = project::FakeFs::new(cx.executor());
1495 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1496 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1497 let context_server_registry =
1498 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1499 let model = Arc::new(FakeLanguageModel::default());
1500 let thread = cx.new(|cx| {
1501 Thread::new(
1502 project.clone(),
1503 cx.new(|_cx| ProjectContext::default()),
1504 context_server_registry,
1505 Templates::new(),
1506 Some(model.clone()),
1507 cx,
1508 )
1509 });
1510 let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry));
1511
1512 assert_eq!(
1513 tool.initial_title(Err(json!({
1514 "path": "src/main.rs",
1515 "display_description": "",
1516 "old_string": "old code",
1517 "new_string": "new code"
1518 }))),
1519 "src/main.rs"
1520 );
1521 assert_eq!(
1522 tool.initial_title(Err(json!({
1523 "path": "",
1524 "display_description": "Fix error handling",
1525 "old_string": "old code",
1526 "new_string": "new code"
1527 }))),
1528 "Fix error handling"
1529 );
1530 assert_eq!(
1531 tool.initial_title(Err(json!({
1532 "path": "src/main.rs",
1533 "display_description": "Fix error handling",
1534 "old_string": "old code",
1535 "new_string": "new code"
1536 }))),
1537 "Fix error handling"
1538 );
1539 assert_eq!(
1540 tool.initial_title(Err(json!({
1541 "path": "",
1542 "display_description": "",
1543 "old_string": "old code",
1544 "new_string": "new code"
1545 }))),
1546 DEFAULT_UI_TEXT
1547 );
1548 assert_eq!(
1549 tool.initial_title(Err(serde_json::Value::Null)),
1550 DEFAULT_UI_TEXT
1551 );
1552 }
1553
1554 #[gpui::test]
1555 async fn test_diff_finalization(cx: &mut TestAppContext) {
1556 init_test(cx);
1557 let fs = project::FakeFs::new(cx.executor());
1558 fs.insert_tree("/", json!({"main.rs": ""})).await;
1559
1560 let project = Project::test(fs.clone(), [path!("/").as_ref()], cx).await;
1561 let languages = project.read_with(cx, |project, _cx| project.languages().clone());
1562 let context_server_registry =
1563 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1564 let model = Arc::new(FakeLanguageModel::default());
1565 let thread = cx.new(|cx| {
1566 Thread::new(
1567 project.clone(),
1568 cx.new(|_cx| ProjectContext::default()),
1569 context_server_registry.clone(),
1570 Templates::new(),
1571 Some(model.clone()),
1572 cx,
1573 )
1574 });
1575
1576 // Ensure the diff is finalized after the edit completes.
1577 {
1578 let tool = Arc::new(EditFileTool::new(thread.downgrade(), languages.clone()));
1579 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1580 let edit = cx.update(|cx| {
1581 tool.run(
1582 EditFileToolInput {
1583 display_description: "Edit file".into(),
1584 path: path!("/main.rs").into(),
1585 mode: EditFileMode::Edit,
1586 },
1587 stream_tx,
1588 cx,
1589 )
1590 });
1591 stream_rx.expect_update_fields().await;
1592 let diff = stream_rx.expect_diff().await;
1593 diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Pending(_))));
1594 cx.run_until_parked();
1595 model.end_last_completion_stream();
1596 edit.await.unwrap();
1597 diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Finalized(_))));
1598 }
1599
1600 // Ensure the diff is finalized if an error occurs while editing.
1601 {
1602 model.forbid_requests();
1603 let tool = Arc::new(EditFileTool::new(thread.downgrade(), languages.clone()));
1604 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1605 let edit = cx.update(|cx| {
1606 tool.run(
1607 EditFileToolInput {
1608 display_description: "Edit file".into(),
1609 path: path!("/main.rs").into(),
1610 mode: EditFileMode::Edit,
1611 },
1612 stream_tx,
1613 cx,
1614 )
1615 });
1616 stream_rx.expect_update_fields().await;
1617 let diff = stream_rx.expect_diff().await;
1618 diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Pending(_))));
1619 edit.await.unwrap_err();
1620 diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Finalized(_))));
1621 model.allow_requests();
1622 }
1623
1624 // Ensure the diff is finalized if the tool call gets dropped.
1625 {
1626 let tool = Arc::new(EditFileTool::new(thread.downgrade(), languages.clone()));
1627 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1628 let edit = cx.update(|cx| {
1629 tool.run(
1630 EditFileToolInput {
1631 display_description: "Edit file".into(),
1632 path: path!("/main.rs").into(),
1633 mode: EditFileMode::Edit,
1634 },
1635 stream_tx,
1636 cx,
1637 )
1638 });
1639 stream_rx.expect_update_fields().await;
1640 let diff = stream_rx.expect_diff().await;
1641 diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Pending(_))));
1642 drop(edit);
1643 cx.run_until_parked();
1644 diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Finalized(_))));
1645 }
1646 }
1647
1648 fn init_test(cx: &mut TestAppContext) {
1649 cx.update(|cx| {
1650 let settings_store = SettingsStore::test(cx);
1651 cx.set_global(settings_store);
1652 language::init(cx);
1653 TelemetrySettings::register(cx);
1654 agent_settings::AgentSettings::register(cx);
1655 Project::init_settings(cx);
1656 });
1657 }
1658}