edit_file_tool.rs

  1use crate::{
  2    replace::{replace_exact, replace_with_flexible_indent},
  3    schema::json_schema_for,
  4};
  5use anyhow::{Context as _, Result, anyhow};
  6use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolCard, ToolResult, ToolUseStatus};
  7use buffer_diff::{BufferDiff, BufferDiffSnapshot};
  8use editor::{Editor, EditorMode, MultiBuffer, PathKey};
  9use gpui::{
 10    AnyWindowHandle, App, AppContext, AsyncApp, Context, Entity, EntityId, Task, WeakEntity,
 11};
 12use language::{
 13    Anchor, Buffer, Capability, LanguageRegistry, LineEnding, OffsetRangeExt, Rope, TextBuffer,
 14};
 15use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
 16use project::Project;
 17use schemars::JsonSchema;
 18use serde::{Deserialize, Serialize};
 19use std::{
 20    path::{Path, PathBuf},
 21    sync::Arc,
 22};
 23use ui::{Disclosure, Tooltip, Window, prelude::*};
 24use util::ResultExt;
 25use workspace::Workspace;
 26
 27pub struct EditFileTool;
 28
 29#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 30pub struct EditFileToolInput {
 31    /// A user-friendly markdown description of what's being replaced. This will be shown in the UI.
 32    ///
 33    /// <example>Fix API endpoint URLs</example>
 34    /// <example>Update copyright year in `page_footer`</example>
 35    ///
 36    /// Make sure to include this field before all the others in the input object
 37    /// so that we can display it immediately.
 38    pub display_description: String,
 39
 40    /// The full path of the file to modify in the project.
 41    ///
 42    /// WARNING: When specifying which file path need changing, you MUST
 43    /// start each path with one of the project's root directories.
 44    ///
 45    /// The following examples assume we have two root directories in the project:
 46    /// - backend
 47    /// - frontend
 48    ///
 49    /// <example>
 50    /// `backend/src/main.rs`
 51    ///
 52    /// Notice how the file path starts with root-1. Without that, the path
 53    /// would be ambiguous and the call would fail!
 54    /// </example>
 55    ///
 56    /// <example>
 57    /// `frontend/db.js`
 58    /// </example>
 59    pub path: PathBuf,
 60
 61    /// The text to replace.
 62    pub old_string: String,
 63
 64    /// The text to replace it with.
 65    pub new_string: String,
 66}
 67
 68#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 69struct PartialInput {
 70    #[serde(default)]
 71    path: String,
 72    #[serde(default)]
 73    display_description: String,
 74    #[serde(default)]
 75    old_string: String,
 76    #[serde(default)]
 77    new_string: String,
 78}
 79
 80const DEFAULT_UI_TEXT: &str = "Editing file";
 81
 82impl Tool for EditFileTool {
 83    fn name(&self) -> String {
 84        "edit_file".into()
 85    }
 86
 87    fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
 88        false
 89    }
 90
 91    fn description(&self) -> String {
 92        include_str!("edit_file_tool/description.md").to_string()
 93    }
 94
 95    fn icon(&self) -> IconName {
 96        IconName::Pencil
 97    }
 98
 99    fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
