edit_file_tool.rs

  1use crate::{
  2    replace::{replace_exact, replace_with_flexible_indent},
  3    schema::json_schema_for,
  4};
  5use anyhow::{Context as _, Result, anyhow};
  6use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolCard, ToolResult, ToolUseStatus};
  7use buffer_diff::{BufferDiff, BufferDiffSnapshot};
  8use editor::{Editor, EditorMode, MultiBuffer, PathKey};
  9use gpui::{
 10    AnyWindowHandle, App, AppContext, AsyncApp, Context, Entity, EntityId, Task, WeakEntity,
 11};
 12use language::{
 13    Anchor, Buffer, Capability, LanguageRegistry, LineEnding, OffsetRangeExt, Rope, TextBuffer,
 14};
 15use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
 16use project::Project;
 17use schemars::JsonSchema;
 18use serde::{Deserialize, Serialize};
 19use std::{
 20    path::{Path, PathBuf},
 21    sync::Arc,
 22};
 23use ui::{Disclosure, Tooltip, Window, prelude::*};
 24use util::ResultExt;
 25use workspace::Workspace;
 26
 27#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 28pub struct EditFileToolInput {
 29    /// The full path of the file to modify in the project.
 30    ///
 31    /// WARNING: When specifying which file path need changing, you MUST
 32    /// start each path with one of the project's root directories.
 33    ///
 34    /// The following examples assume we have two root directories in the project:
 35    /// - backend
 36    /// - frontend
 37    ///
 38    /// <example>
 39    /// `backend/src/main.rs`
 40    ///
 41    /// Notice how the file path starts with root-1. Without that, the path
 42    /// would be ambiguous and the call would fail!
 43    /// </example>
 44    ///
 45    /// <example>
 46    /// `frontend/db.js`
 47    /// </example>
 48    pub path: PathBuf,
 49
 50    /// A user-friendly markdown description of what's being replaced. This will be shown in the UI.
 51    ///
 52    /// <example>Fix API endpoint URLs</example>
 53    /// <example>Update copyright year in `page_footer`</example>
 54    pub display_description: String,
 55
 56    /// The text to replace.
 57    pub old_string: String,
 58
 59    /// The text to replace it with.
 60    pub new_string: String,
 61}
 62
 63#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 64struct PartialInput {
 65    #[serde(default)]
 66    path: String,
 67    #[serde(default)]
 68    display_description: String,
 69    #[serde(default)]
 70    old_string: String,
 71    #[serde(default)]
 72    new_string: String,
 73}
 74
 75pub struct EditFileTool;
 76
 77const DEFAULT_UI_TEXT: &str = "Editing file";
 78
 79impl Tool for EditFileTool {
 80    fn name(&self) -> String {
 81        "edit_file".into()
 82    }
 83
 84    fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
 85        false
 86    }
 87
 88    fn description(&self) -> String {
 89        include_str!("edit_file_tool/description.md").to_string()
 90    }
 91
 92    fn icon(&self) -> IconName {
 93        IconName::Pencil
 94    }
 95
 96    fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
 97        json_schema_for::<EditFileToolInput>(format)
 98    }
 99
