1use crate::{AgentTool, Thread, ToolCallEventStream};
2use acp_thread::Diff;
3use agent_client_protocol::{self as acp, ToolCallLocation, ToolCallUpdateFields};
4use anyhow::{Context as _, Result, anyhow};
5use assistant_tools::edit_agent::{EditAgent, EditAgentOutput, EditAgentOutputEvent, EditFormat};
6use cloud_llm_client::CompletionIntent;
7use collections::HashSet;
8use gpui::{App, AppContext, AsyncApp, Entity, Task, WeakEntity};
9use indoc::formatdoc;
10use language::language_settings::{self, FormatOnSave};
11use language::{LanguageRegistry, ToPoint};
12use language_model::LanguageModelToolResultContent;
13use paths;
14use project::lsp_store::{FormatTrigger, LspFormatTarget};
15use project::{Project, ProjectPath};
16use schemars::JsonSchema;
17use serde::{Deserialize, Serialize};
18use settings::Settings;
19use smol::stream::StreamExt as _;
20use std::path::{Path, PathBuf};
21use std::sync::Arc;
22use ui::SharedString;
23use util::ResultExt;
24
25const DEFAULT_UI_TEXT: &str = "Editing file";
26
27/// This is a tool for creating a new file or editing an existing file. For moving or renaming files, you should generally use the `terminal` tool with the 'mv' command instead.
28///
29/// Before using this tool:
30///
31/// 1. Use the `read_file` tool to understand the file's contents and context
32///
33/// 2. Verify the directory path is correct (only applicable when creating new files):
34/// - Use the `list_directory` tool to verify the parent directory exists and is the correct location
35#[derive(Debug, Serialize, Deserialize, JsonSchema)]
36pub struct EditFileToolInput {
37 /// A one-line, user-friendly markdown description of the edit. This will be shown in the UI and also passed to another model to perform the edit.
38 ///
39 /// Be terse, but also descriptive in what you want to achieve with this edit. Avoid generic instructions.
40 ///
41 /// NEVER mention the file path in this description.
42 ///
43 /// <example>Fix API endpoint URLs</example>
44 /// <example>Update copyright year in `page_footer`</example>
45 ///
46 /// Make sure to include this field before all the others in the input object so that we can display it immediately.
47 pub display_description: String,
48
49 /// The full path of the file to create or modify in the project.
50 ///
51 /// WARNING: When specifying which file path need changing, you MUST start each path with one of the project's root directories.
52 ///
53 /// The following examples assume we have two root directories in the project:
54 /// - /a/b/backend
55 /// - /c/d/frontend
56 ///
57 /// <example>
58 /// `backend/src/main.rs`
59 ///
60 /// Notice how the file path starts with `backend`. Without that, the path would be ambiguous and the call would fail!
61 /// </example>
62 ///
63 /// <example>
64 /// `frontend/db.js`
65 /// </example>
66 pub path: PathBuf,
67 /// The mode of operation on the file. Possible values:
68 /// - 'edit': Make granular edits to an existing file.
69 /// - 'create': Create a new file if it doesn't exist.
70 /// - 'overwrite': Replace the entire contents of an existing file.
71 ///
72 /// When a file already exists or you just created it, prefer editing it as opposed to recreating it from scratch.
73 pub mode: EditFileMode,
74}
75
76#[derive(Debug, Serialize, Deserialize, JsonSchema)]
77struct EditFileToolPartialInput {
78 #[serde(default)]
79 path: String,
80 #[serde(default)]
81 display_description: String,
82}
83
84#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
85#[serde(rename_all = "lowercase")]
86#[schemars(inline)]
87pub enum EditFileMode {
88 Edit,
89 Create,
90 Overwrite,
91}
92
93#[derive(Debug, Serialize, Deserialize)]
94pub struct EditFileToolOutput {
95 #[serde(alias = "original_path")]
96 input_path: PathBuf,
97 new_text: String,
98 old_text: Arc<String>,
99 #[serde(default)]
100 diff: String,
101 #[serde(alias = "raw_output")]
102 edit_agent_output: EditAgentOutput,
103}
104
105impl From<EditFileToolOutput> for LanguageModelToolResultContent {
106 fn from(output: EditFileToolOutput) -> Self {
107 if output.diff.is_empty() {
108 "No edits were made.".into()
109 } else {
110 format!(
111 "Edited {}:\n\n```diff\n{}\n```",
112 output.input_path.display(),
113 output.diff
114 )
115 .into()
116 }
117 }
118}
119
120pub struct EditFileTool {
121 thread: WeakEntity<Thread>,
122 language_registry: Arc<LanguageRegistry>,
123 project: Entity<Project>,
124}
125
126impl EditFileTool {
127 pub fn new(
128 project: Entity<Project>,
129 thread: WeakEntity<Thread>,
130 language_registry: Arc<LanguageRegistry>,
131 ) -> Self {
132 Self {
133 project,
134 thread,
135 language_registry,
136 }
137 }
138
139 fn authorize(
140 &self,
141 input: &EditFileToolInput,
142 event_stream: &ToolCallEventStream,
143 cx: &mut App,
144 ) -> Task<Result<()>> {
145 if agent_settings::AgentSettings::get_global(cx).always_allow_tool_actions {
146 return Task::ready(Ok(()));
147 }
148
149 // If any path component matches the local settings folder, then this could affect
150 // the editor in ways beyond the project source, so prompt.
151 let local_settings_folder = paths::local_settings_folder_relative_path();
152 let path = Path::new(&input.path);
153 if path
154 .components()
155 .any(|component| component.as_os_str() == local_settings_folder.as_os_str())
156 {
157 return event_stream.authorize(
158 format!("{} (local settings)", input.display_description),
159 cx,
160 );
161 }
162
163 // It's also possible that the global config dir is configured to be inside the project,
164 // so check for that edge case too.
165 if let Ok(canonical_path) = std::fs::canonicalize(&input.path)
166 && canonical_path.starts_with(paths::config_dir())
167 {
168 return event_stream.authorize(
169 format!("{} (global settings)", input.display_description),
170 cx,
171 );
172 }
173
174 // Check if path is inside the global config directory
175 // First check if it's already inside project - if not, try to canonicalize
176 let Ok(project_path) = self.thread.read_with(cx, |thread, cx| {
177 thread.project().read(cx).find_project_path(&input.path, cx)
178 }) else {
179 return Task::ready(Err(anyhow!("thread was dropped")));
180 };
181
182 // If the path is inside the project, and it's not one of the above edge cases,
183 // then no confirmation is necessary. Otherwise, confirmation is necessary.
184 if project_path.is_some() {
185 Task::ready(Ok(()))
186 } else {
187 event_stream.authorize(&input.display_description, cx)
188 }
189 }
190}
191
192impl AgentTool for EditFileTool {
193 type Input = EditFileToolInput;
194 type Output = EditFileToolOutput;
195
196 fn name() -> &'static str {
197 "edit_file"
198 }
199
200 fn kind() -> acp::ToolKind {
201 acp::ToolKind::Edit
202 }
203
204 fn initial_title(
205 &self,
206 input: Result<Self::Input, serde_json::Value>,
207 cx: &mut App,
208 ) -> SharedString {
209 match input {
210 Ok(input) => self
211 .project
212 .read(cx)
213 .find_project_path(&input.path, cx)
214 .and_then(|project_path| {
215 self.project
216 .read(cx)
217 .short_full_path_for_project_path(&project_path, cx)
218 })
219 .unwrap_or(Path::new(&input.path).into())
220 .to_string_lossy()
221 .to_string()
222 .into(),
223 Err(raw_input) => {
224 if let Some(input) =
225 serde_json::from_value::<EditFileToolPartialInput>(raw_input).ok()
226 {
227 let path = input.path.trim();
228 if !path.is_empty() {
229 return self
230 .project
231 .read(cx)
232 .find_project_path(&input.path, cx)
233 .and_then(|project_path| {
234 self.project
235 .read(cx)
236 .short_full_path_for_project_path(&project_path, cx)
237 })
238 .unwrap_or(Path::new(&input.path).into())
239 .to_string_lossy()
240 .to_string()
241 .into();
242 }
243
244 let description = input.display_description.trim();
245 if !description.is_empty() {
246 return description.to_string().into();
247 }
248 }
249
250 DEFAULT_UI_TEXT.into()
251 }
252 }
253 }
254
255 fn run(
256 self: Arc<Self>,
257 input: Self::Input,
258 event_stream: ToolCallEventStream,
259 cx: &mut App,
260 ) -> Task<Result<Self::Output>> {
261 let Ok(project) = self
262 .thread
263 .read_with(cx, |thread, _cx| thread.project().clone())
264 else {
265 return Task::ready(Err(anyhow!("thread was dropped")));
266 };
267 let project_path = match resolve_path(&input, project.clone(), cx) {
268 Ok(path) => path,
269 Err(err) => return Task::ready(Err(anyhow!(err))),
270 };
271 let abs_path = project.read(cx).absolute_path(&project_path, cx);
272 if let Some(abs_path) = abs_path.clone() {
273 event_stream.update_fields(ToolCallUpdateFields {
274 locations: Some(vec![acp::ToolCallLocation {
275 path: abs_path,
276 line: None,
277 meta: None,
278 }]),
279 ..Default::default()
280 });
281 }
282
283 let authorize = self.authorize(&input, &event_stream, cx);
284 cx.spawn(async move |cx: &mut AsyncApp| {
285 authorize.await?;
286
287 let (request, model, action_log) = self.thread.update(cx, |thread, cx| {
288 let request = thread.build_completion_request(CompletionIntent::ToolResults, cx);
289 (request, thread.model().cloned(), thread.action_log().clone())
290 })?;
291 let request = request?;
292 let model = model.context("No language model configured")?;
293
294 let edit_format = EditFormat::from_model(model.clone())?;
295 let edit_agent = EditAgent::new(
296 model,
297 project.clone(),
298 action_log.clone(),
299 // TODO: move edit agent to this crate so we can use our templates
300 assistant_tools::templates::Templates::new(),
301 edit_format,
302 );
303
304 let buffer = project
305 .update(cx, |project, cx| {
306 project.open_buffer(project_path.clone(), cx)
307 })?
308 .await?;
309
310 let diff = cx.new(|cx| Diff::new(buffer.clone(), cx))?;
311 event_stream.update_diff(diff.clone());
312 let _finalize_diff = util::defer({
313 let diff = diff.downgrade();
314 let mut cx = cx.clone();
315 move || {
316 diff.update(&mut cx, |diff, cx| diff.finalize(cx)).ok();
317 }
318 });
319
320 let old_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
321 let old_text = cx
322 .background_spawn({
323 let old_snapshot = old_snapshot.clone();
324 async move { Arc::new(old_snapshot.text()) }
325 })
326 .await;
327
328
329 let (output, mut events) = if matches!(input.mode, EditFileMode::Edit) {
330 edit_agent.edit(
331 buffer.clone(),
332 input.display_description.clone(),
333 &request,
334 cx,
335 )
336 } else {
337 edit_agent.overwrite(
338 buffer.clone(),
339 input.display_description.clone(),
340 &request,
341 cx,
342 )
343 };
344
345 let mut hallucinated_old_text = false;
346 let mut ambiguous_ranges = Vec::new();
347 let mut emitted_location = false;
348 while let Some(event) = events.next().await {
349 match event {
350 EditAgentOutputEvent::Edited(range) => {
351 if !emitted_location {
352 let line = buffer.update(cx, |buffer, _cx| {
353 range.start.to_point(&buffer.snapshot()).row
354 }).ok();
355 if let Some(abs_path) = abs_path.clone() {
356 event_stream.update_fields(ToolCallUpdateFields {
357 locations: Some(vec![ToolCallLocation { path: abs_path, line, meta: None }]),
358 ..Default::default()
359 });
360 }
361 emitted_location = true;
362 }
363 },
364 EditAgentOutputEvent::UnresolvedEditRange => hallucinated_old_text = true,
365 EditAgentOutputEvent::AmbiguousEditRange(ranges) => ambiguous_ranges = ranges,
366 EditAgentOutputEvent::ResolvingEditRange(range) => {
367 diff.update(cx, |card, cx| card.reveal_range(range.clone(), cx))?;
368 // if !emitted_location {
369 // let line = buffer.update(cx, |buffer, _cx| {
370 // range.start.to_point(&buffer.snapshot()).row
371 // }).ok();
372 // if let Some(abs_path) = abs_path.clone() {
373 // event_stream.update_fields(ToolCallUpdateFields {
374 // locations: Some(vec![ToolCallLocation { path: abs_path, line }]),
375 // ..Default::default()
376 // });
377 // }
378 // }
379 }
380 }
381 }
382
383 // If format_on_save is enabled, format the buffer
384 let format_on_save_enabled = buffer
385 .read_with(cx, |buffer, cx| {
386 let settings = language_settings::language_settings(
387 buffer.language().map(|l| l.name()),
388 buffer.file(),
389 cx,
390 );
391 settings.format_on_save != FormatOnSave::Off
392 })
393 .unwrap_or(false);
394
395 let edit_agent_output = output.await?;
396
397 if format_on_save_enabled {
398 action_log.update(cx, |log, cx| {
399 log.buffer_edited(buffer.clone(), cx);
400 })?;
401
402 let format_task = project.update(cx, |project, cx| {
403 project.format(
404 HashSet::from_iter([buffer.clone()]),
405 LspFormatTarget::Buffers,
406 false, // Don't push to history since the tool did it.
407 FormatTrigger::Save,
408 cx,
409 )
410 })?;
411 format_task.await.log_err();
412 }
413
414 project
415 .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))?
416 .await?;
417
418 action_log.update(cx, |log, cx| {
419 log.buffer_edited(buffer.clone(), cx);
420 })?;
421
422 let new_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
423 let (new_text, unified_diff) = cx
424 .background_spawn({
425 let new_snapshot = new_snapshot.clone();
426 let old_text = old_text.clone();
427 async move {
428 let new_text = new_snapshot.text();
429 let diff = language::unified_diff(&old_text, &new_text);
430 (new_text, diff)
431 }
432 })
433 .await;
434
435 let input_path = input.path.display();
436 if unified_diff.is_empty() {
437 anyhow::ensure!(
438 !hallucinated_old_text,
439 formatdoc! {"
440 Some edits were produced but none of them could be applied.
441 Read the relevant sections of {input_path} again so that
442 I can perform the requested edits.
443 "}
444 );
445 anyhow::ensure!(
446 ambiguous_ranges.is_empty(),
447 {
448 let line_numbers = ambiguous_ranges
449 .iter()
450 .map(|range| range.start.to_string())
451 .collect::<Vec<_>>()
452 .join(", ");
453 formatdoc! {"
454 <old_text> matches more than one position in the file (lines: {line_numbers}). Read the
455 relevant sections of {input_path} again and extend <old_text> so
456 that I can perform the requested edits.
457 "}
458 }
459 );
460 }
461
462 Ok(EditFileToolOutput {
463 input_path: input.path,
464 new_text,
465 old_text,
466 diff: unified_diff,
467 edit_agent_output,
468 })
469 })
470 }
471
472 fn replay(
473 &self,
474 _input: Self::Input,
475 output: Self::Output,
476 event_stream: ToolCallEventStream,
477 cx: &mut App,
478 ) -> Result<()> {
479 event_stream.update_diff(cx.new(|cx| {
480 Diff::finalized(
481 output.input_path,
482 Some(output.old_text.to_string()),
483 output.new_text,
484 self.language_registry.clone(),
485 cx,
486 )
487 }));
488 Ok(())
489 }
490}
491
492/// Validate that the file path is valid, meaning:
493///
494/// - For `edit` and `overwrite`, the path must point to an existing file.
495/// - For `create`, the file must not already exist, but it's parent dir must exist.
496fn resolve_path(
497 input: &EditFileToolInput,
498 project: Entity<Project>,
499 cx: &mut App,
500) -> Result<ProjectPath> {
501 let project = project.read(cx);
502
503 match input.mode {
504 EditFileMode::Edit | EditFileMode::Overwrite => {
505 let path = project
506 .find_project_path(&input.path, cx)
507 .context("Can't edit file: path not found")?;
508
509 let entry = project
510 .entry_for_path(&path, cx)
511 .context("Can't edit file: path not found")?;
512
513 anyhow::ensure!(entry.is_file(), "Can't edit file: path is a directory");
514 Ok(path)
515 }
516
517 EditFileMode::Create => {
518 if let Some(path) = project.find_project_path(&input.path, cx) {
519 anyhow::ensure!(
520 project.entry_for_path(&path, cx).is_none(),
521 "Can't create file: file already exists"
522 );
523 }
524
525 let parent_path = input
526 .path
527 .parent()
528 .context("Can't create file: incorrect path")?;
529
530 let parent_project_path = project.find_project_path(&parent_path, cx);
531
532 let parent_entry = parent_project_path
533 .as_ref()
534 .and_then(|path| project.entry_for_path(path, cx))
535 .context("Can't create file: parent directory doesn't exist")?;
536
537 anyhow::ensure!(
538 parent_entry.is_dir(),
539 "Can't create file: parent is not a directory"
540 );
541
542 let file_name = input
543 .path
544 .file_name()
545 .context("Can't create file: invalid filename")?;
546
547 let new_file_path = parent_project_path.map(|parent| ProjectPath {
548 path: Arc::from(parent.path.join(file_name)),
549 ..parent
550 });
551
552 new_file_path.context("Can't create file")
553 }
554 }
555}
556
557#[cfg(test)]
558mod tests {
559 use super::*;
560 use crate::{ContextServerRegistry, Templates};
561 use client::TelemetrySettings;
562 use fs::Fs;
563 use gpui::{TestAppContext, UpdateGlobal};
564 use language_model::fake_provider::FakeLanguageModel;
565 use prompt_store::ProjectContext;
566 use serde_json::json;
567 use settings::SettingsStore;
568 use util::path;
569
570 #[gpui::test]
571 async fn test_edit_nonexistent_file(cx: &mut TestAppContext) {
572 init_test(cx);
573
574 let fs = project::FakeFs::new(cx.executor());
575 fs.insert_tree("/root", json!({})).await;
576 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
577 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
578 let context_server_registry =
579 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
580 let model = Arc::new(FakeLanguageModel::default());
581 let thread = cx.new(|cx| {
582 Thread::new(
583 project.clone(),
584 cx.new(|_cx| ProjectContext::default()),
585 context_server_registry,
586 Templates::new(),
587 Some(model),
588 cx,
589 )
590 });
591 let result = cx
592 .update(|cx| {
593 let input = EditFileToolInput {
594 display_description: "Some edit".into(),
595 path: "root/nonexistent_file.txt".into(),
596 mode: EditFileMode::Edit,
597 };
598 Arc::new(EditFileTool::new(
599 project,
600 thread.downgrade(),
601 language_registry,
602 ))
603 .run(input, ToolCallEventStream::test().0, cx)
604 })
605 .await;
606 assert_eq!(
607 result.unwrap_err().to_string(),
608 "Can't edit file: path not found"
609 );
610 }
611
612 #[gpui::test]
613 async fn test_resolve_path_for_creating_file(cx: &mut TestAppContext) {
614 let mode = &EditFileMode::Create;
615
616 let result = test_resolve_path(mode, "root/new.txt", cx);
617 assert_resolved_path_eq(result.await, "new.txt");
618
619 let result = test_resolve_path(mode, "new.txt", cx);
620 assert_resolved_path_eq(result.await, "new.txt");
621
622 let result = test_resolve_path(mode, "dir/new.txt", cx);
623 assert_resolved_path_eq(result.await, "dir/new.txt");
624
625 let result = test_resolve_path(mode, "root/dir/subdir/existing.txt", cx);
626 assert_eq!(
627 result.await.unwrap_err().to_string(),
628 "Can't create file: file already exists"
629 );
630
631 let result = test_resolve_path(mode, "root/dir/nonexistent_dir/new.txt", cx);
632 assert_eq!(
633 result.await.unwrap_err().to_string(),
634 "Can't create file: parent directory doesn't exist"
635 );
636 }
637
638 #[gpui::test]
639 async fn test_resolve_path_for_editing_file(cx: &mut TestAppContext) {
640 let mode = &EditFileMode::Edit;
641
642 let path_with_root = "root/dir/subdir/existing.txt";
643 let path_without_root = "dir/subdir/existing.txt";
644 let result = test_resolve_path(mode, path_with_root, cx);
645 assert_resolved_path_eq(result.await, path_without_root);
646
647 let result = test_resolve_path(mode, path_without_root, cx);
648 assert_resolved_path_eq(result.await, path_without_root);
649
650 let result = test_resolve_path(mode, "root/nonexistent.txt", cx);
651 assert_eq!(
652 result.await.unwrap_err().to_string(),
653 "Can't edit file: path not found"
654 );
655
656 let result = test_resolve_path(mode, "root/dir", cx);
657 assert_eq!(
658 result.await.unwrap_err().to_string(),
659 "Can't edit file: path is a directory"
660 );
661 }
662
663 async fn test_resolve_path(
664 mode: &EditFileMode,
665 path: &str,
666 cx: &mut TestAppContext,
667 ) -> anyhow::Result<ProjectPath> {
668 init_test(cx);
669
670 let fs = project::FakeFs::new(cx.executor());
671 fs.insert_tree(
672 "/root",
673 json!({
674 "dir": {
675 "subdir": {
676 "existing.txt": "hello"
677 }
678 }
679 }),
680 )
681 .await;
682 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
683
684 let input = EditFileToolInput {
685 display_description: "Some edit".into(),
686 path: path.into(),
687 mode: mode.clone(),
688 };
689
690 cx.update(|cx| resolve_path(&input, project, cx))
691 }
692
693 fn assert_resolved_path_eq(path: anyhow::Result<ProjectPath>, expected: &str) {
694 let actual = path
695 .expect("Should return valid path")
696 .path
697 .to_str()
698 .unwrap()
699 .replace("\\", "/"); // Naive Windows paths normalization
700 assert_eq!(actual, expected);
701 }
702
703 #[gpui::test]
704 async fn test_format_on_save(cx: &mut TestAppContext) {
705 init_test(cx);
706
707 let fs = project::FakeFs::new(cx.executor());
708 fs.insert_tree("/root", json!({"src": {}})).await;
709
710 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
711
712 // Set up a Rust language with LSP formatting support
713 let rust_language = Arc::new(language::Language::new(
714 language::LanguageConfig {
715 name: "Rust".into(),
716 matcher: language::LanguageMatcher {
717 path_suffixes: vec!["rs".to_string()],
718 ..Default::default()
719 },
720 ..Default::default()
721 },
722 None,
723 ));
724
725 // Register the language and fake LSP
726 let language_registry = project.read_with(cx, |project, _| project.languages().clone());
727 language_registry.add(rust_language);
728
729 let mut fake_language_servers = language_registry.register_fake_lsp(
730 "Rust",
731 language::FakeLspAdapter {
732 capabilities: lsp::ServerCapabilities {
733 document_formatting_provider: Some(lsp::OneOf::Left(true)),
734 ..Default::default()
735 },
736 ..Default::default()
737 },
738 );
739
740 // Create the file
741 fs.save(
742 path!("/root/src/main.rs").as_ref(),
743 &"initial content".into(),
744 language::LineEnding::Unix,
745 )
746 .await
747 .unwrap();
748
749 // Open the buffer to trigger LSP initialization
750 let buffer = project
751 .update(cx, |project, cx| {
752 project.open_local_buffer(path!("/root/src/main.rs"), cx)
753 })
754 .await
755 .unwrap();
756
757 // Register the buffer with language servers
758 let _handle = project.update(cx, |project, cx| {
759 project.register_buffer_with_language_servers(&buffer, cx)
760 });
761
762 const UNFORMATTED_CONTENT: &str = "fn main() {println!(\"Hello!\");}\n";
763 const FORMATTED_CONTENT: &str =
764 "This file was formatted by the fake formatter in the test.\n";
765
766 // Get the fake language server and set up formatting handler
767 let fake_language_server = fake_language_servers.next().await.unwrap();
768 fake_language_server.set_request_handler::<lsp::request::Formatting, _, _>({
769 |_, _| async move {
770 Ok(Some(vec![lsp::TextEdit {
771 range: lsp::Range::new(lsp::Position::new(0, 0), lsp::Position::new(1, 0)),
772 new_text: FORMATTED_CONTENT.to_string(),
773 }]))
774 }
775 });
776
777 let context_server_registry =
778 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
779 let model = Arc::new(FakeLanguageModel::default());
780 let thread = cx.new(|cx| {
781 Thread::new(
782 project.clone(),
783 cx.new(|_cx| ProjectContext::default()),
784 context_server_registry,
785 Templates::new(),
786 Some(model.clone()),
787 cx,
788 )
789 });
790
791 // First, test with format_on_save enabled
792 cx.update(|cx| {
793 SettingsStore::update_global(cx, |store, cx| {
794 store.update_user_settings::<language::language_settings::AllLanguageSettings>(
795 cx,
796 |settings| {
797 settings.defaults.format_on_save = Some(FormatOnSave::On);
798 settings.defaults.formatter =
799 Some(language::language_settings::SelectedFormatter::Auto);
800 },
801 );
802 });
803 });
804
805 // Have the model stream unformatted content
806 let edit_result = {
807 let edit_task = cx.update(|cx| {
808 let input = EditFileToolInput {
809 display_description: "Create main function".into(),
810 path: "root/src/main.rs".into(),
811 mode: EditFileMode::Overwrite,
812 };
813 Arc::new(EditFileTool::new(
814 project.clone(),
815 thread.downgrade(),
816 language_registry.clone(),
817 ))
818 .run(input, ToolCallEventStream::test().0, cx)
819 });
820
821 // Stream the unformatted content
822 cx.executor().run_until_parked();
823 model.send_last_completion_stream_text_chunk(UNFORMATTED_CONTENT.to_string());
824 model.end_last_completion_stream();
825
826 edit_task.await
827 };
828 assert!(edit_result.is_ok());
829
830 // Wait for any async operations (e.g. formatting) to complete
831 cx.executor().run_until_parked();
832
833 // Read the file to verify it was formatted automatically
834 let new_content = fs.load(path!("/root/src/main.rs").as_ref()).await.unwrap();
835 assert_eq!(
836 // Ignore carriage returns on Windows
837 new_content.replace("\r\n", "\n"),
838 FORMATTED_CONTENT,
839 "Code should be formatted when format_on_save is enabled"
840 );
841
842 let stale_buffer_count = thread
843 .read_with(cx, |thread, _cx| thread.action_log.clone())
844 .read_with(cx, |log, cx| log.stale_buffers(cx).count());
845
846 assert_eq!(
847 stale_buffer_count, 0,
848 "BUG: Buffer is incorrectly marked as stale after format-on-save. Found {} stale buffers. \
849 This causes the agent to think the file was modified externally when it was just formatted.",
850 stale_buffer_count
851 );
852
853 // Next, test with format_on_save disabled
854 cx.update(|cx| {
855 SettingsStore::update_global(cx, |store, cx| {
856 store.update_user_settings::<language::language_settings::AllLanguageSettings>(
857 cx,
858 |settings| {
859 settings.defaults.format_on_save = Some(FormatOnSave::Off);
860 },
861 );
862 });
863 });
864
865 // Stream unformatted edits again
866 let edit_result = {
867 let edit_task = cx.update(|cx| {
868 let input = EditFileToolInput {
869 display_description: "Update main function".into(),
870 path: "root/src/main.rs".into(),
871 mode: EditFileMode::Overwrite,
872 };
873 Arc::new(EditFileTool::new(
874 project.clone(),
875 thread.downgrade(),
876 language_registry,
877 ))
878 .run(input, ToolCallEventStream::test().0, cx)
879 });
880
881 // Stream the unformatted content
882 cx.executor().run_until_parked();
883 model.send_last_completion_stream_text_chunk(UNFORMATTED_CONTENT.to_string());
884 model.end_last_completion_stream();
885
886 edit_task.await
887 };
888 assert!(edit_result.is_ok());
889
890 // Wait for any async operations (e.g. formatting) to complete
891 cx.executor().run_until_parked();
892
893 // Verify the file was not formatted
894 let new_content = fs.load(path!("/root/src/main.rs").as_ref()).await.unwrap();
895 assert_eq!(
896 // Ignore carriage returns on Windows
897 new_content.replace("\r\n", "\n"),
898 UNFORMATTED_CONTENT,
899 "Code should not be formatted when format_on_save is disabled"
900 );
901 }
902
903 #[gpui::test]
904 async fn test_remove_trailing_whitespace(cx: &mut TestAppContext) {
905 init_test(cx);
906
907 let fs = project::FakeFs::new(cx.executor());
908 fs.insert_tree("/root", json!({"src": {}})).await;
909
910 // Create a simple file with trailing whitespace
911 fs.save(
912 path!("/root/src/main.rs").as_ref(),
913 &"initial content".into(),
914 language::LineEnding::Unix,
915 )
916 .await
917 .unwrap();
918
919 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
920 let context_server_registry =
921 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
922 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
923 let model = Arc::new(FakeLanguageModel::default());
924 let thread = cx.new(|cx| {
925 Thread::new(
926 project.clone(),
927 cx.new(|_cx| ProjectContext::default()),
928 context_server_registry,
929 Templates::new(),
930 Some(model.clone()),
931 cx,
932 )
933 });
934
935 // First, test with remove_trailing_whitespace_on_save enabled
936 cx.update(|cx| {
937 SettingsStore::update_global(cx, |store, cx| {
938 store.update_user_settings::<language::language_settings::AllLanguageSettings>(
939 cx,
940 |settings| {
941 settings.defaults.remove_trailing_whitespace_on_save = Some(true);
942 },
943 );
944 });
945 });
946
947 const CONTENT_WITH_TRAILING_WHITESPACE: &str =
948 "fn main() { \n println!(\"Hello!\"); \n}\n";
949
950 // Have the model stream content that contains trailing whitespace
951 let edit_result = {
952 let edit_task = cx.update(|cx| {
953 let input = EditFileToolInput {
954 display_description: "Create main function".into(),
955 path: "root/src/main.rs".into(),
956 mode: EditFileMode::Overwrite,
957 };
958 Arc::new(EditFileTool::new(
959 project.clone(),
960 thread.downgrade(),
961 language_registry.clone(),
962 ))
963 .run(input, ToolCallEventStream::test().0, cx)
964 });
965
966 // Stream the content with trailing whitespace
967 cx.executor().run_until_parked();
968 model.send_last_completion_stream_text_chunk(
969 CONTENT_WITH_TRAILING_WHITESPACE.to_string(),
970 );
971 model.end_last_completion_stream();
972
973 edit_task.await
974 };
975 assert!(edit_result.is_ok());
976
977 // Wait for any async operations (e.g. formatting) to complete
978 cx.executor().run_until_parked();
979
980 // Read the file to verify trailing whitespace was removed automatically
981 assert_eq!(
982 // Ignore carriage returns on Windows
983 fs.load(path!("/root/src/main.rs").as_ref())
984 .await
985 .unwrap()
986 .replace("\r\n", "\n"),
987 "fn main() {\n println!(\"Hello!\");\n}\n",
988 "Trailing whitespace should be removed when remove_trailing_whitespace_on_save is enabled"
989 );
990
991 // Next, test with remove_trailing_whitespace_on_save disabled
992 cx.update(|cx| {
993 SettingsStore::update_global(cx, |store, cx| {
994 store.update_user_settings::<language::language_settings::AllLanguageSettings>(
995 cx,
996 |settings| {
997 settings.defaults.remove_trailing_whitespace_on_save = Some(false);
998 },
999 );
1000 });
1001 });
1002
1003 // Stream edits again with trailing whitespace
1004 let edit_result = {
1005 let edit_task = cx.update(|cx| {
1006 let input = EditFileToolInput {
1007 display_description: "Update main function".into(),
1008 path: "root/src/main.rs".into(),
1009 mode: EditFileMode::Overwrite,
1010 };
1011 Arc::new(EditFileTool::new(
1012 project.clone(),
1013 thread.downgrade(),
1014 language_registry,
1015 ))
1016 .run(input, ToolCallEventStream::test().0, cx)
1017 });
1018
1019 // Stream the content with trailing whitespace
1020 cx.executor().run_until_parked();
1021 model.send_last_completion_stream_text_chunk(
1022 CONTENT_WITH_TRAILING_WHITESPACE.to_string(),
1023 );
1024 model.end_last_completion_stream();
1025
1026 edit_task.await
1027 };
1028 assert!(edit_result.is_ok());
1029
1030 // Wait for any async operations (e.g. formatting) to complete
1031 cx.executor().run_until_parked();
1032
1033 // Verify the file still has trailing whitespace
1034 // Read the file again - it should still have trailing whitespace
1035 let final_content = fs.load(path!("/root/src/main.rs").as_ref()).await.unwrap();
1036 assert_eq!(
1037 // Ignore carriage returns on Windows
1038 final_content.replace("\r\n", "\n"),
1039 CONTENT_WITH_TRAILING_WHITESPACE,
1040 "Trailing whitespace should remain when remove_trailing_whitespace_on_save is disabled"
1041 );
1042 }
1043
1044 #[gpui::test]
1045 async fn test_authorize(cx: &mut TestAppContext) {
1046 init_test(cx);
1047 let fs = project::FakeFs::new(cx.executor());
1048 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
1049 let context_server_registry =
1050 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1051 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1052 let model = Arc::new(FakeLanguageModel::default());
1053 let thread = cx.new(|cx| {
1054 Thread::new(
1055 project.clone(),
1056 cx.new(|_cx| ProjectContext::default()),
1057 context_server_registry,
1058 Templates::new(),
1059 Some(model.clone()),
1060 cx,
1061 )
1062 });
1063 let tool = Arc::new(EditFileTool::new(
1064 project.clone(),
1065 thread.downgrade(),
1066 language_registry,
1067 ));
1068 fs.insert_tree("/root", json!({})).await;
1069
1070 // Test 1: Path with .zed component should require confirmation
1071 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1072 let _auth = cx.update(|cx| {
1073 tool.authorize(
1074 &EditFileToolInput {
1075 display_description: "test 1".into(),
1076 path: ".zed/settings.json".into(),
1077 mode: EditFileMode::Edit,
1078 },
1079 &stream_tx,
1080 cx,
1081 )
1082 });
1083
1084 let event = stream_rx.expect_authorization().await;
1085 assert_eq!(
1086 event.tool_call.fields.title,
1087 Some("test 1 (local settings)".into())
1088 );
1089
1090 // Test 2: Path outside project should require confirmation
1091 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1092 let _auth = cx.update(|cx| {
1093 tool.authorize(
1094 &EditFileToolInput {
1095 display_description: "test 2".into(),
1096 path: "/etc/hosts".into(),
1097 mode: EditFileMode::Edit,
1098 },
1099 &stream_tx,
1100 cx,
1101 )
1102 });
1103
1104 let event = stream_rx.expect_authorization().await;
1105 assert_eq!(event.tool_call.fields.title, Some("test 2".into()));
1106
1107 // Test 3: Relative path without .zed should not require confirmation
1108 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1109 cx.update(|cx| {
1110 tool.authorize(
1111 &EditFileToolInput {
1112 display_description: "test 3".into(),
1113 path: "root/src/main.rs".into(),
1114 mode: EditFileMode::Edit,
1115 },
1116 &stream_tx,
1117 cx,
1118 )
1119 })
1120 .await
1121 .unwrap();
1122 assert!(stream_rx.try_next().is_err());
1123
1124 // Test 4: Path with .zed in the middle should require confirmation
1125 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1126 let _auth = cx.update(|cx| {
1127 tool.authorize(
1128 &EditFileToolInput {
1129 display_description: "test 4".into(),
1130 path: "root/.zed/tasks.json".into(),
1131 mode: EditFileMode::Edit,
1132 },
1133 &stream_tx,
1134 cx,
1135 )
1136 });
1137 let event = stream_rx.expect_authorization().await;
1138 assert_eq!(
1139 event.tool_call.fields.title,
1140 Some("test 4 (local settings)".into())
1141 );
1142
1143 // Test 5: When always_allow_tool_actions is enabled, no confirmation needed
1144 cx.update(|cx| {
1145 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
1146 settings.always_allow_tool_actions = true;
1147 agent_settings::AgentSettings::override_global(settings, cx);
1148 });
1149
1150 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1151 cx.update(|cx| {
1152 tool.authorize(
1153 &EditFileToolInput {
1154 display_description: "test 5.1".into(),
1155 path: ".zed/settings.json".into(),
1156 mode: EditFileMode::Edit,
1157 },
1158 &stream_tx,
1159 cx,
1160 )
1161 })
1162 .await
1163 .unwrap();
1164 assert!(stream_rx.try_next().is_err());
1165
1166 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1167 cx.update(|cx| {
1168 tool.authorize(
1169 &EditFileToolInput {
1170 display_description: "test 5.2".into(),
1171 path: "/etc/hosts".into(),
1172 mode: EditFileMode::Edit,
1173 },
1174 &stream_tx,
1175 cx,
1176 )
1177 })
1178 .await
1179 .unwrap();
1180 assert!(stream_rx.try_next().is_err());
1181 }
1182
1183 #[gpui::test]
1184 async fn test_authorize_global_config(cx: &mut TestAppContext) {
1185 init_test(cx);
1186 let fs = project::FakeFs::new(cx.executor());
1187 fs.insert_tree("/project", json!({})).await;
1188 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1189 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1190 let context_server_registry =
1191 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1192 let model = Arc::new(FakeLanguageModel::default());
1193 let thread = cx.new(|cx| {
1194 Thread::new(
1195 project.clone(),
1196 cx.new(|_cx| ProjectContext::default()),
1197 context_server_registry,
1198 Templates::new(),
1199 Some(model.clone()),
1200 cx,
1201 )
1202 });
1203 let tool = Arc::new(EditFileTool::new(
1204 project.clone(),
1205 thread.downgrade(),
1206 language_registry,
1207 ));
1208
1209 // Test global config paths - these should require confirmation if they exist and are outside the project
1210 let test_cases = vec![
1211 (
1212 "/etc/hosts",
1213 true,
1214 "System file should require confirmation",
1215 ),
1216 (
1217 "/usr/local/bin/script",
1218 true,
1219 "System bin file should require confirmation",
1220 ),
1221 (
1222 "project/normal_file.rs",
1223 false,
1224 "Normal project file should not require confirmation",
1225 ),
1226 ];
1227
1228 for (path, should_confirm, description) in test_cases {
1229 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1230 let auth = cx.update(|cx| {
1231 tool.authorize(
1232 &EditFileToolInput {
1233 display_description: "Edit file".into(),
1234 path: path.into(),
1235 mode: EditFileMode::Edit,
1236 },
1237 &stream_tx,
1238 cx,
1239 )
1240 });
1241
1242 if should_confirm {
1243 stream_rx.expect_authorization().await;
1244 } else {
1245 auth.await.unwrap();
1246 assert!(
1247 stream_rx.try_next().is_err(),
1248 "Failed for case: {} - path: {} - expected no confirmation but got one",
1249 description,
1250 path
1251 );
1252 }
1253 }
1254 }
1255
1256 #[gpui::test]
1257 async fn test_needs_confirmation_with_multiple_worktrees(cx: &mut TestAppContext) {
1258 init_test(cx);
1259 let fs = project::FakeFs::new(cx.executor());
1260
1261 // Create multiple worktree directories
1262 fs.insert_tree(
1263 "/workspace/frontend",
1264 json!({
1265 "src": {
1266 "main.js": "console.log('frontend');"
1267 }
1268 }),
1269 )
1270 .await;
1271 fs.insert_tree(
1272 "/workspace/backend",
1273 json!({
1274 "src": {
1275 "main.rs": "fn main() {}"
1276 }
1277 }),
1278 )
1279 .await;
1280 fs.insert_tree(
1281 "/workspace/shared",
1282 json!({
1283 ".zed": {
1284 "settings.json": "{}"
1285 }
1286 }),
1287 )
1288 .await;
1289
1290 // Create project with multiple worktrees
1291 let project = Project::test(
1292 fs.clone(),
1293 [
1294 path!("/workspace/frontend").as_ref(),
1295 path!("/workspace/backend").as_ref(),
1296 path!("/workspace/shared").as_ref(),
1297 ],
1298 cx,
1299 )
1300 .await;
1301 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1302 let context_server_registry =
1303 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1304 let model = Arc::new(FakeLanguageModel::default());
1305 let thread = cx.new(|cx| {
1306 Thread::new(
1307 project.clone(),
1308 cx.new(|_cx| ProjectContext::default()),
1309 context_server_registry.clone(),
1310 Templates::new(),
1311 Some(model.clone()),
1312 cx,
1313 )
1314 });
1315 let tool = Arc::new(EditFileTool::new(
1316 project.clone(),
1317 thread.downgrade(),
1318 language_registry,
1319 ));
1320
1321 // Test files in different worktrees
1322 let test_cases = vec![
1323 ("frontend/src/main.js", false, "File in first worktree"),
1324 ("backend/src/main.rs", false, "File in second worktree"),
1325 (
1326 "shared/.zed/settings.json",
1327 true,
1328 ".zed file in third worktree",
1329 ),
1330 ("/etc/hosts", true, "Absolute path outside all worktrees"),
1331 (
1332 "../outside/file.txt",
1333 true,
1334 "Relative path outside worktrees",
1335 ),
1336 ];
1337
1338 for (path, should_confirm, description) in test_cases {
1339 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1340 let auth = cx.update(|cx| {
1341 tool.authorize(
1342 &EditFileToolInput {
1343 display_description: "Edit file".into(),
1344 path: path.into(),
1345 mode: EditFileMode::Edit,
1346 },
1347 &stream_tx,
1348 cx,
1349 )
1350 });
1351
1352 if should_confirm {
1353 stream_rx.expect_authorization().await;
1354 } else {
1355 auth.await.unwrap();
1356 assert!(
1357 stream_rx.try_next().is_err(),
1358 "Failed for case: {} - path: {} - expected no confirmation but got one",
1359 description,
1360 path
1361 );
1362 }
1363 }
1364 }
1365
1366 #[gpui::test]
1367 async fn test_needs_confirmation_edge_cases(cx: &mut TestAppContext) {
1368 init_test(cx);
1369 let fs = project::FakeFs::new(cx.executor());
1370 fs.insert_tree(
1371 "/project",
1372 json!({
1373 ".zed": {
1374 "settings.json": "{}"
1375 },
1376 "src": {
1377 ".zed": {
1378 "local.json": "{}"
1379 }
1380 }
1381 }),
1382 )
1383 .await;
1384 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1385 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1386 let context_server_registry =
1387 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1388 let model = Arc::new(FakeLanguageModel::default());
1389 let thread = cx.new(|cx| {
1390 Thread::new(
1391 project.clone(),
1392 cx.new(|_cx| ProjectContext::default()),
1393 context_server_registry.clone(),
1394 Templates::new(),
1395 Some(model.clone()),
1396 cx,
1397 )
1398 });
1399 let tool = Arc::new(EditFileTool::new(
1400 project.clone(),
1401 thread.downgrade(),
1402 language_registry,
1403 ));
1404
1405 // Test edge cases
1406 let test_cases = vec![
1407 // Empty path - find_project_path returns Some for empty paths
1408 ("", false, "Empty path is treated as project root"),
1409 // Root directory
1410 ("/", true, "Root directory should be outside project"),
1411 // Parent directory references - find_project_path resolves these
1412 (
1413 "project/../other",
1414 false,
1415 "Path with .. is resolved by find_project_path",
1416 ),
1417 (
1418 "project/./src/file.rs",
1419 false,
1420 "Path with . should work normally",
1421 ),
1422 // Windows-style paths (if on Windows)
1423 #[cfg(target_os = "windows")]
1424 ("C:\\Windows\\System32\\hosts", true, "Windows system path"),
1425 #[cfg(target_os = "windows")]
1426 ("project\\src\\main.rs", false, "Windows-style project path"),
1427 ];
1428
1429 for (path, should_confirm, description) in test_cases {
1430 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1431 let auth = cx.update(|cx| {
1432 tool.authorize(
1433 &EditFileToolInput {
1434 display_description: "Edit file".into(),
1435 path: path.into(),
1436 mode: EditFileMode::Edit,
1437 },
1438 &stream_tx,
1439 cx,
1440 )
1441 });
1442
1443 if should_confirm {
1444 stream_rx.expect_authorization().await;
1445 } else {
1446 auth.await.unwrap();
1447 assert!(
1448 stream_rx.try_next().is_err(),
1449 "Failed for case: {} - path: {} - expected no confirmation but got one",
1450 description,
1451 path
1452 );
1453 }
1454 }
1455 }
1456
1457 #[gpui::test]
1458 async fn test_needs_confirmation_with_different_modes(cx: &mut TestAppContext) {
1459 init_test(cx);
1460 let fs = project::FakeFs::new(cx.executor());
1461 fs.insert_tree(
1462 "/project",
1463 json!({
1464 "existing.txt": "content",
1465 ".zed": {
1466 "settings.json": "{}"
1467 }
1468 }),
1469 )
1470 .await;
1471 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1472 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1473 let context_server_registry =
1474 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1475 let model = Arc::new(FakeLanguageModel::default());
1476 let thread = cx.new(|cx| {
1477 Thread::new(
1478 project.clone(),
1479 cx.new(|_cx| ProjectContext::default()),
1480 context_server_registry.clone(),
1481 Templates::new(),
1482 Some(model.clone()),
1483 cx,
1484 )
1485 });
1486 let tool = Arc::new(EditFileTool::new(
1487 project.clone(),
1488 thread.downgrade(),
1489 language_registry,
1490 ));
1491
1492 // Test different EditFileMode values
1493 let modes = vec![
1494 EditFileMode::Edit,
1495 EditFileMode::Create,
1496 EditFileMode::Overwrite,
1497 ];
1498
1499 for mode in modes {
1500 // Test .zed path with different modes
1501 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1502 let _auth = cx.update(|cx| {
1503 tool.authorize(
1504 &EditFileToolInput {
1505 display_description: "Edit settings".into(),
1506 path: "project/.zed/settings.json".into(),
1507 mode: mode.clone(),
1508 },
1509 &stream_tx,
1510 cx,
1511 )
1512 });
1513
1514 stream_rx.expect_authorization().await;
1515
1516 // Test outside path with different modes
1517 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1518 let _auth = cx.update(|cx| {
1519 tool.authorize(
1520 &EditFileToolInput {
1521 display_description: "Edit file".into(),
1522 path: "/outside/file.txt".into(),
1523 mode: mode.clone(),
1524 },
1525 &stream_tx,
1526 cx,
1527 )
1528 });
1529
1530 stream_rx.expect_authorization().await;
1531
1532 // Test normal path with different modes
1533 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1534 cx.update(|cx| {
1535 tool.authorize(
1536 &EditFileToolInput {
1537 display_description: "Edit file".into(),
1538 path: "project/normal.txt".into(),
1539 mode: mode.clone(),
1540 },
1541 &stream_tx,
1542 cx,
1543 )
1544 })
1545 .await
1546 .unwrap();
1547 assert!(stream_rx.try_next().is_err());
1548 }
1549 }
1550
1551 #[gpui::test]
1552 async fn test_initial_title_with_partial_input(cx: &mut TestAppContext) {
1553 init_test(cx);
1554 let fs = project::FakeFs::new(cx.executor());
1555 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1556 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1557 let context_server_registry =
1558 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1559 let model = Arc::new(FakeLanguageModel::default());
1560 let thread = cx.new(|cx| {
1561 Thread::new(
1562 project.clone(),
1563 cx.new(|_cx| ProjectContext::default()),
1564 context_server_registry,
1565 Templates::new(),
1566 Some(model.clone()),
1567 cx,
1568 )
1569 });
1570 let tool = Arc::new(EditFileTool::new(
1571 project,
1572 thread.downgrade(),
1573 language_registry,
1574 ));
1575
1576 cx.update(|cx| {
1577 // ...
1578 assert_eq!(
1579 tool.initial_title(
1580 Err(json!({
1581 "path": "src/main.rs",
1582 "display_description": "",
1583 "old_string": "old code",
1584 "new_string": "new code"
1585 })),
1586 cx
1587 ),
1588 "src/main.rs"
1589 );
1590 assert_eq!(
1591 tool.initial_title(
1592 Err(json!({
1593 "path": "",
1594 "display_description": "Fix error handling",
1595 "old_string": "old code",
1596 "new_string": "new code"
1597 })),
1598 cx
1599 ),
1600 "Fix error handling"
1601 );
1602 assert_eq!(
1603 tool.initial_title(
1604 Err(json!({
1605 "path": "src/main.rs",
1606 "display_description": "Fix error handling",
1607 "old_string": "old code",
1608 "new_string": "new code"
1609 })),
1610 cx
1611 ),
1612 "src/main.rs"
1613 );
1614 assert_eq!(
1615 tool.initial_title(
1616 Err(json!({
1617 "path": "",
1618 "display_description": "",
1619 "old_string": "old code",
1620 "new_string": "new code"
1621 })),
1622 cx
1623 ),
1624 DEFAULT_UI_TEXT
1625 );
1626 assert_eq!(
1627 tool.initial_title(Err(serde_json::Value::Null), cx),
1628 DEFAULT_UI_TEXT
1629 );
1630 });
1631 }
1632
1633 #[gpui::test]
1634 async fn test_diff_finalization(cx: &mut TestAppContext) {
1635 init_test(cx);
1636 let fs = project::FakeFs::new(cx.executor());
1637 fs.insert_tree("/", json!({"main.rs": ""})).await;
1638
1639 let project = Project::test(fs.clone(), [path!("/").as_ref()], cx).await;
1640 let languages = project.read_with(cx, |project, _cx| project.languages().clone());
1641 let context_server_registry =
1642 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1643 let model = Arc::new(FakeLanguageModel::default());
1644 let thread = cx.new(|cx| {
1645 Thread::new(
1646 project.clone(),
1647 cx.new(|_cx| ProjectContext::default()),
1648 context_server_registry.clone(),
1649 Templates::new(),
1650 Some(model.clone()),
1651 cx,
1652 )
1653 });
1654
1655 // Ensure the diff is finalized after the edit completes.
1656 {
1657 let tool = Arc::new(EditFileTool::new(
1658 project.clone(),
1659 thread.downgrade(),
1660 languages.clone(),
1661 ));
1662 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1663 let edit = cx.update(|cx| {
1664 tool.run(
1665 EditFileToolInput {
1666 display_description: "Edit file".into(),
1667 path: path!("/main.rs").into(),
1668 mode: EditFileMode::Edit,
1669 },
1670 stream_tx,
1671 cx,
1672 )
1673 });
1674 stream_rx.expect_update_fields().await;
1675 let diff = stream_rx.expect_diff().await;
1676 diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Pending(_))));
1677 cx.run_until_parked();
1678 model.end_last_completion_stream();
1679 edit.await.unwrap();
1680 diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Finalized(_))));
1681 }
1682
1683 // Ensure the diff is finalized if an error occurs while editing.
1684 {
1685 model.forbid_requests();
1686 let tool = Arc::new(EditFileTool::new(
1687 project.clone(),
1688 thread.downgrade(),
1689 languages.clone(),
1690 ));
1691 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1692 let edit = cx.update(|cx| {
1693 tool.run(
1694 EditFileToolInput {
1695 display_description: "Edit file".into(),
1696 path: path!("/main.rs").into(),
1697 mode: EditFileMode::Edit,
1698 },
1699 stream_tx,
1700 cx,
1701 )
1702 });
1703 stream_rx.expect_update_fields().await;
1704 let diff = stream_rx.expect_diff().await;
1705 diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Pending(_))));
1706 edit.await.unwrap_err();
1707 diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Finalized(_))));
1708 model.allow_requests();
1709 }
1710
1711 // Ensure the diff is finalized if the tool call gets dropped.
1712 {
1713 let tool = Arc::new(EditFileTool::new(
1714 project.clone(),
1715 thread.downgrade(),
1716 languages.clone(),
1717 ));
1718 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1719 let edit = cx.update(|cx| {
1720 tool.run(
1721 EditFileToolInput {
1722 display_description: "Edit file".into(),
1723 path: path!("/main.rs").into(),
1724 mode: EditFileMode::Edit,
1725 },
1726 stream_tx,
1727 cx,
1728 )
1729 });
1730 stream_rx.expect_update_fields().await;
1731 let diff = stream_rx.expect_diff().await;
1732 diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Pending(_))));
1733 drop(edit);
1734 cx.run_until_parked();
1735 diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Finalized(_))));
1736 }
1737 }
1738
1739 fn init_test(cx: &mut TestAppContext) {
1740 cx.update(|cx| {
1741 let settings_store = SettingsStore::test(cx);
1742 cx.set_global(settings_store);
1743 language::init(cx);
1744 TelemetrySettings::register(cx);
1745 agent_settings::AgentSettings::register(cx);
1746 Project::init_settings(cx);
1747 });
1748 }
1749}