edit_file_tool.rs

  1use crate::{
  2    replace::{replace_exact, replace_with_flexible_indent},
  3    schema::json_schema_for,
  4};
  5use anyhow::{Context as _, Result, anyhow};
  6use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolCard, ToolResult, ToolUseStatus};
  7use buffer_diff::{BufferDiff, BufferDiffSnapshot};
  8use editor::{Editor, EditorMode, MultiBuffer, PathKey};
  9use gpui::{
 10    AnyWindowHandle, App, AppContext, AsyncApp, Context, Entity, EntityId, Task, WeakEntity,
 11};
 12use language::{
 13    Anchor, Buffer, Capability, LanguageRegistry, LineEnding, OffsetRangeExt, Rope, TextBuffer,
 14    language_settings::SoftWrap,
 15};
 16use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
 17use project::Project;
 18use schemars::JsonSchema;
 19use serde::{Deserialize, Serialize};
 20use std::{
 21    path::{Path, PathBuf},
 22    sync::Arc,
 23};
 24use ui::{Disclosure, Tooltip, Window, prelude::*};
 25use util::ResultExt;
 26use workspace::Workspace;
 27
 28pub struct EditFileTool;
 29
 30#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 31pub struct EditFileToolInput {
 32    /// A user-friendly markdown description of what's being replaced. This will be shown in the UI.
 33    ///
 34    /// <example>Fix API endpoint URLs</example>
 35    /// <example>Update copyright year in `page_footer`</example>
 36    ///
 37    /// Make sure to include this field before all the others in the input object
 38    /// so that we can display it immediately.
 39    pub display_description: String,
 40
 41    /// The full path of the file to modify in the project.
 42    ///
 43    /// WARNING: When specifying which file path need changing, you MUST
 44    /// start each path with one of the project's root directories.
 45    ///
 46    /// The following examples assume we have two root directories in the project:
 47    /// - backend
 48    /// - frontend
 49    ///
 50    /// <example>
 51    /// `backend/src/main.rs`
 52    ///
 53    /// Notice how the file path starts with root-1. Without that, the path
 54    /// would be ambiguous and the call would fail!
 55    /// </example>
 56    ///
 57    /// <example>
 58    /// `frontend/db.js`
 59    /// </example>
 60    pub path: PathBuf,
 61
 62    /// The text to replace.
 63    pub old_string: String,
 64
 65    /// The text to replace it with.
 66    pub new_string: String,
 67}
 68
 69#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 70struct PartialInput {
 71    #[serde(default)]
 72    path: String,
 73    #[serde(default)]
 74    display_description: String,
 75    #[serde(default)]
 76    old_string: String,
 77    #[serde(default)]
 78    new_string: String,
 79}
 80
 81const DEFAULT_UI_TEXT: &str = "Editing file";
 82
 83impl Tool for EditFileTool {
 84    fn name(&self) -> String {
 85        "edit_file".into()
 86    }
 87
 88    fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
 89        false
 90    }
 91
 92    fn description(&self) -> String {
 93        include_str!("edit_file_tool/description.md").to_string()
 94    }
 95
 96    fn icon(&self) -> IconName {
 97        IconName::Pencil
 98    }
 99