100        json_schema_for::<EditFileToolInput>(format)
101    }
102
103    fn ui_text(&self, input: &serde_json::Value) -> String {
104        match serde_json::from_value::<EditFileToolInput>(input.clone()) {
105            Ok(input) => input.display_description,
106            Err(_) => "Editing file".to_string(),
107        }
108    }
109
110    fn still_streaming_ui_text(&self, input: &serde_json::Value) -> String {
111        if let Some(input) = serde_json::from_value::<PartialInput>(input.clone()).ok() {
112            let description = input.display_description.trim();
113            if !description.is_empty() {
114                return description.to_string();
115            }
116
117            let path = input.path.trim();
118            if !path.is_empty() {
119                return path.to_string();
120            }
121        }
122
123        DEFAULT_UI_TEXT.to_string()
124    }
125
126    fn run(
127        self: Arc<Self>,
128        input: serde_json::Value,
129        _messages: &[LanguageModelRequestMessage],
130        project: Entity<Project>,
131        action_log: Entity<ActionLog>,
132        window: Option<AnyWindowHandle>,
133        cx: &mut App,
134    ) -> ToolResult {
135        let input = match serde_json::from_value::<EditFileToolInput>(input) {
136            Ok(input) => input,
137            Err(err) => return Task::ready(Err(anyhow!(err))).into(),
138        };
139
140        let card = window.and_then(|window| {
141            window
142                .update(cx, |_, window, cx| {
143                    cx.new(|cx| {
144                        EditFileToolCard::new(input.path.clone(), project.clone(), window, cx)
145                    })
146                })
147                .ok()
148        });
149
150        let card_clone = card.clone();
151        let task = cx.spawn(async move |cx: &mut AsyncApp| {
152            let project_path = project.read_with(cx, |project, cx| {
153                project
154                    .find_project_path(&input.path, cx)
155                    .context("Path not found in project")
156            })??;
157
158            let buffer = project
159                .update(cx, |project, cx| {
160                    project.open_buffer(project_path.clone(), cx)
161                })?
162                .await?;
163
164            let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
165
166            if input.old_string.is_empty() {
167                return Err(anyhow!(
168                    "`old_string` can't be empty, use another tool if you want to create a file."
169                ));
170            }
171
172            if input.old_string == input.new_string {
173                return Err(anyhow!(
174                    "The `old_string` and `new_string` are identical, so no changes would be made."
175                ));
176            }
177
178            let result = cx
179                .background_spawn(async move {
180                    // Try to match exactly
181                    let diff = replace_exact(&input.old_string, &input.new_string, &snapshot)
182                        .await
183                        // If that fails, try being flexible about indentation
184                        .or_else(|| {
185                            replace_with_flexible_indent(
186                                &input.old_string,
187                                &input.new_string,
188                                &snapshot,
189                            )
190                        })?;
191
192                    if diff.edits.is_empty() {
193                        return None;
194                    }
195
196                    let old_text = snapshot.text();
197
198                    Some((old_text, diff))
199                })
200                .await;
201
202            let Some((old_text, diff)) = result else {
203                let err = buffer.read_with(cx, |buffer, _cx| {
204                    let file_exists = buffer
205                        .file()
206                        .map_or(false, |file| file.disk_state().exists());
207
208                    if !file_exists {
209                        anyhow!("{} does not exist", input.path.display())
210                    } else if buffer.is_empty() {
211                        anyhow!(
212                            "{} is empty, so the provided `old_string` wasn't found.",
213                            input.path.display()
214                        )
215                    } else {
216                        anyhow!("Failed to match the provided `old_string`")
217                    }
218                })?;
219
220                return Err(err);
221            };
222
223            let snapshot = cx.update(|cx| {
224                action_log.update(cx, |log, cx| log.track_buffer(buffer.clone(), cx));
225
226                let snapshot = buffer.update(cx, |buffer, cx| {
227                    buffer.finalize_last_transaction();
228                    buffer.apply_diff(diff, cx);
229                    buffer.finalize_last_transaction();
230                    buffer.snapshot()
231                });
232                action_log.update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx));
233                snapshot
234            })?;
235
236            project
237                .update(cx, |project, cx| project.save_buffer(buffer, cx))?
238                .await?;
239
240            let new_text = snapshot.text();
241            let diff_str = cx
242                .background_spawn({
243                    let old_text = old_text.clone();
244                    let new_text = new_text.clone();
245                    async move { language::unified_diff(&old_text, &new_text) }
246                })
247                .await;
248
249            if let Some(card) = card_clone {
250                card.update(cx, |card, cx| {
251                    card.set_diff(project_path.path.clone(), old_text, new_text, cx);
252                })
253                .log_err();
254            }
255
256            Ok(format!(
257                "Edited {}:\n\n```diff\n{}\n```",
258                input.path.display(),
259                diff_str
260            ))
261        });
262
263        ToolResult {
264            output: task,
265            card: card.map(AnyToolCard::from),
266        }
267    }
268}
269
270pub struct EditFileToolCard {
271    path: PathBuf,
272    editor: Entity<Editor>,
273    multibuffer: Entity<MultiBuffer>,
274    project: Entity<Project>,
275    diff_task: Option<Task<Result<()>>>,
276    preview_expanded: bool,
277    full_height_expanded: bool,
278    editor_unique_id: EntityId,
279}
280
281impl EditFileToolCard {
282    fn new(path: PathBuf, project: Entity<Project>, window: &mut Window, cx: &mut App) -> Self {
283        let multibuffer = cx.new(|_| MultiBuffer::without_headers(Capability::ReadOnly));
284        let editor = cx.new(|cx| {
285            let mut editor = Editor::new(
286                EditorMode::Full {
287                    scale_ui_elements_with_buffer_font_size: false,
288                    show_active_line_background: false,
289                    sized_by_content: true,
290                },
291                multibuffer.clone(),
292                Some(project.clone()),
293                window,
294                cx,
295            );
296            editor.set_show_scrollbars(false, cx);
297            editor.set_show_gutter(false, cx);
298            editor.disable_inline_diagnostics();
299            editor.disable_scrolling(cx);
300            editor.disable_expand_excerpt_buttons(cx);
301            editor.set_show_breakpoints(false, cx);
302            editor.set_show_code_actions(false, cx);
303            editor.set_show_git_diff_gutter(false, cx);
304            editor.set_expand_all_diff_hunks(cx);
305            editor
306        });
307        Self {
308            editor_unique_id: editor.entity_id(),
309            path,
310            project,
311            editor,
312            multibuffer,
313            diff_task: None,
314            preview_expanded: true,
315            full_height_expanded: false,
316        }
317    }
318
319    fn set_diff(
320        &mut self,
321        path: Arc<Path>,
322        old_text: String,
323        new_text: String,
324        cx: &mut Context<Self>,
325    ) {
326        let language_registry = self.project.read(cx).languages().clone();
327        self.diff_task = Some(cx.spawn(async move |this, cx| {
328            let buffer = build_buffer(new_text, path.clone(), &language_registry, cx).await?;
329            let buffer_diff = build_buffer_diff(old_text, &buffer, &language_registry, cx).await?;
330
331            this.update(cx, |this, cx| {
332                this.multibuffer.update(cx, |multibuffer, cx| {
333                    let snapshot = buffer.read(cx).snapshot();
334                    let diff = buffer_diff.read(cx);
335                    let diff_hunk_ranges = diff
336                        .hunks_intersecting_range(Anchor::MIN..Anchor::MAX, &snapshot, cx)
337                        .map(|diff_hunk| diff_hunk.buffer_range.to_point(&snapshot))
338                        .collect::<Vec<_>>();
339                    let (_, is_newly_added) = multibuffer.set_excerpts_for_path(
340                        PathKey::for_buffer(&buffer, cx),
341                        buffer,
342                        diff_hunk_ranges,
343                        editor::DEFAULT_MULTIBUFFER_CONTEXT,
344                        cx,
345                    );
346                    debug_assert!(is_newly_added);
347                    multibuffer.add_diff(buffer_diff, cx);
348                });
349                cx.notify();
350            })
351        }));
352    }
353}
354
355impl ToolCard for EditFileToolCard {
356    fn render(
357        &mut self,
358        status: &ToolUseStatus,
359        window: &mut Window,
360        workspace: WeakEntity<Workspace>,
361        cx: &mut Context<Self>,
362    ) -> impl IntoElement {
363        let failed = matches!(status, ToolUseStatus::Error(_));
364
365        let path_label_button = h_flex()
366            .id(("edit-tool-path-label-button", self.editor_unique_id))
367            .w_full()
368            .max_w_full()
369            .px_1()
370            .gap_0p5()
371            .cursor_pointer()
372            .rounded_sm()
373            .opacity(0.8)
374            .hover(|label| {
375                label
376                    .opacity(1.)
377                    .bg(cx.theme().colors().element_hover.opacity(0.5))
378            })
379            .tooltip(Tooltip::text("Jump to File"))
380            .child(
381                h_flex()
382                    .child(
383                        Icon::new(IconName::Pencil)
384                            .size(IconSize::XSmall)
385                            .color(Color::Muted),
386                    )
387                    .child(
388                        div()
389                            .text_size(rems(0.8125))
390                            .child(self.path.display().to_string())
391                            .ml_1p5()
392                            .mr_0p5(),
393                    )
394                    .child(
395                        Icon::new(IconName::ArrowUpRight)
396                            .size(IconSize::XSmall)
397                            .color(Color::Ignored),
398                    ),
399            )
400            .on_click({
401                let path = self.path.clone();
402                let workspace = workspace.clone();
403                move |_, window, cx| {
404                    workspace
405                        .update(cx, {
406                            |workspace, cx| {
407                                let Some(project_path) =
408                                    workspace.project().read(cx).find_project_path(&path, cx)
409                                else {
410                                    return;
411                                };
412                                let open_task =
413                                    workspace.open_path(project_path, None, true, window, cx);
414                                window
415                                    .spawn(cx, async move |cx| {
416                                        let item = open_task.await?;
417                                        if let Some(active_editor) = item.downcast::<Editor>() {
418                                            active_editor
419                                                .update_in(cx, |editor, window, cx| {
420                                                    editor.go_to_singleton_buffer_point(
421                                                        language::Point::new(0, 0),
422                                                        window,
423                                                        cx,
424                                                    );
425                                                })
426                                                .log_err();
427                                        }
428                                        anyhow::Ok(())
429                                    })
430                                    .detach_and_log_err(cx);
431                            }
432                        })
433                        .ok();
434                }
435            })
436            .into_any_element();
437
438        let codeblock_header_bg = cx
439            .theme()
440            .colors()
441            .element_background
442            .blend(cx.theme().colors().editor_foreground.opacity(0.025));
443
444        let codeblock_header = h_flex()
445            .flex_none()
446            .p_1()
447            .gap_1()
448            .justify_between()
449            .rounded_t_md()
450            .when(!failed, |header| header.bg(codeblock_header_bg))
451            .child(path_label_button)
452            .map(|container| {
453                if failed {
454                    container.child(
455                        Icon::new(IconName::Close)
456                            .size(IconSize::Small)
457                            .color(Color::Error),
458                    )
459                } else {
460                    container.child(
461                        Disclosure::new(
462                            ("edit-file-disclosure", self.editor_unique_id),
463                            self.preview_expanded,
464                        )
465                        .opened_icon(IconName::ChevronUp)
466                        .closed_icon(IconName::ChevronDown)
467                        .on_click(cx.listener(
468                            move |this, _event, _window, _cx| {
469                                this.preview_expanded = !this.preview_expanded;
470                            },
471                        )),
472                    )
473                }
474            });
475
476        let editor = self.editor.update(cx, |editor, cx| {
477            editor.render(window, cx).into_any_element()
478        });
479
480        let (full_height_icon, full_height_tooltip_label) = if self.full_height_expanded {
481            (IconName::ChevronUp, "Collapse Code Block")
482        } else {
483            (IconName::ChevronDown, "Expand Code Block")
484        };
485
486        let gradient_overlay = div()
487            .absolute()
488            .bottom_0()
489            .left_0()
490            .w_full()
491            .h_2_5()
492            .rounded_b_lg()
493            .bg(gpui::linear_gradient(
494                0.,
495                gpui::linear_color_stop(cx.theme().colors().editor_background, 0.),
496                gpui::linear_color_stop(cx.theme().colors().editor_background.opacity(0.), 1.),
497            ));
498
499        let border_color = cx.theme().colors().border.opacity(0.6);
500
501        v_flex()
502            .mb_2()
503            .border_1()
504            .when(failed, |card| card.border_dashed())
505            .border_color(border_color)
506            .rounded_lg()
507            .overflow_hidden()
508            .child(codeblock_header)
509            .when(!failed && self.preview_expanded, |card| {
510                card.child(
511                    v_flex()
512                        .relative()
513                        .overflow_hidden()
514                        .border_t_1()
515                        .border_color(border_color)
516                        .bg(cx.theme().colors().editor_background)
517                        .map(|editor_container| {
518                            if self.full_height_expanded {
519                                editor_container.h_full()
520                            } else {
521                                editor_container.max_h_64()
522                            }
523                        })
524                        .child(div().pl_1().child(editor))
525                        .when(!self.full_height_expanded, |editor_container| {
526                            editor_container.child(gradient_overlay)
527                        }),
528                )
529            })
530            .when(!failed && self.preview_expanded, |card| {
531                card.child(
532                    h_flex()
533                        .id(("edit-tool-card-inner-hflex", self.editor_unique_id))
534                        .flex_none()
535                        .cursor_pointer()
536                        .h_5()
537                        .justify_center()
538                        .rounded_b_md()
539                        .border_t_1()
540                        .border_color(border_color)
541                        .bg(cx.theme().colors().editor_background)
542                        .hover(|style| style.bg(cx.theme().colors().element_hover.opacity(0.1)))
543                        .child(
544                            Icon::new(full_height_icon)
545                                .size(IconSize::Small)
546                                .color(Color::Muted),
547                        )
548                        .tooltip(Tooltip::text(full_height_tooltip_label))
549                        .on_click(cx.listener(move |this, _event, _window, _cx| {
550                            this.full_height_expanded = !this.full_height_expanded;
551                        })),
552                )
553            })
554    }
555}
556
557async fn build_buffer(
558    mut text: String,
559    path: Arc<Path>,
560    language_registry: &Arc<language::LanguageRegistry>,
561    cx: &mut AsyncApp,
562) -> Result<Entity<Buffer>> {
563    let line_ending = LineEnding::detect(&text);
564    LineEnding::normalize(&mut text);
565    let text = Rope::from(text);
566    let language = cx
567        .update(|_cx| language_registry.language_for_file_path(&path))?
568        .await
569        .ok();
570    let buffer = cx.new(|cx| {
571        let buffer = TextBuffer::new_normalized(
572            0,
573            cx.entity_id().as_non_zero_u64().into(),
574            line_ending,
575            text,
576        );
577        let mut buffer = Buffer::build(buffer, None, Capability::ReadWrite);
578        buffer.set_language(language, cx);
579        buffer
580    })?;
581    Ok(buffer)
582}
583
584async fn build_buffer_diff(
585    mut old_text: String,
586    buffer: &Entity<Buffer>,
587    language_registry: &Arc<LanguageRegistry>,
588    cx: &mut AsyncApp,
589) -> Result<Entity<BufferDiff>> {
590    LineEnding::normalize(&mut old_text);
591
592    let buffer = cx.update(|cx| buffer.read(cx).snapshot())?;
593
594    let base_buffer = cx
595        .update(|cx| {
596            Buffer::build_snapshot(
597                old_text.clone().into(),
598                buffer.language().cloned(),
599                Some(language_registry.clone()),
600                cx,
601            )
602        })?
603        .await;
604
605    let diff_snapshot = cx
606        .update(|cx| {
607            BufferDiffSnapshot::new_with_base_buffer(
608                buffer.text.clone(),
609                Some(old_text.into()),
610                base_buffer,
611                cx,
612            )
613        })?
614        .await;
615
616    cx.new(|cx| {
617        let mut diff = BufferDiff::new(&buffer.text, cx);
618        diff.set_snapshot(diff_snapshot, &buffer.text, cx);
619        diff
620    })
621}
622
623#[cfg(test)]
624mod tests {
625    use super::*;
626    use serde_json::json;
627
628    #[test]
629    fn still_streaming_ui_text_with_path() {
630        let input = json!({
631            "path": "src/main.rs",
632            "display_description": "",
633            "old_string": "old code",
634            "new_string": "new code"
635        });
636
637        assert_eq!(EditFileTool.still_streaming_ui_text(&input), "src/main.rs");
638    }
639
640    #[test]
641    fn still_streaming_ui_text_with_description() {
642        let input = json!({
643            "path": "",
644            "display_description": "Fix error handling",
645            "old_string": "old code",
646            "new_string": "new code"
647        });
648
649        assert_eq!(
650            EditFileTool.still_streaming_ui_text(&input),
651            "Fix error handling",
652        );
653    }
654
655    #[test]
656    fn still_streaming_ui_text_with_path_and_description() {
657        let input = json!({
658            "path": "src/main.rs",
659            "display_description": "Fix error handling",
660            "old_string": "old code",
661            "new_string": "new code"
662        });
663
664        assert_eq!(
665            EditFileTool.still_streaming_ui_text(&input),
666            "Fix error handling",
667        );
668    }
669
670    #[test]
671    fn still_streaming_ui_text_no_path_or_description() {
672        let input = json!({
673            "path": "",
674            "display_description": "",
675            "old_string": "old code",
676            "new_string": "new code"
677        });
678
679        assert_eq!(
680            EditFileTool.still_streaming_ui_text(&input),
681            DEFAULT_UI_TEXT,
682        );
683    }
684
685    #[test]
686    fn still_streaming_ui_text_with_null() {
687        let input = serde_json::Value::Null;
688
689        assert_eq!(
690            EditFileTool.still_streaming_ui_text(&input),
691            DEFAULT_UI_TEXT,
692        );
693    }
694}