streaming_edit_file_tool.rs

  1use super::edit_file_tool::EditFileTool;
  2use super::restore_file_from_disk_tool::RestoreFileFromDiskTool;
  3use super::save_file_tool::SaveFileTool;
  4use crate::{
  5    AgentTool, Templates, Thread, ToolCallEventStream, ToolPermissionDecision,
  6    decide_permission_from_settings, edit_agent::streaming_fuzzy_matcher::StreamingFuzzyMatcher,
  7};
  8use acp_thread::Diff;
  9use agent_client_protocol::{self as acp, ToolCallLocation, ToolCallUpdateFields};
 10use anyhow::{Context as _, Result, anyhow};
 11use gpui::{App, AppContext, AsyncApp, Entity, Task, WeakEntity};
 12use language::LanguageRegistry;
 13use language_model::LanguageModelToolResultContent;
 14use paths;
 15use project::{Project, ProjectPath};
 16use schemars::JsonSchema;
 17use serde::{Deserialize, Serialize};
 18use settings::Settings;
 19use std::ffi::OsStr;
 20use std::ops::Range;
 21use std::path::{Path, PathBuf};
 22use std::sync::Arc;
 23use text::BufferSnapshot;
 24use ui::SharedString;
 25use util::rel_path::RelPath;
 26
 27const DEFAULT_UI_TEXT: &str = "Editing file";
 28
 29/// This is a tool for creating a new file or editing an existing file. For moving or renaming files, you should generally use the `terminal` tool with the 'mv' command instead.
 30///
 31/// Before using this tool:
 32///
 33/// 1. Use the `read_file` tool to understand the file's contents and context
 34///
 35/// 2. Verify the directory path is correct (only applicable when creating new files):
 36///    - Use the `list_directory` tool to verify the parent directory exists and is the correct location
 37#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
 38pub struct StreamingEditFileToolInput {
 39    /// A one-line, user-friendly markdown description of the edit. This will be shown in the UI.
 40    ///
 41    /// Be terse, but also descriptive in what you want to achieve with this edit. Avoid generic instructions.
 42    ///
 43    /// NEVER mention the file path in this description.
 44    ///
 45    /// <example>Fix API endpoint URLs</example>
 46    /// <example>Update copyright year in `page_footer`</example>
 47    ///
 48    /// Make sure to include this field before all the others in the input object so that we can display it immediately.
 49    pub display_description: String,
 50
 51    /// The full path of the file to create or modify in the project.
 52    ///
 53    /// WARNING: When specifying which file path need changing, you MUST start each path with one of the project's root directories.
 54    ///
 55    /// The following examples assume we have two root directories in the project:
 56    /// - /a/b/backend
 57    /// - /c/d/frontend
 58    ///
 59    /// <example>
 60    /// `backend/src/main.rs`
 61    ///
 62    /// Notice how the file path starts with `backend`. Without that, the path would be ambiguous and the call would fail!
 63    /// </example>
 64    ///
 65    /// <example>
 66    /// `frontend/db.js`
 67    /// </example>
 68    pub path: PathBuf,
 69
 70    /// The mode of operation on the file. Possible values:
 71    /// - 'create': Create a new file if it doesn't exist. Requires 'content' field.
 72    /// - 'overwrite': Replace the entire contents of an existing file. Requires 'content' field.
 73    /// - 'edit': Make granular edits to an existing file. Requires 'edits' field.
 74    ///
 75    /// When a file already exists or you just created it, prefer editing it as opposed to recreating it from scratch.
 76    pub mode: StreamingEditFileMode,
 77
 78    /// The complete content for the new file (required for 'create' and 'overwrite' modes).
 79    /// This field should contain the entire file content.
 80    #[serde(default, skip_serializing_if = "Option::is_none")]
 81    pub content: Option<String>,
 82
 83    /// List of edit operations to apply sequentially (required for 'edit' mode).
 84    /// Each edit finds `old_text` in the file and replaces it with `new_text`.
 85    #[serde(default, skip_serializing_if = "Option::is_none")]
 86    pub edits: Option<Vec<EditOperation>>,
 87}
 88
 89#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
 90#[serde(rename_all = "snake_case")]
 91pub enum StreamingEditFileMode {
 92    /// Create a new file if it doesn't exist
 93    Create,
 94    /// Replace the entire contents of an existing file
 95    Overwrite,
 96    /// Make granular edits to an existing file
 97    Edit,
 98}
 99
