streaming_edit_file_tool.rs

  1use crate::{
  2    AgentTool, Templates, Thread, ToolCallEventStream, ToolPermissionDecision,
  3    decide_permission_from_settings, edit_agent::streaming_fuzzy_matcher::StreamingFuzzyMatcher,
  4};
  5use acp_thread::Diff;
  6use agent_client_protocol::{self as acp, ToolCallLocation, ToolCallUpdateFields};
  7use anyhow::{Context as _, Result, anyhow};
  8use gpui::{App, AppContext, AsyncApp, Entity, Task, WeakEntity};
  9use language::LanguageRegistry;
 10use language_model::LanguageModelToolResultContent;
 11use paths;
 12use project::{Project, ProjectPath};
 13use schemars::JsonSchema;
 14use serde::{Deserialize, Serialize};
 15use settings::Settings;
 16use std::ffi::OsStr;
 17use std::ops::Range;
 18use std::path::{Path, PathBuf};
 19use std::sync::Arc;
 20use text::BufferSnapshot;
 21use ui::SharedString;
 22use util::rel_path::RelPath;
 23
 24const DEFAULT_UI_TEXT: &str = "Editing file";
 25
 26/// This is a tool for creating a new file or editing an existing file. For moving or renaming files, you should generally use the `terminal` tool with the 'mv' command instead.
 27///
 28/// Before using this tool:
 29///
 30/// 1. Use the `read_file` tool to understand the file's contents and context
 31///
 32/// 2. Verify the directory path is correct (only applicable when creating new files):
 33///    - Use the `list_directory` tool to verify the parent directory exists and is the correct location
 34#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
 35pub struct StreamingEditFileToolInput {
 36    /// A one-line, user-friendly markdown description of the edit. This will be shown in the UI.
 37    ///
 38    /// Be terse, but also descriptive in what you want to achieve with this edit. Avoid generic instructions.
 39    ///
 40    /// NEVER mention the file path in this description.
 41    ///
 42    /// <example>Fix API endpoint URLs</example>
 43    /// <example>Update copyright year in `page_footer`</example>
 44    ///
 45    /// Make sure to include this field before all the others in the input object so that we can display it immediately.
 46    pub display_description: String,
 47
 48    /// The full path of the file to create or modify in the project.
 49    ///
 50    /// WARNING: When specifying which file path need changing, you MUST start each path with one of the project's root directories.
 51    ///
 52    /// The following examples assume we have two root directories in the project:
 53    /// - /a/b/backend
 54    /// - /c/d/frontend
 55    ///
 56    /// <example>
 57    /// `backend/src/main.rs`
 58    ///
 59    /// Notice how the file path starts with `backend`. Without that, the path would be ambiguous and the call would fail!
 60    /// </example>
 61    ///
 62    /// <example>
 63    /// `frontend/db.js`
 64    /// </example>
 65    pub path: PathBuf,
 66
 67    /// The mode of operation on the file. Possible values:
 68    /// - 'create': Create a new file if it doesn't exist. Requires 'content' field.
 69    /// - 'overwrite': Replace the entire contents of an existing file. Requires 'content' field.
 70    /// - 'edit': Make granular edits to an existing file. Requires 'edits' field.
 71    ///
 72    /// When a file already exists or you just created it, prefer editing it as opposed to recreating it from scratch.
 73    pub mode: StreamingEditFileMode,
 74
 75    /// The complete content for the new file (required for 'create' and 'overwrite' modes).
 76    /// This field should contain the entire file content.
 77    #[serde(default, skip_serializing_if = "Option::is_none")]
 78    pub content: Option<String>,
 79
 80    /// List of edit operations to apply sequentially (required for 'edit' mode).
 81    /// Each edit finds `old_text` in the file and replaces it with `new_text`.
 82    #[serde(default, skip_serializing_if = "Option::is_none")]
 83    pub edits: Option<Vec<EditOperation>>,
 84}
 85
 86#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
 87#[serde(rename_all = "snake_case")]
 88pub enum StreamingEditFileMode {
 89    /// Create a new file if it doesn't exist
 90    Create,
 91    /// Replace the entire contents of an existing file
 92    Overwrite,
 93    /// Make granular edits to an existing file
 94    Edit,
 95}
 96
 97/// A single edit operation that replaces old text with new text
 98#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
 99pub struct EditOperation {
100    /// The exact text to find in the file. This will be matched using fuzzy matching
101    /// to handle minor differences in whitespace or formatting.
102    pub old_text: String,
103    /// The text to replace it with
104    pub new_text: String,
105}
106
107#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
108struct StreamingEditFileToolPartialInput {
109    #[serde(default)]
110    path: String,
111    #[serde(default)]
112    display_description: String,
113}
114
115#[derive(Debug, Serialize, Deserialize)]
116pub struct StreamingEditFileToolOutput {
117    #[serde(alias = "original_path")]
118    input_path: PathBuf,
119    new_text: String,
120    old_text: Arc<String>,
121    #[serde(default)]
122    diff: String,
123}
124
125impl From<StreamingEditFileToolOutput> for LanguageModelToolResultContent {
126    fn from(output: StreamingEditFileToolOutput) -> Self {
127        if output.diff.is_empty() {
128            "No edits were made.".into()
129        } else {
130            format!(
131                "Edited {}:\n\n```diff\n{}\n```",
132                output.input_path.display(),
133                output.diff
134            )
135            .into()
136        }
137    }
138}
139
140pub struct StreamingEditFileTool {
141    thread: WeakEntity<Thread>,
142    language_registry: Arc<LanguageRegistry>,
143    project: Entity<Project>,
144    #[allow(dead_code)]
145    templates: Arc<Templates>,
146}
147
148impl StreamingEditFileTool {
149    pub fn new(
150        project: Entity<Project>,
151        thread: WeakEntity<Thread>,
152        language_registry: Arc<LanguageRegistry>,
153        templates: Arc<Templates>,
154    ) -> Self {
155        Self {
156            project,
157            thread,
158            language_registry,
159            templates,
160        }
161    }
162
163    fn authorize(
164        &self,
165        input: &StreamingEditFileToolInput,
166        event_stream: &ToolCallEventStream,
167        cx: &mut App,
168    ) -> Task<Result<()>> {
169        let path_str = input.path.to_string_lossy();
170        let settings = agent_settings::AgentSettings::get_global(cx);
171        let decision = decide_permission_from_settings(Self::name(), &path_str, settings);
172
173        match decision {
174            ToolPermissionDecision::Allow => return Task::ready(Ok(())),
175            ToolPermissionDecision::Deny(reason) => {
176                return Task::ready(Err(anyhow!("{}", reason)));
177            }
178            ToolPermissionDecision::Confirm => {}
179        }
180
181        let local_settings_folder = paths::local_settings_folder_name();
182        let path = Path::new(&input.path);
183        if path.components().any(|component| {
184            component.as_os_str() == <_ as AsRef<OsStr>>::as_ref(&local_settings_folder)
185        }) {
186            let context = crate::ToolPermissionContext {
187                tool_name: "edit_file".to_string(),
188                input_value: path_str.to_string(),
189            };
190            return event_stream.authorize(
191                format!("{} (local settings)", input.display_description),
192                context,
193                cx,
194            );
195        }
196
197        if let Ok(canonical_path) = std::fs::canonicalize(&input.path)
198            && canonical_path.starts_with(paths::config_dir())
199        {
200            let context = crate::ToolPermissionContext {
201                tool_name: "edit_file".to_string(),
202                input_value: path_str.to_string(),
203            };
204            return event_stream.authorize(
205                format!("{} (global settings)", input.display_description),
206                context,
207                cx,
208            );
209        }
210
211        let Ok(project_path) = self.thread.read_with(cx, |thread, cx| {
212            thread.project().read(cx).find_project_path(&input.path, cx)
213        }) else {
214            return Task::ready(Err(anyhow!("thread was dropped")));
215        };
216
217        if project_path.is_some() {
218            Task::ready(Ok(()))
219        } else {
220            let context = crate::ToolPermissionContext {
221                tool_name: "edit_file".to_string(),
222                input_value: path_str.to_string(),
223            };
224            event_stream.authorize(&input.display_description, context, cx)
225        }
226    }
227}
228
229impl AgentTool for StreamingEditFileTool {
230    type Input = StreamingEditFileToolInput;
231    type Output = StreamingEditFileToolOutput;
232
233    fn name() -> &'static str {
234        "streaming_edit_file"
235    }
236
237    fn kind() -> acp::ToolKind {
238        acp::ToolKind::Edit
239    }
240
241    fn initial_title(
242        &self,
243        input: Result<Self::Input, serde_json::Value>,
244        cx: &mut App,
245    ) -> SharedString {
246        match input {
247            Ok(input) => self
248                .project
249                .read(cx)
250                .find_project_path(&input.path, cx)
251                .and_then(|project_path| {
252                    self.project
253                        .read(cx)
254                        .short_full_path_for_project_path(&project_path, cx)
255                })
256                .unwrap_or(input.path.to_string_lossy().into_owned())
257                .into(),
258            Err(raw_input) => {
259                if let Some(input) =
260                    serde_json::from_value::<StreamingEditFileToolPartialInput>(raw_input).ok()
261                {
262                    let path = input.path.trim();
263                    if !path.is_empty() {
264                        return self
265                            .project
266                            .read(cx)
267                            .find_project_path(&input.path, cx)
268                            .and_then(|project_path| {
269                                self.project
270                                    .read(cx)
271                                    .short_full_path_for_project_path(&project_path, cx)
272                            })
273                            .unwrap_or(input.path)
274                            .into();
275                    }
276
277                    let description = input.display_description.trim();
278                    if !description.is_empty() {
279                        return description.to_string().into();
280                    }
281                }
282
283                DEFAULT_UI_TEXT.into()
284            }
285        }
286    }
287
288    fn run(
289        self: Arc<Self>,
290        input: Self::Input,
291        event_stream: ToolCallEventStream,
292        cx: &mut App,
293    ) -> Task<Result<Self::Output>> {
294        let Ok(project) = self
295            .thread
296            .read_with(cx, |thread, _cx| thread.project().clone())
297        else {
298            return Task::ready(Err(anyhow!("thread was dropped")));
299        };
300
301        let project_path = match resolve_path(&input, project.clone(), cx) {
302            Ok(path) => path,
303            Err(err) => return Task::ready(Err(anyhow!(err))),
304        };
305
306        let abs_path = project.read(cx).absolute_path(&project_path, cx);
307        if let Some(abs_path) = abs_path.clone() {
308            event_stream.update_fields(
309                ToolCallUpdateFields::new().locations(vec![acp::ToolCallLocation::new(abs_path)]),
310            );
311        }
312
313        let authorize = self.authorize(&input, &event_stream, cx);
314
315        cx.spawn(async move |cx: &mut AsyncApp| {
316            authorize.await?;
317
318            let buffer = project
319                .update(cx, |project, cx| {
320                    project.open_buffer(project_path.clone(), cx)
321                })
322                .await?;
323
324            if let Some(abs_path) = abs_path.as_ref() {
325                let (last_read_mtime, current_mtime, is_dirty, has_save_tool, has_restore_tool) =
326                    self.thread.update(cx, |thread, cx| {
327                        let last_read = thread.file_read_times.get(abs_path).copied();
328                        let current = buffer
329                            .read(cx)
330                            .file()
331                            .and_then(|file| file.disk_state().mtime());
332                        let dirty = buffer.read(cx).is_dirty();
333                        let has_save = thread.has_tool("save_file");
334                        let has_restore = thread.has_tool("restore_file_from_disk");
335                        (last_read, current, dirty, has_save, has_restore)
336                    })?;
337
338                if is_dirty {
339                    let message = match (has_save_tool, has_restore_tool) {
340                        (true, true) => {
341                            "This file has unsaved changes. Ask the user whether they want to keep or discard those changes. \
342                             If they want to keep them, ask for confirmation then use the save_file tool to save the file, then retry this edit. \
343                             If they want to discard them, ask for confirmation then use the restore_file_from_disk tool to restore the on-disk contents, then retry this edit."
344                        }
345                        (true, false) => {
346                            "This file has unsaved changes. Ask the user whether they want to keep or discard those changes. \
347                             If they want to keep them, ask for confirmation then use the save_file tool to save the file, then retry this edit. \
348                             If they want to discard them, ask the user to manually revert the file, then inform you when it's ok to proceed."
349                        }
350                        (false, true) => {
351                            "This file has unsaved changes. Ask the user whether they want to keep or discard those changes. \
352                             If they want to keep them, ask the user to manually save the file, then inform you when it's ok to proceed. \
353                             If they want to discard them, ask for confirmation then use the restore_file_from_disk tool to restore the on-disk contents, then retry this edit."
354                        }
355                        (false, false) => {
356                            "This file has unsaved changes. Ask the user whether they want to keep or discard those changes, \
357                             then ask them to save or revert the file manually and inform you when it's ok to proceed."
358                        }
359                    };
360                    anyhow::bail!("{}", message);
361                }
362
363                if let (Some(last_read), Some(current)) = (last_read_mtime, current_mtime) {
364                    if current != last_read {
365                        anyhow::bail!(
366                            "The file {} has been modified since you last read it. \
367                             Please read the file again to get the current state before editing it.",
368                            input.path.display()
369                        );
370                    }
371                }
372            }
373
374            let diff = cx.new(|cx| Diff::new(buffer.clone(), cx));
375            event_stream.update_diff(diff.clone());
376            let _finalize_diff = util::defer({
377                let diff = diff.downgrade();
378                let mut cx = cx.clone();
379                move || {
380                    diff.update(&mut cx, |diff, cx| diff.finalize(cx)).ok();
381                }
382            });
383
384            let old_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
385            let old_text = cx
386                .background_spawn({
387                    let old_snapshot = old_snapshot.clone();
388                    async move { Arc::new(old_snapshot.text()) }
389                })
390                .await;
391
392            let action_log = self.thread.read_with(cx, |thread, _cx| thread.action_log().clone())?;
393
394            // Edit the buffer and report edits to the action log as part of the
395            // same effect cycle, otherwise the edit will be reported as if the
396            // user made it (due to the buffer subscription in action_log).
397            match input.mode {
398                StreamingEditFileMode::Create | StreamingEditFileMode::Overwrite => {
399                    action_log.update(cx, |log, cx| {
400                        log.buffer_created(buffer.clone(), cx);
401                    });
402                    let content = input.content.ok_or_else(|| {
403                        anyhow!("'content' field is required for create and overwrite modes")
404                    })?;
405                    cx.update(|cx| {
406                        buffer.update(cx, |buffer, cx| {
407                            buffer.edit([(0..buffer.len(), content.as_str())], None, cx);
408                        });
409                        action_log.update(cx, |log, cx| {
410                            log.buffer_edited(buffer.clone(), cx);
411                        });
412                    });
413                }
414                StreamingEditFileMode::Edit => {
415                    action_log.update(cx, |log, cx| {
416                        log.buffer_read(buffer.clone(), cx);
417                    });
418                    let edits = input.edits.ok_or_else(|| {
419                        anyhow!("'edits' field is required for edit mode")
420                    })?;
421                    // apply_edits now handles buffer_edited internally in the same effect cycle
422                    apply_edits(&buffer, &action_log, &edits, &diff, &event_stream, &abs_path, cx)?;
423                }
424            }
425
426            project
427                .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))
428                .await?;
429
430            action_log.update(cx, |log, cx| {
431                log.buffer_edited(buffer.clone(), cx);
432            });
433
434            if let Some(abs_path) = abs_path.as_ref() {
435                if let Some(new_mtime) = buffer.read_with(cx, |buffer, _| {
436                    buffer.file().and_then(|file| file.disk_state().mtime())
437                }) {
438                    self.thread.update(cx, |thread, _| {
439                        thread.file_read_times.insert(abs_path.to_path_buf(), new_mtime);
440                    })?;
441                }
442            }
443
444            let new_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
445            let (new_text, unified_diff) = cx
446                .background_spawn({
447                    let new_snapshot = new_snapshot.clone();
448                    let old_text = old_text.clone();
449                    async move {
450                        let new_text = new_snapshot.text();
451                        let diff = language::unified_diff(&old_text, &new_text);
452                        (new_text, diff)
453                    }
454                })
455                .await;
456
457            let output = StreamingEditFileToolOutput {
458                input_path: input.path,
459                new_text,
460                old_text,
461                diff: unified_diff,
462            };
463
464            Ok(output)
465        })
466    }
467
468    fn replay(
469        &self,
470        _input: Self::Input,
471        output: Self::Output,
472        event_stream: ToolCallEventStream,
473        cx: &mut App,
474    ) -> Result<()> {
475        event_stream.update_diff(cx.new(|cx| {
476            Diff::finalized(
477                output.input_path.to_string_lossy().into_owned(),
478                Some(output.old_text.to_string()),
479                output.new_text,
480                self.language_registry.clone(),
481                cx,
482            )
483        }));
484        Ok(())
485    }
486}
487
488fn apply_edits(
489    buffer: &Entity<language::Buffer>,
490    action_log: &Entity<action_log::ActionLog>,
491    edits: &[EditOperation],
492    diff: &Entity<Diff>,
493    event_stream: &ToolCallEventStream,
494    abs_path: &Option<PathBuf>,
495    cx: &mut AsyncApp,
496) -> Result<()> {
497    let mut failed_edits = Vec::new();
498    let mut ambiguous_edits = Vec::new();
499    let mut resolved_edits: Vec<(Range<usize>, String)> = Vec::new();
500    let mut first_edit_line: Option<u32> = None;
501
502    // First pass: resolve all edits without applying them
503    for (index, edit) in edits.iter().enumerate() {
504        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
505        let result = resolve_edit(&snapshot, edit);
506
507        match result {
508            Ok(Some((range, new_text))) => {
509                if first_edit_line.is_none() {
510                    first_edit_line = Some(snapshot.offset_to_point(range.start).row);
511                }
512                // Reveal the range in the diff view
513                let start_anchor =
514                    buffer.read_with(cx, |buffer, _cx| buffer.anchor_before(range.start));
515                let end_anchor = buffer.read_with(cx, |buffer, _cx| buffer.anchor_after(range.end));
516                diff.update(cx, |card, cx| {
517                    card.reveal_range(start_anchor..end_anchor, cx)
518                });
519                resolved_edits.push((range, new_text));
520            }
521            Ok(None) => {
522                failed_edits.push(index);
523            }
524            Err(ranges) => {
525                ambiguous_edits.push((index, ranges));
526            }
527        }
528    }
529
530    // Check for errors before applying any edits
531    if !failed_edits.is_empty() {
532        let indices = failed_edits
533            .iter()
534            .map(|i| i.to_string())
535            .collect::<Vec<_>>()
536            .join(", ");
537        anyhow::bail!(
538            "Could not find matching text for edit(s) at index(es): {}. \
539             The old_text did not match any content in the file. \
540             Please read the file again to get the current content.",
541            indices
542        );
543    }
544
545    if !ambiguous_edits.is_empty() {
546        let details: Vec<String> = ambiguous_edits
547            .iter()
548            .map(|(index, ranges)| {
549                let lines = ranges
550                    .iter()
551                    .map(|r| r.start.to_string())
552                    .collect::<Vec<_>>()
553                    .join(", ");
554                format!("edit {}: matches at lines {}", index, lines)
555            })
556            .collect();
557        anyhow::bail!(
558            "Some edits matched multiple locations in the file:\n{}. \
559             Please provide more context in old_text to uniquely identify the location.",
560            details.join("\n")
561        );
562    }
563
564    // Emit location for the first edit
565    if let Some(line) = first_edit_line {
566        if let Some(abs_path) = abs_path.clone() {
567            event_stream.update_fields(
568                ToolCallUpdateFields::new()
569                    .locations(vec![ToolCallLocation::new(abs_path).line(Some(line))]),
570            );
571        }
572    }
573
574    // Second pass: apply all edits and report to action_log in the same effect cycle.
575    // This prevents the buffer subscription from treating these as user edits.
576    if !resolved_edits.is_empty() {
577        cx.update(|cx| {
578            buffer.update(cx, |buffer, cx| {
579                // Apply edits in reverse order so offsets remain valid
580                let mut edits_sorted: Vec<_> = resolved_edits.into_iter().collect();
581                edits_sorted.sort_by(|a, b| b.0.start.cmp(&a.0.start));
582                for (range, new_text) in edits_sorted {
583                    buffer.edit([(range, new_text.as_str())], None, cx);
584                }
585            });
586            action_log.update(cx, |log, cx| {
587                log.buffer_edited(buffer.clone(), cx);
588            });
589        });
590    }
591
592    Ok(())
593}
594
595/// Resolves an edit operation by finding the matching text in the buffer.
596/// Returns Ok(Some((range, new_text))) if a unique match is found,
597/// Ok(None) if no match is found, or Err(ranges) if multiple matches are found.
598fn resolve_edit(
599    snapshot: &BufferSnapshot,
600    edit: &EditOperation,
601) -> std::result::Result<Option<(Range<usize>, String)>, Vec<Range<usize>>> {
602    let mut matcher = StreamingFuzzyMatcher::new(snapshot.clone());
603    matcher.push(&edit.old_text, None);
604    let matches = matcher.finish();
605
606    if matches.is_empty() {
607        return Ok(None);
608    }
609
610    if matches.len() > 1 {
611        return Err(matches);
612    }
613
614    let match_range = matches.into_iter().next().expect("checked len above");
615    Ok(Some((match_range, edit.new_text.clone())))
616}
617
618fn resolve_path(
619    input: &StreamingEditFileToolInput,
620    project: Entity<Project>,
621    cx: &mut App,
622) -> Result<ProjectPath> {
623    let project = project.read(cx);
624
625    match input.mode {
626        StreamingEditFileMode::Edit | StreamingEditFileMode::Overwrite => {
627            let path = project
628                .find_project_path(&input.path, cx)
629                .context("Can't edit file: path not found")?;
630
631            let entry = project
632                .entry_for_path(&path, cx)
633                .context("Can't edit file: path not found")?;
634
635            anyhow::ensure!(entry.is_file(), "Can't edit file: path is a directory");
636            Ok(path)
637        }
638
639        StreamingEditFileMode::Create => {
640            if let Some(path) = project.find_project_path(&input.path, cx) {
641                anyhow::ensure!(
642                    project.entry_for_path(&path, cx).is_none(),
643                    "Can't create file: file already exists"
644                );
645            }
646
647            let parent_path = input
648                .path
649                .parent()
650                .context("Can't create file: incorrect path")?;
651
652            let parent_project_path = project.find_project_path(&parent_path, cx);
653
654            let parent_entry = parent_project_path
655                .as_ref()
656                .and_then(|path| project.entry_for_path(path, cx))
657                .context("Can't create file: parent directory doesn't exist")?;
658
659            anyhow::ensure!(
660                parent_entry.is_dir(),
661                "Can't create file: parent is not a directory"
662            );
663
664            let file_name = input
665                .path
666                .file_name()
667                .and_then(|file_name| file_name.to_str())
668                .and_then(|file_name| RelPath::unix(file_name).ok())
669                .context("Can't create file: invalid filename")?;
670
671            let new_file_path = parent_project_path.map(|parent| ProjectPath {
672                path: parent.path.join(file_name),
673                ..parent
674            });
675
676            new_file_path.context("Can't create file")
677        }
678    }
679}
680
681#[cfg(test)]
682mod tests {
683    use super::*;
684    use crate::{ContextServerRegistry, Templates};
685    use gpui::TestAppContext;
686    use language_model::fake_provider::FakeLanguageModel;
687    use prompt_store::ProjectContext;
688    use serde_json::json;
689    use settings::SettingsStore;
690    use util::path;
691
692    #[gpui::test]
693    async fn test_streaming_edit_create_file(cx: &mut TestAppContext) {
694        init_test(cx);
695
696        let fs = project::FakeFs::new(cx.executor());
697        fs.insert_tree("/root", json!({"dir": {}})).await;
698        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
699        let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
700        let context_server_registry =
701            cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
702        let model = Arc::new(FakeLanguageModel::default());
703        let thread = cx.new(|cx| {
704            crate::Thread::new(
705                project.clone(),
706                cx.new(|_cx| ProjectContext::default()),
707                context_server_registry,
708                Templates::new(),
709                Some(model),
710                cx,
711            )
712        });
713
714        let result = cx
715            .update(|cx| {
716                let input = StreamingEditFileToolInput {
717                    display_description: "Create new file".into(),
718                    path: "root/dir/new_file.txt".into(),
719                    mode: StreamingEditFileMode::Create,
720                    content: Some("Hello, World!".into()),
721                    edits: None,
722                };
723                Arc::new(StreamingEditFileTool::new(
724                    project.clone(),
725                    thread.downgrade(),
726                    language_registry,
727                    Templates::new(),
728                ))
729                .run(input, ToolCallEventStream::test().0, cx)
730            })
731            .await;
732
733        assert!(result.is_ok());
734        let output = result.unwrap();
735        assert_eq!(output.new_text, "Hello, World!");
736        assert!(!output.diff.is_empty());
737    }
738
739    #[gpui::test]
740    async fn test_streaming_edit_overwrite_file(cx: &mut TestAppContext) {
741        init_test(cx);
742
743        let fs = project::FakeFs::new(cx.executor());
744        fs.insert_tree("/root", json!({"file.txt": "old content"}))
745            .await;
746        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
747        let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
748        let context_server_registry =
749            cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
750        let model = Arc::new(FakeLanguageModel::default());
751        let thread = cx.new(|cx| {
752            crate::Thread::new(
753                project.clone(),
754                cx.new(|_cx| ProjectContext::default()),
755                context_server_registry,
756                Templates::new(),
757                Some(model),
758                cx,
759            )
760        });
761
762        let result = cx
763            .update(|cx| {
764                let input = StreamingEditFileToolInput {
765                    display_description: "Overwrite file".into(),
766                    path: "root/file.txt".into(),
767                    mode: StreamingEditFileMode::Overwrite,
768                    content: Some("new content".into()),
769                    edits: None,
770                };
771                Arc::new(StreamingEditFileTool::new(
772                    project.clone(),
773                    thread.downgrade(),
774                    language_registry,
775                    Templates::new(),
776                ))
777                .run(input, ToolCallEventStream::test().0, cx)
778            })
779            .await;
780
781        assert!(result.is_ok());
782        let output = result.unwrap();
783        assert_eq!(output.new_text, "new content");
784        assert_eq!(*output.old_text, "old content");
785    }
786
787    #[gpui::test]
788    async fn test_streaming_edit_granular_edits(cx: &mut TestAppContext) {
789        init_test(cx);
790
791        let fs = project::FakeFs::new(cx.executor());
792        fs.insert_tree(
793            "/root",
794            json!({
795                "file.txt": "line 1\nline 2\nline 3\n"
796            }),
797        )
798        .await;
799        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
800        let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
801        let context_server_registry =
802            cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
803        let model = Arc::new(FakeLanguageModel::default());
804        let thread = cx.new(|cx| {
805            crate::Thread::new(
806                project.clone(),
807                cx.new(|_cx| ProjectContext::default()),
808                context_server_registry,
809                Templates::new(),
810                Some(model),
811                cx,
812            )
813        });
814
815        let result = cx
816            .update(|cx| {
817                let input = StreamingEditFileToolInput {
818                    display_description: "Edit lines".into(),
819                    path: "root/file.txt".into(),
820                    mode: StreamingEditFileMode::Edit,
821                    content: None,
822                    edits: Some(vec![EditOperation {
823                        old_text: "line 2".into(),
824                        new_text: "modified line 2".into(),
825                    }]),
826                };
827                Arc::new(StreamingEditFileTool::new(
828                    project.clone(),
829                    thread.downgrade(),
830                    language_registry,
831                    Templates::new(),
832                ))
833                .run(input, ToolCallEventStream::test().0, cx)
834            })
835            .await;
836
837        assert!(result.is_ok());
838        let output = result.unwrap();
839        assert_eq!(output.new_text, "line 1\nmodified line 2\nline 3\n");
840    }
841
842    #[gpui::test]
843    async fn test_streaming_edit_nonexistent_file(cx: &mut TestAppContext) {
844        init_test(cx);
845
846        let fs = project::FakeFs::new(cx.executor());
847        fs.insert_tree("/root", json!({})).await;
848        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
849        let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
850        let context_server_registry =
851            cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
852        let model = Arc::new(FakeLanguageModel::default());
853        let thread = cx.new(|cx| {
854            crate::Thread::new(
855                project.clone(),
856                cx.new(|_cx| ProjectContext::default()),
857                context_server_registry,
858                Templates::new(),
859                Some(model),
860                cx,
861            )
862        });
863
864        let result = cx
865            .update(|cx| {
866                let input = StreamingEditFileToolInput {
867                    display_description: "Some edit".into(),
868                    path: "root/nonexistent_file.txt".into(),
869                    mode: StreamingEditFileMode::Edit,
870                    content: None,
871                    edits: Some(vec![EditOperation {
872                        old_text: "foo".into(),
873                        new_text: "bar".into(),
874                    }]),
875                };
876                Arc::new(StreamingEditFileTool::new(
877                    project,
878                    thread.downgrade(),
879                    language_registry,
880                    Templates::new(),
881                ))
882                .run(input, ToolCallEventStream::test().0, cx)
883            })
884            .await;
885
886        assert_eq!(
887            result.unwrap_err().to_string(),
888            "Can't edit file: path not found"
889        );
890    }
891
892    #[gpui::test]
893    async fn test_streaming_edit_failed_match(cx: &mut TestAppContext) {
894        init_test(cx);
895
896        let fs = project::FakeFs::new(cx.executor());
897        fs.insert_tree("/root", json!({"file.txt": "hello world"}))
898            .await;
899        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
900        let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
901        let context_server_registry =
902            cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
903        let model = Arc::new(FakeLanguageModel::default());
904        let thread = cx.new(|cx| {
905            crate::Thread::new(
906                project.clone(),
907                cx.new(|_cx| ProjectContext::default()),
908                context_server_registry,
909                Templates::new(),
910                Some(model),
911                cx,
912            )
913        });
914
915        let result = cx
916            .update(|cx| {
917                let input = StreamingEditFileToolInput {
918                    display_description: "Edit file".into(),
919                    path: "root/file.txt".into(),
920                    mode: StreamingEditFileMode::Edit,
921                    content: None,
922                    edits: Some(vec![EditOperation {
923                        old_text: "nonexistent text that is not in the file".into(),
924                        new_text: "replacement".into(),
925                    }]),
926                };
927                Arc::new(StreamingEditFileTool::new(
928                    project,
929                    thread.downgrade(),
930                    language_registry,
931                    Templates::new(),
932                ))
933                .run(input, ToolCallEventStream::test().0, cx)
934            })
935            .await;
936
937        assert!(result.is_err());
938        assert!(
939            result
940                .unwrap_err()
941                .to_string()
942                .contains("Could not find matching text")
943        );
944    }
945
946    fn init_test(cx: &mut TestAppContext) {
947        cx.update(|cx| {
948            let settings_store = SettingsStore::test(cx);
949            cx.set_global(settings_store);
950        });
951    }
952}