edit_file_tool.rs

  1use crate::{
  2    replace::{replace_exact, replace_with_flexible_indent},
  3    schema::json_schema_for,
  4};
  5use anyhow::{Context as _, Result, anyhow};
  6use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolCard, ToolResult, ToolUseStatus};
  7use buffer_diff::{BufferDiff, BufferDiffSnapshot};
  8use editor::{Editor, EditorMode, MultiBuffer, PathKey};
  9use gpui::{
 10    AnyWindowHandle, App, AppContext, AsyncApp, Context, Entity, EntityId, Task, WeakEntity,
 11};
 12use language::{
 13    Anchor, Buffer, Capability, LanguageRegistry, LineEnding, OffsetRangeExt, Rope, TextBuffer,
 14};
 15use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
 16use project::Project;
 17use schemars::JsonSchema;
 18use serde::{Deserialize, Serialize};
 19use std::{
 20    path::{Path, PathBuf},
 21    sync::Arc,
 22};
 23use ui::{Disclosure, Tooltip, Window, prelude::*};
 24use util::ResultExt;
 25use workspace::Workspace;
 26
 27pub struct EditFileTool;
 28
 29#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 30pub struct EditFileToolInput {
 31    /// A user-friendly markdown description of what's being replaced. This will be shown in the UI.
 32    ///
 33    /// <example>Fix API endpoint URLs</example>
 34    /// <example>Update copyright year in `page_footer`</example>
 35    ///
 36    /// Make sure to include this field before all the others in the input object
 37    /// so that we can display it immediately.
 38    pub display_description: String,
 39
 40    /// The full path of the file to modify in the project.
 41    ///
 42    /// WARNING: When specifying which file path need changing, you MUST
 43    /// start each path with one of the project's root directories.
 44    ///
 45    /// The following examples assume we have two root directories in the project:
 46    /// - backend
 47    /// - frontend
 48    ///
 49    /// <example>
 50    /// `backend/src/main.rs`
 51    ///
 52    /// Notice how the file path starts with root-1. Without that, the path
 53    /// would be ambiguous and the call would fail!
 54    /// </example>
 55    ///
 56    /// <example>
 57    /// `frontend/db.js`
 58    /// </example>
 59    pub path: PathBuf,
 60
 61    /// The text to replace.
 62    pub old_string: String,
 63
 64    /// The text to replace it with.
 65    pub new_string: String,
 66}
 67
 68#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 69struct PartialInput {
 70    #[serde(default)]
 71    path: String,
 72    #[serde(default)]
 73    display_description: String,
 74    #[serde(default)]
 75    old_string: String,
 76    #[serde(default)]
 77    new_string: String,
 78}
 79
 80const DEFAULT_UI_TEXT: &str = "Editing file";
 81
 82impl Tool for EditFileTool {
 83    fn name(&self) -> String {
 84        "edit_file".into()
 85    }
 86
 87    fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
 88        false
 89    }
 90
 91    fn description(&self) -> String {
 92        include_str!("edit_file_tool/description.md").to_string()
 93    }
 94
 95    fn icon(&self) -> IconName {
 96        IconName::Pencil
 97    }
 98
 99    fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
