edit_file_tool.rs

   1use crate::{AgentTool, Thread, ToolCallEventStream};
   2use acp_thread::Diff;
   3use agent_client_protocol::{self as acp, ToolCallLocation, ToolCallUpdateFields};
   4use anyhow::{Context as _, Result, anyhow};
   5use assistant_tools::edit_agent::{EditAgent, EditAgentOutput, EditAgentOutputEvent, EditFormat};
   6use cloud_llm_client::CompletionIntent;
   7use collections::HashSet;
   8use gpui::{App, AppContext, AsyncApp, Entity, Task, WeakEntity};
   9use indoc::formatdoc;
  10use language::language_settings::{self, FormatOnSave};
  11use language::{LanguageRegistry, ToPoint};
  12use language_model::LanguageModelToolResultContent;
  13use paths;
  14use project::lsp_store::{FormatTrigger, LspFormatTarget};
  15use project::{Project, ProjectPath};
  16use schemars::JsonSchema;
  17use serde::{Deserialize, Serialize};
  18use settings::Settings;
  19use smol::stream::StreamExt as _;
  20use std::path::{Path, PathBuf};
  21use std::sync::Arc;
  22use ui::SharedString;
  23use util::ResultExt;
  24
  25const DEFAULT_UI_TEXT: &str = "Editing file";
  26
  27/// This is a tool for creating a new file or editing an existing file. For moving or renaming files, you should generally use the `terminal` tool with the 'mv' command instead.
  28///
  29/// Before using this tool:
  30///
  31/// 1. Use the `read_file` tool to understand the file's contents and context
  32///
  33/// 2. Verify the directory path is correct (only applicable when creating new files):
  34///    - Use the `list_directory` tool to verify the parent directory exists and is the correct location
  35#[derive(Debug, Serialize, Deserialize, JsonSchema)]
  36pub struct EditFileToolInput {
  37    /// A one-line, user-friendly markdown description of the edit. This will be shown in the UI and also passed to another model to perform the edit.
  38    ///
  39    /// Be terse, but also descriptive in what you want to achieve with this edit. Avoid generic instructions.
  40    ///
  41    /// NEVER mention the file path in this description.
  42    ///
  43    /// <example>Fix API endpoint URLs</example>
  44    /// <example>Update copyright year in `page_footer`</example>
  45    ///
  46    /// Make sure to include this field before all the others in the input object so that we can display it immediately.
  47    pub display_description: String,
  48
  49    /// The full path of the file to create or modify in the project.
  50    ///
  51    /// WARNING: When specifying which file path need changing, you MUST start each path with one of the project's root directories.
  52    ///
  53    /// The following examples assume we have two root directories in the project:
  54    /// - /a/b/backend
  55    /// - /c/d/frontend
  56    ///
  57    /// <example>
  58    /// `backend/src/main.rs`
  59    ///
  60    /// Notice how the file path starts with `backend`. Without that, the path would be ambiguous and the call would fail!
  61    /// </example>
  62    ///
  63    /// <example>
  64    /// `frontend/db.js`
  65    /// </example>
  66    pub path: PathBuf,
  67    /// The mode of operation on the file. Possible values:
  68    /// - 'edit': Make granular edits to an existing file.
  69    /// - 'create': Create a new file if it doesn't exist.
  70    /// - 'overwrite': Replace the entire contents of an existing file.
  71    ///
  72    /// When a file already exists or you just created it, prefer editing it as opposed to recreating it from scratch.
  73    pub mode: EditFileMode,
  74}
  75
  76#[derive(Debug, Serialize, Deserialize, JsonSchema)]
  77struct EditFileToolPartialInput {
  78    #[serde(default)]
  79    path: String,
  80    #[serde(default)]
  81    display_description: String,
  82}
  83
  84#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
  85#[serde(rename_all = "lowercase")]
  86#[schemars(inline)]
  87pub enum EditFileMode {
  88    Edit,
  89    Create,
  90    Overwrite,
  91}
  92
  93#[derive(Debug, Serialize, Deserialize)]
  94pub struct EditFileToolOutput {
  95    #[serde(alias = "original_path")]
  96    input_path: PathBuf,
  97    new_text: String,
  98    old_text: Arc<String>,
  99    #[serde(default)]
 100    diff: String,
 101    #[serde(alias = "raw_output")]
 102    edit_agent_output: EditAgentOutput,
 103}
 104
 105impl From<EditFileToolOutput> for LanguageModelToolResultContent {
 106    fn from(output: EditFileToolOutput) -> Self {
 107        if output.diff.is_empty() {
 108            "No edits were made.".into()
 109        } else {
 110            format!(
 111                "Edited {}:\n\n```diff\n{}\n```",
 112                output.input_path.display(),
 113                output.diff
 114            )
 115            .into()
 116        }
 117    }
 118}
 119
 120pub struct EditFileTool {
 121    thread: WeakEntity<Thread>,
 122    language_registry: Arc<LanguageRegistry>,
 123    project: Entity<Project>,
 124}
 125
 126impl EditFileTool {
 127    pub fn new(
 128        project: Entity<Project>,
 129        thread: WeakEntity<Thread>,
 130        language_registry: Arc<LanguageRegistry>,
 131    ) -> Self {
 132        Self {
 133            project,
 134            thread,
 135            language_registry,
 136        }
 137    }
 138
 139    fn authorize(
 140        &self,
 141        input: &EditFileToolInput,
 142        event_stream: &ToolCallEventStream,
 143        cx: &mut App,
 144    ) -> Task<Result<()>> {
 145        if agent_settings::AgentSettings::get_global(cx).always_allow_tool_actions {
 146            return Task::ready(Ok(()));
 147        }
 148
 149        // If any path component matches the local settings folder, then this could affect
 150        // the editor in ways beyond the project source, so prompt.
 151        let local_settings_folder = paths::local_settings_folder_relative_path();
 152        let path = Path::new(&input.path);
 153        if path
 154            .components()
 155            .any(|component| component.as_os_str() == local_settings_folder.as_os_str())
 156        {
 157            return event_stream.authorize(
 158                format!("{} (local settings)", input.display_description),
 159                cx,
 160            );
 161        }
 162
 163        // It's also possible that the global config dir is configured to be inside the project,
 164        // so check for that edge case too.
 165        if let Ok(canonical_path) = std::fs::canonicalize(&input.path)
 166            && canonical_path.starts_with(paths::config_dir())
 167        {
 168            return event_stream.authorize(
 169                format!("{} (global settings)", input.display_description),
 170                cx,
 171            );
 172        }
 173
 174        // Check if path is inside the global config directory
 175        // First check if it's already inside project - if not, try to canonicalize
 176        let Ok(project_path) = self.thread.read_with(cx, |thread, cx| {
 177            thread.project().read(cx).find_project_path(&input.path, cx)
 178        }) else {
 179            return Task::ready(Err(anyhow!("thread was dropped")));
 180        };
 181
 182        // If the path is inside the project, and it's not one of the above edge cases,
 183        // then no confirmation is necessary. Otherwise, confirmation is necessary.
 184        if project_path.is_some() {
 185            Task::ready(Ok(()))
 186        } else {
 187            event_stream.authorize(&input.display_description, cx)
 188        }
 189    }
 190}
 191
 192impl AgentTool for EditFileTool {
 193    type Input = EditFileToolInput;
 194    type Output = EditFileToolOutput;
 195
 196    fn name() -> &'static str {
 197        "edit_file"
 198    }
 199
 200    fn kind() -> acp::ToolKind {
 201        acp::ToolKind::Edit
 202    }
 203
 204    fn initial_title(
 205        &self,
 206        input: Result<Self::Input, serde_json::Value>,
 207        cx: &mut App,
 208    ) -> SharedString {
 209        match input {
 210            Ok(input) => self
 211                .project
 212                .read(cx)
 213                .find_project_path(&input.path, cx)
 214                .and_then(|project_path| {
 215                    self.project
 216                        .read(cx)
 217                        .short_full_path_for_project_path(&project_path, cx)
 218                })
 219                .unwrap_or(Path::new(&input.path).into())
 220                .to_string_lossy()
 221                .to_string()
 222                .into(),
 223            Err(raw_input) => {
 224                if let Some(input) =
 225                    serde_json::from_value::<EditFileToolPartialInput>(raw_input).ok()
 226                {
 227                    let path = input.path.trim();
 228                    if !path.is_empty() {
 229                        return self
 230                            .project
 231                            .read(cx)
 232                            .find_project_path(&input.path, cx)
 233                            .and_then(|project_path| {
 234                                self.project
 235                                    .read(cx)
 236                                    .short_full_path_for_project_path(&project_path, cx)
 237                            })
 238                            .unwrap_or(Path::new(&input.path).into())
 239                            .to_string_lossy()
 240                            .to_string()
 241                            .into();
 242                    }
 243
 244                    let description = input.display_description.trim();
 245                    if !description.is_empty() {
 246                        return description.to_string().into();
 247                    }
 248                }
 249
 250                DEFAULT_UI_TEXT.into()
 251            }
 252        }
 253    }
 254
 255    fn run(
 256        self: Arc<Self>,
 257        input: Self::Input,
 258        event_stream: ToolCallEventStream,
 259        cx: &mut App,
 260    ) -> Task<Result<Self::Output>> {
 261        let Ok(project) = self
 262            .thread
 263            .read_with(cx, |thread, _cx| thread.project().clone())
 264        else {
 265            return Task::ready(Err(anyhow!("thread was dropped")));
 266        };
 267        let project_path = match resolve_path(&input, project.clone(), cx) {
 268            Ok(path) => path,
 269            Err(err) => return Task::ready(Err(anyhow!(err))),
 270        };
 271        let abs_path = project.read(cx).absolute_path(&project_path, cx);
 272        if let Some(abs_path) = abs_path.clone() {
 273            event_stream.update_fields(ToolCallUpdateFields {
 274                locations: Some(vec![acp::ToolCallLocation {
 275                    path: abs_path,
 276                    line: None,
 277                    meta: None,
 278                }]),
 279                ..Default::default()
 280            });
 281        }
 282
 283        let authorize = self.authorize(&input, &event_stream, cx);
 284        cx.spawn(async move |cx: &mut AsyncApp| {
 285            authorize.await?;
 286
 287            let (request, model, action_log) = self.thread.update(cx, |thread, cx| {
 288                let request = thread.build_completion_request(CompletionIntent::ToolResults, cx);
 289                (request, thread.model().cloned(), thread.action_log().clone())
 290            })?;
 291            let request = request?;
 292            let model = model.context("No language model configured")?;
 293
 294            let edit_format = EditFormat::from_model(model.clone())?;
 295            let edit_agent = EditAgent::new(
 296                model,
 297                project.clone(),
 298                action_log.clone(),
 299                // TODO: move edit agent to this crate so we can use our templates
 300                assistant_tools::templates::Templates::new(),
 301                edit_format,
 302            );
 303
 304            let buffer = project
 305                .update(cx, |project, cx| {
 306                    project.open_buffer(project_path.clone(), cx)
 307                })?
 308                .await?;
 309
 310            let diff = cx.new(|cx| Diff::new(buffer.clone(), cx))?;
 311            event_stream.update_diff(diff.clone());
 312            let _finalize_diff = util::defer({
 313               let diff = diff.downgrade();
 314               let mut cx = cx.clone();
 315               move || {
 316                   diff.update(&mut cx, |diff, cx| diff.finalize(cx)).ok();
 317               }
 318            });
 319
 320            let old_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
 321            let old_text = cx
 322                .background_spawn({
 323                    let old_snapshot = old_snapshot.clone();
 324                    async move { Arc::new(old_snapshot.text()) }
 325                })
 326                .await;
 327
 328
 329            let (output, mut events) = if matches!(input.mode, EditFileMode::Edit) {
 330                edit_agent.edit(
 331                    buffer.clone(),
 332                    input.display_description.clone(),
 333                    &request,
 334                    cx,
 335                )
 336            } else {
 337                edit_agent.overwrite(
 338                    buffer.clone(),
 339                    input.display_description.clone(),
 340                    &request,
 341                    cx,
 342                )
 343            };
 344
 345            let mut hallucinated_old_text = false;
 346            let mut ambiguous_ranges = Vec::new();
 347            let mut emitted_location = false;
 348            while let Some(event) = events.next().await {
 349                match event {
 350                    EditAgentOutputEvent::Edited(range) => {
 351                        if !emitted_location {
 352                            let line = buffer.update(cx, |buffer, _cx| {
 353                                range.start.to_point(&buffer.snapshot()).row
 354                            }).ok();
 355                            if let Some(abs_path) = abs_path.clone() {
 356                                event_stream.update_fields(ToolCallUpdateFields {
 357                                    locations: Some(vec![ToolCallLocation { path: abs_path, line, meta: None }]),
 358                                    ..Default::default()
 359                                });
 360                            }
 361                            emitted_location = true;
 362                        }
 363                    },
 364                    EditAgentOutputEvent::UnresolvedEditRange => hallucinated_old_text = true,
 365                    EditAgentOutputEvent::AmbiguousEditRange(ranges) => ambiguous_ranges = ranges,
 366                    EditAgentOutputEvent::ResolvingEditRange(range) => {
 367                        diff.update(cx, |card, cx| card.reveal_range(range.clone(), cx))?;
 368                        // if !emitted_location {
 369                        //     let line = buffer.update(cx, |buffer, _cx| {
 370                        //         range.start.to_point(&buffer.snapshot()).row
 371                        //     }).ok();
 372                        //     if let Some(abs_path) = abs_path.clone() {
 373                        //         event_stream.update_fields(ToolCallUpdateFields {
 374                        //             locations: Some(vec![ToolCallLocation { path: abs_path, line }]),
 375                        //             ..Default::default()
 376                        //         });
 377                        //     }
 378                        // }
 379                    }
 380                }
 381            }
 382
 383            // If format_on_save is enabled, format the buffer
 384            let format_on_save_enabled = buffer
 385                .read_with(cx, |buffer, cx| {
 386                    let settings = language_settings::language_settings(
 387                        buffer.language().map(|l| l.name()),
 388                        buffer.file(),
 389                        cx,
 390                    );
 391                    settings.format_on_save != FormatOnSave::Off
 392                })
 393                .unwrap_or(false);
 394
 395            let edit_agent_output = output.await?;
 396
 397            if format_on_save_enabled {
 398                action_log.update(cx, |log, cx| {
 399                    log.buffer_edited(buffer.clone(), cx);
 400                })?;
 401
 402                let format_task = project.update(cx, |project, cx| {
 403                    project.format(
 404                        HashSet::from_iter([buffer.clone()]),
 405                        LspFormatTarget::Buffers,
 406                        false, // Don't push to history since the tool did it.
 407                        FormatTrigger::Save,
 408                        cx,
 409                    )
 410                })?;
 411                format_task.await.log_err();
 412            }
 413
 414            project
 415                .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))?
 416                .await?;
 417
 418            action_log.update(cx, |log, cx| {
 419                log.buffer_edited(buffer.clone(), cx);
 420            })?;
 421
 422            let new_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
 423            let (new_text, unified_diff) = cx
 424                .background_spawn({
 425                    let new_snapshot = new_snapshot.clone();
 426                    let old_text = old_text.clone();
 427                    async move {
 428                        let new_text = new_snapshot.text();
 429                        let diff = language::unified_diff(&old_text, &new_text);
 430                        (new_text, diff)
 431                    }
 432                })
 433                .await;
 434
 435            let input_path = input.path.display();
 436            if unified_diff.is_empty() {
 437                anyhow::ensure!(
 438                    !hallucinated_old_text,
 439                    formatdoc! {"
 440                        Some edits were produced but none of them could be applied.
 441                        Read the relevant sections of {input_path} again so that
 442                        I can perform the requested edits.
 443                    "}
 444                );
 445                anyhow::ensure!(
 446                    ambiguous_ranges.is_empty(),
 447                    {
 448                        let line_numbers = ambiguous_ranges
 449                            .iter()
 450                            .map(|range| range.start.to_string())
 451                            .collect::<Vec<_>>()
 452                            .join(", ");
 453                        formatdoc! {"
 454                            <old_text> matches more than one position in the file (lines: {line_numbers}). Read the
 455                            relevant sections of {input_path} again and extend <old_text> so
 456                            that I can perform the requested edits.
 457                        "}
 458                    }
 459                );
 460            }
 461
 462            Ok(EditFileToolOutput {
 463                input_path: input.path,
 464                new_text,
 465                old_text,
 466                diff: unified_diff,
 467                edit_agent_output,
 468            })
 469        })
 470    }
 471
 472    fn replay(
 473        &self,
 474        _input: Self::Input,
 475        output: Self::Output,
 476        event_stream: ToolCallEventStream,
 477        cx: &mut App,
 478    ) -> Result<()> {
 479        event_stream.update_diff(cx.new(|cx| {
 480            Diff::finalized(
 481                output.input_path,
 482                Some(output.old_text.to_string()),
 483                output.new_text,
 484                self.language_registry.clone(),
 485                cx,
 486            )
 487        }));
 488        Ok(())
 489    }
 490}
 491
 492/// Validate that the file path is valid, meaning:
 493///
 494/// - For `edit` and `overwrite`, the path must point to an existing file.
 495/// - For `create`, the file must not already exist, but it's parent dir must exist.
 496fn resolve_path(
 497    input: &EditFileToolInput,
 498    project: Entity<Project>,
 499    cx: &mut App,
 500) -> Result<ProjectPath> {
 501    let project = project.read(cx);
 502
 503    match input.mode {
 504        EditFileMode::Edit | EditFileMode::Overwrite => {
 505            let path = project
 506                .find_project_path(&input.path, cx)
 507                .context("Can't edit file: path not found")?;
 508
 509            let entry = project
 510                .entry_for_path(&path, cx)
 511                .context("Can't edit file: path not found")?;
 512
 513            anyhow::ensure!(entry.is_file(), "Can't edit file: path is a directory");
 514            Ok(path)
 515        }
 516
 517        EditFileMode::Create => {
 518            if let Some(path) = project.find_project_path(&input.path, cx) {
 519                anyhow::ensure!(
 520                    project.entry_for_path(&path, cx).is_none(),
 521                    "Can't create file: file already exists"
 522                );
 523            }
 524
 525            let parent_path = input
 526                .path
 527                .parent()
 528                .context("Can't create file: incorrect path")?;
 529
 530            let parent_project_path = project.find_project_path(&parent_path, cx);
 531
 532            let parent_entry = parent_project_path
 533                .as_ref()
 534                .and_then(|path| project.entry_for_path(path, cx))
 535                .context("Can't create file: parent directory doesn't exist")?;
 536
 537            anyhow::ensure!(
 538                parent_entry.is_dir(),
 539                "Can't create file: parent is not a directory"
 540            );
 541
 542            let file_name = input
 543                .path
 544                .file_name()
 545                .context("Can't create file: invalid filename")?;
 546
 547            let new_file_path = parent_project_path.map(|parent| ProjectPath {
 548                path: Arc::from(parent.path.join(file_name)),
 549                ..parent
 550            });
 551
 552            new_file_path.context("Can't create file")
 553        }
 554    }
 555}
 556
 557#[cfg(test)]
 558mod tests {
 559    use super::*;
 560    use crate::{ContextServerRegistry, Templates};
 561    use client::TelemetrySettings;
 562    use fs::Fs;
 563    use gpui::{TestAppContext, UpdateGlobal};
 564    use language_model::fake_provider::FakeLanguageModel;
 565    use prompt_store::ProjectContext;
 566    use serde_json::json;
 567    use settings::SettingsStore;
 568    use util::path;
 569
 570    #[gpui::test]
 571    async fn test_edit_nonexistent_file(cx: &mut TestAppContext) {
 572        init_test(cx);
 573
 574        let fs = project::FakeFs::new(cx.executor());
 575        fs.insert_tree("/root", json!({})).await;
 576        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
 577        let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
 578        let context_server_registry =
 579            cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
 580        let model = Arc::new(FakeLanguageModel::default());
 581        let thread = cx.new(|cx| {
 582            Thread::new(
 583                project.clone(),
 584                cx.new(|_cx| ProjectContext::default()),
 585                context_server_registry,
 586                Templates::new(),
 587                Some(model),
 588                cx,
 589            )
 590        });
 591        let result = cx
 592            .update(|cx| {
 593                let input = EditFileToolInput {
 594                    display_description: "Some edit".into(),
 595                    path: "root/nonexistent_file.txt".into(),
 596                    mode: EditFileMode::Edit,
 597                };
 598                Arc::new(EditFileTool::new(
 599                    project,
 600                    thread.downgrade(),
 601                    language_registry,
 602                ))
 603                .run(input, ToolCallEventStream::test().0, cx)
 604            })
 605            .await;
 606        assert_eq!(
 607            result.unwrap_err().to_string(),
 608            "Can't edit file: path not found"
 609        );
 610    }
 611
 612    #[gpui::test]
 613    async fn test_resolve_path_for_creating_file(cx: &mut TestAppContext) {
 614        let mode = &EditFileMode::Create;
 615
 616        let result = test_resolve_path(mode, "root/new.txt", cx);
 617        assert_resolved_path_eq(result.await, "new.txt");
 618
 619        let result = test_resolve_path(mode, "new.txt", cx);
 620        assert_resolved_path_eq(result.await, "new.txt");
 621
 622        let result = test_resolve_path(mode, "dir/new.txt", cx);
 623        assert_resolved_path_eq(result.await, "dir/new.txt");
 624
 625        let result = test_resolve_path(mode, "root/dir/subdir/existing.txt", cx);
 626        assert_eq!(
 627            result.await.unwrap_err().to_string(),
 628            "Can't create file: file already exists"
 629        );
 630
 631        let result = test_resolve_path(mode, "root/dir/nonexistent_dir/new.txt", cx);
 632        assert_eq!(
 633            result.await.unwrap_err().to_string(),
 634            "Can't create file: parent directory doesn't exist"
 635        );
 636    }
 637
 638    #[gpui::test]
 639    async fn test_resolve_path_for_editing_file(cx: &mut TestAppContext) {
 640        let mode = &EditFileMode::Edit;
 641
 642        let path_with_root = "root/dir/subdir/existing.txt";
 643        let path_without_root = "dir/subdir/existing.txt";
 644        let result = test_resolve_path(mode, path_with_root, cx);
 645        assert_resolved_path_eq(result.await, path_without_root);
 646
 647        let result = test_resolve_path(mode, path_without_root, cx);
 648        assert_resolved_path_eq(result.await, path_without_root);
 649
 650        let result = test_resolve_path(mode, "root/nonexistent.txt", cx);
 651        assert_eq!(
 652            result.await.unwrap_err().to_string(),
 653            "Can't edit file: path not found"
 654        );
 655
 656        let result = test_resolve_path(mode, "root/dir", cx);
 657        assert_eq!(
 658            result.await.unwrap_err().to_string(),
 659            "Can't edit file: path is a directory"
 660        );
 661    }
 662
 663    async fn test_resolve_path(
 664        mode: &EditFileMode,
 665        path: &str,
 666        cx: &mut TestAppContext,
 667    ) -> anyhow::Result<ProjectPath> {
 668        init_test(cx);
 669
 670        let fs = project::FakeFs::new(cx.executor());
 671        fs.insert_tree(
 672            "/root",
 673            json!({
 674                "dir": {
 675                    "subdir": {
 676                        "existing.txt": "hello"
 677                    }
 678                }
 679            }),
 680        )
 681        .await;
 682        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
 683
 684        let input = EditFileToolInput {
 685            display_description: "Some edit".into(),
 686            path: path.into(),
 687            mode: mode.clone(),
 688        };
 689
 690        cx.update(|cx| resolve_path(&input, project, cx))
 691    }
 692
 693    fn assert_resolved_path_eq(path: anyhow::Result<ProjectPath>, expected: &str) {
 694        let actual = path
 695            .expect("Should return valid path")
 696            .path
 697            .to_str()
 698            .unwrap()
 699            .replace("\\", "/"); // Naive Windows paths normalization
 700        assert_eq!(actual, expected);
 701    }
 702
 703    #[gpui::test]
 704    async fn test_format_on_save(cx: &mut TestAppContext) {
 705        init_test(cx);
 706
 707        let fs = project::FakeFs::new(cx.executor());
 708        fs.insert_tree("/root", json!({"src": {}})).await;
 709
 710        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
 711
 712        // Set up a Rust language with LSP formatting support
 713        let rust_language = Arc::new(language::Language::new(
 714            language::LanguageConfig {
 715                name: "Rust".into(),
 716                matcher: language::LanguageMatcher {
 717                    path_suffixes: vec!["rs".to_string()],
 718                    ..Default::default()
 719                },
 720                ..Default::default()
 721            },
 722            None,
 723        ));
 724
 725        // Register the language and fake LSP
 726        let language_registry = project.read_with(cx, |project, _| project.languages().clone());
 727        language_registry.add(rust_language);
 728
 729        let mut fake_language_servers = language_registry.register_fake_lsp(
 730            "Rust",
 731            language::FakeLspAdapter {
 732                capabilities: lsp::ServerCapabilities {
 733                    document_formatting_provider: Some(lsp::OneOf::Left(true)),
 734                    ..Default::default()
 735                },
 736                ..Default::default()
 737            },
 738        );
 739
 740        // Create the file
 741        fs.save(
 742            path!("/root/src/main.rs").as_ref(),
 743            &"initial content".into(),
 744            language::LineEnding::Unix,
 745        )
 746        .await
 747        .unwrap();
 748
 749        // Open the buffer to trigger LSP initialization
 750        let buffer = project
 751            .update(cx, |project, cx| {
 752                project.open_local_buffer(path!("/root/src/main.rs"), cx)
 753            })
 754            .await
 755            .unwrap();
 756
 757        // Register the buffer with language servers
 758        let _handle = project.update(cx, |project, cx| {
 759            project.register_buffer_with_language_servers(&buffer, cx)
 760        });
 761
 762        const UNFORMATTED_CONTENT: &str = "fn main() {println!(\"Hello!\");}\n";
 763        const FORMATTED_CONTENT: &str =
 764            "This file was formatted by the fake formatter in the test.\n";
 765
 766        // Get the fake language server and set up formatting handler
 767        let fake_language_server = fake_language_servers.next().await.unwrap();
 768        fake_language_server.set_request_handler::<lsp::request::Formatting, _, _>({
 769            |_, _| async move {
 770                Ok(Some(vec![lsp::TextEdit {
 771                    range: lsp::Range::new(lsp::Position::new(0, 0), lsp::Position::new(1, 0)),
 772                    new_text: FORMATTED_CONTENT.to_string(),
 773                }]))
 774            }
 775        });
 776
 777        let context_server_registry =
 778            cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
 779        let model = Arc::new(FakeLanguageModel::default());
 780        let thread = cx.new(|cx| {
 781            Thread::new(
 782                project.clone(),
 783                cx.new(|_cx| ProjectContext::default()),
 784                context_server_registry,
 785                Templates::new(),
 786                Some(model.clone()),
 787                cx,
 788            )
 789        });
 790
 791        // First, test with format_on_save enabled
 792        cx.update(|cx| {
 793            SettingsStore::update_global(cx, |store, cx| {
 794                store.update_user_settings::<language::language_settings::AllLanguageSettings>(
 795                    cx,
 796                    |settings| {
 797                        settings.defaults.format_on_save = Some(FormatOnSave::On);
 798                        settings.defaults.formatter =
 799                            Some(language::language_settings::SelectedFormatter::Auto);
 800                    },
 801                );
 802            });
 803        });
 804
 805        // Have the model stream unformatted content
 806        let edit_result = {
 807            let edit_task = cx.update(|cx| {
 808                let input = EditFileToolInput {
 809                    display_description: "Create main function".into(),
 810                    path: "root/src/main.rs".into(),
 811                    mode: EditFileMode::Overwrite,
 812                };
 813                Arc::new(EditFileTool::new(
 814                    project.clone(),
 815                    thread.downgrade(),
 816                    language_registry.clone(),
 817                ))
 818                .run(input, ToolCallEventStream::test().0, cx)
 819            });
 820
 821            // Stream the unformatted content
 822            cx.executor().run_until_parked();
 823            model.send_last_completion_stream_text_chunk(UNFORMATTED_CONTENT.to_string());
 824            model.end_last_completion_stream();
 825
 826            edit_task.await
 827        };
 828        assert!(edit_result.is_ok());
 829
 830        // Wait for any async operations (e.g. formatting) to complete
 831        cx.executor().run_until_parked();
 832
 833        // Read the file to verify it was formatted automatically
 834        let new_content = fs.load(path!("/root/src/main.rs").as_ref()).await.unwrap();
 835        assert_eq!(
 836            // Ignore carriage returns on Windows
 837            new_content.replace("\r\n", "\n"),
 838            FORMATTED_CONTENT,
 839            "Code should be formatted when format_on_save is enabled"
 840        );
 841
 842        let stale_buffer_count = thread
 843            .read_with(cx, |thread, _cx| thread.action_log.clone())
 844            .read_with(cx, |log, cx| log.stale_buffers(cx).count());
 845
 846        assert_eq!(
 847            stale_buffer_count, 0,
 848            "BUG: Buffer is incorrectly marked as stale after format-on-save. Found {} stale buffers. \
 849             This causes the agent to think the file was modified externally when it was just formatted.",
 850            stale_buffer_count
 851        );
 852
 853        // Next, test with format_on_save disabled
 854        cx.update(|cx| {
 855            SettingsStore::update_global(cx, |store, cx| {
 856                store.update_user_settings::<language::language_settings::AllLanguageSettings>(
 857                    cx,
 858                    |settings| {
 859                        settings.defaults.format_on_save = Some(FormatOnSave::Off);
 860                    },
 861                );
 862            });
 863        });
 864
 865        // Stream unformatted edits again
 866        let edit_result = {
 867            let edit_task = cx.update(|cx| {
 868                let input = EditFileToolInput {
 869                    display_description: "Update main function".into(),
 870                    path: "root/src/main.rs".into(),
 871                    mode: EditFileMode::Overwrite,
 872                };
 873                Arc::new(EditFileTool::new(
 874                    project.clone(),
 875                    thread.downgrade(),
 876                    language_registry,
 877                ))
 878                .run(input, ToolCallEventStream::test().0, cx)
 879            });
 880
 881            // Stream the unformatted content
 882            cx.executor().run_until_parked();
 883            model.send_last_completion_stream_text_chunk(UNFORMATTED_CONTENT.to_string());
 884            model.end_last_completion_stream();
 885
 886            edit_task.await
 887        };
 888        assert!(edit_result.is_ok());
 889
 890        // Wait for any async operations (e.g. formatting) to complete
 891        cx.executor().run_until_parked();
 892
 893        // Verify the file was not formatted
 894        let new_content = fs.load(path!("/root/src/main.rs").as_ref()).await.unwrap();
 895        assert_eq!(
 896            // Ignore carriage returns on Windows
 897            new_content.replace("\r\n", "\n"),
 898            UNFORMATTED_CONTENT,
 899            "Code should not be formatted when format_on_save is disabled"
 900        );
 901    }
 902
 903    #[gpui::test]
 904    async fn test_remove_trailing_whitespace(cx: &mut TestAppContext) {
 905        init_test(cx);
 906
 907        let fs = project::FakeFs::new(cx.executor());
 908        fs.insert_tree("/root", json!({"src": {}})).await;
 909
 910        // Create a simple file with trailing whitespace
 911        fs.save(
 912            path!("/root/src/main.rs").as_ref(),
 913            &"initial content".into(),
 914            language::LineEnding::Unix,
 915        )
 916        .await
 917        .unwrap();
 918
 919        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
 920        let context_server_registry =
 921            cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
 922        let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
 923        let model = Arc::new(FakeLanguageModel::default());
 924        let thread = cx.new(|cx| {
 925            Thread::new(
 926                project.clone(),
 927                cx.new(|_cx| ProjectContext::default()),
 928                context_server_registry,
 929                Templates::new(),
 930                Some(model.clone()),
 931                cx,
 932            )
 933        });
 934
 935        // First, test with remove_trailing_whitespace_on_save enabled
 936        cx.update(|cx| {
 937            SettingsStore::update_global(cx, |store, cx| {
 938                store.update_user_settings::<language::language_settings::AllLanguageSettings>(
 939                    cx,
 940                    |settings| {
 941                        settings.defaults.remove_trailing_whitespace_on_save = Some(true);
 942                    },
 943                );
 944            });
 945        });
 946
 947        const CONTENT_WITH_TRAILING_WHITESPACE: &str =
 948            "fn main() {  \n    println!(\"Hello!\");  \n}\n";
 949
 950        // Have the model stream content that contains trailing whitespace
 951        let edit_result = {
 952            let edit_task = cx.update(|cx| {
 953                let input = EditFileToolInput {
 954                    display_description: "Create main function".into(),
 955                    path: "root/src/main.rs".into(),
 956                    mode: EditFileMode::Overwrite,
 957                };
 958                Arc::new(EditFileTool::new(
 959                    project.clone(),
 960                    thread.downgrade(),
 961                    language_registry.clone(),
 962                ))
 963                .run(input, ToolCallEventStream::test().0, cx)
 964            });
 965
 966            // Stream the content with trailing whitespace
 967            cx.executor().run_until_parked();
 968            model.send_last_completion_stream_text_chunk(
 969                CONTENT_WITH_TRAILING_WHITESPACE.to_string(),
 970            );
 971            model.end_last_completion_stream();
 972
 973            edit_task.await
 974        };
 975        assert!(edit_result.is_ok());
 976
 977        // Wait for any async operations (e.g. formatting) to complete
 978        cx.executor().run_until_parked();
 979
 980        // Read the file to verify trailing whitespace was removed automatically
 981        assert_eq!(
 982            // Ignore carriage returns on Windows
 983            fs.load(path!("/root/src/main.rs").as_ref())
 984                .await
 985                .unwrap()
 986                .replace("\r\n", "\n"),
 987            "fn main() {\n    println!(\"Hello!\");\n}\n",
 988            "Trailing whitespace should be removed when remove_trailing_whitespace_on_save is enabled"
 989        );
 990
 991        // Next, test with remove_trailing_whitespace_on_save disabled
 992        cx.update(|cx| {
 993            SettingsStore::update_global(cx, |store, cx| {
 994                store.update_user_settings::<language::language_settings::AllLanguageSettings>(
 995                    cx,
 996                    |settings| {
 997                        settings.defaults.remove_trailing_whitespace_on_save = Some(false);
 998                    },
 999                );
1000            });
1001        });
1002
1003        // Stream edits again with trailing whitespace
1004        let edit_result = {
1005            let edit_task = cx.update(|cx| {
1006                let input = EditFileToolInput {
1007                    display_description: "Update main function".into(),
1008                    path: "root/src/main.rs".into(),
1009                    mode: EditFileMode::Overwrite,
1010                };
1011                Arc::new(EditFileTool::new(
1012                    project.clone(),
1013                    thread.downgrade(),
1014                    language_registry,
1015                ))
1016                .run(input, ToolCallEventStream::test().0, cx)
1017            });
1018
1019            // Stream the content with trailing whitespace
1020            cx.executor().run_until_parked();
1021            model.send_last_completion_stream_text_chunk(
1022                CONTENT_WITH_TRAILING_WHITESPACE.to_string(),
1023            );
1024            model.end_last_completion_stream();
1025
1026            edit_task.await
1027        };
1028        assert!(edit_result.is_ok());
1029
1030        // Wait for any async operations (e.g. formatting) to complete
1031        cx.executor().run_until_parked();
1032
1033        // Verify the file still has trailing whitespace
1034        // Read the file again - it should still have trailing whitespace
1035        let final_content = fs.load(path!("/root/src/main.rs").as_ref()).await.unwrap();
1036        assert_eq!(
1037            // Ignore carriage returns on Windows
1038            final_content.replace("\r\n", "\n"),
1039            CONTENT_WITH_TRAILING_WHITESPACE,
1040            "Trailing whitespace should remain when remove_trailing_whitespace_on_save is disabled"
1041        );
1042    }
1043
1044    #[gpui::test]
1045    async fn test_authorize(cx: &mut TestAppContext) {
1046        init_test(cx);
1047        let fs = project::FakeFs::new(cx.executor());
1048        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
1049        let context_server_registry =
1050            cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1051        let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1052        let model = Arc::new(FakeLanguageModel::default());
1053        let thread = cx.new(|cx| {
1054            Thread::new(
1055                project.clone(),
1056                cx.new(|_cx| ProjectContext::default()),
1057                context_server_registry,
1058                Templates::new(),
1059                Some(model.clone()),
1060                cx,
1061            )
1062        });
1063        let tool = Arc::new(EditFileTool::new(
1064            project.clone(),
1065            thread.downgrade(),
1066            language_registry,
1067        ));
1068        fs.insert_tree("/root", json!({})).await;
1069
1070        // Test 1: Path with .zed component should require confirmation
1071        let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1072        let _auth = cx.update(|cx| {
1073            tool.authorize(
1074                &EditFileToolInput {
1075                    display_description: "test 1".into(),
1076                    path: ".zed/settings.json".into(),
1077                    mode: EditFileMode::Edit,
1078                },
1079                &stream_tx,
1080                cx,
1081            )
1082        });
1083
1084        let event = stream_rx.expect_authorization().await;
1085        assert_eq!(
1086            event.tool_call.fields.title,
1087            Some("test 1 (local settings)".into())
1088        );
1089
1090        // Test 2: Path outside project should require confirmation
1091        let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1092        let _auth = cx.update(|cx| {
1093            tool.authorize(
1094                &EditFileToolInput {
1095                    display_description: "test 2".into(),
1096                    path: "/etc/hosts".into(),
1097                    mode: EditFileMode::Edit,
1098                },
1099                &stream_tx,
1100                cx,
1101            )
1102        });
1103
1104        let event = stream_rx.expect_authorization().await;
1105        assert_eq!(event.tool_call.fields.title, Some("test 2".into()));
1106
1107        // Test 3: Relative path without .zed should not require confirmation
1108        let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1109        cx.update(|cx| {
1110            tool.authorize(
1111                &EditFileToolInput {
1112                    display_description: "test 3".into(),
1113                    path: "root/src/main.rs".into(),
1114                    mode: EditFileMode::Edit,
1115                },
1116                &stream_tx,
1117                cx,
1118            )
1119        })
1120        .await
1121        .unwrap();
1122        assert!(stream_rx.try_next().is_err());
1123
1124        // Test 4: Path with .zed in the middle should require confirmation
1125        let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1126        let _auth = cx.update(|cx| {
1127            tool.authorize(
1128                &EditFileToolInput {
1129                    display_description: "test 4".into(),
1130                    path: "root/.zed/tasks.json".into(),
1131                    mode: EditFileMode::Edit,
1132                },
1133                &stream_tx,
1134                cx,
1135            )
1136        });
1137        let event = stream_rx.expect_authorization().await;
1138        assert_eq!(
1139            event.tool_call.fields.title,
1140            Some("test 4 (local settings)".into())
1141        );
1142
1143        // Test 5: When always_allow_tool_actions is enabled, no confirmation needed
1144        cx.update(|cx| {
1145            let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
1146            settings.always_allow_tool_actions = true;
1147            agent_settings::AgentSettings::override_global(settings, cx);
1148        });
1149
1150        let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1151        cx.update(|cx| {
1152            tool.authorize(
1153                &EditFileToolInput {
1154                    display_description: "test 5.1".into(),
1155                    path: ".zed/settings.json".into(),
1156                    mode: EditFileMode::Edit,
1157                },
1158                &stream_tx,
1159                cx,
1160            )
1161        })
1162        .await
1163        .unwrap();
1164        assert!(stream_rx.try_next().is_err());
1165
1166        let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1167        cx.update(|cx| {
1168            tool.authorize(
1169                &EditFileToolInput {
1170                    display_description: "test 5.2".into(),
1171                    path: "/etc/hosts".into(),
1172                    mode: EditFileMode::Edit,
1173                },
1174                &stream_tx,
1175                cx,
1176            )
1177        })
1178        .await
1179        .unwrap();
1180        assert!(stream_rx.try_next().is_err());
1181    }
1182
1183    #[gpui::test]
1184    async fn test_authorize_global_config(cx: &mut TestAppContext) {
1185        init_test(cx);
1186        let fs = project::FakeFs::new(cx.executor());
1187        fs.insert_tree("/project", json!({})).await;
1188        let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1189        let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1190        let context_server_registry =
1191            cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1192        let model = Arc::new(FakeLanguageModel::default());
1193        let thread = cx.new(|cx| {
1194            Thread::new(
1195                project.clone(),
1196                cx.new(|_cx| ProjectContext::default()),
1197                context_server_registry,
1198                Templates::new(),
1199                Some(model.clone()),
1200                cx,
1201            )
1202        });
1203        let tool = Arc::new(EditFileTool::new(
1204            project.clone(),
1205            thread.downgrade(),
1206            language_registry,
1207        ));
1208
1209        // Test global config paths - these should require confirmation if they exist and are outside the project
1210        let test_cases = vec![
1211            (
1212                "/etc/hosts",
1213                true,
1214                "System file should require confirmation",
1215            ),
1216            (
1217                "/usr/local/bin/script",
1218                true,
1219                "System bin file should require confirmation",
1220            ),
1221            (
1222                "project/normal_file.rs",
1223                false,
1224                "Normal project file should not require confirmation",
1225            ),
1226        ];
1227
1228        for (path, should_confirm, description) in test_cases {
1229            let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1230            let auth = cx.update(|cx| {
1231                tool.authorize(
1232                    &EditFileToolInput {
1233                        display_description: "Edit file".into(),
1234                        path: path.into(),
1235                        mode: EditFileMode::Edit,
1236                    },
1237                    &stream_tx,
1238                    cx,
1239                )
1240            });
1241
1242            if should_confirm {
1243                stream_rx.expect_authorization().await;
1244            } else {
1245                auth.await.unwrap();
1246                assert!(
1247                    stream_rx.try_next().is_err(),
1248                    "Failed for case: {} - path: {} - expected no confirmation but got one",
1249                    description,
1250                    path
1251                );
1252            }
1253        }
1254    }
1255
1256    #[gpui::test]
1257    async fn test_needs_confirmation_with_multiple_worktrees(cx: &mut TestAppContext) {
1258        init_test(cx);
1259        let fs = project::FakeFs::new(cx.executor());
1260
1261        // Create multiple worktree directories
1262        fs.insert_tree(
1263            "/workspace/frontend",
1264            json!({
1265                "src": {
1266                    "main.js": "console.log('frontend');"
1267                }
1268            }),
1269        )
1270        .await;
1271        fs.insert_tree(
1272            "/workspace/backend",
1273            json!({
1274                "src": {
1275                    "main.rs": "fn main() {}"
1276                }
1277            }),
1278        )
1279        .await;
1280        fs.insert_tree(
1281            "/workspace/shared",
1282            json!({
1283                ".zed": {
1284                    "settings.json": "{}"
1285                }
1286            }),
1287        )
1288        .await;
1289
1290        // Create project with multiple worktrees
1291        let project = Project::test(
1292            fs.clone(),
1293            [
1294                path!("/workspace/frontend").as_ref(),
1295                path!("/workspace/backend").as_ref(),
1296                path!("/workspace/shared").as_ref(),
1297            ],
1298            cx,
1299        )
1300        .await;
1301        let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1302        let context_server_registry =
1303            cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1304        let model = Arc::new(FakeLanguageModel::default());
1305        let thread = cx.new(|cx| {
1306            Thread::new(
1307                project.clone(),
1308                cx.new(|_cx| ProjectContext::default()),
1309                context_server_registry.clone(),
1310                Templates::new(),
1311                Some(model.clone()),
1312                cx,
1313            )
1314        });
1315        let tool = Arc::new(EditFileTool::new(
1316            project.clone(),
1317            thread.downgrade(),
1318            language_registry,
1319        ));
1320
1321        // Test files in different worktrees
1322        let test_cases = vec![
1323            ("frontend/src/main.js", false, "File in first worktree"),
1324            ("backend/src/main.rs", false, "File in second worktree"),
1325            (
1326                "shared/.zed/settings.json",
1327                true,
1328                ".zed file in third worktree",
1329            ),
1330            ("/etc/hosts", true, "Absolute path outside all worktrees"),
1331            (
1332                "../outside/file.txt",
1333                true,
1334                "Relative path outside worktrees",
1335            ),
1336        ];
1337
1338        for (path, should_confirm, description) in test_cases {
1339            let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1340            let auth = cx.update(|cx| {
1341                tool.authorize(
1342                    &EditFileToolInput {
1343                        display_description: "Edit file".into(),
1344                        path: path.into(),
1345                        mode: EditFileMode::Edit,
1346                    },
1347                    &stream_tx,
1348                    cx,
1349                )
1350            });
1351
1352            if should_confirm {
1353                stream_rx.expect_authorization().await;
1354            } else {
1355                auth.await.unwrap();
1356                assert!(
1357                    stream_rx.try_next().is_err(),
1358                    "Failed for case: {} - path: {} - expected no confirmation but got one",
1359                    description,
1360                    path
1361                );
1362            }
1363        }
1364    }
1365
1366    #[gpui::test]
1367    async fn test_needs_confirmation_edge_cases(cx: &mut TestAppContext) {
1368        init_test(cx);
1369        let fs = project::FakeFs::new(cx.executor());
1370        fs.insert_tree(
1371            "/project",
1372            json!({
1373                ".zed": {
1374                    "settings.json": "{}"
1375                },
1376                "src": {
1377                    ".zed": {
1378                        "local.json": "{}"
1379                    }
1380                }
1381            }),
1382        )
1383        .await;
1384        let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1385        let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1386        let context_server_registry =
1387            cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1388        let model = Arc::new(FakeLanguageModel::default());
1389        let thread = cx.new(|cx| {
1390            Thread::new(
1391                project.clone(),
1392                cx.new(|_cx| ProjectContext::default()),
1393                context_server_registry.clone(),
1394                Templates::new(),
1395                Some(model.clone()),
1396                cx,
1397            )
1398        });
1399        let tool = Arc::new(EditFileTool::new(
1400            project.clone(),
1401            thread.downgrade(),
1402            language_registry,
1403        ));
1404
1405        // Test edge cases
1406        let test_cases = vec![
1407            // Empty path - find_project_path returns Some for empty paths
1408            ("", false, "Empty path is treated as project root"),
1409            // Root directory
1410            ("/", true, "Root directory should be outside project"),
1411            // Parent directory references - find_project_path resolves these
1412            (
1413                "project/../other",
1414                false,
1415                "Path with .. is resolved by find_project_path",
1416            ),
1417            (
1418                "project/./src/file.rs",
1419                false,
1420                "Path with . should work normally",
1421            ),
1422            // Windows-style paths (if on Windows)
1423            #[cfg(target_os = "windows")]
1424            ("C:\\Windows\\System32\\hosts", true, "Windows system path"),
1425            #[cfg(target_os = "windows")]
1426            ("project\\src\\main.rs", false, "Windows-style project path"),
1427        ];
1428
1429        for (path, should_confirm, description) in test_cases {
1430            let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1431            let auth = cx.update(|cx| {
1432                tool.authorize(
1433                    &EditFileToolInput {
1434                        display_description: "Edit file".into(),
1435                        path: path.into(),
1436                        mode: EditFileMode::Edit,
1437                    },
1438                    &stream_tx,
1439                    cx,
1440                )
1441            });
1442
1443            if should_confirm {
1444                stream_rx.expect_authorization().await;
1445            } else {
1446                auth.await.unwrap();
1447                assert!(
1448                    stream_rx.try_next().is_err(),
1449                    "Failed for case: {} - path: {} - expected no confirmation but got one",
1450                    description,
1451                    path
1452                );
1453            }
1454        }
1455    }
1456
1457    #[gpui::test]
1458    async fn test_needs_confirmation_with_different_modes(cx: &mut TestAppContext) {
1459        init_test(cx);
1460        let fs = project::FakeFs::new(cx.executor());
1461        fs.insert_tree(
1462            "/project",
1463            json!({
1464                "existing.txt": "content",
1465                ".zed": {
1466                    "settings.json": "{}"
1467                }
1468            }),
1469        )
1470        .await;
1471        let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1472        let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1473        let context_server_registry =
1474            cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1475        let model = Arc::new(FakeLanguageModel::default());
1476        let thread = cx.new(|cx| {
1477            Thread::new(
1478                project.clone(),
1479                cx.new(|_cx| ProjectContext::default()),
1480                context_server_registry.clone(),
1481                Templates::new(),
1482                Some(model.clone()),
1483                cx,
1484            )
1485        });
1486        let tool = Arc::new(EditFileTool::new(
1487            project.clone(),
1488            thread.downgrade(),
1489            language_registry,
1490        ));
1491
1492        // Test different EditFileMode values
1493        let modes = vec![
1494            EditFileMode::Edit,
1495            EditFileMode::Create,
1496            EditFileMode::Overwrite,
1497        ];
1498
1499        for mode in modes {
1500            // Test .zed path with different modes
1501            let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1502            let _auth = cx.update(|cx| {
1503                tool.authorize(
1504                    &EditFileToolInput {
1505                        display_description: "Edit settings".into(),
1506                        path: "project/.zed/settings.json".into(),
1507                        mode: mode.clone(),
1508                    },
1509                    &stream_tx,
1510                    cx,
1511                )
1512            });
1513
1514            stream_rx.expect_authorization().await;
1515
1516            // Test outside path with different modes
1517            let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1518            let _auth = cx.update(|cx| {
1519                tool.authorize(
1520                    &EditFileToolInput {
1521                        display_description: "Edit file".into(),
1522                        path: "/outside/file.txt".into(),
1523                        mode: mode.clone(),
1524                    },
1525                    &stream_tx,
1526                    cx,
1527                )
1528            });
1529
1530            stream_rx.expect_authorization().await;
1531
1532            // Test normal path with different modes
1533            let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1534            cx.update(|cx| {
1535                tool.authorize(
1536                    &EditFileToolInput {
1537                        display_description: "Edit file".into(),
1538                        path: "project/normal.txt".into(),
1539                        mode: mode.clone(),
1540                    },
1541                    &stream_tx,
1542                    cx,
1543                )
1544            })
1545            .await
1546            .unwrap();
1547            assert!(stream_rx.try_next().is_err());
1548        }
1549    }
1550
1551    #[gpui::test]
1552    async fn test_initial_title_with_partial_input(cx: &mut TestAppContext) {
1553        init_test(cx);
1554        let fs = project::FakeFs::new(cx.executor());
1555        let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1556        let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1557        let context_server_registry =
1558            cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1559        let model = Arc::new(FakeLanguageModel::default());
1560        let thread = cx.new(|cx| {
1561            Thread::new(
1562                project.clone(),
1563                cx.new(|_cx| ProjectContext::default()),
1564                context_server_registry,
1565                Templates::new(),
1566                Some(model.clone()),
1567                cx,
1568            )
1569        });
1570        let tool = Arc::new(EditFileTool::new(
1571            project,
1572            thread.downgrade(),
1573            language_registry,
1574        ));
1575
1576        cx.update(|cx| {
1577            // ...
1578            assert_eq!(
1579                tool.initial_title(
1580                    Err(json!({
1581                        "path": "src/main.rs",
1582                        "display_description": "",
1583                        "old_string": "old code",
1584                        "new_string": "new code"
1585                    })),
1586                    cx
1587                ),
1588                "src/main.rs"
1589            );
1590            assert_eq!(
1591                tool.initial_title(
1592                    Err(json!({
1593                        "path": "",
1594                        "display_description": "Fix error handling",
1595                        "old_string": "old code",
1596                        "new_string": "new code"
1597                    })),
1598                    cx
1599                ),
1600                "Fix error handling"
1601            );
1602            assert_eq!(
1603                tool.initial_title(
1604                    Err(json!({
1605                        "path": "src/main.rs",
1606                        "display_description": "Fix error handling",
1607                        "old_string": "old code",
1608                        "new_string": "new code"
1609                    })),
1610                    cx
1611                ),
1612                "src/main.rs"
1613            );
1614            assert_eq!(
1615                tool.initial_title(
1616                    Err(json!({
1617                        "path": "",
1618                        "display_description": "",
1619                        "old_string": "old code",
1620                        "new_string": "new code"
1621                    })),
1622                    cx
1623                ),
1624                DEFAULT_UI_TEXT
1625            );
1626            assert_eq!(
1627                tool.initial_title(Err(serde_json::Value::Null), cx),
1628                DEFAULT_UI_TEXT
1629            );
1630        });
1631    }
1632
1633    #[gpui::test]
1634    async fn test_diff_finalization(cx: &mut TestAppContext) {
1635        init_test(cx);
1636        let fs = project::FakeFs::new(cx.executor());
1637        fs.insert_tree("/", json!({"main.rs": ""})).await;
1638
1639        let project = Project::test(fs.clone(), [path!("/").as_ref()], cx).await;
1640        let languages = project.read_with(cx, |project, _cx| project.languages().clone());
1641        let context_server_registry =
1642            cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1643        let model = Arc::new(FakeLanguageModel::default());
1644        let thread = cx.new(|cx| {
1645            Thread::new(
1646                project.clone(),
1647                cx.new(|_cx| ProjectContext::default()),
1648                context_server_registry.clone(),
1649                Templates::new(),
1650                Some(model.clone()),
1651                cx,
1652            )
1653        });
1654
1655        // Ensure the diff is finalized after the edit completes.
1656        {
1657            let tool = Arc::new(EditFileTool::new(
1658                project.clone(),
1659                thread.downgrade(),
1660                languages.clone(),
1661            ));
1662            let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1663            let edit = cx.update(|cx| {
1664                tool.run(
1665                    EditFileToolInput {
1666                        display_description: "Edit file".into(),
1667                        path: path!("/main.rs").into(),
1668                        mode: EditFileMode::Edit,
1669                    },
1670                    stream_tx,
1671                    cx,
1672                )
1673            });
1674            stream_rx.expect_update_fields().await;
1675            let diff = stream_rx.expect_diff().await;
1676            diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Pending(_))));
1677            cx.run_until_parked();
1678            model.end_last_completion_stream();
1679            edit.await.unwrap();
1680            diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Finalized(_))));
1681        }
1682
1683        // Ensure the diff is finalized if an error occurs while editing.
1684        {
1685            model.forbid_requests();
1686            let tool = Arc::new(EditFileTool::new(
1687                project.clone(),
1688                thread.downgrade(),
1689                languages.clone(),
1690            ));
1691            let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1692            let edit = cx.update(|cx| {
1693                tool.run(
1694                    EditFileToolInput {
1695                        display_description: "Edit file".into(),
1696                        path: path!("/main.rs").into(),
1697                        mode: EditFileMode::Edit,
1698                    },
1699                    stream_tx,
1700                    cx,
1701                )
1702            });
1703            stream_rx.expect_update_fields().await;
1704            let diff = stream_rx.expect_diff().await;
1705            diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Pending(_))));
1706            edit.await.unwrap_err();
1707            diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Finalized(_))));
1708            model.allow_requests();
1709        }
1710
1711        // Ensure the diff is finalized if the tool call gets dropped.
1712        {
1713            let tool = Arc::new(EditFileTool::new(
1714                project.clone(),
1715                thread.downgrade(),
1716                languages.clone(),
1717            ));
1718            let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1719            let edit = cx.update(|cx| {
1720                tool.run(
1721                    EditFileToolInput {
1722                        display_description: "Edit file".into(),
1723                        path: path!("/main.rs").into(),
1724                        mode: EditFileMode::Edit,
1725                    },
1726                    stream_tx,
1727                    cx,
1728                )
1729            });
1730            stream_rx.expect_update_fields().await;
1731            let diff = stream_rx.expect_diff().await;
1732            diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Pending(_))));
1733            drop(edit);
1734            cx.run_until_parked();
1735            diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Finalized(_))));
1736        }
1737    }
1738
1739    fn init_test(cx: &mut TestAppContext) {
1740        cx.update(|cx| {
1741            let settings_store = SettingsStore::test(cx);
1742            cx.set_global(settings_store);
1743            language::init(cx);
1744            TelemetrySettings::register(cx);
1745            agent_settings::AgentSettings::register(cx);
1746            Project::init_settings(cx);
1747        });
1748    }
1749}