100    fn ui_text(&self, input: &serde_json::Value) -> String {
101        match serde_json::from_value::<EditFileToolInput>(input.clone()) {
102            Ok(input) => input.display_description,
103            Err(_) => "Editing file".to_string(),
104        }
105    }
106
107    fn still_streaming_ui_text(&self, input: &serde_json::Value) -> String {
108        if let Some(input) = serde_json::from_value::<PartialInput>(input.clone()).ok() {
109            let description = input.display_description.trim();
110            if !description.is_empty() {
111                return description.to_string();
112            }
113
114            let path = input.path.trim();
115            if !path.is_empty() {
116                return path.to_string();
117            }
118        }
119
120        DEFAULT_UI_TEXT.to_string()
121    }
122
123    fn run(
124        self: Arc<Self>,
125        input: serde_json::Value,
126        _messages: &[LanguageModelRequestMessage],
127        project: Entity<Project>,
128        action_log: Entity<ActionLog>,
129        window: Option<AnyWindowHandle>,
130        cx: &mut App,
131    ) -> ToolResult {
132        let input = match serde_json::from_value::<EditFileToolInput>(input) {
133            Ok(input) => input,
134            Err(err) => return Task::ready(Err(anyhow!(err))).into(),
135        };
136
137        let card = window.and_then(|window| {
138            window
139                .update(cx, |_, window, cx| {
140                    cx.new(|cx| {
141                        EditFileToolCard::new(input.path.clone(), project.clone(), window, cx)
142                    })
143                })
144                .ok()
145        });
146
147        let card_clone = card.clone();
148        let task = cx.spawn(async move |cx: &mut AsyncApp| {
149            let project_path = project.read_with(cx, |project, cx| {
150                project
151                    .find_project_path(&input.path, cx)
152                    .context("Path not found in project")
153            })??;
154
155            let buffer = project
156                .update(cx, |project, cx| {
157                    project.open_buffer(project_path.clone(), cx)
158                })?
159                .await?;
160
161            let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
162
163            if input.old_string.is_empty() {
164                return Err(anyhow!(
165                    "`old_string` can't be empty, use another tool if you want to create a file."
166                ));
167            }
168
169            if input.old_string == input.new_string {
170                return Err(anyhow!(
171                    "The `old_string` and `new_string` are identical, so no changes would be made."
172                ));
173            }
174
175            let result = cx
176                .background_spawn(async move {
177                    // Try to match exactly
178                    let diff = replace_exact(&input.old_string, &input.new_string, &snapshot)
179                        .await
180                        // If that fails, try being flexible about indentation
181                        .or_else(|| {
182                            replace_with_flexible_indent(
183                                &input.old_string,
184                                &input.new_string,
185                                &snapshot,
186                            )
187                        })?;
188
189                    if diff.edits.is_empty() {
190                        return None;
191                    }
192
193                    let old_text = snapshot.text();
194
195                    Some((old_text, diff))
196                })
197                .await;
198
199            let Some((old_text, diff)) = result else {
200                let err = buffer.read_with(cx, |buffer, _cx| {
201                    let file_exists = buffer
202                        .file()
203                        .map_or(false, |file| file.disk_state().exists());
204
205                    if !file_exists {
206                        anyhow!("{} does not exist", input.path.display())
207                    } else if buffer.is_empty() {
208                        anyhow!(
209                            "{} is empty, so the provided `old_string` wasn't found.",
210                            input.path.display()
211                        )
212                    } else {
213                        anyhow!("Failed to match the provided `old_string`")
214                    }
215                })?;
216
217                return Err(err);
218            };
219
220            let snapshot = cx.update(|cx| {
221                action_log.update(cx, |log, cx| log.track_buffer(buffer.clone(), cx));
222
223                let snapshot = buffer.update(cx, |buffer, cx| {
224                    buffer.finalize_last_transaction();
225                    buffer.apply_diff(diff, cx);
226                    buffer.finalize_last_transaction();
227                    buffer.snapshot()
228                });
229                action_log.update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx));
230                snapshot
231            })?;
232
233            project
234                .update(cx, |project, cx| project.save_buffer(buffer, cx))?
235                .await?;
236
237            let new_text = snapshot.text();
238            let diff_str = cx
239                .background_spawn({
240                    let old_text = old_text.clone();
241                    let new_text = new_text.clone();
242                    async move { language::unified_diff(&old_text, &new_text) }
243                })
244                .await;
245
246            if let Some(card) = card_clone {
247                card.update(cx, |card, cx| {
248                    card.set_diff(project_path.path.clone(), old_text, new_text, cx);
249                })
250                .log_err();
251            }
252
253            Ok(format!(
254                "Edited {}:\n\n```diff\n{}\n```",
255                input.path.display(),
256                diff_str
257            ))
258        });
259
260        ToolResult {
261            output: task,
262            card: card.map(AnyToolCard::from),
263        }
264    }
265}
266
267pub struct EditFileToolCard {
268    path: PathBuf,
269    editor: Entity<Editor>,
270    multibuffer: Entity<MultiBuffer>,
271    project: Entity<Project>,
272    diff_task: Option<Task<Result<()>>>,
273    preview_expanded: bool,
274    full_height_expanded: bool,
275    editor_unique_id: EntityId,
276}
277
278impl EditFileToolCard {
279    fn new(path: PathBuf, project: Entity<Project>, window: &mut Window, cx: &mut App) -> Self {
280        let multibuffer = cx.new(|_| MultiBuffer::without_headers(Capability::ReadOnly));
281        let editor = cx.new(|cx| {
282            let mut editor = Editor::new(
283                EditorMode::Full {
284                    scale_ui_elements_with_buffer_font_size: false,
285                    show_active_line_background: false,
286                    sized_by_content: true,
287                },
288                multibuffer.clone(),
289                Some(project.clone()),
290                window,
291                cx,
292            );
293            editor.set_show_scrollbars(false, cx);
294            editor.set_show_gutter(false, cx);
295            editor.disable_inline_diagnostics();
296            editor.disable_scrolling(cx);
297            editor.disable_expand_excerpt_buttons(cx);
298            editor.set_show_breakpoints(false, cx);
299            editor.set_show_code_actions(false, cx);
300            editor.set_show_git_diff_gutter(false, cx);
301            editor.set_expand_all_diff_hunks(cx);
302            editor
303        });
304        Self {
305            editor_unique_id: editor.entity_id(),
306            path,
307            project,
308            editor,
309            multibuffer,
310            diff_task: None,
311            preview_expanded: true,
312            full_height_expanded: false,
313        }
314    }
315
316    fn set_diff(
317        &mut self,
318        path: Arc<Path>,
319        old_text: String,
320        new_text: String,
321        cx: &mut Context<Self>,
322    ) {
323        let language_registry = self.project.read(cx).languages().clone();
324        self.diff_task = Some(cx.spawn(async move |this, cx| {
325            let buffer = build_buffer(new_text, path.clone(), &language_registry, cx).await?;
326            let buffer_diff = build_buffer_diff(old_text, &buffer, &language_registry, cx).await?;
327
328            this.update(cx, |this, cx| {
329                this.multibuffer.update(cx, |multibuffer, cx| {
330                    let snapshot = buffer.read(cx).snapshot();
331                    let diff = buffer_diff.read(cx);
332                    let diff_hunk_ranges = diff
333                        .hunks_intersecting_range(Anchor::MIN..Anchor::MAX, &snapshot, cx)
334                        .map(|diff_hunk| diff_hunk.buffer_range.to_point(&snapshot))
335                        .collect::<Vec<_>>();
336                    let (_, is_newly_added) = multibuffer.set_excerpts_for_path(
337                        PathKey::for_buffer(&buffer, cx),
338                        buffer,
339                        diff_hunk_ranges,
340                        editor::DEFAULT_MULTIBUFFER_CONTEXT,
341                        cx,
342                    );
343                    debug_assert!(is_newly_added);
344                    multibuffer.add_diff(buffer_diff, cx);
345                });
346                cx.notify();
347            })
348        }));
349    }
350}
351
352impl ToolCard for EditFileToolCard {
353    fn render(
354        &mut self,
355        status: &ToolUseStatus,
356        window: &mut Window,
357        workspace: WeakEntity<Workspace>,
358        cx: &mut Context<Self>,
359    ) -> impl IntoElement {
360        let failed = matches!(status, ToolUseStatus::Error(_));
361
362        let path_label_button = h_flex()
363            .id(("edit-tool-path-label-button", self.editor_unique_id))
364            .w_full()
365            .max_w_full()
366            .px_1()
367            .gap_0p5()
368            .cursor_pointer()
369            .rounded_sm()
370            .opacity(0.8)
371            .hover(|label| {
372                label
373                    .opacity(1.)
374                    .bg(cx.theme().colors().element_hover.opacity(0.5))
375            })
376            .tooltip(Tooltip::text("Jump to File"))
377            .child(
378                h_flex()
379                    .child(
380                        Icon::new(IconName::Pencil)
381                            .size(IconSize::XSmall)
382                            .color(Color::Muted),
383                    )
384                    .child(
385                        div()
386                            .text_size(rems(0.8125))
387                            .child(self.path.display().to_string())
388                            .ml_1p5()
389                            .mr_0p5(),
390                    )
391                    .child(
392                        Icon::new(IconName::ArrowUpRight)
393                            .size(IconSize::XSmall)
394                            .color(Color::Ignored),
395                    ),
396            )
397            .on_click({
398                let path = self.path.clone();
399                let workspace = workspace.clone();
400                move |_, window, cx| {
401                    workspace
402                        .update(cx, {
403                            |workspace, cx| {
404                                let Some(project_path) =
405                                    workspace.project().read(cx).find_project_path(&path, cx)
406                                else {
407                                    return;
408                                };
409                                let open_task =
410                                    workspace.open_path(project_path, None, true, window, cx);
411                                window
412                                    .spawn(cx, async move |cx| {
413                                        let item = open_task.await?;
414                                        if let Some(active_editor) = item.downcast::<Editor>() {
415                                            active_editor
416                                                .update_in(cx, |editor, window, cx| {
417                                                    editor.go_to_singleton_buffer_point(
418                                                        language::Point::new(0, 0),
419                                                        window,
420                                                        cx,
421                                                    );
422                                                })
423                                                .log_err();
424                                        }
425                                        anyhow::Ok(())
426                                    })
427                                    .detach_and_log_err(cx);
428                            }
429                        })
430                        .ok();
431                }
432            })
433            .into_any_element();
434
435        let codeblock_header_bg = cx
436            .theme()
437            .colors()
438            .element_background
439            .blend(cx.theme().colors().editor_foreground.opacity(0.025));
440
441        let codeblock_header = h_flex()
442            .flex_none()
443            .p_1()
444            .gap_1()
445            .justify_between()
446            .rounded_t_md()
447            .when(!failed, |header| header.bg(codeblock_header_bg))
448            .child(path_label_button)
449            .map(|container| {
450                if failed {
451                    container.child(
452                        Icon::new(IconName::Close)
453                            .size(IconSize::Small)
454                            .color(Color::Error),
455                    )
456                } else {
457                    container.child(
458                        Disclosure::new(
459                            ("edit-file-disclosure", self.editor_unique_id),
460                            self.preview_expanded,
461                        )
462                        .opened_icon(IconName::ChevronUp)
463                        .closed_icon(IconName::ChevronDown)
464                        .on_click(cx.listener(
465                            move |this, _event, _window, _cx| {
466                                this.preview_expanded = !this.preview_expanded;
467                            },
468                        )),
469                    )
470                }
471            });
472
473        let editor = self.editor.update(cx, |editor, cx| {
474            editor.render(window, cx).into_any_element()
475        });
476
477        let (full_height_icon, full_height_tooltip_label) = if self.full_height_expanded {
478            (IconName::ChevronUp, "Collapse Code Block")
479        } else {
480            (IconName::ChevronDown, "Expand Code Block")
481        };
482
483        let gradient_overlay = div()
484            .absolute()
485            .bottom_0()
486            .left_0()
487            .w_full()
488            .h_2_5()
489            .rounded_b_lg()
490            .bg(gpui::linear_gradient(
491                0.,
492                gpui::linear_color_stop(cx.theme().colors().editor_background, 0.),
493                gpui::linear_color_stop(cx.theme().colors().editor_background.opacity(0.), 1.),
494            ));
495
496        let border_color = cx.theme().colors().border.opacity(0.6);
497
498        v_flex()
499            .mb_2()
500            .border_1()
501            .when(failed, |card| card.border_dashed())
502            .border_color(border_color)
503            .rounded_lg()
504            .overflow_hidden()
505            .child(codeblock_header)
506            .when(!failed && self.preview_expanded, |card| {
507                card.child(
508                    v_flex()
509                        .relative()
510                        .overflow_hidden()
511                        .border_t_1()
512                        .border_color(border_color)
513                        .bg(cx.theme().colors().editor_background)
514                        .map(|editor_container| {
515                            if self.full_height_expanded {
516                                editor_container.h_full()
517                            } else {
518                                editor_container.max_h_64()
519                            }
520                        })
521                        .child(div().pl_1().child(editor))
522                        .when(!self.full_height_expanded, |editor_container| {
523                            editor_container.child(gradient_overlay)
524                        }),
525                )
526            })
527            .when(!failed && self.preview_expanded, |card| {
528                card.child(
529                    h_flex()
530                        .id(("edit-tool-card-inner-hflex", self.editor_unique_id))
531                        .flex_none()
532                        .cursor_pointer()
533                        .h_5()
534                        .justify_center()
535                        .rounded_b_md()
536                        .border_t_1()
537                        .border_color(border_color)
538                        .bg(cx.theme().colors().editor_background)
539                        .hover(|style| style.bg(cx.theme().colors().element_hover.opacity(0.1)))
540                        .child(
541                            Icon::new(full_height_icon)
542                                .size(IconSize::Small)
543                                .color(Color::Muted),
544                        )
545                        .tooltip(Tooltip::text(full_height_tooltip_label))
546                        .on_click(cx.listener(move |this, _event, _window, _cx| {
547                            this.full_height_expanded = !this.full_height_expanded;
548                        })),
549                )
550            })
551    }
552}
553
554async fn build_buffer(
555    mut text: String,
556    path: Arc<Path>,
557    language_registry: &Arc<language::LanguageRegistry>,
558    cx: &mut AsyncApp,
559) -> Result<Entity<Buffer>> {
560    let line_ending = LineEnding::detect(&text);
561    LineEnding::normalize(&mut text);
562    let text = Rope::from(text);
563    let language = cx
564        .update(|_cx| language_registry.language_for_file_path(&path))?
565        .await
566        .ok();
567    let buffer = cx.new(|cx| {
568        let buffer = TextBuffer::new_normalized(
569            0,
570            cx.entity_id().as_non_zero_u64().into(),
571            line_ending,
572            text,
573        );
574        let mut buffer = Buffer::build(buffer, None, Capability::ReadWrite);
575        buffer.set_language(language, cx);
576        buffer
577    })?;
578    Ok(buffer)
579}
580
581async fn build_buffer_diff(
582    mut old_text: String,
583    buffer: &Entity<Buffer>,
584    language_registry: &Arc<LanguageRegistry>,
585    cx: &mut AsyncApp,
586) -> Result<Entity<BufferDiff>> {
587    LineEnding::normalize(&mut old_text);
588
589    let buffer = cx.update(|cx| buffer.read(cx).snapshot())?;
590
591    let base_buffer = cx
592        .update(|cx| {
593            Buffer::build_snapshot(
594                old_text.clone().into(),
595                buffer.language().cloned(),
596                Some(language_registry.clone()),
597                cx,
598            )
599        })?
600        .await;
601
602    let diff_snapshot = cx
603        .update(|cx| {
604            BufferDiffSnapshot::new_with_base_buffer(
605                buffer.text.clone(),
606                Some(old_text.into()),
607                base_buffer,
608                cx,
609            )
610        })?
611        .await;
612
613    cx.new(|cx| {
614        let mut diff = BufferDiff::new(&buffer.text, cx);
615        diff.set_snapshot(diff_snapshot, &buffer.text, cx);
616        diff
617    })
618}
619
620#[cfg(test)]
621mod tests {
622    use super::*;
623    use serde_json::json;
624
625    #[test]
626    fn still_streaming_ui_text_with_path() {
627        let tool = EditFileTool;
628        let input = json!({
629            "path": "src/main.rs",
630            "display_description": "",
631            "old_string": "old code",
632            "new_string": "new code"
633        });
634
635        assert_eq!(tool.still_streaming_ui_text(&input), "src/main.rs");
636    }
637
638    #[test]
639    fn still_streaming_ui_text_with_description() {
640        let tool = EditFileTool;
641        let input = json!({
642            "path": "",
643            "display_description": "Fix error handling",
644            "old_string": "old code",
645            "new_string": "new code"
646        });
647
648        assert_eq!(tool.still_streaming_ui_text(&input), "Fix error handling");
649    }
650
651    #[test]
652    fn still_streaming_ui_text_with_path_and_description() {
653        let tool = EditFileTool;
654        let input = json!({
655            "path": "src/main.rs",
656            "display_description": "Fix error handling",
657            "old_string": "old code",
658            "new_string": "new code"
659        });
660
661        assert_eq!(tool.still_streaming_ui_text(&input), "Fix error handling");
662    }
663
664    #[test]
665    fn still_streaming_ui_text_no_path_or_description() {
666        let tool = EditFileTool;
667        let input = json!({
668            "path": "",
669            "display_description": "",
670            "old_string": "old code",
671            "new_string": "new code"
672        });
673
674        assert_eq!(tool.still_streaming_ui_text(&input), DEFAULT_UI_TEXT);
675    }
676
677    #[test]
678    fn still_streaming_ui_text_with_null() {
679        let tool = EditFileTool;
680        let input = serde_json::Value::Null;
681
682        assert_eq!(tool.still_streaming_ui_text(&input), DEFAULT_UI_TEXT);
683    }
684}