1use crate::{AgentTool, Thread, ToolCallEventStream};
2use acp_thread::Diff;
3use agent_client_protocol::{self as acp, ToolCallLocation, ToolCallUpdateFields};
4use anyhow::{Context as _, Result, anyhow};
5use assistant_tools::edit_agent::{EditAgent, EditAgentOutput, EditAgentOutputEvent, EditFormat};
6use cloud_llm_client::CompletionIntent;
7use collections::HashSet;
8use gpui::{App, AppContext, AsyncApp, Entity, Task, WeakEntity};
9use indoc::formatdoc;
10use language::language_settings::{self, FormatOnSave};
11use language::{LanguageRegistry, ToPoint};
12use language_model::LanguageModelToolResultContent;
13use paths;
14use project::lsp_store::{FormatTrigger, LspFormatTarget};
15use project::{Project, ProjectPath};
16use schemars::JsonSchema;
17use serde::{Deserialize, Serialize};
18use settings::Settings;
19use smol::stream::StreamExt as _;
20use std::ffi::OsStr;
21use std::path::{Path, PathBuf};
22use std::sync::Arc;
23use ui::SharedString;
24use util::ResultExt;
25use util::rel_path::RelPath;
26
27const DEFAULT_UI_TEXT: &str = "Editing file";
28
29/// This is a tool for creating a new file or editing an existing file. For moving or renaming files, you should generally use the `terminal` tool with the 'mv' command instead.
30///
31/// Before using this tool:
32///
33/// 1. Use the `read_file` tool to understand the file's contents and context
34///
35/// 2. Verify the directory path is correct (only applicable when creating new files):
36/// - Use the `list_directory` tool to verify the parent directory exists and is the correct location
37#[derive(Debug, Serialize, Deserialize, JsonSchema)]
38pub struct EditFileToolInput {
39 /// A one-line, user-friendly markdown description of the edit. This will be shown in the UI and also passed to another model to perform the edit.
40 ///
41 /// Be terse, but also descriptive in what you want to achieve with this edit. Avoid generic instructions.
42 ///
43 /// NEVER mention the file path in this description.
44 ///
45 /// <example>Fix API endpoint URLs</example>
46 /// <example>Update copyright year in `page_footer`</example>
47 ///
48 /// Make sure to include this field before all the others in the input object so that we can display it immediately.
49 pub display_description: String,
50
51 /// The full path of the file to create or modify in the project.
52 ///
53 /// WARNING: When specifying which file path need changing, you MUST start each path with one of the project's root directories.
54 ///
55 /// The following examples assume we have two root directories in the project:
56 /// - /a/b/backend
57 /// - /c/d/frontend
58 ///
59 /// <example>
60 /// `backend/src/main.rs`
61 ///
62 /// Notice how the file path starts with `backend`. Without that, the path would be ambiguous and the call would fail!
63 /// </example>
64 ///
65 /// <example>
66 /// `frontend/db.js`
67 /// </example>
68 pub path: PathBuf,
69 /// The mode of operation on the file. Possible values:
70 /// - 'edit': Make granular edits to an existing file.
71 /// - 'create': Create a new file if it doesn't exist.
72 /// - 'overwrite': Replace the entire contents of an existing file.
73 ///
74 /// When a file already exists or you just created it, prefer editing it as opposed to recreating it from scratch.
75 pub mode: EditFileMode,
76}
77
78#[derive(Debug, Serialize, Deserialize, JsonSchema)]
79struct EditFileToolPartialInput {
80 #[serde(default)]
81 path: String,
82 #[serde(default)]
83 display_description: String,
84}
85
86#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
87#[serde(rename_all = "lowercase")]
88#[schemars(inline)]
89pub enum EditFileMode {
90 Edit,
91 Create,
92 Overwrite,
93}
94
95#[derive(Debug, Serialize, Deserialize)]
96pub struct EditFileToolOutput {
97 #[serde(alias = "original_path")]
98 input_path: PathBuf,
99 new_text: String,
100 old_text: Arc<String>,
101 #[serde(default)]
102 diff: String,
103 #[serde(alias = "raw_output")]
104 edit_agent_output: EditAgentOutput,
105}
106
107impl From<EditFileToolOutput> for LanguageModelToolResultContent {
108 fn from(output: EditFileToolOutput) -> Self {
109 if output.diff.is_empty() {
110 "No edits were made.".into()
111 } else {
112 format!(
113 "Edited {}:\n\n```diff\n{}\n```",
114 output.input_path.display(),
115 output.diff
116 )
117 .into()
118 }
119 }
120}
121
122pub struct EditFileTool {
123 thread: WeakEntity<Thread>,
124 language_registry: Arc<LanguageRegistry>,
125 project: Entity<Project>,
126}
127
128impl EditFileTool {
129 pub fn new(
130 project: Entity<Project>,
131 thread: WeakEntity<Thread>,
132 language_registry: Arc<LanguageRegistry>,
133 ) -> Self {
134 Self {
135 project,
136 thread,
137 language_registry,
138 }
139 }
140
141 fn authorize(
142 &self,
143 input: &EditFileToolInput,
144 event_stream: &ToolCallEventStream,
145 cx: &mut App,
146 ) -> Task<Result<()>> {
147 if agent_settings::AgentSettings::get_global(cx).always_allow_tool_actions {
148 return Task::ready(Ok(()));
149 }
150
151 // If any path component matches the local settings folder, then this could affect
152 // the editor in ways beyond the project source, so prompt.
153 let local_settings_folder = paths::local_settings_folder_name();
154 let path = Path::new(&input.path);
155 if path.components().any(|component| {
156 component.as_os_str() == <_ as AsRef<OsStr>>::as_ref(&local_settings_folder)
157 }) {
158 return event_stream.authorize(
159 format!("{} (local settings)", input.display_description),
160 cx,
161 );
162 }
163
164 // It's also possible that the global config dir is configured to be inside the project,
165 // so check for that edge case too.
166 // TODO this is broken when remoting
167 if let Ok(canonical_path) = std::fs::canonicalize(&input.path)
168 && canonical_path.starts_with(paths::config_dir())
169 {
170 return event_stream.authorize(
171 format!("{} (global settings)", input.display_description),
172 cx,
173 );
174 }
175
176 // Check if path is inside the global config directory
177 // First check if it's already inside project - if not, try to canonicalize
178 let Ok(project_path) = self.thread.read_with(cx, |thread, cx| {
179 thread.project().read(cx).find_project_path(&input.path, cx)
180 }) else {
181 return Task::ready(Err(anyhow!("thread was dropped")));
182 };
183
184 // If the path is inside the project, and it's not one of the above edge cases,
185 // then no confirmation is necessary. Otherwise, confirmation is necessary.
186 if project_path.is_some() {
187 Task::ready(Ok(()))
188 } else {
189 event_stream.authorize(&input.display_description, cx)
190 }
191 }
192}
193
194impl AgentTool for EditFileTool {
195 type Input = EditFileToolInput;
196 type Output = EditFileToolOutput;
197
198 fn name() -> &'static str {
199 "edit_file"
200 }
201
202 fn kind() -> acp::ToolKind {
203 acp::ToolKind::Edit
204 }
205
206 fn initial_title(
207 &self,
208 input: Result<Self::Input, serde_json::Value>,
209 cx: &mut App,
210 ) -> SharedString {
211 match input {
212 Ok(input) => self
213 .project
214 .read(cx)
215 .find_project_path(&input.path, cx)
216 .and_then(|project_path| {
217 self.project
218 .read(cx)
219 .short_full_path_for_project_path(&project_path, cx)
220 })
221 .unwrap_or(input.path.to_string_lossy().to_string())
222 .into(),
223 Err(raw_input) => {
224 if let Some(input) =
225 serde_json::from_value::<EditFileToolPartialInput>(raw_input).ok()
226 {
227 let path = input.path.trim();
228 if !path.is_empty() {
229 return self
230 .project
231 .read(cx)
232 .find_project_path(&input.path, cx)
233 .and_then(|project_path| {
234 self.project
235 .read(cx)
236 .short_full_path_for_project_path(&project_path, cx)
237 })
238 .unwrap_or(input.path)
239 .into();
240 }
241
242 let description = input.display_description.trim();
243 if !description.is_empty() {
244 return description.to_string().into();
245 }
246 }
247
248 DEFAULT_UI_TEXT.into()
249 }
250 }
251 }
252
253 fn run(
254 self: Arc<Self>,
255 input: Self::Input,
256 event_stream: ToolCallEventStream,
257 cx: &mut App,
258 ) -> Task<Result<Self::Output>> {
259 let Ok(project) = self
260 .thread
261 .read_with(cx, |thread, _cx| thread.project().clone())
262 else {
263 return Task::ready(Err(anyhow!("thread was dropped")));
264 };
265 let project_path = match resolve_path(&input, project.clone(), cx) {
266 Ok(path) => path,
267 Err(err) => return Task::ready(Err(anyhow!(err))),
268 };
269 let abs_path = project.read(cx).absolute_path(&project_path, cx);
270 if let Some(abs_path) = abs_path.clone() {
271 event_stream.update_fields(ToolCallUpdateFields {
272 locations: Some(vec![acp::ToolCallLocation {
273 path: abs_path,
274 line: None,
275 meta: None,
276 }]),
277 ..Default::default()
278 });
279 }
280
281 let authorize = self.authorize(&input, &event_stream, cx);
282 cx.spawn(async move |cx: &mut AsyncApp| {
283 authorize.await?;
284
285 let (request, model, action_log) = self.thread.update(cx, |thread, cx| {
286 let request = thread.build_completion_request(CompletionIntent::ToolResults, cx);
287 (request, thread.model().cloned(), thread.action_log().clone())
288 })?;
289 let request = request?;
290 let model = model.context("No language model configured")?;
291
292 let edit_format = EditFormat::from_model(model.clone())?;
293 let edit_agent = EditAgent::new(
294 model,
295 project.clone(),
296 action_log.clone(),
297 // TODO: move edit agent to this crate so we can use our templates
298 assistant_tools::templates::Templates::new(),
299 edit_format,
300 );
301
302 let buffer = project
303 .update(cx, |project, cx| {
304 project.open_buffer(project_path.clone(), cx)
305 })?
306 .await?;
307
308 let diff = cx.new(|cx| Diff::new(buffer.clone(), cx))?;
309 event_stream.update_diff(diff.clone());
310 let _finalize_diff = util::defer({
311 let diff = diff.downgrade();
312 let mut cx = cx.clone();
313 move || {
314 diff.update(&mut cx, |diff, cx| diff.finalize(cx)).ok();
315 }
316 });
317
318 let old_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
319 let old_text = cx
320 .background_spawn({
321 let old_snapshot = old_snapshot.clone();
322 async move { Arc::new(old_snapshot.text()) }
323 })
324 .await;
325
326
327 let (output, mut events) = if matches!(input.mode, EditFileMode::Edit) {
328 edit_agent.edit(
329 buffer.clone(),
330 input.display_description.clone(),
331 &request,
332 cx,
333 )
334 } else {
335 edit_agent.overwrite(
336 buffer.clone(),
337 input.display_description.clone(),
338 &request,
339 cx,
340 )
341 };
342
343 let mut hallucinated_old_text = false;
344 let mut ambiguous_ranges = Vec::new();
345 let mut emitted_location = false;
346 while let Some(event) = events.next().await {
347 match event {
348 EditAgentOutputEvent::Edited(range) => {
349 if !emitted_location {
350 let line = buffer.update(cx, |buffer, _cx| {
351 range.start.to_point(&buffer.snapshot()).row
352 }).ok();
353 if let Some(abs_path) = abs_path.clone() {
354 event_stream.update_fields(ToolCallUpdateFields {
355 locations: Some(vec![ToolCallLocation { path: abs_path, line, meta: None }]),
356 ..Default::default()
357 });
358 }
359 emitted_location = true;
360 }
361 },
362 EditAgentOutputEvent::UnresolvedEditRange => hallucinated_old_text = true,
363 EditAgentOutputEvent::AmbiguousEditRange(ranges) => ambiguous_ranges = ranges,
364 EditAgentOutputEvent::ResolvingEditRange(range) => {
365 diff.update(cx, |card, cx| card.reveal_range(range.clone(), cx))?;
366 // if !emitted_location {
367 // let line = buffer.update(cx, |buffer, _cx| {
368 // range.start.to_point(&buffer.snapshot()).row
369 // }).ok();
370 // if let Some(abs_path) = abs_path.clone() {
371 // event_stream.update_fields(ToolCallUpdateFields {
372 // locations: Some(vec![ToolCallLocation { path: abs_path, line }]),
373 // ..Default::default()
374 // });
375 // }
376 // }
377 }
378 }
379 }
380
381 // If format_on_save is enabled, format the buffer
382 let format_on_save_enabled = buffer
383 .read_with(cx, |buffer, cx| {
384 let settings = language_settings::language_settings(
385 buffer.language().map(|l| l.name()),
386 buffer.file(),
387 cx,
388 );
389 settings.format_on_save != FormatOnSave::Off
390 })
391 .unwrap_or(false);
392
393 let edit_agent_output = output.await?;
394
395 if format_on_save_enabled {
396 action_log.update(cx, |log, cx| {
397 log.buffer_edited(buffer.clone(), cx);
398 })?;
399
400 let format_task = project.update(cx, |project, cx| {
401 project.format(
402 HashSet::from_iter([buffer.clone()]),
403 LspFormatTarget::Buffers,
404 false, // Don't push to history since the tool did it.
405 FormatTrigger::Save,
406 cx,
407 )
408 })?;
409 format_task.await.log_err();
410 }
411
412 project
413 .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))?
414 .await?;
415
416 action_log.update(cx, |log, cx| {
417 log.buffer_edited(buffer.clone(), cx);
418 })?;
419
420 let new_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
421 let (new_text, unified_diff) = cx
422 .background_spawn({
423 let new_snapshot = new_snapshot.clone();
424 let old_text = old_text.clone();
425 async move {
426 let new_text = new_snapshot.text();
427 let diff = language::unified_diff(&old_text, &new_text);
428 (new_text, diff)
429 }
430 })
431 .await;
432
433 let input_path = input.path.display();
434 if unified_diff.is_empty() {
435 anyhow::ensure!(
436 !hallucinated_old_text,
437 formatdoc! {"
438 Some edits were produced but none of them could be applied.
439 Read the relevant sections of {input_path} again so that
440 I can perform the requested edits.
441 "}
442 );
443 anyhow::ensure!(
444 ambiguous_ranges.is_empty(),
445 {
446 let line_numbers = ambiguous_ranges
447 .iter()
448 .map(|range| range.start.to_string())
449 .collect::<Vec<_>>()
450 .join(", ");
451 formatdoc! {"
452 <old_text> matches more than one position in the file (lines: {line_numbers}). Read the
453 relevant sections of {input_path} again and extend <old_text> so
454 that I can perform the requested edits.
455 "}
456 }
457 );
458 }
459
460 Ok(EditFileToolOutput {
461 input_path: input.path,
462 new_text,
463 old_text,
464 diff: unified_diff,
465 edit_agent_output,
466 })
467 })
468 }
469
470 fn replay(
471 &self,
472 _input: Self::Input,
473 output: Self::Output,
474 event_stream: ToolCallEventStream,
475 cx: &mut App,
476 ) -> Result<()> {
477 event_stream.update_diff(cx.new(|cx| {
478 Diff::finalized(
479 output.input_path.to_string_lossy().to_string(),
480 Some(output.old_text.to_string()),
481 output.new_text,
482 self.language_registry.clone(),
483 cx,
484 )
485 }));
486 Ok(())
487 }
488}
489
490/// Validate that the file path is valid, meaning:
491///
492/// - For `edit` and `overwrite`, the path must point to an existing file.
493/// - For `create`, the file must not already exist, but it's parent dir must exist.
494fn resolve_path(
495 input: &EditFileToolInput,
496 project: Entity<Project>,
497 cx: &mut App,
498) -> Result<ProjectPath> {
499 let project = project.read(cx);
500
501 match input.mode {
502 EditFileMode::Edit | EditFileMode::Overwrite => {
503 let path = project
504 .find_project_path(&input.path, cx)
505 .context("Can't edit file: path not found")?;
506
507 let entry = project
508 .entry_for_path(&path, cx)
509 .context("Can't edit file: path not found")?;
510
511 anyhow::ensure!(entry.is_file(), "Can't edit file: path is a directory");
512 Ok(path)
513 }
514
515 EditFileMode::Create => {
516 if let Some(path) = project.find_project_path(&input.path, cx) {
517 anyhow::ensure!(
518 project.entry_for_path(&path, cx).is_none(),
519 "Can't create file: file already exists"
520 );
521 }
522
523 let parent_path = input
524 .path
525 .parent()
526 .context("Can't create file: incorrect path")?;
527
528 let parent_project_path = project.find_project_path(&parent_path, cx);
529
530 let parent_entry = parent_project_path
531 .as_ref()
532 .and_then(|path| project.entry_for_path(path, cx))
533 .context("Can't create file: parent directory doesn't exist")?;
534
535 anyhow::ensure!(
536 parent_entry.is_dir(),
537 "Can't create file: parent is not a directory"
538 );
539
540 let file_name = input
541 .path
542 .file_name()
543 .and_then(|file_name| file_name.to_str())
544 .and_then(|file_name| RelPath::new(file_name).ok())
545 .context("Can't create file: invalid filename")?;
546
547 let new_file_path = parent_project_path.map(|parent| ProjectPath {
548 path: parent.path.join(file_name),
549 ..parent
550 });
551
552 new_file_path.context("Can't create file")
553 }
554 }
555}
556
557#[cfg(test)]
558mod tests {
559 use super::*;
560 use crate::{ContextServerRegistry, Templates};
561 use client::TelemetrySettings;
562 use fs::Fs;
563 use gpui::{TestAppContext, UpdateGlobal};
564 use language_model::fake_provider::FakeLanguageModel;
565 use prompt_store::ProjectContext;
566 use serde_json::json;
567 use settings::SettingsStore;
568 use util::path;
569
570 #[gpui::test]
571 async fn test_edit_nonexistent_file(cx: &mut TestAppContext) {
572 init_test(cx);
573
574 let fs = project::FakeFs::new(cx.executor());
575 fs.insert_tree("/root", json!({})).await;
576 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
577 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
578 let context_server_registry =
579 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
580 let model = Arc::new(FakeLanguageModel::default());
581 let thread = cx.new(|cx| {
582 Thread::new(
583 project.clone(),
584 cx.new(|_cx| ProjectContext::default()),
585 context_server_registry,
586 Templates::new(),
587 Some(model),
588 cx,
589 )
590 });
591 let result = cx
592 .update(|cx| {
593 let input = EditFileToolInput {
594 display_description: "Some edit".into(),
595 path: "root/nonexistent_file.txt".into(),
596 mode: EditFileMode::Edit,
597 };
598 Arc::new(EditFileTool::new(
599 project,
600 thread.downgrade(),
601 language_registry,
602 ))
603 .run(input, ToolCallEventStream::test().0, cx)
604 })
605 .await;
606 assert_eq!(
607 result.unwrap_err().to_string(),
608 "Can't edit file: path not found"
609 );
610 }
611
612 #[gpui::test]
613 async fn test_resolve_path_for_creating_file(cx: &mut TestAppContext) {
614 let mode = &EditFileMode::Create;
615
616 let result = test_resolve_path(mode, "root/new.txt", cx);
617 assert_resolved_path_eq(result.await, "new.txt");
618
619 let result = test_resolve_path(mode, "new.txt", cx);
620 assert_resolved_path_eq(result.await, "new.txt");
621
622 let result = test_resolve_path(mode, "dir/new.txt", cx);
623 assert_resolved_path_eq(result.await, "dir/new.txt");
624
625 let result = test_resolve_path(mode, "root/dir/subdir/existing.txt", cx);
626 assert_eq!(
627 result.await.unwrap_err().to_string(),
628 "Can't create file: file already exists"
629 );
630
631 let result = test_resolve_path(mode, "root/dir/nonexistent_dir/new.txt", cx);
632 assert_eq!(
633 result.await.unwrap_err().to_string(),
634 "Can't create file: parent directory doesn't exist"
635 );
636 }
637
638 #[gpui::test]
639 async fn test_resolve_path_for_editing_file(cx: &mut TestAppContext) {
640 let mode = &EditFileMode::Edit;
641
642 let path_with_root = "root/dir/subdir/existing.txt";
643 let path_without_root = "dir/subdir/existing.txt";
644 let result = test_resolve_path(mode, path_with_root, cx);
645 assert_resolved_path_eq(result.await, path_without_root);
646
647 let result = test_resolve_path(mode, path_without_root, cx);
648 assert_resolved_path_eq(result.await, path_without_root);
649
650 let result = test_resolve_path(mode, "root/nonexistent.txt", cx);
651 assert_eq!(
652 result.await.unwrap_err().to_string(),
653 "Can't edit file: path not found"
654 );
655
656 let result = test_resolve_path(mode, "root/dir", cx);
657 assert_eq!(
658 result.await.unwrap_err().to_string(),
659 "Can't edit file: path is a directory"
660 );
661 }
662
663 async fn test_resolve_path(
664 mode: &EditFileMode,
665 path: &str,
666 cx: &mut TestAppContext,
667 ) -> anyhow::Result<ProjectPath> {
668 init_test(cx);
669
670 let fs = project::FakeFs::new(cx.executor());
671 fs.insert_tree(
672 "/root",
673 json!({
674 "dir": {
675 "subdir": {
676 "existing.txt": "hello"
677 }
678 }
679 }),
680 )
681 .await;
682 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
683
684 let input = EditFileToolInput {
685 display_description: "Some edit".into(),
686 path: path.into(),
687 mode: mode.clone(),
688 };
689
690 cx.update(|cx| resolve_path(&input, project, cx))
691 }
692
693 #[track_caller]
694 fn assert_resolved_path_eq(path: anyhow::Result<ProjectPath>, expected: &str) {
695 let actual = path.expect("Should return valid path").path;
696 let actual = actual.as_str();
697 assert_eq!(actual, expected);
698 }
699
700 #[gpui::test]
701 async fn test_format_on_save(cx: &mut TestAppContext) {
702 init_test(cx);
703
704 let fs = project::FakeFs::new(cx.executor());
705 fs.insert_tree("/root", json!({"src": {}})).await;
706
707 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
708
709 // Set up a Rust language with LSP formatting support
710 let rust_language = Arc::new(language::Language::new(
711 language::LanguageConfig {
712 name: "Rust".into(),
713 matcher: language::LanguageMatcher {
714 path_suffixes: vec!["rs".to_string()],
715 ..Default::default()
716 },
717 ..Default::default()
718 },
719 None,
720 ));
721
722 // Register the language and fake LSP
723 let language_registry = project.read_with(cx, |project, _| project.languages().clone());
724 language_registry.add(rust_language);
725
726 let mut fake_language_servers = language_registry.register_fake_lsp(
727 "Rust",
728 language::FakeLspAdapter {
729 capabilities: lsp::ServerCapabilities {
730 document_formatting_provider: Some(lsp::OneOf::Left(true)),
731 ..Default::default()
732 },
733 ..Default::default()
734 },
735 );
736
737 // Create the file
738 fs.save(
739 path!("/root/src/main.rs").as_ref(),
740 &"initial content".into(),
741 language::LineEnding::Unix,
742 )
743 .await
744 .unwrap();
745
746 // Open the buffer to trigger LSP initialization
747 let buffer = project
748 .update(cx, |project, cx| {
749 project.open_local_buffer(path!("/root/src/main.rs"), cx)
750 })
751 .await
752 .unwrap();
753
754 // Register the buffer with language servers
755 let _handle = project.update(cx, |project, cx| {
756 project.register_buffer_with_language_servers(&buffer, cx)
757 });
758
759 const UNFORMATTED_CONTENT: &str = "fn main() {println!(\"Hello!\");}\n";
760 const FORMATTED_CONTENT: &str =
761 "This file was formatted by the fake formatter in the test.\n";
762
763 // Get the fake language server and set up formatting handler
764 let fake_language_server = fake_language_servers.next().await.unwrap();
765 fake_language_server.set_request_handler::<lsp::request::Formatting, _, _>({
766 |_, _| async move {
767 Ok(Some(vec![lsp::TextEdit {
768 range: lsp::Range::new(lsp::Position::new(0, 0), lsp::Position::new(1, 0)),
769 new_text: FORMATTED_CONTENT.to_string(),
770 }]))
771 }
772 });
773
774 let context_server_registry =
775 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
776 let model = Arc::new(FakeLanguageModel::default());
777 let thread = cx.new(|cx| {
778 Thread::new(
779 project.clone(),
780 cx.new(|_cx| ProjectContext::default()),
781 context_server_registry,
782 Templates::new(),
783 Some(model.clone()),
784 cx,
785 )
786 });
787
788 // First, test with format_on_save enabled
789 cx.update(|cx| {
790 SettingsStore::update_global(cx, |store, cx| {
791 store.update_user_settings(cx, |settings| {
792 settings.project.all_languages.defaults.format_on_save = Some(FormatOnSave::On);
793 settings.project.all_languages.defaults.formatter =
794 Some(language::language_settings::SelectedFormatter::Auto);
795 });
796 });
797 });
798
799 // Have the model stream unformatted content
800 let edit_result = {
801 let edit_task = cx.update(|cx| {
802 let input = EditFileToolInput {
803 display_description: "Create main function".into(),
804 path: "root/src/main.rs".into(),
805 mode: EditFileMode::Overwrite,
806 };
807 Arc::new(EditFileTool::new(
808 project.clone(),
809 thread.downgrade(),
810 language_registry.clone(),
811 ))
812 .run(input, ToolCallEventStream::test().0, cx)
813 });
814
815 // Stream the unformatted content
816 cx.executor().run_until_parked();
817 model.send_last_completion_stream_text_chunk(UNFORMATTED_CONTENT.to_string());
818 model.end_last_completion_stream();
819
820 edit_task.await
821 };
822 assert!(edit_result.is_ok());
823
824 // Wait for any async operations (e.g. formatting) to complete
825 cx.executor().run_until_parked();
826
827 // Read the file to verify it was formatted automatically
828 let new_content = fs.load(path!("/root/src/main.rs").as_ref()).await.unwrap();
829 assert_eq!(
830 // Ignore carriage returns on Windows
831 new_content.replace("\r\n", "\n"),
832 FORMATTED_CONTENT,
833 "Code should be formatted when format_on_save is enabled"
834 );
835
836 let stale_buffer_count = thread
837 .read_with(cx, |thread, _cx| thread.action_log.clone())
838 .read_with(cx, |log, cx| log.stale_buffers(cx).count());
839
840 assert_eq!(
841 stale_buffer_count, 0,
842 "BUG: Buffer is incorrectly marked as stale after format-on-save. Found {} stale buffers. \
843 This causes the agent to think the file was modified externally when it was just formatted.",
844 stale_buffer_count
845 );
846
847 // Next, test with format_on_save disabled
848 cx.update(|cx| {
849 SettingsStore::update_global(cx, |store, cx| {
850 store.update_user_settings(cx, |settings| {
851 settings.project.all_languages.defaults.format_on_save =
852 Some(FormatOnSave::Off);
853 });
854 });
855 });
856
857 // Stream unformatted edits again
858 let edit_result = {
859 let edit_task = cx.update(|cx| {
860 let input = EditFileToolInput {
861 display_description: "Update main function".into(),
862 path: "root/src/main.rs".into(),
863 mode: EditFileMode::Overwrite,
864 };
865 Arc::new(EditFileTool::new(
866 project.clone(),
867 thread.downgrade(),
868 language_registry,
869 ))
870 .run(input, ToolCallEventStream::test().0, cx)
871 });
872
873 // Stream the unformatted content
874 cx.executor().run_until_parked();
875 model.send_last_completion_stream_text_chunk(UNFORMATTED_CONTENT.to_string());
876 model.end_last_completion_stream();
877
878 edit_task.await
879 };
880 assert!(edit_result.is_ok());
881
882 // Wait for any async operations (e.g. formatting) to complete
883 cx.executor().run_until_parked();
884
885 // Verify the file was not formatted
886 let new_content = fs.load(path!("/root/src/main.rs").as_ref()).await.unwrap();
887 assert_eq!(
888 // Ignore carriage returns on Windows
889 new_content.replace("\r\n", "\n"),
890 UNFORMATTED_CONTENT,
891 "Code should not be formatted when format_on_save is disabled"
892 );
893 }
894
895 #[gpui::test]
896 async fn test_remove_trailing_whitespace(cx: &mut TestAppContext) {
897 init_test(cx);
898
899 let fs = project::FakeFs::new(cx.executor());
900 fs.insert_tree("/root", json!({"src": {}})).await;
901
902 // Create a simple file with trailing whitespace
903 fs.save(
904 path!("/root/src/main.rs").as_ref(),
905 &"initial content".into(),
906 language::LineEnding::Unix,
907 )
908 .await
909 .unwrap();
910
911 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
912 let context_server_registry =
913 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
914 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
915 let model = Arc::new(FakeLanguageModel::default());
916 let thread = cx.new(|cx| {
917 Thread::new(
918 project.clone(),
919 cx.new(|_cx| ProjectContext::default()),
920 context_server_registry,
921 Templates::new(),
922 Some(model.clone()),
923 cx,
924 )
925 });
926
927 // First, test with remove_trailing_whitespace_on_save enabled
928 cx.update(|cx| {
929 SettingsStore::update_global(cx, |store, cx| {
930 store.update_user_settings(cx, |settings| {
931 settings
932 .project
933 .all_languages
934 .defaults
935 .remove_trailing_whitespace_on_save = Some(true);
936 });
937 });
938 });
939
940 const CONTENT_WITH_TRAILING_WHITESPACE: &str =
941 "fn main() { \n println!(\"Hello!\"); \n}\n";
942
943 // Have the model stream content that contains trailing whitespace
944 let edit_result = {
945 let edit_task = cx.update(|cx| {
946 let input = EditFileToolInput {
947 display_description: "Create main function".into(),
948 path: "root/src/main.rs".into(),
949 mode: EditFileMode::Overwrite,
950 };
951 Arc::new(EditFileTool::new(
952 project.clone(),
953 thread.downgrade(),
954 language_registry.clone(),
955 ))
956 .run(input, ToolCallEventStream::test().0, cx)
957 });
958
959 // Stream the content with trailing whitespace
960 cx.executor().run_until_parked();
961 model.send_last_completion_stream_text_chunk(
962 CONTENT_WITH_TRAILING_WHITESPACE.to_string(),
963 );
964 model.end_last_completion_stream();
965
966 edit_task.await
967 };
968 assert!(edit_result.is_ok());
969
970 // Wait for any async operations (e.g. formatting) to complete
971 cx.executor().run_until_parked();
972
973 // Read the file to verify trailing whitespace was removed automatically
974 assert_eq!(
975 // Ignore carriage returns on Windows
976 fs.load(path!("/root/src/main.rs").as_ref())
977 .await
978 .unwrap()
979 .replace("\r\n", "\n"),
980 "fn main() {\n println!(\"Hello!\");\n}\n",
981 "Trailing whitespace should be removed when remove_trailing_whitespace_on_save is enabled"
982 );
983
984 // Next, test with remove_trailing_whitespace_on_save disabled
985 cx.update(|cx| {
986 SettingsStore::update_global(cx, |store, cx| {
987 store.update_user_settings(cx, |settings| {
988 settings
989 .project
990 .all_languages
991 .defaults
992 .remove_trailing_whitespace_on_save = Some(false);
993 });
994 });
995 });
996
997 // Stream edits again with trailing whitespace
998 let edit_result = {
999 let edit_task = cx.update(|cx| {
1000 let input = EditFileToolInput {
1001 display_description: "Update main function".into(),
1002 path: "root/src/main.rs".into(),
1003 mode: EditFileMode::Overwrite,
1004 };
1005 Arc::new(EditFileTool::new(
1006 project.clone(),
1007 thread.downgrade(),
1008 language_registry,
1009 ))
1010 .run(input, ToolCallEventStream::test().0, cx)
1011 });
1012
1013 // Stream the content with trailing whitespace
1014 cx.executor().run_until_parked();
1015 model.send_last_completion_stream_text_chunk(
1016 CONTENT_WITH_TRAILING_WHITESPACE.to_string(),
1017 );
1018 model.end_last_completion_stream();
1019
1020 edit_task.await
1021 };
1022 assert!(edit_result.is_ok());
1023
1024 // Wait for any async operations (e.g. formatting) to complete
1025 cx.executor().run_until_parked();
1026
1027 // Verify the file still has trailing whitespace
1028 // Read the file again - it should still have trailing whitespace
1029 let final_content = fs.load(path!("/root/src/main.rs").as_ref()).await.unwrap();
1030 assert_eq!(
1031 // Ignore carriage returns on Windows
1032 final_content.replace("\r\n", "\n"),
1033 CONTENT_WITH_TRAILING_WHITESPACE,
1034 "Trailing whitespace should remain when remove_trailing_whitespace_on_save is disabled"
1035 );
1036 }
1037
1038 #[gpui::test]
1039 async fn test_authorize(cx: &mut TestAppContext) {
1040 init_test(cx);
1041 let fs = project::FakeFs::new(cx.executor());
1042 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
1043 let context_server_registry =
1044 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1045 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1046 let model = Arc::new(FakeLanguageModel::default());
1047 let thread = cx.new(|cx| {
1048 Thread::new(
1049 project.clone(),
1050 cx.new(|_cx| ProjectContext::default()),
1051 context_server_registry,
1052 Templates::new(),
1053 Some(model.clone()),
1054 cx,
1055 )
1056 });
1057 let tool = Arc::new(EditFileTool::new(
1058 project.clone(),
1059 thread.downgrade(),
1060 language_registry,
1061 ));
1062 fs.insert_tree("/root", json!({})).await;
1063
1064 // Test 1: Path with .zed component should require confirmation
1065 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1066 let _auth = cx.update(|cx| {
1067 tool.authorize(
1068 &EditFileToolInput {
1069 display_description: "test 1".into(),
1070 path: ".zed/settings.json".into(),
1071 mode: EditFileMode::Edit,
1072 },
1073 &stream_tx,
1074 cx,
1075 )
1076 });
1077
1078 let event = stream_rx.expect_authorization().await;
1079 assert_eq!(
1080 event.tool_call.fields.title,
1081 Some("test 1 (local settings)".into())
1082 );
1083
1084 // Test 2: Path outside project should require confirmation
1085 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1086 let _auth = cx.update(|cx| {
1087 tool.authorize(
1088 &EditFileToolInput {
1089 display_description: "test 2".into(),
1090 path: "/etc/hosts".into(),
1091 mode: EditFileMode::Edit,
1092 },
1093 &stream_tx,
1094 cx,
1095 )
1096 });
1097
1098 let event = stream_rx.expect_authorization().await;
1099 assert_eq!(event.tool_call.fields.title, Some("test 2".into()));
1100
1101 // Test 3: Relative path without .zed should not require confirmation
1102 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1103 cx.update(|cx| {
1104 tool.authorize(
1105 &EditFileToolInput {
1106 display_description: "test 3".into(),
1107 path: "root/src/main.rs".into(),
1108 mode: EditFileMode::Edit,
1109 },
1110 &stream_tx,
1111 cx,
1112 )
1113 })
1114 .await
1115 .unwrap();
1116 assert!(stream_rx.try_next().is_err());
1117
1118 // Test 4: Path with .zed in the middle should require confirmation
1119 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1120 let _auth = cx.update(|cx| {
1121 tool.authorize(
1122 &EditFileToolInput {
1123 display_description: "test 4".into(),
1124 path: "root/.zed/tasks.json".into(),
1125 mode: EditFileMode::Edit,
1126 },
1127 &stream_tx,
1128 cx,
1129 )
1130 });
1131 let event = stream_rx.expect_authorization().await;
1132 assert_eq!(
1133 event.tool_call.fields.title,
1134 Some("test 4 (local settings)".into())
1135 );
1136
1137 // Test 5: When always_allow_tool_actions is enabled, no confirmation needed
1138 cx.update(|cx| {
1139 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
1140 settings.always_allow_tool_actions = true;
1141 agent_settings::AgentSettings::override_global(settings, cx);
1142 });
1143
1144 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1145 cx.update(|cx| {
1146 tool.authorize(
1147 &EditFileToolInput {
1148 display_description: "test 5.1".into(),
1149 path: ".zed/settings.json".into(),
1150 mode: EditFileMode::Edit,
1151 },
1152 &stream_tx,
1153 cx,
1154 )
1155 })
1156 .await
1157 .unwrap();
1158 assert!(stream_rx.try_next().is_err());
1159
1160 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1161 cx.update(|cx| {
1162 tool.authorize(
1163 &EditFileToolInput {
1164 display_description: "test 5.2".into(),
1165 path: "/etc/hosts".into(),
1166 mode: EditFileMode::Edit,
1167 },
1168 &stream_tx,
1169 cx,
1170 )
1171 })
1172 .await
1173 .unwrap();
1174 assert!(stream_rx.try_next().is_err());
1175 }
1176
1177 #[gpui::test]
1178 async fn test_authorize_global_config(cx: &mut TestAppContext) {
1179 init_test(cx);
1180 let fs = project::FakeFs::new(cx.executor());
1181 fs.insert_tree("/project", json!({})).await;
1182 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1183 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1184 let context_server_registry =
1185 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1186 let model = Arc::new(FakeLanguageModel::default());
1187 let thread = cx.new(|cx| {
1188 Thread::new(
1189 project.clone(),
1190 cx.new(|_cx| ProjectContext::default()),
1191 context_server_registry,
1192 Templates::new(),
1193 Some(model.clone()),
1194 cx,
1195 )
1196 });
1197 let tool = Arc::new(EditFileTool::new(
1198 project.clone(),
1199 thread.downgrade(),
1200 language_registry,
1201 ));
1202
1203 // Test global config paths - these should require confirmation if they exist and are outside the project
1204 let test_cases = vec![
1205 (
1206 "/etc/hosts",
1207 true,
1208 "System file should require confirmation",
1209 ),
1210 (
1211 "/usr/local/bin/script",
1212 true,
1213 "System bin file should require confirmation",
1214 ),
1215 (
1216 "project/normal_file.rs",
1217 false,
1218 "Normal project file should not require confirmation",
1219 ),
1220 ];
1221
1222 for (path, should_confirm, description) in test_cases {
1223 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1224 let auth = cx.update(|cx| {
1225 tool.authorize(
1226 &EditFileToolInput {
1227 display_description: "Edit file".into(),
1228 path: path.into(),
1229 mode: EditFileMode::Edit,
1230 },
1231 &stream_tx,
1232 cx,
1233 )
1234 });
1235
1236 if should_confirm {
1237 stream_rx.expect_authorization().await;
1238 } else {
1239 auth.await.unwrap();
1240 assert!(
1241 stream_rx.try_next().is_err(),
1242 "Failed for case: {} - path: {} - expected no confirmation but got one",
1243 description,
1244 path
1245 );
1246 }
1247 }
1248 }
1249
1250 #[gpui::test]
1251 async fn test_needs_confirmation_with_multiple_worktrees(cx: &mut TestAppContext) {
1252 init_test(cx);
1253 let fs = project::FakeFs::new(cx.executor());
1254
1255 // Create multiple worktree directories
1256 fs.insert_tree(
1257 "/workspace/frontend",
1258 json!({
1259 "src": {
1260 "main.js": "console.log('frontend');"
1261 }
1262 }),
1263 )
1264 .await;
1265 fs.insert_tree(
1266 "/workspace/backend",
1267 json!({
1268 "src": {
1269 "main.rs": "fn main() {}"
1270 }
1271 }),
1272 )
1273 .await;
1274 fs.insert_tree(
1275 "/workspace/shared",
1276 json!({
1277 ".zed": {
1278 "settings.json": "{}"
1279 }
1280 }),
1281 )
1282 .await;
1283
1284 // Create project with multiple worktrees
1285 let project = Project::test(
1286 fs.clone(),
1287 [
1288 path!("/workspace/frontend").as_ref(),
1289 path!("/workspace/backend").as_ref(),
1290 path!("/workspace/shared").as_ref(),
1291 ],
1292 cx,
1293 )
1294 .await;
1295 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1296 let context_server_registry =
1297 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1298 let model = Arc::new(FakeLanguageModel::default());
1299 let thread = cx.new(|cx| {
1300 Thread::new(
1301 project.clone(),
1302 cx.new(|_cx| ProjectContext::default()),
1303 context_server_registry.clone(),
1304 Templates::new(),
1305 Some(model.clone()),
1306 cx,
1307 )
1308 });
1309 let tool = Arc::new(EditFileTool::new(
1310 project.clone(),
1311 thread.downgrade(),
1312 language_registry,
1313 ));
1314
1315 // Test files in different worktrees
1316 let test_cases = vec![
1317 ("frontend/src/main.js", false, "File in first worktree"),
1318 ("backend/src/main.rs", false, "File in second worktree"),
1319 (
1320 "shared/.zed/settings.json",
1321 true,
1322 ".zed file in third worktree",
1323 ),
1324 ("/etc/hosts", true, "Absolute path outside all worktrees"),
1325 (
1326 "../outside/file.txt",
1327 true,
1328 "Relative path outside worktrees",
1329 ),
1330 ];
1331
1332 for (path, should_confirm, description) in test_cases {
1333 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1334 let auth = cx.update(|cx| {
1335 tool.authorize(
1336 &EditFileToolInput {
1337 display_description: "Edit file".into(),
1338 path: path.into(),
1339 mode: EditFileMode::Edit,
1340 },
1341 &stream_tx,
1342 cx,
1343 )
1344 });
1345
1346 if should_confirm {
1347 stream_rx.expect_authorization().await;
1348 } else {
1349 auth.await.unwrap();
1350 assert!(
1351 stream_rx.try_next().is_err(),
1352 "Failed for case: {} - path: {} - expected no confirmation but got one",
1353 description,
1354 path
1355 );
1356 }
1357 }
1358 }
1359
1360 #[gpui::test]
1361 async fn test_needs_confirmation_edge_cases(cx: &mut TestAppContext) {
1362 init_test(cx);
1363 let fs = project::FakeFs::new(cx.executor());
1364 fs.insert_tree(
1365 "/project",
1366 json!({
1367 ".zed": {
1368 "settings.json": "{}"
1369 },
1370 "src": {
1371 ".zed": {
1372 "local.json": "{}"
1373 }
1374 }
1375 }),
1376 )
1377 .await;
1378 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1379 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1380 let context_server_registry =
1381 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1382 let model = Arc::new(FakeLanguageModel::default());
1383 let thread = cx.new(|cx| {
1384 Thread::new(
1385 project.clone(),
1386 cx.new(|_cx| ProjectContext::default()),
1387 context_server_registry.clone(),
1388 Templates::new(),
1389 Some(model.clone()),
1390 cx,
1391 )
1392 });
1393 let tool = Arc::new(EditFileTool::new(
1394 project.clone(),
1395 thread.downgrade(),
1396 language_registry,
1397 ));
1398
1399 // Test edge cases
1400 let test_cases = vec![
1401 // Empty path - find_project_path returns Some for empty paths
1402 ("", false, "Empty path is treated as project root"),
1403 // Root directory
1404 ("/", true, "Root directory should be outside project"),
1405 // Parent directory references - find_project_path resolves these
1406 (
1407 "project/../other",
1408 true,
1409 "Path with .. that goes outside of root directory",
1410 ),
1411 (
1412 "project/./src/file.rs",
1413 false,
1414 "Path with . should work normally",
1415 ),
1416 // Windows-style paths (if on Windows)
1417 #[cfg(target_os = "windows")]
1418 ("C:\\Windows\\System32\\hosts", true, "Windows system path"),
1419 #[cfg(target_os = "windows")]
1420 ("project\\src\\main.rs", false, "Windows-style project path"),
1421 ];
1422
1423 for (path, should_confirm, description) in test_cases {
1424 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1425 let auth = cx.update(|cx| {
1426 tool.authorize(
1427 &EditFileToolInput {
1428 display_description: "Edit file".into(),
1429 path: path.into(),
1430 mode: EditFileMode::Edit,
1431 },
1432 &stream_tx,
1433 cx,
1434 )
1435 });
1436
1437 cx.run_until_parked();
1438
1439 if should_confirm {
1440 stream_rx.expect_authorization().await;
1441 } else {
1442 assert!(
1443 stream_rx.try_next().is_err(),
1444 "Failed for case: {} - path: {} - expected no confirmation but got one",
1445 description,
1446 path
1447 );
1448 auth.await.unwrap();
1449 }
1450 }
1451 }
1452
1453 #[gpui::test]
1454 async fn test_needs_confirmation_with_different_modes(cx: &mut TestAppContext) {
1455 init_test(cx);
1456 let fs = project::FakeFs::new(cx.executor());
1457 fs.insert_tree(
1458 "/project",
1459 json!({
1460 "existing.txt": "content",
1461 ".zed": {
1462 "settings.json": "{}"
1463 }
1464 }),
1465 )
1466 .await;
1467 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1468 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1469 let context_server_registry =
1470 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1471 let model = Arc::new(FakeLanguageModel::default());
1472 let thread = cx.new(|cx| {
1473 Thread::new(
1474 project.clone(),
1475 cx.new(|_cx| ProjectContext::default()),
1476 context_server_registry.clone(),
1477 Templates::new(),
1478 Some(model.clone()),
1479 cx,
1480 )
1481 });
1482 let tool = Arc::new(EditFileTool::new(
1483 project.clone(),
1484 thread.downgrade(),
1485 language_registry,
1486 ));
1487
1488 // Test different EditFileMode values
1489 let modes = vec![
1490 EditFileMode::Edit,
1491 EditFileMode::Create,
1492 EditFileMode::Overwrite,
1493 ];
1494
1495 for mode in modes {
1496 // Test .zed path with different modes
1497 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1498 let _auth = cx.update(|cx| {
1499 tool.authorize(
1500 &EditFileToolInput {
1501 display_description: "Edit settings".into(),
1502 path: "project/.zed/settings.json".into(),
1503 mode: mode.clone(),
1504 },
1505 &stream_tx,
1506 cx,
1507 )
1508 });
1509
1510 stream_rx.expect_authorization().await;
1511
1512 // Test outside path with different modes
1513 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1514 let _auth = cx.update(|cx| {
1515 tool.authorize(
1516 &EditFileToolInput {
1517 display_description: "Edit file".into(),
1518 path: "/outside/file.txt".into(),
1519 mode: mode.clone(),
1520 },
1521 &stream_tx,
1522 cx,
1523 )
1524 });
1525
1526 stream_rx.expect_authorization().await;
1527
1528 // Test normal path with different modes
1529 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1530 cx.update(|cx| {
1531 tool.authorize(
1532 &EditFileToolInput {
1533 display_description: "Edit file".into(),
1534 path: "project/normal.txt".into(),
1535 mode: mode.clone(),
1536 },
1537 &stream_tx,
1538 cx,
1539 )
1540 })
1541 .await
1542 .unwrap();
1543 assert!(stream_rx.try_next().is_err());
1544 }
1545 }
1546
1547 #[gpui::test]
1548 async fn test_initial_title_with_partial_input(cx: &mut TestAppContext) {
1549 init_test(cx);
1550 let fs = project::FakeFs::new(cx.executor());
1551 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1552 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1553 let context_server_registry =
1554 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1555 let model = Arc::new(FakeLanguageModel::default());
1556 let thread = cx.new(|cx| {
1557 Thread::new(
1558 project.clone(),
1559 cx.new(|_cx| ProjectContext::default()),
1560 context_server_registry,
1561 Templates::new(),
1562 Some(model.clone()),
1563 cx,
1564 )
1565 });
1566 let tool = Arc::new(EditFileTool::new(
1567 project,
1568 thread.downgrade(),
1569 language_registry,
1570 ));
1571
1572 cx.update(|cx| {
1573 // ...
1574 assert_eq!(
1575 tool.initial_title(
1576 Err(json!({
1577 "path": "src/main.rs",
1578 "display_description": "",
1579 "old_string": "old code",
1580 "new_string": "new code"
1581 })),
1582 cx
1583 ),
1584 "src/main.rs"
1585 );
1586 assert_eq!(
1587 tool.initial_title(
1588 Err(json!({
1589 "path": "",
1590 "display_description": "Fix error handling",
1591 "old_string": "old code",
1592 "new_string": "new code"
1593 })),
1594 cx
1595 ),
1596 "Fix error handling"
1597 );
1598 assert_eq!(
1599 tool.initial_title(
1600 Err(json!({
1601 "path": "src/main.rs",
1602 "display_description": "Fix error handling",
1603 "old_string": "old code",
1604 "new_string": "new code"
1605 })),
1606 cx
1607 ),
1608 "src/main.rs"
1609 );
1610 assert_eq!(
1611 tool.initial_title(
1612 Err(json!({
1613 "path": "",
1614 "display_description": "",
1615 "old_string": "old code",
1616 "new_string": "new code"
1617 })),
1618 cx
1619 ),
1620 DEFAULT_UI_TEXT
1621 );
1622 assert_eq!(
1623 tool.initial_title(Err(serde_json::Value::Null), cx),
1624 DEFAULT_UI_TEXT
1625 );
1626 });
1627 }
1628
1629 #[gpui::test]
1630 async fn test_diff_finalization(cx: &mut TestAppContext) {
1631 init_test(cx);
1632 let fs = project::FakeFs::new(cx.executor());
1633 fs.insert_tree("/", json!({"main.rs": ""})).await;
1634
1635 let project = Project::test(fs.clone(), [path!("/").as_ref()], cx).await;
1636 let languages = project.read_with(cx, |project, _cx| project.languages().clone());
1637 let context_server_registry =
1638 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1639 let model = Arc::new(FakeLanguageModel::default());
1640 let thread = cx.new(|cx| {
1641 Thread::new(
1642 project.clone(),
1643 cx.new(|_cx| ProjectContext::default()),
1644 context_server_registry.clone(),
1645 Templates::new(),
1646 Some(model.clone()),
1647 cx,
1648 )
1649 });
1650
1651 // Ensure the diff is finalized after the edit completes.
1652 {
1653 let tool = Arc::new(EditFileTool::new(
1654 project.clone(),
1655 thread.downgrade(),
1656 languages.clone(),
1657 ));
1658 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1659 let edit = cx.update(|cx| {
1660 tool.run(
1661 EditFileToolInput {
1662 display_description: "Edit file".into(),
1663 path: path!("/main.rs").into(),
1664 mode: EditFileMode::Edit,
1665 },
1666 stream_tx,
1667 cx,
1668 )
1669 });
1670 stream_rx.expect_update_fields().await;
1671 let diff = stream_rx.expect_diff().await;
1672 diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Pending(_))));
1673 cx.run_until_parked();
1674 model.end_last_completion_stream();
1675 edit.await.unwrap();
1676 diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Finalized(_))));
1677 }
1678
1679 // Ensure the diff is finalized if an error occurs while editing.
1680 {
1681 model.forbid_requests();
1682 let tool = Arc::new(EditFileTool::new(
1683 project.clone(),
1684 thread.downgrade(),
1685 languages.clone(),
1686 ));
1687 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1688 let edit = cx.update(|cx| {
1689 tool.run(
1690 EditFileToolInput {
1691 display_description: "Edit file".into(),
1692 path: path!("/main.rs").into(),
1693 mode: EditFileMode::Edit,
1694 },
1695 stream_tx,
1696 cx,
1697 )
1698 });
1699 stream_rx.expect_update_fields().await;
1700 let diff = stream_rx.expect_diff().await;
1701 diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Pending(_))));
1702 edit.await.unwrap_err();
1703 diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Finalized(_))));
1704 model.allow_requests();
1705 }
1706
1707 // Ensure the diff is finalized if the tool call gets dropped.
1708 {
1709 let tool = Arc::new(EditFileTool::new(
1710 project.clone(),
1711 thread.downgrade(),
1712 languages.clone(),
1713 ));
1714 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1715 let edit = cx.update(|cx| {
1716 tool.run(
1717 EditFileToolInput {
1718 display_description: "Edit file".into(),
1719 path: path!("/main.rs").into(),
1720 mode: EditFileMode::Edit,
1721 },
1722 stream_tx,
1723 cx,
1724 )
1725 });
1726 stream_rx.expect_update_fields().await;
1727 let diff = stream_rx.expect_diff().await;
1728 diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Pending(_))));
1729 drop(edit);
1730 cx.run_until_parked();
1731 diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Finalized(_))));
1732 }
1733 }
1734
1735 fn init_test(cx: &mut TestAppContext) {
1736 cx.update(|cx| {
1737 let settings_store = SettingsStore::test(cx);
1738 cx.set_global(settings_store);
1739 language::init(cx);
1740 TelemetrySettings::register(cx);
1741 agent_settings::AgentSettings::register(cx);
1742 Project::init_settings(cx);
1743 });
1744 }
1745}