100        json_schema_for::<EditFileToolInput>(format)
101    }
102
103    fn ui_text(&self, input: &serde_json::Value) -> String {
104        match serde_json::from_value::<EditFileToolInput>(input.clone()) {
105            Ok(input) => input.display_description,
106            Err(_) => "Editing file".to_string(),
107        }
108    }
109
110    fn still_streaming_ui_text(&self, input: &serde_json::Value) -> String {
111        if let Some(input) = serde_json::from_value::<PartialInput>(input.clone()).ok() {
112            let description = input.display_description.trim();
113            if !description.is_empty() {
114                return description.to_string();
115            }
116
117            let path = input.path.trim();
118            if !path.is_empty() {
119                return path.to_string();
120            }
121        }
122
123        DEFAULT_UI_TEXT.to_string()
124    }
125
126    fn run(
127        self: Arc<Self>,
128        input: serde_json::Value,
129        _messages: &[LanguageModelRequestMessage],
130        project: Entity<Project>,
131        action_log: Entity<ActionLog>,
132        window: Option<AnyWindowHandle>,
133        cx: &mut App,
134    ) -> ToolResult {
135        let input = match serde_json::from_value::<EditFileToolInput>(input) {
136            Ok(input) => input,
137            Err(err) => return Task::ready(Err(anyhow!(err))).into(),
138        };
139
140        let card = window.and_then(|window| {
141            window
142                .update(cx, |_, window, cx| {
143                    cx.new(|cx| {
144                        EditFileToolCard::new(input.path.clone(), project.clone(), window, cx)
145                    })
146                })
147                .ok()
148        });
149
150        let card_clone = card.clone();
151        let task = cx.spawn(async move |cx: &mut AsyncApp| {
152            let project_path = project.read_with(cx, |project, cx| {
153                project
154                    .find_project_path(&input.path, cx)
155                    .context("Path not found in project")
156            })??;
157
158            let buffer = project
159                .update(cx, |project, cx| {
160                    project.open_buffer(project_path.clone(), cx)
161                })?
162                .await?;
163
164            let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
165
166            if input.old_string.is_empty() {
167                return Err(anyhow!(
168                    "`old_string` can't be empty, use another tool if you want to create a file."
169                ));
170            }
171
172            if input.old_string == input.new_string {
173                return Err(anyhow!(
174                    "The `old_string` and `new_string` are identical, so no changes would be made."
175                ));
176            }
177            let old_string = input.old_string.clone();
178
179            let result = cx
180                .background_spawn(async move {
181                    // Try to match exactly
182                    let diff = replace_exact(&input.old_string, &input.new_string, &snapshot)
183                        .await
184                        // If that fails, try being flexible about indentation
185                        .or_else(|| {
186                            replace_with_flexible_indent(
187                                &input.old_string,
188                                &input.new_string,
189                                &snapshot,
190                            )
191                        })?;
192
193                    if diff.edits.is_empty() {
194                        return None;
195                    }
196
197                    let old_text = snapshot.text();
198
199                    Some((old_text, diff))
200                })
201                .await;
202
203            let Some((old_text, diff)) = result else {
204                let err = buffer.read_with(cx, |buffer, _cx| {
205                    let file_exists = buffer
206                        .file()
207                        .map_or(false, |file| file.disk_state().exists());
208
209                    if !file_exists {
210                        anyhow!("{} does not exist", input.path.display())
211                    } else if buffer.is_empty() {
212                        anyhow!(
213                            "{} is empty, so the provided `old_string` wasn't found.",
214                            input.path.display()
215                        )
216                    } else {
217                        let old_string_with_buffer = format!(
218                            "old_string:\n\n{}\n\n-------file-------\n\n{}",
219                            &old_string,
220                            buffer.text()
221                        );
222                        let path = {
223                            use std::collections::hash_map::DefaultHasher;
224                            use std::hash::{Hash, Hasher};
225
226                            let mut hasher = DefaultHasher::new();
227                            old_string_with_buffer.hash(&mut hasher);
228
229                            PathBuf::from(format!("failed_tool_{}.txt", hasher.finish()))
230                        };
231                        std::fs::write(path, old_string_with_buffer).unwrap();
232                        anyhow!("Failed to match the provided `old_string`")
233                    }
234                })?;
235
236                return Err(err);
237            };
238
239            let snapshot = cx.update(|cx| {
240                action_log.update(cx, |log, cx| log.track_buffer(buffer.clone(), cx));
241
242                let snapshot = buffer.update(cx, |buffer, cx| {
243                    buffer.finalize_last_transaction();
244                    buffer.apply_diff(diff, cx);
245                    buffer.finalize_last_transaction();
246                    buffer.snapshot()
247                });
248                action_log.update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx));
249                snapshot
250            })?;
251
252            project
253                .update(cx, |project, cx| project.save_buffer(buffer, cx))?
254                .await?;
255
256            let new_text = snapshot.text();
257            let diff_str = cx
258                .background_spawn({
259                    let old_text = old_text.clone();
260                    let new_text = new_text.clone();
261                    async move { language::unified_diff(&old_text, &new_text) }
262                })
263                .await;
264
265            if let Some(card) = card_clone {
266                card.update(cx, |card, cx| {
267                    card.set_diff(project_path.path.clone(), old_text, new_text, cx);
268                })
269                .log_err();
270            }
271
272            Ok(format!(
273                "Edited {}:\n\n```diff\n{}\n```",
274                input.path.display(),
275                diff_str
276            ))
277        });
278
279        ToolResult {
280            output: task,
281            card: card.map(AnyToolCard::from),
282        }
283    }
284}
285
286pub struct EditFileToolCard {
287    path: PathBuf,
288    editor: Entity<Editor>,
289    multibuffer: Entity<MultiBuffer>,
290    project: Entity<Project>,
291    diff_task: Option<Task<Result<()>>>,
292    preview_expanded: bool,
293    full_height_expanded: bool,
294    editor_unique_id: EntityId,
295}
296
297impl EditFileToolCard {
298    fn new(path: PathBuf, project: Entity<Project>, window: &mut Window, cx: &mut App) -> Self {
299        let multibuffer = cx.new(|_| MultiBuffer::without_headers(Capability::ReadOnly));
300        let editor = cx.new(|cx| {
301            let mut editor = Editor::new(
302                EditorMode::Full {
303                    scale_ui_elements_with_buffer_font_size: false,
304                    show_active_line_background: false,
305                    sized_by_content: true,
306                },
307                multibuffer.clone(),
308                Some(project.clone()),
309                window,
310                cx,
311            );
312            editor.set_show_scrollbars(false, cx);
313            editor.set_show_gutter(false, cx);
314            editor.disable_inline_diagnostics();
315            editor.disable_scrolling(cx);
316            editor.disable_expand_excerpt_buttons(cx);
317            editor.set_show_breakpoints(false, cx);
318            editor.set_show_code_actions(false, cx);
319            editor.set_show_git_diff_gutter(false, cx);
320            editor.set_expand_all_diff_hunks(cx);
321            editor
322        });
323        Self {
324            editor_unique_id: editor.entity_id(),
325            path,
326            project,
327            editor,
328            multibuffer,
329            diff_task: None,
330            preview_expanded: true,
331            full_height_expanded: false,
332        }
333    }
334
335    fn set_diff(
336        &mut self,
337        path: Arc<Path>,
338        old_text: String,
339        new_text: String,
340        cx: &mut Context<Self>,
341    ) {
342        let language_registry = self.project.read(cx).languages().clone();
343        self.diff_task = Some(cx.spawn(async move |this, cx| {
344            let buffer = build_buffer(new_text, path.clone(), &language_registry, cx).await?;
345            let buffer_diff = build_buffer_diff(old_text, &buffer, &language_registry, cx).await?;
346
347            this.update(cx, |this, cx| {
348                this.multibuffer.update(cx, |multibuffer, cx| {
349                    let snapshot = buffer.read(cx).snapshot();
350                    let diff = buffer_diff.read(cx);
351                    let diff_hunk_ranges = diff
352                        .hunks_intersecting_range(Anchor::MIN..Anchor::MAX, &snapshot, cx)
353                        .map(|diff_hunk| diff_hunk.buffer_range.to_point(&snapshot))
354                        .collect::<Vec<_>>();
355                    let (_, is_newly_added) = multibuffer.set_excerpts_for_path(
356                        PathKey::for_buffer(&buffer, cx),
357                        buffer,
358                        diff_hunk_ranges,
359                        editor::DEFAULT_MULTIBUFFER_CONTEXT,
360                        cx,
361                    );
362                    debug_assert!(is_newly_added);
363                    multibuffer.add_diff(buffer_diff, cx);
364                });
365                cx.notify();
366            })
367        }));
368    }
369}
370
371impl ToolCard for EditFileToolCard {
372    fn render(
373        &mut self,
374        status: &ToolUseStatus,
375        window: &mut Window,
376        workspace: WeakEntity<Workspace>,
377        cx: &mut Context<Self>,
378    ) -> impl IntoElement {
379        let failed = matches!(status, ToolUseStatus::Error(_));
380
381        let path_label_button = h_flex()
382            .id(("edit-tool-path-label-button", self.editor_unique_id))
383            .w_full()
384            .max_w_full()
385            .px_1()
386            .gap_0p5()
387            .cursor_pointer()
388            .rounded_sm()
389            .opacity(0.8)
390            .hover(|label| {
391                label
392                    .opacity(1.)
393                    .bg(cx.theme().colors().element_hover.opacity(0.5))
394            })
395            .tooltip(Tooltip::text("Jump to File"))
396            .child(
397                h_flex()
398                    .child(
399                        Icon::new(IconName::Pencil)
400                            .size(IconSize::XSmall)
401                            .color(Color::Muted),
402                    )
403                    .child(
404                        div()
405                            .text_size(rems(0.8125))
406                            .child(self.path.display().to_string())
407                            .ml_1p5()
408                            .mr_0p5(),
409                    )
410                    .child(
411                        Icon::new(IconName::ArrowUpRight)
412                            .size(IconSize::XSmall)
413                            .color(Color::Ignored),
414                    ),
415            )
416            .on_click({
417                let path = self.path.clone();
418                let workspace = workspace.clone();
419                move |_, window, cx| {
420                    workspace
421                        .update(cx, {
422                            |workspace, cx| {
423                                let Some(project_path) =
424                                    workspace.project().read(cx).find_project_path(&path, cx)
425                                else {
426                                    return;
427                                };
428                                let open_task =
429                                    workspace.open_path(project_path, None, true, window, cx);
430                                window
431                                    .spawn(cx, async move |cx| {
432                                        let item = open_task.await?;
433                                        if let Some(active_editor) = item.downcast::<Editor>() {
434                                            active_editor
435                                                .update_in(cx, |editor, window, cx| {
436                                                    editor.go_to_singleton_buffer_point(
437                                                        language::Point::new(0, 0),
438                                                        window,
439                                                        cx,
440                                                    );
441                                                })
442                                                .log_err();
443                                        }
444                                        anyhow::Ok(())
445                                    })
446                                    .detach_and_log_err(cx);
447                            }
448                        })
449                        .ok();
450                }
451            })
452            .into_any_element();
453
454        let codeblock_header_bg = cx
455            .theme()
456            .colors()
457            .element_background
458            .blend(cx.theme().colors().editor_foreground.opacity(0.025));
459
460        let codeblock_header = h_flex()
461            .flex_none()
462            .p_1()
463            .gap_1()
464            .justify_between()
465            .rounded_t_md()
466            .when(!failed, |header| header.bg(codeblock_header_bg))
467            .child(path_label_button)
468            .map(|container| {
469                if failed {
470                    container.child(
471                        Icon::new(IconName::Close)
472                            .size(IconSize::Small)
473                            .color(Color::Error),
474                    )
475                } else {
476                    container.child(
477                        Disclosure::new(
478                            ("edit-file-disclosure", self.editor_unique_id),
479                            self.preview_expanded,
480                        )
481                        .opened_icon(IconName::ChevronUp)
482                        .closed_icon(IconName::ChevronDown)
483                        .on_click(cx.listener(
484                            move |this, _event, _window, _cx| {
485                                this.preview_expanded = !this.preview_expanded;
486                            },
487                        )),
488                    )
489                }
490            });
491
492        let editor = self.editor.update(cx, |editor, cx| {
493            editor.render(window, cx).into_any_element()
494        });
495
496        let (full_height_icon, full_height_tooltip_label) = if self.full_height_expanded {
497            (IconName::ChevronUp, "Collapse Code Block")
498        } else {
499            (IconName::ChevronDown, "Expand Code Block")
500        };
501
502        let gradient_overlay = div()
503            .absolute()
504            .bottom_0()
505            .left_0()
506            .w_full()
507            .h_2_5()
508            .rounded_b_lg()
509            .bg(gpui::linear_gradient(
510                0.,
511                gpui::linear_color_stop(cx.theme().colors().editor_background, 0.),
512                gpui::linear_color_stop(cx.theme().colors().editor_background.opacity(0.), 1.),
513            ));
514
515        let border_color = cx.theme().colors().border.opacity(0.6);
516
517        v_flex()
518            .mb_2()
519            .border_1()
520            .when(failed, |card| card.border_dashed())
521            .border_color(border_color)
522            .rounded_lg()
523            .overflow_hidden()
524            .child(codeblock_header)
525            .when(!failed && self.preview_expanded, |card| {
526                card.child(
527                    v_flex()
528                        .relative()
529                        .overflow_hidden()
530                        .border_t_1()
531                        .border_color(border_color)
532                        .bg(cx.theme().colors().editor_background)
533                        .map(|editor_container| {
534                            if self.full_height_expanded {
535                                editor_container.h_full()
536                            } else {
537                                editor_container.max_h_64()
538                            }
539                        })
540                        .child(div().pl_1().child(editor))
541                        .when(!self.full_height_expanded, |editor_container| {
542                            editor_container.child(gradient_overlay)
543                        }),
544                )
545            })
546            .when(!failed && self.preview_expanded, |card| {
547                card.child(
548                    h_flex()
549                        .id(("edit-tool-card-inner-hflex", self.editor_unique_id))
550                        .flex_none()
551                        .cursor_pointer()
552                        .h_5()
553                        .justify_center()
554                        .rounded_b_md()
555                        .border_t_1()
556                        .border_color(border_color)
557                        .bg(cx.theme().colors().editor_background)
558                        .hover(|style| style.bg(cx.theme().colors().element_hover.opacity(0.1)))
559                        .child(
560                            Icon::new(full_height_icon)
561                                .size(IconSize::Small)
562                                .color(Color::Muted),
563                        )
564                        .tooltip(Tooltip::text(full_height_tooltip_label))
565                        .on_click(cx.listener(move |this, _event, _window, _cx| {
566                            this.full_height_expanded = !this.full_height_expanded;
567                        })),
568                )
569            })
570    }
571}
572
573async fn build_buffer(
574    mut text: String,
575    path: Arc<Path>,
576    language_registry: &Arc<language::LanguageRegistry>,
577    cx: &mut AsyncApp,
578) -> Result<Entity<Buffer>> {
579    let line_ending = LineEnding::detect(&text);
580    LineEnding::normalize(&mut text);
581    let text = Rope::from(text);
582    let language = cx
583        .update(|_cx| language_registry.language_for_file_path(&path))?
584        .await
585        .ok();
586    let buffer = cx.new(|cx| {
587        let buffer = TextBuffer::new_normalized(
588            0,
589            cx.entity_id().as_non_zero_u64().into(),
590            line_ending,
591            text,
592        );
593        let mut buffer = Buffer::build(buffer, None, Capability::ReadWrite);
594        buffer.set_language(language, cx);
595        buffer
596    })?;
597    Ok(buffer)
598}
599
600async fn build_buffer_diff(
601    mut old_text: String,
602    buffer: &Entity<Buffer>,
603    language_registry: &Arc<LanguageRegistry>,
604    cx: &mut AsyncApp,
605) -> Result<Entity<BufferDiff>> {
606    LineEnding::normalize(&mut old_text);
607
608    let buffer = cx.update(|cx| buffer.read(cx).snapshot())?;
609
610    let base_buffer = cx
611        .update(|cx| {
612            Buffer::build_snapshot(
613                old_text.clone().into(),
614                buffer.language().cloned(),
615                Some(language_registry.clone()),
616                cx,
617            )
618        })?
619        .await;
620
621    let diff_snapshot = cx
622        .update(|cx| {
623            BufferDiffSnapshot::new_with_base_buffer(
624                buffer.text.clone(),
625                Some(old_text.into()),
626                base_buffer,
627                cx,
628            )
629        })?
630        .await;
631
632    cx.new(|cx| {
633        let mut diff = BufferDiff::new(&buffer.text, cx);
634        diff.set_snapshot(diff_snapshot, &buffer.text, cx);
635        diff
636    })
637}
638
639#[cfg(test)]
640mod tests {
641    use super::*;
642    use serde_json::json;
643
644    #[test]
645    fn still_streaming_ui_text_with_path() {
646        let input = json!({
647            "path": "src/main.rs",
648            "display_description": "",
649            "old_string": "old code",
650            "new_string": "new code"
651        });
652
653        assert_eq!(EditFileTool.still_streaming_ui_text(&input), "src/main.rs");
654    }
655
656    #[test]
657    fn still_streaming_ui_text_with_description() {
658        let input = json!({
659            "path": "",
660            "display_description": "Fix error handling",
661            "old_string": "old code",
662            "new_string": "new code"
663        });
664
665        assert_eq!(
666            EditFileTool.still_streaming_ui_text(&input),
667            "Fix error handling",
668        );
669    }
670
671    #[test]
672    fn still_streaming_ui_text_with_path_and_description() {
673        let input = json!({
674            "path": "src/main.rs",
675            "display_description": "Fix error handling",
676            "old_string": "old code",
677            "new_string": "new code"
678        });
679
680        assert_eq!(
681            EditFileTool.still_streaming_ui_text(&input),
682            "Fix error handling",
683        );
684    }
685
686    #[test]
687    fn still_streaming_ui_text_no_path_or_description() {
688        let input = json!({
689            "path": "",
690            "display_description": "",
691            "old_string": "old code",
692            "new_string": "new code"
693        });
694
695        assert_eq!(
696            EditFileTool.still_streaming_ui_text(&input),
697            DEFAULT_UI_TEXT,
698        );
699    }
700
701    #[test]
702    fn still_streaming_ui_text_with_null() {
703        let input = serde_json::Value::Null;
704
705        assert_eq!(
706            EditFileTool.still_streaming_ui_text(&input),
707            DEFAULT_UI_TEXT,
708        );
709    }
710}