1use crate::{
2 AgentTool, Templates, Thread, ToolCallEventStream,
3 edit_agent::{EditAgent, EditAgentOutput, EditAgentOutputEvent, EditFormat},
4};
5use acp_thread::Diff;
6use agent_client_protocol::{self as acp, ToolCallLocation, ToolCallUpdateFields};
7use anyhow::{Context as _, Result, anyhow};
8use cloud_llm_client::CompletionIntent;
9use collections::HashSet;
10use gpui::{App, AppContext, AsyncApp, Entity, Task, WeakEntity};
11use indoc::formatdoc;
12use language::language_settings::{self, FormatOnSave};
13use language::{LanguageRegistry, ToPoint};
14use language_model::LanguageModelToolResultContent;
15use paths;
16use project::lsp_store::{FormatTrigger, LspFormatTarget};
17use project::{Project, ProjectPath};
18use schemars::JsonSchema;
19use serde::{Deserialize, Serialize};
20use settings::Settings;
21use smol::stream::StreamExt as _;
22use std::ffi::OsStr;
23use std::path::{Path, PathBuf};
24use std::sync::Arc;
25use ui::SharedString;
26use util::ResultExt;
27use util::rel_path::RelPath;
28
29const DEFAULT_UI_TEXT: &str = "Editing file";
30
31/// This is a tool for creating a new file or editing an existing file. For moving or renaming files, you should generally use the `terminal` tool with the 'mv' command instead.
32///
33/// Before using this tool:
34///
35/// 1. Use the `read_file` tool to understand the file's contents and context
36///
37/// 2. Verify the directory path is correct (only applicable when creating new files):
38/// - Use the `list_directory` tool to verify the parent directory exists and is the correct location
39#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
40pub struct EditFileToolInput {
41 /// A one-line, user-friendly markdown description of the edit. This will be shown in the UI and also passed to another model to perform the edit.
42 ///
43 /// Be terse, but also descriptive in what you want to achieve with this edit. Avoid generic instructions.
44 ///
45 /// NEVER mention the file path in this description.
46 ///
47 /// <example>Fix API endpoint URLs</example>
48 /// <example>Update copyright year in `page_footer`</example>
49 ///
50 /// Make sure to include this field before all the others in the input object so that we can display it immediately.
51 pub display_description: String,
52
53 /// The full path of the file to create or modify in the project.
54 ///
55 /// WARNING: When specifying which file path need changing, you MUST start each path with one of the project's root directories.
56 ///
57 /// The following examples assume we have two root directories in the project:
58 /// - /a/b/backend
59 /// - /c/d/frontend
60 ///
61 /// <example>
62 /// `backend/src/main.rs`
63 ///
64 /// Notice how the file path starts with `backend`. Without that, the path would be ambiguous and the call would fail!
65 /// </example>
66 ///
67 /// <example>
68 /// `frontend/db.js`
69 /// </example>
70 pub path: PathBuf,
71 /// The mode of operation on the file. Possible values:
72 /// - 'edit': Make granular edits to an existing file.
73 /// - 'create': Create a new file if it doesn't exist.
74 /// - 'overwrite': Replace the entire contents of an existing file.
75 ///
76 /// When a file already exists or you just created it, prefer editing it as opposed to recreating it from scratch.
77 pub mode: EditFileMode,
78}
79
80#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
81struct EditFileToolPartialInput {
82 #[serde(default)]
83 path: String,
84 #[serde(default)]
85 display_description: String,
86}
87
88#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
89#[serde(rename_all = "lowercase")]
90#[schemars(inline)]
91pub enum EditFileMode {
92 Edit,
93 Create,
94 Overwrite,
95}
96
97#[derive(Debug, Serialize, Deserialize)]
98pub struct EditFileToolOutput {
99 #[serde(alias = "original_path")]
100 input_path: PathBuf,
101 new_text: String,
102 old_text: Arc<String>,
103 #[serde(default)]
104 diff: String,
105 #[serde(alias = "raw_output")]
106 edit_agent_output: EditAgentOutput,
107}
108
109impl From<EditFileToolOutput> for LanguageModelToolResultContent {
110 fn from(output: EditFileToolOutput) -> Self {
111 if output.diff.is_empty() {
112 "No edits were made.".into()
113 } else {
114 format!(
115 "Edited {}:\n\n```diff\n{}\n```",
116 output.input_path.display(),
117 output.diff
118 )
119 .into()
120 }
121 }
122}
123
124pub struct EditFileTool {
125 thread: WeakEntity<Thread>,
126 language_registry: Arc<LanguageRegistry>,
127 project: Entity<Project>,
128 templates: Arc<Templates>,
129}
130
131impl EditFileTool {
132 pub fn new(
133 project: Entity<Project>,
134 thread: WeakEntity<Thread>,
135 language_registry: Arc<LanguageRegistry>,
136 templates: Arc<Templates>,
137 ) -> Self {
138 Self {
139 project,
140 thread,
141 language_registry,
142 templates,
143 }
144 }
145
146 fn authorize(
147 &self,
148 input: &EditFileToolInput,
149 event_stream: &ToolCallEventStream,
150 cx: &mut App,
151 ) -> Task<Result<()>> {
152 if agent_settings::AgentSettings::get_global(cx).always_allow_tool_actions {
153 return Task::ready(Ok(()));
154 }
155
156 // If any path component matches the local settings folder, then this could affect
157 // the editor in ways beyond the project source, so prompt.
158 let local_settings_folder = paths::local_settings_folder_name();
159 let path = Path::new(&input.path);
160 if path.components().any(|component| {
161 component.as_os_str() == <_ as AsRef<OsStr>>::as_ref(&local_settings_folder)
162 }) {
163 return event_stream.authorize(
164 format!("{} (local settings)", input.display_description),
165 cx,
166 );
167 }
168
169 // It's also possible that the global config dir is configured to be inside the project,
170 // so check for that edge case too.
171 // TODO this is broken when remoting
172 if let Ok(canonical_path) = std::fs::canonicalize(&input.path)
173 && canonical_path.starts_with(paths::config_dir())
174 {
175 return event_stream.authorize(
176 format!("{} (global settings)", input.display_description),
177 cx,
178 );
179 }
180
181 // Check if path is inside the global config directory
182 // First check if it's already inside project - if not, try to canonicalize
183 let Ok(project_path) = self.thread.read_with(cx, |thread, cx| {
184 thread.project().read(cx).find_project_path(&input.path, cx)
185 }) else {
186 return Task::ready(Err(anyhow!("thread was dropped")));
187 };
188
189 // If the path is inside the project, and it's not one of the above edge cases,
190 // then no confirmation is necessary. Otherwise, confirmation is necessary.
191 if project_path.is_some() {
192 Task::ready(Ok(()))
193 } else {
194 event_stream.authorize(&input.display_description, cx)
195 }
196 }
197}
198
199impl AgentTool for EditFileTool {
200 type Input = EditFileToolInput;
201 type Output = EditFileToolOutput;
202
203 fn name() -> &'static str {
204 "edit_file"
205 }
206
207 fn kind() -> acp::ToolKind {
208 acp::ToolKind::Edit
209 }
210
211 fn initial_title(
212 &self,
213 input: Result<Self::Input, serde_json::Value>,
214 cx: &mut App,
215 ) -> SharedString {
216 match input {
217 Ok(input) => self
218 .project
219 .read(cx)
220 .find_project_path(&input.path, cx)
221 .and_then(|project_path| {
222 self.project
223 .read(cx)
224 .short_full_path_for_project_path(&project_path, cx)
225 })
226 .unwrap_or(input.path.to_string_lossy().into_owned())
227 .into(),
228 Err(raw_input) => {
229 if let Some(input) =
230 serde_json::from_value::<EditFileToolPartialInput>(raw_input).ok()
231 {
232 let path = input.path.trim();
233 if !path.is_empty() {
234 return self
235 .project
236 .read(cx)
237 .find_project_path(&input.path, cx)
238 .and_then(|project_path| {
239 self.project
240 .read(cx)
241 .short_full_path_for_project_path(&project_path, cx)
242 })
243 .unwrap_or(input.path)
244 .into();
245 }
246
247 let description = input.display_description.trim();
248 if !description.is_empty() {
249 return description.to_string().into();
250 }
251 }
252
253 DEFAULT_UI_TEXT.into()
254 }
255 }
256 }
257
258 fn run(
259 self: Arc<Self>,
260 input: Self::Input,
261 event_stream: ToolCallEventStream,
262 cx: &mut App,
263 ) -> Task<Result<Self::Output>> {
264 let Ok(project) = self
265 .thread
266 .read_with(cx, |thread, _cx| thread.project().clone())
267 else {
268 return Task::ready(Err(anyhow!("thread was dropped")));
269 };
270 let project_path = match resolve_path(&input, project.clone(), cx) {
271 Ok(path) => path,
272 Err(err) => return Task::ready(Err(anyhow!(err))),
273 };
274 let abs_path = project.read(cx).absolute_path(&project_path, cx);
275 if let Some(abs_path) = abs_path.clone() {
276 event_stream.update_fields(ToolCallUpdateFields {
277 locations: Some(vec![acp::ToolCallLocation {
278 path: abs_path,
279 line: None,
280 meta: None,
281 }]),
282 ..Default::default()
283 });
284 }
285
286 let authorize = self.authorize(&input, &event_stream, cx);
287 cx.spawn(async move |cx: &mut AsyncApp| {
288 authorize.await?;
289
290 let (request, model, action_log) = self.thread.update(cx, |thread, cx| {
291 let request = thread.build_completion_request(CompletionIntent::ToolResults, cx);
292 (request, thread.model().cloned(), thread.action_log().clone())
293 })?;
294 let request = request?;
295 let model = model.context("No language model configured")?;
296
297 let edit_format = EditFormat::from_model(model.clone())?;
298 let edit_agent = EditAgent::new(
299 model,
300 project.clone(),
301 action_log.clone(),
302 self.templates.clone(),
303 edit_format,
304 );
305
306 let buffer = project
307 .update(cx, |project, cx| {
308 project.open_buffer(project_path.clone(), cx)
309 })?
310 .await?;
311
312 let diff = cx.new(|cx| Diff::new(buffer.clone(), cx))?;
313 event_stream.update_diff(diff.clone());
314 let _finalize_diff = util::defer({
315 let diff = diff.downgrade();
316 let mut cx = cx.clone();
317 move || {
318 diff.update(&mut cx, |diff, cx| diff.finalize(cx)).ok();
319 }
320 });
321
322 let old_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
323 let old_text = cx
324 .background_spawn({
325 let old_snapshot = old_snapshot.clone();
326 async move { Arc::new(old_snapshot.text()) }
327 })
328 .await;
329
330
331 let (output, mut events) = if matches!(input.mode, EditFileMode::Edit) {
332 edit_agent.edit(
333 buffer.clone(),
334 input.display_description.clone(),
335 &request,
336 cx,
337 )
338 } else {
339 edit_agent.overwrite(
340 buffer.clone(),
341 input.display_description.clone(),
342 &request,
343 cx,
344 )
345 };
346
347 let mut hallucinated_old_text = false;
348 let mut ambiguous_ranges = Vec::new();
349 let mut emitted_location = false;
350 while let Some(event) = events.next().await {
351 match event {
352 EditAgentOutputEvent::Edited(range) => {
353 if !emitted_location {
354 let line = buffer.update(cx, |buffer, _cx| {
355 range.start.to_point(&buffer.snapshot()).row
356 }).ok();
357 if let Some(abs_path) = abs_path.clone() {
358 event_stream.update_fields(ToolCallUpdateFields {
359 locations: Some(vec![ToolCallLocation { path: abs_path, line, meta: None }]),
360 ..Default::default()
361 });
362 }
363 emitted_location = true;
364 }
365 },
366 EditAgentOutputEvent::UnresolvedEditRange => hallucinated_old_text = true,
367 EditAgentOutputEvent::AmbiguousEditRange(ranges) => ambiguous_ranges = ranges,
368 EditAgentOutputEvent::ResolvingEditRange(range) => {
369 diff.update(cx, |card, cx| card.reveal_range(range.clone(), cx))?;
370 // if !emitted_location {
371 // let line = buffer.update(cx, |buffer, _cx| {
372 // range.start.to_point(&buffer.snapshot()).row
373 // }).ok();
374 // if let Some(abs_path) = abs_path.clone() {
375 // event_stream.update_fields(ToolCallUpdateFields {
376 // locations: Some(vec![ToolCallLocation { path: abs_path, line }]),
377 // ..Default::default()
378 // });
379 // }
380 // }
381 }
382 }
383 }
384
385 // If format_on_save is enabled, format the buffer
386 let format_on_save_enabled = buffer
387 .read_with(cx, |buffer, cx| {
388 let settings = language_settings::language_settings(
389 buffer.language().map(|l| l.name()),
390 buffer.file(),
391 cx,
392 );
393 settings.format_on_save != FormatOnSave::Off
394 })
395 .unwrap_or(false);
396
397 let edit_agent_output = output.await?;
398
399 if format_on_save_enabled {
400 action_log.update(cx, |log, cx| {
401 log.buffer_edited(buffer.clone(), cx);
402 })?;
403
404 let format_task = project.update(cx, |project, cx| {
405 project.format(
406 HashSet::from_iter([buffer.clone()]),
407 LspFormatTarget::Buffers,
408 false, // Don't push to history since the tool did it.
409 FormatTrigger::Save,
410 cx,
411 )
412 })?;
413 format_task.await.log_err();
414 }
415
416 project
417 .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))?
418 .await?;
419
420 action_log.update(cx, |log, cx| {
421 log.buffer_edited(buffer.clone(), cx);
422 })?;
423
424 let new_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
425 let (new_text, unified_diff) = cx
426 .background_spawn({
427 let new_snapshot = new_snapshot.clone();
428 let old_text = old_text.clone();
429 async move {
430 let new_text = new_snapshot.text();
431 let diff = language::unified_diff(&old_text, &new_text);
432 (new_text, diff)
433 }
434 })
435 .await;
436
437 let input_path = input.path.display();
438 if unified_diff.is_empty() {
439 anyhow::ensure!(
440 !hallucinated_old_text,
441 formatdoc! {"
442 Some edits were produced but none of them could be applied.
443 Read the relevant sections of {input_path} again so that
444 I can perform the requested edits.
445 "}
446 );
447 anyhow::ensure!(
448 ambiguous_ranges.is_empty(),
449 {
450 let line_numbers = ambiguous_ranges
451 .iter()
452 .map(|range| range.start.to_string())
453 .collect::<Vec<_>>()
454 .join(", ");
455 formatdoc! {"
456 <old_text> matches more than one position in the file (lines: {line_numbers}). Read the
457 relevant sections of {input_path} again and extend <old_text> so
458 that I can perform the requested edits.
459 "}
460 }
461 );
462 }
463
464 Ok(EditFileToolOutput {
465 input_path: input.path,
466 new_text,
467 old_text,
468 diff: unified_diff,
469 edit_agent_output,
470 })
471 })
472 }
473
474 fn replay(
475 &self,
476 _input: Self::Input,
477 output: Self::Output,
478 event_stream: ToolCallEventStream,
479 cx: &mut App,
480 ) -> Result<()> {
481 event_stream.update_diff(cx.new(|cx| {
482 Diff::finalized(
483 output.input_path.to_string_lossy().into_owned(),
484 Some(output.old_text.to_string()),
485 output.new_text,
486 self.language_registry.clone(),
487 cx,
488 )
489 }));
490 Ok(())
491 }
492}
493
494/// Validate that the file path is valid, meaning:
495///
496/// - For `edit` and `overwrite`, the path must point to an existing file.
497/// - For `create`, the file must not already exist, but it's parent dir must exist.
498fn resolve_path(
499 input: &EditFileToolInput,
500 project: Entity<Project>,
501 cx: &mut App,
502) -> Result<ProjectPath> {
503 let project = project.read(cx);
504
505 match input.mode {
506 EditFileMode::Edit | EditFileMode::Overwrite => {
507 let path = project
508 .find_project_path(&input.path, cx)
509 .context("Can't edit file: path not found")?;
510
511 let entry = project
512 .entry_for_path(&path, cx)
513 .context("Can't edit file: path not found")?;
514
515 anyhow::ensure!(entry.is_file(), "Can't edit file: path is a directory");
516 Ok(path)
517 }
518
519 EditFileMode::Create => {
520 if let Some(path) = project.find_project_path(&input.path, cx) {
521 anyhow::ensure!(
522 project.entry_for_path(&path, cx).is_none(),
523 "Can't create file: file already exists"
524 );
525 }
526
527 let parent_path = input
528 .path
529 .parent()
530 .context("Can't create file: incorrect path")?;
531
532 let parent_project_path = project.find_project_path(&parent_path, cx);
533
534 let parent_entry = parent_project_path
535 .as_ref()
536 .and_then(|path| project.entry_for_path(path, cx))
537 .context("Can't create file: parent directory doesn't exist")?;
538
539 anyhow::ensure!(
540 parent_entry.is_dir(),
541 "Can't create file: parent is not a directory"
542 );
543
544 let file_name = input
545 .path
546 .file_name()
547 .and_then(|file_name| file_name.to_str())
548 .and_then(|file_name| RelPath::unix(file_name).ok())
549 .context("Can't create file: invalid filename")?;
550
551 let new_file_path = parent_project_path.map(|parent| ProjectPath {
552 path: parent.path.join(file_name),
553 ..parent
554 });
555
556 new_file_path.context("Can't create file")
557 }
558 }
559}
560
561#[cfg(test)]
562mod tests {
563 use super::*;
564 use crate::{ContextServerRegistry, Templates};
565 use client::TelemetrySettings;
566 use encoding::all::UTF_8;
567 use fs::{Fs, encodings::EncodingWrapper};
568 use gpui::{TestAppContext, UpdateGlobal};
569 use language_model::fake_provider::FakeLanguageModel;
570 use prompt_store::ProjectContext;
571 use serde_json::json;
572 use settings::SettingsStore;
573 use text::Rope;
574 use util::{path, rel_path::rel_path};
575
576 #[gpui::test]
577 async fn test_edit_nonexistent_file(cx: &mut TestAppContext) {
578 init_test(cx);
579
580 let fs = project::FakeFs::new(cx.executor());
581 fs.insert_tree("/root", json!({})).await;
582 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
583 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
584 let context_server_registry =
585 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
586 let model = Arc::new(FakeLanguageModel::default());
587 let thread = cx.new(|cx| {
588 Thread::new(
589 project.clone(),
590 cx.new(|_cx| ProjectContext::default()),
591 context_server_registry,
592 Templates::new(),
593 Some(model),
594 cx,
595 )
596 });
597 let result = cx
598 .update(|cx| {
599 let input = EditFileToolInput {
600 display_description: "Some edit".into(),
601 path: "root/nonexistent_file.txt".into(),
602 mode: EditFileMode::Edit,
603 };
604 Arc::new(EditFileTool::new(
605 project,
606 thread.downgrade(),
607 language_registry,
608 Templates::new(),
609 ))
610 .run(input, ToolCallEventStream::test().0, cx)
611 })
612 .await;
613 assert_eq!(
614 result.unwrap_err().to_string(),
615 "Can't edit file: path not found"
616 );
617 }
618
619 #[gpui::test]
620 async fn test_resolve_path_for_creating_file(cx: &mut TestAppContext) {
621 let mode = &EditFileMode::Create;
622
623 let result = test_resolve_path(mode, "root/new.txt", cx);
624 assert_resolved_path_eq(result.await, rel_path("new.txt"));
625
626 let result = test_resolve_path(mode, "new.txt", cx);
627 assert_resolved_path_eq(result.await, rel_path("new.txt"));
628
629 let result = test_resolve_path(mode, "dir/new.txt", cx);
630 assert_resolved_path_eq(result.await, rel_path("dir/new.txt"));
631
632 let result = test_resolve_path(mode, "root/dir/subdir/existing.txt", cx);
633 assert_eq!(
634 result.await.unwrap_err().to_string(),
635 "Can't create file: file already exists"
636 );
637
638 let result = test_resolve_path(mode, "root/dir/nonexistent_dir/new.txt", cx);
639 assert_eq!(
640 result.await.unwrap_err().to_string(),
641 "Can't create file: parent directory doesn't exist"
642 );
643 }
644
645 #[gpui::test]
646 async fn test_resolve_path_for_editing_file(cx: &mut TestAppContext) {
647 let mode = &EditFileMode::Edit;
648
649 let path_with_root = "root/dir/subdir/existing.txt";
650 let path_without_root = "dir/subdir/existing.txt";
651 let result = test_resolve_path(mode, path_with_root, cx);
652 assert_resolved_path_eq(result.await, rel_path(path_without_root));
653
654 let result = test_resolve_path(mode, path_without_root, cx);
655 assert_resolved_path_eq(result.await, rel_path(path_without_root));
656
657 let result = test_resolve_path(mode, "root/nonexistent.txt", cx);
658 assert_eq!(
659 result.await.unwrap_err().to_string(),
660 "Can't edit file: path not found"
661 );
662
663 let result = test_resolve_path(mode, "root/dir", cx);
664 assert_eq!(
665 result.await.unwrap_err().to_string(),
666 "Can't edit file: path is a directory"
667 );
668 }
669
670 async fn test_resolve_path(
671 mode: &EditFileMode,
672 path: &str,
673 cx: &mut TestAppContext,
674 ) -> anyhow::Result<ProjectPath> {
675 init_test(cx);
676
677 let fs = project::FakeFs::new(cx.executor());
678 fs.insert_tree(
679 "/root",
680 json!({
681 "dir": {
682 "subdir": {
683 "existing.txt": "hello"
684 }
685 }
686 }),
687 )
688 .await;
689 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
690
691 let input = EditFileToolInput {
692 display_description: "Some edit".into(),
693 path: path.into(),
694 mode: mode.clone(),
695 };
696
697 cx.update(|cx| resolve_path(&input, project, cx))
698 }
699
700 #[track_caller]
701 fn assert_resolved_path_eq(path: anyhow::Result<ProjectPath>, expected: &RelPath) {
702 let actual = path.expect("Should return valid path").path;
703 assert_eq!(actual.as_ref(), expected);
704 }
705
706 #[gpui::test]
707 async fn test_format_on_save(cx: &mut TestAppContext) {
708 init_test(cx);
709
710 let fs = project::FakeFs::new(cx.executor());
711 fs.insert_tree("/root", json!({"src": {}})).await;
712
713 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
714
715 // Set up a Rust language with LSP formatting support
716 let rust_language = Arc::new(language::Language::new(
717 language::LanguageConfig {
718 name: "Rust".into(),
719 matcher: language::LanguageMatcher {
720 path_suffixes: vec!["rs".to_string()],
721 ..Default::default()
722 },
723 ..Default::default()
724 },
725 None,
726 ));
727
728 // Register the language and fake LSP
729 let language_registry = project.read_with(cx, |project, _| project.languages().clone());
730 language_registry.add(rust_language);
731
732 let mut fake_language_servers = language_registry.register_fake_lsp(
733 "Rust",
734 language::FakeLspAdapter {
735 capabilities: lsp::ServerCapabilities {
736 document_formatting_provider: Some(lsp::OneOf::Left(true)),
737 ..Default::default()
738 },
739 ..Default::default()
740 },
741 );
742
743 // Create the file
744 fs.save(
745 path!("/root/src/main.rs").as_ref(),
746 &Rope::from_str_small("initial content"),
747 language::LineEnding::Unix,
748 EncodingWrapper::new(UTF_8),
749 )
750 .await
751 .unwrap();
752
753 // Open the buffer to trigger LSP initialization
754 let buffer = project
755 .update(cx, |project, cx| {
756 project.open_local_buffer(path!("/root/src/main.rs"), cx)
757 })
758 .await
759 .unwrap();
760
761 // Register the buffer with language servers
762 let _handle = project.update(cx, |project, cx| {
763 project.register_buffer_with_language_servers(&buffer, cx)
764 });
765
766 const UNFORMATTED_CONTENT: &str = "fn main() {println!(\"Hello!\");}\n";
767 const FORMATTED_CONTENT: &str =
768 "This file was formatted by the fake formatter in the test.\n";
769
770 // Get the fake language server and set up formatting handler
771 let fake_language_server = fake_language_servers.next().await.unwrap();
772 fake_language_server.set_request_handler::<lsp::request::Formatting, _, _>({
773 |_, _| async move {
774 Ok(Some(vec![lsp::TextEdit {
775 range: lsp::Range::new(lsp::Position::new(0, 0), lsp::Position::new(1, 0)),
776 new_text: FORMATTED_CONTENT.to_string(),
777 }]))
778 }
779 });
780
781 let context_server_registry =
782 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
783 let model = Arc::new(FakeLanguageModel::default());
784 let thread = cx.new(|cx| {
785 Thread::new(
786 project.clone(),
787 cx.new(|_cx| ProjectContext::default()),
788 context_server_registry,
789 Templates::new(),
790 Some(model.clone()),
791 cx,
792 )
793 });
794
795 // First, test with format_on_save enabled
796 cx.update(|cx| {
797 SettingsStore::update_global(cx, |store, cx| {
798 store.update_user_settings(cx, |settings| {
799 settings.project.all_languages.defaults.format_on_save = Some(FormatOnSave::On);
800 settings.project.all_languages.defaults.formatter =
801 Some(language::language_settings::FormatterList::default());
802 });
803 });
804 });
805
806 // Have the model stream unformatted content
807 let edit_result = {
808 let edit_task = cx.update(|cx| {
809 let input = EditFileToolInput {
810 display_description: "Create main function".into(),
811 path: "root/src/main.rs".into(),
812 mode: EditFileMode::Overwrite,
813 };
814 Arc::new(EditFileTool::new(
815 project.clone(),
816 thread.downgrade(),
817 language_registry.clone(),
818 Templates::new(),
819 ))
820 .run(input, ToolCallEventStream::test().0, cx)
821 });
822
823 // Stream the unformatted content
824 cx.executor().run_until_parked();
825 model.send_last_completion_stream_text_chunk(UNFORMATTED_CONTENT.to_string());
826 model.end_last_completion_stream();
827
828 edit_task.await
829 };
830 assert!(edit_result.is_ok());
831
832 // Wait for any async operations (e.g. formatting) to complete
833 cx.executor().run_until_parked();
834
835 // Read the file to verify it was formatted automatically
836 let new_content = fs.load(path!("/root/src/main.rs").as_ref()).await.unwrap();
837 assert_eq!(
838 // Ignore carriage returns on Windows
839 new_content.replace("\r\n", "\n"),
840 FORMATTED_CONTENT,
841 "Code should be formatted when format_on_save is enabled"
842 );
843
844 let stale_buffer_count = thread
845 .read_with(cx, |thread, _cx| thread.action_log.clone())
846 .read_with(cx, |log, cx| log.stale_buffers(cx).count());
847
848 assert_eq!(
849 stale_buffer_count, 0,
850 "BUG: Buffer is incorrectly marked as stale after format-on-save. Found {} stale buffers. \
851 This causes the agent to think the file was modified externally when it was just formatted.",
852 stale_buffer_count
853 );
854
855 // Next, test with format_on_save disabled
856 cx.update(|cx| {
857 SettingsStore::update_global(cx, |store, cx| {
858 store.update_user_settings(cx, |settings| {
859 settings.project.all_languages.defaults.format_on_save =
860 Some(FormatOnSave::Off);
861 });
862 });
863 });
864
865 // Stream unformatted edits again
866 let edit_result = {
867 let edit_task = cx.update(|cx| {
868 let input = EditFileToolInput {
869 display_description: "Update main function".into(),
870 path: "root/src/main.rs".into(),
871 mode: EditFileMode::Overwrite,
872 };
873 Arc::new(EditFileTool::new(
874 project.clone(),
875 thread.downgrade(),
876 language_registry,
877 Templates::new(),
878 ))
879 .run(input, ToolCallEventStream::test().0, cx)
880 });
881
882 // Stream the unformatted content
883 cx.executor().run_until_parked();
884 model.send_last_completion_stream_text_chunk(UNFORMATTED_CONTENT.to_string());
885 model.end_last_completion_stream();
886
887 edit_task.await
888 };
889 assert!(edit_result.is_ok());
890
891 // Wait for any async operations (e.g. formatting) to complete
892 cx.executor().run_until_parked();
893
894 // Verify the file was not formatted
895 let new_content = fs.load(path!("/root/src/main.rs").as_ref()).await.unwrap();
896 assert_eq!(
897 // Ignore carriage returns on Windows
898 new_content.replace("\r\n", "\n"),
899 UNFORMATTED_CONTENT,
900 "Code should not be formatted when format_on_save is disabled"
901 );
902 }
903
904 #[gpui::test]
905 async fn test_remove_trailing_whitespace(cx: &mut TestAppContext) {
906 init_test(cx);
907
908 let fs = project::FakeFs::new(cx.executor());
909 fs.insert_tree("/root", json!({"src": {}})).await;
910
911 // Create a simple file with trailing whitespace
912 fs.save(
913 path!("/root/src/main.rs").as_ref(),
914 &Rope::from_str_small("initial content"),
915 language::LineEnding::Unix,
916 EncodingWrapper::new(UTF_8),
917 )
918 .await
919 .unwrap();
920
921 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
922 let context_server_registry =
923 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
924 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
925 let model = Arc::new(FakeLanguageModel::default());
926 let thread = cx.new(|cx| {
927 Thread::new(
928 project.clone(),
929 cx.new(|_cx| ProjectContext::default()),
930 context_server_registry,
931 Templates::new(),
932 Some(model.clone()),
933 cx,
934 )
935 });
936
937 // First, test with remove_trailing_whitespace_on_save enabled
938 cx.update(|cx| {
939 SettingsStore::update_global(cx, |store, cx| {
940 store.update_user_settings(cx, |settings| {
941 settings
942 .project
943 .all_languages
944 .defaults
945 .remove_trailing_whitespace_on_save = Some(true);
946 });
947 });
948 });
949
950 const CONTENT_WITH_TRAILING_WHITESPACE: &str =
951 "fn main() { \n println!(\"Hello!\"); \n}\n";
952
953 // Have the model stream content that contains trailing whitespace
954 let edit_result = {
955 let edit_task = cx.update(|cx| {
956 let input = EditFileToolInput {
957 display_description: "Create main function".into(),
958 path: "root/src/main.rs".into(),
959 mode: EditFileMode::Overwrite,
960 };
961 Arc::new(EditFileTool::new(
962 project.clone(),
963 thread.downgrade(),
964 language_registry.clone(),
965 Templates::new(),
966 ))
967 .run(input, ToolCallEventStream::test().0, cx)
968 });
969
970 // Stream the content with trailing whitespace
971 cx.executor().run_until_parked();
972 model.send_last_completion_stream_text_chunk(
973 CONTENT_WITH_TRAILING_WHITESPACE.to_string(),
974 );
975 model.end_last_completion_stream();
976
977 edit_task.await
978 };
979 assert!(edit_result.is_ok());
980
981 // Wait for any async operations (e.g. formatting) to complete
982 cx.executor().run_until_parked();
983
984 // Read the file to verify trailing whitespace was removed automatically
985 assert_eq!(
986 // Ignore carriage returns on Windows
987 fs.load(path!("/root/src/main.rs").as_ref())
988 .await
989 .unwrap()
990 .replace("\r\n", "\n"),
991 "fn main() {\n println!(\"Hello!\");\n}\n",
992 "Trailing whitespace should be removed when remove_trailing_whitespace_on_save is enabled"
993 );
994
995 // Next, test with remove_trailing_whitespace_on_save disabled
996 cx.update(|cx| {
997 SettingsStore::update_global(cx, |store, cx| {
998 store.update_user_settings(cx, |settings| {
999 settings
1000 .project
1001 .all_languages
1002 .defaults
1003 .remove_trailing_whitespace_on_save = Some(false);
1004 });
1005 });
1006 });
1007
1008 // Stream edits again with trailing whitespace
1009 let edit_result = {
1010 let edit_task = cx.update(|cx| {
1011 let input = EditFileToolInput {
1012 display_description: "Update main function".into(),
1013 path: "root/src/main.rs".into(),
1014 mode: EditFileMode::Overwrite,
1015 };
1016 Arc::new(EditFileTool::new(
1017 project.clone(),
1018 thread.downgrade(),
1019 language_registry,
1020 Templates::new(),
1021 ))
1022 .run(input, ToolCallEventStream::test().0, cx)
1023 });
1024
1025 // Stream the content with trailing whitespace
1026 cx.executor().run_until_parked();
1027 model.send_last_completion_stream_text_chunk(
1028 CONTENT_WITH_TRAILING_WHITESPACE.to_string(),
1029 );
1030 model.end_last_completion_stream();
1031
1032 edit_task.await
1033 };
1034 assert!(edit_result.is_ok());
1035
1036 // Wait for any async operations (e.g. formatting) to complete
1037 cx.executor().run_until_parked();
1038
1039 // Verify the file still has trailing whitespace
1040 // Read the file again - it should still have trailing whitespace
1041 let final_content = fs.load(path!("/root/src/main.rs").as_ref()).await.unwrap();
1042 assert_eq!(
1043 // Ignore carriage returns on Windows
1044 final_content.replace("\r\n", "\n"),
1045 CONTENT_WITH_TRAILING_WHITESPACE,
1046 "Trailing whitespace should remain when remove_trailing_whitespace_on_save is disabled"
1047 );
1048 }
1049
1050 #[gpui::test]
1051 async fn test_authorize(cx: &mut TestAppContext) {
1052 init_test(cx);
1053 let fs = project::FakeFs::new(cx.executor());
1054 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
1055 let context_server_registry =
1056 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1057 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1058 let model = Arc::new(FakeLanguageModel::default());
1059 let thread = cx.new(|cx| {
1060 Thread::new(
1061 project.clone(),
1062 cx.new(|_cx| ProjectContext::default()),
1063 context_server_registry,
1064 Templates::new(),
1065 Some(model.clone()),
1066 cx,
1067 )
1068 });
1069 let tool = Arc::new(EditFileTool::new(
1070 project.clone(),
1071 thread.downgrade(),
1072 language_registry,
1073 Templates::new(),
1074 ));
1075 fs.insert_tree("/root", json!({})).await;
1076
1077 // Test 1: Path with .zed component should require confirmation
1078 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1079 let _auth = cx.update(|cx| {
1080 tool.authorize(
1081 &EditFileToolInput {
1082 display_description: "test 1".into(),
1083 path: ".zed/settings.json".into(),
1084 mode: EditFileMode::Edit,
1085 },
1086 &stream_tx,
1087 cx,
1088 )
1089 });
1090
1091 let event = stream_rx.expect_authorization().await;
1092 assert_eq!(
1093 event.tool_call.fields.title,
1094 Some("test 1 (local settings)".into())
1095 );
1096
1097 // Test 2: Path outside project should require confirmation
1098 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1099 let _auth = cx.update(|cx| {
1100 tool.authorize(
1101 &EditFileToolInput {
1102 display_description: "test 2".into(),
1103 path: "/etc/hosts".into(),
1104 mode: EditFileMode::Edit,
1105 },
1106 &stream_tx,
1107 cx,
1108 )
1109 });
1110
1111 let event = stream_rx.expect_authorization().await;
1112 assert_eq!(event.tool_call.fields.title, Some("test 2".into()));
1113
1114 // Test 3: Relative path without .zed should not require confirmation
1115 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1116 cx.update(|cx| {
1117 tool.authorize(
1118 &EditFileToolInput {
1119 display_description: "test 3".into(),
1120 path: "root/src/main.rs".into(),
1121 mode: EditFileMode::Edit,
1122 },
1123 &stream_tx,
1124 cx,
1125 )
1126 })
1127 .await
1128 .unwrap();
1129 assert!(stream_rx.try_next().is_err());
1130
1131 // Test 4: Path with .zed in the middle should require confirmation
1132 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1133 let _auth = cx.update(|cx| {
1134 tool.authorize(
1135 &EditFileToolInput {
1136 display_description: "test 4".into(),
1137 path: "root/.zed/tasks.json".into(),
1138 mode: EditFileMode::Edit,
1139 },
1140 &stream_tx,
1141 cx,
1142 )
1143 });
1144 let event = stream_rx.expect_authorization().await;
1145 assert_eq!(
1146 event.tool_call.fields.title,
1147 Some("test 4 (local settings)".into())
1148 );
1149
1150 // Test 5: When always_allow_tool_actions is enabled, no confirmation needed
1151 cx.update(|cx| {
1152 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
1153 settings.always_allow_tool_actions = true;
1154 agent_settings::AgentSettings::override_global(settings, cx);
1155 });
1156
1157 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1158 cx.update(|cx| {
1159 tool.authorize(
1160 &EditFileToolInput {
1161 display_description: "test 5.1".into(),
1162 path: ".zed/settings.json".into(),
1163 mode: EditFileMode::Edit,
1164 },
1165 &stream_tx,
1166 cx,
1167 )
1168 })
1169 .await
1170 .unwrap();
1171 assert!(stream_rx.try_next().is_err());
1172
1173 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1174 cx.update(|cx| {
1175 tool.authorize(
1176 &EditFileToolInput {
1177 display_description: "test 5.2".into(),
1178 path: "/etc/hosts".into(),
1179 mode: EditFileMode::Edit,
1180 },
1181 &stream_tx,
1182 cx,
1183 )
1184 })
1185 .await
1186 .unwrap();
1187 assert!(stream_rx.try_next().is_err());
1188 }
1189
1190 #[gpui::test]
1191 async fn test_authorize_global_config(cx: &mut TestAppContext) {
1192 init_test(cx);
1193 let fs = project::FakeFs::new(cx.executor());
1194 fs.insert_tree("/project", json!({})).await;
1195 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1196 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1197 let context_server_registry =
1198 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1199 let model = Arc::new(FakeLanguageModel::default());
1200 let thread = cx.new(|cx| {
1201 Thread::new(
1202 project.clone(),
1203 cx.new(|_cx| ProjectContext::default()),
1204 context_server_registry,
1205 Templates::new(),
1206 Some(model.clone()),
1207 cx,
1208 )
1209 });
1210 let tool = Arc::new(EditFileTool::new(
1211 project.clone(),
1212 thread.downgrade(),
1213 language_registry,
1214 Templates::new(),
1215 ));
1216
1217 // Test global config paths - these should require confirmation if they exist and are outside the project
1218 let test_cases = vec![
1219 (
1220 "/etc/hosts",
1221 true,
1222 "System file should require confirmation",
1223 ),
1224 (
1225 "/usr/local/bin/script",
1226 true,
1227 "System bin file should require confirmation",
1228 ),
1229 (
1230 "project/normal_file.rs",
1231 false,
1232 "Normal project file should not require confirmation",
1233 ),
1234 ];
1235
1236 for (path, should_confirm, description) in test_cases {
1237 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1238 let auth = cx.update(|cx| {
1239 tool.authorize(
1240 &EditFileToolInput {
1241 display_description: "Edit file".into(),
1242 path: path.into(),
1243 mode: EditFileMode::Edit,
1244 },
1245 &stream_tx,
1246 cx,
1247 )
1248 });
1249
1250 if should_confirm {
1251 stream_rx.expect_authorization().await;
1252 } else {
1253 auth.await.unwrap();
1254 assert!(
1255 stream_rx.try_next().is_err(),
1256 "Failed for case: {} - path: {} - expected no confirmation but got one",
1257 description,
1258 path
1259 );
1260 }
1261 }
1262 }
1263
1264 #[gpui::test]
1265 async fn test_needs_confirmation_with_multiple_worktrees(cx: &mut TestAppContext) {
1266 init_test(cx);
1267 let fs = project::FakeFs::new(cx.executor());
1268
1269 // Create multiple worktree directories
1270 fs.insert_tree(
1271 "/workspace/frontend",
1272 json!({
1273 "src": {
1274 "main.js": "console.log('frontend');"
1275 }
1276 }),
1277 )
1278 .await;
1279 fs.insert_tree(
1280 "/workspace/backend",
1281 json!({
1282 "src": {
1283 "main.rs": "fn main() {}"
1284 }
1285 }),
1286 )
1287 .await;
1288 fs.insert_tree(
1289 "/workspace/shared",
1290 json!({
1291 ".zed": {
1292 "settings.json": "{}"
1293 }
1294 }),
1295 )
1296 .await;
1297
1298 // Create project with multiple worktrees
1299 let project = Project::test(
1300 fs.clone(),
1301 [
1302 path!("/workspace/frontend").as_ref(),
1303 path!("/workspace/backend").as_ref(),
1304 path!("/workspace/shared").as_ref(),
1305 ],
1306 cx,
1307 )
1308 .await;
1309 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1310 let context_server_registry =
1311 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1312 let model = Arc::new(FakeLanguageModel::default());
1313 let thread = cx.new(|cx| {
1314 Thread::new(
1315 project.clone(),
1316 cx.new(|_cx| ProjectContext::default()),
1317 context_server_registry.clone(),
1318 Templates::new(),
1319 Some(model.clone()),
1320 cx,
1321 )
1322 });
1323 let tool = Arc::new(EditFileTool::new(
1324 project.clone(),
1325 thread.downgrade(),
1326 language_registry,
1327 Templates::new(),
1328 ));
1329
1330 // Test files in different worktrees
1331 let test_cases = vec![
1332 ("frontend/src/main.js", false, "File in first worktree"),
1333 ("backend/src/main.rs", false, "File in second worktree"),
1334 (
1335 "shared/.zed/settings.json",
1336 true,
1337 ".zed file in third worktree",
1338 ),
1339 ("/etc/hosts", true, "Absolute path outside all worktrees"),
1340 (
1341 "../outside/file.txt",
1342 true,
1343 "Relative path outside worktrees",
1344 ),
1345 ];
1346
1347 for (path, should_confirm, description) in test_cases {
1348 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1349 let auth = cx.update(|cx| {
1350 tool.authorize(
1351 &EditFileToolInput {
1352 display_description: "Edit file".into(),
1353 path: path.into(),
1354 mode: EditFileMode::Edit,
1355 },
1356 &stream_tx,
1357 cx,
1358 )
1359 });
1360
1361 if should_confirm {
1362 stream_rx.expect_authorization().await;
1363 } else {
1364 auth.await.unwrap();
1365 assert!(
1366 stream_rx.try_next().is_err(),
1367 "Failed for case: {} - path: {} - expected no confirmation but got one",
1368 description,
1369 path
1370 );
1371 }
1372 }
1373 }
1374
1375 #[gpui::test]
1376 async fn test_needs_confirmation_edge_cases(cx: &mut TestAppContext) {
1377 init_test(cx);
1378 let fs = project::FakeFs::new(cx.executor());
1379 fs.insert_tree(
1380 "/project",
1381 json!({
1382 ".zed": {
1383 "settings.json": "{}"
1384 },
1385 "src": {
1386 ".zed": {
1387 "local.json": "{}"
1388 }
1389 }
1390 }),
1391 )
1392 .await;
1393 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1394 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1395 let context_server_registry =
1396 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1397 let model = Arc::new(FakeLanguageModel::default());
1398 let thread = cx.new(|cx| {
1399 Thread::new(
1400 project.clone(),
1401 cx.new(|_cx| ProjectContext::default()),
1402 context_server_registry.clone(),
1403 Templates::new(),
1404 Some(model.clone()),
1405 cx,
1406 )
1407 });
1408 let tool = Arc::new(EditFileTool::new(
1409 project.clone(),
1410 thread.downgrade(),
1411 language_registry,
1412 Templates::new(),
1413 ));
1414
1415 // Test edge cases
1416 let test_cases = vec![
1417 // Empty path - find_project_path returns Some for empty paths
1418 ("", false, "Empty path is treated as project root"),
1419 // Root directory
1420 ("/", true, "Root directory should be outside project"),
1421 // Parent directory references - find_project_path resolves these
1422 (
1423 "project/../other",
1424 true,
1425 "Path with .. that goes outside of root directory",
1426 ),
1427 (
1428 "project/./src/file.rs",
1429 false,
1430 "Path with . should work normally",
1431 ),
1432 // Windows-style paths (if on Windows)
1433 #[cfg(target_os = "windows")]
1434 ("C:\\Windows\\System32\\hosts", true, "Windows system path"),
1435 #[cfg(target_os = "windows")]
1436 ("project\\src\\main.rs", false, "Windows-style project path"),
1437 ];
1438
1439 for (path, should_confirm, description) in test_cases {
1440 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1441 let auth = cx.update(|cx| {
1442 tool.authorize(
1443 &EditFileToolInput {
1444 display_description: "Edit file".into(),
1445 path: path.into(),
1446 mode: EditFileMode::Edit,
1447 },
1448 &stream_tx,
1449 cx,
1450 )
1451 });
1452
1453 cx.run_until_parked();
1454
1455 if should_confirm {
1456 stream_rx.expect_authorization().await;
1457 } else {
1458 assert!(
1459 stream_rx.try_next().is_err(),
1460 "Failed for case: {} - path: {} - expected no confirmation but got one",
1461 description,
1462 path
1463 );
1464 auth.await.unwrap();
1465 }
1466 }
1467 }
1468
1469 #[gpui::test]
1470 async fn test_needs_confirmation_with_different_modes(cx: &mut TestAppContext) {
1471 init_test(cx);
1472 let fs = project::FakeFs::new(cx.executor());
1473 fs.insert_tree(
1474 "/project",
1475 json!({
1476 "existing.txt": "content",
1477 ".zed": {
1478 "settings.json": "{}"
1479 }
1480 }),
1481 )
1482 .await;
1483 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1484 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1485 let context_server_registry =
1486 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1487 let model = Arc::new(FakeLanguageModel::default());
1488 let thread = cx.new(|cx| {
1489 Thread::new(
1490 project.clone(),
1491 cx.new(|_cx| ProjectContext::default()),
1492 context_server_registry.clone(),
1493 Templates::new(),
1494 Some(model.clone()),
1495 cx,
1496 )
1497 });
1498 let tool = Arc::new(EditFileTool::new(
1499 project.clone(),
1500 thread.downgrade(),
1501 language_registry,
1502 Templates::new(),
1503 ));
1504
1505 // Test different EditFileMode values
1506 let modes = vec![
1507 EditFileMode::Edit,
1508 EditFileMode::Create,
1509 EditFileMode::Overwrite,
1510 ];
1511
1512 for mode in modes {
1513 // Test .zed path with different modes
1514 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1515 let _auth = cx.update(|cx| {
1516 tool.authorize(
1517 &EditFileToolInput {
1518 display_description: "Edit settings".into(),
1519 path: "project/.zed/settings.json".into(),
1520 mode: mode.clone(),
1521 },
1522 &stream_tx,
1523 cx,
1524 )
1525 });
1526
1527 stream_rx.expect_authorization().await;
1528
1529 // Test outside path with different modes
1530 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1531 let _auth = cx.update(|cx| {
1532 tool.authorize(
1533 &EditFileToolInput {
1534 display_description: "Edit file".into(),
1535 path: "/outside/file.txt".into(),
1536 mode: mode.clone(),
1537 },
1538 &stream_tx,
1539 cx,
1540 )
1541 });
1542
1543 stream_rx.expect_authorization().await;
1544
1545 // Test normal path with different modes
1546 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1547 cx.update(|cx| {
1548 tool.authorize(
1549 &EditFileToolInput {
1550 display_description: "Edit file".into(),
1551 path: "project/normal.txt".into(),
1552 mode: mode.clone(),
1553 },
1554 &stream_tx,
1555 cx,
1556 )
1557 })
1558 .await
1559 .unwrap();
1560 assert!(stream_rx.try_next().is_err());
1561 }
1562 }
1563
1564 #[gpui::test]
1565 async fn test_initial_title_with_partial_input(cx: &mut TestAppContext) {
1566 init_test(cx);
1567 let fs = project::FakeFs::new(cx.executor());
1568 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1569 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1570 let context_server_registry =
1571 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1572 let model = Arc::new(FakeLanguageModel::default());
1573 let thread = cx.new(|cx| {
1574 Thread::new(
1575 project.clone(),
1576 cx.new(|_cx| ProjectContext::default()),
1577 context_server_registry,
1578 Templates::new(),
1579 Some(model.clone()),
1580 cx,
1581 )
1582 });
1583 let tool = Arc::new(EditFileTool::new(
1584 project,
1585 thread.downgrade(),
1586 language_registry,
1587 Templates::new(),
1588 ));
1589
1590 cx.update(|cx| {
1591 // ...
1592 assert_eq!(
1593 tool.initial_title(
1594 Err(json!({
1595 "path": "src/main.rs",
1596 "display_description": "",
1597 "old_string": "old code",
1598 "new_string": "new code"
1599 })),
1600 cx
1601 ),
1602 "src/main.rs"
1603 );
1604 assert_eq!(
1605 tool.initial_title(
1606 Err(json!({
1607 "path": "",
1608 "display_description": "Fix error handling",
1609 "old_string": "old code",
1610 "new_string": "new code"
1611 })),
1612 cx
1613 ),
1614 "Fix error handling"
1615 );
1616 assert_eq!(
1617 tool.initial_title(
1618 Err(json!({
1619 "path": "src/main.rs",
1620 "display_description": "Fix error handling",
1621 "old_string": "old code",
1622 "new_string": "new code"
1623 })),
1624 cx
1625 ),
1626 "src/main.rs"
1627 );
1628 assert_eq!(
1629 tool.initial_title(
1630 Err(json!({
1631 "path": "",
1632 "display_description": "",
1633 "old_string": "old code",
1634 "new_string": "new code"
1635 })),
1636 cx
1637 ),
1638 DEFAULT_UI_TEXT
1639 );
1640 assert_eq!(
1641 tool.initial_title(Err(serde_json::Value::Null), cx),
1642 DEFAULT_UI_TEXT
1643 );
1644 });
1645 }
1646
1647 #[gpui::test]
1648 async fn test_diff_finalization(cx: &mut TestAppContext) {
1649 init_test(cx);
1650 let fs = project::FakeFs::new(cx.executor());
1651 fs.insert_tree("/", json!({"main.rs": ""})).await;
1652
1653 let project = Project::test(fs.clone(), [path!("/").as_ref()], cx).await;
1654 let languages = project.read_with(cx, |project, _cx| project.languages().clone());
1655 let context_server_registry =
1656 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1657 let model = Arc::new(FakeLanguageModel::default());
1658 let thread = cx.new(|cx| {
1659 Thread::new(
1660 project.clone(),
1661 cx.new(|_cx| ProjectContext::default()),
1662 context_server_registry.clone(),
1663 Templates::new(),
1664 Some(model.clone()),
1665 cx,
1666 )
1667 });
1668
1669 // Ensure the diff is finalized after the edit completes.
1670 {
1671 let tool = Arc::new(EditFileTool::new(
1672 project.clone(),
1673 thread.downgrade(),
1674 languages.clone(),
1675 Templates::new(),
1676 ));
1677 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1678 let edit = cx.update(|cx| {
1679 tool.run(
1680 EditFileToolInput {
1681 display_description: "Edit file".into(),
1682 path: path!("/main.rs").into(),
1683 mode: EditFileMode::Edit,
1684 },
1685 stream_tx,
1686 cx,
1687 )
1688 });
1689 stream_rx.expect_update_fields().await;
1690 let diff = stream_rx.expect_diff().await;
1691 diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Pending(_))));
1692 cx.run_until_parked();
1693 model.end_last_completion_stream();
1694 edit.await.unwrap();
1695 diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Finalized(_))));
1696 }
1697
1698 // Ensure the diff is finalized if an error occurs while editing.
1699 {
1700 model.forbid_requests();
1701 let tool = Arc::new(EditFileTool::new(
1702 project.clone(),
1703 thread.downgrade(),
1704 languages.clone(),
1705 Templates::new(),
1706 ));
1707 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1708 let edit = cx.update(|cx| {
1709 tool.run(
1710 EditFileToolInput {
1711 display_description: "Edit file".into(),
1712 path: path!("/main.rs").into(),
1713 mode: EditFileMode::Edit,
1714 },
1715 stream_tx,
1716 cx,
1717 )
1718 });
1719 stream_rx.expect_update_fields().await;
1720 let diff = stream_rx.expect_diff().await;
1721 diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Pending(_))));
1722 edit.await.unwrap_err();
1723 diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Finalized(_))));
1724 model.allow_requests();
1725 }
1726
1727 // Ensure the diff is finalized if the tool call gets dropped.
1728 {
1729 let tool = Arc::new(EditFileTool::new(
1730 project.clone(),
1731 thread.downgrade(),
1732 languages.clone(),
1733 Templates::new(),
1734 ));
1735 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1736 let edit = cx.update(|cx| {
1737 tool.run(
1738 EditFileToolInput {
1739 display_description: "Edit file".into(),
1740 path: path!("/main.rs").into(),
1741 mode: EditFileMode::Edit,
1742 },
1743 stream_tx,
1744 cx,
1745 )
1746 });
1747 stream_rx.expect_update_fields().await;
1748 let diff = stream_rx.expect_diff().await;
1749 diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Pending(_))));
1750 drop(edit);
1751 cx.run_until_parked();
1752 diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Finalized(_))));
1753 }
1754 }
1755
1756 fn init_test(cx: &mut TestAppContext) {
1757 cx.update(|cx| {
1758 let settings_store = SettingsStore::test(cx);
1759 cx.set_global(settings_store);
1760 language::init(cx);
1761 TelemetrySettings::register(cx);
1762 agent_settings::AgentSettings::register(cx);
1763 Project::init_settings(cx);
1764 });
1765 }
1766}