agent_diff.rs

   1use crate::{Keep, KeepAll, Reject, RejectAll, Thread, ThreadEvent};
   2use anyhow::Result;
   3use buffer_diff::DiffHunkStatus;
   4use collections::{HashMap, HashSet};
   5use editor::{
   6    Direction, Editor, EditorEvent, MultiBuffer, ToPoint,
   7    actions::{GoToHunk, GoToPreviousHunk},
   8    scroll::Autoscroll,
   9};
  10use gpui::{
  11    Action, AnyElement, AnyView, App, Entity, EventEmitter, FocusHandle, Focusable, SharedString,
  12    Subscription, Task, WeakEntity, Window, prelude::*,
  13};
  14use language::{Capability, DiskState, OffsetRangeExt, Point};
  15use multi_buffer::PathKey;
  16use project::{Project, ProjectPath};
  17use std::{
  18    any::{Any, TypeId},
  19    ops::Range,
  20    sync::Arc,
  21};
  22use ui::{IconButtonShape, KeyBinding, Tooltip, prelude::*};
  23use workspace::{
  24    Item, ItemHandle, ItemNavHistory, ToolbarItemEvent, ToolbarItemLocation, ToolbarItemView,
  25    Workspace,
  26    item::{BreadcrumbText, ItemEvent, TabContentParams},
  27    searchable::SearchableItemHandle,
  28};
  29use zed_actions::assistant::ToggleFocus;
  30
  31pub struct AgentDiff {
  32    multibuffer: Entity<MultiBuffer>,
  33    editor: Entity<Editor>,
  34    thread: Entity<Thread>,
  35    focus_handle: FocusHandle,
  36    workspace: WeakEntity<Workspace>,
  37    title: SharedString,
  38    _subscriptions: Vec<Subscription>,
  39}
  40
  41impl AgentDiff {
  42    pub fn deploy(
  43        thread: Entity<Thread>,
  44        workspace: WeakEntity<Workspace>,
  45        window: &mut Window,
  46        cx: &mut App,
  47    ) -> Result<Entity<Self>> {
  48        workspace.update(cx, |workspace, cx| {
  49            Self::deploy_in_workspace(thread, workspace, window, cx)
  50        })
  51    }
  52
  53    pub fn deploy_in_workspace(
  54        thread: Entity<Thread>,
  55        workspace: &mut Workspace,
  56        window: &mut Window,
  57        cx: &mut Context<Workspace>,
  58    ) -> Entity<Self> {
  59        let existing_diff = workspace
  60            .items_of_type::<AgentDiff>(cx)
  61            .find(|diff| diff.read(cx).thread == thread);
  62        if let Some(existing_diff) = existing_diff {
  63            workspace.activate_item(&existing_diff, true, true, window, cx);
  64            existing_diff
  65        } else {
  66            let agent_diff =
  67                cx.new(|cx| AgentDiff::new(thread.clone(), workspace.weak_handle(), window, cx));
  68            workspace.add_item_to_center(Box::new(agent_diff.clone()), window, cx);
  69            agent_diff
  70        }
  71    }
  72
  73    pub fn new(
  74        thread: Entity<Thread>,
  75        workspace: WeakEntity<Workspace>,
  76        window: &mut Window,
  77        cx: &mut Context<Self>,
  78    ) -> Self {
  79        let focus_handle = cx.focus_handle();
  80        let multibuffer = cx.new(|_| MultiBuffer::new(Capability::ReadWrite));
  81
  82        let project = thread.read(cx).project().clone();
  83        let render_diff_hunk_controls = Arc::new({
  84            let agent_diff = cx.entity();
  85            move |row,
  86                  status: &DiffHunkStatus,
  87                  hunk_range,
  88                  is_created_file,
  89                  line_height,
  90                  editor: &Entity<Editor>,
  91                  window: &mut Window,
  92                  cx: &mut App| {
  93                render_diff_hunk_controls(
  94                    row,
  95                    status,
  96                    hunk_range,
  97                    is_created_file,
  98                    line_height,
  99                    &agent_diff,
 100                    editor,
 101                    window,
 102                    cx,
 103                )
 104            }
 105        });
 106        let editor = cx.new(|cx| {
 107            let mut editor =
 108                Editor::for_multibuffer(multibuffer.clone(), Some(project.clone()), window, cx);
 109            editor.disable_inline_diagnostics();
 110            editor.set_expand_all_diff_hunks(cx);
 111            editor.set_render_diff_hunk_controls(render_diff_hunk_controls, cx);
 112            editor.register_addon(AgentDiffAddon);
 113            editor
 114        });
 115
 116        let action_log = thread.read(cx).action_log().clone();
 117        let mut this = Self {
 118            _subscriptions: vec![
 119                cx.observe_in(&action_log, window, |this, _action_log, window, cx| {
 120                    this.update_excerpts(window, cx)
 121                }),
 122                cx.subscribe(&thread, |this, _thread, event, cx| {
 123                    this.handle_thread_event(event, cx)
 124                }),
 125            ],
 126            title: SharedString::default(),
 127            multibuffer,
 128            editor,
 129            thread,
 130            focus_handle,
 131            workspace,
 132        };
 133        this.update_excerpts(window, cx);
 134        this.update_title(cx);
 135        this
 136    }
 137
 138    fn update_excerpts(&mut self, window: &mut Window, cx: &mut Context<Self>) {
 139        let thread = self.thread.read(cx);
 140        let changed_buffers = thread.action_log().read(cx).changed_buffers(cx);
 141        let mut paths_to_delete = self.multibuffer.read(cx).paths().collect::<HashSet<_>>();
 142
 143        for (buffer, diff_handle) in changed_buffers {
 144            if buffer.read(cx).file().is_none() {
 145                continue;
 146            }
 147
 148            let path_key = PathKey::for_buffer(&buffer, cx);
 149            paths_to_delete.remove(&path_key);
 150
 151            let snapshot = buffer.read(cx).snapshot();
 152            let diff = diff_handle.read(cx);
 153
 154            let diff_hunk_ranges = diff
 155                .hunks_intersecting_range(
 156                    language::Anchor::MIN..language::Anchor::MAX,
 157                    &snapshot,
 158                    cx,
 159                )
 160                .map(|diff_hunk| diff_hunk.buffer_range.to_point(&snapshot))
 161                .collect::<Vec<_>>();
 162
 163            let (was_empty, is_excerpt_newly_added) =
 164                self.multibuffer.update(cx, |multibuffer, cx| {
 165                    let was_empty = multibuffer.is_empty();
 166                    let (_, is_excerpt_newly_added) = multibuffer.set_excerpts_for_path(
 167                        path_key.clone(),
 168                        buffer.clone(),
 169                        diff_hunk_ranges,
 170                        editor::DEFAULT_MULTIBUFFER_CONTEXT,
 171                        cx,
 172                    );
 173                    multibuffer.add_diff(diff_handle, cx);
 174                    (was_empty, is_excerpt_newly_added)
 175                });
 176
 177            self.editor.update(cx, |editor, cx| {
 178                if was_empty {
 179                    let first_hunk = editor
 180                        .diff_hunks_in_ranges(
 181                            &[editor::Anchor::min()..editor::Anchor::max()],
 182                            &self.multibuffer.read(cx).read(cx),
 183                        )
 184                        .next();
 185
 186                    if let Some(first_hunk) = first_hunk {
 187                        let first_hunk_start = first_hunk.multi_buffer_range().start;
 188                        editor.change_selections(
 189                            Some(Autoscroll::fit()),
 190                            window,
 191                            cx,
 192                            |selections| {
 193                                selections
 194                                    .select_anchor_ranges([first_hunk_start..first_hunk_start]);
 195                            },
 196                        )
 197                    }
 198                }
 199
 200                if is_excerpt_newly_added
 201                    && buffer
 202                        .read(cx)
 203                        .file()
 204                        .map_or(false, |file| file.disk_state() == DiskState::Deleted)
 205                {
 206                    editor.fold_buffer(snapshot.text.remote_id(), cx)
 207                }
 208            });
 209        }
 210
 211        self.multibuffer.update(cx, |multibuffer, cx| {
 212            for path in paths_to_delete {
 213                multibuffer.remove_excerpts_for_path(path, cx);
 214            }
 215        });
 216
 217        if self.multibuffer.read(cx).is_empty()
 218            && self
 219                .editor
 220                .read(cx)
 221                .focus_handle(cx)
 222                .contains_focused(window, cx)
 223        {
 224            self.focus_handle.focus(window);
 225        } else if self.focus_handle.is_focused(window) && !self.multibuffer.read(cx).is_empty() {
 226            self.editor.update(cx, |editor, cx| {
 227                editor.focus_handle(cx).focus(window);
 228            });
 229        }
 230    }
 231
 232    fn update_title(&mut self, cx: &mut Context<Self>) {
 233        let new_title = self
 234            .thread
 235            .read(cx)
 236            .summary()
 237            .unwrap_or("Assistant Changes".into());
 238        if new_title != self.title {
 239            self.title = new_title;
 240            cx.emit(EditorEvent::TitleChanged);
 241        }
 242    }
 243
 244    fn handle_thread_event(&mut self, event: &ThreadEvent, cx: &mut Context<Self>) {
 245        match event {
 246            ThreadEvent::SummaryGenerated => self.update_title(cx),
 247            _ => {}
 248        }
 249    }
 250
 251    pub fn move_to_path(&mut self, path_key: PathKey, window: &mut Window, cx: &mut Context<Self>) {
 252        if let Some(position) = self.multibuffer.read(cx).location_for_path(&path_key, cx) {
 253            self.editor.update(cx, |editor, cx| {
 254                let first_hunk = editor
 255                    .diff_hunks_in_ranges(
 256                        &[position..editor::Anchor::max()],
 257                        &self.multibuffer.read(cx).read(cx),
 258                    )
 259                    .next();
 260
 261                if let Some(first_hunk) = first_hunk {
 262                    let first_hunk_start = first_hunk.multi_buffer_range().start;
 263                    editor.change_selections(Some(Autoscroll::fit()), window, cx, |selections| {
 264                        selections.select_anchor_ranges([first_hunk_start..first_hunk_start]);
 265                    })
 266                }
 267            });
 268        }
 269    }
 270
 271    fn keep(&mut self, _: &crate::Keep, window: &mut Window, cx: &mut Context<Self>) {
 272        let ranges = self
 273            .editor
 274            .read(cx)
 275            .selections
 276            .disjoint_anchor_ranges()
 277            .collect::<Vec<_>>();
 278        self.keep_edits_in_ranges(ranges, window, cx);
 279    }
 280
 281    fn reject(&mut self, _: &crate::Reject, window: &mut Window, cx: &mut Context<Self>) {
 282        let ranges = self
 283            .editor
 284            .read(cx)
 285            .selections
 286            .disjoint_anchor_ranges()
 287            .collect::<Vec<_>>();
 288        self.reject_edits_in_ranges(ranges, window, cx);
 289    }
 290
 291    fn reject_all(&mut self, _: &crate::RejectAll, window: &mut Window, cx: &mut Context<Self>) {
 292        self.reject_edits_in_ranges(
 293            vec![editor::Anchor::min()..editor::Anchor::max()],
 294            window,
 295            cx,
 296        );
 297    }
 298
 299    fn keep_all(&mut self, _: &crate::KeepAll, _window: &mut Window, cx: &mut Context<Self>) {
 300        self.thread
 301            .update(cx, |thread, cx| thread.keep_all_edits(cx));
 302    }
 303
 304    fn keep_edits_in_ranges(
 305        &mut self,
 306        ranges: Vec<Range<editor::Anchor>>,
 307        window: &mut Window,
 308        cx: &mut Context<Self>,
 309    ) {
 310        let snapshot = self.multibuffer.read(cx).snapshot(cx);
 311        let diff_hunks_in_ranges = self
 312            .editor
 313            .read(cx)
 314            .diff_hunks_in_ranges(&ranges, &snapshot)
 315            .collect::<Vec<_>>();
 316        let newest_cursor = self.editor.update(cx, |editor, cx| {
 317            editor.selections.newest::<Point>(cx).head()
 318        });
 319        if diff_hunks_in_ranges.iter().any(|hunk| {
 320            hunk.row_range
 321                .contains(&multi_buffer::MultiBufferRow(newest_cursor.row))
 322        }) {
 323            self.update_selection(&diff_hunks_in_ranges, window, cx);
 324        }
 325
 326        for hunk in &diff_hunks_in_ranges {
 327            let buffer = self.multibuffer.read(cx).buffer(hunk.buffer_id);
 328            if let Some(buffer) = buffer {
 329                self.thread.update(cx, |thread, cx| {
 330                    thread.keep_edits_in_range(buffer, hunk.buffer_range.clone(), cx)
 331                });
 332            }
 333        }
 334    }
 335
 336    fn reject_edits_in_ranges(
 337        &mut self,
 338        ranges: Vec<Range<editor::Anchor>>,
 339        window: &mut Window,
 340        cx: &mut Context<Self>,
 341    ) {
 342        let snapshot = self.multibuffer.read(cx).snapshot(cx);
 343        let diff_hunks_in_ranges = self
 344            .editor
 345            .read(cx)
 346            .diff_hunks_in_ranges(&ranges, &snapshot)
 347            .collect::<Vec<_>>();
 348        let newest_cursor = self.editor.update(cx, |editor, cx| {
 349            editor.selections.newest::<Point>(cx).head()
 350        });
 351        if diff_hunks_in_ranges.iter().any(|hunk| {
 352            hunk.row_range
 353                .contains(&multi_buffer::MultiBufferRow(newest_cursor.row))
 354        }) {
 355            self.update_selection(&diff_hunks_in_ranges, window, cx);
 356        }
 357
 358        let mut ranges_by_buffer = HashMap::default();
 359        for hunk in &diff_hunks_in_ranges {
 360            let buffer = self.multibuffer.read(cx).buffer(hunk.buffer_id);
 361            if let Some(buffer) = buffer {
 362                ranges_by_buffer
 363                    .entry(buffer.clone())
 364                    .or_insert_with(Vec::new)
 365                    .push(hunk.buffer_range.clone());
 366            }
 367        }
 368
 369        for (buffer, ranges) in ranges_by_buffer {
 370            self.thread
 371                .update(cx, |thread, cx| {
 372                    thread.reject_edits_in_ranges(buffer, ranges, cx)
 373                })
 374                .detach_and_log_err(cx);
 375        }
 376    }
 377
 378    fn update_selection(
 379        &mut self,
 380        diff_hunks: &[multi_buffer::MultiBufferDiffHunk],
 381        window: &mut Window,
 382        cx: &mut Context<Self>,
 383    ) {
 384        let snapshot = self.multibuffer.read(cx).snapshot(cx);
 385        let target_hunk = diff_hunks
 386            .last()
 387            .and_then(|last_kept_hunk| {
 388                let last_kept_hunk_end = last_kept_hunk.multi_buffer_range().end;
 389                self.editor
 390                    .read(cx)
 391                    .diff_hunks_in_ranges(&[last_kept_hunk_end..editor::Anchor::max()], &snapshot)
 392                    .skip(1)
 393                    .next()
 394            })
 395            .or_else(|| {
 396                let first_kept_hunk = diff_hunks.first()?;
 397                let first_kept_hunk_start = first_kept_hunk.multi_buffer_range().start;
 398                self.editor
 399                    .read(cx)
 400                    .diff_hunks_in_ranges(
 401                        &[editor::Anchor::min()..first_kept_hunk_start],
 402                        &snapshot,
 403                    )
 404                    .next()
 405            });
 406
 407        if let Some(target_hunk) = target_hunk {
 408            self.editor.update(cx, |editor, cx| {
 409                editor.change_selections(Some(Autoscroll::fit()), window, cx, |selections| {
 410                    let next_hunk_start = target_hunk.multi_buffer_range().start;
 411                    selections.select_anchor_ranges([next_hunk_start..next_hunk_start]);
 412                })
 413            });
 414        }
 415    }
 416}
 417
 418impl EventEmitter<EditorEvent> for AgentDiff {}
 419
 420impl Focusable for AgentDiff {
 421    fn focus_handle(&self, cx: &App) -> FocusHandle {
 422        if self.multibuffer.read(cx).is_empty() {
 423            self.focus_handle.clone()
 424        } else {
 425            self.editor.focus_handle(cx)
 426        }
 427    }
 428}
 429
 430impl Item for AgentDiff {
 431    type Event = EditorEvent;
 432
 433    fn tab_icon(&self, _window: &Window, _cx: &App) -> Option<Icon> {
 434        Some(Icon::new(IconName::ZedAssistant).color(Color::Muted))
 435    }
 436
 437    fn to_item_events(event: &EditorEvent, f: impl FnMut(ItemEvent)) {
 438        Editor::to_item_events(event, f)
 439    }
 440
 441    fn deactivated(&mut self, window: &mut Window, cx: &mut Context<Self>) {
 442        self.editor
 443            .update(cx, |editor, cx| editor.deactivated(window, cx));
 444    }
 445
 446    fn navigate(
 447        &mut self,
 448        data: Box<dyn Any>,
 449        window: &mut Window,
 450        cx: &mut Context<Self>,
 451    ) -> bool {
 452        self.editor
 453            .update(cx, |editor, cx| editor.navigate(data, window, cx))
 454    }
 455
 456    fn tab_tooltip_text(&self, _: &App) -> Option<SharedString> {
 457        Some("Agent Diff".into())
 458    }
 459
 460    fn tab_content(&self, params: TabContentParams, _window: &Window, cx: &App) -> AnyElement {
 461        let summary = self
 462            .thread
 463            .read(cx)
 464            .summary()
 465            .unwrap_or("Assistant Changes".into());
 466        Label::new(format!("Review: {}", summary))
 467            .color(if params.selected {
 468                Color::Default
 469            } else {
 470                Color::Muted
 471            })
 472            .into_any_element()
 473    }
 474
 475    fn telemetry_event_text(&self) -> Option<&'static str> {
 476        Some("Assistant Diff Opened")
 477    }
 478
 479    fn as_searchable(&self, _: &Entity<Self>) -> Option<Box<dyn SearchableItemHandle>> {
 480        Some(Box::new(self.editor.clone()))
 481    }
 482
 483    fn for_each_project_item(
 484        &self,
 485        cx: &App,
 486        f: &mut dyn FnMut(gpui::EntityId, &dyn project::ProjectItem),
 487    ) {
 488        self.editor.for_each_project_item(cx, f)
 489    }
 490
 491    fn is_singleton(&self, _: &App) -> bool {
 492        false
 493    }
 494
 495    fn set_nav_history(
 496        &mut self,
 497        nav_history: ItemNavHistory,
 498        _: &mut Window,
 499        cx: &mut Context<Self>,
 500    ) {
 501        self.editor.update(cx, |editor, _| {
 502            editor.set_nav_history(Some(nav_history));
 503        });
 504    }
 505
 506    fn clone_on_split(
 507        &self,
 508        _workspace_id: Option<workspace::WorkspaceId>,
 509        window: &mut Window,
 510        cx: &mut Context<Self>,
 511    ) -> Option<Entity<Self>>
 512    where
 513        Self: Sized,
 514    {
 515        Some(cx.new(|cx| Self::new(self.thread.clone(), self.workspace.clone(), window, cx)))
 516    }
 517
 518    fn is_dirty(&self, cx: &App) -> bool {
 519        self.multibuffer.read(cx).is_dirty(cx)
 520    }
 521
 522    fn has_conflict(&self, cx: &App) -> bool {
 523        self.multibuffer.read(cx).has_conflict(cx)
 524    }
 525
 526    fn can_save(&self, _: &App) -> bool {
 527        true
 528    }
 529
 530    fn save(
 531        &mut self,
 532        format: bool,
 533        project: Entity<Project>,
 534        window: &mut Window,
 535        cx: &mut Context<Self>,
 536    ) -> Task<Result<()>> {
 537        self.editor.save(format, project, window, cx)
 538    }
 539
 540    fn save_as(
 541        &mut self,
 542        _: Entity<Project>,
 543        _: ProjectPath,
 544        _window: &mut Window,
 545        _: &mut Context<Self>,
 546    ) -> Task<Result<()>> {
 547        unreachable!()
 548    }
 549
 550    fn reload(
 551        &mut self,
 552        project: Entity<Project>,
 553        window: &mut Window,
 554        cx: &mut Context<Self>,
 555    ) -> Task<Result<()>> {
 556        self.editor.reload(project, window, cx)
 557    }
 558
 559    fn act_as_type<'a>(
 560        &'a self,
 561        type_id: TypeId,
 562        self_handle: &'a Entity<Self>,
 563        _: &'a App,
 564    ) -> Option<AnyView> {
 565        if type_id == TypeId::of::<Self>() {
 566            Some(self_handle.to_any())
 567        } else if type_id == TypeId::of::<Editor>() {
 568            Some(self.editor.to_any())
 569        } else {
 570            None
 571        }
 572    }
 573
 574    fn breadcrumb_location(&self, _: &App) -> ToolbarItemLocation {
 575        ToolbarItemLocation::PrimaryLeft
 576    }
 577
 578    fn breadcrumbs(&self, theme: &theme::Theme, cx: &App) -> Option<Vec<BreadcrumbText>> {
 579        self.editor.breadcrumbs(theme, cx)
 580    }
 581
 582    fn added_to_workspace(
 583        &mut self,
 584        workspace: &mut Workspace,
 585        window: &mut Window,
 586        cx: &mut Context<Self>,
 587    ) {
 588        self.editor.update(cx, |editor, cx| {
 589            editor.added_to_workspace(workspace, window, cx)
 590        });
 591    }
 592}
 593
 594impl Render for AgentDiff {
 595    fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
 596        let is_empty = self.multibuffer.read(cx).is_empty();
 597        let focus_handle = &self.focus_handle;
 598
 599        div()
 600            .track_focus(focus_handle)
 601            .key_context(if is_empty { "EmptyPane" } else { "AgentDiff" })
 602            .on_action(cx.listener(Self::keep))
 603            .on_action(cx.listener(Self::reject))
 604            .on_action(cx.listener(Self::reject_all))
 605            .on_action(cx.listener(Self::keep_all))
 606            .bg(cx.theme().colors().editor_background)
 607            .flex()
 608            .items_center()
 609            .justify_center()
 610            .size_full()
 611            .when(is_empty, |el| {
 612                el.child(
 613                    v_flex()
 614                        .items_center()
 615                        .gap_2()
 616                        .child("No changes to review")
 617                        .child(
 618                            Button::new("continue-iterating", "Continue Iterating")
 619                                .style(ButtonStyle::Filled)
 620                                .icon(IconName::ForwardArrow)
 621                                .icon_position(IconPosition::Start)
 622                                .icon_size(IconSize::Small)
 623                                .icon_color(Color::Muted)
 624                                .full_width()
 625                                .key_binding(KeyBinding::for_action_in(
 626                                    &ToggleFocus,
 627                                    &focus_handle.clone(),
 628                                    window,
 629                                    cx,
 630                                ))
 631                                .on_click(|_event, window, cx| {
 632                                    window.dispatch_action(ToggleFocus.boxed_clone(), cx)
 633                                }),
 634                        ),
 635                )
 636            })
 637            .when(!is_empty, |el| el.child(self.editor.clone()))
 638    }
 639}
 640
 641fn render_diff_hunk_controls(
 642    row: u32,
 643    _status: &DiffHunkStatus,
 644    hunk_range: Range<editor::Anchor>,
 645    is_created_file: bool,
 646    line_height: Pixels,
 647    agent_diff: &Entity<AgentDiff>,
 648    editor: &Entity<Editor>,
 649    window: &mut Window,
 650    cx: &mut App,
 651) -> AnyElement {
 652    let editor = editor.clone();
 653    h_flex()
 654        .h(line_height)
 655        .mr_0p5()
 656        .gap_1()
 657        .px_0p5()
 658        .pb_1()
 659        .border_x_1()
 660        .border_b_1()
 661        .border_color(cx.theme().colors().border)
 662        .rounded_b_md()
 663        .bg(cx.theme().colors().editor_background)
 664        .gap_1()
 665        .occlude()
 666        .shadow_md()
 667        .children(vec![
 668            Button::new(("reject", row as u64), "Reject")
 669                .disabled(is_created_file)
 670                .key_binding(
 671                    KeyBinding::for_action_in(
 672                        &Reject,
 673                        &editor.read(cx).focus_handle(cx),
 674                        window,
 675                        cx,
 676                    )
 677                    .map(|kb| kb.size(rems_from_px(12.))),
 678                )
 679                .on_click({
 680                    let agent_diff = agent_diff.clone();
 681                    move |_event, window, cx| {
 682                        agent_diff.update(cx, |diff, cx| {
 683                            diff.reject_edits_in_ranges(
 684                                vec![hunk_range.start..hunk_range.start],
 685                                window,
 686                                cx,
 687                            );
 688                        });
 689                    }
 690                }),
 691            Button::new(("keep", row as u64), "Keep")
 692                .key_binding(
 693                    KeyBinding::for_action_in(&Keep, &editor.read(cx).focus_handle(cx), window, cx)
 694                        .map(|kb| kb.size(rems_from_px(12.))),
 695                )
 696                .on_click({
 697                    let agent_diff = agent_diff.clone();
 698                    move |_event, window, cx| {
 699                        agent_diff.update(cx, |diff, cx| {
 700                            diff.keep_edits_in_ranges(
 701                                vec![hunk_range.start..hunk_range.start],
 702                                window,
 703                                cx,
 704                            );
 705                        });
 706                    }
 707                }),
 708        ])
 709        .when(
 710            !editor.read(cx).buffer().read(cx).all_diff_hunks_expanded(),
 711            |el| {
 712                el.child(
 713                    IconButton::new(("next-hunk", row as u64), IconName::ArrowDown)
 714                        .shape(IconButtonShape::Square)
 715                        .icon_size(IconSize::Small)
 716                        // .disabled(!has_multiple_hunks)
 717                        .tooltip({
 718                            let focus_handle = editor.focus_handle(cx);
 719                            move |window, cx| {
 720                                Tooltip::for_action_in(
 721                                    "Next Hunk",
 722                                    &GoToHunk,
 723                                    &focus_handle,
 724                                    window,
 725                                    cx,
 726                                )
 727                            }
 728                        })
 729                        .on_click({
 730                            let editor = editor.clone();
 731                            move |_event, window, cx| {
 732                                editor.update(cx, |editor, cx| {
 733                                    let snapshot = editor.snapshot(window, cx);
 734                                    let position =
 735                                        hunk_range.end.to_point(&snapshot.buffer_snapshot);
 736                                    editor.go_to_hunk_before_or_after_position(
 737                                        &snapshot,
 738                                        position,
 739                                        Direction::Next,
 740                                        window,
 741                                        cx,
 742                                    );
 743                                    editor.expand_selected_diff_hunks(cx);
 744                                });
 745                            }
 746                        }),
 747                )
 748                .child(
 749                    IconButton::new(("prev-hunk", row as u64), IconName::ArrowUp)
 750                        .shape(IconButtonShape::Square)
 751                        .icon_size(IconSize::Small)
 752                        // .disabled(!has_multiple_hunks)
 753                        .tooltip({
 754                            let focus_handle = editor.focus_handle(cx);
 755                            move |window, cx| {
 756                                Tooltip::for_action_in(
 757                                    "Previous Hunk",
 758                                    &GoToPreviousHunk,
 759                                    &focus_handle,
 760                                    window,
 761                                    cx,
 762                                )
 763                            }
 764                        })
 765                        .on_click({
 766                            let editor = editor.clone();
 767                            move |_event, window, cx| {
 768                                editor.update(cx, |editor, cx| {
 769                                    let snapshot = editor.snapshot(window, cx);
 770                                    let point =
 771                                        hunk_range.start.to_point(&snapshot.buffer_snapshot);
 772                                    editor.go_to_hunk_before_or_after_position(
 773                                        &snapshot,
 774                                        point,
 775                                        Direction::Prev,
 776                                        window,
 777                                        cx,
 778                                    );
 779                                    editor.expand_selected_diff_hunks(cx);
 780                                });
 781                            }
 782                        }),
 783                )
 784            },
 785        )
 786        .into_any_element()
 787}
 788
 789struct AgentDiffAddon;
 790
 791impl editor::Addon for AgentDiffAddon {
 792    fn to_any(&self) -> &dyn std::any::Any {
 793        self
 794    }
 795
 796    fn extend_key_context(&self, key_context: &mut gpui::KeyContext, _: &App) {
 797        key_context.add("agent_diff");
 798    }
 799}
 800
 801pub struct AgentDiffToolbar {
 802    agent_diff: Option<WeakEntity<AgentDiff>>,
 803}
 804
 805impl AgentDiffToolbar {
 806    pub fn new() -> Self {
 807        Self { agent_diff: None }
 808    }
 809
 810    fn agent_diff(&self, _: &App) -> Option<Entity<AgentDiff>> {
 811        self.agent_diff.as_ref()?.upgrade()
 812    }
 813
 814    fn dispatch_action(&self, action: &dyn Action, window: &mut Window, cx: &mut Context<Self>) {
 815        if let Some(agent_diff) = self.agent_diff(cx) {
 816            agent_diff.focus_handle(cx).focus(window);
 817        }
 818        let action = action.boxed_clone();
 819        cx.defer(move |cx| {
 820            cx.dispatch_action(action.as_ref());
 821        })
 822    }
 823}
 824
 825impl EventEmitter<ToolbarItemEvent> for AgentDiffToolbar {}
 826
 827impl ToolbarItemView for AgentDiffToolbar {
 828    fn set_active_pane_item(
 829        &mut self,
 830        active_pane_item: Option<&dyn ItemHandle>,
 831        _: &mut Window,
 832        cx: &mut Context<Self>,
 833    ) -> ToolbarItemLocation {
 834        self.agent_diff = active_pane_item
 835            .and_then(|item| item.act_as::<AgentDiff>(cx))
 836            .map(|entity| entity.downgrade());
 837        if self.agent_diff.is_some() {
 838            ToolbarItemLocation::PrimaryRight
 839        } else {
 840            ToolbarItemLocation::Hidden
 841        }
 842    }
 843
 844    fn pane_focus_update(
 845        &mut self,
 846        _pane_focused: bool,
 847        _window: &mut Window,
 848        _cx: &mut Context<Self>,
 849    ) {
 850    }
 851}
 852
 853impl Render for AgentDiffToolbar {
 854    fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
 855        let agent_diff = match self.agent_diff(cx) {
 856            Some(ad) => ad,
 857            None => return div(),
 858        };
 859
 860        let is_empty = agent_diff.read(cx).multibuffer.read(cx).is_empty();
 861
 862        if is_empty {
 863            return div();
 864        }
 865
 866        let focus_handle = agent_diff.focus_handle(cx);
 867
 868        h_group_xl()
 869            .my_neg_1()
 870            .items_center()
 871            .p_1()
 872            .flex_wrap()
 873            .justify_between()
 874            .child(
 875                h_group_sm()
 876                    .child(
 877                        Button::new("reject-all", "Reject All")
 878                            .key_binding({
 879                                KeyBinding::for_action_in(&RejectAll, &focus_handle, window, cx)
 880                                    .map(|kb| kb.size(rems_from_px(12.)))
 881                            })
 882                            .on_click(cx.listener(|this, _, window, cx| {
 883                                this.dispatch_action(&RejectAll, window, cx)
 884                            })),
 885                    )
 886                    .child(
 887                        Button::new("keep-all", "Keep All")
 888                            .key_binding({
 889                                KeyBinding::for_action_in(&KeepAll, &focus_handle, window, cx)
 890                                    .map(|kb| kb.size(rems_from_px(12.)))
 891                            })
 892                            .on_click(cx.listener(|this, _, window, cx| {
 893                                this.dispatch_action(&KeepAll, window, cx)
 894                            })),
 895                    ),
 896            )
 897    }
 898}
 899
 900#[cfg(test)]
 901mod tests {
 902    use super::*;
 903    use crate::{ThreadStore, thread_store};
 904    use assistant_settings::AssistantSettings;
 905    use assistant_tool::ToolWorkingSet;
 906    use context_server::ContextServerSettings;
 907    use editor::EditorSettings;
 908    use gpui::TestAppContext;
 909    use project::{FakeFs, Project};
 910    use prompt_store::PromptBuilder;
 911    use serde_json::json;
 912    use settings::{Settings, SettingsStore};
 913    use std::sync::Arc;
 914    use theme::ThemeSettings;
 915    use util::path;
 916
 917    #[gpui::test]
 918    async fn test_agent_diff(cx: &mut TestAppContext) {
 919        cx.update(|cx| {
 920            let settings_store = SettingsStore::test(cx);
 921            cx.set_global(settings_store);
 922            language::init(cx);
 923            Project::init_settings(cx);
 924            AssistantSettings::register(cx);
 925            thread_store::init(cx);
 926            workspace::init_settings(cx);
 927            ThemeSettings::register(cx);
 928            ContextServerSettings::register(cx);
 929            EditorSettings::register(cx);
 930        });
 931
 932        let fs = FakeFs::new(cx.executor());
 933        fs.insert_tree(
 934            path!("/test"),
 935            json!({"file1": "abc\ndef\nghi\njkl\nmno\npqr\nstu\nvwx\nyz"}),
 936        )
 937        .await;
 938        let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
 939        let buffer_path = project
 940            .read_with(cx, |project, cx| {
 941                project.find_project_path("test/file1", cx)
 942            })
 943            .unwrap();
 944
 945        let thread_store = cx
 946            .update(|cx| {
 947                ThreadStore::load(
 948                    project.clone(),
 949                    cx.new(|_| ToolWorkingSet::default()),
 950                    Arc::new(PromptBuilder::new(None).unwrap()),
 951                    cx,
 952                )
 953            })
 954            .await;
 955        let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
 956        let action_log = thread.read_with(cx, |thread, _| thread.action_log().clone());
 957
 958        let (workspace, cx) =
 959            cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
 960        let agent_diff = cx.new_window_entity(|window, cx| {
 961            AgentDiff::new(thread.clone(), workspace.downgrade(), window, cx)
 962        });
 963        let editor = agent_diff.read_with(cx, |diff, _cx| diff.editor.clone());
 964
 965        let buffer = project
 966            .update(cx, |project, cx| project.open_buffer(buffer_path, cx))
 967            .await
 968            .unwrap();
 969        cx.update(|_, cx| {
 970            action_log.update(cx, |log, cx| log.buffer_read(buffer.clone(), cx));
 971            buffer.update(cx, |buffer, cx| {
 972                buffer
 973                    .edit(
 974                        [
 975                            (Point::new(1, 1)..Point::new(1, 2), "E"),
 976                            (Point::new(3, 2)..Point::new(3, 3), "L"),
 977                            (Point::new(5, 0)..Point::new(5, 1), "P"),
 978                            (Point::new(7, 1)..Point::new(7, 2), "W"),
 979                        ],
 980                        None,
 981                        cx,
 982                    )
 983                    .unwrap()
 984            });
 985            action_log.update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx));
 986        });
 987        cx.run_until_parked();
 988
 989        // When opening the assistant diff, the cursor is positioned on the first hunk.
 990        assert_eq!(
 991            editor.read_with(cx, |editor, cx| editor.text(cx)),
 992            "abc\ndef\ndEf\nghi\njkl\njkL\nmno\npqr\nPqr\nstu\nvwx\nvWx\nyz"
 993        );
 994        assert_eq!(
 995            editor
 996                .update(cx, |editor, cx| editor.selections.newest::<Point>(cx))
 997                .range(),
 998            Point::new(1, 0)..Point::new(1, 0)
 999        );