100    fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
101        json_schema_for::<EditFileToolInput>(format)
102    }
103
104    fn ui_text(&self, input: &serde_json::Value) -> String {
105        match serde_json::from_value::<EditFileToolInput>(input.clone()) {
106            Ok(input) => input.display_description,
107            Err(_) => "Editing file".to_string(),
108        }
109    }
110
111    fn still_streaming_ui_text(&self, input: &serde_json::Value) -> String {
112        if let Some(input) = serde_json::from_value::<PartialInput>(input.clone()).ok() {
113            let description = input.display_description.trim();
114            if !description.is_empty() {
115                return description.to_string();
116            }
117
118            let path = input.path.trim();
119            if !path.is_empty() {
120                return path.to_string();
121            }
122        }
123
124        DEFAULT_UI_TEXT.to_string()
125    }
126
127    fn run(
128        self: Arc<Self>,
129        input: serde_json::Value,
130        _messages: &[LanguageModelRequestMessage],
131        project: Entity<Project>,
132        action_log: Entity<ActionLog>,
133        window: Option<AnyWindowHandle>,
134        cx: &mut App,
135    ) -> ToolResult {
136        let input = match serde_json::from_value::<EditFileToolInput>(input) {
137            Ok(input) => input,
138            Err(err) => return Task::ready(Err(anyhow!(err))).into(),
139        };
140
141        let card = window.and_then(|window| {
142            window
143                .update(cx, |_, window, cx| {
144                    cx.new(|cx| {
145                        EditFileToolCard::new(input.path.clone(), project.clone(), window, cx)
146                    })
147                })
148                .ok()
149        });
150
151        let card_clone = card.clone();
152        let task = cx.spawn(async move |cx: &mut AsyncApp| {
153            let project_path = project.read_with(cx, |project, cx| {
154                project
155                    .find_project_path(&input.path, cx)
156                    .context("Path not found in project")
157            })??;
158
159            let buffer = project
160                .update(cx, |project, cx| {
161                    project.open_buffer(project_path.clone(), cx)
162                })?
163                .await?;
164
165            let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
166
167            if input.old_string.is_empty() {
168                return Err(anyhow!(
169                    "`old_string` can't be empty, use another tool if you want to create a file."
170                ));
171            }
172
173            if input.old_string == input.new_string {
174                return Err(anyhow!(
175                    "The `old_string` and `new_string` are identical, so no changes would be made."
176                ));
177            }
178
179            let result = cx
180                .background_spawn(async move {
181                    // Try to match exactly
182                    let diff = replace_exact(&input.old_string, &input.new_string, &snapshot)
183                        .await
184                        // If that fails, try being flexible about indentation
185                        .or_else(|| {
186                            replace_with_flexible_indent(
187                                &input.old_string,
188                                &input.new_string,
189                                &snapshot,
190                            )
191                        })?;
192
193                    if diff.edits.is_empty() {
194                        return None;
195                    }
196
197                    let old_text = snapshot.text();
198
199                    Some((old_text, diff))
200                })
201                .await;
202
203            let Some((old_text, diff)) = result else {
204                let err = buffer.read_with(cx, |buffer, _cx| {
205                    let file_exists = buffer
206                        .file()
207                        .map_or(false, |file| file.disk_state().exists());
208
209                    if !file_exists {
210                        anyhow!("{} does not exist", input.path.display())
211                    } else if buffer.is_empty() {
212                        anyhow!(
213                            "{} is empty, so the provided `old_string` wasn't found.",
214                            input.path.display()
215                        )
216                    } else {
217                        anyhow!("Failed to match the provided `old_string`")
218                    }
219                })?;
220
221                return Err(err);
222            };
223
224            let snapshot = cx.update(|cx| {
225                action_log.update(cx, |log, cx| log.track_buffer(buffer.clone(), cx));
226
227                let snapshot = buffer.update(cx, |buffer, cx| {
228                    buffer.finalize_last_transaction();
229                    buffer.apply_diff(diff, cx);
230                    buffer.finalize_last_transaction();
231                    buffer.snapshot()
232                });
233                action_log.update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx));
234                snapshot
235            })?;
236
237            project
238                .update(cx, |project, cx| project.save_buffer(buffer, cx))?
239                .await?;
240
241            let new_text = snapshot.text();
242            let diff_str = cx
243                .background_spawn({
244                    let old_text = old_text.clone();
245                    let new_text = new_text.clone();
246                    async move { language::unified_diff(&old_text, &new_text) }
247                })
248                .await;
249
250            if let Some(card) = card_clone {
251                card.update(cx, |card, cx| {
252                    card.set_diff(project_path.path.clone(), old_text, new_text, cx);
253                })
254                .log_err();
255            }
256
257            Ok(format!(
258                "Edited {}:\n\n```diff\n{}\n```",
259                input.path.display(),
260                diff_str
261            ))
262        });
263
264        ToolResult {
265            output: task,
266            card: card.map(AnyToolCard::from),
267        }
268    }
269}
270
271pub struct EditFileToolCard {
272    path: PathBuf,
273    editor: Entity<Editor>,
274    multibuffer: Entity<MultiBuffer>,
275    project: Entity<Project>,
276    diff_task: Option<Task<Result<()>>>,
277    preview_expanded: bool,
278    error_expanded: bool,
279    full_height_expanded: bool,
280    total_lines: Option<u32>,
281    editor_unique_id: EntityId,
282}
283
284impl EditFileToolCard {
285    pub fn new(path: PathBuf, project: Entity<Project>, window: &mut Window, cx: &mut App) -> Self {
286        let multibuffer = cx.new(|_| MultiBuffer::without_headers(Capability::ReadOnly));
287        let editor = cx.new(|cx| {
288            let mut editor = Editor::new(
289                EditorMode::Full {
290                    scale_ui_elements_with_buffer_font_size: false,
291                    show_active_line_background: false,
292                    sized_by_content: true,
293                },
294                multibuffer.clone(),
295                Some(project.clone()),
296                window,
297                cx,
298            );
299            editor.set_show_gutter(false, cx);
300            editor.disable_inline_diagnostics();
301            editor.disable_expand_excerpt_buttons(cx);
302            editor.set_soft_wrap_mode(SoftWrap::None, cx);
303            editor.scroll_manager.set_forbid_vertical_scroll(true);
304            editor.set_show_scrollbars(false, cx);
305            editor.set_read_only(true);
306            editor.set_show_breakpoints(false, cx);
307            editor.set_show_code_actions(false, cx);
308            editor.set_show_git_diff_gutter(false, cx);
309            editor.set_expand_all_diff_hunks(cx);
310            editor
311        });
312        Self {
313            editor_unique_id: editor.entity_id(),
314            path,
315            project,
316            editor,
317            multibuffer,
318            diff_task: None,
319            preview_expanded: true,
320            error_expanded: false,
321            full_height_expanded: false,
322            total_lines: None,
323        }
324    }
325
326    pub fn set_diff(
327        &mut self,
328        path: Arc<Path>,
329        old_text: String,
330        new_text: String,
331        cx: &mut Context<Self>,
332    ) {
333        let language_registry = self.project.read(cx).languages().clone();
334        self.diff_task = Some(cx.spawn(async move |this, cx| {
335            let buffer = build_buffer(new_text, path.clone(), &language_registry, cx).await?;
336            let buffer_diff = build_buffer_diff(old_text, &buffer, &language_registry, cx).await?;
337
338            this.update(cx, |this, cx| {
339                this.total_lines = this.multibuffer.update(cx, |multibuffer, cx| {
340                    let snapshot = buffer.read(cx).snapshot();
341                    let diff = buffer_diff.read(cx);
342                    let diff_hunk_ranges = diff
343                        .hunks_intersecting_range(Anchor::MIN..Anchor::MAX, &snapshot, cx)
344                        .map(|diff_hunk| diff_hunk.buffer_range.to_point(&snapshot))
345                        .collect::<Vec<_>>();
346                    multibuffer.clear(cx);
347                    let (_, is_newly_added) = multibuffer.set_excerpts_for_path(
348                        PathKey::for_buffer(&buffer, cx),
349                        buffer,
350                        diff_hunk_ranges,
351                        editor::DEFAULT_MULTIBUFFER_CONTEXT,
352                        cx,
353                    );
354                    debug_assert!(is_newly_added);
355                    multibuffer.add_diff(buffer_diff, cx);
356                    let end = multibuffer.len(cx);
357                    Some(multibuffer.snapshot(cx).offset_to_point(end).row + 1)
358                });
359
360                cx.notify();
361            })
362        }));
363    }
364}
365
366impl ToolCard for EditFileToolCard {
367    fn render(
368        &mut self,
369        status: &ToolUseStatus,
370        window: &mut Window,
371        workspace: WeakEntity<Workspace>,
372        cx: &mut Context<Self>,
373    ) -> impl IntoElement {
374        let (failed, error_message) = match status {
375            ToolUseStatus::Error(err) => (true, Some(err.to_string())),
376            _ => (false, None),
377        };
378
379        let path_label_button = h_flex()
380            .id(("edit-tool-path-label-button", self.editor_unique_id))
381            .w_full()
382            .max_w_full()
383            .px_1()
384            .gap_0p5()
385            .cursor_pointer()
386            .rounded_sm()
387            .opacity(0.8)
388            .hover(|label| {
389                label
390                    .opacity(1.)
391                    .bg(cx.theme().colors().element_hover.opacity(0.5))
392            })
393            .tooltip(Tooltip::text("Jump to File"))
394            .child(
395                h_flex()
396                    .child(
397                        Icon::new(IconName::Pencil)
398                            .size(IconSize::XSmall)
399                            .color(Color::Muted),
400                    )
401                    .child(
402                        div()
403                            .text_size(rems(0.8125))
404                            .child(self.path.display().to_string())
405                            .ml_1p5()
406                            .mr_0p5(),
407                    )
408                    .child(
409                        Icon::new(IconName::ArrowUpRight)
410                            .size(IconSize::XSmall)
411                            .color(Color::Ignored),
412                    ),
413            )
414            .on_click({
415                let path = self.path.clone();
416                let workspace = workspace.clone();
417                move |_, window, cx| {
418                    workspace
419                        .update(cx, {
420                            |workspace, cx| {
421                                let Some(project_path) =
422                                    workspace.project().read(cx).find_project_path(&path, cx)
423                                else {
424                                    return;
425                                };
426                                let open_task =
427                                    workspace.open_path(project_path, None, true, window, cx);
428                                window
429                                    .spawn(cx, async move |cx| {
430                                        let item = open_task.await?;
431                                        if let Some(active_editor) = item.downcast::<Editor>() {
432                                            active_editor
433                                                .update_in(cx, |editor, window, cx| {
434                                                    editor.go_to_singleton_buffer_point(
435                                                        language::Point::new(0, 0),
436                                                        window,
437                                                        cx,
438                                                    );
439                                                })
440                                                .log_err();
441                                        }
442                                        anyhow::Ok(())
443                                    })
444                                    .detach_and_log_err(cx);
445                            }
446                        })
447                        .ok();
448                }
449            })
450            .into_any_element();
451
452        let codeblock_header_bg = cx
453            .theme()
454            .colors()
455            .element_background
456            .blend(cx.theme().colors().editor_foreground.opacity(0.025));
457
458        let codeblock_header = h_flex()
459            .flex_none()
460            .p_1()
461            .gap_1()
462            .justify_between()
463            .rounded_t_md()
464            .when(!failed, |header| header.bg(codeblock_header_bg))
465            .child(path_label_button)
466            .map(|container| {
467                if failed {
468                    container.child(
469                        h_flex()
470                            .gap_1()
471                            .child(
472                                Icon::new(IconName::Close)
473                                    .size(IconSize::Small)
474                                    .color(Color::Error),
475                            )
476                            .child(
477                                Disclosure::new(
478                                    ("edit-file-error-disclosure", self.editor_unique_id),
479                                    self.error_expanded,
480                                )
481                                .opened_icon(IconName::ChevronUp)
482                                .closed_icon(IconName::ChevronDown)
483                                .on_click(cx.listener(
484                                    move |this, _event, _window, _cx| {
485                                        this.error_expanded = !this.error_expanded;
486                                    },
487                                )),
488                            ),
489                    )
490                } else {
491                    container.child(
492                        Disclosure::new(
493                            ("edit-file-disclosure", self.editor_unique_id),
494                            self.preview_expanded,
495                        )
496                        .opened_icon(IconName::ChevronUp)
497                        .closed_icon(IconName::ChevronDown)
498                        .on_click(cx.listener(
499                            move |this, _event, _window, _cx| {
500                                this.preview_expanded = !this.preview_expanded;
501                            },
502                        )),
503                    )
504                }
505            });
506
507        let (editor, editor_line_height) = self.editor.update(cx, |editor, cx| {
508            let line_height = editor
509                .style()
510                .map(|style| style.text.line_height_in_pixels(window.rem_size()))
511                .unwrap_or_default();
512
513            let element = editor.render(window, cx);
514            (element.into_any_element(), line_height)
515        });
516
517        let (full_height_icon, full_height_tooltip_label) = if self.full_height_expanded {
518            (IconName::ChevronUp, "Collapse Code Block")
519        } else {
520            (IconName::ChevronDown, "Expand Code Block")
521        };
522
523        let gradient_overlay = div()
524            .absolute()
525            .bottom_0()
526            .left_0()
527            .w_full()
528            .h_2_5()
529            .rounded_b_lg()
530            .bg(gpui::linear_gradient(
531                0.,
532                gpui::linear_color_stop(cx.theme().colors().editor_background, 0.),
533                gpui::linear_color_stop(cx.theme().colors().editor_background.opacity(0.), 1.),
534            ));
535
536        let border_color = cx.theme().colors().border.opacity(0.6);
537
538        const DEFAULT_COLLAPSED_LINES: u32 = 10;
539        let is_collapsible = self.total_lines.unwrap_or(0) > DEFAULT_COLLAPSED_LINES;
540
541        v_flex()
542            .mb_2()
543            .border_1()
544            .when(failed, |card| card.border_dashed())
545            .border_color(border_color)
546            .rounded_lg()
547            .overflow_hidden()
548            .child(codeblock_header)
549            .when(failed && self.error_expanded, |card| {
550                card.child(
551                    v_flex()
552                        .p_2()
553                        .gap_1()
554                        .border_t_1()
555                        .border_dashed()
556                        .border_color(border_color)
557                        .bg(cx.theme().colors().editor_background)
558                        .rounded_b_md()
559                        .child(
560                            Label::new("Error")
561                                .size(LabelSize::XSmall)
562                                .color(Color::Error),
563                        )
564                        .child(
565                            div()
566                                .rounded_md()
567                                .text_ui_sm(cx)
568                                .bg(cx.theme().colors().editor_background)
569                                .children(
570                                    error_message
571                                        .map(|error| div().child(error).into_any_element()),
572                                ),
573                        ),
574                )
575            })
576            .when(!failed && self.preview_expanded, |card| {
577                card.child(
578                    v_flex()
579                        .relative()
580                        .h_full()
581                        .when(!self.full_height_expanded, |editor_container| {
582                            editor_container
583                                .max_h(DEFAULT_COLLAPSED_LINES as f32 * editor_line_height)
584                        })
585                        .overflow_hidden()
586                        .border_t_1()
587                        .border_color(border_color)
588                        .bg(cx.theme().colors().editor_background)
589                        .child(div().pl_1().child(editor))
590                        .when(
591                            !self.full_height_expanded && is_collapsible,
592                            |editor_container| editor_container.child(gradient_overlay),
593                        ),
594                )
595                .when(is_collapsible, |editor_container| {
596                    editor_container.child(
597                        h_flex()
598                            .id(("expand-button", self.editor_unique_id))
599                            .flex_none()
600                            .cursor_pointer()
601                            .h_5()
602                            .justify_center()
603                            .rounded_b_md()
604                            .border_t_1()
605                            .border_color(border_color)
606                            .bg(cx.theme().colors().editor_background)
607                            .hover(|style| style.bg(cx.theme().colors().element_hover.opacity(0.1)))
608                            .child(
609                                Icon::new(full_height_icon)
610                                    .size(IconSize::Small)
611                                    .color(Color::Muted),
612                            )
613                            .tooltip(Tooltip::text(full_height_tooltip_label))
614                            .on_click(cx.listener(move |this, _event, _window, _cx| {
615                                this.full_height_expanded = !this.full_height_expanded;
616                            })),
617                    )
618                })
619            })
620    }
621}
622
623async fn build_buffer(
624    mut text: String,
625    path: Arc<Path>,
626    language_registry: &Arc<language::LanguageRegistry>,
627    cx: &mut AsyncApp,
628) -> Result<Entity<Buffer>> {
629    let line_ending = LineEnding::detect(&text);
630    LineEnding::normalize(&mut text);
631    let text = Rope::from(text);
632    let language = cx
633        .update(|_cx| language_registry.language_for_file_path(&path))?
634        .await
635        .ok();
636    let buffer = cx.new(|cx| {
637        let buffer = TextBuffer::new_normalized(
638            0,
639            cx.entity_id().as_non_zero_u64().into(),
640            line_ending,
641            text,
642        );
643        let mut buffer = Buffer::build(buffer, None, Capability::ReadWrite);
644        buffer.set_language(language, cx);
645        buffer
646    })?;
647    Ok(buffer)
648}
649
650async fn build_buffer_diff(
651    mut old_text: String,
652    buffer: &Entity<Buffer>,
653    language_registry: &Arc<LanguageRegistry>,
654    cx: &mut AsyncApp,
655) -> Result<Entity<BufferDiff>> {
656    LineEnding::normalize(&mut old_text);
657
658    let buffer = cx.update(|cx| buffer.read(cx).snapshot())?;
659
660    let base_buffer = cx
661        .update(|cx| {
662            Buffer::build_snapshot(
663                old_text.clone().into(),
664                buffer.language().cloned(),
665                Some(language_registry.clone()),
666                cx,
667            )
668        })?
669        .await;
670
671    let diff_snapshot = cx
672        .update(|cx| {
673            BufferDiffSnapshot::new_with_base_buffer(
674                buffer.text.clone(),
675                Some(old_text.into()),
676                base_buffer,
677                cx,
678            )
679        })?
680        .await;
681
682    cx.new(|cx| {
683        let mut diff = BufferDiff::new(&buffer.text, cx);
684        diff.set_snapshot(diff_snapshot, &buffer.text, cx);
685        diff
686    })
687}
688
689#[cfg(test)]
690mod tests {
691    use super::*;
692    use serde_json::json;
693
694    #[test]
695    fn still_streaming_ui_text_with_path() {
696        let input = json!({
697            "path": "src/main.rs",
698            "display_description": "",
699            "old_string": "old code",
700            "new_string": "new code"
701        });
702
703        assert_eq!(EditFileTool.still_streaming_ui_text(&input), "src/main.rs");
704    }
705
706    #[test]
707    fn still_streaming_ui_text_with_description() {
708        let input = json!({
709            "path": "",
710            "display_description": "Fix error handling",
711            "old_string": "old code",
712            "new_string": "new code"
713        });
714
715        assert_eq!(
716            EditFileTool.still_streaming_ui_text(&input),
717            "Fix error handling",
718        );
719    }
720
721    #[test]
722    fn still_streaming_ui_text_with_path_and_description() {
723        let input = json!({
724            "path": "src/main.rs",
725            "display_description": "Fix error handling",
726            "old_string": "old code",
727            "new_string": "new code"
728        });
729
730        assert_eq!(
731            EditFileTool.still_streaming_ui_text(&input),
732            "Fix error handling",
733        );
734    }
735
736    #[test]
737    fn still_streaming_ui_text_no_path_or_description() {
738        let input = json!({
739            "path": "",
740            "display_description": "",
741            "old_string": "old code",
742            "new_string": "new code"
743        });
744
745        assert_eq!(
746            EditFileTool.still_streaming_ui_text(&input),
747            DEFAULT_UI_TEXT,
748        );
749    }
750
751    #[test]
752    fn still_streaming_ui_text_with_null() {
753        let input = serde_json::Value::Null;
754
755        assert_eq!(
756            EditFileTool.still_streaming_ui_text(&input),
757            DEFAULT_UI_TEXT,
758        );
759    }
760}