1use crate::{
2 AgentTool, Templates, Thread, ToolCallEventStream,
3 edit_agent::{EditAgent, EditAgentOutput, EditAgentOutputEvent, EditFormat},
4};
5use acp_thread::Diff;
6use agent_client_protocol::{self as acp, ToolCallLocation, ToolCallUpdateFields};
7use anyhow::{Context as _, Result, anyhow};
8use cloud_llm_client::CompletionIntent;
9use collections::HashSet;
10use gpui::{App, AppContext, AsyncApp, Entity, Task, WeakEntity};
11use indoc::formatdoc;
12use language::language_settings::{self, FormatOnSave};
13use language::{LanguageRegistry, ToPoint};
14use language_model::LanguageModelToolResultContent;
15use paths;
16use project::lsp_store::{FormatTrigger, LspFormatTarget};
17use project::{Project, ProjectPath};
18use schemars::JsonSchema;
19use serde::{Deserialize, Serialize};
20use settings::Settings;
21use smol::stream::StreamExt as _;
22use std::ffi::OsStr;
23use std::path::{Path, PathBuf};
24use std::sync::Arc;
25use ui::SharedString;
26use util::ResultExt;
27use util::rel_path::RelPath;
28
29const DEFAULT_UI_TEXT: &str = "Editing file";
30
31/// This is a tool for creating a new file or editing an existing file. For moving or renaming files, you should generally use the `terminal` tool with the 'mv' command instead.
32///
33/// Before using this tool:
34///
35/// 1. Use the `read_file` tool to understand the file's contents and context
36///
37/// 2. Verify the directory path is correct (only applicable when creating new files):
38/// - Use the `list_directory` tool to verify the parent directory exists and is the correct location
39#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
40pub struct EditFileToolInput {
41 /// A one-line, user-friendly markdown description of the edit. This will be shown in the UI and also passed to another model to perform the edit.
42 ///
43 /// Be terse, but also descriptive in what you want to achieve with this edit. Avoid generic instructions.
44 ///
45 /// NEVER mention the file path in this description.
46 ///
47 /// <example>Fix API endpoint URLs</example>
48 /// <example>Update copyright year in `page_footer`</example>
49 ///
50 /// Make sure to include this field before all the others in the input object so that we can display it immediately.
51 pub display_description: String,
52
53 /// The full path of the file to create or modify in the project.
54 ///
55 /// WARNING: When specifying which file path need changing, you MUST start each path with one of the project's root directories.
56 ///
57 /// The following examples assume we have two root directories in the project:
58 /// - /a/b/backend
59 /// - /c/d/frontend
60 ///
61 /// <example>
62 /// `backend/src/main.rs`
63 ///
64 /// Notice how the file path starts with `backend`. Without that, the path would be ambiguous and the call would fail!
65 /// </example>
66 ///
67 /// <example>
68 /// `frontend/db.js`
69 /// </example>
70 pub path: PathBuf,
71 /// The mode of operation on the file. Possible values:
72 /// - 'edit': Make granular edits to an existing file.
73 /// - 'create': Create a new file if it doesn't exist.
74 /// - 'overwrite': Replace the entire contents of an existing file.
75 ///
76 /// When a file already exists or you just created it, prefer editing it as opposed to recreating it from scratch.
77 pub mode: EditFileMode,
78}
79
80#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
81struct EditFileToolPartialInput {
82 #[serde(default)]
83 path: String,
84 #[serde(default)]
85 display_description: String,
86}
87
88#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
89#[serde(rename_all = "lowercase")]
90#[schemars(inline)]
91pub enum EditFileMode {
92 Edit,
93 Create,
94 Overwrite,
95}
96
97#[derive(Debug, Serialize, Deserialize)]
98pub struct EditFileToolOutput {
99 #[serde(alias = "original_path")]
100 input_path: PathBuf,
101 new_text: String,
102 old_text: Arc<String>,
103 #[serde(default)]
104 diff: String,
105 #[serde(alias = "raw_output")]
106 edit_agent_output: EditAgentOutput,
107}
108
109impl From<EditFileToolOutput> for LanguageModelToolResultContent {
110 fn from(output: EditFileToolOutput) -> Self {
111 if output.diff.is_empty() {
112 "No edits were made.".into()
113 } else {
114 format!(
115 "Edited {}:\n\n```diff\n{}\n```",
116 output.input_path.display(),
117 output.diff
118 )
119 .into()
120 }
121 }
122}
123
124pub struct EditFileTool {
125 thread: WeakEntity<Thread>,
126 language_registry: Arc<LanguageRegistry>,
127 project: Entity<Project>,
128 templates: Arc<Templates>,
129}
130
131impl EditFileTool {
132 pub fn new(
133 project: Entity<Project>,
134 thread: WeakEntity<Thread>,
135 language_registry: Arc<LanguageRegistry>,
136 templates: Arc<Templates>,
137 ) -> Self {
138 Self {
139 project,
140 thread,
141 language_registry,
142 templates,
143 }
144 }
145
146 fn authorize(
147 &self,
148 input: &EditFileToolInput,
149 event_stream: &ToolCallEventStream,
150 cx: &mut App,
151 ) -> Task<Result<()>> {
152 if agent_settings::AgentSettings::get_global(cx).always_allow_tool_actions {
153 return Task::ready(Ok(()));
154 }
155
156 // If any path component matches the local settings folder, then this could affect
157 // the editor in ways beyond the project source, so prompt.
158 let local_settings_folder = paths::local_settings_folder_name();
159 let path = Path::new(&input.path);
160 if path.components().any(|component| {
161 component.as_os_str() == <_ as AsRef<OsStr>>::as_ref(&local_settings_folder)
162 }) {
163 return event_stream.authorize(
164 format!("{} (local settings)", input.display_description),
165 cx,
166 );
167 }
168
169 // It's also possible that the global config dir is configured to be inside the project,
170 // so check for that edge case too.
171 // TODO this is broken when remoting
172 if let Ok(canonical_path) = std::fs::canonicalize(&input.path)
173 && canonical_path.starts_with(paths::config_dir())
174 {
175 return event_stream.authorize(
176 format!("{} (global settings)", input.display_description),
177 cx,
178 );
179 }
180
181 // Check if path is inside the global config directory
182 // First check if it's already inside project - if not, try to canonicalize
183 let Ok(project_path) = self.thread.read_with(cx, |thread, cx| {
184 thread.project().read(cx).find_project_path(&input.path, cx)
185 }) else {
186 return Task::ready(Err(anyhow!("thread was dropped")));
187 };
188
189 // If the path is inside the project, and it's not one of the above edge cases,
190 // then no confirmation is necessary. Otherwise, confirmation is necessary.
191 if project_path.is_some() {
192 Task::ready(Ok(()))
193 } else {
194 event_stream.authorize(&input.display_description, cx)
195 }
196 }
197}
198
199impl AgentTool for EditFileTool {
200 type Input = EditFileToolInput;
201 type Output = EditFileToolOutput;
202
203 fn name() -> &'static str {
204 "edit_file"
205 }
206
207 fn kind() -> acp::ToolKind {
208 acp::ToolKind::Edit
209 }
210
211 fn initial_title(
212 &self,
213 input: Result<Self::Input, serde_json::Value>,
214 cx: &mut App,
215 ) -> SharedString {
216 match input {
217 Ok(input) => self
218 .project
219 .read(cx)
220 .find_project_path(&input.path, cx)
221 .and_then(|project_path| {
222 self.project
223 .read(cx)
224 .short_full_path_for_project_path(&project_path, cx)
225 })
226 .unwrap_or(input.path.to_string_lossy().into_owned())
227 .into(),
228 Err(raw_input) => {
229 if let Some(input) =
230 serde_json::from_value::<EditFileToolPartialInput>(raw_input).ok()
231 {
232 let path = input.path.trim();
233 if !path.is_empty() {
234 return self
235 .project
236 .read(cx)
237 .find_project_path(&input.path, cx)
238 .and_then(|project_path| {
239 self.project
240 .read(cx)
241 .short_full_path_for_project_path(&project_path, cx)
242 })
243 .unwrap_or(input.path)
244 .into();
245 }
246
247 let description = input.display_description.trim();
248 if !description.is_empty() {
249 return description.to_string().into();
250 }
251 }
252
253 DEFAULT_UI_TEXT.into()
254 }
255 }
256 }
257
258 fn run(
259 self: Arc<Self>,
260 input: Self::Input,
261 event_stream: ToolCallEventStream,
262 cx: &mut App,
263 ) -> Task<Result<Self::Output>> {
264 let Ok(project) = self
265 .thread
266 .read_with(cx, |thread, _cx| thread.project().clone())
267 else {
268 return Task::ready(Err(anyhow!("thread was dropped")));
269 };
270 let project_path = match resolve_path(&input, project.clone(), cx) {
271 Ok(path) => path,
272 Err(err) => return Task::ready(Err(anyhow!(err))),
273 };
274 let abs_path = project.read(cx).absolute_path(&project_path, cx);
275 if let Some(abs_path) = abs_path.clone() {
276 event_stream.update_fields(ToolCallUpdateFields {
277 locations: Some(vec![acp::ToolCallLocation {
278 path: abs_path,
279 line: None,
280 meta: None,
281 }]),
282 ..Default::default()
283 });
284 }
285
286 let authorize = self.authorize(&input, &event_stream, cx);
287 cx.spawn(async move |cx: &mut AsyncApp| {
288 authorize.await?;
289
290 let (request, model, action_log) = self.thread.update(cx, |thread, cx| {
291 let request = thread.build_completion_request(CompletionIntent::ToolResults, cx);
292 (request, thread.model().cloned(), thread.action_log().clone())
293 })?;
294 let request = request?;
295 let model = model.context("No language model configured")?;
296
297 let edit_format = EditFormat::from_model(model.clone())?;
298 let edit_agent = EditAgent::new(
299 model,
300 project.clone(),
301 action_log.clone(),
302 self.templates.clone(),
303 edit_format,
304 );
305
306 let buffer = project
307 .update(cx, |project, cx| {
308 project.open_buffer(project_path.clone(), cx)
309 })?
310 .await?;
311
312 let diff = cx.new(|cx| Diff::new(buffer.clone(), cx))?;
313 event_stream.update_diff(diff.clone());
314 let _finalize_diff = util::defer({
315 let diff = diff.downgrade();
316 let mut cx = cx.clone();
317 move || {
318 diff.update(&mut cx, |diff, cx| diff.finalize(cx)).ok();
319 }
320 });
321
322 let old_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
323 let old_text = cx
324 .background_spawn({
325 let old_snapshot = old_snapshot.clone();
326 async move { Arc::new(old_snapshot.text()) }
327 })
328 .await;
329
330
331 let (output, mut events) = if matches!(input.mode, EditFileMode::Edit) {
332 edit_agent.edit(
333 buffer.clone(),
334 input.display_description.clone(),
335 &request,
336 cx,
337 )
338 } else {
339 edit_agent.overwrite(
340 buffer.clone(),
341 input.display_description.clone(),
342 &request,
343 cx,
344 )
345 };
346
347 let mut hallucinated_old_text = false;
348 let mut ambiguous_ranges = Vec::new();
349 let mut emitted_location = false;
350 while let Some(event) = events.next().await {
351 match event {
352 EditAgentOutputEvent::Edited(range) => {
353 if !emitted_location {
354 let line = buffer.update(cx, |buffer, _cx| {
355 range.start.to_point(&buffer.snapshot()).row
356 }).ok();
357 if let Some(abs_path) = abs_path.clone() {
358 event_stream.update_fields(ToolCallUpdateFields {
359 locations: Some(vec![ToolCallLocation { path: abs_path, line, meta: None }]),
360 ..Default::default()
361 });
362 }
363 emitted_location = true;
364 }
365 },
366 EditAgentOutputEvent::UnresolvedEditRange => hallucinated_old_text = true,
367 EditAgentOutputEvent::AmbiguousEditRange(ranges) => ambiguous_ranges = ranges,
368 EditAgentOutputEvent::ResolvingEditRange(range) => {
369 diff.update(cx, |card, cx| card.reveal_range(range.clone(), cx))?;
370 // if !emitted_location {
371 // let line = buffer.update(cx, |buffer, _cx| {
372 // range.start.to_point(&buffer.snapshot()).row
373 // }).ok();
374 // if let Some(abs_path) = abs_path.clone() {
375 // event_stream.update_fields(ToolCallUpdateFields {
376 // locations: Some(vec![ToolCallLocation { path: abs_path, line }]),
377 // ..Default::default()
378 // });
379 // }
380 // }
381 }
382 }
383 }
384
385 // If format_on_save is enabled, format the buffer
386 let format_on_save_enabled = buffer
387 .read_with(cx, |buffer, cx| {
388 let settings = language_settings::language_settings(
389 buffer.language().map(|l| l.name()),
390 buffer.file(),
391 cx,
392 );
393 settings.format_on_save != FormatOnSave::Off
394 })
395 .unwrap_or(false);
396
397 let edit_agent_output = output.await?;
398
399 if format_on_save_enabled {
400 action_log.update(cx, |log, cx| {
401 log.buffer_edited(buffer.clone(), cx);
402 })?;
403
404 let format_task = project.update(cx, |project, cx| {
405 project.format(
406 HashSet::from_iter([buffer.clone()]),
407 LspFormatTarget::Buffers,
408 false, // Don't push to history since the tool did it.
409 FormatTrigger::Save,
410 cx,
411 )
412 })?;
413 format_task.await.log_err();
414 }
415
416 project
417 .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))?
418 .await?;
419
420 action_log.update(cx, |log, cx| {
421 log.buffer_edited(buffer.clone(), cx);
422 })?;
423
424 let new_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
425 let (new_text, unified_diff) = cx
426 .background_spawn({
427 let new_snapshot = new_snapshot.clone();
428 let old_text = old_text.clone();
429 async move {
430 let new_text = new_snapshot.text();
431 let diff = language::unified_diff(&old_text, &new_text);
432 (new_text, diff)
433 }
434 })
435 .await;
436
437 let input_path = input.path.display();
438 if unified_diff.is_empty() {
439 anyhow::ensure!(
440 !hallucinated_old_text,
441 formatdoc! {"
442 Some edits were produced but none of them could be applied.
443 Read the relevant sections of {input_path} again so that
444 I can perform the requested edits.
445 "}
446 );
447 anyhow::ensure!(
448 ambiguous_ranges.is_empty(),
449 {
450 let line_numbers = ambiguous_ranges
451 .iter()
452 .map(|range| range.start.to_string())
453 .collect::<Vec<_>>()
454 .join(", ");
455 formatdoc! {"
456 <old_text> matches more than one position in the file (lines: {line_numbers}). Read the
457 relevant sections of {input_path} again and extend <old_text> so
458 that I can perform the requested edits.
459 "}
460 }
461 );
462 }
463
464 Ok(EditFileToolOutput {
465 input_path: input.path,
466 new_text,
467 old_text,
468 diff: unified_diff,
469 edit_agent_output,
470 })
471 })
472 }
473
474 fn replay(
475 &self,
476 _input: Self::Input,
477 output: Self::Output,
478 event_stream: ToolCallEventStream,
479 cx: &mut App,
480 ) -> Result<()> {
481 event_stream.update_diff(cx.new(|cx| {
482 Diff::finalized(
483 output.input_path.to_string_lossy().into_owned(),
484 Some(output.old_text.to_string()),
485 output.new_text,
486 self.language_registry.clone(),
487 cx,
488 )
489 }));
490 Ok(())
491 }
492}
493
494/// Validate that the file path is valid, meaning:
495///
496/// - For `edit` and `overwrite`, the path must point to an existing file.
497/// - For `create`, the file must not already exist, but it's parent dir must exist.
498fn resolve_path(
499 input: &EditFileToolInput,
500 project: Entity<Project>,
501 cx: &mut App,
502) -> Result<ProjectPath> {
503 let project = project.read(cx);
504
505 match input.mode {
506 EditFileMode::Edit | EditFileMode::Overwrite => {
507 let path = project
508 .find_project_path(&input.path, cx)
509 .context("Can't edit file: path not found")?;
510
511 let entry = project
512 .entry_for_path(&path, cx)
513 .context("Can't edit file: path not found")?;
514
515 anyhow::ensure!(entry.is_file(), "Can't edit file: path is a directory");
516 Ok(path)
517 }
518
519 EditFileMode::Create => {
520 if let Some(path) = project.find_project_path(&input.path, cx) {
521 anyhow::ensure!(
522 project.entry_for_path(&path, cx).is_none(),
523 "Can't create file: file already exists"
524 );
525 }
526
527 let parent_path = input
528 .path
529 .parent()
530 .context("Can't create file: incorrect path")?;
531
532 let parent_project_path = project.find_project_path(&parent_path, cx);
533
534 let parent_entry = parent_project_path
535 .as_ref()
536 .and_then(|path| project.entry_for_path(path, cx))
537 .context("Can't create file: parent directory doesn't exist")?;
538
539 anyhow::ensure!(
540 parent_entry.is_dir(),
541 "Can't create file: parent is not a directory"
542 );
543
544 let file_name = input
545 .path
546 .file_name()
547 .and_then(|file_name| file_name.to_str())
548 .and_then(|file_name| RelPath::unix(file_name).ok())
549 .context("Can't create file: invalid filename")?;
550
551 let new_file_path = parent_project_path.map(|parent| ProjectPath {
552 path: parent.path.join(file_name),
553 ..parent
554 });
555
556 new_file_path.context("Can't create file")
557 }
558 }
559}
560
561#[cfg(test)]
562mod tests {
563 use super::*;
564 use crate::{ContextServerRegistry, Templates};
565 use fs::Fs;
566 use gpui::{TestAppContext, UpdateGlobal};
567 use language_model::fake_provider::FakeLanguageModel;
568 use prompt_store::ProjectContext;
569 use serde_json::json;
570 use settings::SettingsStore;
571 use util::{path, rel_path::rel_path};
572
573 #[gpui::test]
574 async fn test_edit_nonexistent_file(cx: &mut TestAppContext) {
575 init_test(cx);
576
577 let fs = project::FakeFs::new(cx.executor());
578 fs.insert_tree("/root", json!({})).await;
579 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
580 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
581 let context_server_registry =
582 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
583 let model = Arc::new(FakeLanguageModel::default());
584 let thread = cx.new(|cx| {
585 Thread::new(
586 project.clone(),
587 cx.new(|_cx| ProjectContext::default()),
588 context_server_registry,
589 Templates::new(),
590 Some(model),
591 cx,
592 )
593 });
594 let result = cx
595 .update(|cx| {
596 let input = EditFileToolInput {
597 display_description: "Some edit".into(),
598 path: "root/nonexistent_file.txt".into(),
599 mode: EditFileMode::Edit,
600 };
601 Arc::new(EditFileTool::new(
602 project,
603 thread.downgrade(),
604 language_registry,
605 Templates::new(),
606 ))
607 .run(input, ToolCallEventStream::test().0, cx)
608 })
609 .await;
610 assert_eq!(
611 result.unwrap_err().to_string(),
612 "Can't edit file: path not found"
613 );
614 }
615
616 #[gpui::test]
617 async fn test_resolve_path_for_creating_file(cx: &mut TestAppContext) {
618 let mode = &EditFileMode::Create;
619
620 let result = test_resolve_path(mode, "root/new.txt", cx);
621 assert_resolved_path_eq(result.await, rel_path("new.txt"));
622
623 let result = test_resolve_path(mode, "new.txt", cx);
624 assert_resolved_path_eq(result.await, rel_path("new.txt"));
625
626 let result = test_resolve_path(mode, "dir/new.txt", cx);
627 assert_resolved_path_eq(result.await, rel_path("dir/new.txt"));
628
629 let result = test_resolve_path(mode, "root/dir/subdir/existing.txt", cx);
630 assert_eq!(
631 result.await.unwrap_err().to_string(),
632 "Can't create file: file already exists"
633 );
634
635 let result = test_resolve_path(mode, "root/dir/nonexistent_dir/new.txt", cx);
636 assert_eq!(
637 result.await.unwrap_err().to_string(),
638 "Can't create file: parent directory doesn't exist"
639 );
640 }
641
642 #[gpui::test]
643 async fn test_resolve_path_for_editing_file(cx: &mut TestAppContext) {
644 let mode = &EditFileMode::Edit;
645
646 let path_with_root = "root/dir/subdir/existing.txt";
647 let path_without_root = "dir/subdir/existing.txt";
648 let result = test_resolve_path(mode, path_with_root, cx);
649 assert_resolved_path_eq(result.await, rel_path(path_without_root));
650
651 let result = test_resolve_path(mode, path_without_root, cx);
652 assert_resolved_path_eq(result.await, rel_path(path_without_root));
653
654 let result = test_resolve_path(mode, "root/nonexistent.txt", cx);
655 assert_eq!(
656 result.await.unwrap_err().to_string(),
657 "Can't edit file: path not found"
658 );
659
660 let result = test_resolve_path(mode, "root/dir", cx);
661 assert_eq!(
662 result.await.unwrap_err().to_string(),
663 "Can't edit file: path is a directory"
664 );
665 }
666
667 async fn test_resolve_path(
668 mode: &EditFileMode,
669 path: &str,
670 cx: &mut TestAppContext,
671 ) -> anyhow::Result<ProjectPath> {
672 init_test(cx);
673
674 let fs = project::FakeFs::new(cx.executor());
675 fs.insert_tree(
676 "/root",
677 json!({
678 "dir": {
679 "subdir": {
680 "existing.txt": "hello"
681 }
682 }
683 }),
684 )
685 .await;
686 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
687
688 let input = EditFileToolInput {
689 display_description: "Some edit".into(),
690 path: path.into(),
691 mode: mode.clone(),
692 };
693
694 cx.update(|cx| resolve_path(&input, project, cx))
695 }
696
697 #[track_caller]
698 fn assert_resolved_path_eq(path: anyhow::Result<ProjectPath>, expected: &RelPath) {
699 let actual = path.expect("Should return valid path").path;
700 assert_eq!(actual.as_ref(), expected);
701 }
702
703 #[gpui::test]
704 async fn test_format_on_save(cx: &mut TestAppContext) {
705 init_test(cx);
706
707 let fs = project::FakeFs::new(cx.executor());
708 fs.insert_tree("/root", json!({"src": {}})).await;
709
710 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
711
712 // Set up a Rust language with LSP formatting support
713 let rust_language = Arc::new(language::Language::new(
714 language::LanguageConfig {
715 name: "Rust".into(),
716 matcher: language::LanguageMatcher {
717 path_suffixes: vec!["rs".to_string()],
718 ..Default::default()
719 },
720 ..Default::default()
721 },
722 None,
723 ));
724
725 // Register the language and fake LSP
726 let language_registry = project.read_with(cx, |project, _| project.languages().clone());
727 language_registry.add(rust_language);
728
729 let mut fake_language_servers = language_registry.register_fake_lsp(
730 "Rust",
731 language::FakeLspAdapter {
732 capabilities: lsp::ServerCapabilities {
733 document_formatting_provider: Some(lsp::OneOf::Left(true)),
734 ..Default::default()
735 },
736 ..Default::default()
737 },
738 );
739
740 // Create the file
741 fs.save(
742 path!("/root/src/main.rs").as_ref(),
743 &"initial content".into(),
744 language::LineEnding::Unix,
745 )
746 .await
747 .unwrap();
748
749 // Open the buffer to trigger LSP initialization
750 let buffer = project
751 .update(cx, |project, cx| {
752 project.open_local_buffer(path!("/root/src/main.rs"), cx)
753 })
754 .await
755 .unwrap();
756
757 // Register the buffer with language servers
758 let _handle = project.update(cx, |project, cx| {
759 project.register_buffer_with_language_servers(&buffer, cx)
760 });
761
762 const UNFORMATTED_CONTENT: &str = "fn main() {println!(\"Hello!\");}\n";
763 const FORMATTED_CONTENT: &str =
764 "This file was formatted by the fake formatter in the test.\n";
765
766 // Get the fake language server and set up formatting handler
767 let fake_language_server = fake_language_servers.next().await.unwrap();
768 fake_language_server.set_request_handler::<lsp::request::Formatting, _, _>({
769 |_, _| async move {
770 Ok(Some(vec![lsp::TextEdit {
771 range: lsp::Range::new(lsp::Position::new(0, 0), lsp::Position::new(1, 0)),
772 new_text: FORMATTED_CONTENT.to_string(),
773 }]))
774 }
775 });
776
777 let context_server_registry =
778 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
779 let model = Arc::new(FakeLanguageModel::default());
780 let thread = cx.new(|cx| {
781 Thread::new(
782 project.clone(),
783 cx.new(|_cx| ProjectContext::default()),
784 context_server_registry,
785 Templates::new(),
786 Some(model.clone()),
787 cx,
788 )
789 });
790
791 // First, test with format_on_save enabled
792 cx.update(|cx| {
793 SettingsStore::update_global(cx, |store, cx| {
794 store.update_user_settings(cx, |settings| {
795 settings.project.all_languages.defaults.format_on_save = Some(FormatOnSave::On);
796 settings.project.all_languages.defaults.formatter =
797 Some(language::language_settings::FormatterList::default());
798 });
799 });
800 });
801
802 // Have the model stream unformatted content
803 let edit_result = {
804 let edit_task = cx.update(|cx| {
805 let input = EditFileToolInput {
806 display_description: "Create main function".into(),
807 path: "root/src/main.rs".into(),
808 mode: EditFileMode::Overwrite,
809 };
810 Arc::new(EditFileTool::new(
811 project.clone(),
812 thread.downgrade(),
813 language_registry.clone(),
814 Templates::new(),
815 ))
816 .run(input, ToolCallEventStream::test().0, cx)
817 });
818
819 // Stream the unformatted content
820 cx.executor().run_until_parked();
821 model.send_last_completion_stream_text_chunk(UNFORMATTED_CONTENT.to_string());
822 model.end_last_completion_stream();
823
824 edit_task.await
825 };
826 assert!(edit_result.is_ok());
827
828 // Wait for any async operations (e.g. formatting) to complete
829 cx.executor().run_until_parked();
830
831 // Read the file to verify it was formatted automatically
832 let new_content = fs.load(path!("/root/src/main.rs").as_ref()).await.unwrap();
833 assert_eq!(
834 // Ignore carriage returns on Windows
835 new_content.replace("\r\n", "\n"),
836 FORMATTED_CONTENT,
837 "Code should be formatted when format_on_save is enabled"
838 );
839
840 let stale_buffer_count = thread
841 .read_with(cx, |thread, _cx| thread.action_log.clone())
842 .read_with(cx, |log, cx| log.stale_buffers(cx).count());
843
844 assert_eq!(
845 stale_buffer_count, 0,
846 "BUG: Buffer is incorrectly marked as stale after format-on-save. Found {} stale buffers. \
847 This causes the agent to think the file was modified externally when it was just formatted.",
848 stale_buffer_count
849 );
850
851 // Next, test with format_on_save disabled
852 cx.update(|cx| {
853 SettingsStore::update_global(cx, |store, cx| {
854 store.update_user_settings(cx, |settings| {
855 settings.project.all_languages.defaults.format_on_save =
856 Some(FormatOnSave::Off);
857 });
858 });
859 });
860
861 // Stream unformatted edits again
862 let edit_result = {
863 let edit_task = cx.update(|cx| {
864 let input = EditFileToolInput {
865 display_description: "Update main function".into(),
866 path: "root/src/main.rs".into(),
867 mode: EditFileMode::Overwrite,
868 };
869 Arc::new(EditFileTool::new(
870 project.clone(),
871 thread.downgrade(),
872 language_registry,
873 Templates::new(),
874 ))
875 .run(input, ToolCallEventStream::test().0, cx)
876 });
877
878 // Stream the unformatted content
879 cx.executor().run_until_parked();
880 model.send_last_completion_stream_text_chunk(UNFORMATTED_CONTENT.to_string());
881 model.end_last_completion_stream();
882
883 edit_task.await
884 };
885 assert!(edit_result.is_ok());
886
887 // Wait for any async operations (e.g. formatting) to complete
888 cx.executor().run_until_parked();
889
890 // Verify the file was not formatted
891 let new_content = fs.load(path!("/root/src/main.rs").as_ref()).await.unwrap();
892 assert_eq!(
893 // Ignore carriage returns on Windows
894 new_content.replace("\r\n", "\n"),
895 UNFORMATTED_CONTENT,
896 "Code should not be formatted when format_on_save is disabled"
897 );
898 }
899
900 #[gpui::test]
901 async fn test_remove_trailing_whitespace(cx: &mut TestAppContext) {
902 init_test(cx);
903
904 let fs = project::FakeFs::new(cx.executor());
905 fs.insert_tree("/root", json!({"src": {}})).await;
906
907 // Create a simple file with trailing whitespace
908 fs.save(
909 path!("/root/src/main.rs").as_ref(),
910 &"initial content".into(),
911 language::LineEnding::Unix,
912 )
913 .await
914 .unwrap();
915
916 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
917 let context_server_registry =
918 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
919 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
920 let model = Arc::new(FakeLanguageModel::default());
921 let thread = cx.new(|cx| {
922 Thread::new(
923 project.clone(),
924 cx.new(|_cx| ProjectContext::default()),
925 context_server_registry,
926 Templates::new(),
927 Some(model.clone()),
928 cx,
929 )
930 });
931
932 // First, test with remove_trailing_whitespace_on_save enabled
933 cx.update(|cx| {
934 SettingsStore::update_global(cx, |store, cx| {
935 store.update_user_settings(cx, |settings| {
936 settings
937 .project
938 .all_languages
939 .defaults
940 .remove_trailing_whitespace_on_save = Some(true);
941 });
942 });
943 });
944
945 const CONTENT_WITH_TRAILING_WHITESPACE: &str =
946 "fn main() { \n println!(\"Hello!\"); \n}\n";
947
948 // Have the model stream content that contains trailing whitespace
949 let edit_result = {
950 let edit_task = cx.update(|cx| {
951 let input = EditFileToolInput {
952 display_description: "Create main function".into(),
953 path: "root/src/main.rs".into(),
954 mode: EditFileMode::Overwrite,
955 };
956 Arc::new(EditFileTool::new(
957 project.clone(),
958 thread.downgrade(),
959 language_registry.clone(),
960 Templates::new(),
961 ))
962 .run(input, ToolCallEventStream::test().0, cx)
963 });
964
965 // Stream the content with trailing whitespace
966 cx.executor().run_until_parked();
967 model.send_last_completion_stream_text_chunk(
968 CONTENT_WITH_TRAILING_WHITESPACE.to_string(),
969 );
970 model.end_last_completion_stream();
971
972 edit_task.await
973 };
974 assert!(edit_result.is_ok());
975
976 // Wait for any async operations (e.g. formatting) to complete
977 cx.executor().run_until_parked();
978
979 // Read the file to verify trailing whitespace was removed automatically
980 assert_eq!(
981 // Ignore carriage returns on Windows
982 fs.load(path!("/root/src/main.rs").as_ref())
983 .await
984 .unwrap()
985 .replace("\r\n", "\n"),
986 "fn main() {\n println!(\"Hello!\");\n}\n",
987 "Trailing whitespace should be removed when remove_trailing_whitespace_on_save is enabled"
988 );
989
990 // Next, test with remove_trailing_whitespace_on_save disabled
991 cx.update(|cx| {
992 SettingsStore::update_global(cx, |store, cx| {
993 store.update_user_settings(cx, |settings| {
994 settings
995 .project
996 .all_languages
997 .defaults
998 .remove_trailing_whitespace_on_save = Some(false);
999 });
1000 });
1001 });
1002
1003 // Stream edits again with trailing whitespace
1004 let edit_result = {
1005 let edit_task = cx.update(|cx| {
1006 let input = EditFileToolInput {
1007 display_description: "Update main function".into(),
1008 path: "root/src/main.rs".into(),
1009 mode: EditFileMode::Overwrite,
1010 };
1011 Arc::new(EditFileTool::new(
1012 project.clone(),
1013 thread.downgrade(),
1014 language_registry,
1015 Templates::new(),
1016 ))
1017 .run(input, ToolCallEventStream::test().0, cx)
1018 });
1019
1020 // Stream the content with trailing whitespace
1021 cx.executor().run_until_parked();
1022 model.send_last_completion_stream_text_chunk(
1023 CONTENT_WITH_TRAILING_WHITESPACE.to_string(),
1024 );
1025 model.end_last_completion_stream();
1026
1027 edit_task.await
1028 };
1029 assert!(edit_result.is_ok());
1030
1031 // Wait for any async operations (e.g. formatting) to complete
1032 cx.executor().run_until_parked();
1033
1034 // Verify the file still has trailing whitespace
1035 // Read the file again - it should still have trailing whitespace
1036 let final_content = fs.load(path!("/root/src/main.rs").as_ref()).await.unwrap();
1037 assert_eq!(
1038 // Ignore carriage returns on Windows
1039 final_content.replace("\r\n", "\n"),
1040 CONTENT_WITH_TRAILING_WHITESPACE,
1041 "Trailing whitespace should remain when remove_trailing_whitespace_on_save is disabled"
1042 );
1043 }
1044
1045 #[gpui::test]
1046 async fn test_authorize(cx: &mut TestAppContext) {
1047 init_test(cx);
1048 let fs = project::FakeFs::new(cx.executor());
1049 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
1050 let context_server_registry =
1051 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1052 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1053 let model = Arc::new(FakeLanguageModel::default());
1054 let thread = cx.new(|cx| {
1055 Thread::new(
1056 project.clone(),
1057 cx.new(|_cx| ProjectContext::default()),
1058 context_server_registry,
1059 Templates::new(),
1060 Some(model.clone()),
1061 cx,
1062 )
1063 });
1064 let tool = Arc::new(EditFileTool::new(
1065 project.clone(),
1066 thread.downgrade(),
1067 language_registry,
1068 Templates::new(),
1069 ));
1070 fs.insert_tree("/root", json!({})).await;
1071
1072 // Test 1: Path with .zed component should require confirmation
1073 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1074 let _auth = cx.update(|cx| {
1075 tool.authorize(
1076 &EditFileToolInput {
1077 display_description: "test 1".into(),
1078 path: ".zed/settings.json".into(),
1079 mode: EditFileMode::Edit,
1080 },
1081 &stream_tx,
1082 cx,
1083 )
1084 });
1085
1086 let event = stream_rx.expect_authorization().await;
1087 assert_eq!(
1088 event.tool_call.fields.title,
1089 Some("test 1 (local settings)".into())
1090 );
1091
1092 // Test 2: Path outside project should require confirmation
1093 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1094 let _auth = cx.update(|cx| {
1095 tool.authorize(
1096 &EditFileToolInput {
1097 display_description: "test 2".into(),
1098 path: "/etc/hosts".into(),
1099 mode: EditFileMode::Edit,
1100 },
1101 &stream_tx,
1102 cx,
1103 )
1104 });
1105
1106 let event = stream_rx.expect_authorization().await;
1107 assert_eq!(event.tool_call.fields.title, Some("test 2".into()));
1108
1109 // Test 3: Relative path without .zed should not require confirmation
1110 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1111 cx.update(|cx| {
1112 tool.authorize(
1113 &EditFileToolInput {
1114 display_description: "test 3".into(),
1115 path: "root/src/main.rs".into(),
1116 mode: EditFileMode::Edit,
1117 },
1118 &stream_tx,
1119 cx,
1120 )
1121 })
1122 .await
1123 .unwrap();
1124 assert!(stream_rx.try_next().is_err());
1125
1126 // Test 4: Path with .zed in the middle should require confirmation
1127 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1128 let _auth = cx.update(|cx| {
1129 tool.authorize(
1130 &EditFileToolInput {
1131 display_description: "test 4".into(),
1132 path: "root/.zed/tasks.json".into(),
1133 mode: EditFileMode::Edit,
1134 },
1135 &stream_tx,
1136 cx,
1137 )
1138 });
1139 let event = stream_rx.expect_authorization().await;
1140 assert_eq!(
1141 event.tool_call.fields.title,
1142 Some("test 4 (local settings)".into())
1143 );
1144
1145 // Test 5: When always_allow_tool_actions is enabled, no confirmation needed
1146 cx.update(|cx| {
1147 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
1148 settings.always_allow_tool_actions = true;
1149 agent_settings::AgentSettings::override_global(settings, cx);
1150 });
1151
1152 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1153 cx.update(|cx| {
1154 tool.authorize(
1155 &EditFileToolInput {
1156 display_description: "test 5.1".into(),
1157 path: ".zed/settings.json".into(),
1158 mode: EditFileMode::Edit,
1159 },
1160 &stream_tx,
1161 cx,
1162 )
1163 })
1164 .await
1165 .unwrap();
1166 assert!(stream_rx.try_next().is_err());
1167
1168 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1169 cx.update(|cx| {
1170 tool.authorize(
1171 &EditFileToolInput {
1172 display_description: "test 5.2".into(),
1173 path: "/etc/hosts".into(),
1174 mode: EditFileMode::Edit,
1175 },
1176 &stream_tx,
1177 cx,
1178 )
1179 })
1180 .await
1181 .unwrap();
1182 assert!(stream_rx.try_next().is_err());
1183 }
1184
1185 #[gpui::test]
1186 async fn test_authorize_global_config(cx: &mut TestAppContext) {
1187 init_test(cx);
1188 let fs = project::FakeFs::new(cx.executor());
1189 fs.insert_tree("/project", json!({})).await;
1190 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1191 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1192 let context_server_registry =
1193 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1194 let model = Arc::new(FakeLanguageModel::default());
1195 let thread = cx.new(|cx| {
1196 Thread::new(
1197 project.clone(),
1198 cx.new(|_cx| ProjectContext::default()),
1199 context_server_registry,
1200 Templates::new(),
1201 Some(model.clone()),
1202 cx,
1203 )
1204 });
1205 let tool = Arc::new(EditFileTool::new(
1206 project.clone(),
1207 thread.downgrade(),
1208 language_registry,
1209 Templates::new(),
1210 ));
1211
1212 // Test global config paths - these should require confirmation if they exist and are outside the project
1213 let test_cases = vec![
1214 (
1215 "/etc/hosts",
1216 true,
1217 "System file should require confirmation",
1218 ),
1219 (
1220 "/usr/local/bin/script",
1221 true,
1222 "System bin file should require confirmation",
1223 ),
1224 (
1225 "project/normal_file.rs",
1226 false,
1227 "Normal project file should not require confirmation",
1228 ),
1229 ];
1230
1231 for (path, should_confirm, description) in test_cases {
1232 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1233 let auth = cx.update(|cx| {
1234 tool.authorize(
1235 &EditFileToolInput {
1236 display_description: "Edit file".into(),
1237 path: path.into(),
1238 mode: EditFileMode::Edit,
1239 },
1240 &stream_tx,
1241 cx,
1242 )
1243 });
1244
1245 if should_confirm {
1246 stream_rx.expect_authorization().await;
1247 } else {
1248 auth.await.unwrap();
1249 assert!(
1250 stream_rx.try_next().is_err(),
1251 "Failed for case: {} - path: {} - expected no confirmation but got one",
1252 description,
1253 path
1254 );
1255 }
1256 }
1257 }
1258
1259 #[gpui::test]
1260 async fn test_needs_confirmation_with_multiple_worktrees(cx: &mut TestAppContext) {
1261 init_test(cx);
1262 let fs = project::FakeFs::new(cx.executor());
1263
1264 // Create multiple worktree directories
1265 fs.insert_tree(
1266 "/workspace/frontend",
1267 json!({
1268 "src": {
1269 "main.js": "console.log('frontend');"
1270 }
1271 }),
1272 )
1273 .await;
1274 fs.insert_tree(
1275 "/workspace/backend",
1276 json!({
1277 "src": {
1278 "main.rs": "fn main() {}"
1279 }
1280 }),
1281 )
1282 .await;
1283 fs.insert_tree(
1284 "/workspace/shared",
1285 json!({
1286 ".zed": {
1287 "settings.json": "{}"
1288 }
1289 }),
1290 )
1291 .await;
1292
1293 // Create project with multiple worktrees
1294 let project = Project::test(
1295 fs.clone(),
1296 [
1297 path!("/workspace/frontend").as_ref(),
1298 path!("/workspace/backend").as_ref(),
1299 path!("/workspace/shared").as_ref(),
1300 ],
1301 cx,
1302 )
1303 .await;
1304 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1305 let context_server_registry =
1306 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1307 let model = Arc::new(FakeLanguageModel::default());
1308 let thread = cx.new(|cx| {
1309 Thread::new(
1310 project.clone(),
1311 cx.new(|_cx| ProjectContext::default()),
1312 context_server_registry.clone(),
1313 Templates::new(),
1314 Some(model.clone()),
1315 cx,
1316 )
1317 });
1318 let tool = Arc::new(EditFileTool::new(
1319 project.clone(),
1320 thread.downgrade(),
1321 language_registry,
1322 Templates::new(),
1323 ));
1324
1325 // Test files in different worktrees
1326 let test_cases = vec![
1327 ("frontend/src/main.js", false, "File in first worktree"),
1328 ("backend/src/main.rs", false, "File in second worktree"),
1329 (
1330 "shared/.zed/settings.json",
1331 true,
1332 ".zed file in third worktree",
1333 ),
1334 ("/etc/hosts", true, "Absolute path outside all worktrees"),
1335 (
1336 "../outside/file.txt",
1337 true,
1338 "Relative path outside worktrees",
1339 ),
1340 ];
1341
1342 for (path, should_confirm, description) in test_cases {
1343 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1344 let auth = cx.update(|cx| {
1345 tool.authorize(
1346 &EditFileToolInput {
1347 display_description: "Edit file".into(),
1348 path: path.into(),
1349 mode: EditFileMode::Edit,
1350 },
1351 &stream_tx,
1352 cx,
1353 )
1354 });
1355
1356 if should_confirm {
1357 stream_rx.expect_authorization().await;
1358 } else {
1359 auth.await.unwrap();
1360 assert!(
1361 stream_rx.try_next().is_err(),
1362 "Failed for case: {} - path: {} - expected no confirmation but got one",
1363 description,
1364 path
1365 );
1366 }
1367 }
1368 }
1369
1370 #[gpui::test]
1371 async fn test_needs_confirmation_edge_cases(cx: &mut TestAppContext) {
1372 init_test(cx);
1373 let fs = project::FakeFs::new(cx.executor());
1374 fs.insert_tree(
1375 "/project",
1376 json!({
1377 ".zed": {
1378 "settings.json": "{}"
1379 },
1380 "src": {
1381 ".zed": {
1382 "local.json": "{}"
1383 }
1384 }
1385 }),
1386 )
1387 .await;
1388 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1389 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1390 let context_server_registry =
1391 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1392 let model = Arc::new(FakeLanguageModel::default());
1393 let thread = cx.new(|cx| {
1394 Thread::new(
1395 project.clone(),
1396 cx.new(|_cx| ProjectContext::default()),
1397 context_server_registry.clone(),
1398 Templates::new(),
1399 Some(model.clone()),
1400 cx,
1401 )
1402 });
1403 let tool = Arc::new(EditFileTool::new(
1404 project.clone(),
1405 thread.downgrade(),
1406 language_registry,
1407 Templates::new(),
1408 ));
1409
1410 // Test edge cases
1411 let test_cases = vec![
1412 // Empty path - find_project_path returns Some for empty paths
1413 ("", false, "Empty path is treated as project root"),
1414 // Root directory
1415 ("/", true, "Root directory should be outside project"),
1416 // Parent directory references - find_project_path resolves these
1417 (
1418 "project/../other",
1419 true,
1420 "Path with .. that goes outside of root directory",
1421 ),
1422 (
1423 "project/./src/file.rs",
1424 false,
1425 "Path with . should work normally",
1426 ),
1427 // Windows-style paths (if on Windows)
1428 #[cfg(target_os = "windows")]
1429 ("C:\\Windows\\System32\\hosts", true, "Windows system path"),
1430 #[cfg(target_os = "windows")]
1431 ("project\\src\\main.rs", false, "Windows-style project path"),
1432 ];
1433
1434 for (path, should_confirm, description) in test_cases {
1435 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1436 let auth = cx.update(|cx| {
1437 tool.authorize(
1438 &EditFileToolInput {
1439 display_description: "Edit file".into(),
1440 path: path.into(),
1441 mode: EditFileMode::Edit,
1442 },
1443 &stream_tx,
1444 cx,
1445 )
1446 });
1447
1448 cx.run_until_parked();
1449
1450 if should_confirm {
1451 stream_rx.expect_authorization().await;
1452 } else {
1453 assert!(
1454 stream_rx.try_next().is_err(),
1455 "Failed for case: {} - path: {} - expected no confirmation but got one",
1456 description,
1457 path
1458 );
1459 auth.await.unwrap();
1460 }
1461 }
1462 }
1463
1464 #[gpui::test]
1465 async fn test_needs_confirmation_with_different_modes(cx: &mut TestAppContext) {
1466 init_test(cx);
1467 let fs = project::FakeFs::new(cx.executor());
1468 fs.insert_tree(
1469 "/project",
1470 json!({
1471 "existing.txt": "content",
1472 ".zed": {
1473 "settings.json": "{}"
1474 }
1475 }),
1476 )
1477 .await;
1478 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1479 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1480 let context_server_registry =
1481 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1482 let model = Arc::new(FakeLanguageModel::default());
1483 let thread = cx.new(|cx| {
1484 Thread::new(
1485 project.clone(),
1486 cx.new(|_cx| ProjectContext::default()),
1487 context_server_registry.clone(),
1488 Templates::new(),
1489 Some(model.clone()),
1490 cx,
1491 )
1492 });
1493 let tool = Arc::new(EditFileTool::new(
1494 project.clone(),
1495 thread.downgrade(),
1496 language_registry,
1497 Templates::new(),
1498 ));
1499
1500 // Test different EditFileMode values
1501 let modes = vec![
1502 EditFileMode::Edit,
1503 EditFileMode::Create,
1504 EditFileMode::Overwrite,
1505 ];
1506
1507 for mode in modes {
1508 // Test .zed path with different modes
1509 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1510 let _auth = cx.update(|cx| {
1511 tool.authorize(
1512 &EditFileToolInput {
1513 display_description: "Edit settings".into(),
1514 path: "project/.zed/settings.json".into(),
1515 mode: mode.clone(),
1516 },
1517 &stream_tx,
1518 cx,
1519 )
1520 });
1521
1522 stream_rx.expect_authorization().await;
1523
1524 // Test outside path with different modes
1525 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1526 let _auth = cx.update(|cx| {
1527 tool.authorize(
1528 &EditFileToolInput {
1529 display_description: "Edit file".into(),
1530 path: "/outside/file.txt".into(),
1531 mode: mode.clone(),
1532 },
1533 &stream_tx,
1534 cx,
1535 )
1536 });
1537
1538 stream_rx.expect_authorization().await;
1539
1540 // Test normal path with different modes
1541 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1542 cx.update(|cx| {
1543 tool.authorize(
1544 &EditFileToolInput {
1545 display_description: "Edit file".into(),
1546 path: "project/normal.txt".into(),
1547 mode: mode.clone(),
1548 },
1549 &stream_tx,
1550 cx,
1551 )
1552 })
1553 .await
1554 .unwrap();
1555 assert!(stream_rx.try_next().is_err());
1556 }
1557 }
1558
1559 #[gpui::test]
1560 async fn test_initial_title_with_partial_input(cx: &mut TestAppContext) {
1561 init_test(cx);
1562 let fs = project::FakeFs::new(cx.executor());
1563 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1564 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1565 let context_server_registry =
1566 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1567 let model = Arc::new(FakeLanguageModel::default());
1568 let thread = cx.new(|cx| {
1569 Thread::new(
1570 project.clone(),
1571 cx.new(|_cx| ProjectContext::default()),
1572 context_server_registry,
1573 Templates::new(),
1574 Some(model.clone()),
1575 cx,
1576 )
1577 });
1578 let tool = Arc::new(EditFileTool::new(
1579 project,
1580 thread.downgrade(),
1581 language_registry,
1582 Templates::new(),
1583 ));
1584
1585 cx.update(|cx| {
1586 // ...
1587 assert_eq!(
1588 tool.initial_title(
1589 Err(json!({
1590 "path": "src/main.rs",
1591 "display_description": "",
1592 "old_string": "old code",
1593 "new_string": "new code"
1594 })),
1595 cx
1596 ),
1597 "src/main.rs"
1598 );
1599 assert_eq!(
1600 tool.initial_title(
1601 Err(json!({
1602 "path": "",
1603 "display_description": "Fix error handling",
1604 "old_string": "old code",
1605 "new_string": "new code"
1606 })),
1607 cx
1608 ),
1609 "Fix error handling"
1610 );
1611 assert_eq!(
1612 tool.initial_title(
1613 Err(json!({
1614 "path": "src/main.rs",
1615 "display_description": "Fix error handling",
1616 "old_string": "old code",
1617 "new_string": "new code"
1618 })),
1619 cx
1620 ),
1621 "src/main.rs"
1622 );
1623 assert_eq!(
1624 tool.initial_title(
1625 Err(json!({
1626 "path": "",
1627 "display_description": "",
1628 "old_string": "old code",
1629 "new_string": "new code"
1630 })),
1631 cx
1632 ),
1633 DEFAULT_UI_TEXT
1634 );
1635 assert_eq!(
1636 tool.initial_title(Err(serde_json::Value::Null), cx),
1637 DEFAULT_UI_TEXT
1638 );
1639 });
1640 }
1641
1642 #[gpui::test]
1643 async fn test_diff_finalization(cx: &mut TestAppContext) {
1644 init_test(cx);
1645 let fs = project::FakeFs::new(cx.executor());
1646 fs.insert_tree("/", json!({"main.rs": ""})).await;
1647
1648 let project = Project::test(fs.clone(), [path!("/").as_ref()], cx).await;
1649 let languages = project.read_with(cx, |project, _cx| project.languages().clone());
1650 let context_server_registry =
1651 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1652 let model = Arc::new(FakeLanguageModel::default());
1653 let thread = cx.new(|cx| {
1654 Thread::new(
1655 project.clone(),
1656 cx.new(|_cx| ProjectContext::default()),
1657 context_server_registry.clone(),
1658 Templates::new(),
1659 Some(model.clone()),
1660 cx,
1661 )
1662 });
1663
1664 // Ensure the diff is finalized after the edit completes.
1665 {
1666 let tool = Arc::new(EditFileTool::new(
1667 project.clone(),
1668 thread.downgrade(),
1669 languages.clone(),
1670 Templates::new(),
1671 ));
1672 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1673 let edit = cx.update(|cx| {
1674 tool.run(
1675 EditFileToolInput {
1676 display_description: "Edit file".into(),
1677 path: path!("/main.rs").into(),
1678 mode: EditFileMode::Edit,
1679 },
1680 stream_tx,
1681 cx,
1682 )
1683 });
1684 stream_rx.expect_update_fields().await;
1685 let diff = stream_rx.expect_diff().await;
1686 diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Pending(_))));
1687 cx.run_until_parked();
1688 model.end_last_completion_stream();
1689 edit.await.unwrap();
1690 diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Finalized(_))));
1691 }
1692
1693 // Ensure the diff is finalized if an error occurs while editing.
1694 {
1695 model.forbid_requests();
1696 let tool = Arc::new(EditFileTool::new(
1697 project.clone(),
1698 thread.downgrade(),
1699 languages.clone(),
1700 Templates::new(),
1701 ));
1702 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1703 let edit = cx.update(|cx| {
1704 tool.run(
1705 EditFileToolInput {
1706 display_description: "Edit file".into(),
1707 path: path!("/main.rs").into(),
1708 mode: EditFileMode::Edit,
1709 },
1710 stream_tx,
1711 cx,
1712 )
1713 });
1714 stream_rx.expect_update_fields().await;
1715 let diff = stream_rx.expect_diff().await;
1716 diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Pending(_))));
1717 edit.await.unwrap_err();
1718 diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Finalized(_))));
1719 model.allow_requests();
1720 }
1721
1722 // Ensure the diff is finalized if the tool call gets dropped.
1723 {
1724 let tool = Arc::new(EditFileTool::new(
1725 project.clone(),
1726 thread.downgrade(),
1727 languages.clone(),
1728 Templates::new(),
1729 ));
1730 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1731 let edit = cx.update(|cx| {
1732 tool.run(
1733 EditFileToolInput {
1734 display_description: "Edit file".into(),
1735 path: path!("/main.rs").into(),
1736 mode: EditFileMode::Edit,
1737 },
1738 stream_tx,
1739 cx,
1740 )
1741 });
1742 stream_rx.expect_update_fields().await;
1743 let diff = stream_rx.expect_diff().await;
1744 diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Pending(_))));
1745 drop(edit);
1746 cx.run_until_parked();
1747 diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Finalized(_))));
1748 }
1749 }
1750
1751 fn init_test(cx: &mut TestAppContext) {
1752 cx.update(|cx| {
1753 let settings_store = SettingsStore::test(cx);
1754 cx.set_global(settings_store);
1755 });
1756 }
1757}