1000
1001        // After keeping a hunk, the cursor should be positioned on the second hunk.
1002        agent_diff.update_in(cx, |diff, window, cx| diff.keep(&crate::Keep, window, cx));
1003        cx.run_until_parked();
1004        assert_eq!(
1005            editor.read_with(cx, |editor, cx| editor.text(cx)),
1006            "abc\ndEf\nghi\njkl\njkL\nmno\npqr\nPqr\nstu\nvwx\nvWx\nyz"
1007        );
1008        assert_eq!(
1009            editor
1010                .update(cx, |editor, cx| editor.selections.newest::<Point>(cx))
1011                .range(),
1012            Point::new(3, 0)..Point::new(3, 0)
1013        );
1014
1015        // Rejecting a hunk also moves the cursor to the next hunk, possibly cycling if it's at the end.
1016        editor.update_in(cx, |editor, window, cx| {
1017            editor.change_selections(None, window, cx, |selections| {
1018                selections.select_ranges([Point::new(10, 0)..Point::new(10, 0)])
1019            });
1020        });
1021        agent_diff.update_in(cx, |diff, window, cx| {
1022            diff.reject(&crate::Reject, window, cx)
1023        });
1024        cx.run_until_parked();
1025        assert_eq!(
1026            editor.read_with(cx, |editor, cx| editor.text(cx)),
1027            "abc\ndEf\nghi\njkl\njkL\nmno\npqr\nPqr\nstu\nvwx\nyz"
1028        );
1029        assert_eq!(
1030            editor
1031                .update(cx, |editor, cx| editor.selections.newest::<Point>(cx))
1032                .range(),
1033            Point::new(3, 0)..Point::new(3, 0)
1034        );
1035
1036        // Keeping a range that doesn't intersect the current selection doesn't move it.
1037        agent_diff.update_in(cx, |diff, window, cx| {
1038            let position = editor
1039                .read(cx)
1040                .buffer()
1041                .read(cx)
1042                .read(cx)
1043                .anchor_before(Point::new(7, 0));
1044            diff.keep_edits_in_ranges(vec![position..position], window, cx)
1045        });
1046        cx.run_until_parked();
1047        assert_eq!(
1048            editor.read_with(cx, |editor, cx| editor.text(cx)),
1049            "abc\ndEf\nghi\njkl\njkL\nmno\nPqr\nstu\nvwx\nyz"
1050        );
1051        assert_eq!(
1052            editor
1053                .update(cx, |editor, cx| editor.selections.newest::<Point>(cx))
1054                .range(),
1055            Point::new(3, 0)..Point::new(3, 0)
1056        );
1057    }
1058}