edit_file_tool.rs

  1use crate::{
  2    replace::{replace_exact, replace_with_flexible_indent},
  3    schema::json_schema_for,
  4};
  5use anyhow::{Context as _, Result, anyhow};
  6use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolCard, ToolResult, ToolUseStatus};
  7use buffer_diff::{BufferDiff, BufferDiffSnapshot};
  8use editor::{Editor, EditorMode, MultiBuffer, PathKey};
  9use gpui::{
 10    AnyWindowHandle, App, AppContext, AsyncApp, Context, Entity, EntityId, Task, WeakEntity,
 11};
 12use language::{
 13    Anchor, Buffer, Capability, LanguageRegistry, LineEnding, OffsetRangeExt, Rope, TextBuffer,
 14    language_settings::SoftWrap,
 15};
 16use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
 17use project::Project;
 18use schemars::JsonSchema;
 19use serde::{Deserialize, Serialize};
 20use std::{
 21    path::{Path, PathBuf},
 22    sync::Arc,
 23};
 24use ui::{Disclosure, Tooltip, Window, prelude::*};
 25use util::ResultExt;
 26use workspace::Workspace;
 27
 28pub struct EditFileTool;
 29
 30#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 31pub struct EditFileToolInput {
 32    /// A user-friendly markdown description of what's being replaced. This will be shown in the UI.
 33    ///
 34    /// <example>Fix API endpoint URLs</example>
 35    /// <example>Update copyright year in `page_footer`</example>
 36    ///
 37    /// Make sure to include this field before all the others in the input object
 38    /// so that we can display it immediately.
 39    pub display_description: String,
 40
 41    /// The full path of the file to modify in the project.
 42    ///
 43    /// WARNING: When specifying which file path need changing, you MUST
 44    /// start each path with one of the project's root directories.
 45    ///
 46    /// The following examples assume we have two root directories in the project:
 47    /// - backend
 48    /// - frontend
 49    ///
 50    /// <example>
 51    /// `backend/src/main.rs`
 52    ///
 53    /// Notice how the file path starts with root-1. Without that, the path
 54    /// would be ambiguous and the call would fail!
 55    /// </example>
 56    ///
 57    /// <example>
 58    /// `frontend/db.js`
 59    /// </example>
 60    pub path: PathBuf,
 61
 62    /// The text to replace.
 63    pub old_string: String,
 64
 65    /// The text to replace it with.
 66    pub new_string: String,
 67}
 68
 69#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 70struct PartialInput {
 71    #[serde(default)]
 72    path: String,
 73    #[serde(default)]
 74    display_description: String,
 75    #[serde(default)]
 76    old_string: String,
 77    #[serde(default)]
 78    new_string: String,
 79}
 80
 81const DEFAULT_UI_TEXT: &str = "Editing file";
 82
 83impl Tool for EditFileTool {
 84    fn name(&self) -> String {
 85        "edit_file".into()
 86    }
 87
 88    fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
 89        false
 90    }
 91
 92    fn description(&self) -> String {
 93        include_str!("edit_file_tool/description.md").to_string()
 94    }
 95
 96    fn icon(&self) -> IconName {
 97        IconName::Pencil
 98    }
 99
