agent_diff.rs

  1use crate::{Thread, ThreadEvent};
  2use anyhow::Result;
  3use buffer_diff::DiffHunkStatus;
  4use collections::HashSet;
  5use editor::{
  6    AnchorRangeExt, Direction, Editor, EditorEvent, MultiBuffer, ToPoint,
  7    actions::{GoToHunk, GoToPreviousHunk},
  8    scroll::Autoscroll,
  9};
 10use gpui::{
 11    Action, AnyElement, AnyView, App, Entity, EventEmitter, FocusHandle, Focusable, SharedString,
 12    Subscription, Task, WeakEntity, Window, prelude::*,
 13};
 14use language::{Capability, DiskState, OffsetRangeExt, Point};
 15use multi_buffer::PathKey;
 16use project::{Project, ProjectPath};
 17use std::{
 18    any::{Any, TypeId},
 19    ops::Range,
 20    sync::Arc,
 21};
 22use ui::{IconButtonShape, KeyBinding, Tooltip, prelude::*};
 23use workspace::{
 24    Item, ItemHandle, ItemNavHistory, ToolbarItemEvent, ToolbarItemLocation, ToolbarItemView,
 25    Workspace,
 26    item::{BreadcrumbText, ItemEvent, TabContentParams},
 27    searchable::SearchableItemHandle,
 28};
 29
 30pub struct AgentDiff {
 31    multibuffer: Entity<MultiBuffer>,
 32    editor: Entity<Editor>,
 33    thread: Entity<Thread>,
 34    focus_handle: FocusHandle,
 35    workspace: WeakEntity<Workspace>,
 36    title: SharedString,
 37    _subscriptions: Vec<Subscription>,
 38}
 39
 40impl AgentDiff {
 41    pub fn deploy(
 42        thread: Entity<Thread>,
 43        workspace: WeakEntity<Workspace>,
 44        window: &mut Window,
 45        cx: &mut App,
 46    ) -> Result<()> {
 47        let existing_diff = workspace.update(cx, |workspace, cx| {
 48            workspace
 49                .items_of_type::<AgentDiff>(cx)
 50                .find(|diff| diff.read(cx).thread == thread)
 51        })?;
 52        if let Some(existing_diff) = existing_diff {
 53            workspace.update(cx, |workspace, cx| {
 54                workspace.activate_item(&existing_diff, true, true, window, cx);
 55            })
 56        } else {
 57            let agent_diff =
 58                cx.new(|cx| AgentDiff::new(thread.clone(), workspace.clone(), window, cx));
 59            workspace.update(cx, |workspace, cx| {
 60                workspace.add_item_to_center(Box::new(agent_diff), window, cx);
 61            })
 62        }
 63    }
 64
 65    pub fn new(
 66        thread: Entity<Thread>,
 67        workspace: WeakEntity<Workspace>,
 68        window: &mut Window,
 69        cx: &mut Context<Self>,
 70    ) -> Self {
 71        let focus_handle = cx.focus_handle();
 72        let multibuffer = cx.new(|_| MultiBuffer::new(Capability::ReadWrite));
 73
 74        let project = thread.read(cx).project().clone();
 75        let render_diff_hunk_controls = Arc::new({
 76            let agent_diff = cx.entity();
 77            move |row,
 78                  status: &DiffHunkStatus,
 79                  hunk_range,
 80                  is_created_file,
 81                  line_height,
 82                  editor: &Entity<Editor>,
 83                  window: &mut Window,
 84                  cx: &mut App| {
 85                render_diff_hunk_controls(
 86                    row,
 87                    status,
 88                    hunk_range,
 89                    is_created_file,
 90                    line_height,
 91                    &agent_diff,
 92                    editor,
 93                    window,
 94                    cx,
 95                )
 96            }
 97        });
 98        let editor = cx.new(|cx| {
 99            let mut editor =
100                Editor::for_multibuffer(multibuffer.clone(), Some(project.clone()), window, cx);
101            editor.disable_inline_diagnostics();
102            editor.set_expand_all_diff_hunks(cx);
103            editor.set_render_diff_hunk_controls(render_diff_hunk_controls, cx);
104            editor.register_addon(AgentDiffAddon);
105            editor
106        });
107
108        let action_log = thread.read(cx).action_log().clone();
109        let mut this = Self {
110            _subscriptions: vec![
111                cx.observe_in(&action_log, window, |this, _action_log, window, cx| {
112                    this.update_excerpts(window, cx)
113                }),
114                cx.subscribe(&thread, |this, _thread, event, cx| {
115                    this.handle_thread_event(event, cx)
116                }),
117            ],
118            title: SharedString::default(),
119            multibuffer,
120            editor,
121            thread,
122            focus_handle,
123            workspace,
124        };
125        this.update_excerpts(window, cx);
126        this.update_title(cx);
127        this
128    }
129
130    fn update_excerpts(&mut self, window: &mut Window, cx: &mut Context<Self>) {
131        let thread = self.thread.read(cx);
132        let changed_buffers = thread.action_log().read(cx).changed_buffers(cx);
133        let mut paths_to_delete = self.multibuffer.read(cx).paths().collect::<HashSet<_>>();
134
135        for (buffer, diff_handle) in changed_buffers {
136            let Some(file) = buffer.read(cx).file().cloned() else {
137                continue;
138            };
139
140            let path_key = PathKey::namespaced(0, file.full_path(cx).into());
141            paths_to_delete.remove(&path_key);
142
143            let snapshot = buffer.read(cx).snapshot();
144            let diff = diff_handle.read(cx);
145
146            let diff_hunk_ranges = diff
147                .hunks_intersecting_range(
148                    language::Anchor::MIN..language::Anchor::MAX,
149                    &snapshot,
150                    cx,
151                )
152                .map(|diff_hunk| diff_hunk.buffer_range.to_point(&snapshot))
153                .collect::<Vec<_>>();
154
155            let (was_empty, is_excerpt_newly_added) =
156                self.multibuffer.update(cx, |multibuffer, cx| {
157                    let was_empty = multibuffer.is_empty();
158                    let (_, is_excerpt_newly_added) = multibuffer.set_excerpts_for_path(
159                        path_key.clone(),
160                        buffer.clone(),
161                        diff_hunk_ranges,
162                        editor::DEFAULT_MULTIBUFFER_CONTEXT,
163                        cx,
164                    );
165                    multibuffer.add_diff(diff_handle, cx);
166                    (was_empty, is_excerpt_newly_added)
167                });
168
169            self.editor.update(cx, |editor, cx| {
170                if was_empty {
171                    let first_hunk = editor
172                        .diff_hunks_in_ranges(
173                            &[editor::Anchor::min()..editor::Anchor::max()],
174                            &self.multibuffer.read(cx).read(cx),
175                        )
176                        .next();
177
178                    if let Some(first_hunk) = first_hunk {
179                        let first_hunk_start = first_hunk.multi_buffer_range().start;
180                        editor.change_selections(
181                            Some(Autoscroll::fit()),
182                            window,
183                            cx,
184                            |selections| {
185                                selections
186                                    .select_anchor_ranges([first_hunk_start..first_hunk_start]);
187                            },
188                        )
189                    }
190                }
191
192                if is_excerpt_newly_added
193                    && buffer
194                        .read(cx)
195                        .file()
196                        .map_or(false, |file| file.disk_state() == DiskState::Deleted)
197                {
198                    editor.fold_buffer(snapshot.text.remote_id(), cx)
199                }
200            });
201        }
202
203        self.multibuffer.update(cx, |multibuffer, cx| {
204            for path in paths_to_delete {
205                multibuffer.remove_excerpts_for_path(path, cx);
206            }
207        });
208
209        if self.multibuffer.read(cx).is_empty()
210            && self
211                .editor
212                .read(cx)
213                .focus_handle(cx)
214                .contains_focused(window, cx)
215        {
216            self.focus_handle.focus(window);
217        } else if self.focus_handle.is_focused(window) && !self.multibuffer.read(cx).is_empty() {
218            self.editor.update(cx, |editor, cx| {
219                editor.focus_handle(cx).focus(window);
220            });
221        }
222    }
223
224    fn update_title(&mut self, cx: &mut Context<Self>) {
225        let new_title = self
226            .thread
227            .read(cx)
228            .summary()
229            .unwrap_or("Assistant Changes".into());
230        if new_title != self.title {
231            self.title = new_title;
232            cx.emit(EditorEvent::TitleChanged);
233        }
234    }
235
236    fn handle_thread_event(&mut self, event: &ThreadEvent, cx: &mut Context<Self>) {
237        match event {
238            ThreadEvent::SummaryChanged => self.update_title(cx),
239            _ => {}
240        }
241    }
242
243    fn keep(&mut self, _: &crate::Keep, window: &mut Window, cx: &mut Context<Self>) {
244        let ranges = self
245            .editor
246            .read(cx)
247            .selections
248            .disjoint_anchor_ranges()
249            .collect::<Vec<_>>();
250        self.keep_edits_in_ranges(ranges, window, cx);
251    }
252
253    fn reject(&mut self, _: &crate::Reject, window: &mut Window, cx: &mut Context<Self>) {
254        let ranges = self
255            .editor
256            .read(cx)
257            .selections
258            .disjoint_anchor_ranges()
259            .collect::<Vec<_>>();
260        self.reject_edits_in_ranges(ranges, window, cx);
261    }
262
263    fn reject_all(&mut self, _: &crate::RejectAll, window: &mut Window, cx: &mut Context<Self>) {
264        self.reject_edits_in_ranges(
265            vec![editor::Anchor::min()..editor::Anchor::max()],
266            window,
267            cx,
268        );
269    }
270
271    fn keep_all(&mut self, _: &crate::KeepAll, _window: &mut Window, cx: &mut Context<Self>) {
272        self.thread
273            .update(cx, |thread, cx| thread.keep_all_edits(cx));
274    }
275
276    fn keep_edits_in_ranges(
277        &mut self,
278        ranges: Vec<Range<editor::Anchor>>,
279        window: &mut Window,
280        cx: &mut Context<Self>,
281    ) {
282        let snapshot = self.multibuffer.read(cx).snapshot(cx);
283        let diff_hunks_in_ranges = self
284            .editor
285            .read(cx)
286            .diff_hunks_in_ranges(&ranges, &snapshot)
287            .collect::<Vec<_>>();
288        let newest_cursor = self.editor.update(cx, |editor, cx| {
289            editor.selections.newest::<Point>(cx).head()
290        });
291        if diff_hunks_in_ranges.iter().any(|hunk| {
292            hunk.row_range
293                .contains(&multi_buffer::MultiBufferRow(newest_cursor.row))
294        }) {
295            self.update_selection(&diff_hunks_in_ranges, window, cx);
296        }
297
298        for hunk in &diff_hunks_in_ranges {
299            let buffer = self.multibuffer.read(cx).buffer(hunk.buffer_id);
300            if let Some(buffer) = buffer {
301                self.thread.update(cx, |thread, cx| {
302                    thread.keep_edits_in_range(buffer, hunk.buffer_range.clone(), cx)
303                });
304            }
305        }
306    }
307
308    fn reject_edits_in_ranges(
309        &mut self,
310        ranges: Vec<Range<editor::Anchor>>,
311        window: &mut Window,
312        cx: &mut Context<Self>,
313    ) {
314        let snapshot = self.multibuffer.read(cx).snapshot(cx);
315        let diff_hunks_in_ranges = self
316            .editor
317            .read(cx)
318            .diff_hunks_in_ranges(&ranges, &snapshot)
319            .collect::<Vec<_>>();
320        let newest_cursor = self.editor.update(cx, |editor, cx| {
321            editor.selections.newest::<Point>(cx).head()
322        });
323        if diff_hunks_in_ranges.iter().any(|hunk| {
324            hunk.row_range
325                .contains(&multi_buffer::MultiBufferRow(newest_cursor.row))
326        }) {
327            self.update_selection(&diff_hunks_in_ranges, window, cx);
328        }
329
330        let point_ranges = ranges
331            .into_iter()
332            .map(|range| range.to_point(&snapshot))
333            .collect();
334        self.editor.update(cx, |editor, cx| {
335            editor.restore_hunks_in_ranges(point_ranges, window, cx)
336        });
337    }
338
339    fn update_selection(
340        &mut self,
341        diff_hunks: &[multi_buffer::MultiBufferDiffHunk],
342        window: &mut Window,
343        cx: &mut Context<Self>,
344    ) {
345        let snapshot = self.multibuffer.read(cx).snapshot(cx);
346        let target_hunk = diff_hunks
347            .last()
348            .and_then(|last_kept_hunk| {
349                let last_kept_hunk_end = last_kept_hunk.multi_buffer_range().end;
350                self.editor
351                    .read(cx)
352                    .diff_hunks_in_ranges(&[last_kept_hunk_end..editor::Anchor::max()], &snapshot)
353                    .skip(1)
354                    .next()
355            })
356            .or_else(|| {
357                let first_kept_hunk = diff_hunks.first()?;
358                let first_kept_hunk_start = first_kept_hunk.multi_buffer_range().start;
359                self.editor
360                    .read(cx)
361                    .diff_hunks_in_ranges(
362                        &[editor::Anchor::min()..first_kept_hunk_start],
363                        &snapshot,
364                    )
365                    .next()
366            });
367
368        if let Some(target_hunk) = target_hunk {
369            self.editor.update(cx, |editor, cx| {
370                editor.change_selections(Some(Autoscroll::fit()), window, cx, |selections| {
371                    let next_hunk_start = target_hunk.multi_buffer_range().start;
372                    selections.select_anchor_ranges([next_hunk_start..next_hunk_start]);
373                })
374            });
375        }
376    }
377}
378
379impl EventEmitter<EditorEvent> for AgentDiff {}
380
381impl Focusable for AgentDiff {
382    fn focus_handle(&self, cx: &App) -> FocusHandle {
383        if self.multibuffer.read(cx).is_empty() {
384            self.focus_handle.clone()
385        } else {
386            self.editor.focus_handle(cx)
387        }
388    }
389}
390
391impl Item for AgentDiff {
392    type Event = EditorEvent;
393
394    fn tab_icon(&self, _window: &Window, _cx: &App) -> Option<Icon> {
395        Some(Icon::new(IconName::ZedAssistant).color(Color::Muted))
396    }
397
398    fn to_item_events(event: &EditorEvent, f: impl FnMut(ItemEvent)) {
399        Editor::to_item_events(event, f)
400    }
401
402    fn deactivated(&mut self, window: &mut Window, cx: &mut Context<Self>) {
403        self.editor
404            .update(cx, |editor, cx| editor.deactivated(window, cx));
405    }
406
407    fn navigate(
408        &mut self,
409        data: Box<dyn Any>,
410        window: &mut Window,
411        cx: &mut Context<Self>,
412    ) -> bool {
413        self.editor
414            .update(cx, |editor, cx| editor.navigate(data, window, cx))
415    }
416
417    fn tab_tooltip_text(&self, _: &App) -> Option<SharedString> {
418        Some("Agent Diff".into())
419    }
420
421    fn tab_content(&self, params: TabContentParams, _window: &Window, cx: &App) -> AnyElement {
422        let summary = self
423            .thread
424            .read(cx)
425            .summary()
426            .unwrap_or("Assistant Changes".into());
427        Label::new(format!("Review: {}", summary))
428            .color(if params.selected {
429                Color::Default
430            } else {
431                Color::Muted
432            })
433            .into_any_element()
434    }
435
436    fn telemetry_event_text(&self) -> Option<&'static str> {
437        Some("Assistant Diff Opened")
438    }
439
440    fn as_searchable(&self, _: &Entity<Self>) -> Option<Box<dyn SearchableItemHandle>> {
441        Some(Box::new(self.editor.clone()))
442    }
443
444    fn for_each_project_item(
445        &self,
446        cx: &App,
447        f: &mut dyn FnMut(gpui::EntityId, &dyn project::ProjectItem),
448    ) {
449        self.editor.for_each_project_item(cx, f)
450    }
451
452    fn is_singleton(&self, _: &App) -> bool {
453        false
454    }
455
456    fn set_nav_history(
457        &mut self,
458        nav_history: ItemNavHistory,
459        _: &mut Window,
460        cx: &mut Context<Self>,
461    ) {
462        self.editor.update(cx, |editor, _| {
463            editor.set_nav_history(Some(nav_history));
464        });
465    }
466
467    fn clone_on_split(
468        &self,
469        _workspace_id: Option<workspace::WorkspaceId>,
470        window: &mut Window,
471        cx: &mut Context<Self>,
472    ) -> Option<Entity<Self>>
473    where
474        Self: Sized,
475    {
476        Some(cx.new(|cx| Self::new(self.thread.clone(), self.workspace.clone(), window, cx)))
477    }
478
479    fn is_dirty(&self, cx: &App) -> bool {
480        self.multibuffer.read(cx).is_dirty(cx)
481    }
482
483    fn has_conflict(&self, cx: &App) -> bool {
484        self.multibuffer.read(cx).has_conflict(cx)
485    }
486
487    fn can_save(&self, _: &App) -> bool {
488        true
489    }
490
491    fn save(
492        &mut self,
493        format: bool,
494        project: Entity<Project>,
495        window: &mut Window,
496        cx: &mut Context<Self>,
497    ) -> Task<Result<()>> {
498        self.editor.save(format, project, window, cx)
499    }
500
501    fn save_as(
502        &mut self,
503        _: Entity<Project>,
504        _: ProjectPath,
505        _window: &mut Window,
506        _: &mut Context<Self>,
507    ) -> Task<Result<()>> {
508        unreachable!()
509    }
510
511    fn reload(
512        &mut self,
513        project: Entity<Project>,
514        window: &mut Window,
515        cx: &mut Context<Self>,
516    ) -> Task<Result<()>> {
517        self.editor.reload(project, window, cx)
518    }
519
520    fn act_as_type<'a>(
521        &'a self,
522        type_id: TypeId,
523        self_handle: &'a Entity<Self>,
524        _: &'a App,
525    ) -> Option<AnyView> {
526        if type_id == TypeId::of::<Self>() {
527            Some(self_handle.to_any())
528        } else if type_id == TypeId::of::<Editor>() {
529            Some(self.editor.to_any())
530        } else {
531            None
532        }
533    }
534
535    fn breadcrumb_location(&self, _: &App) -> ToolbarItemLocation {
536        ToolbarItemLocation::PrimaryLeft
537    }
538
539    fn breadcrumbs(&self, theme: &theme::Theme, cx: &App) -> Option<Vec<BreadcrumbText>> {
540        self.editor.breadcrumbs(theme, cx)
541    }
542
543    fn added_to_workspace(
544        &mut self,
545        workspace: &mut Workspace,
546        window: &mut Window,
547        cx: &mut Context<Self>,
548    ) {
549        self.editor.update(cx, |editor, cx| {
550            editor.added_to_workspace(workspace, window, cx)
551        });
552    }
553}
554
555impl Render for AgentDiff {
556    fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
557        let is_empty = self.multibuffer.read(cx).is_empty();
558
559        div()
560            .track_focus(&self.focus_handle)
561            .key_context(if is_empty { "EmptyPane" } else { "AgentDiff" })
562            .on_action(cx.listener(Self::keep))
563            .on_action(cx.listener(Self::reject))
564            .on_action(cx.listener(Self::reject_all))
565            .on_action(cx.listener(Self::keep_all))
566            .bg(cx.theme().colors().editor_background)
567            .flex()
568            .items_center()
569            .justify_center()
570            .size_full()
571            .when(is_empty, |el| el.child("No changes to review"))
572            .when(!is_empty, |el| el.child(self.editor.clone()))
573    }
574}
575
576fn render_diff_hunk_controls(
577    row: u32,
578    _status: &DiffHunkStatus,
579    hunk_range: Range<editor::Anchor>,
580    is_created_file: bool,
581    line_height: Pixels,
582    agent_diff: &Entity<AgentDiff>,
583    editor: &Entity<Editor>,
584    window: &mut Window,
585    cx: &mut App,
586) -> AnyElement {
587    let editor = editor.clone();
588    h_flex()
589        .h(line_height)
590        .mr_0p5()
591        .gap_1()
592        .px_0p5()
593        .pb_1()
594        .border_x_1()
595        .border_b_1()
596        .border_color(cx.theme().colors().border)
597        .rounded_b_md()
598        .bg(cx.theme().colors().editor_background)
599        .gap_1()
600        .occlude()
601        .shadow_md()
602        .children(vec![
603            Button::new(("reject", row as u64), "Reject")
604                .disabled(is_created_file)
605                .key_binding(
606                    KeyBinding::for_action_in(
607                        &crate::Reject,
608                        &editor.read(cx).focus_handle(cx),
609                        window,
610                        cx,
611                    )
612                    .map(|kb| kb.size(rems_from_px(12.))),
613                )
614                .on_click({
615                    let agent_diff = agent_diff.clone();
616                    move |_event, window, cx| {
617                        agent_diff.update(cx, |diff, cx| {
618                            diff.reject_edits_in_ranges(
619                                vec![hunk_range.start..hunk_range.start],
620                                window,
621                                cx,
622                            );
623                        });
624                    }
625                }),
626            Button::new(("keep", row as u64), "Keep")
627                .key_binding(
628                    KeyBinding::for_action_in(
629                        &crate::Keep,
630                        &editor.read(cx).focus_handle(cx),
631                        window,
632                        cx,
633                    )
634                    .map(|kb| kb.size(rems_from_px(12.))),
635                )
636                .on_click({
637                    let agent_diff = agent_diff.clone();
638                    move |_event, window, cx| {
639                        agent_diff.update(cx, |diff, cx| {
640                            diff.keep_edits_in_ranges(
641                                vec![hunk_range.start..hunk_range.start],
642                                window,
643                                cx,
644                            );
645                        });
646                    }
647                }),
648        ])
649        .when(
650            !editor.read(cx).buffer().read(cx).all_diff_hunks_expanded(),
651            |el| {
652                el.child(
653                    IconButton::new(("next-hunk", row as u64), IconName::ArrowDown)
654                        .shape(IconButtonShape::Square)
655                        .icon_size(IconSize::Small)
656                        // .disabled(!has_multiple_hunks)
657                        .tooltip({
658                            let focus_handle = editor.focus_handle(cx);
659                            move |window, cx| {
660                                Tooltip::for_action_in(
661                                    "Next Hunk",
662                                    &GoToHunk,
663                                    &focus_handle,
664                                    window,
665                                    cx,
666                                )
667                            }
668                        })
669                        .on_click({
670                            let editor = editor.clone();
671                            move |_event, window, cx| {
672                                editor.update(cx, |editor, cx| {
673                                    let snapshot = editor.snapshot(window, cx);
674                                    let position =
675                                        hunk_range.end.to_point(&snapshot.buffer_snapshot);
676                                    editor.go_to_hunk_before_or_after_position(
677                                        &snapshot,
678                                        position,
679                                        Direction::Next,
680                                        window,
681                                        cx,
682                                    );
683                                    editor.expand_selected_diff_hunks(cx);
684                                });
685                            }
686                        }),
687                )
688                .child(
689                    IconButton::new(("prev-hunk", row as u64), IconName::ArrowUp)
690                        .shape(IconButtonShape::Square)
691                        .icon_size(IconSize::Small)
692                        // .disabled(!has_multiple_hunks)
693                        .tooltip({
694                            let focus_handle = editor.focus_handle(cx);
695                            move |window, cx| {
696                                Tooltip::for_action_in(
697                                    "Previous Hunk",
698                                    &GoToPreviousHunk,
699                                    &focus_handle,
700                                    window,
701                                    cx,
702                                )
703                            }
704                        })
705                        .on_click({
706                            let editor = editor.clone();
707                            move |_event, window, cx| {
708                                editor.update(cx, |editor, cx| {
709                                    let snapshot = editor.snapshot(window, cx);
710                                    let point =
711                                        hunk_range.start.to_point(&snapshot.buffer_snapshot);
712                                    editor.go_to_hunk_before_or_after_position(
713                                        &snapshot,
714                                        point,
715                                        Direction::Prev,
716                                        window,
717                                        cx,
718                                    );
719                                    editor.expand_selected_diff_hunks(cx);
720                                });
721                            }
722                        }),
723                )
724            },
725        )
726        .into_any_element()
727}
728
729struct AgentDiffAddon;
730
731impl editor::Addon for AgentDiffAddon {
732    fn to_any(&self) -> &dyn std::any::Any {
733        self
734    }
735
736    fn extend_key_context(&self, key_context: &mut gpui::KeyContext, _: &App) {
737        key_context.add("agent_diff");
738    }
739}
740
741pub struct AgentDiffToolbar {
742    agent_diff: Option<WeakEntity<AgentDiff>>,
743    _workspace: WeakEntity<Workspace>,
744}
745
746impl AgentDiffToolbar {
747    pub fn new(workspace: &Workspace, _: &mut Context<Self>) -> Self {
748        Self {
749            agent_diff: None,
750            _workspace: workspace.weak_handle(),
751        }
752    }
753
754    fn agent_diff(&self, _: &App) -> Option<Entity<AgentDiff>> {
755        self.agent_diff.as_ref()?.upgrade()
756    }
757
758    fn dispatch_action(&self, action: &dyn Action, window: &mut Window, cx: &mut Context<Self>) {
759        if let Some(agent_diff) = self.agent_diff(cx) {
760            agent_diff.focus_handle(cx).focus(window);
761        }
762        let action = action.boxed_clone();
763        cx.defer(move |cx| {
764            cx.dispatch_action(action.as_ref());
765        })
766    }
767}
768
769impl EventEmitter<ToolbarItemEvent> for AgentDiffToolbar {}
770
771impl ToolbarItemView for AgentDiffToolbar {
772    fn set_active_pane_item(
773        &mut self,
774        active_pane_item: Option<&dyn ItemHandle>,
775        _: &mut Window,
776        cx: &mut Context<Self>,
777    ) -> ToolbarItemLocation {
778        self.agent_diff = active_pane_item
779            .and_then(|item| item.act_as::<AgentDiff>(cx))
780            .map(|entity| entity.downgrade());
781        if self.agent_diff.is_some() {
782            ToolbarItemLocation::PrimaryRight
783        } else {
784            ToolbarItemLocation::Hidden
785        }
786    }
787
788    fn pane_focus_update(
789        &mut self,
790        _pane_focused: bool,
791        _window: &mut Window,
792        _cx: &mut Context<Self>,
793    ) {
794    }
795}
796
797impl Render for AgentDiffToolbar {
798    fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
799        let agent_diff = match self.agent_diff(cx) {
800            Some(ad) => ad,
801            None => return div(),
802        };
803
804        let is_empty = agent_diff.read(cx).multibuffer.read(cx).is_empty();
805
806        if is_empty {
807            return div();
808        }
809
810        h_group_xl()
811            .my_neg_1()
812            .items_center()
813            .p_1()
814            .flex_wrap()
815            .justify_between()
816            .child(
817                h_group_sm()
818                    .child(
819                        Button::new("reject-all", "Reject All").on_click(cx.listener(
820                            |this, _, window, cx| {
821                                this.dispatch_action(&crate::RejectAll, window, cx)
822                            },
823                        )),
824                    )
825                    .child(Button::new("keep-all", "Keep All").on_click(cx.listener(
826                        |this, _, window, cx| this.dispatch_action(&crate::KeepAll, window, cx),
827                    ))),
828            )
829    }
830}
831
832#[cfg(test)]
833mod tests {
834    use super::*;
835    use crate::{ThreadStore, thread_store};
836    use assistant_settings::AssistantSettings;
837    use context_server::ContextServerSettings;
838    use editor::EditorSettings;
839    use gpui::TestAppContext;
840    use project::{FakeFs, Project};
841    use prompt_store::PromptBuilder;
842    use serde_json::json;
843    use settings::{Settings, SettingsStore};
844    use std::sync::Arc;
845    use theme::ThemeSettings;
846    use util::path;
847
848    #[gpui::test]
849    async fn test_agent_diff(cx: &mut TestAppContext) {
850        cx.update(|cx| {
851            let settings_store = SettingsStore::test(cx);
852            cx.set_global(settings_store);
853            language::init(cx);
854            Project::init_settings(cx);
855            AssistantSettings::register(cx);
856            thread_store::init(cx);
857            workspace::init_settings(cx);
858            ThemeSettings::register(cx);
859            ContextServerSettings::register(cx);
860            EditorSettings::register(cx);
861        });
862
863        let fs = FakeFs::new(cx.executor());
864        fs.insert_tree(
865            path!("/test"),
866            json!({"file1": "abc\ndef\nghi\njkl\nmno\npqr\nstu\nvwx\nyz"}),
867        )
868        .await;
869        let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
870        let buffer_path = project
871            .read_with(cx, |project, cx| {
872                project.find_project_path("test/file1", cx)
873            })
874            .unwrap();
875
876        let thread_store = cx.update(|cx| {
877            ThreadStore::new(
878                project.clone(),
879                Arc::default(),
880                Arc::new(PromptBuilder::new(None).unwrap()),
881                cx,
882            )
883            .unwrap()
884        });
885        let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
886        let action_log = thread.read_with(cx, |thread, _| thread.action_log().clone());
887
888        let (workspace, cx) =
889            cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
890        let agent_diff = cx.new_window_entity(|window, cx| {
891            AgentDiff::new(thread.clone(), workspace.downgrade(), window, cx)
892        });
893        let editor = agent_diff.read_with(cx, |diff, _cx| diff.editor.clone());
894
895        let buffer = project
896            .update(cx, |project, cx| project.open_buffer(buffer_path, cx))
897            .await
898            .unwrap();
899        cx.update(|_, cx| {
900            action_log.update(cx, |log, cx| log.buffer_read(buffer.clone(), cx));
901            buffer.update(cx, |buffer, cx| {
902                buffer
903                    .edit(
904                        [
905                            (Point::new(1, 1)..Point::new(1, 2), "E"),
906                            (Point::new(3, 2)..Point::new(3, 3), "L"),
907                            (Point::new(5, 0)..Point::new(5, 1), "P"),
908                            (Point::new(7, 1)..Point::new(7, 2), "W"),
909                        ],
910                        None,
911                        cx,
912                    )
913                    .unwrap()
914            });
915            action_log.update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx));
916        });
917        cx.run_until_parked();
918
919        // When opening the assistant diff, the cursor is positioned on the first hunk.
920        assert_eq!(
921            editor.read_with(cx, |editor, cx| editor.text(cx)),
922            "abc\ndef\ndEf\nghi\njkl\njkL\nmno\npqr\nPqr\nstu\nvwx\nvWx\nyz"
923        );
924        assert_eq!(
925            editor
926                .update(cx, |editor, cx| editor.selections.newest::<Point>(cx))
927                .range(),
928            Point::new(1, 0)..Point::new(1, 0)
929        );
930
931        // After keeping a hunk, the cursor should be positioned on the second hunk.
932        agent_diff.update_in(cx, |diff, window, cx| diff.keep(&crate::Keep, window, cx));
933        cx.run_until_parked();
934        assert_eq!(
935            editor.read_with(cx, |editor, cx| editor.text(cx)),
936            "abc\ndEf\nghi\njkl\njkL\nmno\npqr\nPqr\nstu\nvwx\nvWx\nyz"
937        );
938        assert_eq!(
939            editor
940                .update(cx, |editor, cx| editor.selections.newest::<Point>(cx))
941                .range(),
942            Point::new(3, 0)..Point::new(3, 0)
943        );
944
945        // Restoring a hunk also moves the cursor to the next hunk, possibly cycling if it's at the end.
946        editor.update_in(cx, |editor, window, cx| {
947            editor.change_selections(None, window, cx, |selections| {
948                selections.select_ranges([Point::new(10, 0)..Point::new(10, 0)])
949            });
950        });
951        agent_diff.update_in(cx, |diff, window, cx| {
952            diff.reject(&crate::Reject, window, cx)
953        });
954        cx.run_until_parked();
955        assert_eq!(
956            editor.read_with(cx, |editor, cx| editor.text(cx)),
957            "abc\ndEf\nghi\njkl\njkL\nmno\npqr\nPqr\nstu\nvwx\nyz"
958        );
959        assert_eq!(
960            editor
961                .update(cx, |editor, cx| editor.selections.newest::<Point>(cx))
962                .range(),
963            Point::new(3, 0)..Point::new(3, 0)
964        );
965
966        // Keeping a range that doesn't intersect the current selection doesn't move it.
967        agent_diff.update_in(cx, |diff, window, cx| {
968            let position = editor
969                .read(cx)
970                .buffer()
971                .read(cx)
972                .read(cx)
973                .anchor_before(Point::new(7, 0));
974            diff.keep_edits_in_ranges(vec![position..position], window, cx)
975        });
976        cx.run_until_parked();
977        assert_eq!(
978            editor.read_with(cx, |editor, cx| editor.text(cx)),
979            "abc\ndEf\nghi\njkl\njkL\nmno\nPqr\nstu\nvwx\nyz"
980        );
981        assert_eq!(
982            editor
983                .update(cx, |editor, cx| editor.selections.newest::<Point>(cx))
984                .range(),
985            Point::new(3, 0)..Point::new(3, 0)
986        );
987    }
988}