agent_diff.rs

   1use crate::{Keep, Reject, Thread, ThreadEvent};
   2use anyhow::Result;
   3use buffer_diff::DiffHunkStatus;
   4use collections::HashSet;
   5use editor::{
   6    AnchorRangeExt, Direction, Editor, EditorEvent, MultiBuffer, ToPoint,
   7    actions::{GoToHunk, GoToPreviousHunk},
   8    scroll::Autoscroll,
   9};
  10use gpui::{
  11    Action, AnyElement, AnyView, App, Entity, EventEmitter, FocusHandle, Focusable, SharedString,
  12    Subscription, Task, WeakEntity, Window, prelude::*,
  13};
  14use language::{Capability, DiskState, OffsetRangeExt, Point};
  15use multi_buffer::PathKey;
  16use project::{Project, ProjectPath};
  17use std::{
  18    any::{Any, TypeId},
  19    ops::Range,
  20    sync::Arc,
  21};
  22use ui::{IconButtonShape, KeyBinding, Tooltip, prelude::*};
  23use workspace::{
  24    Item, ItemHandle, ItemNavHistory, ToolbarItemEvent, ToolbarItemLocation, ToolbarItemView,
  25    Workspace,
  26    item::{BreadcrumbText, ItemEvent, TabContentParams},
  27    searchable::SearchableItemHandle,
  28};
  29use zed_actions::assistant::ToggleFocus;
  30
  31pub struct AgentDiff {
  32    multibuffer: Entity<MultiBuffer>,
  33    editor: Entity<Editor>,
  34    thread: Entity<Thread>,
  35    focus_handle: FocusHandle,
  36    workspace: WeakEntity<Workspace>,
  37    title: SharedString,
  38    _subscriptions: Vec<Subscription>,
  39}
  40
  41impl AgentDiff {
  42    pub fn deploy(
  43        thread: Entity<Thread>,
  44        workspace: WeakEntity<Workspace>,
  45        window: &mut Window,
  46        cx: &mut App,
  47    ) -> Result<()> {
  48        let existing_diff = workspace.update(cx, |workspace, cx| {
  49            workspace
  50                .items_of_type::<AgentDiff>(cx)
  51                .find(|diff| diff.read(cx).thread == thread)
  52        })?;
  53        if let Some(existing_diff) = existing_diff {
  54            workspace.update(cx, |workspace, cx| {
  55                workspace.activate_item(&existing_diff, true, true, window, cx);
  56            })
  57        } else {
  58            let agent_diff =
  59                cx.new(|cx| AgentDiff::new(thread.clone(), workspace.clone(), window, cx));
  60            workspace.update(cx, |workspace, cx| {
  61                workspace.add_item_to_center(Box::new(agent_diff), window, cx);
  62            })
  63        }
  64    }
  65
  66    pub fn new(
  67        thread: Entity<Thread>,
  68        workspace: WeakEntity<Workspace>,
  69        window: &mut Window,
  70        cx: &mut Context<Self>,
  71    ) -> Self {
  72        let focus_handle = cx.focus_handle();
  73        let multibuffer = cx.new(|_| MultiBuffer::new(Capability::ReadWrite));
  74
  75        let project = thread.read(cx).project().clone();
  76        let render_diff_hunk_controls = Arc::new({
  77            let agent_diff = cx.entity();
  78            move |row,
  79                  status: &DiffHunkStatus,
  80                  hunk_range,
  81                  is_created_file,
  82                  line_height,
  83                  editor: &Entity<Editor>,
  84                  window: &mut Window,
  85                  cx: &mut App| {
  86                render_diff_hunk_controls(
  87                    row,
  88                    status,
  89                    hunk_range,
  90                    is_created_file,
  91                    line_height,
  92                    &agent_diff,
  93                    editor,
  94                    window,
  95                    cx,
  96                )
  97            }
  98        });
  99        let editor = cx.new(|cx| {
 100            let mut editor =
 101                Editor::for_multibuffer(multibuffer.clone(), Some(project.clone()), window, cx);
 102            editor.disable_inline_diagnostics();
 103            editor.set_expand_all_diff_hunks(cx);
 104            editor.set_render_diff_hunk_controls(render_diff_hunk_controls, cx);
 105            editor.register_addon(AgentDiffAddon);
 106            editor
 107        });
 108
 109        let action_log = thread.read(cx).action_log().clone();
 110        let mut this = Self {
 111            _subscriptions: vec![
 112                cx.observe_in(&action_log, window, |this, _action_log, window, cx| {
 113                    this.update_excerpts(window, cx)
 114                }),
 115                cx.subscribe(&thread, |this, _thread, event, cx| {
 116                    this.handle_thread_event(event, cx)
 117                }),
 118            ],
 119            title: SharedString::default(),
 120            multibuffer,
 121            editor,
 122            thread,
 123            focus_handle,
 124            workspace,
 125        };
 126        this.update_excerpts(window, cx);
 127        this.update_title(cx);
 128        this
 129    }
 130
 131    fn update_excerpts(&mut self, window: &mut Window, cx: &mut Context<Self>) {
 132        let thread = self.thread.read(cx);
 133        let changed_buffers = thread.action_log().read(cx).changed_buffers(cx);
 134        let mut paths_to_delete = self.multibuffer.read(cx).paths().collect::<HashSet<_>>();
 135
 136        for (buffer, diff_handle) in changed_buffers {
 137            let Some(file) = buffer.read(cx).file().cloned() else {
 138                continue;
 139            };
 140
 141            let path_key = PathKey::namespaced(0, file.full_path(cx).into());
 142            paths_to_delete.remove(&path_key);
 143
 144            let snapshot = buffer.read(cx).snapshot();
 145            let diff = diff_handle.read(cx);
 146
 147            let diff_hunk_ranges = diff
 148                .hunks_intersecting_range(
 149                    language::Anchor::MIN..language::Anchor::MAX,
 150                    &snapshot,
 151                    cx,
 152                )
 153                .map(|diff_hunk| diff_hunk.buffer_range.to_point(&snapshot))
 154                .collect::<Vec<_>>();
 155
 156            let (was_empty, is_excerpt_newly_added) =
 157                self.multibuffer.update(cx, |multibuffer, cx| {
 158                    let was_empty = multibuffer.is_empty();
 159                    let (_, is_excerpt_newly_added) = multibuffer.set_excerpts_for_path(
 160                        path_key.clone(),
 161                        buffer.clone(),
 162                        diff_hunk_ranges,
 163                        editor::DEFAULT_MULTIBUFFER_CONTEXT,
 164                        cx,
 165                    );
 166                    multibuffer.add_diff(diff_handle, cx);
 167                    (was_empty, is_excerpt_newly_added)
 168                });
 169
 170            self.editor.update(cx, |editor, cx| {
 171                if was_empty {
 172                    let first_hunk = editor
 173                        .diff_hunks_in_ranges(
 174                            &[editor::Anchor::min()..editor::Anchor::max()],
 175                            &self.multibuffer.read(cx).read(cx),
 176                        )
 177                        .next();
 178
 179                    if let Some(first_hunk) = first_hunk {
 180                        let first_hunk_start = first_hunk.multi_buffer_range().start;
 181                        editor.change_selections(
 182                            Some(Autoscroll::fit()),
 183                            window,
 184                            cx,
 185                            |selections| {
 186                                selections
 187                                    .select_anchor_ranges([first_hunk_start..first_hunk_start]);
 188                            },
 189                        )
 190                    }
 191                }
 192
 193                if is_excerpt_newly_added
 194                    && buffer
 195                        .read(cx)
 196                        .file()
 197                        .map_or(false, |file| file.disk_state() == DiskState::Deleted)
 198                {
 199                    editor.fold_buffer(snapshot.text.remote_id(), cx)
 200                }
 201            });
 202        }
 203
 204        self.multibuffer.update(cx, |multibuffer, cx| {
 205            for path in paths_to_delete {
 206                multibuffer.remove_excerpts_for_path(path, cx);
 207            }
 208        });
 209
 210        if self.multibuffer.read(cx).is_empty()
 211            && self
 212                .editor
 213                .read(cx)
 214                .focus_handle(cx)
 215                .contains_focused(window, cx)
 216        {
 217            self.focus_handle.focus(window);
 218        } else if self.focus_handle.is_focused(window) && !self.multibuffer.read(cx).is_empty() {
 219            self.editor.update(cx, |editor, cx| {
 220                editor.focus_handle(cx).focus(window);
 221            });
 222        }
 223    }
 224
 225    fn update_title(&mut self, cx: &mut Context<Self>) {
 226        let new_title = self
 227            .thread
 228            .read(cx)
 229            .summary()
 230            .unwrap_or("Assistant Changes".into());
 231        if new_title != self.title {
 232            self.title = new_title;
 233            cx.emit(EditorEvent::TitleChanged);
 234        }
 235    }
 236
 237    fn handle_thread_event(&mut self, event: &ThreadEvent, cx: &mut Context<Self>) {
 238        match event {
 239            ThreadEvent::SummaryChanged => self.update_title(cx),
 240            _ => {}
 241        }
 242    }
 243
 244    fn keep(&mut self, _: &crate::Keep, window: &mut Window, cx: &mut Context<Self>) {
 245        let ranges = self
 246            .editor
 247            .read(cx)
 248            .selections
 249            .disjoint_anchor_ranges()
 250            .collect::<Vec<_>>();
 251        self.keep_edits_in_ranges(ranges, window, cx);
 252    }
 253
 254    fn reject(&mut self, _: &crate::Reject, window: &mut Window, cx: &mut Context<Self>) {
 255        let ranges = self
 256            .editor
 257            .read(cx)
 258            .selections
 259            .disjoint_anchor_ranges()
 260            .collect::<Vec<_>>();
 261        self.reject_edits_in_ranges(ranges, window, cx);
 262    }
 263
 264    fn reject_all(&mut self, _: &crate::RejectAll, window: &mut Window, cx: &mut Context<Self>) {
 265        self.reject_edits_in_ranges(
 266            vec![editor::Anchor::min()..editor::Anchor::max()],
 267            window,
 268            cx,
 269        );
 270    }
 271
 272    fn keep_all(&mut self, _: &crate::KeepAll, _window: &mut Window, cx: &mut Context<Self>) {
 273        self.thread
 274            .update(cx, |thread, cx| thread.keep_all_edits(cx));
 275    }
 276
 277    fn keep_edits_in_ranges(
 278        &mut self,
 279        ranges: Vec<Range<editor::Anchor>>,
 280        window: &mut Window,
 281        cx: &mut Context<Self>,
 282    ) {
 283        let snapshot = self.multibuffer.read(cx).snapshot(cx);
 284        let diff_hunks_in_ranges = self
 285            .editor
 286            .read(cx)
 287            .diff_hunks_in_ranges(&ranges, &snapshot)
 288            .collect::<Vec<_>>();
 289        let newest_cursor = self.editor.update(cx, |editor, cx| {
 290            editor.selections.newest::<Point>(cx).head()
 291        });
 292        if diff_hunks_in_ranges.iter().any(|hunk| {
 293            hunk.row_range
 294                .contains(&multi_buffer::MultiBufferRow(newest_cursor.row))
 295        }) {
 296            self.update_selection(&diff_hunks_in_ranges, window, cx);
 297        }
 298
 299        for hunk in &diff_hunks_in_ranges {
 300            let buffer = self.multibuffer.read(cx).buffer(hunk.buffer_id);
 301            if let Some(buffer) = buffer {
 302                self.thread.update(cx, |thread, cx| {
 303                    thread.keep_edits_in_range(buffer, hunk.buffer_range.clone(), cx)
 304                });
 305            }
 306        }
 307    }
 308
 309    fn reject_edits_in_ranges(
 310        &mut self,
 311        ranges: Vec<Range<editor::Anchor>>,
 312        window: &mut Window,
 313        cx: &mut Context<Self>,
 314    ) {
 315        let snapshot = self.multibuffer.read(cx).snapshot(cx);
 316        let diff_hunks_in_ranges = self
 317            .editor
 318            .read(cx)
 319            .diff_hunks_in_ranges(&ranges, &snapshot)
 320            .collect::<Vec<_>>();
 321        let newest_cursor = self.editor.update(cx, |editor, cx| {
 322            editor.selections.newest::<Point>(cx).head()
 323        });
 324        if diff_hunks_in_ranges.iter().any(|hunk| {
 325            hunk.row_range
 326                .contains(&multi_buffer::MultiBufferRow(newest_cursor.row))
 327        }) {
 328            self.update_selection(&diff_hunks_in_ranges, window, cx);
 329        }
 330
 331        let point_ranges = ranges
 332            .into_iter()
 333            .map(|range| range.to_point(&snapshot))
 334            .collect();
 335        self.editor.update(cx, |editor, cx| {
 336            editor.restore_hunks_in_ranges(point_ranges, window, cx)
 337        });
 338    }
 339
 340    fn update_selection(
 341        &mut self,
 342        diff_hunks: &[multi_buffer::MultiBufferDiffHunk],
 343        window: &mut Window,
 344        cx: &mut Context<Self>,
 345    ) {
 346        let snapshot = self.multibuffer.read(cx).snapshot(cx);
 347        let target_hunk = diff_hunks
 348            .last()
 349            .and_then(|last_kept_hunk| {
 350                let last_kept_hunk_end = last_kept_hunk.multi_buffer_range().end;
 351                self.editor
 352                    .read(cx)
 353                    .diff_hunks_in_ranges(&[last_kept_hunk_end..editor::Anchor::max()], &snapshot)
 354                    .skip(1)
 355                    .next()
 356            })
 357            .or_else(|| {
 358                let first_kept_hunk = diff_hunks.first()?;
 359                let first_kept_hunk_start = first_kept_hunk.multi_buffer_range().start;
 360                self.editor
 361                    .read(cx)
 362                    .diff_hunks_in_ranges(
 363                        &[editor::Anchor::min()..first_kept_hunk_start],
 364                        &snapshot,
 365                    )
 366                    .next()
 367            });
 368
 369        if let Some(target_hunk) = target_hunk {
 370            self.editor.update(cx, |editor, cx| {
 371                editor.change_selections(Some(Autoscroll::fit()), window, cx, |selections| {
 372                    let next_hunk_start = target_hunk.multi_buffer_range().start;
 373                    selections.select_anchor_ranges([next_hunk_start..next_hunk_start]);
 374                })
 375            });
 376        }
 377    }
 378}
 379
 380impl EventEmitter<EditorEvent> for AgentDiff {}
 381
 382impl Focusable for AgentDiff {
 383    fn focus_handle(&self, cx: &App) -> FocusHandle {
 384        if self.multibuffer.read(cx).is_empty() {
 385            self.focus_handle.clone()
 386        } else {
 387            self.editor.focus_handle(cx)
 388        }
 389    }
 390}
 391
 392impl Item for AgentDiff {
 393    type Event = EditorEvent;
 394
 395    fn tab_icon(&self, _window: &Window, _cx: &App) -> Option<Icon> {
 396        Some(Icon::new(IconName::ZedAssistant).color(Color::Muted))
 397    }
 398
 399    fn to_item_events(event: &EditorEvent, f: impl FnMut(ItemEvent)) {
 400        Editor::to_item_events(event, f)
 401    }
 402
 403    fn deactivated(&mut self, window: &mut Window, cx: &mut Context<Self>) {
 404        self.editor
 405            .update(cx, |editor, cx| editor.deactivated(window, cx));
 406    }
 407
 408    fn navigate(
 409        &mut self,
 410        data: Box<dyn Any>,
 411        window: &mut Window,
 412        cx: &mut Context<Self>,
 413    ) -> bool {
 414        self.editor
 415            .update(cx, |editor, cx| editor.navigate(data, window, cx))
 416    }
 417
 418    fn tab_tooltip_text(&self, _: &App) -> Option<SharedString> {
 419        Some("Agent Diff".into())
 420    }
 421
 422    fn tab_content(&self, params: TabContentParams, _window: &Window, cx: &App) -> AnyElement {
 423        let summary = self
 424            .thread
 425            .read(cx)
 426            .summary()
 427            .unwrap_or("Assistant Changes".into());
 428        Label::new(format!("Review: {}", summary))
 429            .color(if params.selected {
 430                Color::Default
 431            } else {
 432                Color::Muted
 433            })
 434            .into_any_element()
 435    }
 436
 437    fn telemetry_event_text(&self) -> Option<&'static str> {
 438        Some("Assistant Diff Opened")
 439    }
 440
 441    fn as_searchable(&self, _: &Entity<Self>) -> Option<Box<dyn SearchableItemHandle>> {
 442        Some(Box::new(self.editor.clone()))
 443    }
 444
 445    fn for_each_project_item(
 446        &self,
 447        cx: &App,
 448        f: &mut dyn FnMut(gpui::EntityId, &dyn project::ProjectItem),
 449    ) {
 450        self.editor.for_each_project_item(cx, f)
 451    }
 452
 453    fn is_singleton(&self, _: &App) -> bool {
 454        false
 455    }
 456
 457    fn set_nav_history(
 458        &mut self,
 459        nav_history: ItemNavHistory,
 460        _: &mut Window,
 461        cx: &mut Context<Self>,
 462    ) {
 463        self.editor.update(cx, |editor, _| {
 464            editor.set_nav_history(Some(nav_history));
 465        });
 466    }
 467
 468    fn clone_on_split(
 469        &self,
 470        _workspace_id: Option<workspace::WorkspaceId>,
 471        window: &mut Window,
 472        cx: &mut Context<Self>,
 473    ) -> Option<Entity<Self>>
 474    where
 475        Self: Sized,
 476    {
 477        Some(cx.new(|cx| Self::new(self.thread.clone(), self.workspace.clone(), window, cx)))
 478    }
 479
 480    fn is_dirty(&self, cx: &App) -> bool {
 481        self.multibuffer.read(cx).is_dirty(cx)
 482    }
 483
 484    fn has_conflict(&self, cx: &App) -> bool {
 485        self.multibuffer.read(cx).has_conflict(cx)
 486    }
 487
 488    fn can_save(&self, _: &App) -> bool {
 489        true
 490    }
 491
 492    fn save(
 493        &mut self,
 494        format: bool,
 495        project: Entity<Project>,
 496        window: &mut Window,
 497        cx: &mut Context<Self>,
 498    ) -> Task<Result<()>> {
 499        self.editor.save(format, project, window, cx)
 500    }
 501
 502    fn save_as(
 503        &mut self,
 504        _: Entity<Project>,
 505        _: ProjectPath,
 506        _window: &mut Window,
 507        _: &mut Context<Self>,
 508    ) -> Task<Result<()>> {
 509        unreachable!()
 510    }
 511
 512    fn reload(
 513        &mut self,
 514        project: Entity<Project>,
 515        window: &mut Window,
 516        cx: &mut Context<Self>,
 517    ) -> Task<Result<()>> {
 518        self.editor.reload(project, window, cx)
 519    }
 520
 521    fn act_as_type<'a>(
 522        &'a self,
 523        type_id: TypeId,
 524        self_handle: &'a Entity<Self>,
 525        _: &'a App,
 526    ) -> Option<AnyView> {
 527        if type_id == TypeId::of::<Self>() {
 528            Some(self_handle.to_any())
 529        } else if type_id == TypeId::of::<Editor>() {
 530            Some(self.editor.to_any())
 531        } else {
 532            None
 533        }
 534    }
 535
 536    fn breadcrumb_location(&self, _: &App) -> ToolbarItemLocation {
 537        ToolbarItemLocation::PrimaryLeft
 538    }
 539
 540    fn breadcrumbs(&self, theme: &theme::Theme, cx: &App) -> Option<Vec<BreadcrumbText>> {
 541        self.editor.breadcrumbs(theme, cx)
 542    }
 543
 544    fn added_to_workspace(
 545        &mut self,
 546        workspace: &mut Workspace,
 547        window: &mut Window,
 548        cx: &mut Context<Self>,
 549    ) {
 550        self.editor.update(cx, |editor, cx| {
 551            editor.added_to_workspace(workspace, window, cx)
 552        });
 553    }
 554}
 555
 556impl Render for AgentDiff {
 557    fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
 558        let is_empty = self.multibuffer.read(cx).is_empty();
 559        let focus_handle = &self.focus_handle;
 560
 561        div()
 562            .track_focus(focus_handle)
 563            .key_context(if is_empty { "EmptyPane" } else { "AgentDiff" })
 564            .on_action(cx.listener(Self::keep))
 565            .on_action(cx.listener(Self::reject))
 566            .on_action(cx.listener(Self::reject_all))
 567            .on_action(cx.listener(Self::keep_all))
 568            .bg(cx.theme().colors().editor_background)
 569            .flex()
 570            .items_center()
 571            .justify_center()
 572            .size_full()
 573            .when(is_empty, |el| {
 574                el.child(
 575                    v_flex()
 576                        .items_center()
 577                        .gap_2()
 578                        .child("No changes to review")
 579                        .child(
 580                            Button::new("continue-iterating", "Continue Iterating")
 581                                .style(ButtonStyle::Filled)
 582                                .icon(IconName::ForwardArrow)
 583                                .icon_position(IconPosition::Start)
 584                                .icon_size(IconSize::Small)
 585                                .icon_color(Color::Muted)
 586                                .full_width()
 587                                .key_binding(KeyBinding::for_action_in(
 588                                    &ToggleFocus,
 589                                    &focus_handle.clone(),
 590                                    window,
 591                                    cx,
 592                                ))
 593                                .on_click(|_event, window, cx| {
 594                                    window.dispatch_action(ToggleFocus.boxed_clone(), cx)
 595                                }),
 596                        ),
 597                )
 598            })
 599            .when(!is_empty, |el| el.child(self.editor.clone()))
 600    }
 601}
 602
 603fn render_diff_hunk_controls(
 604    row: u32,
 605    _status: &DiffHunkStatus,
 606    hunk_range: Range<editor::Anchor>,
 607    is_created_file: bool,
 608    line_height: Pixels,
 609    agent_diff: &Entity<AgentDiff>,
 610    editor: &Entity<Editor>,
 611    window: &mut Window,
 612    cx: &mut App,
 613) -> AnyElement {
 614    let editor = editor.clone();
 615    h_flex()
 616        .h(line_height)
 617        .mr_0p5()
 618        .gap_1()
 619        .px_0p5()
 620        .pb_1()
 621        .border_x_1()
 622        .border_b_1()
 623        .border_color(cx.theme().colors().border)
 624        .rounded_b_md()
 625        .bg(cx.theme().colors().editor_background)
 626        .gap_1()
 627        .occlude()
 628        .shadow_md()
 629        .children(vec![
 630            Button::new(("reject", row as u64), "Reject")
 631                .disabled(is_created_file)
 632                .key_binding(
 633                    KeyBinding::for_action_in(
 634                        &Reject,
 635                        &editor.read(cx).focus_handle(cx),
 636                        window,
 637                        cx,
 638                    )
 639                    .map(|kb| kb.size(rems_from_px(12.))),
 640                )
 641                .on_click({
 642                    let agent_diff = agent_diff.clone();
 643                    move |_event, window, cx| {
 644                        agent_diff.update(cx, |diff, cx| {
 645                            diff.reject_edits_in_ranges(
 646                                vec![hunk_range.start..hunk_range.start],
 647                                window,
 648                                cx,
 649                            );
 650                        });
 651                    }
 652                }),
 653            Button::new(("keep", row as u64), "Keep")
 654                .key_binding(
 655                    KeyBinding::for_action_in(&Keep, &editor.read(cx).focus_handle(cx), window, cx)
 656                        .map(|kb| kb.size(rems_from_px(12.))),
 657                )
 658                .on_click({
 659                    let agent_diff = agent_diff.clone();
 660                    move |_event, window, cx| {
 661                        agent_diff.update(cx, |diff, cx| {
 662                            diff.keep_edits_in_ranges(
 663                                vec![hunk_range.start..hunk_range.start],
 664                                window,
 665                                cx,
 666                            );
 667                        });
 668                    }
 669                }),
 670        ])
 671        .when(
 672            !editor.read(cx).buffer().read(cx).all_diff_hunks_expanded(),
 673            |el| {
 674                el.child(
 675                    IconButton::new(("next-hunk", row as u64), IconName::ArrowDown)
 676                        .shape(IconButtonShape::Square)
 677                        .icon_size(IconSize::Small)
 678                        // .disabled(!has_multiple_hunks)
 679                        .tooltip({
 680                            let focus_handle = editor.focus_handle(cx);
 681                            move |window, cx| {
 682                                Tooltip::for_action_in(
 683                                    "Next Hunk",
 684                                    &GoToHunk,
 685                                    &focus_handle,
 686                                    window,
 687                                    cx,
 688                                )
 689                            }
 690                        })
 691                        .on_click({
 692                            let editor = editor.clone();
 693                            move |_event, window, cx| {
 694                                editor.update(cx, |editor, cx| {
 695                                    let snapshot = editor.snapshot(window, cx);
 696                                    let position =
 697                                        hunk_range.end.to_point(&snapshot.buffer_snapshot);
 698                                    editor.go_to_hunk_before_or_after_position(
 699                                        &snapshot,
 700                                        position,
 701                                        Direction::Next,
 702                                        window,
 703                                        cx,
 704                                    );
 705                                    editor.expand_selected_diff_hunks(cx);
 706                                });
 707                            }
 708                        }),
 709                )
 710                .child(
 711                    IconButton::new(("prev-hunk", row as u64), IconName::ArrowUp)
 712                        .shape(IconButtonShape::Square)
 713                        .icon_size(IconSize::Small)
 714                        // .disabled(!has_multiple_hunks)
 715                        .tooltip({
 716                            let focus_handle = editor.focus_handle(cx);
 717                            move |window, cx| {
 718                                Tooltip::for_action_in(
 719                                    "Previous Hunk",
 720                                    &GoToPreviousHunk,
 721                                    &focus_handle,
 722                                    window,
 723                                    cx,
 724                                )
 725                            }
 726                        })
 727                        .on_click({
 728                            let editor = editor.clone();
 729                            move |_event, window, cx| {
 730                                editor.update(cx, |editor, cx| {
 731                                    let snapshot = editor.snapshot(window, cx);
 732                                    let point =
 733                                        hunk_range.start.to_point(&snapshot.buffer_snapshot);
 734                                    editor.go_to_hunk_before_or_after_position(
 735                                        &snapshot,
 736                                        point,
 737                                        Direction::Prev,
 738                                        window,
 739                                        cx,
 740                                    );
 741                                    editor.expand_selected_diff_hunks(cx);
 742                                });
 743                            }
 744                        }),
 745                )
 746            },
 747        )
 748        .into_any_element()
 749}
 750
 751struct AgentDiffAddon;
 752
 753impl editor::Addon for AgentDiffAddon {
 754    fn to_any(&self) -> &dyn std::any::Any {
 755        self
 756    }
 757
 758    fn extend_key_context(&self, key_context: &mut gpui::KeyContext, _: &App) {
 759        key_context.add("agent_diff");
 760    }
 761}
 762
 763pub struct AgentDiffToolbar {
 764    agent_diff: Option<WeakEntity<AgentDiff>>,
 765    _workspace: WeakEntity<Workspace>,
 766}
 767
 768impl AgentDiffToolbar {
 769    pub fn new(workspace: &Workspace, _: &mut Context<Self>) -> Self {
 770        Self {
 771            agent_diff: None,
 772            _workspace: workspace.weak_handle(),
 773        }
 774    }
 775
 776    fn agent_diff(&self, _: &App) -> Option<Entity<AgentDiff>> {
 777        self.agent_diff.as_ref()?.upgrade()
 778    }
 779
 780    fn dispatch_action(&self, action: &dyn Action, window: &mut Window, cx: &mut Context<Self>) {
 781        if let Some(agent_diff) = self.agent_diff(cx) {
 782            agent_diff.focus_handle(cx).focus(window);
 783        }
 784        let action = action.boxed_clone();
 785        cx.defer(move |cx| {
 786            cx.dispatch_action(action.as_ref());
 787        })
 788    }
 789}
 790
 791impl EventEmitter<ToolbarItemEvent> for AgentDiffToolbar {}
 792
 793impl ToolbarItemView for AgentDiffToolbar {
 794    fn set_active_pane_item(
 795        &mut self,
 796        active_pane_item: Option<&dyn ItemHandle>,
 797        _: &mut Window,
 798        cx: &mut Context<Self>,
 799    ) -> ToolbarItemLocation {
 800        self.agent_diff = active_pane_item
 801            .and_then(|item| item.act_as::<AgentDiff>(cx))
 802            .map(|entity| entity.downgrade());
 803        if self.agent_diff.is_some() {
 804            ToolbarItemLocation::PrimaryRight
 805        } else {
 806            ToolbarItemLocation::Hidden
 807        }
 808    }
 809
 810    fn pane_focus_update(
 811        &mut self,
 812        _pane_focused: bool,
 813        _window: &mut Window,
 814        _cx: &mut Context<Self>,
 815    ) {
 816    }
 817}
 818
 819impl Render for AgentDiffToolbar {
 820    fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
 821        let agent_diff = match self.agent_diff(cx) {
 822            Some(ad) => ad,
 823            None => return div(),
 824        };
 825
 826        let is_empty = agent_diff.read(cx).multibuffer.read(cx).is_empty();
 827
 828        if is_empty {
 829            return div();
 830        }
 831
 832        h_group_xl()
 833            .my_neg_1()
 834            .items_center()
 835            .p_1()
 836            .flex_wrap()
 837            .justify_between()
 838            .child(
 839                h_group_sm()
 840                    .child(
 841                        Button::new("reject-all", "Reject All").on_click(cx.listener(
 842                            |this, _, window, cx| {
 843                                this.dispatch_action(&crate::RejectAll, window, cx)
 844                            },
 845                        )),
 846                    )
 847                    .child(Button::new("keep-all", "Keep All").on_click(cx.listener(
 848                        |this, _, window, cx| this.dispatch_action(&crate::KeepAll, window, cx),
 849                    ))),
 850            )
 851    }
 852}
 853
 854#[cfg(test)]
 855mod tests {
 856    use super::*;
 857    use crate::{ThreadStore, thread_store};
 858    use assistant_settings::AssistantSettings;
 859    use context_server::ContextServerSettings;
 860    use editor::EditorSettings;
 861    use gpui::TestAppContext;
 862    use project::{FakeFs, Project};
 863    use prompt_store::PromptBuilder;
 864    use serde_json::json;
 865    use settings::{Settings, SettingsStore};
 866    use std::sync::Arc;
 867    use theme::ThemeSettings;
 868    use util::path;
 869
 870    #[gpui::test]
 871    async fn test_agent_diff(cx: &mut TestAppContext) {
 872        cx.update(|cx| {
 873            let settings_store = SettingsStore::test(cx);
 874            cx.set_global(settings_store);
 875            language::init(cx);
 876            Project::init_settings(cx);
 877            AssistantSettings::register(cx);
 878            thread_store::init(cx);
 879            workspace::init_settings(cx);
 880            ThemeSettings::register(cx);
 881            ContextServerSettings::register(cx);
 882            EditorSettings::register(cx);
 883        });
 884
 885        let fs = FakeFs::new(cx.executor());
 886        fs.insert_tree(
 887            path!("/test"),
 888            json!({"file1": "abc\ndef\nghi\njkl\nmno\npqr\nstu\nvwx\nyz"}),
 889        )
 890        .await;
 891        let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
 892        let buffer_path = project
 893            .read_with(cx, |project, cx| {
 894                project.find_project_path("test/file1", cx)
 895            })
 896            .unwrap();
 897
 898        let thread_store = cx.update(|cx| {
 899            ThreadStore::new(
 900                project.clone(),
 901                Arc::default(),
 902                Arc::new(PromptBuilder::new(None).unwrap()),
 903                cx,
 904            )
 905            .unwrap()
 906        });
 907        let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
 908        let action_log = thread.read_with(cx, |thread, _| thread.action_log().clone());
 909
 910        let (workspace, cx) =
 911            cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
 912        let agent_diff = cx.new_window_entity(|window, cx| {
 913            AgentDiff::new(thread.clone(), workspace.downgrade(), window, cx)
 914        });
 915        let editor = agent_diff.read_with(cx, |diff, _cx| diff.editor.clone());
 916
 917        let buffer = project
 918            .update(cx, |project, cx| project.open_buffer(buffer_path, cx))
 919            .await
 920            .unwrap();
 921        cx.update(|_, cx| {
 922            action_log.update(cx, |log, cx| log.buffer_read(buffer.clone(), cx));
 923            buffer.update(cx, |buffer, cx| {
 924                buffer
 925                    .edit(
 926                        [
 927                            (Point::new(1, 1)..Point::new(1, 2), "E"),
 928                            (Point::new(3, 2)..Point::new(3, 3), "L"),
 929                            (Point::new(5, 0)..Point::new(5, 1), "P"),
 930                            (Point::new(7, 1)..Point::new(7, 2), "W"),
 931                        ],
 932                        None,
 933                        cx,
 934                    )
 935                    .unwrap()
 936            });
 937            action_log.update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx));
 938        });
 939        cx.run_until_parked();
 940
 941        // When opening the assistant diff, the cursor is positioned on the first hunk.
 942        assert_eq!(
 943            editor.read_with(cx, |editor, cx| editor.text(cx)),
 944            "abc\ndef\ndEf\nghi\njkl\njkL\nmno\npqr\nPqr\nstu\nvwx\nvWx\nyz"
 945        );
 946        assert_eq!(
 947            editor
 948                .update(cx, |editor, cx| editor.selections.newest::<Point>(cx))
 949                .range(),
 950            Point::new(1, 0)..Point::new(1, 0)
 951        );
 952
 953        // After keeping a hunk, the cursor should be positioned on the second hunk.
 954        agent_diff.update_in(cx, |diff, window, cx| diff.keep(&crate::Keep, window, cx));
 955        cx.run_until_parked();
 956        assert_eq!(
 957            editor.read_with(cx, |editor, cx| editor.text(cx)),
 958            "abc\ndEf\nghi\njkl\njkL\nmno\npqr\nPqr\nstu\nvwx\nvWx\nyz"
 959        );
 960        assert_eq!(
 961            editor
 962                .update(cx, |editor, cx| editor.selections.newest::<Point>(cx))
 963                .range(),
 964            Point::new(3, 0)..Point::new(3, 0)
 965        );
 966
 967        // Restoring a hunk also moves the cursor to the next hunk, possibly cycling if it's at the end.
 968        editor.update_in(cx, |editor, window, cx| {
 969            editor.change_selections(None, window, cx, |selections| {
 970                selections.select_ranges([Point::new(10, 0)..Point::new(10, 0)])
 971            });
 972        });
 973        agent_diff.update_in(cx, |diff, window, cx| {
 974            diff.reject(&crate::Reject, window, cx)
 975        });
 976        cx.run_until_parked();
 977        assert_eq!(
 978            editor.read_with(cx, |editor, cx| editor.text(cx)),
 979            "abc\ndEf\nghi\njkl\njkL\nmno\npqr\nPqr\nstu\nvwx\nyz"
 980        );
 981        assert_eq!(
 982            editor
 983                .update(cx, |editor, cx| editor.selections.newest::<Point>(cx))
 984                .range(),
 985            Point::new(3, 0)..Point::new(3, 0)
 986        );
 987
 988        // Keeping a range that doesn't intersect the current selection doesn't move it.
 989        agent_diff.update_in(cx, |diff, window, cx| {
 990            let position = editor
 991                .read(cx)
 992                .buffer()
 993                .read(cx)
 994                .read(cx)
 995                .anchor_before(Point::new(7, 0));
 996            diff.keep_edits_in_ranges(vec![position..position], window, cx)
 997        });
 998        cx.run_until_parked();
 999        assert_eq!(
1000            editor.read_with(cx, |editor, cx| editor.text(cx)),
1001            "abc\ndEf\nghi\njkl\njkL\nmno\nPqr\nstu\nvwx\nyz"
1002        );
1003        assert_eq!(
1004            editor
1005                .update(cx, |editor, cx| editor.selections.newest::<Point>(cx))
1006                .range(),
1007            Point::new(3, 0)..Point::new(3, 0)
1008        );
1009    }
1010}