streaming_edit_file_tool.rs

   1use super::edit_file_tool::EditFileTool;
   2use super::restore_file_from_disk_tool::RestoreFileFromDiskTool;
   3use super::save_file_tool::SaveFileTool;
   4use crate::{
   5    AgentTool, Templates, Thread, ToolCallEventStream,
   6    edit_agent::streaming_fuzzy_matcher::StreamingFuzzyMatcher,
   7};
   8use acp_thread::Diff;
   9use agent_client_protocol::{self as acp, ToolCallLocation, ToolCallUpdateFields};
  10use anyhow::{Context as _, Result, anyhow};
  11use collections::HashSet;
  12use futures::FutureExt as _;
  13use gpui::{App, AppContext, AsyncApp, Entity, Task, WeakEntity};
  14use language::LanguageRegistry;
  15use language::language_settings::{self, FormatOnSave};
  16use language_model::LanguageModelToolResultContent;
  17use project::lsp_store::{FormatTrigger, LspFormatTarget};
  18use project::{Project, ProjectPath};
  19use schemars::JsonSchema;
  20use serde::{Deserialize, Serialize};
  21use std::ops::Range;
  22use std::path::PathBuf;
  23use std::sync::Arc;
  24use text::BufferSnapshot;
  25use ui::SharedString;
  26use util::ResultExt;
  27use util::rel_path::RelPath;
  28
  29const DEFAULT_UI_TEXT: &str = "Editing file";
  30
  31/// This is a tool for creating a new file or editing an existing file. For moving or renaming files, you should generally use the `terminal` tool with the 'mv' command instead.
  32///
  33/// Before using this tool:
  34///
  35/// 1. Use the `read_file` tool to understand the file's contents and context
  36///
  37/// 2. Verify the directory path is correct (only applicable when creating new files):
  38///    - Use the `list_directory` tool to verify the parent directory exists and is the correct location
  39#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
  40pub struct StreamingEditFileToolInput {
  41    /// A one-line, user-friendly markdown description of the edit. This will be shown in the UI.
  42    ///
  43    /// Be terse, but also descriptive in what you want to achieve with this edit. Avoid generic instructions.
  44    ///
  45    /// NEVER mention the file path in this description.
  46    ///
  47    /// <example>Fix API endpoint URLs</example>
  48    /// <example>Update copyright year in `page_footer`</example>
  49    ///
  50    /// Make sure to include this field before all the others in the input object so that we can display it immediately.
  51    pub display_description: String,
  52
  53    /// The full path of the file to create or modify in the project.
  54    ///
  55    /// WARNING: When specifying which file path need changing, you MUST start each path with one of the project's root directories.
  56    ///
  57    /// The following examples assume we have two root directories in the project:
  58    /// - /a/b/backend
  59    /// - /c/d/frontend
  60    ///
  61    /// <example>
  62    /// `backend/src/main.rs`
  63    ///
  64    /// Notice how the file path starts with `backend`. Without that, the path would be ambiguous and the call would fail!
  65    /// </example>
  66    ///
  67    /// <example>
  68    /// `frontend/db.js`
  69    /// </example>
  70    pub path: PathBuf,
  71
  72    /// The mode of operation on the file. Possible values:
  73    /// - 'create': Create a new file if it doesn't exist. Requires 'content' field.
  74    /// - 'overwrite': Replace the entire contents of an existing file. Requires 'content' field.
  75    /// - 'edit': Make granular edits to an existing file. Requires 'edits' field.
  76    ///
  77    /// When a file already exists or you just created it, prefer editing it as opposed to recreating it from scratch.
  78    pub mode: StreamingEditFileMode,
  79
  80    /// The complete content for the new file (required for 'create' and 'overwrite' modes).
  81    /// This field should contain the entire file content.
  82    #[serde(default, skip_serializing_if = "Option::is_none")]
  83    pub content: Option<String>,
  84
  85    /// List of edit operations to apply sequentially (required for 'edit' mode).
  86    /// Each edit finds `old_text` in the file and replaces it with `new_text`.
  87    #[serde(default, skip_serializing_if = "Option::is_none")]
  88    pub edits: Option<Vec<EditOperation>>,
  89}
  90
  91#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
  92#[serde(rename_all = "snake_case")]
  93pub enum StreamingEditFileMode {
  94    /// Create a new file if it doesn't exist
  95    Create,
  96    /// Replace the entire contents of an existing file
  97    Overwrite,
  98    /// Make granular edits to an existing file
  99    Edit,
 100}
 101
 102/// A single edit operation that replaces old text with new text
 103#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
 104pub struct EditOperation {
 105    /// The exact text to find in the file. This will be matched using fuzzy matching
 106    /// to handle minor differences in whitespace or formatting.
 107    pub old_text: String,
 108    /// The text to replace it with
 109    pub new_text: String,
 110}
 111
 112#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
 113struct StreamingEditFileToolPartialInput {
 114    #[serde(default)]
 115    path: String,
 116    #[serde(default)]
 117    display_description: String,
 118}
 119
 120#[derive(Debug, Serialize, Deserialize)]
 121pub struct StreamingEditFileToolOutput {
 122    #[serde(alias = "original_path")]
 123    input_path: PathBuf,
 124    new_text: String,
 125    old_text: Arc<String>,
 126    #[serde(default)]
 127    diff: String,
 128}
 129
 130impl From<StreamingEditFileToolOutput> for LanguageModelToolResultContent {
 131    fn from(output: StreamingEditFileToolOutput) -> Self {
 132        if output.diff.is_empty() {
 133            "No edits were made.".into()
 134        } else {
 135            format!(
 136                "Edited {}:\n\n```diff\n{}\n```",
 137                output.input_path.display(),
 138                output.diff
 139            )
 140            .into()
 141        }
 142    }
 143}
 144
 145pub struct StreamingEditFileTool {
 146    thread: WeakEntity<Thread>,
 147    language_registry: Arc<LanguageRegistry>,
 148    project: Entity<Project>,
 149    #[allow(dead_code)]
 150    templates: Arc<Templates>,
 151}
 152
 153impl StreamingEditFileTool {
 154    pub fn new(
 155        project: Entity<Project>,
 156        thread: WeakEntity<Thread>,
 157        language_registry: Arc<LanguageRegistry>,
 158        templates: Arc<Templates>,
 159    ) -> Self {
 160        Self {
 161            project,
 162            thread,
 163            language_registry,
 164            templates,
 165        }
 166    }
 167
 168    fn authorize(
 169        &self,
 170        input: &StreamingEditFileToolInput,
 171        event_stream: &ToolCallEventStream,
 172        cx: &mut App,
 173    ) -> Task<Result<()>> {
 174        super::edit_file_tool::authorize_file_edit(
 175            EditFileTool::NAME,
 176            &input.path,
 177            &input.display_description,
 178            &self.thread,
 179            event_stream,
 180            cx,
 181        )
 182    }
 183}
 184
 185impl AgentTool for StreamingEditFileTool {
 186    type Input = StreamingEditFileToolInput;
 187    type Output = StreamingEditFileToolOutput;
 188
 189    const NAME: &'static str = "streaming_edit_file";
 190
 191    fn kind() -> acp::ToolKind {
 192        acp::ToolKind::Edit
 193    }
 194
 195    fn initial_title(
 196        &self,
 197        input: Result<Self::Input, serde_json::Value>,
 198        cx: &mut App,
 199    ) -> SharedString {
 200        match input {
 201            Ok(input) => self
 202                .project
 203                .read(cx)
 204                .find_project_path(&input.path, cx)
 205                .and_then(|project_path| {
 206                    self.project
 207                        .read(cx)
 208                        .short_full_path_for_project_path(&project_path, cx)
 209                })
 210                .unwrap_or(input.path.to_string_lossy().into_owned())
 211                .into(),
 212            Err(raw_input) => {
 213                if let Some(input) =
 214                    serde_json::from_value::<StreamingEditFileToolPartialInput>(raw_input).ok()
 215                {
 216                    let path = input.path.trim();
 217                    if !path.is_empty() {
 218                        return self
 219                            .project
 220                            .read(cx)
 221                            .find_project_path(&input.path, cx)
 222                            .and_then(|project_path| {
 223                                self.project
 224                                    .read(cx)
 225                                    .short_full_path_for_project_path(&project_path, cx)
 226                            })
 227                            .unwrap_or(input.path)
 228                            .into();
 229                    }
 230
 231                    let description = input.display_description.trim();
 232                    if !description.is_empty() {
 233                        return description.to_string().into();
 234                    }
 235                }
 236
 237                DEFAULT_UI_TEXT.into()
 238            }
 239        }
 240    }
 241
 242    fn run(
 243        self: Arc<Self>,
 244        input: Self::Input,
 245        event_stream: ToolCallEventStream,
 246        cx: &mut App,
 247    ) -> Task<Result<Self::Output>> {
 248        let Ok(project) = self
 249            .thread
 250            .read_with(cx, |thread, _cx| thread.project().clone())
 251        else {
 252            return Task::ready(Err(anyhow!("thread was dropped")));
 253        };
 254
 255        let project_path = match resolve_path(&input, project.clone(), cx) {
 256            Ok(path) => path,
 257            Err(err) => return Task::ready(Err(anyhow!(err))),
 258        };
 259
 260        let abs_path = project.read(cx).absolute_path(&project_path, cx);
 261        if let Some(abs_path) = abs_path.clone() {
 262            event_stream.update_fields(
 263                ToolCallUpdateFields::new().locations(vec![acp::ToolCallLocation::new(abs_path)]),
 264            );
 265        }
 266
 267        let authorize = self.authorize(&input, &event_stream, cx);
 268
 269        cx.spawn(async move |cx: &mut AsyncApp| {
 270            authorize.await?;
 271
 272            let buffer = project
 273                .update(cx, |project, cx| {
 274                    project.open_buffer(project_path.clone(), cx)
 275                })
 276                .await?;
 277
 278            if let Some(abs_path) = abs_path.as_ref() {
 279                let (last_read_mtime, current_mtime, is_dirty, has_save_tool, has_restore_tool) =
 280                    self.thread.update(cx, |thread, cx| {
 281                        let last_read = thread.file_read_times.get(abs_path).copied();
 282                        let current = buffer
 283                            .read(cx)
 284                            .file()
 285                            .and_then(|file| file.disk_state().mtime());
 286                        let dirty = buffer.read(cx).is_dirty();
 287                        let has_save = thread.has_tool(SaveFileTool::NAME);
 288                        let has_restore = thread.has_tool(RestoreFileFromDiskTool::NAME);
 289                        (last_read, current, dirty, has_save, has_restore)
 290                    })?;
 291
 292                if is_dirty {
 293                    let message = match (has_save_tool, has_restore_tool) {
 294                        (true, true) => {
 295                            "This file has unsaved changes. Ask the user whether they want to keep or discard those changes. \
 296                             If they want to keep them, ask for confirmation then use the save_file tool to save the file, then retry this edit. \
 297                             If they want to discard them, ask for confirmation then use the restore_file_from_disk tool to restore the on-disk contents, then retry this edit."
 298                        }
 299                        (true, false) => {
 300                            "This file has unsaved changes. Ask the user whether they want to keep or discard those changes. \
 301                             If they want to keep them, ask for confirmation then use the save_file tool to save the file, then retry this edit. \
 302                             If they want to discard them, ask the user to manually revert the file, then inform you when it's ok to proceed."
 303                        }
 304                        (false, true) => {
 305                            "This file has unsaved changes. Ask the user whether they want to keep or discard those changes. \
 306                             If they want to keep them, ask the user to manually save the file, then inform you when it's ok to proceed. \
 307                             If they want to discard them, ask for confirmation then use the restore_file_from_disk tool to restore the on-disk contents, then retry this edit."
 308                        }
 309                        (false, false) => {
 310                            "This file has unsaved changes. Ask the user whether they want to keep or discard those changes, \
 311                             then ask them to save or revert the file manually and inform you when it's ok to proceed."
 312                        }
 313                    };
 314                    anyhow::bail!("{}", message);
 315                }
 316
 317                if let (Some(last_read), Some(current)) = (last_read_mtime, current_mtime) {
 318                    if current != last_read {
 319                        anyhow::bail!(
 320                            "The file {} has been modified since you last read it. \
 321                             Please read the file again to get the current state before editing it.",
 322                            input.path.display()
 323                        );
 324                    }
 325                }
 326            }
 327
 328            let diff = cx.new(|cx| Diff::new(buffer.clone(), cx));
 329            event_stream.update_diff(diff.clone());
 330            let _finalize_diff = util::defer({
 331                let diff = diff.downgrade();
 332                let mut cx = cx.clone();
 333                move || {
 334                    diff.update(&mut cx, |diff, cx| diff.finalize(cx)).ok();
 335                }
 336            });
 337
 338            let old_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 339            let old_text = cx
 340                .background_spawn({
 341                    let old_snapshot = old_snapshot.clone();
 342                    async move { Arc::new(old_snapshot.text()) }
 343                })
 344                .await;
 345
 346            let action_log = self.thread.read_with(cx, |thread, _cx| thread.action_log().clone())?;
 347
 348            // Edit the buffer and report edits to the action log as part of the
 349            // same effect cycle, otherwise the edit will be reported as if the
 350            // user made it (due to the buffer subscription in action_log).
 351            match input.mode {
 352                StreamingEditFileMode::Create | StreamingEditFileMode::Overwrite => {
 353                    action_log.update(cx, |log, cx| {
 354                        log.buffer_created(buffer.clone(), cx);
 355                    });
 356                    let content = input.content.ok_or_else(|| {
 357                        anyhow!("'content' field is required for create and overwrite modes")
 358                    })?;
 359                    cx.update(|cx| {
 360                        buffer.update(cx, |buffer, cx| {
 361                            buffer.edit([(0..buffer.len(), content.as_str())], None, cx);
 362                        });
 363                        action_log.update(cx, |log, cx| {
 364                            log.buffer_edited(buffer.clone(), cx);
 365                        });
 366                    });
 367                }
 368                StreamingEditFileMode::Edit => {
 369                    action_log.update(cx, |log, cx| {
 370                        log.buffer_read(buffer.clone(), cx);
 371                    });
 372                    let edits = input.edits.ok_or_else(|| {
 373                        anyhow!("'edits' field is required for edit mode")
 374                    })?;
 375                    // apply_edits now handles buffer_edited internally in the same effect cycle
 376                    apply_edits(&buffer, &action_log, &edits, &diff, &event_stream, &abs_path, cx)?;
 377                }
 378            }
 379
 380            let format_on_save_enabled = buffer.read_with(cx, |buffer, cx| {
 381                let settings = language_settings::language_settings(
 382                    buffer.language().map(|l| l.name()),
 383                    buffer.file(),
 384                    cx,
 385                );
 386                settings.format_on_save != FormatOnSave::Off
 387            });
 388
 389            if format_on_save_enabled {
 390                action_log.update(cx, |log, cx| {
 391                    log.buffer_edited(buffer.clone(), cx);
 392                });
 393
 394                let format_task = project.update(cx, |project, cx| {
 395                    project.format(
 396                        HashSet::from_iter([buffer.clone()]),
 397                        LspFormatTarget::Buffers,
 398                        false,
 399                        FormatTrigger::Save,
 400                        cx,
 401                    )
 402                });
 403                futures::select! {
 404                    result = format_task.fuse() => { result.log_err(); },
 405                    _ = event_stream.cancelled_by_user().fuse() => {
 406                        anyhow::bail!("Edit cancelled by user");
 407                    }
 408                };
 409            }
 410
 411            let save_task = project
 412                .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx));
 413            futures::select! {
 414                result = save_task.fuse() => { result?; },
 415                _ = event_stream.cancelled_by_user().fuse() => {
 416                    anyhow::bail!("Edit cancelled by user");
 417                }
 418            };
 419
 420            action_log.update(cx, |log, cx| {
 421                log.buffer_edited(buffer.clone(), cx);
 422            });
 423
 424            if let Some(abs_path) = abs_path.as_ref() {
 425                if let Some(new_mtime) = buffer.read_with(cx, |buffer, _| {
 426                    buffer.file().and_then(|file| file.disk_state().mtime())
 427                }) {
 428                    self.thread.update(cx, |thread, _| {
 429                        thread.file_read_times.insert(abs_path.to_path_buf(), new_mtime);
 430                    })?;
 431                }
 432            }
 433
 434            let new_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 435            let (new_text, unified_diff) = cx
 436                .background_spawn({
 437                    let new_snapshot = new_snapshot.clone();
 438                    let old_text = old_text.clone();
 439                    async move {
 440                        let new_text = new_snapshot.text();
 441                        let diff = language::unified_diff(&old_text, &new_text);
 442                        (new_text, diff)
 443                    }
 444                })
 445                .await;
 446
 447            let output = StreamingEditFileToolOutput {
 448                input_path: input.path,
 449                new_text,
 450                old_text,
 451                diff: unified_diff,
 452            };
 453
 454            Ok(output)
 455        })
 456    }
 457
 458    fn replay(
 459        &self,
 460        _input: Self::Input,
 461        output: Self::Output,
 462        event_stream: ToolCallEventStream,
 463        cx: &mut App,
 464    ) -> Result<()> {
 465        event_stream.update_diff(cx.new(|cx| {
 466            Diff::finalized(
 467                output.input_path.to_string_lossy().into_owned(),
 468                Some(output.old_text.to_string()),
 469                output.new_text,
 470                self.language_registry.clone(),
 471                cx,
 472            )
 473        }));
 474        Ok(())
 475    }
 476}
 477
 478fn apply_edits(
 479    buffer: &Entity<language::Buffer>,
 480    action_log: &Entity<action_log::ActionLog>,
 481    edits: &[EditOperation],
 482    diff: &Entity<Diff>,
 483    event_stream: &ToolCallEventStream,
 484    abs_path: &Option<PathBuf>,
 485    cx: &mut AsyncApp,
 486) -> Result<()> {
 487    let mut failed_edits = Vec::new();
 488    let mut ambiguous_edits = Vec::new();
 489    let mut resolved_edits: Vec<(Range<usize>, String)> = Vec::new();
 490
 491    // First pass: resolve all edits without applying them
 492    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 493    for (index, edit) in edits.iter().enumerate() {
 494        let result = resolve_edit(&snapshot, edit);
 495
 496        match result {
 497            Ok(Some((range, new_text))) => {
 498                // Reveal the range in the diff view
 499                let (start_anchor, end_anchor) = buffer.read_with(cx, |buffer, _cx| {
 500                    (
 501                        buffer.anchor_before(range.start),
 502                        buffer.anchor_after(range.end),
 503                    )
 504                });
 505                diff.update(cx, |card, cx| {
 506                    card.reveal_range(start_anchor..end_anchor, cx)
 507                });
 508                resolved_edits.push((range, new_text));
 509            }
 510            Ok(None) => {
 511                failed_edits.push(index);
 512            }
 513            Err(ranges) => {
 514                ambiguous_edits.push((index, ranges));
 515            }
 516        }
 517    }
 518
 519    // Check for errors before applying any edits
 520    if !failed_edits.is_empty() {
 521        let indices = failed_edits
 522            .iter()
 523            .map(|i| i.to_string())
 524            .collect::<Vec<_>>()
 525            .join(", ");
 526        anyhow::bail!(
 527            "Could not find matching text for edit(s) at index(es): {}. \
 528             The old_text did not match any content in the file. \
 529             Please read the file again to get the current content.",
 530            indices
 531        );
 532    }
 533
 534    if !ambiguous_edits.is_empty() {
 535        let details: Vec<String> = ambiguous_edits
 536            .iter()
 537            .map(|(index, ranges)| {
 538                let lines = ranges
 539                    .iter()
 540                    .map(|r| (snapshot.offset_to_point(r.start).row + 1).to_string())
 541                    .collect::<Vec<_>>()
 542                    .join(", ");
 543                format!("edit {}: matches at lines {}", index, lines)
 544            })
 545            .collect();
 546        anyhow::bail!(
 547            "Some edits matched multiple locations in the file:\n{}. \
 548             Please provide more context in old_text to uniquely identify the location.",
 549            details.join("\n")
 550        );
 551    }
 552
 553    // Sort edits by position so buffer.edit() can handle offset translation
 554    let mut edits_sorted = resolved_edits;
 555    edits_sorted.sort_by(|a, b| a.0.start.cmp(&b.0.start));
 556
 557    // Emit location for the earliest edit in the file
 558    if let Some((first_range, _)) = edits_sorted.first() {
 559        if let Some(abs_path) = abs_path.clone() {
 560            let line = snapshot.offset_to_point(first_range.start).row;
 561            event_stream.update_fields(
 562                ToolCallUpdateFields::new()
 563                    .locations(vec![ToolCallLocation::new(abs_path).line(Some(line))]),
 564            );
 565        }
 566    }
 567
 568    // Validate no overlaps (sorted ascending by start)
 569    for window in edits_sorted.windows(2) {
 570        if let [(earlier_range, _), (later_range, _)] = window
 571            && (earlier_range.end > later_range.start || earlier_range.start == later_range.start)
 572        {
 573            let earlier_start_line = snapshot.offset_to_point(earlier_range.start).row + 1;
 574            let earlier_end_line = snapshot.offset_to_point(earlier_range.end).row + 1;
 575            let later_start_line = snapshot.offset_to_point(later_range.start).row + 1;
 576            let later_end_line = snapshot.offset_to_point(later_range.end).row + 1;
 577            anyhow::bail!(
 578                "Conflicting edit ranges detected: lines {}-{} conflicts with lines {}-{}. \
 579                 Conflicting edit ranges are not allowed, as they would overwrite each other.",
 580                earlier_start_line,
 581                earlier_end_line,
 582                later_start_line,
 583                later_end_line,
 584            );
 585        }
 586    }
 587
 588    // Apply all edits in a single batch and report to action_log in the same
 589    // effect cycle. This prevents the buffer subscription from treating these
 590    // as user edits.
 591    if !edits_sorted.is_empty() {
 592        cx.update(|cx| {
 593            buffer.update(cx, |buffer, cx| {
 594                buffer.edit(
 595                    edits_sorted
 596                        .iter()
 597                        .map(|(range, new_text)| (range.clone(), new_text.as_str())),
 598                    None,
 599                    cx,
 600                );
 601            });
 602            action_log.update(cx, |log, cx| {
 603                log.buffer_edited(buffer.clone(), cx);
 604            });
 605        });
 606    }
 607
 608    Ok(())
 609}
 610
 611/// Resolves an edit operation by finding the matching text in the buffer.
 612/// Returns Ok(Some((range, new_text))) if a unique match is found,
 613/// Ok(None) if no match is found, or Err(ranges) if multiple matches are found.
 614fn resolve_edit(
 615    snapshot: &BufferSnapshot,
 616    edit: &EditOperation,
 617) -> std::result::Result<Option<(Range<usize>, String)>, Vec<Range<usize>>> {
 618    let mut matcher = StreamingFuzzyMatcher::new(snapshot.clone());
 619    matcher.push(&edit.old_text, None);
 620    let matches = matcher.finish();
 621
 622    if matches.is_empty() {
 623        return Ok(None);
 624    }
 625
 626    if matches.len() > 1 {
 627        return Err(matches);
 628    }
 629
 630    let match_range = matches.into_iter().next().expect("checked len above");
 631    Ok(Some((match_range, edit.new_text.clone())))
 632}
 633
 634fn resolve_path(
 635    input: &StreamingEditFileToolInput,
 636    project: Entity<Project>,
 637    cx: &mut App,
 638) -> Result<ProjectPath> {
 639    let project = project.read(cx);
 640
 641    match input.mode {
 642        StreamingEditFileMode::Edit | StreamingEditFileMode::Overwrite => {
 643            let path = project
 644                .find_project_path(&input.path, cx)
 645                .context("Can't edit file: path not found")?;
 646
 647            let entry = project
 648                .entry_for_path(&path, cx)
 649                .context("Can't edit file: path not found")?;
 650
 651            anyhow::ensure!(entry.is_file(), "Can't edit file: path is a directory");
 652            Ok(path)
 653        }
 654
 655        StreamingEditFileMode::Create => {
 656            if let Some(path) = project.find_project_path(&input.path, cx) {
 657                anyhow::ensure!(
 658                    project.entry_for_path(&path, cx).is_none(),
 659                    "Can't create file: file already exists"
 660                );
 661            }
 662
 663            let parent_path = input
 664                .path
 665                .parent()
 666                .context("Can't create file: incorrect path")?;
 667
 668            let parent_project_path = project.find_project_path(&parent_path, cx);
 669
 670            let parent_entry = parent_project_path
 671                .as_ref()
 672                .and_then(|path| project.entry_for_path(path, cx))
 673                .context("Can't create file: parent directory doesn't exist")?;
 674
 675            anyhow::ensure!(
 676                parent_entry.is_dir(),
 677                "Can't create file: parent is not a directory"
 678            );
 679
 680            let file_name = input
 681                .path
 682                .file_name()
 683                .and_then(|file_name| file_name.to_str())
 684                .and_then(|file_name| RelPath::unix(file_name).ok())
 685                .context("Can't create file: invalid filename")?;
 686
 687            let new_file_path = parent_project_path.map(|parent| ProjectPath {
 688                path: parent.path.join(file_name),
 689                ..parent
 690            });
 691
 692            new_file_path.context("Can't create file")
 693        }
 694    }
 695}
 696
 697#[cfg(test)]
 698mod tests {
 699    use super::*;
 700    use crate::{ContextServerRegistry, Templates};
 701    use gpui::{TestAppContext, UpdateGlobal};
 702    use language_model::fake_provider::FakeLanguageModel;
 703    use prompt_store::ProjectContext;
 704    use serde_json::json;
 705    use settings::SettingsStore;
 706    use util::path;
 707
 708    #[gpui::test]
 709    async fn test_streaming_edit_create_file(cx: &mut TestAppContext) {
 710        init_test(cx);
 711
 712        let fs = project::FakeFs::new(cx.executor());
 713        fs.insert_tree("/root", json!({"dir": {}})).await;
 714        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
 715        let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
 716        let context_server_registry =
 717            cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
 718        let model = Arc::new(FakeLanguageModel::default());
 719        let thread = cx.new(|cx| {
 720            crate::Thread::new(
 721                project.clone(),
 722                cx.new(|_cx| ProjectContext::default()),
 723                context_server_registry,
 724                Templates::new(),
 725                Some(model),
 726                cx,
 727            )
 728        });
 729
 730        let result = cx
 731            .update(|cx| {
 732                let input = StreamingEditFileToolInput {
 733                    display_description: "Create new file".into(),
 734                    path: "root/dir/new_file.txt".into(),
 735                    mode: StreamingEditFileMode::Create,
 736                    content: Some("Hello, World!".into()),
 737                    edits: None,
 738                };
 739                Arc::new(StreamingEditFileTool::new(
 740                    project.clone(),
 741                    thread.downgrade(),
 742                    language_registry,
 743                    Templates::new(),
 744                ))
 745                .run(input, ToolCallEventStream::test().0, cx)
 746            })
 747            .await;
 748
 749        assert!(result.is_ok());
 750        let output = result.unwrap();
 751        assert_eq!(output.new_text, "Hello, World!");
 752        assert!(!output.diff.is_empty());
 753    }
 754
 755    #[gpui::test]
 756    async fn test_streaming_edit_overwrite_file(cx: &mut TestAppContext) {
 757        init_test(cx);
 758
 759        let fs = project::FakeFs::new(cx.executor());
 760        fs.insert_tree("/root", json!({"file.txt": "old content"}))
 761            .await;
 762        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
 763        let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
 764        let context_server_registry =
 765            cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
 766        let model = Arc::new(FakeLanguageModel::default());
 767        let thread = cx.new(|cx| {
 768            crate::Thread::new(
 769                project.clone(),
 770                cx.new(|_cx| ProjectContext::default()),
 771                context_server_registry,
 772                Templates::new(),
 773                Some(model),
 774                cx,
 775            )
 776        });
 777
 778        let result = cx
 779            .update(|cx| {
 780                let input = StreamingEditFileToolInput {
 781                    display_description: "Overwrite file".into(),
 782                    path: "root/file.txt".into(),
 783                    mode: StreamingEditFileMode::Overwrite,
 784                    content: Some("new content".into()),
 785                    edits: None,
 786                };
 787                Arc::new(StreamingEditFileTool::new(
 788                    project.clone(),
 789                    thread.downgrade(),
 790                    language_registry,
 791                    Templates::new(),
 792                ))
 793                .run(input, ToolCallEventStream::test().0, cx)
 794            })
 795            .await;
 796
 797        assert!(result.is_ok());
 798        let output = result.unwrap();
 799        assert_eq!(output.new_text, "new content");
 800        assert_eq!(*output.old_text, "old content");
 801    }
 802
 803    #[gpui::test]
 804    async fn test_streaming_edit_granular_edits(cx: &mut TestAppContext) {
 805        init_test(cx);
 806
 807        let fs = project::FakeFs::new(cx.executor());
 808        fs.insert_tree(
 809            "/root",
 810            json!({
 811                "file.txt": "line 1\nline 2\nline 3\n"
 812            }),
 813        )
 814        .await;
 815        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
 816        let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
 817        let context_server_registry =
 818            cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
 819        let model = Arc::new(FakeLanguageModel::default());
 820        let thread = cx.new(|cx| {
 821            crate::Thread::new(
 822                project.clone(),
 823                cx.new(|_cx| ProjectContext::default()),
 824                context_server_registry,
 825                Templates::new(),
 826                Some(model),
 827                cx,
 828            )
 829        });
 830
 831        let result = cx
 832            .update(|cx| {
 833                let input = StreamingEditFileToolInput {
 834                    display_description: "Edit lines".into(),
 835                    path: "root/file.txt".into(),
 836                    mode: StreamingEditFileMode::Edit,
 837                    content: None,
 838                    edits: Some(vec![EditOperation {
 839                        old_text: "line 2".into(),
 840                        new_text: "modified line 2".into(),
 841                    }]),
 842                };
 843                Arc::new(StreamingEditFileTool::new(
 844                    project.clone(),
 845                    thread.downgrade(),
 846                    language_registry,
 847                    Templates::new(),
 848                ))
 849                .run(input, ToolCallEventStream::test().0, cx)
 850            })
 851            .await;
 852
 853        assert!(result.is_ok());
 854        let output = result.unwrap();
 855        assert_eq!(output.new_text, "line 1\nmodified line 2\nline 3\n");
 856    }
 857
 858    #[gpui::test]
 859    async fn test_streaming_edit_multiple_nonoverlapping_edits(cx: &mut TestAppContext) {
 860        init_test(cx);
 861
 862        let fs = project::FakeFs::new(cx.executor());
 863        fs.insert_tree(
 864            "/root",
 865            json!({
 866                "file.txt": "line 1\nline 2\nline 3\nline 4\nline 5\n"
 867            }),
 868        )
 869        .await;
 870        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
 871        let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
 872        let context_server_registry =
 873            cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
 874        let model = Arc::new(FakeLanguageModel::default());
 875        let thread = cx.new(|cx| {
 876            crate::Thread::new(
 877                project.clone(),
 878                cx.new(|_cx| ProjectContext::default()),
 879                context_server_registry,
 880                Templates::new(),
 881                Some(model),
 882                cx,
 883            )
 884        });
 885
 886        let result = cx
 887            .update(|cx| {
 888                let input = StreamingEditFileToolInput {
 889                    display_description: "Edit multiple lines".into(),
 890                    path: "root/file.txt".into(),
 891                    mode: StreamingEditFileMode::Edit,
 892                    content: None,
 893                    edits: Some(vec![
 894                        EditOperation {
 895                            old_text: "line 5".into(),
 896                            new_text: "modified line 5".into(),
 897                        },
 898                        EditOperation {
 899                            old_text: "line 1".into(),
 900                            new_text: "modified line 1".into(),
 901                        },
 902                    ]),
 903                };
 904                Arc::new(StreamingEditFileTool::new(
 905                    project.clone(),
 906                    thread.downgrade(),
 907                    language_registry,
 908                    Templates::new(),
 909                ))
 910                .run(input, ToolCallEventStream::test().0, cx)
 911            })
 912            .await;
 913
 914        assert!(result.is_ok());
 915        let output = result.unwrap();
 916        assert_eq!(
 917            output.new_text,
 918            "modified line 1\nline 2\nline 3\nline 4\nmodified line 5\n"
 919        );
 920    }
 921
 922    #[gpui::test]
 923    async fn test_streaming_edit_adjacent_edits(cx: &mut TestAppContext) {
 924        init_test(cx);
 925
 926        let fs = project::FakeFs::new(cx.executor());
 927        fs.insert_tree(
 928            "/root",
 929            json!({
 930                "file.txt": "line 1\nline 2\nline 3\nline 4\nline 5\n"
 931            }),
 932        )
 933        .await;
 934        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
 935        let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
 936        let context_server_registry =
 937            cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
 938        let model = Arc::new(FakeLanguageModel::default());
 939        let thread = cx.new(|cx| {
 940            crate::Thread::new(
 941                project.clone(),
 942                cx.new(|_cx| ProjectContext::default()),
 943                context_server_registry,
 944                Templates::new(),
 945                Some(model),
 946                cx,
 947            )
 948        });
 949
 950        let result = cx
 951            .update(|cx| {
 952                let input = StreamingEditFileToolInput {
 953                    display_description: "Edit adjacent lines".into(),
 954                    path: "root/file.txt".into(),
 955                    mode: StreamingEditFileMode::Edit,
 956                    content: None,
 957                    edits: Some(vec![
 958                        EditOperation {
 959                            old_text: "line 2".into(),
 960                            new_text: "modified line 2".into(),
 961                        },
 962                        EditOperation {
 963                            old_text: "line 3".into(),
 964                            new_text: "modified line 3".into(),
 965                        },
 966                    ]),
 967                };
 968                Arc::new(StreamingEditFileTool::new(
 969                    project.clone(),
 970                    thread.downgrade(),
 971                    language_registry,
 972                    Templates::new(),
 973                ))
 974                .run(input, ToolCallEventStream::test().0, cx)
 975            })
 976            .await;
 977
 978        assert!(result.is_ok());
 979        let output = result.unwrap();
 980        assert_eq!(
 981            output.new_text,
 982            "line 1\nmodified line 2\nmodified line 3\nline 4\nline 5\n"
 983        );
 984    }
 985
 986    #[gpui::test]
 987    async fn test_streaming_edit_ascending_order_edits(cx: &mut TestAppContext) {
 988        init_test(cx);
 989
 990        let fs = project::FakeFs::new(cx.executor());
 991        fs.insert_tree(
 992            "/root",
 993            json!({
 994                "file.txt": "line 1\nline 2\nline 3\nline 4\nline 5\n"
 995            }),
 996        )
 997        .await;
 998        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
 999        let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1000        let context_server_registry =
1001            cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1002        let model = Arc::new(FakeLanguageModel::default());
1003        let thread = cx.new(|cx| {
1004            crate::Thread::new(
1005                project.clone(),
1006                cx.new(|_cx| ProjectContext::default()),
1007                context_server_registry,
1008                Templates::new(),
1009                Some(model),
1010                cx,
1011            )
1012        });
1013
1014        let result = cx
1015            .update(|cx| {
1016                let input = StreamingEditFileToolInput {
1017                    display_description: "Edit multiple lines in ascending order".into(),
1018                    path: "root/file.txt".into(),
1019                    mode: StreamingEditFileMode::Edit,
1020                    content: None,
1021                    edits: Some(vec![
1022                        EditOperation {
1023                            old_text: "line 1".into(),
1024                            new_text: "modified line 1".into(),
1025                        },
1026                        EditOperation {
1027                            old_text: "line 5".into(),
1028                            new_text: "modified line 5".into(),
1029                        },
1030                    ]),
1031                };
1032                Arc::new(StreamingEditFileTool::new(
1033                    project.clone(),
1034                    thread.downgrade(),
1035                    language_registry,
1036                    Templates::new(),
1037                ))
1038                .run(input, ToolCallEventStream::test().0, cx)
1039            })
1040            .await;
1041
1042        assert!(result.is_ok());
1043        let output = result.unwrap();
1044        assert_eq!(
1045            output.new_text,
1046            "modified line 1\nline 2\nline 3\nline 4\nmodified line 5\n"
1047        );
1048    }
1049
1050    #[gpui::test]
1051    async fn test_streaming_edit_nonexistent_file(cx: &mut TestAppContext) {
1052        init_test(cx);
1053
1054        let fs = project::FakeFs::new(cx.executor());
1055        fs.insert_tree("/root", json!({})).await;
1056        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
1057        let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1058        let context_server_registry =
1059            cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1060        let model = Arc::new(FakeLanguageModel::default());
1061        let thread = cx.new(|cx| {
1062            crate::Thread::new(
1063                project.clone(),
1064                cx.new(|_cx| ProjectContext::default()),
1065                context_server_registry,
1066                Templates::new(),
1067                Some(model),
1068                cx,
1069            )
1070        });
1071
1072        let result = cx
1073            .update(|cx| {
1074                let input = StreamingEditFileToolInput {
1075                    display_description: "Some edit".into(),
1076                    path: "root/nonexistent_file.txt".into(),
1077                    mode: StreamingEditFileMode::Edit,
1078                    content: None,
1079                    edits: Some(vec![EditOperation {
1080                        old_text: "foo".into(),
1081                        new_text: "bar".into(),
1082                    }]),
1083                };
1084                Arc::new(StreamingEditFileTool::new(
1085                    project,
1086                    thread.downgrade(),
1087                    language_registry,
1088                    Templates::new(),
1089                ))
1090                .run(input, ToolCallEventStream::test().0, cx)
1091            })
1092            .await;
1093
1094        assert_eq!(
1095            result.unwrap_err().to_string(),
1096            "Can't edit file: path not found"
1097        );
1098    }
1099
1100    #[gpui::test]
1101    async fn test_streaming_edit_failed_match(cx: &mut TestAppContext) {
1102        init_test(cx);
1103
1104        let fs = project::FakeFs::new(cx.executor());
1105        fs.insert_tree("/root", json!({"file.txt": "hello world"}))
1106            .await;
1107        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
1108        let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1109        let context_server_registry =
1110            cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1111        let model = Arc::new(FakeLanguageModel::default());
1112        let thread = cx.new(|cx| {
1113            crate::Thread::new(
1114                project.clone(),
1115                cx.new(|_cx| ProjectContext::default()),
1116                context_server_registry,
1117                Templates::new(),
1118                Some(model),
1119                cx,
1120            )
1121        });
1122
1123        let result = cx
1124            .update(|cx| {
1125                let input = StreamingEditFileToolInput {
1126                    display_description: "Edit file".into(),
1127                    path: "root/file.txt".into(),
1128                    mode: StreamingEditFileMode::Edit,
1129                    content: None,
1130                    edits: Some(vec![EditOperation {
1131                        old_text: "nonexistent text that is not in the file".into(),
1132                        new_text: "replacement".into(),
1133                    }]),
1134                };
1135                Arc::new(StreamingEditFileTool::new(
1136                    project,
1137                    thread.downgrade(),
1138                    language_registry,
1139                    Templates::new(),
1140                ))
1141                .run(input, ToolCallEventStream::test().0, cx)
1142            })
1143            .await;
1144
1145        assert!(result.is_err());
1146        assert!(
1147            result
1148                .unwrap_err()
1149                .to_string()
1150                .contains("Could not find matching text")
1151        );
1152    }
1153
1154    #[gpui::test]
1155    async fn test_streaming_edit_overlapping_edits_out_of_order(cx: &mut TestAppContext) {
1156        init_test(cx);
1157
1158        let fs = project::FakeFs::new(cx.executor());
1159        // Multi-line file so the line-based fuzzy matcher can resolve each edit.
1160        fs.insert_tree(
1161            "/root",
1162            json!({
1163                "file.txt": "line 1\nline 2\nline 3\nline 4\nline 5\n"
1164            }),
1165        )
1166        .await;
1167        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
1168        let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1169        let context_server_registry =
1170            cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1171        let model = Arc::new(FakeLanguageModel::default());
1172        let thread = cx.new(|cx| {
1173            crate::Thread::new(
1174                project.clone(),
1175                cx.new(|_cx| ProjectContext::default()),
1176                context_server_registry,
1177                Templates::new(),
1178                Some(model),
1179                cx,
1180            )
1181        });
1182
1183        // Edit A spans lines 3-4, edit B spans lines 2-3. They overlap on
1184        // "line 3" and are given in descending file order so the ascending
1185        // sort must reorder them before the pairwise overlap check can
1186        // detect them correctly.
1187        let result = cx
1188            .update(|cx| {
1189                let input = StreamingEditFileToolInput {
1190                    display_description: "Overlapping edits".into(),
1191                    path: "root/file.txt".into(),
1192                    mode: StreamingEditFileMode::Edit,
1193                    content: None,
1194                    edits: Some(vec![
1195                        EditOperation {
1196                            old_text: "line 3\nline 4".into(),
1197                            new_text: "SECOND".into(),
1198                        },
1199                        EditOperation {
1200                            old_text: "line 2\nline 3".into(),
1201                            new_text: "FIRST".into(),
1202                        },
1203                    ]),
1204                };
1205                Arc::new(StreamingEditFileTool::new(
1206                    project,
1207                    thread.downgrade(),
1208                    language_registry,
1209                    Templates::new(),
1210                ))
1211                .run(input, ToolCallEventStream::test().0, cx)
1212            })
1213            .await;
1214
1215        let error = result.unwrap_err();
1216        let error_message = error.to_string();
1217        assert!(
1218            error_message.contains("Conflicting edit ranges detected"),
1219            "Expected 'Conflicting edit ranges detected' but got: {error_message}"
1220        );
1221    }
1222
1223    fn init_test(cx: &mut TestAppContext) {
1224        cx.update(|cx| {
1225            let settings_store = SettingsStore::test(cx);
1226            cx.set_global(settings_store);
1227            SettingsStore::update_global(cx, |store: &mut SettingsStore, cx| {
1228                store.update_user_settings(cx, |settings| {
1229                    settings
1230                        .project
1231                        .all_languages
1232                        .defaults
1233                        .ensure_final_newline_on_save = Some(false);
1234                });
1235            });
1236        });
1237    }
1238}