100    fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
101        json_schema_for::<EditFileToolInput>(format)
102    }
103
104    fn ui_text(&self, input: &serde_json::Value) -> String {
105        match serde_json::from_value::<EditFileToolInput>(input.clone()) {
106            Ok(input) => input.display_description,
107            Err(_) => "Editing file".to_string(),
108        }
109    }
110
111    fn still_streaming_ui_text(&self, input: &serde_json::Value) -> String {
112        if let Some(input) = serde_json::from_value::<PartialInput>(input.clone()).ok() {
113            let description = input.display_description.trim();
114            if !description.is_empty() {
115                return description.to_string();
116            }
117
118            let path = input.path.trim();
119            if !path.is_empty() {
120                return path.to_string();
121            }
122        }
123
124        DEFAULT_UI_TEXT.to_string()
125    }
126
127    fn run(
128        self: Arc<Self>,
129        input: serde_json::Value,
130        _messages: &[LanguageModelRequestMessage],
131        project: Entity<Project>,
132        action_log: Entity<ActionLog>,
133        window: Option<AnyWindowHandle>,
134        cx: &mut App,
135    ) -> ToolResult {
136        let input = match serde_json::from_value::<EditFileToolInput>(input) {
137            Ok(input) => input,
138            Err(err) => return Task::ready(Err(anyhow!(err))).into(),
139        };
140
141        let card = window.and_then(|window| {
142            window
143                .update(cx, |_, window, cx| {
144                    cx.new(|cx| {
145                        EditFileToolCard::new(input.path.clone(), project.clone(), window, cx)
146                    })
147                })
148                .ok()
149        });
150
151        let card_clone = card.clone();
152        let task = cx.spawn(async move |cx: &mut AsyncApp| {
153            let project_path = project.read_with(cx, |project, cx| {
154                project
155                    .find_project_path(&input.path, cx)
156                    .context("Path not found in project")
157            })??;
158
159            let buffer = project
160                .update(cx, |project, cx| {
161                    project.open_buffer(project_path.clone(), cx)
162                })?
163                .await?;
164
165            let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
166
167            if input.old_string.is_empty() {
168                return Err(anyhow!(
169                    "`old_string` can't be empty, use another tool if you want to create a file."
170                ));
171            }
172
173            if input.old_string == input.new_string {
174                return Err(anyhow!(
175                    "The `old_string` and `new_string` are identical, so no changes would be made."
176                ));
177            }
178
179            let result = cx
180                .background_spawn(async move {
181                    // Try to match exactly
182                    let diff = replace_exact(&input.old_string, &input.new_string, &snapshot)
183                        .await
184                        // If that fails, try being flexible about indentation
185                        .or_else(|| {
186                            replace_with_flexible_indent(
187                                &input.old_string,
188                                &input.new_string,
189                                &snapshot,
190                            )
191                        })?;
192
193                    if diff.edits.is_empty() {
194                        return None;
195                    }
196
197                    let old_text = snapshot.text();
198
199                    Some((old_text, diff))
200                })
201                .await;
202
203            let Some((old_text, diff)) = result else {
204                let err = buffer.read_with(cx, |buffer, _cx| {
205                    let file_exists = buffer
206                        .file()
207                        .map_or(false, |file| file.disk_state().exists());
208
209                    if !file_exists {
210                        anyhow!("{} does not exist", input.path.display())
211                    } else if buffer.is_empty() {
212                        anyhow!(
213                            "{} is empty, so the provided `old_string` wasn't found.",
214                            input.path.display()
215                        )
216                    } else {
217                        anyhow!("Failed to match the provided `old_string`")
218                    }
219                })?;
220
221                return Err(err);
222            };
223
224            let snapshot = cx.update(|cx| {
225                action_log.update(cx, |log, cx| log.track_buffer(buffer.clone(), cx));
226
227                let snapshot = buffer.update(cx, |buffer, cx| {
228                    buffer.finalize_last_transaction();
229                    buffer.apply_diff(diff, cx);
230                    buffer.finalize_last_transaction();
231                    buffer.snapshot()
232                });
233                action_log.update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx));
234                snapshot
235            })?;
236
237            project
238                .update(cx, |project, cx| project.save_buffer(buffer, cx))?
239                .await?;
240
241            let new_text = snapshot.text();
242            let diff_str = cx
243                .background_spawn({
244                    let old_text = old_text.clone();
245                    let new_text = new_text.clone();
246                    async move { language::unified_diff(&old_text, &new_text) }
247                })
248                .await;
249
250            if let Some(card) = card_clone {
251                card.update(cx, |card, cx| {
252                    card.set_diff(project_path.path.clone(), old_text, new_text, cx);
253                })
254                .log_err();
255            }
256
257            Ok(format!(
258                "Edited {}:\n\n```diff\n{}\n```",
259                input.path.display(),
260                diff_str
261            ))
262        });
263
264        ToolResult {
265            output: task,
266            card: card.map(AnyToolCard::from),
267        }
268    }
269}
270
271pub struct EditFileToolCard {
272    path: PathBuf,
273    editor: Entity<Editor>,
274    multibuffer: Entity<MultiBuffer>,
275    project: Entity<Project>,
276    diff_task: Option<Task<Result<()>>>,
277    preview_expanded: bool,
278    error_expanded: bool,
279    full_height_expanded: bool,
280    total_lines: Option<u32>,
281    editor_unique_id: EntityId,
282}
283
284impl EditFileToolCard {
285    fn new(path: PathBuf, project: Entity<Project>, window: &mut Window, cx: &mut App) -> Self {
286        let multibuffer = cx.new(|_| MultiBuffer::without_headers(Capability::ReadOnly));
287        let editor = cx.new(|cx| {
288            let mut editor = Editor::new(
289                EditorMode::Full {
290                    scale_ui_elements_with_buffer_font_size: false,
291                    show_active_line_background: false,
292                    sized_by_content: true,
293                },
294                multibuffer.clone(),
295                Some(project.clone()),
296                window,
297                cx,
298            );
299            editor.set_show_gutter(false, cx);
300            editor.disable_inline_diagnostics();
301            editor.disable_expand_excerpt_buttons(cx);
302            editor.set_soft_wrap_mode(SoftWrap::None, cx);
303            editor.scroll_manager.set_forbid_vertical_scroll(true);
304            editor.set_show_scrollbars(false, cx);
305            editor.set_read_only(true);
306            editor.set_show_breakpoints(false, cx);
307            editor.set_show_code_actions(false, cx);
308            editor.set_show_git_diff_gutter(false, cx);
309            editor.set_expand_all_diff_hunks(cx);
310            editor
311        });
312        Self {
313            editor_unique_id: editor.entity_id(),
314            path,
315            project,
316            editor,
317            multibuffer,
318            diff_task: None,
319            preview_expanded: true,
320            error_expanded: false,
321            full_height_expanded: false,
322            total_lines: None,
323        }
324    }
325
326    fn set_diff(
327        &mut self,
328        path: Arc<Path>,
329        old_text: String,
330        new_text: String,
331        cx: &mut Context<Self>,
332    ) {
333        let language_registry = self.project.read(cx).languages().clone();
334        self.diff_task = Some(cx.spawn(async move |this, cx| {
335            let buffer = build_buffer(new_text, path.clone(), &language_registry, cx).await?;
336            let buffer_diff = build_buffer_diff(old_text, &buffer, &language_registry, cx).await?;
337
338            this.update(cx, |this, cx| {
339                this.total_lines = this.multibuffer.update(cx, |multibuffer, cx| {
340                    let snapshot = buffer.read(cx).snapshot();
341                    let diff = buffer_diff.read(cx);
342                    let diff_hunk_ranges = diff
343                        .hunks_intersecting_range(Anchor::MIN..Anchor::MAX, &snapshot, cx)
344                        .map(|diff_hunk| diff_hunk.buffer_range.to_point(&snapshot))
345                        .collect::<Vec<_>>();
346                    let (_, is_newly_added) = multibuffer.set_excerpts_for_path(
347                        PathKey::for_buffer(&buffer, cx),
348                        buffer,
349                        diff_hunk_ranges,
350                        editor::DEFAULT_MULTIBUFFER_CONTEXT,
351                        cx,
352                    );
353                    debug_assert!(is_newly_added);
354                    multibuffer.add_diff(buffer_diff, cx);
355                    let end = multibuffer.len(cx);
356                    Some(multibuffer.snapshot(cx).offset_to_point(end).row + 1)
357                });
358
359                cx.notify();
360            })
361        }));
362    }
363}
364
365impl ToolCard for EditFileToolCard {
366    fn render(
367        &mut self,
368        status: &ToolUseStatus,
369        window: &mut Window,
370        workspace: WeakEntity<Workspace>,
371        cx: &mut Context<Self>,
372    ) -> impl IntoElement {
373        let (failed, error_message) = match status {
374            ToolUseStatus::Error(err) => (true, Some(err.to_string())),
375            _ => (false, None),
376        };
377
378        let path_label_button = h_flex()
379            .id(("edit-tool-path-label-button", self.editor_unique_id))
380            .w_full()
381            .max_w_full()
382            .px_1()
383            .gap_0p5()
384            .cursor_pointer()
385            .rounded_sm()
386            .opacity(0.8)
387            .hover(|label| {
388                label
389                    .opacity(1.)
390                    .bg(cx.theme().colors().element_hover.opacity(0.5))
391            })
392            .tooltip(Tooltip::text("Jump to File"))
393            .child(
394                h_flex()
395                    .child(
396                        Icon::new(IconName::Pencil)
397                            .size(IconSize::XSmall)
398                            .color(Color::Muted),
399                    )
400                    .child(
401                        div()
402                            .text_size(rems(0.8125))
403                            .child(self.path.display().to_string())
404                            .ml_1p5()
405                            .mr_0p5(),
406                    )
407                    .child(
408                        Icon::new(IconName::ArrowUpRight)
409                            .size(IconSize::XSmall)
410                            .color(Color::Ignored),
411                    ),
412            )
413            .on_click({
414                let path = self.path.clone();
415                let workspace = workspace.clone();
416                move |_, window, cx| {
417                    workspace
418                        .update(cx, {
419                            |workspace, cx| {
420                                let Some(project_path) =
421                                    workspace.project().read(cx).find_project_path(&path, cx)
422                                else {
423                                    return;
424                                };
425                                let open_task =
426                                    workspace.open_path(project_path, None, true, window, cx);
427                                window
428                                    .spawn(cx, async move |cx| {
429                                        let item = open_task.await?;
430                                        if let Some(active_editor) = item.downcast::<Editor>() {
431                                            active_editor
432                                                .update_in(cx, |editor, window, cx| {
433                                                    editor.go_to_singleton_buffer_point(
434                                                        language::Point::new(0, 0),
435                                                        window,
436                                                        cx,
437                                                    );
438                                                })
439                                                .log_err();
440                                        }
441                                        anyhow::Ok(())
442                                    })
443                                    .detach_and_log_err(cx);
444                            }
445                        })
446                        .ok();
447                }
448            })
449            .into_any_element();
450
451        let codeblock_header_bg = cx
452            .theme()
453            .colors()
454            .element_background
455            .blend(cx.theme().colors().editor_foreground.opacity(0.025));
456
457        let codeblock_header = h_flex()
458            .flex_none()
459            .p_1()
460            .gap_1()
461            .justify_between()
462            .rounded_t_md()
463            .when(!failed, |header| header.bg(codeblock_header_bg))
464            .child(path_label_button)
465            .map(|container| {
466                if failed {
467                    container.child(
468                        h_flex()
469                            .gap_1()
470                            .child(
471                                Icon::new(IconName::Close)
472                                    .size(IconSize::Small)
473                                    .color(Color::Error),
474                            )
475                            .child(
476                                Disclosure::new(
477                                    ("edit-file-error-disclosure", self.editor_unique_id),
478                                    self.error_expanded,
479                                )
480                                .opened_icon(IconName::ChevronUp)
481                                .closed_icon(IconName::ChevronDown)
482                                .on_click(cx.listener(
483                                    move |this, _event, _window, _cx| {
484                                        this.error_expanded = !this.error_expanded;
485                                    },
486                                )),
487                            ),
488                    )
489                } else {
490                    container.child(
491                        Disclosure::new(
492                            ("edit-file-disclosure", self.editor_unique_id),
493                            self.preview_expanded,
494                        )
495                        .opened_icon(IconName::ChevronUp)
496                        .closed_icon(IconName::ChevronDown)
497                        .on_click(cx.listener(
498                            move |this, _event, _window, _cx| {
499                                this.preview_expanded = !this.preview_expanded;
500                            },
501                        )),
502                    )
503                }
504            });
505
506        let (editor, editor_line_height) = self.editor.update(cx, |editor, cx| {
507            let line_height = editor
508                .style()
509                .map(|style| style.text.line_height_in_pixels(window.rem_size()))
510                .unwrap_or_default();
511
512            let element = editor.render(window, cx);
513            (element.into_any_element(), line_height)
514        });
515
516        let (full_height_icon, full_height_tooltip_label) = if self.full_height_expanded {
517            (IconName::ChevronUp, "Collapse Code Block")
518        } else {
519            (IconName::ChevronDown, "Expand Code Block")
520        };
521
522        let gradient_overlay = div()
523            .absolute()
524            .bottom_0()
525            .left_0()
526            .w_full()
527            .h_2_5()
528            .rounded_b_lg()
529            .bg(gpui::linear_gradient(
530                0.,
531                gpui::linear_color_stop(cx.theme().colors().editor_background, 0.),
532                gpui::linear_color_stop(cx.theme().colors().editor_background.opacity(0.), 1.),
533            ));
534
535        let border_color = cx.theme().colors().border.opacity(0.6);
536
537        const DEFAULT_COLLAPSED_LINES: u32 = 10;
538        let is_collapsible = self.total_lines.unwrap_or(0) > DEFAULT_COLLAPSED_LINES;
539
540        v_flex()
541            .mb_2()
542            .border_1()
543            .when(failed, |card| card.border_dashed())
544            .border_color(border_color)
545            .rounded_lg()
546            .overflow_hidden()
547            .child(codeblock_header)
548            .when(failed && self.error_expanded, |card| {
549                card.child(
550                    v_flex()
551                        .p_2()
552                        .gap_1()
553                        .border_t_1()
554                        .border_dashed()
555                        .border_color(border_color)
556                        .bg(cx.theme().colors().editor_background)
557                        .rounded_b_md()
558                        .child(
559                            Label::new("Error")
560                                .size(LabelSize::XSmall)
561                                .color(Color::Error),
562                        )
563                        .child(
564                            div()
565                                .rounded_md()
566                                .text_ui_sm(cx)
567                                .bg(cx.theme().colors().editor_background)
568                                .children(
569                                    error_message
570                                        .map(|error| div().child(error).into_any_element()),
571                                ),
572                        ),
573                )
574            })
575            .when(!failed && self.preview_expanded, |card| {
576                card.child(
577                    v_flex()
578                        .relative()
579                        .map(|editor_container| {
580                            if self.full_height_expanded {
581                                editor_container.h_full()
582                            } else {
583                                editor_container
584                                    .h(DEFAULT_COLLAPSED_LINES as f32 * editor_line_height)
585                            }
586                        })
587                        .overflow_hidden()
588                        .border_t_1()
589                        .border_color(border_color)
590                        .bg(cx.theme().colors().editor_background)
591                        .child(div().pl_1().child(editor))
592                        .when(
593                            !self.full_height_expanded && is_collapsible,
594                            |editor_container| editor_container.child(gradient_overlay),
595                        ),
596                )
597                .when(is_collapsible, |editor_container| {
598                    editor_container.child(
599                        h_flex()
600                            .id(("expand-button", self.editor_unique_id))
601                            .flex_none()
602                            .cursor_pointer()
603                            .h_5()
604                            .justify_center()
605                            .rounded_b_md()
606                            .border_t_1()
607                            .border_color(border_color)
608                            .bg(cx.theme().colors().editor_background)
609                            .hover(|style| style.bg(cx.theme().colors().element_hover.opacity(0.1)))
610                            .child(
611                                Icon::new(full_height_icon)
612                                    .size(IconSize::Small)
613                                    .color(Color::Muted),
614                            )
615                            .tooltip(Tooltip::text(full_height_tooltip_label))
616                            .on_click(cx.listener(move |this, _event, _window, _cx| {
617                                this.full_height_expanded = !this.full_height_expanded;
618                            })),
619                    )
620                })
621            })
622    }
623}
624
625async fn build_buffer(
626    mut text: String,
627    path: Arc<Path>,
628    language_registry: &Arc<language::LanguageRegistry>,
629    cx: &mut AsyncApp,
630) -> Result<Entity<Buffer>> {
631    let line_ending = LineEnding::detect(&text);
632    LineEnding::normalize(&mut text);
633    let text = Rope::from(text);
634    let language = cx
635        .update(|_cx| language_registry.language_for_file_path(&path))?
636        .await
637        .ok();
638    let buffer = cx.new(|cx| {
639        let buffer = TextBuffer::new_normalized(
640            0,
641            cx.entity_id().as_non_zero_u64().into(),
642            line_ending,
643            text,
644        );
645        let mut buffer = Buffer::build(buffer, None, Capability::ReadWrite);
646        buffer.set_language(language, cx);
647        buffer
648    })?;
649    Ok(buffer)
650}
651
652async fn build_buffer_diff(
653    mut old_text: String,
654    buffer: &Entity<Buffer>,
655    language_registry: &Arc<LanguageRegistry>,
656    cx: &mut AsyncApp,
657) -> Result<Entity<BufferDiff>> {
658    LineEnding::normalize(&mut old_text);
659
660    let buffer = cx.update(|cx| buffer.read(cx).snapshot())?;
661
662    let base_buffer = cx
663        .update(|cx| {
664            Buffer::build_snapshot(
665                old_text.clone().into(),
666                buffer.language().cloned(),
667                Some(language_registry.clone()),
668                cx,
669            )
670        })?
671        .await;
672
673    let diff_snapshot = cx
674        .update(|cx| {
675            BufferDiffSnapshot::new_with_base_buffer(
676                buffer.text.clone(),
677                Some(old_text.into()),
678                base_buffer,
679                cx,
680            )
681        })?
682        .await;
683
684    cx.new(|cx| {
685        let mut diff = BufferDiff::new(&buffer.text, cx);
686        diff.set_snapshot(diff_snapshot, &buffer.text, cx);
687        diff
688    })
689}
690
691#[cfg(test)]
692mod tests {
693    use super::*;
694    use serde_json::json;
695
696    #[test]
697    fn still_streaming_ui_text_with_path() {
698        let input = json!({
699            "path": "src/main.rs",
700            "display_description": "",
701            "old_string": "old code",
702            "new_string": "new code"
703        });
704
705        assert_eq!(EditFileTool.still_streaming_ui_text(&input), "src/main.rs");
706    }
707
708    #[test]
709    fn still_streaming_ui_text_with_description() {
710        let input = json!({
711            "path": "",
712            "display_description": "Fix error handling",
713            "old_string": "old code",
714            "new_string": "new code"
715        });
716
717        assert_eq!(
718            EditFileTool.still_streaming_ui_text(&input),
719            "Fix error handling",
720        );
721    }
722
723    #[test]
724    fn still_streaming_ui_text_with_path_and_description() {
725        let input = json!({
726            "path": "src/main.rs",
727            "display_description": "Fix error handling",
728            "old_string": "old code",
729            "new_string": "new code"
730        });
731
732        assert_eq!(
733            EditFileTool.still_streaming_ui_text(&input),
734            "Fix error handling",
735        );
736    }
737
738    #[test]
739    fn still_streaming_ui_text_no_path_or_description() {
740        let input = json!({
741            "path": "",
742            "display_description": "",
743            "old_string": "old code",
744            "new_string": "new code"
745        });
746
747        assert_eq!(
748            EditFileTool.still_streaming_ui_text(&input),
749            DEFAULT_UI_TEXT,
750        );
751    }
752
753    #[test]
754    fn still_streaming_ui_text_with_null() {
755        let input = serde_json::Value::Null;
756
757        assert_eq!(
758            EditFileTool.still_streaming_ui_text(&input),
759            DEFAULT_UI_TEXT,
760        );
761    }
762}