streaming_edit_file_tool.rs

   1use super::edit_file_tool::EditFileTool;
   2use super::restore_file_from_disk_tool::RestoreFileFromDiskTool;
   3use super::save_file_tool::SaveFileTool;
   4use crate::{
   5    AgentTool, Templates, Thread, ToolCallEventStream,
   6    edit_agent::streaming_fuzzy_matcher::StreamingFuzzyMatcher,
   7};
   8use acp_thread::Diff;
   9use agent_client_protocol::{self as acp, ToolCallLocation, ToolCallUpdateFields};
  10use anyhow::{Context as _, Result, anyhow};
  11use collections::HashSet;
  12use futures::FutureExt as _;
  13use gpui::{App, AppContext, AsyncApp, Entity, Task, WeakEntity};
  14use language::LanguageRegistry;
  15use language::language_settings::{self, FormatOnSave};
  16use language_model::LanguageModelToolResultContent;
  17use project::lsp_store::{FormatTrigger, LspFormatTarget};
  18use project::{Project, ProjectPath};
  19use schemars::JsonSchema;
  20use serde::{Deserialize, Serialize};
  21use std::ops::Range;
  22use std::path::PathBuf;
  23use std::sync::Arc;
  24use text::BufferSnapshot;
  25use ui::SharedString;
  26use util::ResultExt;
  27use util::rel_path::RelPath;
  28
  29const DEFAULT_UI_TEXT: &str = "Editing file";
  30
  31/// This is a tool for creating a new file or editing an existing file. For moving or renaming files, you should generally use the `move_path` tool instead.
  32///
  33/// Before using this tool:
  34///
  35/// 1. Use the `read_file` tool to understand the file's contents and context
  36///
  37/// 2. Verify the directory path is correct (only applicable when creating new files):
  38///    - Use the `list_directory` tool to verify the parent directory exists and is the correct location
  39#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
  40pub struct StreamingEditFileToolInput {
  41    /// A one-line, user-friendly markdown description of the edit. This will be shown in the UI.
  42    ///
  43    /// Be terse, but also descriptive in what you want to achieve with this edit. Avoid generic instructions.
  44    ///
  45    /// NEVER mention the file path in this description.
  46    ///
  47    /// <example>Fix API endpoint URLs</example>
  48    /// <example>Update copyright year in `page_footer`</example>
  49    ///
  50    /// Make sure to include this field before all the others in the input object so that we can display it immediately.
  51    pub display_description: String,
  52
  53    /// The full path of the file to create or modify in the project.
  54    ///
  55    /// WARNING: When specifying which file path need changing, you MUST start each path with one of the project's root directories.
  56    ///
  57    /// The following examples assume we have two root directories in the project:
  58    /// - /a/b/backend
  59    /// - /c/d/frontend
  60    ///
  61    /// <example>
  62    /// `backend/src/main.rs`
  63    ///
  64    /// Notice how the file path starts with `backend`. Without that, the path would be ambiguous and the call would fail!
  65    /// </example>
  66    ///
  67    /// <example>
  68    /// `frontend/db.js`
  69    /// </example>
  70    pub path: PathBuf,
  71
  72    /// The mode of operation on the file. Possible values:
  73    /// - 'create': Create a new file if it doesn't exist. Requires 'content' field.
  74    /// - 'overwrite': Replace the entire contents of an existing file. Requires 'content' field.
  75    /// - 'edit': Make granular edits to an existing file. Requires 'edits' field.
  76    ///
  77    /// When a file already exists or you just created it, prefer editing it as opposed to recreating it from scratch.
  78    pub mode: StreamingEditFileMode,
  79
  80    /// The complete content for the new file (required for 'create' and 'overwrite' modes).
  81    /// This field should contain the entire file content.
  82    #[serde(default, skip_serializing_if = "Option::is_none")]
  83    pub content: Option<String>,
  84
  85    /// List of edit operations to apply sequentially (required for 'edit' mode).
  86    /// Each edit finds `old_text` in the file and replaces it with `new_text`.
  87    #[serde(default, skip_serializing_if = "Option::is_none")]
  88    pub edits: Option<Vec<EditOperation>>,
  89}
  90
  91#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
  92#[serde(rename_all = "snake_case")]
  93pub enum StreamingEditFileMode {
  94    /// Create a new file if it doesn't exist
  95    Create,
  96    /// Replace the entire contents of an existing file
  97    Overwrite,
  98    /// Make granular edits to an existing file
  99    Edit,
 100}
 101
 102/// A single edit operation that replaces old text with new text
 103#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
 104pub struct EditOperation {
 105    /// The exact text to find in the file. This will be matched using fuzzy matching
 106    /// to handle minor differences in whitespace or formatting.
 107    pub old_text: String,
 108    /// The text to replace it with
 109    pub new_text: String,
 110}
 111
 112#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
 113struct StreamingEditFileToolPartialInput {
 114    #[serde(default)]
 115    path: String,
 116    #[serde(default)]
 117    display_description: String,
 118}
 119
 120#[derive(Debug, Serialize, Deserialize)]
 121pub struct StreamingEditFileToolOutput {
 122    #[serde(alias = "original_path")]
 123    input_path: PathBuf,
 124    new_text: String,
 125    old_text: Arc<String>,
 126    #[serde(default)]
 127    diff: String,
 128}
 129
 130impl From<StreamingEditFileToolOutput> for LanguageModelToolResultContent {
 131    fn from(output: StreamingEditFileToolOutput) -> Self {
 132        if output.diff.is_empty() {
 133            "No edits were made.".into()
 134        } else {
 135            format!(
 136                "Edited {}:\n\n```diff\n{}\n```",
 137                output.input_path.display(),
 138                output.diff
 139            )
 140            .into()
 141        }
 142    }
 143}
 144
 145pub struct StreamingEditFileTool {
 146    thread: WeakEntity<Thread>,
 147    language_registry: Arc<LanguageRegistry>,
 148    project: Entity<Project>,
 149    #[allow(dead_code)]
 150    templates: Arc<Templates>,
 151}
 152
 153impl StreamingEditFileTool {
 154    pub fn new(
 155        project: Entity<Project>,
 156        thread: WeakEntity<Thread>,
 157        language_registry: Arc<LanguageRegistry>,
 158        templates: Arc<Templates>,
 159    ) -> Self {
 160        Self {
 161            project,
 162            thread,
 163            language_registry,
 164            templates,
 165        }
 166    }
 167
 168    pub fn with_thread(&self, new_thread: WeakEntity<Thread>) -> Self {
 169        Self {
 170            project: self.project.clone(),
 171            thread: new_thread,
 172            language_registry: self.language_registry.clone(),
 173            templates: self.templates.clone(),
 174        }
 175    }
 176
 177    fn authorize(
 178        &self,
 179        input: &StreamingEditFileToolInput,
 180        event_stream: &ToolCallEventStream,
 181        cx: &mut App,
 182    ) -> Task<Result<()>> {
 183        super::tool_permissions::authorize_file_edit(
 184            EditFileTool::NAME,
 185            &input.path,
 186            &input.display_description,
 187            &self.thread,
 188            event_stream,
 189            cx,
 190        )
 191    }
 192}
 193
 194impl AgentTool for StreamingEditFileTool {
 195    type Input = StreamingEditFileToolInput;
 196    type Output = StreamingEditFileToolOutput;
 197
 198    const NAME: &'static str = "streaming_edit_file";
 199
 200    fn kind() -> acp::ToolKind {
 201        acp::ToolKind::Edit
 202    }
 203
 204    fn initial_title(
 205        &self,
 206        input: Result<Self::Input, serde_json::Value>,
 207        cx: &mut App,
 208    ) -> SharedString {
 209        match input {
 210            Ok(input) => self
 211                .project
 212                .read(cx)
 213                .find_project_path(&input.path, cx)
 214                .and_then(|project_path| {
 215                    self.project
 216                        .read(cx)
 217                        .short_full_path_for_project_path(&project_path, cx)
 218                })
 219                .unwrap_or(input.path.to_string_lossy().into_owned())
 220                .into(),
 221            Err(raw_input) => {
 222                if let Some(input) =
 223                    serde_json::from_value::<StreamingEditFileToolPartialInput>(raw_input).ok()
 224                {
 225                    let path = input.path.trim();
 226                    if !path.is_empty() {
 227                        return self
 228                            .project
 229                            .read(cx)
 230                            .find_project_path(&input.path, cx)
 231                            .and_then(|project_path| {
 232                                self.project
 233                                    .read(cx)
 234                                    .short_full_path_for_project_path(&project_path, cx)
 235                            })
 236                            .unwrap_or(input.path)
 237                            .into();
 238                    }
 239
 240                    let description = input.display_description.trim();
 241                    if !description.is_empty() {
 242                        return description.to_string().into();
 243                    }
 244                }
 245
 246                DEFAULT_UI_TEXT.into()
 247            }
 248        }
 249    }
 250
 251    fn run(
 252        self: Arc<Self>,
 253        input: Self::Input,
 254        event_stream: ToolCallEventStream,
 255        cx: &mut App,
 256    ) -> Task<Result<Self::Output>> {
 257        let Ok(project) = self
 258            .thread
 259            .read_with(cx, |thread, _cx| thread.project().clone())
 260        else {
 261            return Task::ready(Err(anyhow!("thread was dropped")));
 262        };
 263
 264        let project_path = match resolve_path(&input, project.clone(), cx) {
 265            Ok(path) => path,
 266            Err(err) => return Task::ready(Err(anyhow!(err))),
 267        };
 268
 269        let abs_path = project.read(cx).absolute_path(&project_path, cx);
 270        if let Some(abs_path) = abs_path.clone() {
 271            event_stream.update_fields(
 272                ToolCallUpdateFields::new().locations(vec![acp::ToolCallLocation::new(abs_path)]),
 273            );
 274        }
 275
 276        let authorize = self.authorize(&input, &event_stream, cx);
 277
 278        cx.spawn(async move |cx: &mut AsyncApp| {
 279            authorize.await?;
 280
 281            let buffer = project
 282                .update(cx, |project, cx| {
 283                    project.open_buffer(project_path.clone(), cx)
 284                })
 285                .await?;
 286
 287            if let Some(abs_path) = abs_path.as_ref() {
 288                let (last_read_mtime, current_mtime, is_dirty, has_save_tool, has_restore_tool) =
 289                    self.thread.update(cx, |thread, cx| {
 290                        let last_read = thread.file_read_times.get(abs_path).copied();
 291                        let current = buffer
 292                            .read(cx)
 293                            .file()
 294                            .and_then(|file| file.disk_state().mtime());
 295                        let dirty = buffer.read(cx).is_dirty();
 296                        let has_save = thread.has_tool(SaveFileTool::NAME);
 297                        let has_restore = thread.has_tool(RestoreFileFromDiskTool::NAME);
 298                        (last_read, current, dirty, has_save, has_restore)
 299                    })?;
 300
 301                if is_dirty {
 302                    let message = match (has_save_tool, has_restore_tool) {
 303                        (true, true) => {
 304                            "This file has unsaved changes. Ask the user whether they want to keep or discard those changes. \
 305                             If they want to keep them, ask for confirmation then use the save_file tool to save the file, then retry this edit. \
 306                             If they want to discard them, ask for confirmation then use the restore_file_from_disk tool to restore the on-disk contents, then retry this edit."
 307                        }
 308                        (true, false) => {
 309                            "This file has unsaved changes. Ask the user whether they want to keep or discard those changes. \
 310                             If they want to keep them, ask for confirmation then use the save_file tool to save the file, then retry this edit. \
 311                             If they want to discard them, ask the user to manually revert the file, then inform you when it's ok to proceed."
 312                        }
 313                        (false, true) => {
 314                            "This file has unsaved changes. Ask the user whether they want to keep or discard those changes. \
 315                             If they want to keep them, ask the user to manually save the file, then inform you when it's ok to proceed. \
 316                             If they want to discard them, ask for confirmation then use the restore_file_from_disk tool to restore the on-disk contents, then retry this edit."
 317                        }
 318                        (false, false) => {
 319                            "This file has unsaved changes. Ask the user whether they want to keep or discard those changes, \
 320                             then ask them to save or revert the file manually and inform you when it's ok to proceed."
 321                        }
 322                    };
 323                    anyhow::bail!("{}", message);
 324                }
 325
 326                if let (Some(last_read), Some(current)) = (last_read_mtime, current_mtime) {
 327                    if current != last_read {
 328                        anyhow::bail!(
 329                            "The file {} has been modified since you last read it. \
 330                             Please read the file again to get the current state before editing it.",
 331                            input.path.display()
 332                        );
 333                    }
 334                }
 335            }
 336
 337            let diff = cx.new(|cx| Diff::new(buffer.clone(), cx));
 338            event_stream.update_diff(diff.clone());
 339            let _finalize_diff = util::defer({
 340                let diff = diff.downgrade();
 341                let mut cx = cx.clone();
 342                move || {
 343                    diff.update(&mut cx, |diff, cx| diff.finalize(cx)).ok();
 344                }
 345            });
 346
 347            let old_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 348            let old_text = cx
 349                .background_spawn({
 350                    let old_snapshot = old_snapshot.clone();
 351                    async move { Arc::new(old_snapshot.text()) }
 352                })
 353                .await;
 354
 355            let action_log = self.thread.read_with(cx, |thread, _cx| thread.action_log().clone())?;
 356
 357            // Edit the buffer and report edits to the action log as part of the
 358            // same effect cycle, otherwise the edit will be reported as if the
 359            // user made it (due to the buffer subscription in action_log).
 360            match input.mode {
 361                StreamingEditFileMode::Create | StreamingEditFileMode::Overwrite => {
 362                    action_log.update(cx, |log, cx| {
 363                        log.buffer_created(buffer.clone(), cx);
 364                    });
 365                    let content = input.content.ok_or_else(|| {
 366                        anyhow!("'content' field is required for create and overwrite modes")
 367                    })?;
 368                    cx.update(|cx| {
 369                        buffer.update(cx, |buffer, cx| {
 370                            buffer.edit([(0..buffer.len(), content.as_str())], None, cx);
 371                        });
 372                        action_log.update(cx, |log, cx| {
 373                            log.buffer_edited(buffer.clone(), cx);
 374                        });
 375                    });
 376                }
 377                StreamingEditFileMode::Edit => {
 378                    action_log.update(cx, |log, cx| {
 379                        log.buffer_read(buffer.clone(), cx);
 380                    });
 381                    let edits = input.edits.ok_or_else(|| {
 382                        anyhow!("'edits' field is required for edit mode")
 383                    })?;
 384                    // apply_edits now handles buffer_edited internally in the same effect cycle
 385                    apply_edits(&buffer, &action_log, &edits, &diff, &event_stream, &abs_path, cx)?;
 386                }
 387            }
 388
 389            let format_on_save_enabled = buffer.read_with(cx, |buffer, cx| {
 390                let settings = language_settings::language_settings(
 391                    buffer.language().map(|l| l.name()),
 392                    buffer.file(),
 393                    cx,
 394                );
 395                settings.format_on_save != FormatOnSave::Off
 396            });
 397
 398            if format_on_save_enabled {
 399                action_log.update(cx, |log, cx| {
 400                    log.buffer_edited(buffer.clone(), cx);
 401                });
 402
 403                let format_task = project.update(cx, |project, cx| {
 404                    project.format(
 405                        HashSet::from_iter([buffer.clone()]),
 406                        LspFormatTarget::Buffers,
 407                        false,
 408                        FormatTrigger::Save,
 409                        cx,
 410                    )
 411                });
 412                futures::select! {
 413                    result = format_task.fuse() => { result.log_err(); },
 414                    _ = event_stream.cancelled_by_user().fuse() => {
 415                        anyhow::bail!("Edit cancelled by user");
 416                    }
 417                };
 418            }
 419
 420            let save_task = project
 421                .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx));
 422            futures::select! {
 423                result = save_task.fuse() => { result?; },
 424                _ = event_stream.cancelled_by_user().fuse() => {
 425                    anyhow::bail!("Edit cancelled by user");
 426                }
 427            };
 428
 429            action_log.update(cx, |log, cx| {
 430                log.buffer_edited(buffer.clone(), cx);
 431            });
 432
 433            if let Some(abs_path) = abs_path.as_ref() {
 434                if let Some(new_mtime) = buffer.read_with(cx, |buffer, _| {
 435                    buffer.file().and_then(|file| file.disk_state().mtime())
 436                }) {
 437                    self.thread.update(cx, |thread, _| {
 438                        thread.file_read_times.insert(abs_path.to_path_buf(), new_mtime);
 439                    })?;
 440                }
 441            }
 442
 443            let new_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 444            let (new_text, unified_diff) = cx
 445                .background_spawn({
 446                    let new_snapshot = new_snapshot.clone();
 447                    let old_text = old_text.clone();
 448                    async move {
 449                        let new_text = new_snapshot.text();
 450                        let diff = language::unified_diff(&old_text, &new_text);
 451                        (new_text, diff)
 452                    }
 453                })
 454                .await;
 455
 456            let output = StreamingEditFileToolOutput {
 457                input_path: input.path,
 458                new_text,
 459                old_text,
 460                diff: unified_diff,
 461            };
 462
 463            Ok(output)
 464        })
 465    }
 466
 467    fn replay(
 468        &self,
 469        _input: Self::Input,
 470        output: Self::Output,
 471        event_stream: ToolCallEventStream,
 472        cx: &mut App,
 473    ) -> Result<()> {
 474        event_stream.update_diff(cx.new(|cx| {
 475            Diff::finalized(
 476                output.input_path.to_string_lossy().into_owned(),
 477                Some(output.old_text.to_string()),
 478                output.new_text,
 479                self.language_registry.clone(),
 480                cx,
 481            )
 482        }));
 483        Ok(())
 484    }
 485}
 486
 487fn apply_edits(
 488    buffer: &Entity<language::Buffer>,
 489    action_log: &Entity<action_log::ActionLog>,
 490    edits: &[EditOperation],
 491    diff: &Entity<Diff>,
 492    event_stream: &ToolCallEventStream,
 493    abs_path: &Option<PathBuf>,
 494    cx: &mut AsyncApp,
 495) -> Result<()> {
 496    let mut failed_edits = Vec::new();
 497    let mut ambiguous_edits = Vec::new();
 498    let mut resolved_edits: Vec<(Range<usize>, String)> = Vec::new();
 499
 500    // First pass: resolve all edits without applying them
 501    let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
 502    for (index, edit) in edits.iter().enumerate() {
 503        let result = resolve_edit(&snapshot, edit);
 504
 505        match result {
 506            Ok(Some((range, new_text))) => {
 507                // Reveal the range in the diff view
 508                let (start_anchor, end_anchor) = buffer.read_with(cx, |buffer, _cx| {
 509                    (
 510                        buffer.anchor_before(range.start),
 511                        buffer.anchor_after(range.end),
 512                    )
 513                });
 514                diff.update(cx, |card, cx| {
 515                    card.reveal_range(start_anchor..end_anchor, cx)
 516                });
 517                resolved_edits.push((range, new_text));
 518            }
 519            Ok(None) => {
 520                failed_edits.push(index);
 521            }
 522            Err(ranges) => {
 523                ambiguous_edits.push((index, ranges));
 524            }
 525        }
 526    }
 527
 528    // Check for errors before applying any edits
 529    if !failed_edits.is_empty() {
 530        let indices = failed_edits
 531            .iter()
 532            .map(|i| i.to_string())
 533            .collect::<Vec<_>>()
 534            .join(", ");
 535        anyhow::bail!(
 536            "Could not find matching text for edit(s) at index(es): {}. \
 537             The old_text did not match any content in the file. \
 538             Please read the file again to get the current content.",
 539            indices
 540        );
 541    }
 542
 543    if !ambiguous_edits.is_empty() {
 544        let details: Vec<String> = ambiguous_edits
 545            .iter()
 546            .map(|(index, ranges)| {
 547                let lines = ranges
 548                    .iter()
 549                    .map(|r| (snapshot.offset_to_point(r.start).row + 1).to_string())
 550                    .collect::<Vec<_>>()
 551                    .join(", ");
 552                format!("edit {}: matches at lines {}", index, lines)
 553            })
 554            .collect();
 555        anyhow::bail!(
 556            "Some edits matched multiple locations in the file:\n{}. \
 557             Please provide more context in old_text to uniquely identify the location.",
 558            details.join("\n")
 559        );
 560    }
 561
 562    // Sort edits by position so buffer.edit() can handle offset translation
 563    let mut edits_sorted = resolved_edits;
 564    edits_sorted.sort_by(|a, b| a.0.start.cmp(&b.0.start));
 565
 566    // Emit location for the earliest edit in the file
 567    if let Some((first_range, _)) = edits_sorted.first() {
 568        if let Some(abs_path) = abs_path.clone() {
 569            let line = snapshot.offset_to_point(first_range.start).row;
 570            event_stream.update_fields(
 571                ToolCallUpdateFields::new()
 572                    .locations(vec![ToolCallLocation::new(abs_path).line(Some(line))]),
 573            );
 574        }
 575    }
 576
 577    // Validate no overlaps (sorted ascending by start)
 578    for window in edits_sorted.windows(2) {
 579        if let [(earlier_range, _), (later_range, _)] = window
 580            && (earlier_range.end > later_range.start || earlier_range.start == later_range.start)
 581        {
 582            let earlier_start_line = snapshot.offset_to_point(earlier_range.start).row + 1;
 583            let earlier_end_line = snapshot.offset_to_point(earlier_range.end).row + 1;
 584            let later_start_line = snapshot.offset_to_point(later_range.start).row + 1;
 585            let later_end_line = snapshot.offset_to_point(later_range.end).row + 1;
 586            anyhow::bail!(
 587                "Conflicting edit ranges detected: lines {}-{} conflicts with lines {}-{}. \
 588                 Conflicting edit ranges are not allowed, as they would overwrite each other.",
 589                earlier_start_line,
 590                earlier_end_line,
 591                later_start_line,
 592                later_end_line,
 593            );
 594        }
 595    }
 596
 597    // Apply all edits in a single batch and report to action_log in the same
 598    // effect cycle. This prevents the buffer subscription from treating these
 599    // as user edits.
 600    if !edits_sorted.is_empty() {
 601        cx.update(|cx| {
 602            buffer.update(cx, |buffer, cx| {
 603                buffer.edit(
 604                    edits_sorted
 605                        .iter()
 606                        .map(|(range, new_text)| (range.clone(), new_text.as_str())),
 607                    None,
 608                    cx,
 609                );
 610            });
 611            action_log.update(cx, |log, cx| {
 612                log.buffer_edited(buffer.clone(), cx);
 613            });
 614        });
 615    }
 616
 617    Ok(())
 618}
 619
 620/// Resolves an edit operation by finding the matching text in the buffer.
 621/// Returns Ok(Some((range, new_text))) if a unique match is found,
 622/// Ok(None) if no match is found, or Err(ranges) if multiple matches are found.
 623fn resolve_edit(
 624    snapshot: &BufferSnapshot,
 625    edit: &EditOperation,
 626) -> std::result::Result<Option<(Range<usize>, String)>, Vec<Range<usize>>> {
 627    let mut matcher = StreamingFuzzyMatcher::new(snapshot.clone());
 628    matcher.push(&edit.old_text, None);
 629    let matches = matcher.finish();
 630
 631    if matches.is_empty() {
 632        return Ok(None);
 633    }
 634
 635    if matches.len() > 1 {
 636        return Err(matches);
 637    }
 638
 639    let match_range = matches.into_iter().next().expect("checked len above");
 640    Ok(Some((match_range, edit.new_text.clone())))
 641}
 642
 643fn resolve_path(
 644    input: &StreamingEditFileToolInput,
 645    project: Entity<Project>,
 646    cx: &mut App,
 647) -> Result<ProjectPath> {
 648    let project = project.read(cx);
 649
 650    match input.mode {
 651        StreamingEditFileMode::Edit | StreamingEditFileMode::Overwrite => {
 652            let path = project
 653                .find_project_path(&input.path, cx)
 654                .context("Can't edit file: path not found")?;
 655
 656            let entry = project
 657                .entry_for_path(&path, cx)
 658                .context("Can't edit file: path not found")?;
 659
 660            anyhow::ensure!(entry.is_file(), "Can't edit file: path is a directory");
 661            Ok(path)
 662        }
 663
 664        StreamingEditFileMode::Create => {
 665            if let Some(path) = project.find_project_path(&input.path, cx) {
 666                anyhow::ensure!(
 667                    project.entry_for_path(&path, cx).is_none(),
 668                    "Can't create file: file already exists"
 669                );
 670            }
 671
 672            let parent_path = input
 673                .path
 674                .parent()
 675                .context("Can't create file: incorrect path")?;
 676
 677            let parent_project_path = project.find_project_path(&parent_path, cx);
 678
 679            let parent_entry = parent_project_path
 680                .as_ref()
 681                .and_then(|path| project.entry_for_path(path, cx))
 682                .context("Can't create file: parent directory doesn't exist")?;
 683
 684            anyhow::ensure!(
 685                parent_entry.is_dir(),
 686                "Can't create file: parent is not a directory"
 687            );
 688
 689            let file_name = input
 690                .path
 691                .file_name()
 692                .and_then(|file_name| file_name.to_str())
 693                .and_then(|file_name| RelPath::unix(file_name).ok())
 694                .context("Can't create file: invalid filename")?;
 695
 696            let new_file_path = parent_project_path.map(|parent| ProjectPath {
 697                path: parent.path.join(file_name),
 698                ..parent
 699            });
 700
 701            new_file_path.context("Can't create file")
 702        }
 703    }
 704}
 705
 706#[cfg(test)]
 707mod tests {
 708    use super::*;
 709    use crate::{ContextServerRegistry, Templates};
 710    use gpui::{TestAppContext, UpdateGlobal};
 711    use language_model::fake_provider::FakeLanguageModel;
 712    use prompt_store::ProjectContext;
 713    use serde_json::json;
 714    use settings::SettingsStore;
 715    use util::path;
 716
 717    #[gpui::test]
 718    async fn test_streaming_edit_create_file(cx: &mut TestAppContext) {
 719        init_test(cx);
 720
 721        let fs = project::FakeFs::new(cx.executor());
 722        fs.insert_tree("/root", json!({"dir": {}})).await;
 723        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
 724        let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
 725        let context_server_registry =
 726            cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
 727        let model = Arc::new(FakeLanguageModel::default());
 728        let thread = cx.new(|cx| {
 729            crate::Thread::new(
 730                project.clone(),
 731                cx.new(|_cx| ProjectContext::default()),
 732                context_server_registry,
 733                Templates::new(),
 734                Some(model),
 735                cx,
 736            )
 737        });
 738
 739        let result = cx
 740            .update(|cx| {
 741                let input = StreamingEditFileToolInput {
 742                    display_description: "Create new file".into(),
 743                    path: "root/dir/new_file.txt".into(),
 744                    mode: StreamingEditFileMode::Create,
 745                    content: Some("Hello, World!".into()),
 746                    edits: None,
 747                };
 748                Arc::new(StreamingEditFileTool::new(
 749                    project.clone(),
 750                    thread.downgrade(),
 751                    language_registry,
 752                    Templates::new(),
 753                ))
 754                .run(input, ToolCallEventStream::test().0, cx)
 755            })
 756            .await;
 757
 758        assert!(result.is_ok());
 759        let output = result.unwrap();
 760        assert_eq!(output.new_text, "Hello, World!");
 761        assert!(!output.diff.is_empty());
 762    }
 763
 764    #[gpui::test]
 765    async fn test_streaming_edit_overwrite_file(cx: &mut TestAppContext) {
 766        init_test(cx);
 767
 768        let fs = project::FakeFs::new(cx.executor());
 769        fs.insert_tree("/root", json!({"file.txt": "old content"}))
 770            .await;
 771        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
 772        let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
 773        let context_server_registry =
 774            cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
 775        let model = Arc::new(FakeLanguageModel::default());
 776        let thread = cx.new(|cx| {
 777            crate::Thread::new(
 778                project.clone(),
 779                cx.new(|_cx| ProjectContext::default()),
 780                context_server_registry,
 781                Templates::new(),
 782                Some(model),
 783                cx,
 784            )
 785        });
 786
 787        let result = cx
 788            .update(|cx| {
 789                let input = StreamingEditFileToolInput {
 790                    display_description: "Overwrite file".into(),
 791                    path: "root/file.txt".into(),
 792                    mode: StreamingEditFileMode::Overwrite,
 793                    content: Some("new content".into()),
 794                    edits: None,
 795                };
 796                Arc::new(StreamingEditFileTool::new(
 797                    project.clone(),
 798                    thread.downgrade(),
 799                    language_registry,
 800                    Templates::new(),
 801                ))
 802                .run(input, ToolCallEventStream::test().0, cx)
 803            })
 804            .await;
 805
 806        assert!(result.is_ok());
 807        let output = result.unwrap();
 808        assert_eq!(output.new_text, "new content");
 809        assert_eq!(*output.old_text, "old content");
 810    }
 811
 812    #[gpui::test]
 813    async fn test_streaming_edit_granular_edits(cx: &mut TestAppContext) {
 814        init_test(cx);
 815
 816        let fs = project::FakeFs::new(cx.executor());
 817        fs.insert_tree(
 818            "/root",
 819            json!({
 820                "file.txt": "line 1\nline 2\nline 3\n"
 821            }),
 822        )
 823        .await;
 824        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
 825        let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
 826        let context_server_registry =
 827            cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
 828        let model = Arc::new(FakeLanguageModel::default());
 829        let thread = cx.new(|cx| {
 830            crate::Thread::new(
 831                project.clone(),
 832                cx.new(|_cx| ProjectContext::default()),
 833                context_server_registry,
 834                Templates::new(),
 835                Some(model),
 836                cx,
 837            )
 838        });
 839
 840        let result = cx
 841            .update(|cx| {
 842                let input = StreamingEditFileToolInput {
 843                    display_description: "Edit lines".into(),
 844                    path: "root/file.txt".into(),
 845                    mode: StreamingEditFileMode::Edit,
 846                    content: None,
 847                    edits: Some(vec![EditOperation {
 848                        old_text: "line 2".into(),
 849                        new_text: "modified line 2".into(),
 850                    }]),
 851                };
 852                Arc::new(StreamingEditFileTool::new(
 853                    project.clone(),
 854                    thread.downgrade(),
 855                    language_registry,
 856                    Templates::new(),
 857                ))
 858                .run(input, ToolCallEventStream::test().0, cx)
 859            })
 860            .await;
 861
 862        assert!(result.is_ok());
 863        let output = result.unwrap();
 864        assert_eq!(output.new_text, "line 1\nmodified line 2\nline 3\n");
 865    }
 866
 867    #[gpui::test]
 868    async fn test_streaming_edit_multiple_nonoverlapping_edits(cx: &mut TestAppContext) {
 869        init_test(cx);
 870
 871        let fs = project::FakeFs::new(cx.executor());
 872        fs.insert_tree(
 873            "/root",
 874            json!({
 875                "file.txt": "line 1\nline 2\nline 3\nline 4\nline 5\n"
 876            }),
 877        )
 878        .await;
 879        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
 880        let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
 881        let context_server_registry =
 882            cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
 883        let model = Arc::new(FakeLanguageModel::default());
 884        let thread = cx.new(|cx| {
 885            crate::Thread::new(
 886                project.clone(),
 887                cx.new(|_cx| ProjectContext::default()),
 888                context_server_registry,
 889                Templates::new(),
 890                Some(model),
 891                cx,
 892            )
 893        });
 894
 895        let result = cx
 896            .update(|cx| {
 897                let input = StreamingEditFileToolInput {
 898                    display_description: "Edit multiple lines".into(),
 899                    path: "root/file.txt".into(),
 900                    mode: StreamingEditFileMode::Edit,
 901                    content: None,
 902                    edits: Some(vec![
 903                        EditOperation {
 904                            old_text: "line 5".into(),
 905                            new_text: "modified line 5".into(),
 906                        },
 907                        EditOperation {
 908                            old_text: "line 1".into(),
 909                            new_text: "modified line 1".into(),
 910                        },
 911                    ]),
 912                };
 913                Arc::new(StreamingEditFileTool::new(
 914                    project.clone(),
 915                    thread.downgrade(),
 916                    language_registry,
 917                    Templates::new(),
 918                ))
 919                .run(input, ToolCallEventStream::test().0, cx)
 920            })
 921            .await;
 922
 923        assert!(result.is_ok());
 924        let output = result.unwrap();
 925        assert_eq!(
 926            output.new_text,
 927            "modified line 1\nline 2\nline 3\nline 4\nmodified line 5\n"
 928        );
 929    }
 930
 931    #[gpui::test]
 932    async fn test_streaming_edit_adjacent_edits(cx: &mut TestAppContext) {
 933        init_test(cx);
 934
 935        let fs = project::FakeFs::new(cx.executor());
 936        fs.insert_tree(
 937            "/root",
 938            json!({
 939                "file.txt": "line 1\nline 2\nline 3\nline 4\nline 5\n"
 940            }),
 941        )
 942        .await;
 943        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
 944        let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
 945        let context_server_registry =
 946            cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
 947        let model = Arc::new(FakeLanguageModel::default());
 948        let thread = cx.new(|cx| {
 949            crate::Thread::new(
 950                project.clone(),
 951                cx.new(|_cx| ProjectContext::default()),
 952                context_server_registry,
 953                Templates::new(),
 954                Some(model),
 955                cx,
 956            )
 957        });
 958
 959        let result = cx
 960            .update(|cx| {
 961                let input = StreamingEditFileToolInput {
 962                    display_description: "Edit adjacent lines".into(),
 963                    path: "root/file.txt".into(),
 964                    mode: StreamingEditFileMode::Edit,
 965                    content: None,
 966                    edits: Some(vec![
 967                        EditOperation {
 968                            old_text: "line 2".into(),
 969                            new_text: "modified line 2".into(),
 970                        },
 971                        EditOperation {
 972                            old_text: "line 3".into(),
 973                            new_text: "modified line 3".into(),
 974                        },
 975                    ]),
 976                };
 977                Arc::new(StreamingEditFileTool::new(
 978                    project.clone(),
 979                    thread.downgrade(),
 980                    language_registry,
 981                    Templates::new(),
 982                ))
 983                .run(input, ToolCallEventStream::test().0, cx)
 984            })
 985            .await;
 986
 987        assert!(result.is_ok());
 988        let output = result.unwrap();
 989        assert_eq!(
 990            output.new_text,
 991            "line 1\nmodified line 2\nmodified line 3\nline 4\nline 5\n"
 992        );
 993    }
 994
 995    #[gpui::test]
 996    async fn test_streaming_edit_ascending_order_edits(cx: &mut TestAppContext) {
 997        init_test(cx);
 998
 999        let fs = project::FakeFs::new(cx.executor());
1000        fs.insert_tree(
1001            "/root",
1002            json!({
1003                "file.txt": "line 1\nline 2\nline 3\nline 4\nline 5\n"
1004            }),
1005        )
1006        .await;
1007        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
1008        let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1009        let context_server_registry =
1010            cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1011        let model = Arc::new(FakeLanguageModel::default());
1012        let thread = cx.new(|cx| {
1013            crate::Thread::new(
1014                project.clone(),
1015                cx.new(|_cx| ProjectContext::default()),
1016                context_server_registry,
1017                Templates::new(),
1018                Some(model),
1019                cx,
1020            )
1021        });
1022
1023        let result = cx
1024            .update(|cx| {
1025                let input = StreamingEditFileToolInput {
1026                    display_description: "Edit multiple lines in ascending order".into(),
1027                    path: "root/file.txt".into(),
1028                    mode: StreamingEditFileMode::Edit,
1029                    content: None,
1030                    edits: Some(vec![
1031                        EditOperation {
1032                            old_text: "line 1".into(),
1033                            new_text: "modified line 1".into(),
1034                        },
1035                        EditOperation {
1036                            old_text: "line 5".into(),
1037                            new_text: "modified line 5".into(),
1038                        },
1039                    ]),
1040                };
1041                Arc::new(StreamingEditFileTool::new(
1042                    project.clone(),
1043                    thread.downgrade(),
1044                    language_registry,
1045                    Templates::new(),
1046                ))
1047                .run(input, ToolCallEventStream::test().0, cx)
1048            })
1049            .await;
1050
1051        assert!(result.is_ok());
1052        let output = result.unwrap();
1053        assert_eq!(
1054            output.new_text,
1055            "modified line 1\nline 2\nline 3\nline 4\nmodified line 5\n"
1056        );
1057    }
1058
1059    #[gpui::test]
1060    async fn test_streaming_edit_nonexistent_file(cx: &mut TestAppContext) {
1061        init_test(cx);
1062
1063        let fs = project::FakeFs::new(cx.executor());
1064        fs.insert_tree("/root", json!({})).await;
1065        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
1066        let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1067        let context_server_registry =
1068            cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1069        let model = Arc::new(FakeLanguageModel::default());
1070        let thread = cx.new(|cx| {
1071            crate::Thread::new(
1072                project.clone(),
1073                cx.new(|_cx| ProjectContext::default()),
1074                context_server_registry,
1075                Templates::new(),
1076                Some(model),
1077                cx,
1078            )
1079        });
1080
1081        let result = cx
1082            .update(|cx| {
1083                let input = StreamingEditFileToolInput {
1084                    display_description: "Some edit".into(),
1085                    path: "root/nonexistent_file.txt".into(),
1086                    mode: StreamingEditFileMode::Edit,
1087                    content: None,
1088                    edits: Some(vec![EditOperation {
1089                        old_text: "foo".into(),
1090                        new_text: "bar".into(),
1091                    }]),
1092                };
1093                Arc::new(StreamingEditFileTool::new(
1094                    project,
1095                    thread.downgrade(),
1096                    language_registry,
1097                    Templates::new(),
1098                ))
1099                .run(input, ToolCallEventStream::test().0, cx)
1100            })
1101            .await;
1102
1103        assert_eq!(
1104            result.unwrap_err().to_string(),
1105            "Can't edit file: path not found"
1106        );
1107    }
1108
1109    #[gpui::test]
1110    async fn test_streaming_edit_failed_match(cx: &mut TestAppContext) {
1111        init_test(cx);
1112
1113        let fs = project::FakeFs::new(cx.executor());
1114        fs.insert_tree("/root", json!({"file.txt": "hello world"}))
1115            .await;
1116        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
1117        let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1118        let context_server_registry =
1119            cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1120        let model = Arc::new(FakeLanguageModel::default());
1121        let thread = cx.new(|cx| {
1122            crate::Thread::new(
1123                project.clone(),
1124                cx.new(|_cx| ProjectContext::default()),
1125                context_server_registry,
1126                Templates::new(),
1127                Some(model),
1128                cx,
1129            )
1130        });
1131
1132        let result = cx
1133            .update(|cx| {
1134                let input = StreamingEditFileToolInput {
1135                    display_description: "Edit file".into(),
1136                    path: "root/file.txt".into(),
1137                    mode: StreamingEditFileMode::Edit,
1138                    content: None,
1139                    edits: Some(vec![EditOperation {
1140                        old_text: "nonexistent text that is not in the file".into(),
1141                        new_text: "replacement".into(),
1142                    }]),
1143                };
1144                Arc::new(StreamingEditFileTool::new(
1145                    project,
1146                    thread.downgrade(),
1147                    language_registry,
1148                    Templates::new(),
1149                ))
1150                .run(input, ToolCallEventStream::test().0, cx)
1151            })
1152            .await;
1153
1154        assert!(result.is_err());
1155        assert!(
1156            result
1157                .unwrap_err()
1158                .to_string()
1159                .contains("Could not find matching text")
1160        );
1161    }
1162
1163    #[gpui::test]
1164    async fn test_streaming_edit_overlapping_edits_out_of_order(cx: &mut TestAppContext) {
1165        init_test(cx);
1166
1167        let fs = project::FakeFs::new(cx.executor());
1168        // Multi-line file so the line-based fuzzy matcher can resolve each edit.
1169        fs.insert_tree(
1170            "/root",
1171            json!({
1172                "file.txt": "line 1\nline 2\nline 3\nline 4\nline 5\n"
1173            }),
1174        )
1175        .await;
1176        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
1177        let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1178        let context_server_registry =
1179            cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1180        let model = Arc::new(FakeLanguageModel::default());
1181        let thread = cx.new(|cx| {
1182            crate::Thread::new(
1183                project.clone(),
1184                cx.new(|_cx| ProjectContext::default()),
1185                context_server_registry,
1186                Templates::new(),
1187                Some(model),
1188                cx,
1189            )
1190        });
1191
1192        // Edit A spans lines 3-4, edit B spans lines 2-3. They overlap on
1193        // "line 3" and are given in descending file order so the ascending
1194        // sort must reorder them before the pairwise overlap check can
1195        // detect them correctly.
1196        let result = cx
1197            .update(|cx| {
1198                let input = StreamingEditFileToolInput {
1199                    display_description: "Overlapping edits".into(),
1200                    path: "root/file.txt".into(),
1201                    mode: StreamingEditFileMode::Edit,
1202                    content: None,
1203                    edits: Some(vec![
1204                        EditOperation {
1205                            old_text: "line 3\nline 4".into(),
1206                            new_text: "SECOND".into(),
1207                        },
1208                        EditOperation {
1209                            old_text: "line 2\nline 3".into(),
1210                            new_text: "FIRST".into(),
1211                        },
1212                    ]),
1213                };
1214                Arc::new(StreamingEditFileTool::new(
1215                    project,
1216                    thread.downgrade(),
1217                    language_registry,
1218                    Templates::new(),
1219                ))
1220                .run(input, ToolCallEventStream::test().0, cx)
1221            })
1222            .await;
1223
1224        let error = result.unwrap_err();
1225        let error_message = error.to_string();
1226        assert!(
1227            error_message.contains("Conflicting edit ranges detected"),
1228            "Expected 'Conflicting edit ranges detected' but got: {error_message}"
1229        );
1230    }
1231
1232    fn init_test(cx: &mut TestAppContext) {
1233        cx.update(|cx| {
1234            let settings_store = SettingsStore::test(cx);
1235            cx.set_global(settings_store);
1236            SettingsStore::update_global(cx, |store: &mut SettingsStore, cx| {
1237                store.update_user_settings(cx, |settings| {
1238                    settings
1239                        .project
1240                        .all_languages
1241                        .defaults
1242                        .ensure_final_newline_on_save = Some(false);
1243                });
1244            });
1245        });
1246    }
1247}