1use crate::{AgentTool, Thread, ToolCallEventStream};
2use acp_thread::Diff;
3use agent_client_protocol::{self as acp, ToolCallLocation, ToolCallUpdateFields};
4use anyhow::{Context as _, Result, anyhow};
5use assistant_tools::edit_agent::{EditAgent, EditAgentOutput, EditAgentOutputEvent, EditFormat};
6use cloud_llm_client::CompletionIntent;
7use collections::HashSet;
8use gpui::{App, AppContext, AsyncApp, Entity, Task, WeakEntity};
9use indoc::formatdoc;
10use language::language_settings::{self, FormatOnSave};
11use language::{LanguageRegistry, ToPoint};
12use language_model::LanguageModelToolResultContent;
13use paths;
14use project::lsp_store::{FormatTrigger, LspFormatTarget};
15use project::{Project, ProjectPath};
16use schemars::JsonSchema;
17use serde::{Deserialize, Serialize};
18use settings::Settings;
19use smol::stream::StreamExt as _;
20use std::path::{Path, PathBuf};
21use std::sync::Arc;
22use ui::SharedString;
23use util::ResultExt;
24
25const DEFAULT_UI_TEXT: &str = "Editing file";
26
27/// This is a tool for creating a new file or editing an existing file. For moving or renaming files, you should generally use the `terminal` tool with the 'mv' command instead.
28///
29/// Before using this tool:
30///
31/// 1. Use the `read_file` tool to understand the file's contents and context
32///
33/// 2. Verify the directory path is correct (only applicable when creating new files):
34/// - Use the `list_directory` tool to verify the parent directory exists and is the correct location
35#[derive(Debug, Serialize, Deserialize, JsonSchema)]
36pub struct EditFileToolInput {
37 /// A one-line, user-friendly markdown description of the edit. This will be shown in the UI and also passed to another model to perform the edit.
38 ///
39 /// Be terse, but also descriptive in what you want to achieve with this edit. Avoid generic instructions.
40 ///
41 /// NEVER mention the file path in this description.
42 ///
43 /// <example>Fix API endpoint URLs</example>
44 /// <example>Update copyright year in `page_footer`</example>
45 ///
46 /// Make sure to include this field before all the others in the input object so that we can display it immediately.
47 pub display_description: String,
48
49 /// The full path of the file to create or modify in the project.
50 ///
51 /// WARNING: When specifying which file path need changing, you MUST start each path with one of the project's root directories.
52 ///
53 /// The following examples assume we have two root directories in the project:
54 /// - /a/b/backend
55 /// - /c/d/frontend
56 ///
57 /// <example>
58 /// `backend/src/main.rs`
59 ///
60 /// Notice how the file path starts with `backend`. Without that, the path would be ambiguous and the call would fail!
61 /// </example>
62 ///
63 /// <example>
64 /// `frontend/db.js`
65 /// </example>
66 pub path: PathBuf,
67 /// The mode of operation on the file. Possible values:
68 /// - 'edit': Make granular edits to an existing file.
69 /// - 'create': Create a new file if it doesn't exist.
70 /// - 'overwrite': Replace the entire contents of an existing file.
71 ///
72 /// When a file already exists or you just created it, prefer editing it as opposed to recreating it from scratch.
73 pub mode: EditFileMode,
74}
75
76#[derive(Debug, Serialize, Deserialize, JsonSchema)]
77struct EditFileToolPartialInput {
78 #[serde(default)]
79 path: String,
80 #[serde(default)]
81 display_description: String,
82}
83
84#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
85#[serde(rename_all = "lowercase")]
86pub enum EditFileMode {
87 Edit,
88 Create,
89 Overwrite,
90}
91
92#[derive(Debug, Serialize, Deserialize)]
93pub struct EditFileToolOutput {
94 #[serde(alias = "original_path")]
95 input_path: PathBuf,
96 new_text: String,
97 old_text: Arc<String>,
98 #[serde(default)]
99 diff: String,
100 #[serde(alias = "raw_output")]
101 edit_agent_output: EditAgentOutput,
102}
103
104impl From<EditFileToolOutput> for LanguageModelToolResultContent {
105 fn from(output: EditFileToolOutput) -> Self {
106 if output.diff.is_empty() {
107 "No edits were made.".into()
108 } else {
109 format!(
110 "Edited {}:\n\n```diff\n{}\n```",
111 output.input_path.display(),
112 output.diff
113 )
114 .into()
115 }
116 }
117}
118
119pub struct EditFileTool {
120 thread: WeakEntity<Thread>,
121 language_registry: Arc<LanguageRegistry>,
122}
123
124impl EditFileTool {
125 pub fn new(thread: WeakEntity<Thread>, language_registry: Arc<LanguageRegistry>) -> Self {
126 Self {
127 thread,
128 language_registry,
129 }
130 }
131
132 fn authorize(
133 &self,
134 input: &EditFileToolInput,
135 event_stream: &ToolCallEventStream,
136 cx: &mut App,
137 ) -> Task<Result<()>> {
138 if agent_settings::AgentSettings::get_global(cx).always_allow_tool_actions {
139 return Task::ready(Ok(()));
140 }
141
142 // If any path component matches the local settings folder, then this could affect
143 // the editor in ways beyond the project source, so prompt.
144 let local_settings_folder = paths::local_settings_folder_relative_path();
145 let path = Path::new(&input.path);
146 if path
147 .components()
148 .any(|component| component.as_os_str() == local_settings_folder.as_os_str())
149 {
150 return event_stream.authorize(
151 format!("{} (local settings)", input.display_description),
152 cx,
153 );
154 }
155
156 // It's also possible that the global config dir is configured to be inside the project,
157 // so check for that edge case too.
158 if let Ok(canonical_path) = std::fs::canonicalize(&input.path)
159 && canonical_path.starts_with(paths::config_dir())
160 {
161 return event_stream.authorize(
162 format!("{} (global settings)", input.display_description),
163 cx,
164 );
165 }
166
167 // Check if path is inside the global config directory
168 // First check if it's already inside project - if not, try to canonicalize
169 let Ok(project_path) = self.thread.read_with(cx, |thread, cx| {
170 thread.project().read(cx).find_project_path(&input.path, cx)
171 }) else {
172 return Task::ready(Err(anyhow!("thread was dropped")));
173 };
174
175 // If the path is inside the project, and it's not one of the above edge cases,
176 // then no confirmation is necessary. Otherwise, confirmation is necessary.
177 if project_path.is_some() {
178 Task::ready(Ok(()))
179 } else {
180 event_stream.authorize(&input.display_description, cx)
181 }
182 }
183}
184
185impl AgentTool for EditFileTool {
186 type Input = EditFileToolInput;
187 type Output = EditFileToolOutput;
188
189 fn name() -> &'static str {
190 "edit_file"
191 }
192
193 fn kind() -> acp::ToolKind {
194 acp::ToolKind::Edit
195 }
196
197 fn initial_title(&self, input: Result<Self::Input, serde_json::Value>) -> SharedString {
198 match input {
199 Ok(input) => input.display_description.into(),
200 Err(raw_input) => {
201 if let Some(input) =
202 serde_json::from_value::<EditFileToolPartialInput>(raw_input).ok()
203 {
204 let description = input.display_description.trim();
205 if !description.is_empty() {
206 return description.to_string().into();
207 }
208
209 let path = input.path.trim().to_string();
210 if !path.is_empty() {
211 return path.into();
212 }
213 }
214
215 DEFAULT_UI_TEXT.into()
216 }
217 }
218 }
219
220 fn run(
221 self: Arc<Self>,
222 input: Self::Input,
223 event_stream: ToolCallEventStream,
224 cx: &mut App,
225 ) -> Task<Result<Self::Output>> {
226 let Ok(project) = self
227 .thread
228 .read_with(cx, |thread, _cx| thread.project().clone())
229 else {
230 return Task::ready(Err(anyhow!("thread was dropped")));
231 };
232 let project_path = match resolve_path(&input, project.clone(), cx) {
233 Ok(path) => path,
234 Err(err) => return Task::ready(Err(anyhow!(err))),
235 };
236 let abs_path = project.read(cx).absolute_path(&project_path, cx);
237 if let Some(abs_path) = abs_path.clone() {
238 event_stream.update_fields(ToolCallUpdateFields {
239 locations: Some(vec![acp::ToolCallLocation {
240 path: abs_path,
241 line: None,
242 }]),
243 ..Default::default()
244 });
245 }
246
247 let authorize = self.authorize(&input, &event_stream, cx);
248 cx.spawn(async move |cx: &mut AsyncApp| {
249 authorize.await?;
250
251 let (request, model, action_log) = self.thread.update(cx, |thread, cx| {
252 let request = thread.build_completion_request(CompletionIntent::ToolResults, cx);
253 (request, thread.model().cloned(), thread.action_log().clone())
254 })?;
255 let request = request?;
256 let model = model.context("No language model configured")?;
257
258 let edit_format = EditFormat::from_model(model.clone())?;
259 let edit_agent = EditAgent::new(
260 model,
261 project.clone(),
262 action_log.clone(),
263 // TODO: move edit agent to this crate so we can use our templates
264 assistant_tools::templates::Templates::new(),
265 edit_format,
266 );
267
268 let buffer = project
269 .update(cx, |project, cx| {
270 project.open_buffer(project_path.clone(), cx)
271 })?
272 .await?;
273
274 let diff = cx.new(|cx| Diff::new(buffer.clone(), cx))?;
275 event_stream.update_diff(diff.clone());
276 let _finalize_diff = util::defer({
277 let diff = diff.downgrade();
278 let mut cx = cx.clone();
279 move || {
280 diff.update(&mut cx, |diff, cx| diff.finalize(cx)).ok();
281 }
282 });
283
284 let old_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
285 let old_text = cx
286 .background_spawn({
287 let old_snapshot = old_snapshot.clone();
288 async move { Arc::new(old_snapshot.text()) }
289 })
290 .await;
291
292
293 let (output, mut events) = if matches!(input.mode, EditFileMode::Edit) {
294 edit_agent.edit(
295 buffer.clone(),
296 input.display_description.clone(),
297 &request,
298 cx,
299 )
300 } else {
301 edit_agent.overwrite(
302 buffer.clone(),
303 input.display_description.clone(),
304 &request,
305 cx,
306 )
307 };
308
309 let mut hallucinated_old_text = false;
310 let mut ambiguous_ranges = Vec::new();
311 let mut emitted_location = false;
312 while let Some(event) = events.next().await {
313 match event {
314 EditAgentOutputEvent::Edited(range) => {
315 if !emitted_location {
316 let line = buffer.update(cx, |buffer, _cx| {
317 range.start.to_point(&buffer.snapshot()).row
318 }).ok();
319 if let Some(abs_path) = abs_path.clone() {
320 event_stream.update_fields(ToolCallUpdateFields {
321 locations: Some(vec![ToolCallLocation { path: abs_path, line }]),
322 ..Default::default()
323 });
324 }
325 emitted_location = true;
326 }
327 },
328 EditAgentOutputEvent::UnresolvedEditRange => hallucinated_old_text = true,
329 EditAgentOutputEvent::AmbiguousEditRange(ranges) => ambiguous_ranges = ranges,
330 EditAgentOutputEvent::ResolvingEditRange(range) => {
331 diff.update(cx, |card, cx| card.reveal_range(range.clone(), cx))?;
332 // if !emitted_location {
333 // let line = buffer.update(cx, |buffer, _cx| {
334 // range.start.to_point(&buffer.snapshot()).row
335 // }).ok();
336 // if let Some(abs_path) = abs_path.clone() {
337 // event_stream.update_fields(ToolCallUpdateFields {
338 // locations: Some(vec![ToolCallLocation { path: abs_path, line }]),
339 // ..Default::default()
340 // });
341 // }
342 // }
343 }
344 }
345 }
346
347 // If format_on_save is enabled, format the buffer
348 let format_on_save_enabled = buffer
349 .read_with(cx, |buffer, cx| {
350 let settings = language_settings::language_settings(
351 buffer.language().map(|l| l.name()),
352 buffer.file(),
353 cx,
354 );
355 settings.format_on_save != FormatOnSave::Off
356 })
357 .unwrap_or(false);
358
359 let edit_agent_output = output.await?;
360
361 if format_on_save_enabled {
362 action_log.update(cx, |log, cx| {
363 log.buffer_edited(buffer.clone(), cx);
364 })?;
365
366 let format_task = project.update(cx, |project, cx| {
367 project.format(
368 HashSet::from_iter([buffer.clone()]),
369 LspFormatTarget::Buffers,
370 false, // Don't push to history since the tool did it.
371 FormatTrigger::Save,
372 cx,
373 )
374 })?;
375 format_task.await.log_err();
376 }
377
378 project
379 .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))?
380 .await?;
381
382 action_log.update(cx, |log, cx| {
383 log.buffer_edited(buffer.clone(), cx);
384 })?;
385
386 let new_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
387 let (new_text, unified_diff) = cx
388 .background_spawn({
389 let new_snapshot = new_snapshot.clone();
390 let old_text = old_text.clone();
391 async move {
392 let new_text = new_snapshot.text();
393 let diff = language::unified_diff(&old_text, &new_text);
394 (new_text, diff)
395 }
396 })
397 .await;
398
399 let input_path = input.path.display();
400 if unified_diff.is_empty() {
401 anyhow::ensure!(
402 !hallucinated_old_text,
403 formatdoc! {"
404 Some edits were produced but none of them could be applied.
405 Read the relevant sections of {input_path} again so that
406 I can perform the requested edits.
407 "}
408 );
409 anyhow::ensure!(
410 ambiguous_ranges.is_empty(),
411 {
412 let line_numbers = ambiguous_ranges
413 .iter()
414 .map(|range| range.start.to_string())
415 .collect::<Vec<_>>()
416 .join(", ");
417 formatdoc! {"
418 <old_text> matches more than one position in the file (lines: {line_numbers}). Read the
419 relevant sections of {input_path} again and extend <old_text> so
420 that I can perform the requested edits.
421 "}
422 }
423 );
424 }
425
426 Ok(EditFileToolOutput {
427 input_path: input.path,
428 new_text,
429 old_text,
430 diff: unified_diff,
431 edit_agent_output,
432 })
433 })
434 }
435
436 fn replay(
437 &self,
438 _input: Self::Input,
439 output: Self::Output,
440 event_stream: ToolCallEventStream,
441 cx: &mut App,
442 ) -> Result<()> {
443 event_stream.update_diff(cx.new(|cx| {
444 Diff::finalized(
445 output.input_path,
446 Some(output.old_text.to_string()),
447 output.new_text,
448 self.language_registry.clone(),
449 cx,
450 )
451 }));
452 Ok(())
453 }
454}
455
456/// Validate that the file path is valid, meaning:
457///
458/// - For `edit` and `overwrite`, the path must point to an existing file.
459/// - For `create`, the file must not already exist, but it's parent dir must exist.
460fn resolve_path(
461 input: &EditFileToolInput,
462 project: Entity<Project>,
463 cx: &mut App,
464) -> Result<ProjectPath> {
465 let project = project.read(cx);
466
467 match input.mode {
468 EditFileMode::Edit | EditFileMode::Overwrite => {
469 let path = project
470 .find_project_path(&input.path, cx)
471 .context("Can't edit file: path not found")?;
472
473 let entry = project
474 .entry_for_path(&path, cx)
475 .context("Can't edit file: path not found")?;
476
477 anyhow::ensure!(entry.is_file(), "Can't edit file: path is a directory");
478 Ok(path)
479 }
480
481 EditFileMode::Create => {
482 if let Some(path) = project.find_project_path(&input.path, cx) {
483 anyhow::ensure!(
484 project.entry_for_path(&path, cx).is_none(),
485 "Can't create file: file already exists"
486 );
487 }
488
489 let parent_path = input
490 .path
491 .parent()
492 .context("Can't create file: incorrect path")?;
493
494 let parent_project_path = project.find_project_path(&parent_path, cx);
495
496 let parent_entry = parent_project_path
497 .as_ref()
498 .and_then(|path| project.entry_for_path(path, cx))
499 .context("Can't create file: parent directory doesn't exist")?;
500
501 anyhow::ensure!(
502 parent_entry.is_dir(),
503 "Can't create file: parent is not a directory"
504 );
505
506 let file_name = input
507 .path
508 .file_name()
509 .context("Can't create file: invalid filename")?;
510
511 let new_file_path = parent_project_path.map(|parent| ProjectPath {
512 path: Arc::from(parent.path.join(file_name)),
513 ..parent
514 });
515
516 new_file_path.context("Can't create file")
517 }
518 }
519}
520
521#[cfg(test)]
522mod tests {
523 use super::*;
524 use crate::{ContextServerRegistry, Templates};
525 use client::TelemetrySettings;
526 use fs::Fs;
527 use gpui::{TestAppContext, UpdateGlobal};
528 use language_model::fake_provider::FakeLanguageModel;
529 use prompt_store::ProjectContext;
530 use serde_json::json;
531 use settings::SettingsStore;
532 use util::path;
533
534 #[gpui::test]
535 async fn test_edit_nonexistent_file(cx: &mut TestAppContext) {
536 init_test(cx);
537
538 let fs = project::FakeFs::new(cx.executor());
539 fs.insert_tree("/root", json!({})).await;
540 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
541 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
542 let context_server_registry =
543 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
544 let model = Arc::new(FakeLanguageModel::default());
545 let thread = cx.new(|cx| {
546 Thread::new(
547 project,
548 cx.new(|_cx| ProjectContext::default()),
549 context_server_registry,
550 Templates::new(),
551 Some(model),
552 cx,
553 )
554 });
555 let result = cx
556 .update(|cx| {
557 let input = EditFileToolInput {
558 display_description: "Some edit".into(),
559 path: "root/nonexistent_file.txt".into(),
560 mode: EditFileMode::Edit,
561 };
562 Arc::new(EditFileTool::new(thread.downgrade(), language_registry)).run(
563 input,
564 ToolCallEventStream::test().0,
565 cx,
566 )
567 })
568 .await;
569 assert_eq!(
570 result.unwrap_err().to_string(),
571 "Can't edit file: path not found"
572 );
573 }
574
575 #[gpui::test]
576 async fn test_resolve_path_for_creating_file(cx: &mut TestAppContext) {
577 let mode = &EditFileMode::Create;
578
579 let result = test_resolve_path(mode, "root/new.txt", cx);
580 assert_resolved_path_eq(result.await, "new.txt");
581
582 let result = test_resolve_path(mode, "new.txt", cx);
583 assert_resolved_path_eq(result.await, "new.txt");
584
585 let result = test_resolve_path(mode, "dir/new.txt", cx);
586 assert_resolved_path_eq(result.await, "dir/new.txt");
587
588 let result = test_resolve_path(mode, "root/dir/subdir/existing.txt", cx);
589 assert_eq!(
590 result.await.unwrap_err().to_string(),
591 "Can't create file: file already exists"
592 );
593
594 let result = test_resolve_path(mode, "root/dir/nonexistent_dir/new.txt", cx);
595 assert_eq!(
596 result.await.unwrap_err().to_string(),
597 "Can't create file: parent directory doesn't exist"
598 );
599 }
600
601 #[gpui::test]
602 async fn test_resolve_path_for_editing_file(cx: &mut TestAppContext) {
603 let mode = &EditFileMode::Edit;
604
605 let path_with_root = "root/dir/subdir/existing.txt";
606 let path_without_root = "dir/subdir/existing.txt";
607 let result = test_resolve_path(mode, path_with_root, cx);
608 assert_resolved_path_eq(result.await, path_without_root);
609
610 let result = test_resolve_path(mode, path_without_root, cx);
611 assert_resolved_path_eq(result.await, path_without_root);
612
613 let result = test_resolve_path(mode, "root/nonexistent.txt", cx);
614 assert_eq!(
615 result.await.unwrap_err().to_string(),
616 "Can't edit file: path not found"
617 );
618
619 let result = test_resolve_path(mode, "root/dir", cx);
620 assert_eq!(
621 result.await.unwrap_err().to_string(),
622 "Can't edit file: path is a directory"
623 );
624 }
625
626 async fn test_resolve_path(
627 mode: &EditFileMode,
628 path: &str,
629 cx: &mut TestAppContext,
630 ) -> anyhow::Result<ProjectPath> {
631 init_test(cx);
632
633 let fs = project::FakeFs::new(cx.executor());
634 fs.insert_tree(
635 "/root",
636 json!({
637 "dir": {
638 "subdir": {
639 "existing.txt": "hello"
640 }
641 }
642 }),
643 )
644 .await;
645 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
646
647 let input = EditFileToolInput {
648 display_description: "Some edit".into(),
649 path: path.into(),
650 mode: mode.clone(),
651 };
652
653 cx.update(|cx| resolve_path(&input, project, cx))
654 }
655
656 fn assert_resolved_path_eq(path: anyhow::Result<ProjectPath>, expected: &str) {
657 let actual = path
658 .expect("Should return valid path")
659 .path
660 .to_str()
661 .unwrap()
662 .replace("\\", "/"); // Naive Windows paths normalization
663 assert_eq!(actual, expected);
664 }
665
666 #[gpui::test]
667 async fn test_format_on_save(cx: &mut TestAppContext) {
668 init_test(cx);
669
670 let fs = project::FakeFs::new(cx.executor());
671 fs.insert_tree("/root", json!({"src": {}})).await;
672
673 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
674
675 // Set up a Rust language with LSP formatting support
676 let rust_language = Arc::new(language::Language::new(
677 language::LanguageConfig {
678 name: "Rust".into(),
679 matcher: language::LanguageMatcher {
680 path_suffixes: vec!["rs".to_string()],
681 ..Default::default()
682 },
683 ..Default::default()
684 },
685 None,
686 ));
687
688 // Register the language and fake LSP
689 let language_registry = project.read_with(cx, |project, _| project.languages().clone());
690 language_registry.add(rust_language);
691
692 let mut fake_language_servers = language_registry.register_fake_lsp(
693 "Rust",
694 language::FakeLspAdapter {
695 capabilities: lsp::ServerCapabilities {
696 document_formatting_provider: Some(lsp::OneOf::Left(true)),
697 ..Default::default()
698 },
699 ..Default::default()
700 },
701 );
702
703 // Create the file
704 fs.save(
705 path!("/root/src/main.rs").as_ref(),
706 &"initial content".into(),
707 language::LineEnding::Unix,
708 )
709 .await
710 .unwrap();
711
712 // Open the buffer to trigger LSP initialization
713 let buffer = project
714 .update(cx, |project, cx| {
715 project.open_local_buffer(path!("/root/src/main.rs"), cx)
716 })
717 .await
718 .unwrap();
719
720 // Register the buffer with language servers
721 let _handle = project.update(cx, |project, cx| {
722 project.register_buffer_with_language_servers(&buffer, cx)
723 });
724
725 const UNFORMATTED_CONTENT: &str = "fn main() {println!(\"Hello!\");}\n";
726 const FORMATTED_CONTENT: &str =
727 "This file was formatted by the fake formatter in the test.\n";
728
729 // Get the fake language server and set up formatting handler
730 let fake_language_server = fake_language_servers.next().await.unwrap();
731 fake_language_server.set_request_handler::<lsp::request::Formatting, _, _>({
732 |_, _| async move {
733 Ok(Some(vec![lsp::TextEdit {
734 range: lsp::Range::new(lsp::Position::new(0, 0), lsp::Position::new(1, 0)),
735 new_text: FORMATTED_CONTENT.to_string(),
736 }]))
737 }
738 });
739
740 let context_server_registry =
741 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
742 let model = Arc::new(FakeLanguageModel::default());
743 let thread = cx.new(|cx| {
744 Thread::new(
745 project,
746 cx.new(|_cx| ProjectContext::default()),
747 context_server_registry,
748 Templates::new(),
749 Some(model.clone()),
750 cx,
751 )
752 });
753
754 // First, test with format_on_save enabled
755 cx.update(|cx| {
756 SettingsStore::update_global(cx, |store, cx| {
757 store.update_user_settings::<language::language_settings::AllLanguageSettings>(
758 cx,
759 |settings| {
760 settings.defaults.format_on_save = Some(FormatOnSave::On);
761 settings.defaults.formatter =
762 Some(language::language_settings::SelectedFormatter::Auto);
763 },
764 );
765 });
766 });
767
768 // Have the model stream unformatted content
769 let edit_result = {
770 let edit_task = cx.update(|cx| {
771 let input = EditFileToolInput {
772 display_description: "Create main function".into(),
773 path: "root/src/main.rs".into(),
774 mode: EditFileMode::Overwrite,
775 };
776 Arc::new(EditFileTool::new(
777 thread.downgrade(),
778 language_registry.clone(),
779 ))
780 .run(input, ToolCallEventStream::test().0, cx)
781 });
782
783 // Stream the unformatted content
784 cx.executor().run_until_parked();
785 model.send_last_completion_stream_text_chunk(UNFORMATTED_CONTENT.to_string());
786 model.end_last_completion_stream();
787
788 edit_task.await
789 };
790 assert!(edit_result.is_ok());
791
792 // Wait for any async operations (e.g. formatting) to complete
793 cx.executor().run_until_parked();
794
795 // Read the file to verify it was formatted automatically
796 let new_content = fs.load(path!("/root/src/main.rs").as_ref()).await.unwrap();
797 assert_eq!(
798 // Ignore carriage returns on Windows
799 new_content.replace("\r\n", "\n"),
800 FORMATTED_CONTENT,
801 "Code should be formatted when format_on_save is enabled"
802 );
803
804 let stale_buffer_count = thread
805 .read_with(cx, |thread, _cx| thread.action_log.clone())
806 .read_with(cx, |log, cx| log.stale_buffers(cx).count());
807
808 assert_eq!(
809 stale_buffer_count, 0,
810 "BUG: Buffer is incorrectly marked as stale after format-on-save. Found {} stale buffers. \
811 This causes the agent to think the file was modified externally when it was just formatted.",
812 stale_buffer_count
813 );
814
815 // Next, test with format_on_save disabled
816 cx.update(|cx| {
817 SettingsStore::update_global(cx, |store, cx| {
818 store.update_user_settings::<language::language_settings::AllLanguageSettings>(
819 cx,
820 |settings| {
821 settings.defaults.format_on_save = Some(FormatOnSave::Off);
822 },
823 );
824 });
825 });
826
827 // Stream unformatted edits again
828 let edit_result = {
829 let edit_task = cx.update(|cx| {
830 let input = EditFileToolInput {
831 display_description: "Update main function".into(),
832 path: "root/src/main.rs".into(),
833 mode: EditFileMode::Overwrite,
834 };
835 Arc::new(EditFileTool::new(thread.downgrade(), language_registry)).run(
836 input,
837 ToolCallEventStream::test().0,
838 cx,
839 )
840 });
841
842 // Stream the unformatted content
843 cx.executor().run_until_parked();
844 model.send_last_completion_stream_text_chunk(UNFORMATTED_CONTENT.to_string());
845 model.end_last_completion_stream();
846
847 edit_task.await
848 };
849 assert!(edit_result.is_ok());
850
851 // Wait for any async operations (e.g. formatting) to complete
852 cx.executor().run_until_parked();
853
854 // Verify the file was not formatted
855 let new_content = fs.load(path!("/root/src/main.rs").as_ref()).await.unwrap();
856 assert_eq!(
857 // Ignore carriage returns on Windows
858 new_content.replace("\r\n", "\n"),
859 UNFORMATTED_CONTENT,
860 "Code should not be formatted when format_on_save is disabled"
861 );
862 }
863
864 #[gpui::test]
865 async fn test_remove_trailing_whitespace(cx: &mut TestAppContext) {
866 init_test(cx);
867
868 let fs = project::FakeFs::new(cx.executor());
869 fs.insert_tree("/root", json!({"src": {}})).await;
870
871 // Create a simple file with trailing whitespace
872 fs.save(
873 path!("/root/src/main.rs").as_ref(),
874 &"initial content".into(),
875 language::LineEnding::Unix,
876 )
877 .await
878 .unwrap();
879
880 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
881 let context_server_registry =
882 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
883 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
884 let model = Arc::new(FakeLanguageModel::default());
885 let thread = cx.new(|cx| {
886 Thread::new(
887 project,
888 cx.new(|_cx| ProjectContext::default()),
889 context_server_registry,
890 Templates::new(),
891 Some(model.clone()),
892 cx,
893 )
894 });
895
896 // First, test with remove_trailing_whitespace_on_save enabled
897 cx.update(|cx| {
898 SettingsStore::update_global(cx, |store, cx| {
899 store.update_user_settings::<language::language_settings::AllLanguageSettings>(
900 cx,
901 |settings| {
902 settings.defaults.remove_trailing_whitespace_on_save = Some(true);
903 },
904 );
905 });
906 });
907
908 const CONTENT_WITH_TRAILING_WHITESPACE: &str =
909 "fn main() { \n println!(\"Hello!\"); \n}\n";
910
911 // Have the model stream content that contains trailing whitespace
912 let edit_result = {
913 let edit_task = cx.update(|cx| {
914 let input = EditFileToolInput {
915 display_description: "Create main function".into(),
916 path: "root/src/main.rs".into(),
917 mode: EditFileMode::Overwrite,
918 };
919 Arc::new(EditFileTool::new(
920 thread.downgrade(),
921 language_registry.clone(),
922 ))
923 .run(input, ToolCallEventStream::test().0, cx)
924 });
925
926 // Stream the content with trailing whitespace
927 cx.executor().run_until_parked();
928 model.send_last_completion_stream_text_chunk(
929 CONTENT_WITH_TRAILING_WHITESPACE.to_string(),
930 );
931 model.end_last_completion_stream();
932
933 edit_task.await
934 };
935 assert!(edit_result.is_ok());
936
937 // Wait for any async operations (e.g. formatting) to complete
938 cx.executor().run_until_parked();
939
940 // Read the file to verify trailing whitespace was removed automatically
941 assert_eq!(
942 // Ignore carriage returns on Windows
943 fs.load(path!("/root/src/main.rs").as_ref())
944 .await
945 .unwrap()
946 .replace("\r\n", "\n"),
947 "fn main() {\n println!(\"Hello!\");\n}\n",
948 "Trailing whitespace should be removed when remove_trailing_whitespace_on_save is enabled"
949 );
950
951 // Next, test with remove_trailing_whitespace_on_save disabled
952 cx.update(|cx| {
953 SettingsStore::update_global(cx, |store, cx| {
954 store.update_user_settings::<language::language_settings::AllLanguageSettings>(
955 cx,
956 |settings| {
957 settings.defaults.remove_trailing_whitespace_on_save = Some(false);
958 },
959 );
960 });
961 });
962
963 // Stream edits again with trailing whitespace
964 let edit_result = {
965 let edit_task = cx.update(|cx| {
966 let input = EditFileToolInput {
967 display_description: "Update main function".into(),
968 path: "root/src/main.rs".into(),
969 mode: EditFileMode::Overwrite,
970 };
971 Arc::new(EditFileTool::new(thread.downgrade(), language_registry)).run(
972 input,
973 ToolCallEventStream::test().0,
974 cx,
975 )
976 });
977
978 // Stream the content with trailing whitespace
979 cx.executor().run_until_parked();
980 model.send_last_completion_stream_text_chunk(
981 CONTENT_WITH_TRAILING_WHITESPACE.to_string(),
982 );
983 model.end_last_completion_stream();
984
985 edit_task.await
986 };
987 assert!(edit_result.is_ok());
988
989 // Wait for any async operations (e.g. formatting) to complete
990 cx.executor().run_until_parked();
991
992 // Verify the file still has trailing whitespace
993 // Read the file again - it should still have trailing whitespace
994 let final_content = fs.load(path!("/root/src/main.rs").as_ref()).await.unwrap();
995 assert_eq!(
996 // Ignore carriage returns on Windows
997 final_content.replace("\r\n", "\n"),
998 CONTENT_WITH_TRAILING_WHITESPACE,
999 "Trailing whitespace should remain when remove_trailing_whitespace_on_save is disabled"
1000 );
1001 }
1002
1003 #[gpui::test]
1004 async fn test_authorize(cx: &mut TestAppContext) {
1005 init_test(cx);
1006 let fs = project::FakeFs::new(cx.executor());
1007 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
1008 let context_server_registry =
1009 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1010 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1011 let model = Arc::new(FakeLanguageModel::default());
1012 let thread = cx.new(|cx| {
1013 Thread::new(
1014 project,
1015 cx.new(|_cx| ProjectContext::default()),
1016 context_server_registry,
1017 Templates::new(),
1018 Some(model.clone()),
1019 cx,
1020 )
1021 });
1022 let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry));
1023 fs.insert_tree("/root", json!({})).await;
1024
1025 // Test 1: Path with .zed component should require confirmation
1026 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1027 let _auth = cx.update(|cx| {
1028 tool.authorize(
1029 &EditFileToolInput {
1030 display_description: "test 1".into(),
1031 path: ".zed/settings.json".into(),
1032 mode: EditFileMode::Edit,
1033 },
1034 &stream_tx,
1035 cx,
1036 )
1037 });
1038
1039 let event = stream_rx.expect_authorization().await;
1040 assert_eq!(
1041 event.tool_call.fields.title,
1042 Some("test 1 (local settings)".into())
1043 );
1044
1045 // Test 2: Path outside project should require confirmation
1046 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1047 let _auth = cx.update(|cx| {
1048 tool.authorize(
1049 &EditFileToolInput {
1050 display_description: "test 2".into(),
1051 path: "/etc/hosts".into(),
1052 mode: EditFileMode::Edit,
1053 },
1054 &stream_tx,
1055 cx,
1056 )
1057 });
1058
1059 let event = stream_rx.expect_authorization().await;
1060 assert_eq!(event.tool_call.fields.title, Some("test 2".into()));
1061
1062 // Test 3: Relative path without .zed should not require confirmation
1063 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1064 cx.update(|cx| {
1065 tool.authorize(
1066 &EditFileToolInput {
1067 display_description: "test 3".into(),
1068 path: "root/src/main.rs".into(),
1069 mode: EditFileMode::Edit,
1070 },
1071 &stream_tx,
1072 cx,
1073 )
1074 })
1075 .await
1076 .unwrap();
1077 assert!(stream_rx.try_next().is_err());
1078
1079 // Test 4: Path with .zed in the middle should require confirmation
1080 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1081 let _auth = cx.update(|cx| {
1082 tool.authorize(
1083 &EditFileToolInput {
1084 display_description: "test 4".into(),
1085 path: "root/.zed/tasks.json".into(),
1086 mode: EditFileMode::Edit,
1087 },
1088 &stream_tx,
1089 cx,
1090 )
1091 });
1092 let event = stream_rx.expect_authorization().await;
1093 assert_eq!(
1094 event.tool_call.fields.title,
1095 Some("test 4 (local settings)".into())
1096 );
1097
1098 // Test 5: When always_allow_tool_actions is enabled, no confirmation needed
1099 cx.update(|cx| {
1100 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
1101 settings.always_allow_tool_actions = true;
1102 agent_settings::AgentSettings::override_global(settings, cx);
1103 });
1104
1105 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1106 cx.update(|cx| {
1107 tool.authorize(
1108 &EditFileToolInput {
1109 display_description: "test 5.1".into(),
1110 path: ".zed/settings.json".into(),
1111 mode: EditFileMode::Edit,
1112 },
1113 &stream_tx,
1114 cx,
1115 )
1116 })
1117 .await
1118 .unwrap();
1119 assert!(stream_rx.try_next().is_err());
1120
1121 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1122 cx.update(|cx| {
1123 tool.authorize(
1124 &EditFileToolInput {
1125 display_description: "test 5.2".into(),
1126 path: "/etc/hosts".into(),
1127 mode: EditFileMode::Edit,
1128 },
1129 &stream_tx,
1130 cx,
1131 )
1132 })
1133 .await
1134 .unwrap();
1135 assert!(stream_rx.try_next().is_err());
1136 }
1137
1138 #[gpui::test]
1139 async fn test_authorize_global_config(cx: &mut TestAppContext) {
1140 init_test(cx);
1141 let fs = project::FakeFs::new(cx.executor());
1142 fs.insert_tree("/project", json!({})).await;
1143 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1144 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1145 let context_server_registry =
1146 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1147 let model = Arc::new(FakeLanguageModel::default());
1148 let thread = cx.new(|cx| {
1149 Thread::new(
1150 project,
1151 cx.new(|_cx| ProjectContext::default()),
1152 context_server_registry,
1153 Templates::new(),
1154 Some(model.clone()),
1155 cx,
1156 )
1157 });
1158 let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry));
1159
1160 // Test global config paths - these should require confirmation if they exist and are outside the project
1161 let test_cases = vec![
1162 (
1163 "/etc/hosts",
1164 true,
1165 "System file should require confirmation",
1166 ),
1167 (
1168 "/usr/local/bin/script",
1169 true,
1170 "System bin file should require confirmation",
1171 ),
1172 (
1173 "project/normal_file.rs",
1174 false,
1175 "Normal project file should not require confirmation",
1176 ),
1177 ];
1178
1179 for (path, should_confirm, description) in test_cases {
1180 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1181 let auth = cx.update(|cx| {
1182 tool.authorize(
1183 &EditFileToolInput {
1184 display_description: "Edit file".into(),
1185 path: path.into(),
1186 mode: EditFileMode::Edit,
1187 },
1188 &stream_tx,
1189 cx,
1190 )
1191 });
1192
1193 if should_confirm {
1194 stream_rx.expect_authorization().await;
1195 } else {
1196 auth.await.unwrap();
1197 assert!(
1198 stream_rx.try_next().is_err(),
1199 "Failed for case: {} - path: {} - expected no confirmation but got one",
1200 description,
1201 path
1202 );
1203 }
1204 }
1205 }
1206
1207 #[gpui::test]
1208 async fn test_needs_confirmation_with_multiple_worktrees(cx: &mut TestAppContext) {
1209 init_test(cx);
1210 let fs = project::FakeFs::new(cx.executor());
1211
1212 // Create multiple worktree directories
1213 fs.insert_tree(
1214 "/workspace/frontend",
1215 json!({
1216 "src": {
1217 "main.js": "console.log('frontend');"
1218 }
1219 }),
1220 )
1221 .await;
1222 fs.insert_tree(
1223 "/workspace/backend",
1224 json!({
1225 "src": {
1226 "main.rs": "fn main() {}"
1227 }
1228 }),
1229 )
1230 .await;
1231 fs.insert_tree(
1232 "/workspace/shared",
1233 json!({
1234 ".zed": {
1235 "settings.json": "{}"
1236 }
1237 }),
1238 )
1239 .await;
1240
1241 // Create project with multiple worktrees
1242 let project = Project::test(
1243 fs.clone(),
1244 [
1245 path!("/workspace/frontend").as_ref(),
1246 path!("/workspace/backend").as_ref(),
1247 path!("/workspace/shared").as_ref(),
1248 ],
1249 cx,
1250 )
1251 .await;
1252 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1253 let context_server_registry =
1254 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1255 let model = Arc::new(FakeLanguageModel::default());
1256 let thread = cx.new(|cx| {
1257 Thread::new(
1258 project.clone(),
1259 cx.new(|_cx| ProjectContext::default()),
1260 context_server_registry.clone(),
1261 Templates::new(),
1262 Some(model.clone()),
1263 cx,
1264 )
1265 });
1266 let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry));
1267
1268 // Test files in different worktrees
1269 let test_cases = vec![
1270 ("frontend/src/main.js", false, "File in first worktree"),
1271 ("backend/src/main.rs", false, "File in second worktree"),
1272 (
1273 "shared/.zed/settings.json",
1274 true,
1275 ".zed file in third worktree",
1276 ),
1277 ("/etc/hosts", true, "Absolute path outside all worktrees"),
1278 (
1279 "../outside/file.txt",
1280 true,
1281 "Relative path outside worktrees",
1282 ),
1283 ];
1284
1285 for (path, should_confirm, description) in test_cases {
1286 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1287 let auth = cx.update(|cx| {
1288 tool.authorize(
1289 &EditFileToolInput {
1290 display_description: "Edit file".into(),
1291 path: path.into(),
1292 mode: EditFileMode::Edit,
1293 },
1294 &stream_tx,
1295 cx,
1296 )
1297 });
1298
1299 if should_confirm {
1300 stream_rx.expect_authorization().await;
1301 } else {
1302 auth.await.unwrap();
1303 assert!(
1304 stream_rx.try_next().is_err(),
1305 "Failed for case: {} - path: {} - expected no confirmation but got one",
1306 description,
1307 path
1308 );
1309 }
1310 }
1311 }
1312
1313 #[gpui::test]
1314 async fn test_needs_confirmation_edge_cases(cx: &mut TestAppContext) {
1315 init_test(cx);
1316 let fs = project::FakeFs::new(cx.executor());
1317 fs.insert_tree(
1318 "/project",
1319 json!({
1320 ".zed": {
1321 "settings.json": "{}"
1322 },
1323 "src": {
1324 ".zed": {
1325 "local.json": "{}"
1326 }
1327 }
1328 }),
1329 )
1330 .await;
1331 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1332 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1333 let context_server_registry =
1334 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1335 let model = Arc::new(FakeLanguageModel::default());
1336 let thread = cx.new(|cx| {
1337 Thread::new(
1338 project.clone(),
1339 cx.new(|_cx| ProjectContext::default()),
1340 context_server_registry.clone(),
1341 Templates::new(),
1342 Some(model.clone()),
1343 cx,
1344 )
1345 });
1346 let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry));
1347
1348 // Test edge cases
1349 let test_cases = vec![
1350 // Empty path - find_project_path returns Some for empty paths
1351 ("", false, "Empty path is treated as project root"),
1352 // Root directory
1353 ("/", true, "Root directory should be outside project"),
1354 // Parent directory references - find_project_path resolves these
1355 (
1356 "project/../other",
1357 false,
1358 "Path with .. is resolved by find_project_path",
1359 ),
1360 (
1361 "project/./src/file.rs",
1362 false,
1363 "Path with . should work normally",
1364 ),
1365 // Windows-style paths (if on Windows)
1366 #[cfg(target_os = "windows")]
1367 ("C:\\Windows\\System32\\hosts", true, "Windows system path"),
1368 #[cfg(target_os = "windows")]
1369 ("project\\src\\main.rs", false, "Windows-style project path"),
1370 ];
1371
1372 for (path, should_confirm, description) in test_cases {
1373 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1374 let auth = cx.update(|cx| {
1375 tool.authorize(
1376 &EditFileToolInput {
1377 display_description: "Edit file".into(),
1378 path: path.into(),
1379 mode: EditFileMode::Edit,
1380 },
1381 &stream_tx,
1382 cx,
1383 )
1384 });
1385
1386 if should_confirm {
1387 stream_rx.expect_authorization().await;
1388 } else {
1389 auth.await.unwrap();
1390 assert!(
1391 stream_rx.try_next().is_err(),
1392 "Failed for case: {} - path: {} - expected no confirmation but got one",
1393 description,
1394 path
1395 );
1396 }
1397 }
1398 }
1399
1400 #[gpui::test]
1401 async fn test_needs_confirmation_with_different_modes(cx: &mut TestAppContext) {
1402 init_test(cx);
1403 let fs = project::FakeFs::new(cx.executor());
1404 fs.insert_tree(
1405 "/project",
1406 json!({
1407 "existing.txt": "content",
1408 ".zed": {
1409 "settings.json": "{}"
1410 }
1411 }),
1412 )
1413 .await;
1414 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1415 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1416 let context_server_registry =
1417 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1418 let model = Arc::new(FakeLanguageModel::default());
1419 let thread = cx.new(|cx| {
1420 Thread::new(
1421 project.clone(),
1422 cx.new(|_cx| ProjectContext::default()),
1423 context_server_registry.clone(),
1424 Templates::new(),
1425 Some(model.clone()),
1426 cx,
1427 )
1428 });
1429 let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry));
1430
1431 // Test different EditFileMode values
1432 let modes = vec![
1433 EditFileMode::Edit,
1434 EditFileMode::Create,
1435 EditFileMode::Overwrite,
1436 ];
1437
1438 for mode in modes {
1439 // Test .zed path with different modes
1440 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1441 let _auth = cx.update(|cx| {
1442 tool.authorize(
1443 &EditFileToolInput {
1444 display_description: "Edit settings".into(),
1445 path: "project/.zed/settings.json".into(),
1446 mode: mode.clone(),
1447 },
1448 &stream_tx,
1449 cx,
1450 )
1451 });
1452
1453 stream_rx.expect_authorization().await;
1454
1455 // Test outside path with different modes
1456 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1457 let _auth = cx.update(|cx| {
1458 tool.authorize(
1459 &EditFileToolInput {
1460 display_description: "Edit file".into(),
1461 path: "/outside/file.txt".into(),
1462 mode: mode.clone(),
1463 },
1464 &stream_tx,
1465 cx,
1466 )
1467 });
1468
1469 stream_rx.expect_authorization().await;
1470
1471 // Test normal path with different modes
1472 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1473 cx.update(|cx| {
1474 tool.authorize(
1475 &EditFileToolInput {
1476 display_description: "Edit file".into(),
1477 path: "project/normal.txt".into(),
1478 mode: mode.clone(),
1479 },
1480 &stream_tx,
1481 cx,
1482 )
1483 })
1484 .await
1485 .unwrap();
1486 assert!(stream_rx.try_next().is_err());
1487 }
1488 }
1489
1490 #[gpui::test]
1491 async fn test_initial_title_with_partial_input(cx: &mut TestAppContext) {
1492 init_test(cx);
1493 let fs = project::FakeFs::new(cx.executor());
1494 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1495 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1496 let context_server_registry =
1497 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1498 let model = Arc::new(FakeLanguageModel::default());
1499 let thread = cx.new(|cx| {
1500 Thread::new(
1501 project.clone(),
1502 cx.new(|_cx| ProjectContext::default()),
1503 context_server_registry,
1504 Templates::new(),
1505 Some(model.clone()),
1506 cx,
1507 )
1508 });
1509 let tool = Arc::new(EditFileTool::new(thread.downgrade(), language_registry));
1510
1511 assert_eq!(
1512 tool.initial_title(Err(json!({
1513 "path": "src/main.rs",
1514 "display_description": "",
1515 "old_string": "old code",
1516 "new_string": "new code"
1517 }))),
1518 "src/main.rs"
1519 );
1520 assert_eq!(
1521 tool.initial_title(Err(json!({
1522 "path": "",
1523 "display_description": "Fix error handling",
1524 "old_string": "old code",
1525 "new_string": "new code"
1526 }))),
1527 "Fix error handling"
1528 );
1529 assert_eq!(
1530 tool.initial_title(Err(json!({
1531 "path": "src/main.rs",
1532 "display_description": "Fix error handling",
1533 "old_string": "old code",
1534 "new_string": "new code"
1535 }))),
1536 "Fix error handling"
1537 );
1538 assert_eq!(
1539 tool.initial_title(Err(json!({
1540 "path": "",
1541 "display_description": "",
1542 "old_string": "old code",
1543 "new_string": "new code"
1544 }))),
1545 DEFAULT_UI_TEXT
1546 );
1547 assert_eq!(
1548 tool.initial_title(Err(serde_json::Value::Null)),
1549 DEFAULT_UI_TEXT
1550 );
1551 }
1552
1553 #[gpui::test]
1554 async fn test_diff_finalization(cx: &mut TestAppContext) {
1555 init_test(cx);
1556 let fs = project::FakeFs::new(cx.executor());
1557 fs.insert_tree("/", json!({"main.rs": ""})).await;
1558
1559 let project = Project::test(fs.clone(), [path!("/").as_ref()], cx).await;
1560 let languages = project.read_with(cx, |project, _cx| project.languages().clone());
1561 let context_server_registry =
1562 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1563 let model = Arc::new(FakeLanguageModel::default());
1564 let thread = cx.new(|cx| {
1565 Thread::new(
1566 project.clone(),
1567 cx.new(|_cx| ProjectContext::default()),
1568 context_server_registry.clone(),
1569 Templates::new(),
1570 Some(model.clone()),
1571 cx,
1572 )
1573 });
1574
1575 // Ensure the diff is finalized after the edit completes.
1576 {
1577 let tool = Arc::new(EditFileTool::new(thread.downgrade(), languages.clone()));
1578 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1579 let edit = cx.update(|cx| {
1580 tool.run(
1581 EditFileToolInput {
1582 display_description: "Edit file".into(),
1583 path: path!("/main.rs").into(),
1584 mode: EditFileMode::Edit,
1585 },
1586 stream_tx,
1587 cx,
1588 )
1589 });
1590 stream_rx.expect_update_fields().await;
1591 let diff = stream_rx.expect_diff().await;
1592 diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Pending(_))));
1593 cx.run_until_parked();
1594 model.end_last_completion_stream();
1595 edit.await.unwrap();
1596 diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Finalized(_))));
1597 }
1598
1599 // Ensure the diff is finalized if an error occurs while editing.
1600 {
1601 model.forbid_requests();
1602 let tool = Arc::new(EditFileTool::new(thread.downgrade(), languages.clone()));
1603 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1604 let edit = cx.update(|cx| {
1605 tool.run(
1606 EditFileToolInput {
1607 display_description: "Edit file".into(),
1608 path: path!("/main.rs").into(),
1609 mode: EditFileMode::Edit,
1610 },
1611 stream_tx,
1612 cx,
1613 )
1614 });
1615 stream_rx.expect_update_fields().await;
1616 let diff = stream_rx.expect_diff().await;
1617 diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Pending(_))));
1618 edit.await.unwrap_err();
1619 diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Finalized(_))));
1620 model.allow_requests();
1621 }
1622
1623 // Ensure the diff is finalized if the tool call gets dropped.
1624 {
1625 let tool = Arc::new(EditFileTool::new(thread.downgrade(), languages.clone()));
1626 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1627 let edit = cx.update(|cx| {
1628 tool.run(
1629 EditFileToolInput {
1630 display_description: "Edit file".into(),
1631 path: path!("/main.rs").into(),
1632 mode: EditFileMode::Edit,
1633 },
1634 stream_tx,
1635 cx,
1636 )
1637 });
1638 stream_rx.expect_update_fields().await;
1639 let diff = stream_rx.expect_diff().await;
1640 diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Pending(_))));
1641 drop(edit);
1642 cx.run_until_parked();
1643 diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Finalized(_))));
1644 }
1645 }
1646
1647 fn init_test(cx: &mut TestAppContext) {
1648 cx.update(|cx| {
1649 let settings_store = SettingsStore::test(cx);
1650 cx.set_global(settings_store);
1651 language::init(cx);
1652 TelemetrySettings::register(cx);
1653 agent_settings::AgentSettings::register(cx);
1654 Project::init_settings(cx);
1655 });
1656 }
1657}