100/// A single edit operation that replaces old text with new text
101#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
102pub struct EditOperation {
103    /// The exact text to find in the file. This will be matched using fuzzy matching
104    /// to handle minor differences in whitespace or formatting.
105    pub old_text: String,
106    /// The text to replace it with
107    pub new_text: String,
108}
109
110#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
111struct StreamingEditFileToolPartialInput {
112    #[serde(default)]
113    path: String,
114    #[serde(default)]
115    display_description: String,
116}
117
118#[derive(Debug, Serialize, Deserialize)]
119pub struct StreamingEditFileToolOutput {
120    #[serde(alias = "original_path")]
121    input_path: PathBuf,
122    new_text: String,
123    old_text: Arc<String>,
124    #[serde(default)]
125    diff: String,
126}
127
128impl From<StreamingEditFileToolOutput> for LanguageModelToolResultContent {
129    fn from(output: StreamingEditFileToolOutput) -> Self {
130        if output.diff.is_empty() {
131            "No edits were made.".into()
132        } else {
133            format!(
134                "Edited {}:\n\n```diff\n{}\n```",
135                output.input_path.display(),
136                output.diff
137            )
138            .into()
139        }
140    }
141}
142
143pub struct StreamingEditFileTool {
144    thread: WeakEntity<Thread>,
145    language_registry: Arc<LanguageRegistry>,
146    project: Entity<Project>,
147    #[allow(dead_code)]
148    templates: Arc<Templates>,
149}
150
151impl StreamingEditFileTool {
152    pub fn new(
153        project: Entity<Project>,
154        thread: WeakEntity<Thread>,
155        language_registry: Arc<LanguageRegistry>,
156        templates: Arc<Templates>,
157    ) -> Self {
158        Self {
159            project,
160            thread,
161            language_registry,
162            templates,
163        }
164    }
165
166    fn authorize(
167        &self,
168        input: &StreamingEditFileToolInput,
169        event_stream: &ToolCallEventStream,
170        cx: &mut App,
171    ) -> Task<Result<()>> {
172        let path_str = input.path.to_string_lossy();
173        let settings = agent_settings::AgentSettings::get_global(cx);
174        let decision = decide_permission_from_settings(Self::NAME, &path_str, settings);
175
176        match decision {
177            ToolPermissionDecision::Allow => return Task::ready(Ok(())),
178            ToolPermissionDecision::Deny(reason) => {
179                return Task::ready(Err(anyhow!("{}", reason)));
180            }
181            ToolPermissionDecision::Confirm => {}
182        }
183
184        let local_settings_folder = paths::local_settings_folder_name();
185        let path = Path::new(&input.path);
186        if path.components().any(|component| {
187            component.as_os_str() == <_ as AsRef<OsStr>>::as_ref(&local_settings_folder)
188        }) {
189            let context = crate::ToolPermissionContext {
190                tool_name: EditFileTool::NAME.to_string(),
191                input_value: path_str.to_string(),
192            };
193            return event_stream.authorize(
194                format!("{} (local settings)", input.display_description),
195                context,
196                cx,
197            );
198        }
199
200        if let Ok(canonical_path) = std::fs::canonicalize(&input.path)
201            && canonical_path.starts_with(paths::config_dir())
202        {
203            let context = crate::ToolPermissionContext {
204                tool_name: EditFileTool::NAME.to_string(),
205                input_value: path_str.to_string(),
206            };
207            return event_stream.authorize(
208                format!("{} (global settings)", input.display_description),
209                context,
210                cx,
211            );
212        }
213
214        let Ok(project_path) = self.thread.read_with(cx, |thread, cx| {
215            thread.project().read(cx).find_project_path(&input.path, cx)
216        }) else {
217            return Task::ready(Err(anyhow!("thread was dropped")));
218        };
219
220        if project_path.is_some() {
221            Task::ready(Ok(()))
222        } else {
223            let context = crate::ToolPermissionContext {
224                tool_name: EditFileTool::NAME.to_string(),
225                input_value: path_str.to_string(),
226            };
227            event_stream.authorize(&input.display_description, context, cx)
228        }
229    }
230}
231
232impl AgentTool for StreamingEditFileTool {
233    type Input = StreamingEditFileToolInput;
234    type Output = StreamingEditFileToolOutput;
235
236    const NAME: &'static str = "streaming_edit_file";
237
238    fn kind() -> acp::ToolKind {
239        acp::ToolKind::Edit
240    }
241
242    fn initial_title(
243        &self,
244        input: Result<Self::Input, serde_json::Value>,
245        cx: &mut App,
246    ) -> SharedString {
247        match input {
248            Ok(input) => self
249                .project
250                .read(cx)
251                .find_project_path(&input.path, cx)
252                .and_then(|project_path| {
253                    self.project
254                        .read(cx)
255                        .short_full_path_for_project_path(&project_path, cx)
256                })
257                .unwrap_or(input.path.to_string_lossy().into_owned())
258                .into(),
259            Err(raw_input) => {
260                if let Some(input) =
261                    serde_json::from_value::<StreamingEditFileToolPartialInput>(raw_input).ok()
262                {
263                    let path = input.path.trim();
264                    if !path.is_empty() {
265                        return self
266                            .project
267                            .read(cx)
268                            .find_project_path(&input.path, cx)
269                            .and_then(|project_path| {
270                                self.project
271                                    .read(cx)
272                                    .short_full_path_for_project_path(&project_path, cx)
273                            })
274                            .unwrap_or(input.path)
275                            .into();
276                    }
277
278                    let description = input.display_description.trim();
279                    if !description.is_empty() {
280                        return description.to_string().into();
281                    }
282                }
283
284                DEFAULT_UI_TEXT.into()
285            }
286        }
287    }
288
289    fn run(
290        self: Arc<Self>,
291        input: Self::Input,
292        event_stream: ToolCallEventStream,
293        cx: &mut App,
294    ) -> Task<Result<Self::Output>> {
295        let Ok(project) = self
296            .thread
297            .read_with(cx, |thread, _cx| thread.project().clone())
298        else {
299            return Task::ready(Err(anyhow!("thread was dropped")));
300        };
301
302        let project_path = match resolve_path(&input, project.clone(), cx) {
303            Ok(path) => path,
304            Err(err) => return Task::ready(Err(anyhow!(err))),
305        };
306
307        let abs_path = project.read(cx).absolute_path(&project_path, cx);
308        if let Some(abs_path) = abs_path.clone() {
309            event_stream.update_fields(
310                ToolCallUpdateFields::new().locations(vec![acp::ToolCallLocation::new(abs_path)]),
311            );
312        }
313
314        let authorize = self.authorize(&input, &event_stream, cx);
315
316        cx.spawn(async move |cx: &mut AsyncApp| {
317            authorize.await?;
318
319            let buffer = project
320                .update(cx, |project, cx| {
321                    project.open_buffer(project_path.clone(), cx)
322                })
323                .await?;
324
325            if let Some(abs_path) = abs_path.as_ref() {
326                let (last_read_mtime, current_mtime, is_dirty, has_save_tool, has_restore_tool) =
327                    self.thread.update(cx, |thread, cx| {
328                        let last_read = thread.file_read_times.get(abs_path).copied();
329                        let current = buffer
330                            .read(cx)
331                            .file()
332                            .and_then(|file| file.disk_state().mtime());
333                        let dirty = buffer.read(cx).is_dirty();
334                        let has_save = thread.has_tool(SaveFileTool::NAME);
335                        let has_restore = thread.has_tool(RestoreFileFromDiskTool::NAME);
336                        (last_read, current, dirty, has_save, has_restore)
337                    })?;
338
339                if is_dirty {
340                    let message = match (has_save_tool, has_restore_tool) {
341                        (true, true) => {
342                            "This file has unsaved changes. Ask the user whether they want to keep or discard those changes. \
343                             If they want to keep them, ask for confirmation then use the save_file tool to save the file, then retry this edit. \
344                             If they want to discard them, ask for confirmation then use the restore_file_from_disk tool to restore the on-disk contents, then retry this edit."
345                        }
346                        (true, false) => {
347                            "This file has unsaved changes. Ask the user whether they want to keep or discard those changes. \
348                             If they want to keep them, ask for confirmation then use the save_file tool to save the file, then retry this edit. \
349                             If they want to discard them, ask the user to manually revert the file, then inform you when it's ok to proceed."
350                        }
351                        (false, true) => {
352                            "This file has unsaved changes. Ask the user whether they want to keep or discard those changes. \
353                             If they want to keep them, ask the user to manually save the file, then inform you when it's ok to proceed. \
354                             If they want to discard them, ask for confirmation then use the restore_file_from_disk tool to restore the on-disk contents, then retry this edit."
355                        }
356                        (false, false) => {
357                            "This file has unsaved changes. Ask the user whether they want to keep or discard those changes, \
358                             then ask them to save or revert the file manually and inform you when it's ok to proceed."
359                        }
360                    };
361                    anyhow::bail!("{}", message);
362                }
363
364                if let (Some(last_read), Some(current)) = (last_read_mtime, current_mtime) {
365                    if current != last_read {
366                        anyhow::bail!(
367                            "The file {} has been modified since you last read it. \
368                             Please read the file again to get the current state before editing it.",
369                            input.path.display()
370                        );
371                    }
372                }
373            }
374
375            let diff = cx.new(|cx| Diff::new(buffer.clone(), cx));
376            event_stream.update_diff(diff.clone());
377            let _finalize_diff = util::defer({
378                let diff = diff.downgrade();
379                let mut cx = cx.clone();
380                move || {
381                    diff.update(&mut cx, |diff, cx| diff.finalize(cx)).ok();
382                }
383            });
384
385            let old_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
386            let old_text = cx
387                .background_spawn({
388                    let old_snapshot = old_snapshot.clone();
389                    async move { Arc::new(old_snapshot.text()) }
390                })
391                .await;
392
393            let action_log = self.thread.read_with(cx, |thread, _cx| thread.action_log().clone())?;
394
395            // Edit the buffer and report edits to the action log as part of the
396            // same effect cycle, otherwise the edit will be reported as if the
397            // user made it (due to the buffer subscription in action_log).
398            match input.mode {
399                StreamingEditFileMode::Create | StreamingEditFileMode::Overwrite => {
400                    action_log.update(cx, |log, cx| {
401                        log.buffer_created(buffer.clone(), cx);
402                    });
403                    let content = input.content.ok_or_else(|| {
404                        anyhow!("'content' field is required for create and overwrite modes")
405                    })?;
406                    cx.update(|cx| {
407                        buffer.update(cx, |buffer, cx| {
408                            buffer.edit([(0..buffer.len(), content.as_str())], None, cx);
409                        });
410                        action_log.update(cx, |log, cx| {
411                            log.buffer_edited(buffer.clone(), cx);
412                        });
413                    });
414                }
415                StreamingEditFileMode::Edit => {
416                    action_log.update(cx, |log, cx| {
417                        log.buffer_read(buffer.clone(), cx);
418                    });
419                    let edits = input.edits.ok_or_else(|| {
420                        anyhow!("'edits' field is required for edit mode")
421                    })?;
422                    // apply_edits now handles buffer_edited internally in the same effect cycle
423                    apply_edits(&buffer, &action_log, &edits, &diff, &event_stream, &abs_path, cx)?;
424                }
425            }
426
427            project
428                .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))
429                .await?;
430
431            action_log.update(cx, |log, cx| {
432                log.buffer_edited(buffer.clone(), cx);
433            });
434
435            if let Some(abs_path) = abs_path.as_ref() {
436                if let Some(new_mtime) = buffer.read_with(cx, |buffer, _| {
437                    buffer.file().and_then(|file| file.disk_state().mtime())
438                }) {
439                    self.thread.update(cx, |thread, _| {
440                        thread.file_read_times.insert(abs_path.to_path_buf(), new_mtime);
441                    })?;
442                }
443            }
444
445            let new_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
446            let (new_text, unified_diff) = cx
447                .background_spawn({
448                    let new_snapshot = new_snapshot.clone();
449                    let old_text = old_text.clone();
450                    async move {
451                        let new_text = new_snapshot.text();
452                        let diff = language::unified_diff(&old_text, &new_text);
453                        (new_text, diff)
454                    }
455                })
456                .await;
457
458            let output = StreamingEditFileToolOutput {
459                input_path: input.path,
460                new_text,
461                old_text,
462                diff: unified_diff,
463            };
464
465            Ok(output)
466        })
467    }
468
469    fn replay(
470        &self,
471        _input: Self::Input,
472        output: Self::Output,
473        event_stream: ToolCallEventStream,
474        cx: &mut App,
475    ) -> Result<()> {
476        event_stream.update_diff(cx.new(|cx| {
477            Diff::finalized(
478                output.input_path.to_string_lossy().into_owned(),
479                Some(output.old_text.to_string()),
480                output.new_text,
481                self.language_registry.clone(),
482                cx,
483            )
484        }));
485        Ok(())
486    }
487}
488
489fn apply_edits(
490    buffer: &Entity<language::Buffer>,
491    action_log: &Entity<action_log::ActionLog>,
492    edits: &[EditOperation],
493    diff: &Entity<Diff>,
494    event_stream: &ToolCallEventStream,
495    abs_path: &Option<PathBuf>,
496    cx: &mut AsyncApp,
497) -> Result<()> {
498    let mut failed_edits = Vec::new();
499    let mut ambiguous_edits = Vec::new();
500    let mut resolved_edits: Vec<(Range<usize>, String)> = Vec::new();
501    let mut first_edit_line: Option<u32> = None;
502
503    // First pass: resolve all edits without applying them
504    for (index, edit) in edits.iter().enumerate() {
505        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
506        let result = resolve_edit(&snapshot, edit);
507
508        match result {
509            Ok(Some((range, new_text))) => {
510                if first_edit_line.is_none() {
511                    first_edit_line = Some(snapshot.offset_to_point(range.start).row);
512                }
513                // Reveal the range in the diff view
514                let start_anchor =
515                    buffer.read_with(cx, |buffer, _cx| buffer.anchor_before(range.start));
516                let end_anchor = buffer.read_with(cx, |buffer, _cx| buffer.anchor_after(range.end));
517                diff.update(cx, |card, cx| {
518                    card.reveal_range(start_anchor..end_anchor, cx)
519                });
520                resolved_edits.push((range, new_text));
521            }
522            Ok(None) => {
523                failed_edits.push(index);
524            }
525            Err(ranges) => {
526                ambiguous_edits.push((index, ranges));
527            }
528        }
529    }
530
531    // Check for errors before applying any edits
532    if !failed_edits.is_empty() {
533        let indices = failed_edits
534            .iter()
535            .map(|i| i.to_string())
536            .collect::<Vec<_>>()
537            .join(", ");
538        anyhow::bail!(
539            "Could not find matching text for edit(s) at index(es): {}. \
540             The old_text did not match any content in the file. \
541             Please read the file again to get the current content.",
542            indices
543        );
544    }
545
546    if !ambiguous_edits.is_empty() {
547        let details: Vec<String> = ambiguous_edits
548            .iter()
549            .map(|(index, ranges)| {
550                let lines = ranges
551                    .iter()
552                    .map(|r| r.start.to_string())
553                    .collect::<Vec<_>>()
554                    .join(", ");
555                format!("edit {}: matches at lines {}", index, lines)
556            })
557            .collect();
558        anyhow::bail!(
559            "Some edits matched multiple locations in the file:\n{}. \
560             Please provide more context in old_text to uniquely identify the location.",
561            details.join("\n")
562        );
563    }
564
565    // Emit location for the first edit
566    if let Some(line) = first_edit_line {
567        if let Some(abs_path) = abs_path.clone() {
568            event_stream.update_fields(
569                ToolCallUpdateFields::new()
570                    .locations(vec![ToolCallLocation::new(abs_path).line(Some(line))]),
571            );
572        }
573    }
574
575    // Second pass: apply all edits and report to action_log in the same effect cycle.
576    // This prevents the buffer subscription from treating these as user edits.
577    if !resolved_edits.is_empty() {
578        cx.update(|cx| {
579            buffer.update(cx, |buffer, cx| {
580                // Apply edits in reverse order so offsets remain valid
581                let mut edits_sorted: Vec<_> = resolved_edits.into_iter().collect();
582                edits_sorted.sort_by(|a, b| b.0.start.cmp(&a.0.start));
583                for (range, new_text) in edits_sorted {
584                    buffer.edit([(range, new_text.as_str())], None, cx);
585                }
586            });
587            action_log.update(cx, |log, cx| {
588                log.buffer_edited(buffer.clone(), cx);
589            });
590        });
591    }
592
593    Ok(())
594}
595
596/// Resolves an edit operation by finding the matching text in the buffer.
597/// Returns Ok(Some((range, new_text))) if a unique match is found,
598/// Ok(None) if no match is found, or Err(ranges) if multiple matches are found.
599fn resolve_edit(
600    snapshot: &BufferSnapshot,
601    edit: &EditOperation,
602) -> std::result::Result<Option<(Range<usize>, String)>, Vec<Range<usize>>> {
603    let mut matcher = StreamingFuzzyMatcher::new(snapshot.clone());
604    matcher.push(&edit.old_text, None);
605    let matches = matcher.finish();
606
607    if matches.is_empty() {
608        return Ok(None);
609    }
610
611    if matches.len() > 1 {
612        return Err(matches);
613    }
614
615    let match_range = matches.into_iter().next().expect("checked len above");
616    Ok(Some((match_range, edit.new_text.clone())))
617}
618
619fn resolve_path(
620    input: &StreamingEditFileToolInput,
621    project: Entity<Project>,
622    cx: &mut App,
623) -> Result<ProjectPath> {
624    let project = project.read(cx);
625
626    match input.mode {
627        StreamingEditFileMode::Edit | StreamingEditFileMode::Overwrite => {
628            let path = project
629                .find_project_path(&input.path, cx)
630                .context("Can't edit file: path not found")?;
631
632            let entry = project
633                .entry_for_path(&path, cx)
634                .context("Can't edit file: path not found")?;
635
636            anyhow::ensure!(entry.is_file(), "Can't edit file: path is a directory");
637            Ok(path)
638        }
639
640        StreamingEditFileMode::Create => {
641            if let Some(path) = project.find_project_path(&input.path, cx) {
642                anyhow::ensure!(
643                    project.entry_for_path(&path, cx).is_none(),
644                    "Can't create file: file already exists"
645                );
646            }
647
648            let parent_path = input
649                .path
650                .parent()
651                .context("Can't create file: incorrect path")?;
652
653            let parent_project_path = project.find_project_path(&parent_path, cx);
654
655            let parent_entry = parent_project_path
656                .as_ref()
657                .and_then(|path| project.entry_for_path(path, cx))
658                .context("Can't create file: parent directory doesn't exist")?;
659
660            anyhow::ensure!(
661                parent_entry.is_dir(),
662                "Can't create file: parent is not a directory"
663            );
664
665            let file_name = input
666                .path
667                .file_name()
668                .and_then(|file_name| file_name.to_str())
669                .and_then(|file_name| RelPath::unix(file_name).ok())
670                .context("Can't create file: invalid filename")?;
671
672            let new_file_path = parent_project_path.map(|parent| ProjectPath {
673                path: parent.path.join(file_name),
674                ..parent
675            });
676
677            new_file_path.context("Can't create file")
678        }
679    }
680}
681
682#[cfg(test)]
683mod tests {
684    use super::*;
685    use crate::{ContextServerRegistry, Templates};
686    use gpui::TestAppContext;
687    use language_model::fake_provider::FakeLanguageModel;
688    use prompt_store::ProjectContext;
689    use serde_json::json;
690    use settings::SettingsStore;
691    use util::path;
692
693    #[gpui::test]
694    async fn test_streaming_edit_create_file(cx: &mut TestAppContext) {
695        init_test(cx);
696
697        let fs = project::FakeFs::new(cx.executor());
698        fs.insert_tree("/root", json!({"dir": {}})).await;
699        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
700        let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
701        let context_server_registry =
702            cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
703        let model = Arc::new(FakeLanguageModel::default());
704        let thread = cx.new(|cx| {
705            crate::Thread::new(
706                project.clone(),
707                cx.new(|_cx| ProjectContext::default()),
708                context_server_registry,
709                Templates::new(),
710                Some(model),
711                cx,
712            )
713        });
714
715        let result = cx
716            .update(|cx| {
717                let input = StreamingEditFileToolInput {
718                    display_description: "Create new file".into(),
719                    path: "root/dir/new_file.txt".into(),
720                    mode: StreamingEditFileMode::Create,
721                    content: Some("Hello, World!".into()),
722                    edits: None,
723                };
724                Arc::new(StreamingEditFileTool::new(
725                    project.clone(),
726                    thread.downgrade(),
727                    language_registry,
728                    Templates::new(),
729                ))
730                .run(input, ToolCallEventStream::test().0, cx)
731            })
732            .await;
733
734        assert!(result.is_ok());
735        let output = result.unwrap();
736        assert_eq!(output.new_text, "Hello, World!");
737        assert!(!output.diff.is_empty());
738    }
739
740    #[gpui::test]
741    async fn test_streaming_edit_overwrite_file(cx: &mut TestAppContext) {
742        init_test(cx);
743
744        let fs = project::FakeFs::new(cx.executor());
745        fs.insert_tree("/root", json!({"file.txt": "old content"}))
746            .await;
747        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
748        let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
749        let context_server_registry =
750            cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
751        let model = Arc::new(FakeLanguageModel::default());
752        let thread = cx.new(|cx| {
753            crate::Thread::new(
754                project.clone(),
755                cx.new(|_cx| ProjectContext::default()),
756                context_server_registry,
757                Templates::new(),
758                Some(model),
759                cx,
760            )
761        });
762
763        let result = cx
764            .update(|cx| {
765                let input = StreamingEditFileToolInput {
766                    display_description: "Overwrite file".into(),
767                    path: "root/file.txt".into(),
768                    mode: StreamingEditFileMode::Overwrite,
769                    content: Some("new content".into()),
770                    edits: None,
771                };
772                Arc::new(StreamingEditFileTool::new(
773                    project.clone(),
774                    thread.downgrade(),
775                    language_registry,
776                    Templates::new(),
777                ))
778                .run(input, ToolCallEventStream::test().0, cx)
779            })
780            .await;
781
782        assert!(result.is_ok());
783        let output = result.unwrap();
784        assert_eq!(output.new_text, "new content");
785        assert_eq!(*output.old_text, "old content");
786    }
787
788    #[gpui::test]
789    async fn test_streaming_edit_granular_edits(cx: &mut TestAppContext) {
790        init_test(cx);
791
792        let fs = project::FakeFs::new(cx.executor());
793        fs.insert_tree(
794            "/root",
795            json!({
796                "file.txt": "line 1\nline 2\nline 3\n"
797            }),
798        )
799        .await;
800        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
801        let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
802        let context_server_registry =
803            cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
804        let model = Arc::new(FakeLanguageModel::default());
805        let thread = cx.new(|cx| {
806            crate::Thread::new(
807                project.clone(),
808                cx.new(|_cx| ProjectContext::default()),
809                context_server_registry,
810                Templates::new(),
811                Some(model),
812                cx,
813            )
814        });
815
816        let result = cx
817            .update(|cx| {
818                let input = StreamingEditFileToolInput {
819                    display_description: "Edit lines".into(),
820                    path: "root/file.txt".into(),
821                    mode: StreamingEditFileMode::Edit,
822                    content: None,
823                    edits: Some(vec![EditOperation {
824                        old_text: "line 2".into(),
825                        new_text: "modified line 2".into(),
826                    }]),
827                };
828                Arc::new(StreamingEditFileTool::new(
829                    project.clone(),
830                    thread.downgrade(),
831                    language_registry,
832                    Templates::new(),
833                ))
834                .run(input, ToolCallEventStream::test().0, cx)
835            })
836            .await;
837
838        assert!(result.is_ok());
839        let output = result.unwrap();
840        assert_eq!(output.new_text, "line 1\nmodified line 2\nline 3\n");
841    }
842
843    #[gpui::test]
844    async fn test_streaming_edit_nonexistent_file(cx: &mut TestAppContext) {
845        init_test(cx);
846
847        let fs = project::FakeFs::new(cx.executor());
848        fs.insert_tree("/root", json!({})).await;
849        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
850        let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
851        let context_server_registry =
852            cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
853        let model = Arc::new(FakeLanguageModel::default());
854        let thread = cx.new(|cx| {
855            crate::Thread::new(
856                project.clone(),
857                cx.new(|_cx| ProjectContext::default()),
858                context_server_registry,
859                Templates::new(),
860                Some(model),
861                cx,
862            )
863        });
864
865        let result = cx
866            .update(|cx| {
867                let input = StreamingEditFileToolInput {
868                    display_description: "Some edit".into(),
869                    path: "root/nonexistent_file.txt".into(),
870                    mode: StreamingEditFileMode::Edit,
871                    content: None,
872                    edits: Some(vec![EditOperation {
873                        old_text: "foo".into(),
874                        new_text: "bar".into(),
875                    }]),
876                };
877                Arc::new(StreamingEditFileTool::new(
878                    project,
879                    thread.downgrade(),
880                    language_registry,
881                    Templates::new(),
882                ))
883                .run(input, ToolCallEventStream::test().0, cx)
884            })
885            .await;
886
887        assert_eq!(
888            result.unwrap_err().to_string(),
889            "Can't edit file: path not found"
890        );
891    }
892
893    #[gpui::test]
894    async fn test_streaming_edit_failed_match(cx: &mut TestAppContext) {
895        init_test(cx);
896
897        let fs = project::FakeFs::new(cx.executor());
898        fs.insert_tree("/root", json!({"file.txt": "hello world"}))
899            .await;
900        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
901        let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
902        let context_server_registry =
903            cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
904        let model = Arc::new(FakeLanguageModel::default());
905        let thread = cx.new(|cx| {
906            crate::Thread::new(
907                project.clone(),
908                cx.new(|_cx| ProjectContext::default()),
909                context_server_registry,
910                Templates::new(),
911                Some(model),
912                cx,
913            )
914        });
915
916        let result = cx
917            .update(|cx| {
918                let input = StreamingEditFileToolInput {
919                    display_description: "Edit file".into(),
920                    path: "root/file.txt".into(),
921                    mode: StreamingEditFileMode::Edit,
922                    content: None,
923                    edits: Some(vec![EditOperation {
924                        old_text: "nonexistent text that is not in the file".into(),
925                        new_text: "replacement".into(),
926                    }]),
927                };
928                Arc::new(StreamingEditFileTool::new(
929                    project,
930                    thread.downgrade(),
931                    language_registry,
932                    Templates::new(),
933                ))
934                .run(input, ToolCallEventStream::test().0, cx)
935            })
936            .await;
937
938        assert!(result.is_err());
939        assert!(
940            result
941                .unwrap_err()
942                .to_string()
943                .contains("Could not find matching text")
944        );
945    }
946
947    fn init_test(cx: &mut TestAppContext) {
948        cx.update(|cx| {
949            let settings_store = SettingsStore::test(cx);
950            cx.set_global(settings_store);
951        });
952    }
953}