1use crate::{
2 AgentTool, Templates, Thread, ToolCallEventStream,
3 edit_agent::{EditAgent, EditAgentOutput, EditAgentOutputEvent, EditFormat},
4};
5use acp_thread::Diff;
6use agent_client_protocol::{self as acp, ToolCallLocation, ToolCallUpdateFields};
7use anyhow::{Context as _, Result, anyhow};
8use cloud_llm_client::CompletionIntent;
9use collections::HashSet;
10use gpui::{App, AppContext, AsyncApp, Entity, Task, WeakEntity};
11use indoc::formatdoc;
12use language::language_settings::{self, FormatOnSave};
13use language::{LanguageRegistry, ToPoint};
14use language_model::LanguageModelToolResultContent;
15use paths;
16use project::lsp_store::{FormatTrigger, LspFormatTarget};
17use project::{Project, ProjectPath};
18use schemars::JsonSchema;
19use serde::{Deserialize, Serialize};
20use settings::Settings;
21use smol::stream::StreamExt as _;
22use std::ffi::OsStr;
23use std::path::{Path, PathBuf};
24use std::sync::Arc;
25use ui::SharedString;
26use util::ResultExt;
27use util::rel_path::RelPath;
28
29const DEFAULT_UI_TEXT: &str = "Editing file";
30
31/// This is a tool for creating a new file or editing an existing file. For moving or renaming files, you should generally use the `terminal` tool with the 'mv' command instead.
32///
33/// Before using this tool:
34///
35/// 1. Use the `read_file` tool to understand the file's contents and context
36///
37/// 2. Verify the directory path is correct (only applicable when creating new files):
38/// - Use the `list_directory` tool to verify the parent directory exists and is the correct location
39#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
40pub struct EditFileToolInput {
41 /// A one-line, user-friendly markdown description of the edit. This will be shown in the UI and also passed to another model to perform the edit.
42 ///
43 /// Be terse, but also descriptive in what you want to achieve with this edit. Avoid generic instructions.
44 ///
45 /// NEVER mention the file path in this description.
46 ///
47 /// <example>Fix API endpoint URLs</example>
48 /// <example>Update copyright year in `page_footer`</example>
49 ///
50 /// Make sure to include this field before all the others in the input object so that we can display it immediately.
51 pub display_description: String,
52
53 /// The full path of the file to create or modify in the project.
54 ///
55 /// WARNING: When specifying which file path need changing, you MUST start each path with one of the project's root directories.
56 ///
57 /// The following examples assume we have two root directories in the project:
58 /// - /a/b/backend
59 /// - /c/d/frontend
60 ///
61 /// <example>
62 /// `backend/src/main.rs`
63 ///
64 /// Notice how the file path starts with `backend`. Without that, the path would be ambiguous and the call would fail!
65 /// </example>
66 ///
67 /// <example>
68 /// `frontend/db.js`
69 /// </example>
70 pub path: PathBuf,
71 /// The mode of operation on the file. Possible values:
72 /// - 'edit': Make granular edits to an existing file.
73 /// - 'create': Create a new file if it doesn't exist.
74 /// - 'overwrite': Replace the entire contents of an existing file.
75 ///
76 /// When a file already exists or you just created it, prefer editing it as opposed to recreating it from scratch.
77 pub mode: EditFileMode,
78}
79
80#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
81struct EditFileToolPartialInput {
82 #[serde(default)]
83 path: String,
84 #[serde(default)]
85 display_description: String,
86}
87
88#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
89#[serde(rename_all = "lowercase")]
90#[schemars(inline)]
91pub enum EditFileMode {
92 Edit,
93 Create,
94 Overwrite,
95}
96
97#[derive(Debug, Serialize, Deserialize)]
98pub struct EditFileToolOutput {
99 #[serde(alias = "original_path")]
100 input_path: PathBuf,
101 new_text: String,
102 old_text: Arc<String>,
103 #[serde(default)]
104 diff: String,
105 #[serde(alias = "raw_output")]
106 edit_agent_output: EditAgentOutput,
107}
108
109impl From<EditFileToolOutput> for LanguageModelToolResultContent {
110 fn from(output: EditFileToolOutput) -> Self {
111 if output.diff.is_empty() {
112 "No edits were made.".into()
113 } else {
114 format!(
115 "Edited {}:\n\n```diff\n{}\n```",
116 output.input_path.display(),
117 output.diff
118 )
119 .into()
120 }
121 }
122}
123
124pub struct EditFileTool {
125 thread: WeakEntity<Thread>,
126 language_registry: Arc<LanguageRegistry>,
127 project: Entity<Project>,
128 templates: Arc<Templates>,
129}
130
131impl EditFileTool {
132 pub const fn new(
133 project: Entity<Project>,
134 thread: WeakEntity<Thread>,
135 language_registry: Arc<LanguageRegistry>,
136 templates: Arc<Templates>,
137 ) -> Self {
138 Self {
139 project,
140 thread,
141 language_registry,
142 templates,
143 }
144 }
145
146 fn authorize(
147 &self,
148 input: &EditFileToolInput,
149 event_stream: &ToolCallEventStream,
150 cx: &mut App,
151 ) -> Task<Result<()>> {
152 if agent_settings::AgentSettings::get_global(cx).always_allow_tool_actions {
153 return Task::ready(Ok(()));
154 }
155
156 // If any path component matches the local settings folder, then this could affect
157 // the editor in ways beyond the project source, so prompt.
158 let local_settings_folder = paths::local_settings_folder_name();
159 let path = Path::new(&input.path);
160 if path.components().any(|component| {
161 component.as_os_str() == <_ as AsRef<OsStr>>::as_ref(&local_settings_folder)
162 }) {
163 return event_stream.authorize(
164 format!("{} (local settings)", input.display_description),
165 cx,
166 );
167 }
168
169 // It's also possible that the global config dir is configured to be inside the project,
170 // so check for that edge case too.
171 // TODO this is broken when remoting
172 if let Ok(canonical_path) = std::fs::canonicalize(&input.path)
173 && canonical_path.starts_with(paths::config_dir())
174 {
175 return event_stream.authorize(
176 format!("{} (global settings)", input.display_description),
177 cx,
178 );
179 }
180
181 // Check if path is inside the global config directory
182 // First check if it's already inside project - if not, try to canonicalize
183 let Ok(project_path) = self.thread.read_with(cx, |thread, cx| {
184 thread.project().read(cx).find_project_path(&input.path, cx)
185 }) else {
186 return Task::ready(Err(anyhow!("thread was dropped")));
187 };
188
189 // If the path is inside the project, and it's not one of the above edge cases,
190 // then no confirmation is necessary. Otherwise, confirmation is necessary.
191 if project_path.is_some() {
192 Task::ready(Ok(()))
193 } else {
194 event_stream.authorize(&input.display_description, cx)
195 }
196 }
197}
198
199impl AgentTool for EditFileTool {
200 type Input = EditFileToolInput;
201 type Output = EditFileToolOutput;
202
203 fn name() -> &'static str {
204 "edit_file"
205 }
206
207 fn kind() -> acp::ToolKind {
208 acp::ToolKind::Edit
209 }
210
211 fn initial_title(
212 &self,
213 input: Result<Self::Input, serde_json::Value>,
214 cx: &mut App,
215 ) -> SharedString {
216 match input {
217 Ok(input) => self
218 .project
219 .read(cx)
220 .find_project_path(&input.path, cx)
221 .and_then(|project_path| {
222 self.project
223 .read(cx)
224 .short_full_path_for_project_path(&project_path, cx)
225 })
226 .unwrap_or(input.path.to_string_lossy().into_owned())
227 .into(),
228 Err(raw_input) => {
229 if let Some(input) =
230 serde_json::from_value::<EditFileToolPartialInput>(raw_input).ok()
231 {
232 let path = input.path.trim();
233 if !path.is_empty() {
234 return self
235 .project
236 .read(cx)
237 .find_project_path(&input.path, cx)
238 .and_then(|project_path| {
239 self.project
240 .read(cx)
241 .short_full_path_for_project_path(&project_path, cx)
242 })
243 .unwrap_or(input.path)
244 .into();
245 }
246
247 let description = input.display_description.trim();
248 if !description.is_empty() {
249 return description.to_string().into();
250 }
251 }
252
253 DEFAULT_UI_TEXT.into()
254 }
255 }
256 }
257
258 fn run(
259 self: Arc<Self>,
260 input: Self::Input,
261 event_stream: ToolCallEventStream,
262 cx: &mut App,
263 ) -> Task<Result<Self::Output>> {
264 let Ok(project) = self
265 .thread
266 .read_with(cx, |thread, _cx| thread.project().clone())
267 else {
268 return Task::ready(Err(anyhow!("thread was dropped")));
269 };
270 let project_path = match resolve_path(&input, project.clone(), cx) {
271 Ok(path) => path,
272 Err(err) => return Task::ready(Err(anyhow!(err))),
273 };
274 let abs_path = project.read(cx).absolute_path(&project_path, cx);
275 if let Some(abs_path) = abs_path.clone() {
276 event_stream.update_fields(ToolCallUpdateFields {
277 locations: Some(vec![acp::ToolCallLocation {
278 path: abs_path,
279 line: None,
280 meta: None,
281 }]),
282 ..Default::default()
283 });
284 }
285
286 let authorize = self.authorize(&input, &event_stream, cx);
287 cx.spawn(async move |cx: &mut AsyncApp| {
288 authorize.await?;
289
290 let (request, model, action_log) = self.thread.update(cx, |thread, cx| {
291 let request = thread.build_completion_request(CompletionIntent::ToolResults, cx);
292 (request, thread.model().cloned(), thread.action_log().clone())
293 })?;
294 let request = request?;
295 let model = model.context("No language model configured")?;
296
297 let edit_format = EditFormat::from_model(model.clone())?;
298 let edit_agent = EditAgent::new(
299 model,
300 project.clone(),
301 action_log.clone(),
302 self.templates.clone(),
303 edit_format,
304 );
305
306 let buffer = project
307 .update(cx, |project, cx| {
308 project.open_buffer(project_path.clone(), cx)
309 })?
310 .await?;
311
312 let diff = cx.new(|cx| Diff::new(buffer.clone(), cx))?;
313 event_stream.update_diff(diff.clone());
314 let _finalize_diff = util::defer({
315 let diff = diff.downgrade();
316 let mut cx = cx.clone();
317 move || {
318 diff.update(&mut cx, |diff, cx| diff.finalize(cx)).ok();
319 }
320 });
321
322 let old_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
323 let old_text = cx
324 .background_spawn({
325 let old_snapshot = old_snapshot.clone();
326 async move { Arc::new(old_snapshot.text()) }
327 })
328 .await;
329
330
331 let (output, mut events) = if matches!(input.mode, EditFileMode::Edit) {
332 edit_agent.edit(
333 buffer.clone(),
334 input.display_description.clone(),
335 &request,
336 cx,
337 )
338 } else {
339 edit_agent.overwrite(
340 buffer.clone(),
341 input.display_description.clone(),
342 &request,
343 cx,
344 )
345 };
346
347 let mut hallucinated_old_text = false;
348 let mut ambiguous_ranges = Vec::new();
349 let mut emitted_location = false;
350 while let Some(event) = events.next().await {
351 match event {
352 EditAgentOutputEvent::Edited(range) => {
353 if !emitted_location {
354 let line = buffer.update(cx, |buffer, _cx| {
355 range.start.to_point(&buffer.snapshot()).row
356 }).ok();
357 if let Some(abs_path) = abs_path.clone() {
358 event_stream.update_fields(ToolCallUpdateFields {
359 locations: Some(vec![ToolCallLocation { path: abs_path, line, meta: None }]),
360 ..Default::default()
361 });
362 }
363 emitted_location = true;
364 }
365 },
366 EditAgentOutputEvent::UnresolvedEditRange => hallucinated_old_text = true,
367 EditAgentOutputEvent::AmbiguousEditRange(ranges) => ambiguous_ranges = ranges,
368 EditAgentOutputEvent::ResolvingEditRange(range) => {
369 diff.update(cx, |card, cx| card.reveal_range(range.clone(), cx))?;
370 // if !emitted_location {
371 // let line = buffer.update(cx, |buffer, _cx| {
372 // range.start.to_point(&buffer.snapshot()).row
373 // }).ok();
374 // if let Some(abs_path) = abs_path.clone() {
375 // event_stream.update_fields(ToolCallUpdateFields {
376 // locations: Some(vec![ToolCallLocation { path: abs_path, line }]),
377 // ..Default::default()
378 // });
379 // }
380 // }
381 }
382 }
383 }
384
385 // If format_on_save is enabled, format the buffer
386 let format_on_save_enabled = buffer
387 .read_with(cx, |buffer, cx| {
388 let settings = language_settings::language_settings(
389 buffer.language().map(|l| l.name()),
390 buffer.file(),
391 cx,
392 );
393 settings.format_on_save != FormatOnSave::Off
394 })
395 .unwrap_or(false);
396
397 let edit_agent_output = output.await?;
398
399 if format_on_save_enabled {
400 action_log.update(cx, |log, cx| {
401 log.buffer_edited(buffer.clone(), cx);
402 })?;
403
404 let format_task = project.update(cx, |project, cx| {
405 project.format(
406 HashSet::from_iter([buffer.clone()]),
407 LspFormatTarget::Buffers,
408 false, // Don't push to history since the tool did it.
409 FormatTrigger::Save,
410 cx,
411 )
412 })?;
413 format_task.await.log_err();
414 }
415
416 project
417 .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))?
418 .await?;
419
420 action_log.update(cx, |log, cx| {
421 log.buffer_edited(buffer.clone(), cx);
422 })?;
423
424 let new_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
425 let (new_text, unified_diff) = cx
426 .background_spawn({
427 let new_snapshot = new_snapshot.clone();
428 let old_text = old_text.clone();
429 async move {
430 let new_text = new_snapshot.text();
431 let diff = language::unified_diff(&old_text, &new_text);
432 (new_text, diff)
433 }
434 })
435 .await;
436
437 let input_path = input.path.display();
438 if unified_diff.is_empty() {
439 anyhow::ensure!(
440 !hallucinated_old_text,
441 formatdoc! {"
442 Some edits were produced but none of them could be applied.
443 Read the relevant sections of {input_path} again so that
444 I can perform the requested edits.
445 "}
446 );
447 anyhow::ensure!(
448 ambiguous_ranges.is_empty(),
449 {
450 let line_numbers = ambiguous_ranges
451 .iter()
452 .map(|range| range.start.to_string())
453 .collect::<Vec<_>>()
454 .join(", ");
455 formatdoc! {"
456 <old_text> matches more than one position in the file (lines: {line_numbers}). Read the
457 relevant sections of {input_path} again and extend <old_text> so
458 that I can perform the requested edits.
459 "}
460 }
461 );
462 }
463
464 Ok(EditFileToolOutput {
465 input_path: input.path,
466 new_text,
467 old_text,
468 diff: unified_diff,
469 edit_agent_output,
470 })
471 })
472 }
473
474 fn replay(
475 &self,
476 _input: Self::Input,
477 output: Self::Output,
478 event_stream: ToolCallEventStream,
479 cx: &mut App,
480 ) -> Result<()> {
481 event_stream.update_diff(cx.new(|cx| {
482 Diff::finalized(
483 output.input_path.to_string_lossy().into_owned(),
484 Some(output.old_text.to_string()),
485 output.new_text,
486 self.language_registry.clone(),
487 cx,
488 )
489 }));
490 Ok(())
491 }
492}
493
494/// Validate that the file path is valid, meaning:
495///
496/// - For `edit` and `overwrite`, the path must point to an existing file.
497/// - For `create`, the file must not already exist, but it's parent dir must exist.
498fn resolve_path(
499 input: &EditFileToolInput,
500 project: Entity<Project>,
501 cx: &mut App,
502) -> Result<ProjectPath> {
503 let project = project.read(cx);
504
505 match input.mode {
506 EditFileMode::Edit | EditFileMode::Overwrite => {
507 let path = project
508 .find_project_path(&input.path, cx)
509 .context("Can't edit file: path not found")?;
510
511 let entry = project
512 .entry_for_path(&path, cx)
513 .context("Can't edit file: path not found")?;
514
515 anyhow::ensure!(entry.is_file(), "Can't edit file: path is a directory");
516 Ok(path)
517 }
518
519 EditFileMode::Create => {
520 if let Some(path) = project.find_project_path(&input.path, cx) {
521 anyhow::ensure!(
522 project.entry_for_path(&path, cx).is_none(),
523 "Can't create file: file already exists"
524 );
525 }
526
527 let parent_path = input
528 .path
529 .parent()
530 .context("Can't create file: incorrect path")?;
531
532 let parent_project_path = project.find_project_path(&parent_path, cx);
533
534 let parent_entry = parent_project_path
535 .as_ref()
536 .and_then(|path| project.entry_for_path(path, cx))
537 .context("Can't create file: parent directory doesn't exist")?;
538
539 anyhow::ensure!(
540 parent_entry.is_dir(),
541 "Can't create file: parent is not a directory"
542 );
543
544 let file_name = input
545 .path
546 .file_name()
547 .and_then(|file_name| file_name.to_str())
548 .and_then(|file_name| RelPath::unix(file_name).ok())
549 .context("Can't create file: invalid filename")?;
550
551 let new_file_path = parent_project_path.map(|parent| ProjectPath {
552 path: parent.path.join(file_name),
553 ..parent
554 });
555
556 new_file_path.context("Can't create file")
557 }
558 }
559}
560
561#[cfg(test)]
562mod tests {
563 use super::*;
564 use crate::{ContextServerRegistry, Templates};
565 use client::TelemetrySettings;
566 use fs::Fs;
567 use gpui::{TestAppContext, UpdateGlobal};
568 use language_model::fake_provider::FakeLanguageModel;
569 use prompt_store::ProjectContext;
570 use serde_json::json;
571 use settings::SettingsStore;
572 use util::{path, rel_path::rel_path};
573
574 #[gpui::test]
575 async fn test_edit_nonexistent_file(cx: &mut TestAppContext) {
576 init_test(cx);
577
578 let fs = project::FakeFs::new(cx.executor());
579 fs.insert_tree("/root", json!({})).await;
580 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
581 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
582 let context_server_registry =
583 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
584 let model = Arc::new(FakeLanguageModel::default());
585 let thread = cx.new(|cx| {
586 Thread::new(
587 project.clone(),
588 cx.new(|_cx| ProjectContext::default()),
589 context_server_registry,
590 Templates::new(),
591 Some(model),
592 cx,
593 )
594 });
595 let result = cx
596 .update(|cx| {
597 let input = EditFileToolInput {
598 display_description: "Some edit".into(),
599 path: "root/nonexistent_file.txt".into(),
600 mode: EditFileMode::Edit,
601 };
602 Arc::new(EditFileTool::new(
603 project,
604 thread.downgrade(),
605 language_registry,
606 Templates::new(),
607 ))
608 .run(input, ToolCallEventStream::test().0, cx)
609 })
610 .await;
611 assert_eq!(
612 result.unwrap_err().to_string(),
613 "Can't edit file: path not found"
614 );
615 }
616
617 #[gpui::test]
618 async fn test_resolve_path_for_creating_file(cx: &mut TestAppContext) {
619 let mode = &EditFileMode::Create;
620
621 let result = test_resolve_path(mode, "root/new.txt", cx);
622 assert_resolved_path_eq(result.await, rel_path("new.txt"));
623
624 let result = test_resolve_path(mode, "new.txt", cx);
625 assert_resolved_path_eq(result.await, rel_path("new.txt"));
626
627 let result = test_resolve_path(mode, "dir/new.txt", cx);
628 assert_resolved_path_eq(result.await, rel_path("dir/new.txt"));
629
630 let result = test_resolve_path(mode, "root/dir/subdir/existing.txt", cx);
631 assert_eq!(
632 result.await.unwrap_err().to_string(),
633 "Can't create file: file already exists"
634 );
635
636 let result = test_resolve_path(mode, "root/dir/nonexistent_dir/new.txt", cx);
637 assert_eq!(
638 result.await.unwrap_err().to_string(),
639 "Can't create file: parent directory doesn't exist"
640 );
641 }
642
643 #[gpui::test]
644 async fn test_resolve_path_for_editing_file(cx: &mut TestAppContext) {
645 let mode = &EditFileMode::Edit;
646
647 let path_with_root = "root/dir/subdir/existing.txt";
648 let path_without_root = "dir/subdir/existing.txt";
649 let result = test_resolve_path(mode, path_with_root, cx);
650 assert_resolved_path_eq(result.await, rel_path(path_without_root));
651
652 let result = test_resolve_path(mode, path_without_root, cx);
653 assert_resolved_path_eq(result.await, rel_path(path_without_root));
654
655 let result = test_resolve_path(mode, "root/nonexistent.txt", cx);
656 assert_eq!(
657 result.await.unwrap_err().to_string(),
658 "Can't edit file: path not found"
659 );
660
661 let result = test_resolve_path(mode, "root/dir", cx);
662 assert_eq!(
663 result.await.unwrap_err().to_string(),
664 "Can't edit file: path is a directory"
665 );
666 }
667
668 async fn test_resolve_path(
669 mode: &EditFileMode,
670 path: &str,
671 cx: &mut TestAppContext,
672 ) -> anyhow::Result<ProjectPath> {
673 init_test(cx);
674
675 let fs = project::FakeFs::new(cx.executor());
676 fs.insert_tree(
677 "/root",
678 json!({
679 "dir": {
680 "subdir": {
681 "existing.txt": "hello"
682 }
683 }
684 }),
685 )
686 .await;
687 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
688
689 let input = EditFileToolInput {
690 display_description: "Some edit".into(),
691 path: path.into(),
692 mode: mode.clone(),
693 };
694
695 cx.update(|cx| resolve_path(&input, project, cx))
696 }
697
698 #[track_caller]
699 fn assert_resolved_path_eq(path: anyhow::Result<ProjectPath>, expected: &RelPath) {
700 let actual = path.expect("Should return valid path").path;
701 assert_eq!(actual.as_ref(), expected);
702 }
703
704 #[gpui::test]
705 async fn test_format_on_save(cx: &mut TestAppContext) {
706 init_test(cx);
707
708 let fs = project::FakeFs::new(cx.executor());
709 fs.insert_tree("/root", json!({"src": {}})).await;
710
711 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
712
713 // Set up a Rust language with LSP formatting support
714 let rust_language = Arc::new(language::Language::new(
715 language::LanguageConfig {
716 name: "Rust".into(),
717 matcher: language::LanguageMatcher {
718 path_suffixes: vec!["rs".to_string()],
719 ..Default::default()
720 },
721 ..Default::default()
722 },
723 None,
724 ));
725
726 // Register the language and fake LSP
727 let language_registry = project.read_with(cx, |project, _| project.languages().clone());
728 language_registry.add(rust_language);
729
730 let mut fake_language_servers = language_registry.register_fake_lsp(
731 "Rust",
732 language::FakeLspAdapter {
733 capabilities: lsp::ServerCapabilities {
734 document_formatting_provider: Some(lsp::OneOf::Left(true)),
735 ..Default::default()
736 },
737 ..Default::default()
738 },
739 );
740
741 // Create the file
742 fs.save(
743 path!("/root/src/main.rs").as_ref(),
744 &"initial content".into(),
745 language::LineEnding::Unix,
746 )
747 .await
748 .unwrap();
749
750 // Open the buffer to trigger LSP initialization
751 let buffer = project
752 .update(cx, |project, cx| {
753 project.open_local_buffer(path!("/root/src/main.rs"), cx)
754 })
755 .await
756 .unwrap();
757
758 // Register the buffer with language servers
759 let _handle = project.update(cx, |project, cx| {
760 project.register_buffer_with_language_servers(&buffer, cx)
761 });
762
763 const UNFORMATTED_CONTENT: &str = "fn main() {println!(\"Hello!\");}\n";
764 const FORMATTED_CONTENT: &str =
765 "This file was formatted by the fake formatter in the test.\n";
766
767 // Get the fake language server and set up formatting handler
768 let fake_language_server = fake_language_servers.next().await.unwrap();
769 fake_language_server.set_request_handler::<lsp::request::Formatting, _, _>({
770 |_, _| async move {
771 Ok(Some(vec![lsp::TextEdit {
772 range: lsp::Range::new(lsp::Position::new(0, 0), lsp::Position::new(1, 0)),
773 new_text: FORMATTED_CONTENT.to_string(),
774 }]))
775 }
776 });
777
778 let context_server_registry =
779 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
780 let model = Arc::new(FakeLanguageModel::default());
781 let thread = cx.new(|cx| {
782 Thread::new(
783 project.clone(),
784 cx.new(|_cx| ProjectContext::default()),
785 context_server_registry,
786 Templates::new(),
787 Some(model.clone()),
788 cx,
789 )
790 });
791
792 // First, test with format_on_save enabled
793 cx.update(|cx| {
794 SettingsStore::update_global(cx, |store, cx| {
795 store.update_user_settings(cx, |settings| {
796 settings.project.all_languages.defaults.format_on_save = Some(FormatOnSave::On);
797 settings.project.all_languages.defaults.formatter =
798 Some(language::language_settings::FormatterList::default());
799 });
800 });
801 });
802
803 // Have the model stream unformatted content
804 let edit_result = {
805 let edit_task = cx.update(|cx| {
806 let input = EditFileToolInput {
807 display_description: "Create main function".into(),
808 path: "root/src/main.rs".into(),
809 mode: EditFileMode::Overwrite,
810 };
811 Arc::new(EditFileTool::new(
812 project.clone(),
813 thread.downgrade(),
814 language_registry.clone(),
815 Templates::new(),
816 ))
817 .run(input, ToolCallEventStream::test().0, cx)
818 });
819
820 // Stream the unformatted content
821 cx.executor().run_until_parked();
822 model.send_last_completion_stream_text_chunk(UNFORMATTED_CONTENT.to_string());
823 model.end_last_completion_stream();
824
825 edit_task.await
826 };
827 assert!(edit_result.is_ok());
828
829 // Wait for any async operations (e.g. formatting) to complete
830 cx.executor().run_until_parked();
831
832 // Read the file to verify it was formatted automatically
833 let new_content = fs.load(path!("/root/src/main.rs").as_ref()).await.unwrap();
834 assert_eq!(
835 // Ignore carriage returns on Windows
836 new_content.replace("\r\n", "\n"),
837 FORMATTED_CONTENT,
838 "Code should be formatted when format_on_save is enabled"
839 );
840
841 let stale_buffer_count = thread
842 .read_with(cx, |thread, _cx| thread.action_log.clone())
843 .read_with(cx, |log, cx| log.stale_buffers(cx).count());
844
845 assert_eq!(
846 stale_buffer_count, 0,
847 "BUG: Buffer is incorrectly marked as stale after format-on-save. Found {} stale buffers. \
848 This causes the agent to think the file was modified externally when it was just formatted.",
849 stale_buffer_count
850 );
851
852 // Next, test with format_on_save disabled
853 cx.update(|cx| {
854 SettingsStore::update_global(cx, |store, cx| {
855 store.update_user_settings(cx, |settings| {
856 settings.project.all_languages.defaults.format_on_save =
857 Some(FormatOnSave::Off);
858 });
859 });
860 });
861
862 // Stream unformatted edits again
863 let edit_result = {
864 let edit_task = cx.update(|cx| {
865 let input = EditFileToolInput {
866 display_description: "Update main function".into(),
867 path: "root/src/main.rs".into(),
868 mode: EditFileMode::Overwrite,
869 };
870 Arc::new(EditFileTool::new(
871 project.clone(),
872 thread.downgrade(),
873 language_registry,
874 Templates::new(),
875 ))
876 .run(input, ToolCallEventStream::test().0, cx)
877 });
878
879 // Stream the unformatted content
880 cx.executor().run_until_parked();
881 model.send_last_completion_stream_text_chunk(UNFORMATTED_CONTENT.to_string());
882 model.end_last_completion_stream();
883
884 edit_task.await
885 };
886 assert!(edit_result.is_ok());
887
888 // Wait for any async operations (e.g. formatting) to complete
889 cx.executor().run_until_parked();
890
891 // Verify the file was not formatted
892 let new_content = fs.load(path!("/root/src/main.rs").as_ref()).await.unwrap();
893 assert_eq!(
894 // Ignore carriage returns on Windows
895 new_content.replace("\r\n", "\n"),
896 UNFORMATTED_CONTENT,
897 "Code should not be formatted when format_on_save is disabled"
898 );
899 }
900
901 #[gpui::test]
902 async fn test_remove_trailing_whitespace(cx: &mut TestAppContext) {
903 init_test(cx);
904
905 let fs = project::FakeFs::new(cx.executor());
906 fs.insert_tree("/root", json!({"src": {}})).await;
907
908 // Create a simple file with trailing whitespace
909 fs.save(
910 path!("/root/src/main.rs").as_ref(),
911 &"initial content".into(),
912 language::LineEnding::Unix,
913 )
914 .await
915 .unwrap();
916
917 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
918 let context_server_registry =
919 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
920 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
921 let model = Arc::new(FakeLanguageModel::default());
922 let thread = cx.new(|cx| {
923 Thread::new(
924 project.clone(),
925 cx.new(|_cx| ProjectContext::default()),
926 context_server_registry,
927 Templates::new(),
928 Some(model.clone()),
929 cx,
930 )
931 });
932
933 // First, test with remove_trailing_whitespace_on_save enabled
934 cx.update(|cx| {
935 SettingsStore::update_global(cx, |store, cx| {
936 store.update_user_settings(cx, |settings| {
937 settings
938 .project
939 .all_languages
940 .defaults
941 .remove_trailing_whitespace_on_save = Some(true);
942 });
943 });
944 });
945
946 const CONTENT_WITH_TRAILING_WHITESPACE: &str =
947 "fn main() { \n println!(\"Hello!\"); \n}\n";
948
949 // Have the model stream content that contains trailing whitespace
950 let edit_result = {
951 let edit_task = cx.update(|cx| {
952 let input = EditFileToolInput {
953 display_description: "Create main function".into(),
954 path: "root/src/main.rs".into(),
955 mode: EditFileMode::Overwrite,
956 };
957 Arc::new(EditFileTool::new(
958 project.clone(),
959 thread.downgrade(),
960 language_registry.clone(),
961 Templates::new(),
962 ))
963 .run(input, ToolCallEventStream::test().0, cx)
964 });
965
966 // Stream the content with trailing whitespace
967 cx.executor().run_until_parked();
968 model.send_last_completion_stream_text_chunk(
969 CONTENT_WITH_TRAILING_WHITESPACE.to_string(),
970 );
971 model.end_last_completion_stream();
972
973 edit_task.await
974 };
975 assert!(edit_result.is_ok());
976
977 // Wait for any async operations (e.g. formatting) to complete
978 cx.executor().run_until_parked();
979
980 // Read the file to verify trailing whitespace was removed automatically
981 assert_eq!(
982 // Ignore carriage returns on Windows
983 fs.load(path!("/root/src/main.rs").as_ref())
984 .await
985 .unwrap()
986 .replace("\r\n", "\n"),
987 "fn main() {\n println!(\"Hello!\");\n}\n",
988 "Trailing whitespace should be removed when remove_trailing_whitespace_on_save is enabled"
989 );
990
991 // Next, test with remove_trailing_whitespace_on_save disabled
992 cx.update(|cx| {
993 SettingsStore::update_global(cx, |store, cx| {
994 store.update_user_settings(cx, |settings| {
995 settings
996 .project
997 .all_languages
998 .defaults
999 .remove_trailing_whitespace_on_save = Some(false);
1000 });
1001 });
1002 });
1003
1004 // Stream edits again with trailing whitespace
1005 let edit_result = {
1006 let edit_task = cx.update(|cx| {
1007 let input = EditFileToolInput {
1008 display_description: "Update main function".into(),
1009 path: "root/src/main.rs".into(),
1010 mode: EditFileMode::Overwrite,
1011 };
1012 Arc::new(EditFileTool::new(
1013 project.clone(),
1014 thread.downgrade(),
1015 language_registry,
1016 Templates::new(),
1017 ))
1018 .run(input, ToolCallEventStream::test().0, cx)
1019 });
1020
1021 // Stream the content with trailing whitespace
1022 cx.executor().run_until_parked();
1023 model.send_last_completion_stream_text_chunk(
1024 CONTENT_WITH_TRAILING_WHITESPACE.to_string(),
1025 );
1026 model.end_last_completion_stream();
1027
1028 edit_task.await
1029 };
1030 assert!(edit_result.is_ok());
1031
1032 // Wait for any async operations (e.g. formatting) to complete
1033 cx.executor().run_until_parked();
1034
1035 // Verify the file still has trailing whitespace
1036 // Read the file again - it should still have trailing whitespace
1037 let final_content = fs.load(path!("/root/src/main.rs").as_ref()).await.unwrap();
1038 assert_eq!(
1039 // Ignore carriage returns on Windows
1040 final_content.replace("\r\n", "\n"),
1041 CONTENT_WITH_TRAILING_WHITESPACE,
1042 "Trailing whitespace should remain when remove_trailing_whitespace_on_save is disabled"
1043 );
1044 }
1045
1046 #[gpui::test]
1047 async fn test_authorize(cx: &mut TestAppContext) {
1048 init_test(cx);
1049 let fs = project::FakeFs::new(cx.executor());
1050 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
1051 let context_server_registry =
1052 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1053 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1054 let model = Arc::new(FakeLanguageModel::default());
1055 let thread = cx.new(|cx| {
1056 Thread::new(
1057 project.clone(),
1058 cx.new(|_cx| ProjectContext::default()),
1059 context_server_registry,
1060 Templates::new(),
1061 Some(model.clone()),
1062 cx,
1063 )
1064 });
1065 let tool = Arc::new(EditFileTool::new(
1066 project.clone(),
1067 thread.downgrade(),
1068 language_registry,
1069 Templates::new(),
1070 ));
1071 fs.insert_tree("/root", json!({})).await;
1072
1073 // Test 1: Path with .zed component should require confirmation
1074 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1075 let _auth = cx.update(|cx| {
1076 tool.authorize(
1077 &EditFileToolInput {
1078 display_description: "test 1".into(),
1079 path: ".zed/settings.json".into(),
1080 mode: EditFileMode::Edit,
1081 },
1082 &stream_tx,
1083 cx,
1084 )
1085 });
1086
1087 let event = stream_rx.expect_authorization().await;
1088 assert_eq!(
1089 event.tool_call.fields.title,
1090 Some("test 1 (local settings)".into())
1091 );
1092
1093 // Test 2: Path outside project should require confirmation
1094 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1095 let _auth = cx.update(|cx| {
1096 tool.authorize(
1097 &EditFileToolInput {
1098 display_description: "test 2".into(),
1099 path: "/etc/hosts".into(),
1100 mode: EditFileMode::Edit,
1101 },
1102 &stream_tx,
1103 cx,
1104 )
1105 });
1106
1107 let event = stream_rx.expect_authorization().await;
1108 assert_eq!(event.tool_call.fields.title, Some("test 2".into()));
1109
1110 // Test 3: Relative path without .zed should not require confirmation
1111 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1112 cx.update(|cx| {
1113 tool.authorize(
1114 &EditFileToolInput {
1115 display_description: "test 3".into(),
1116 path: "root/src/main.rs".into(),
1117 mode: EditFileMode::Edit,
1118 },
1119 &stream_tx,
1120 cx,
1121 )
1122 })
1123 .await
1124 .unwrap();
1125 assert!(stream_rx.try_next().is_err());
1126
1127 // Test 4: Path with .zed in the middle should require confirmation
1128 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1129 let _auth = cx.update(|cx| {
1130 tool.authorize(
1131 &EditFileToolInput {
1132 display_description: "test 4".into(),
1133 path: "root/.zed/tasks.json".into(),
1134 mode: EditFileMode::Edit,
1135 },
1136 &stream_tx,
1137 cx,
1138 )
1139 });
1140 let event = stream_rx.expect_authorization().await;
1141 assert_eq!(
1142 event.tool_call.fields.title,
1143 Some("test 4 (local settings)".into())
1144 );
1145
1146 // Test 5: When always_allow_tool_actions is enabled, no confirmation needed
1147 cx.update(|cx| {
1148 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
1149 settings.always_allow_tool_actions = true;
1150 agent_settings::AgentSettings::override_global(settings, cx);
1151 });
1152
1153 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1154 cx.update(|cx| {
1155 tool.authorize(
1156 &EditFileToolInput {
1157 display_description: "test 5.1".into(),
1158 path: ".zed/settings.json".into(),
1159 mode: EditFileMode::Edit,
1160 },
1161 &stream_tx,
1162 cx,
1163 )
1164 })
1165 .await
1166 .unwrap();
1167 assert!(stream_rx.try_next().is_err());
1168
1169 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1170 cx.update(|cx| {
1171 tool.authorize(
1172 &EditFileToolInput {
1173 display_description: "test 5.2".into(),
1174 path: "/etc/hosts".into(),
1175 mode: EditFileMode::Edit,
1176 },
1177 &stream_tx,
1178 cx,
1179 )
1180 })
1181 .await
1182 .unwrap();
1183 assert!(stream_rx.try_next().is_err());
1184 }
1185
1186 #[gpui::test]
1187 async fn test_authorize_global_config(cx: &mut TestAppContext) {
1188 init_test(cx);
1189 let fs = project::FakeFs::new(cx.executor());
1190 fs.insert_tree("/project", json!({})).await;
1191 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1192 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1193 let context_server_registry =
1194 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1195 let model = Arc::new(FakeLanguageModel::default());
1196 let thread = cx.new(|cx| {
1197 Thread::new(
1198 project.clone(),
1199 cx.new(|_cx| ProjectContext::default()),
1200 context_server_registry,
1201 Templates::new(),
1202 Some(model.clone()),
1203 cx,
1204 )
1205 });
1206 let tool = Arc::new(EditFileTool::new(
1207 project.clone(),
1208 thread.downgrade(),
1209 language_registry,
1210 Templates::new(),
1211 ));
1212
1213 // Test global config paths - these should require confirmation if they exist and are outside the project
1214 let test_cases = vec![
1215 (
1216 "/etc/hosts",
1217 true,
1218 "System file should require confirmation",
1219 ),
1220 (
1221 "/usr/local/bin/script",
1222 true,
1223 "System bin file should require confirmation",
1224 ),
1225 (
1226 "project/normal_file.rs",
1227 false,
1228 "Normal project file should not require confirmation",
1229 ),
1230 ];
1231
1232 for (path, should_confirm, description) in test_cases {
1233 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1234 let auth = cx.update(|cx| {
1235 tool.authorize(
1236 &EditFileToolInput {
1237 display_description: "Edit file".into(),
1238 path: path.into(),
1239 mode: EditFileMode::Edit,
1240 },
1241 &stream_tx,
1242 cx,
1243 )
1244 });
1245
1246 if should_confirm {
1247 stream_rx.expect_authorization().await;
1248 } else {
1249 auth.await.unwrap();
1250 assert!(
1251 stream_rx.try_next().is_err(),
1252 "Failed for case: {} - path: {} - expected no confirmation but got one",
1253 description,
1254 path
1255 );
1256 }
1257 }
1258 }
1259
1260 #[gpui::test]
1261 async fn test_needs_confirmation_with_multiple_worktrees(cx: &mut TestAppContext) {
1262 init_test(cx);
1263 let fs = project::FakeFs::new(cx.executor());
1264
1265 // Create multiple worktree directories
1266 fs.insert_tree(
1267 "/workspace/frontend",
1268 json!({
1269 "src": {
1270 "main.js": "console.log('frontend');"
1271 }
1272 }),
1273 )
1274 .await;
1275 fs.insert_tree(
1276 "/workspace/backend",
1277 json!({
1278 "src": {
1279 "main.rs": "fn main() {}"
1280 }
1281 }),
1282 )
1283 .await;
1284 fs.insert_tree(
1285 "/workspace/shared",
1286 json!({
1287 ".zed": {
1288 "settings.json": "{}"
1289 }
1290 }),
1291 )
1292 .await;
1293
1294 // Create project with multiple worktrees
1295 let project = Project::test(
1296 fs.clone(),
1297 [
1298 path!("/workspace/frontend").as_ref(),
1299 path!("/workspace/backend").as_ref(),
1300 path!("/workspace/shared").as_ref(),
1301 ],
1302 cx,
1303 )
1304 .await;
1305 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1306 let context_server_registry =
1307 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1308 let model = Arc::new(FakeLanguageModel::default());
1309 let thread = cx.new(|cx| {
1310 Thread::new(
1311 project.clone(),
1312 cx.new(|_cx| ProjectContext::default()),
1313 context_server_registry.clone(),
1314 Templates::new(),
1315 Some(model.clone()),
1316 cx,
1317 )
1318 });
1319 let tool = Arc::new(EditFileTool::new(
1320 project.clone(),
1321 thread.downgrade(),
1322 language_registry,
1323 Templates::new(),
1324 ));
1325
1326 // Test files in different worktrees
1327 let test_cases = vec![
1328 ("frontend/src/main.js", false, "File in first worktree"),
1329 ("backend/src/main.rs", false, "File in second worktree"),
1330 (
1331 "shared/.zed/settings.json",
1332 true,
1333 ".zed file in third worktree",
1334 ),
1335 ("/etc/hosts", true, "Absolute path outside all worktrees"),
1336 (
1337 "../outside/file.txt",
1338 true,
1339 "Relative path outside worktrees",
1340 ),
1341 ];
1342
1343 for (path, should_confirm, description) in test_cases {
1344 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1345 let auth = cx.update(|cx| {
1346 tool.authorize(
1347 &EditFileToolInput {
1348 display_description: "Edit file".into(),
1349 path: path.into(),
1350 mode: EditFileMode::Edit,
1351 },
1352 &stream_tx,
1353 cx,
1354 )
1355 });
1356
1357 if should_confirm {
1358 stream_rx.expect_authorization().await;
1359 } else {
1360 auth.await.unwrap();
1361 assert!(
1362 stream_rx.try_next().is_err(),
1363 "Failed for case: {} - path: {} - expected no confirmation but got one",
1364 description,
1365 path
1366 );
1367 }
1368 }
1369 }
1370
1371 #[gpui::test]
1372 async fn test_needs_confirmation_edge_cases(cx: &mut TestAppContext) {
1373 init_test(cx);
1374 let fs = project::FakeFs::new(cx.executor());
1375 fs.insert_tree(
1376 "/project",
1377 json!({
1378 ".zed": {
1379 "settings.json": "{}"
1380 },
1381 "src": {
1382 ".zed": {
1383 "local.json": "{}"
1384 }
1385 }
1386 }),
1387 )
1388 .await;
1389 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1390 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1391 let context_server_registry =
1392 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1393 let model = Arc::new(FakeLanguageModel::default());
1394 let thread = cx.new(|cx| {
1395 Thread::new(
1396 project.clone(),
1397 cx.new(|_cx| ProjectContext::default()),
1398 context_server_registry.clone(),
1399 Templates::new(),
1400 Some(model.clone()),
1401 cx,
1402 )
1403 });
1404 let tool = Arc::new(EditFileTool::new(
1405 project.clone(),
1406 thread.downgrade(),
1407 language_registry,
1408 Templates::new(),
1409 ));
1410
1411 // Test edge cases
1412 let test_cases = vec![
1413 // Empty path - find_project_path returns Some for empty paths
1414 ("", false, "Empty path is treated as project root"),
1415 // Root directory
1416 ("/", true, "Root directory should be outside project"),
1417 // Parent directory references - find_project_path resolves these
1418 (
1419 "project/../other",
1420 true,
1421 "Path with .. that goes outside of root directory",
1422 ),
1423 (
1424 "project/./src/file.rs",
1425 false,
1426 "Path with . should work normally",
1427 ),
1428 // Windows-style paths (if on Windows)
1429 #[cfg(target_os = "windows")]
1430 ("C:\\Windows\\System32\\hosts", true, "Windows system path"),
1431 #[cfg(target_os = "windows")]
1432 ("project\\src\\main.rs", false, "Windows-style project path"),
1433 ];
1434
1435 for (path, should_confirm, description) in test_cases {
1436 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1437 let auth = cx.update(|cx| {
1438 tool.authorize(
1439 &EditFileToolInput {
1440 display_description: "Edit file".into(),
1441 path: path.into(),
1442 mode: EditFileMode::Edit,
1443 },
1444 &stream_tx,
1445 cx,
1446 )
1447 });
1448
1449 cx.run_until_parked();
1450
1451 if should_confirm {
1452 stream_rx.expect_authorization().await;
1453 } else {
1454 assert!(
1455 stream_rx.try_next().is_err(),
1456 "Failed for case: {} - path: {} - expected no confirmation but got one",
1457 description,
1458 path
1459 );
1460 auth.await.unwrap();
1461 }
1462 }
1463 }
1464
1465 #[gpui::test]
1466 async fn test_needs_confirmation_with_different_modes(cx: &mut TestAppContext) {
1467 init_test(cx);
1468 let fs = project::FakeFs::new(cx.executor());
1469 fs.insert_tree(
1470 "/project",
1471 json!({
1472 "existing.txt": "content",
1473 ".zed": {
1474 "settings.json": "{}"
1475 }
1476 }),
1477 )
1478 .await;
1479 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1480 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1481 let context_server_registry =
1482 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1483 let model = Arc::new(FakeLanguageModel::default());
1484 let thread = cx.new(|cx| {
1485 Thread::new(
1486 project.clone(),
1487 cx.new(|_cx| ProjectContext::default()),
1488 context_server_registry.clone(),
1489 Templates::new(),
1490 Some(model.clone()),
1491 cx,
1492 )
1493 });
1494 let tool = Arc::new(EditFileTool::new(
1495 project.clone(),
1496 thread.downgrade(),
1497 language_registry,
1498 Templates::new(),
1499 ));
1500
1501 // Test different EditFileMode values
1502 let modes = vec![
1503 EditFileMode::Edit,
1504 EditFileMode::Create,
1505 EditFileMode::Overwrite,
1506 ];
1507
1508 for mode in modes {
1509 // Test .zed path with different modes
1510 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1511 let _auth = cx.update(|cx| {
1512 tool.authorize(
1513 &EditFileToolInput {
1514 display_description: "Edit settings".into(),
1515 path: "project/.zed/settings.json".into(),
1516 mode: mode.clone(),
1517 },
1518 &stream_tx,
1519 cx,
1520 )
1521 });
1522
1523 stream_rx.expect_authorization().await;
1524
1525 // Test outside path with different modes
1526 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1527 let _auth = cx.update(|cx| {
1528 tool.authorize(
1529 &EditFileToolInput {
1530 display_description: "Edit file".into(),
1531 path: "/outside/file.txt".into(),
1532 mode: mode.clone(),
1533 },
1534 &stream_tx,
1535 cx,
1536 )
1537 });
1538
1539 stream_rx.expect_authorization().await;
1540
1541 // Test normal path with different modes
1542 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1543 cx.update(|cx| {
1544 tool.authorize(
1545 &EditFileToolInput {
1546 display_description: "Edit file".into(),
1547 path: "project/normal.txt".into(),
1548 mode: mode.clone(),
1549 },
1550 &stream_tx,
1551 cx,
1552 )
1553 })
1554 .await
1555 .unwrap();
1556 assert!(stream_rx.try_next().is_err());
1557 }
1558 }
1559
1560 #[gpui::test]
1561 async fn test_initial_title_with_partial_input(cx: &mut TestAppContext) {
1562 init_test(cx);
1563 let fs = project::FakeFs::new(cx.executor());
1564 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1565 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1566 let context_server_registry =
1567 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1568 let model = Arc::new(FakeLanguageModel::default());
1569 let thread = cx.new(|cx| {
1570 Thread::new(
1571 project.clone(),
1572 cx.new(|_cx| ProjectContext::default()),
1573 context_server_registry,
1574 Templates::new(),
1575 Some(model.clone()),
1576 cx,
1577 )
1578 });
1579 let tool = Arc::new(EditFileTool::new(
1580 project,
1581 thread.downgrade(),
1582 language_registry,
1583 Templates::new(),
1584 ));
1585
1586 cx.update(|cx| {
1587 // ...
1588 assert_eq!(
1589 tool.initial_title(
1590 Err(json!({
1591 "path": "src/main.rs",
1592 "display_description": "",
1593 "old_string": "old code",
1594 "new_string": "new code"
1595 })),
1596 cx
1597 ),
1598 "src/main.rs"
1599 );
1600 assert_eq!(
1601 tool.initial_title(
1602 Err(json!({
1603 "path": "",
1604 "display_description": "Fix error handling",
1605 "old_string": "old code",
1606 "new_string": "new code"
1607 })),
1608 cx
1609 ),
1610 "Fix error handling"
1611 );
1612 assert_eq!(
1613 tool.initial_title(
1614 Err(json!({
1615 "path": "src/main.rs",
1616 "display_description": "Fix error handling",
1617 "old_string": "old code",
1618 "new_string": "new code"
1619 })),
1620 cx
1621 ),
1622 "src/main.rs"
1623 );
1624 assert_eq!(
1625 tool.initial_title(
1626 Err(json!({
1627 "path": "",
1628 "display_description": "",
1629 "old_string": "old code",
1630 "new_string": "new code"
1631 })),
1632 cx
1633 ),
1634 DEFAULT_UI_TEXT
1635 );
1636 assert_eq!(
1637 tool.initial_title(Err(serde_json::Value::Null), cx),
1638 DEFAULT_UI_TEXT
1639 );
1640 });
1641 }
1642
1643 #[gpui::test]
1644 async fn test_diff_finalization(cx: &mut TestAppContext) {
1645 init_test(cx);
1646 let fs = project::FakeFs::new(cx.executor());
1647 fs.insert_tree("/", json!({"main.rs": ""})).await;
1648
1649 let project = Project::test(fs.clone(), [path!("/").as_ref()], cx).await;
1650 let languages = project.read_with(cx, |project, _cx| project.languages().clone());
1651 let context_server_registry =
1652 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1653 let model = Arc::new(FakeLanguageModel::default());
1654 let thread = cx.new(|cx| {
1655 Thread::new(
1656 project.clone(),
1657 cx.new(|_cx| ProjectContext::default()),
1658 context_server_registry.clone(),
1659 Templates::new(),
1660 Some(model.clone()),
1661 cx,
1662 )
1663 });
1664
1665 // Ensure the diff is finalized after the edit completes.
1666 {
1667 let tool = Arc::new(EditFileTool::new(
1668 project.clone(),
1669 thread.downgrade(),
1670 languages.clone(),
1671 Templates::new(),
1672 ));
1673 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1674 let edit = cx.update(|cx| {
1675 tool.run(
1676 EditFileToolInput {
1677 display_description: "Edit file".into(),
1678 path: path!("/main.rs").into(),
1679 mode: EditFileMode::Edit,
1680 },
1681 stream_tx,
1682 cx,
1683 )
1684 });
1685 stream_rx.expect_update_fields().await;
1686 let diff = stream_rx.expect_diff().await;
1687 diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Pending(_))));
1688 cx.run_until_parked();
1689 model.end_last_completion_stream();
1690 edit.await.unwrap();
1691 diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Finalized(_))));
1692 }
1693
1694 // Ensure the diff is finalized if an error occurs while editing.
1695 {
1696 model.forbid_requests();
1697 let tool = Arc::new(EditFileTool::new(
1698 project.clone(),
1699 thread.downgrade(),
1700 languages.clone(),
1701 Templates::new(),
1702 ));
1703 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1704 let edit = cx.update(|cx| {
1705 tool.run(
1706 EditFileToolInput {
1707 display_description: "Edit file".into(),
1708 path: path!("/main.rs").into(),
1709 mode: EditFileMode::Edit,
1710 },
1711 stream_tx,
1712 cx,
1713 )
1714 });
1715 stream_rx.expect_update_fields().await;
1716 let diff = stream_rx.expect_diff().await;
1717 diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Pending(_))));
1718 edit.await.unwrap_err();
1719 diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Finalized(_))));
1720 model.allow_requests();
1721 }
1722
1723 // Ensure the diff is finalized if the tool call gets dropped.
1724 {
1725 let tool = Arc::new(EditFileTool::new(
1726 project.clone(),
1727 thread.downgrade(),
1728 languages.clone(),
1729 Templates::new(),
1730 ));
1731 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1732 let edit = cx.update(|cx| {
1733 tool.run(
1734 EditFileToolInput {
1735 display_description: "Edit file".into(),
1736 path: path!("/main.rs").into(),
1737 mode: EditFileMode::Edit,
1738 },
1739 stream_tx,
1740 cx,
1741 )
1742 });
1743 stream_rx.expect_update_fields().await;
1744 let diff = stream_rx.expect_diff().await;
1745 diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Pending(_))));
1746 drop(edit);
1747 cx.run_until_parked();
1748 diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Finalized(_))));
1749 }
1750 }
1751
1752 fn init_test(cx: &mut TestAppContext) {
1753 cx.update(|cx| {
1754 let settings_store = SettingsStore::test(cx);
1755 cx.set_global(settings_store);
1756 language::init(cx);
1757 TelemetrySettings::register(cx);
1758 agent_settings::AgentSettings::register(cx);
1759 Project::init_settings(cx);
1760 });
1761 }
1762}