streaming_edit_file_tool.rs

  1use crate::{
  2    AgentTool, Templates, Thread, ToolCallEventStream, ToolPermissionDecision,
  3    decide_permission_from_settings, edit_agent::streaming_fuzzy_matcher::StreamingFuzzyMatcher,
  4};
  5use acp_thread::Diff;
  6use agent_client_protocol::{self as acp, ToolCallLocation, ToolCallUpdateFields};
  7use anyhow::{Context as _, Result, anyhow};
  8use gpui::{App, AppContext, AsyncApp, Entity, Task, WeakEntity};
  9use language::{Anchor, LanguageRegistry, ToPoint};
 10use language_model::LanguageModelToolResultContent;
 11use paths;
 12use project::{Project, ProjectPath};
 13use schemars::JsonSchema;
 14use serde::{Deserialize, Serialize};
 15use settings::Settings;
 16use std::ffi::OsStr;
 17use std::ops::Range;
 18use std::path::{Path, PathBuf};
 19use std::sync::Arc;
 20use text::BufferSnapshot;
 21use ui::SharedString;
 22use util::rel_path::RelPath;
 23
 24const DEFAULT_UI_TEXT: &str = "Editing file";
 25
 26/// This is a tool for creating a new file or editing an existing file. For moving or renaming files, you should generally use the `terminal` tool with the 'mv' command instead.
 27///
 28/// Before using this tool:
 29///
 30/// 1. Use the `read_file` tool to understand the file's contents and context
 31///
 32/// 2. Verify the directory path is correct (only applicable when creating new files):
 33///    - Use the `list_directory` tool to verify the parent directory exists and is the correct location
 34#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
 35pub struct StreamingEditFileToolInput {
 36    /// A one-line, user-friendly markdown description of the edit. This will be shown in the UI.
 37    ///
 38    /// Be terse, but also descriptive in what you want to achieve with this edit. Avoid generic instructions.
 39    ///
 40    /// NEVER mention the file path in this description.
 41    ///
 42    /// <example>Fix API endpoint URLs</example>
 43    /// <example>Update copyright year in `page_footer`</example>
 44    ///
 45    /// Make sure to include this field before all the others in the input object so that we can display it immediately.
 46    pub display_description: String,
 47
 48    /// The full path of the file to create or modify in the project.
 49    ///
 50    /// WARNING: When specifying which file path need changing, you MUST start each path with one of the project's root directories.
 51    ///
 52    /// The following examples assume we have two root directories in the project:
 53    /// - /a/b/backend
 54    /// - /c/d/frontend
 55    ///
 56    /// <example>
 57    /// `backend/src/main.rs`
 58    ///
 59    /// Notice how the file path starts with `backend`. Without that, the path would be ambiguous and the call would fail!
 60    /// </example>
 61    ///
 62    /// <example>
 63    /// `frontend/db.js`
 64    /// </example>
 65    pub path: PathBuf,
 66
 67    /// The mode of operation on the file. Possible values:
 68    /// - 'create': Create a new file if it doesn't exist. Requires 'content' field.
 69    /// - 'overwrite': Replace the entire contents of an existing file. Requires 'content' field.
 70    /// - 'edit': Make granular edits to an existing file. Requires 'edits' field.
 71    ///
 72    /// When a file already exists or you just created it, prefer editing it as opposed to recreating it from scratch.
 73    pub mode: StreamingEditFileMode,
 74
 75    /// The complete content for the new file (required for 'create' and 'overwrite' modes).
 76    /// This field should contain the entire file content.
 77    #[serde(default, skip_serializing_if = "Option::is_none")]
 78    pub content: Option<String>,
 79
 80    /// List of edit operations to apply sequentially (required for 'edit' mode).
 81    /// Each edit finds `old_text` in the file and replaces it with `new_text`.
 82    #[serde(default, skip_serializing_if = "Option::is_none")]
 83    pub edits: Option<Vec<EditOperation>>,
 84}
 85
 86#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
 87#[serde(rename_all = "snake_case")]
 88pub enum StreamingEditFileMode {
 89    /// Create a new file if it doesn't exist
 90    Create,
 91    /// Replace the entire contents of an existing file
 92    Overwrite,
 93    /// Make granular edits to an existing file
 94    Edit,
 95}
 96
 97/// A single edit operation that replaces old text with new text
 98#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
 99pub struct EditOperation {
100    /// The exact text to find in the file. This will be matched using fuzzy matching
101    /// to handle minor differences in whitespace or formatting.
102    pub old_text: String,
103    /// The text to replace it with
104    pub new_text: String,
105}
106
107#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
108struct StreamingEditFileToolPartialInput {
109    #[serde(default)]
110    path: String,
111    #[serde(default)]
112    display_description: String,
113}
114
115#[derive(Debug, Serialize, Deserialize)]
116pub struct StreamingEditFileToolOutput {
117    #[serde(alias = "original_path")]
118    input_path: PathBuf,
119    new_text: String,
120    old_text: Arc<String>,
121    #[serde(default)]
122    diff: String,
123}
124
125impl From<StreamingEditFileToolOutput> for LanguageModelToolResultContent {
126    fn from(output: StreamingEditFileToolOutput) -> Self {
127        if output.diff.is_empty() {
128            "No edits were made.".into()
129        } else {
130            format!(
131                "Edited {}:\n\n```diff\n{}\n```",
132                output.input_path.display(),
133                output.diff
134            )
135            .into()
136        }
137    }
138}
139
140pub struct StreamingEditFileTool {
141    thread: WeakEntity<Thread>,
142    language_registry: Arc<LanguageRegistry>,
143    project: Entity<Project>,
144    #[allow(dead_code)]
145    templates: Arc<Templates>,
146}
147
148impl StreamingEditFileTool {
149    pub fn new(
150        project: Entity<Project>,
151        thread: WeakEntity<Thread>,
152        language_registry: Arc<LanguageRegistry>,
153        templates: Arc<Templates>,
154    ) -> Self {
155        Self {
156            project,
157            thread,
158            language_registry,
159            templates,
160        }
161    }
162
163    fn authorize(
164        &self,
165        input: &StreamingEditFileToolInput,
166        event_stream: &ToolCallEventStream,
167        cx: &mut App,
168    ) -> Task<Result<()>> {
169        let path_str = input.path.to_string_lossy();
170        let settings = agent_settings::AgentSettings::get_global(cx);
171        let decision = decide_permission_from_settings(Self::name(), &path_str, settings);
172
173        match decision {
174            ToolPermissionDecision::Allow => return Task::ready(Ok(())),
175            ToolPermissionDecision::Deny(reason) => {
176                return Task::ready(Err(anyhow!("{}", reason)));
177            }
178            ToolPermissionDecision::Confirm => {}
179        }
180
181        let local_settings_folder = paths::local_settings_folder_name();
182        let path = Path::new(&input.path);
183        if path.components().any(|component| {
184            component.as_os_str() == <_ as AsRef<OsStr>>::as_ref(&local_settings_folder)
185        }) {
186            let context = crate::ToolPermissionContext {
187                tool_name: "edit_file".to_string(),
188                input_value: path_str.to_string(),
189            };
190            return event_stream.authorize(
191                format!("{} (local settings)", input.display_description),
192                context,
193                cx,
194            );
195        }
196
197        if let Ok(canonical_path) = std::fs::canonicalize(&input.path)
198            && canonical_path.starts_with(paths::config_dir())
199        {
200            let context = crate::ToolPermissionContext {
201                tool_name: "edit_file".to_string(),
202                input_value: path_str.to_string(),
203            };
204            return event_stream.authorize(
205                format!("{} (global settings)", input.display_description),
206                context,
207                cx,
208            );
209        }
210
211        let Ok(project_path) = self.thread.read_with(cx, |thread, cx| {
212            thread.project().read(cx).find_project_path(&input.path, cx)
213        }) else {
214            return Task::ready(Err(anyhow!("thread was dropped")));
215        };
216
217        if project_path.is_some() {
218            Task::ready(Ok(()))
219        } else {
220            let context = crate::ToolPermissionContext {
221                tool_name: "edit_file".to_string(),
222                input_value: path_str.to_string(),
223            };
224            event_stream.authorize(&input.display_description, context, cx)
225        }
226    }
227}
228
229impl AgentTool for StreamingEditFileTool {
230    type Input = StreamingEditFileToolInput;
231    type Output = StreamingEditFileToolOutput;
232
233    fn name() -> &'static str {
234        "streaming_edit_file"
235    }
236
237    fn kind() -> acp::ToolKind {
238        acp::ToolKind::Edit
239    }
240
241    fn initial_title(
242        &self,
243        input: Result<Self::Input, serde_json::Value>,
244        cx: &mut App,
245    ) -> SharedString {
246        match input {
247            Ok(input) => self
248                .project
249                .read(cx)
250                .find_project_path(&input.path, cx)
251                .and_then(|project_path| {
252                    self.project
253                        .read(cx)
254                        .short_full_path_for_project_path(&project_path, cx)
255                })
256                .unwrap_or(input.path.to_string_lossy().into_owned())
257                .into(),
258            Err(raw_input) => {
259                if let Some(input) =
260                    serde_json::from_value::<StreamingEditFileToolPartialInput>(raw_input).ok()
261                {
262                    let path = input.path.trim();
263                    if !path.is_empty() {
264                        return self
265                            .project
266                            .read(cx)
267                            .find_project_path(&input.path, cx)
268                            .and_then(|project_path| {
269                                self.project
270                                    .read(cx)
271                                    .short_full_path_for_project_path(&project_path, cx)
272                            })
273                            .unwrap_or(input.path)
274                            .into();
275                    }
276
277                    let description = input.display_description.trim();
278                    if !description.is_empty() {
279                        return description.to_string().into();
280                    }
281                }
282
283                DEFAULT_UI_TEXT.into()
284            }
285        }
286    }
287
288    fn run(
289        self: Arc<Self>,
290        input: Self::Input,
291        event_stream: ToolCallEventStream,
292        cx: &mut App,
293    ) -> Task<Result<Self::Output>> {
294        let Ok(project) = self
295            .thread
296            .read_with(cx, |thread, _cx| thread.project().clone())
297        else {
298            return Task::ready(Err(anyhow!("thread was dropped")));
299        };
300
301        let project_path = match resolve_path(&input, project.clone(), cx) {
302            Ok(path) => path,
303            Err(err) => return Task::ready(Err(anyhow!(err))),
304        };
305
306        let abs_path = project.read(cx).absolute_path(&project_path, cx);
307        if let Some(abs_path) = abs_path.clone() {
308            event_stream.update_fields(
309                ToolCallUpdateFields::new().locations(vec![acp::ToolCallLocation::new(abs_path)]),
310            );
311        }
312
313        let authorize = self.authorize(&input, &event_stream, cx);
314
315        cx.spawn(async move |cx: &mut AsyncApp| {
316            authorize.await?;
317
318            let buffer = project
319                .update(cx, |project, cx| {
320                    project.open_buffer(project_path.clone(), cx)
321                })
322                .await?;
323
324            if let Some(abs_path) = abs_path.as_ref() {
325                let (last_read_mtime, current_mtime, is_dirty, has_save_tool, has_restore_tool) =
326                    self.thread.update(cx, |thread, cx| {
327                        let last_read = thread.file_read_times.get(abs_path).copied();
328                        let current = buffer
329                            .read(cx)
330                            .file()
331                            .and_then(|file| file.disk_state().mtime());
332                        let dirty = buffer.read(cx).is_dirty();
333                        let has_save = thread.has_tool("save_file");
334                        let has_restore = thread.has_tool("restore_file_from_disk");
335                        (last_read, current, dirty, has_save, has_restore)
336                    })?;
337
338                if is_dirty {
339                    let message = match (has_save_tool, has_restore_tool) {
340                        (true, true) => {
341                            "This file has unsaved changes. Ask the user whether they want to keep or discard those changes. \
342                             If they want to keep them, ask for confirmation then use the save_file tool to save the file, then retry this edit. \
343                             If they want to discard them, ask for confirmation then use the restore_file_from_disk tool to restore the on-disk contents, then retry this edit."
344                        }
345                        (true, false) => {
346                            "This file has unsaved changes. Ask the user whether they want to keep or discard those changes. \
347                             If they want to keep them, ask for confirmation then use the save_file tool to save the file, then retry this edit. \
348                             If they want to discard them, ask the user to manually revert the file, then inform you when it's ok to proceed."
349                        }
350                        (false, true) => {
351                            "This file has unsaved changes. Ask the user whether they want to keep or discard those changes. \
352                             If they want to keep them, ask the user to manually save the file, then inform you when it's ok to proceed. \
353                             If they want to discard them, ask for confirmation then use the restore_file_from_disk tool to restore the on-disk contents, then retry this edit."
354                        }
355                        (false, false) => {
356                            "This file has unsaved changes. Ask the user whether they want to keep or discard those changes, \
357                             then ask them to save or revert the file manually and inform you when it's ok to proceed."
358                        }
359                    };
360                    anyhow::bail!("{}", message);
361                }
362
363                if let (Some(last_read), Some(current)) = (last_read_mtime, current_mtime) {
364                    if current != last_read {
365                        anyhow::bail!(
366                            "The file {} has been modified since you last read it. \
367                             Please read the file again to get the current state before editing it.",
368                            input.path.display()
369                        );
370                    }
371                }
372            }
373
374            let diff = cx.new(|cx| Diff::new(buffer.clone(), cx));
375            event_stream.update_diff(diff.clone());
376            let _finalize_diff = util::defer({
377                let diff = diff.downgrade();
378                let mut cx = cx.clone();
379                move || {
380                    diff.update(&mut cx, |diff, cx| diff.finalize(cx)).ok();
381                }
382            });
383
384            let old_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
385            let old_text = cx
386                .background_spawn({
387                    let old_snapshot = old_snapshot.clone();
388                    async move { Arc::new(old_snapshot.text()) }
389                })
390                .await;
391
392            match input.mode {
393                StreamingEditFileMode::Create | StreamingEditFileMode::Overwrite => {
394                    let content = input.content.ok_or_else(|| {
395                        anyhow!("'content' field is required for create and overwrite modes")
396                    })?;
397                    buffer.update(cx, |buffer, cx| {
398                        buffer.edit([(0..buffer.len(), content.as_str())], None, cx);
399                    });
400                }
401                StreamingEditFileMode::Edit => {
402                    let edits = input.edits.ok_or_else(|| {
403                        anyhow!("'edits' field is required for edit mode")
404                    })?;
405                    apply_edits(&buffer, &edits, &diff, &event_stream, &abs_path, cx)?;
406                }
407            }
408
409            let action_log = self.thread.read_with(cx, |thread, _cx| thread.action_log().clone())?;
410
411            action_log.update(cx, |log, cx| {
412                log.buffer_edited(buffer.clone(), cx);
413            });
414
415            project
416                .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))
417                .await?;
418
419            action_log.update(cx, |log, cx| {
420                log.buffer_edited(buffer.clone(), cx);
421            });
422
423            if let Some(abs_path) = abs_path.as_ref() {
424                if let Some(new_mtime) = buffer.read_with(cx, |buffer, _| {
425                    buffer.file().and_then(|file| file.disk_state().mtime())
426                }) {
427                    self.thread.update(cx, |thread, _| {
428                        thread.file_read_times.insert(abs_path.to_path_buf(), new_mtime);
429                    })?;
430                }
431            }
432
433            let new_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
434            let (new_text, unified_diff) = cx
435                .background_spawn({
436                    let new_snapshot = new_snapshot.clone();
437                    let old_text = old_text.clone();
438                    async move {
439                        let new_text = new_snapshot.text();
440                        let diff = language::unified_diff(&old_text, &new_text);
441                        (new_text, diff)
442                    }
443                })
444                .await;
445
446            let output = StreamingEditFileToolOutput {
447                input_path: input.path,
448                new_text,
449                old_text,
450                diff: unified_diff,
451            };
452
453            Ok(output)
454        })
455    }
456
457    fn replay(
458        &self,
459        _input: Self::Input,
460        output: Self::Output,
461        event_stream: ToolCallEventStream,
462        cx: &mut App,
463    ) -> Result<()> {
464        event_stream.update_diff(cx.new(|cx| {
465            Diff::finalized(
466                output.input_path.to_string_lossy().into_owned(),
467                Some(output.old_text.to_string()),
468                output.new_text,
469                self.language_registry.clone(),
470                cx,
471            )
472        }));
473        Ok(())
474    }
475}
476
477fn apply_edits(
478    buffer: &Entity<language::Buffer>,
479    edits: &[EditOperation],
480    diff: &Entity<Diff>,
481    event_stream: &ToolCallEventStream,
482    abs_path: &Option<PathBuf>,
483    cx: &mut AsyncApp,
484) -> Result<()> {
485    let mut emitted_location = false;
486    let mut failed_edits = Vec::new();
487    let mut ambiguous_edits = Vec::new();
488
489    for (index, edit) in edits.iter().enumerate() {
490        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
491        let result = apply_single_edit(buffer, &snapshot, edit, diff, cx);
492
493        match result {
494            Ok(Some(range)) => {
495                if !emitted_location {
496                    let line = buffer.update(cx, |buffer, _cx| {
497                        range.start.to_point(&buffer.snapshot()).row
498                    });
499                    if let Some(abs_path) = abs_path.clone() {
500                        event_stream.update_fields(
501                            ToolCallUpdateFields::new()
502                                .locations(vec![ToolCallLocation::new(abs_path).line(Some(line))]),
503                        );
504                    }
505                    emitted_location = true;
506                }
507            }
508            Ok(None) => {
509                failed_edits.push(index);
510            }
511            Err(ranges) => {
512                ambiguous_edits.push((index, ranges));
513            }
514        }
515    }
516
517    if !failed_edits.is_empty() {
518        let indices = failed_edits
519            .iter()
520            .map(|i| i.to_string())
521            .collect::<Vec<_>>()
522            .join(", ");
523        anyhow::bail!(
524            "Could not find matching text for edit(s) at index(es): {}. \
525             The old_text did not match any content in the file. \
526             Please read the file again to get the current content.",
527            indices
528        );
529    }
530
531    if !ambiguous_edits.is_empty() {
532        let details: Vec<String> = ambiguous_edits
533            .iter()
534            .map(|(index, ranges)| {
535                let lines = ranges
536                    .iter()
537                    .map(|r| r.start.to_string())
538                    .collect::<Vec<_>>()
539                    .join(", ");
540                format!("edit {}: matches at lines {}", index, lines)
541            })
542            .collect();
543        anyhow::bail!(
544            "Some edits matched multiple locations in the file:\n{}. \
545             Please provide more context in old_text to uniquely identify the location.",
546            details.join("\n")
547        );
548    }
549
550    Ok(())
551}
552
553fn apply_single_edit(
554    buffer: &Entity<language::Buffer>,
555    snapshot: &BufferSnapshot,
556    edit: &EditOperation,
557    diff: &Entity<Diff>,
558    cx: &mut AsyncApp,
559) -> std::result::Result<Option<Range<Anchor>>, Vec<Range<usize>>> {
560    let mut matcher = StreamingFuzzyMatcher::new(snapshot.clone());
561    matcher.push(&edit.old_text, None);
562    let matches = matcher.finish();
563
564    if matches.is_empty() {
565        return Ok(None);
566    }
567
568    if matches.len() > 1 {
569        return Err(matches);
570    }
571
572    let match_range = matches.into_iter().next().expect("checked len above");
573
574    let start_anchor = buffer.read_with(cx, |buffer, _cx| buffer.anchor_before(match_range.start));
575    let end_anchor = buffer.read_with(cx, |buffer, _cx| buffer.anchor_after(match_range.end));
576
577    diff.update(cx, |card, cx| {
578        card.reveal_range(start_anchor..end_anchor, cx)
579    });
580
581    buffer.update(cx, |buffer, cx| {
582        buffer.edit([(match_range.clone(), edit.new_text.as_str())], None, cx);
583    });
584
585    let new_end = buffer.read_with(cx, |buffer, _cx| {
586        buffer.anchor_after(match_range.start + edit.new_text.len())
587    });
588
589    Ok(Some(start_anchor..new_end))
590}
591
592fn resolve_path(
593    input: &StreamingEditFileToolInput,
594    project: Entity<Project>,
595    cx: &mut App,
596) -> Result<ProjectPath> {
597    let project = project.read(cx);
598
599    match input.mode {
600        StreamingEditFileMode::Edit | StreamingEditFileMode::Overwrite => {
601            let path = project
602                .find_project_path(&input.path, cx)
603                .context("Can't edit file: path not found")?;
604
605            let entry = project
606                .entry_for_path(&path, cx)
607                .context("Can't edit file: path not found")?;
608
609            anyhow::ensure!(entry.is_file(), "Can't edit file: path is a directory");
610            Ok(path)
611        }
612
613        StreamingEditFileMode::Create => {
614            if let Some(path) = project.find_project_path(&input.path, cx) {
615                anyhow::ensure!(
616                    project.entry_for_path(&path, cx).is_none(),
617                    "Can't create file: file already exists"
618                );
619            }
620
621            let parent_path = input
622                .path
623                .parent()
624                .context("Can't create file: incorrect path")?;
625
626            let parent_project_path = project.find_project_path(&parent_path, cx);
627
628            let parent_entry = parent_project_path
629                .as_ref()
630                .and_then(|path| project.entry_for_path(path, cx))
631                .context("Can't create file: parent directory doesn't exist")?;
632
633            anyhow::ensure!(
634                parent_entry.is_dir(),
635                "Can't create file: parent is not a directory"
636            );
637
638            let file_name = input
639                .path
640                .file_name()
641                .and_then(|file_name| file_name.to_str())
642                .and_then(|file_name| RelPath::unix(file_name).ok())
643                .context("Can't create file: invalid filename")?;
644
645            let new_file_path = parent_project_path.map(|parent| ProjectPath {
646                path: parent.path.join(file_name),
647                ..parent
648            });
649
650            new_file_path.context("Can't create file")
651        }
652    }
653}
654
655#[cfg(test)]
656mod tests {
657    use super::*;
658    use crate::{ContextServerRegistry, Templates};
659    use gpui::TestAppContext;
660    use language_model::fake_provider::FakeLanguageModel;
661    use prompt_store::ProjectContext;
662    use serde_json::json;
663    use settings::SettingsStore;
664    use util::path;
665
666    #[gpui::test]
667    async fn test_streaming_edit_create_file(cx: &mut TestAppContext) {
668        init_test(cx);
669
670        let fs = project::FakeFs::new(cx.executor());
671        fs.insert_tree("/root", json!({"dir": {}})).await;
672        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
673        let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
674        let context_server_registry =
675            cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
676        let model = Arc::new(FakeLanguageModel::default());
677        let thread = cx.new(|cx| {
678            crate::Thread::new(
679                project.clone(),
680                cx.new(|_cx| ProjectContext::default()),
681                context_server_registry,
682                Templates::new(),
683                Some(model),
684                cx,
685            )
686        });
687
688        let result = cx
689            .update(|cx| {
690                let input = StreamingEditFileToolInput {
691                    display_description: "Create new file".into(),
692                    path: "root/dir/new_file.txt".into(),
693                    mode: StreamingEditFileMode::Create,
694                    content: Some("Hello, World!".into()),
695                    edits: None,
696                };
697                Arc::new(StreamingEditFileTool::new(
698                    project.clone(),
699                    thread.downgrade(),
700                    language_registry,
701                    Templates::new(),
702                ))
703                .run(input, ToolCallEventStream::test().0, cx)
704            })
705            .await;
706
707        assert!(result.is_ok());
708        let output = result.unwrap();
709        assert_eq!(output.new_text, "Hello, World!");
710        assert!(!output.diff.is_empty());
711    }
712
713    #[gpui::test]
714    async fn test_streaming_edit_overwrite_file(cx: &mut TestAppContext) {
715        init_test(cx);
716
717        let fs = project::FakeFs::new(cx.executor());
718        fs.insert_tree("/root", json!({"file.txt": "old content"}))
719            .await;
720        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
721        let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
722        let context_server_registry =
723            cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
724        let model = Arc::new(FakeLanguageModel::default());
725        let thread = cx.new(|cx| {
726            crate::Thread::new(
727                project.clone(),
728                cx.new(|_cx| ProjectContext::default()),
729                context_server_registry,
730                Templates::new(),
731                Some(model),
732                cx,
733            )
734        });
735
736        let result = cx
737            .update(|cx| {
738                let input = StreamingEditFileToolInput {
739                    display_description: "Overwrite file".into(),
740                    path: "root/file.txt".into(),
741                    mode: StreamingEditFileMode::Overwrite,
742                    content: Some("new content".into()),
743                    edits: None,
744                };
745                Arc::new(StreamingEditFileTool::new(
746                    project.clone(),
747                    thread.downgrade(),
748                    language_registry,
749                    Templates::new(),
750                ))
751                .run(input, ToolCallEventStream::test().0, cx)
752            })
753            .await;
754
755        assert!(result.is_ok());
756        let output = result.unwrap();
757        assert_eq!(output.new_text, "new content");
758        assert_eq!(*output.old_text, "old content");
759    }
760
761    #[gpui::test]
762    async fn test_streaming_edit_granular_edits(cx: &mut TestAppContext) {
763        init_test(cx);
764
765        let fs = project::FakeFs::new(cx.executor());
766        fs.insert_tree(
767            "/root",
768            json!({
769                "file.txt": "line 1\nline 2\nline 3\n"
770            }),
771        )
772        .await;
773        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
774        let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
775        let context_server_registry =
776            cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
777        let model = Arc::new(FakeLanguageModel::default());
778        let thread = cx.new(|cx| {
779            crate::Thread::new(
780                project.clone(),
781                cx.new(|_cx| ProjectContext::default()),
782                context_server_registry,
783                Templates::new(),
784                Some(model),
785                cx,
786            )
787        });
788
789        let result = cx
790            .update(|cx| {
791                let input = StreamingEditFileToolInput {
792                    display_description: "Edit lines".into(),
793                    path: "root/file.txt".into(),
794                    mode: StreamingEditFileMode::Edit,
795                    content: None,
796                    edits: Some(vec![EditOperation {
797                        old_text: "line 2".into(),
798                        new_text: "modified line 2".into(),
799                    }]),
800                };
801                Arc::new(StreamingEditFileTool::new(
802                    project.clone(),
803                    thread.downgrade(),
804                    language_registry,
805                    Templates::new(),
806                ))
807                .run(input, ToolCallEventStream::test().0, cx)
808            })
809            .await;
810
811        assert!(result.is_ok());
812        let output = result.unwrap();
813        assert_eq!(output.new_text, "line 1\nmodified line 2\nline 3\n");
814    }
815
816    #[gpui::test]
817    async fn test_streaming_edit_nonexistent_file(cx: &mut TestAppContext) {
818        init_test(cx);
819
820        let fs = project::FakeFs::new(cx.executor());
821        fs.insert_tree("/root", json!({})).await;
822        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
823        let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
824        let context_server_registry =
825            cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
826        let model = Arc::new(FakeLanguageModel::default());
827        let thread = cx.new(|cx| {
828            crate::Thread::new(
829                project.clone(),
830                cx.new(|_cx| ProjectContext::default()),
831                context_server_registry,
832                Templates::new(),
833                Some(model),
834                cx,
835            )
836        });
837
838        let result = cx
839            .update(|cx| {
840                let input = StreamingEditFileToolInput {
841                    display_description: "Some edit".into(),
842                    path: "root/nonexistent_file.txt".into(),
843                    mode: StreamingEditFileMode::Edit,
844                    content: None,
845                    edits: Some(vec![EditOperation {
846                        old_text: "foo".into(),
847                        new_text: "bar".into(),
848                    }]),
849                };
850                Arc::new(StreamingEditFileTool::new(
851                    project,
852                    thread.downgrade(),
853                    language_registry,
854                    Templates::new(),
855                ))
856                .run(input, ToolCallEventStream::test().0, cx)
857            })
858            .await;
859
860        assert_eq!(
861            result.unwrap_err().to_string(),
862            "Can't edit file: path not found"
863        );
864    }
865
866    #[gpui::test]
867    async fn test_streaming_edit_failed_match(cx: &mut TestAppContext) {
868        init_test(cx);
869
870        let fs = project::FakeFs::new(cx.executor());
871        fs.insert_tree("/root", json!({"file.txt": "hello world"}))
872            .await;
873        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
874        let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
875        let context_server_registry =
876            cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
877        let model = Arc::new(FakeLanguageModel::default());
878        let thread = cx.new(|cx| {
879            crate::Thread::new(
880                project.clone(),
881                cx.new(|_cx| ProjectContext::default()),
882                context_server_registry,
883                Templates::new(),
884                Some(model),
885                cx,
886            )
887        });
888
889        let result = cx
890            .update(|cx| {
891                let input = StreamingEditFileToolInput {
892                    display_description: "Edit file".into(),
893                    path: "root/file.txt".into(),
894                    mode: StreamingEditFileMode::Edit,
895                    content: None,
896                    edits: Some(vec![EditOperation {
897                        old_text: "nonexistent text that is not in the file".into(),
898                        new_text: "replacement".into(),
899                    }]),
900                };
901                Arc::new(StreamingEditFileTool::new(
902                    project,
903                    thread.downgrade(),
904                    language_registry,
905                    Templates::new(),
906                ))
907                .run(input, ToolCallEventStream::test().0, cx)
908            })
909            .await;
910
911        assert!(result.is_err());
912        assert!(
913            result
914                .unwrap_err()
915                .to_string()
916                .contains("Could not find matching text")
917        );
918    }
919
920    fn init_test(cx: &mut TestAppContext) {
921        cx.update(|cx| {
922            let settings_store = SettingsStore::test(cx);
923            cx.set_global(settings_store);
924        });
925    }
926}