entry_view_state.rs

  1use std::ops::Range;
  2
  3use acp_thread::{AcpThread, AgentThreadEntry};
  4use agent::ThreadStore;
  5use agent_client_protocol::ToolCallId;
  6use collections::HashMap;
  7use editor::{Editor, EditorEvent, EditorMode, MinimapVisibility, SizingBehavior};
  8use gpui::{
  9    AnyEntity, App, AppContext as _, Entity, EntityId, EventEmitter, FocusHandle, Focusable,
 10    ScrollHandle, TextStyleRefinement, WeakEntity, Window,
 11};
 12use language::language_settings::SoftWrap;
 13use project::{AgentId, Project};
 14use prompt_store::PromptStore;
 15use rope::Point;
 16use settings::Settings as _;
 17use terminal_view::TerminalView;
 18use theme_settings::ThemeSettings;
 19use ui::{Context, TextSize};
 20use workspace::Workspace;
 21
 22use crate::message_editor::{MessageEditor, MessageEditorEvent, SharedSessionCapabilities};
 23
 24pub struct EntryViewState {
 25    workspace: WeakEntity<Workspace>,
 26    project: WeakEntity<Project>,
 27    thread_store: Option<Entity<ThreadStore>>,
 28    prompt_store: Option<Entity<PromptStore>>,
 29    entries: Vec<Entry>,
 30    session_capabilities: SharedSessionCapabilities,
 31    agent_id: AgentId,
 32}
 33
 34impl EntryViewState {
 35    pub fn new(
 36        workspace: WeakEntity<Workspace>,
 37        project: WeakEntity<Project>,
 38        thread_store: Option<Entity<ThreadStore>>,
 39        prompt_store: Option<Entity<PromptStore>>,
 40        session_capabilities: SharedSessionCapabilities,
 41        agent_id: AgentId,
 42    ) -> Self {
 43        Self {
 44            workspace,
 45            project,
 46            thread_store,
 47            prompt_store,
 48            entries: Vec::new(),
 49            session_capabilities,
 50            agent_id,
 51        }
 52    }
 53
 54    pub fn entry(&self, index: usize) -> Option<&Entry> {
 55        self.entries.get(index)
 56    }
 57
 58    pub fn sync_entry(
 59        &mut self,
 60        index: usize,
 61        thread: &Entity<AcpThread>,
 62        window: &mut Window,
 63        cx: &mut Context<Self>,
 64    ) {
 65        let Some(thread_entry) = thread.read(cx).entries().get(index) else {
 66            return;
 67        };
 68
 69        match thread_entry {
 70            AgentThreadEntry::UserMessage(message) => {
 71                let can_rewind = thread.read(cx).supports_truncate(cx);
 72                let has_id = message.id.is_some();
 73                let is_subagent = thread.read(cx).parent_session_id().is_some();
 74                let chunks = message.chunks.clone();
 75                if let Some(Entry::UserMessage(editor)) = self.entries.get_mut(index) {
 76                    if !editor.focus_handle(cx).is_focused(window) {
 77                        // Only update if we are not editing.
 78                        // If we are, cancelling the edit will set the message to the newest content.
 79                        editor.update(cx, |editor, cx| {
 80                            editor.set_message(chunks, window, cx);
 81                        });
 82                    }
 83                } else {
 84                    let message_editor = cx.new(|cx| {
 85                        let mut editor = MessageEditor::new(
 86                            self.workspace.clone(),
 87                            self.project.clone(),
 88                            self.thread_store.clone(),
 89                            self.prompt_store.clone(),
 90                            self.session_capabilities.clone(),
 91                            self.agent_id.clone(),
 92                            "Edit message - @ to include context",
 93                            editor::EditorMode::AutoHeight {
 94                                min_lines: 1,
 95                                max_lines: None,
 96                            },
 97                            window,
 98                            cx,
 99                        );
100                        if !can_rewind || !has_id || is_subagent {
101                            editor.set_read_only(true, cx);
102                        }
103                        editor.set_message(chunks, window, cx);
104                        editor
105                    });
106                    cx.subscribe(&message_editor, move |_, editor, event, cx| {
107                        cx.emit(EntryViewEvent {
108                            entry_index: index,
109                            view_event: ViewEvent::MessageEditorEvent(editor, event.clone()),
110                        })
111                    })
112                    .detach();
113                    self.set_entry(index, Entry::UserMessage(message_editor));
114                }
115            }
116            AgentThreadEntry::ToolCall(tool_call) => {
117                let id = tool_call.id.clone();
118                let terminals = tool_call.terminals().cloned().collect::<Vec<_>>();
119                let diffs = tool_call.diffs().cloned().collect::<Vec<_>>();
120
121                let views = if let Some(Entry::ToolCall(tool_call)) = self.entries.get_mut(index) {
122                    &mut tool_call.content
123                } else {
124                    self.set_entry(
125                        index,
126                        Entry::ToolCall(ToolCallEntry {
127                            content: HashMap::default(),
128                            focus_handle: cx.focus_handle(),
129                        }),
130                    );
131                    let Some(Entry::ToolCall(tool_call)) = self.entries.get_mut(index) else {
132                        unreachable!()
133                    };
134                    &mut tool_call.content
135                };
136
137                let is_tool_call_completed =
138                    matches!(tool_call.status, acp_thread::ToolCallStatus::Completed);
139
140                for terminal in terminals {
141                    match views.entry(terminal.entity_id()) {
142                        collections::hash_map::Entry::Vacant(entry) => {
143                            let element = create_terminal(
144                                self.workspace.clone(),
145                                self.project.clone(),
146                                terminal.clone(),
147                                window,
148                                cx,
149                            )
150                            .into_any();
151                            cx.emit(EntryViewEvent {
152                                entry_index: index,
153                                view_event: ViewEvent::NewTerminal(id.clone()),
154                            });
155                            entry.insert(element);
156                        }
157                        collections::hash_map::Entry::Occupied(_entry) => {
158                            if is_tool_call_completed && terminal.read(cx).output().is_none() {
159                                cx.emit(EntryViewEvent {
160                                    entry_index: index,
161                                    view_event: ViewEvent::TerminalMovedToBackground(id.clone()),
162                                });
163                            }
164                        }
165                    }
166                }
167
168                for diff in diffs {
169                    views.entry(diff.entity_id()).or_insert_with(|| {
170                        let editor = create_editor_diff(diff.clone(), window, cx);
171                        cx.subscribe(&editor, {
172                            let diff = diff.clone();
173                            let entry_index = index;
174                            move |_this, _editor, event: &EditorEvent, cx| {
175                                if let EditorEvent::OpenExcerptsRequested {
176                                    selections_by_buffer,
177                                    split,
178                                } = event
179                                {
180                                    let multibuffer = diff.read(cx).multibuffer();
181                                    if let Some((buffer_id, (ranges, _))) =
182                                        selections_by_buffer.iter().next()
183                                    {
184                                        if let Some(buffer) =
185                                            multibuffer.read(cx).buffer(*buffer_id)
186                                        {
187                                            if let Some(range) = ranges.first() {
188                                                let point =
189                                                    buffer.read(cx).offset_to_point(range.start.0);
190                                                if let Some(path) = diff.read(cx).file_path(cx) {
191                                                    cx.emit(EntryViewEvent {
192                                                        entry_index,
193                                                        view_event: ViewEvent::OpenDiffLocation {
194                                                            path,
195                                                            position: point,
196                                                            split: *split,
197                                                        },
198                                                    });
199                                                }
200                                            }
201                                        }
202                                    }
203                                }
204                            }
205                        })
206                        .detach();
207                        cx.emit(EntryViewEvent {
208                            entry_index: index,
209                            view_event: ViewEvent::NewDiff(id.clone()),
210                        });
211                        editor.into_any()
212                    });
213                }
214            }
215            AgentThreadEntry::AssistantMessage(message) => {
216                let entry = if let Some(Entry::AssistantMessage(entry)) =
217                    self.entries.get_mut(index)
218                {
219                    entry
220                } else {
221                    self.set_entry(
222                        index,
223                        Entry::AssistantMessage(AssistantMessageEntry {
224                            scroll_handles_by_chunk_index: HashMap::default(),
225                            focus_handle: cx.focus_handle(),
226                        }),
227                    );
228                    let Some(Entry::AssistantMessage(entry)) = self.entries.get_mut(index) else {
229                        unreachable!()
230                    };
231                    entry
232                };
233                entry.sync(message);
234            }
235            AgentThreadEntry::CompletedPlan(_) => {
236                if !matches!(self.entries.get(index), Some(Entry::CompletedPlan)) {
237                    self.set_entry(index, Entry::CompletedPlan);
238                }
239            }
240        };
241    }
242
243    fn set_entry(&mut self, index: usize, entry: Entry) {
244        if index == self.entries.len() {
245            self.entries.push(entry);
246        } else {
247            self.entries[index] = entry;
248        }
249    }
250
251    pub fn remove(&mut self, range: Range<usize>) {
252        self.entries.drain(range);
253    }
254
255    pub fn agent_ui_font_size_changed(&mut self, cx: &mut App) {
256        for entry in self.entries.iter() {
257            match entry {
258                Entry::UserMessage { .. }
259                | Entry::AssistantMessage { .. }
260                | Entry::CompletedPlan => {}
261                Entry::ToolCall(ToolCallEntry { content, .. }) => {
262                    for view in content.values() {
263                        if let Ok(diff_editor) = view.clone().downcast::<Editor>() {
264                            diff_editor.update(cx, |diff_editor, cx| {
265                                diff_editor.set_text_style_refinement(
266                                    diff_editor_text_style_refinement(cx),
267                                );
268                                cx.notify();
269                            })
270                        }
271                    }
272                }
273            }
274        }
275    }
276}
277
278impl EventEmitter<EntryViewEvent> for EntryViewState {}
279
280pub struct EntryViewEvent {
281    pub entry_index: usize,
282    pub view_event: ViewEvent,
283}
284
285pub enum ViewEvent {
286    NewDiff(ToolCallId),
287    NewTerminal(ToolCallId),
288    TerminalMovedToBackground(ToolCallId),
289    MessageEditorEvent(Entity<MessageEditor>, MessageEditorEvent),
290    OpenDiffLocation {
291        path: String,
292        position: Point,
293        split: bool,
294    },
295}
296
297#[derive(Debug)]
298pub struct AssistantMessageEntry {
299    scroll_handles_by_chunk_index: HashMap<usize, ScrollHandle>,
300    focus_handle: FocusHandle,
301}
302
303impl AssistantMessageEntry {
304    pub fn scroll_handle_for_chunk(&self, ix: usize) -> Option<ScrollHandle> {
305        self.scroll_handles_by_chunk_index.get(&ix).cloned()
306    }
307
308    pub fn sync(&mut self, message: &acp_thread::AssistantMessage) {
309        if let Some(acp_thread::AssistantMessageChunk::Thought { .. }) = message.chunks.last() {
310            let ix = message.chunks.len() - 1;
311            let handle = self.scroll_handles_by_chunk_index.entry(ix).or_default();
312            handle.scroll_to_bottom();
313        }
314    }
315}
316
317#[derive(Debug)]
318pub struct ToolCallEntry {
319    content: HashMap<EntityId, AnyEntity>,
320    focus_handle: FocusHandle,
321}
322
323#[derive(Debug)]
324pub enum Entry {
325    UserMessage(Entity<MessageEditor>),
326    AssistantMessage(AssistantMessageEntry),
327    ToolCall(ToolCallEntry),
328    CompletedPlan,
329}
330
331impl Entry {
332    pub fn focus_handle(&self, cx: &App) -> Option<FocusHandle> {
333        match self {
334            Self::UserMessage(editor) => Some(editor.read(cx).focus_handle(cx)),
335            Self::AssistantMessage(message) => Some(message.focus_handle.clone()),
336            Self::ToolCall(tool_call) => Some(tool_call.focus_handle.clone()),
337            Self::CompletedPlan => None,
338        }
339    }
340
341    pub fn message_editor(&self) -> Option<&Entity<MessageEditor>> {
342        match self {
343            Self::UserMessage(editor) => Some(editor),
344            Self::AssistantMessage(_) | Self::ToolCall(_) | Self::CompletedPlan => None,
345        }
346    }
347
348    pub fn editor_for_diff(&self, diff: &Entity<acp_thread::Diff>) -> Option<Entity<Editor>> {
349        self.content_map()?
350            .get(&diff.entity_id())
351            .cloned()
352            .map(|entity| entity.downcast::<Editor>().unwrap())
353    }
354
355    pub fn terminal(
356        &self,
357        terminal: &Entity<acp_thread::Terminal>,
358    ) -> Option<Entity<TerminalView>> {
359        self.content_map()?
360            .get(&terminal.entity_id())
361            .cloned()
362            .map(|entity| entity.downcast::<TerminalView>().unwrap())
363    }
364
365    pub fn scroll_handle_for_assistant_message_chunk(
366        &self,
367        chunk_ix: usize,
368    ) -> Option<ScrollHandle> {
369        match self {
370            Self::AssistantMessage(message) => message.scroll_handle_for_chunk(chunk_ix),
371            Self::UserMessage(_) | Self::ToolCall(_) | Self::CompletedPlan => None,
372        }
373    }
374
375    fn content_map(&self) -> Option<&HashMap<EntityId, AnyEntity>> {
376        match self {
377            Self::ToolCall(ToolCallEntry { content, .. }) => Some(content),
378            _ => None,
379        }
380    }
381
382    #[cfg(test)]
383    pub fn has_content(&self) -> bool {
384        match self {
385            Self::ToolCall(ToolCallEntry { content, .. }) => !content.is_empty(),
386            Self::UserMessage(_) | Self::AssistantMessage(_) | Self::CompletedPlan => false,
387        }
388    }
389}
390
391impl Focusable for ToolCallEntry {
392    fn focus_handle(&self, _cx: &App) -> FocusHandle {
393        self.focus_handle.clone()
394    }
395}
396
397impl Focusable for Entry {
398    fn focus_handle(&self, cx: &App) -> FocusHandle {
399        match self {
400            Self::UserMessage(editor) => editor.read(cx).focus_handle(cx),
401            Self::AssistantMessage(message) => message.focus_handle.clone(),
402            Self::ToolCall(tool_call) => tool_call.focus_handle.clone(),
403            Self::CompletedPlan => cx.focus_handle(),
404        }
405    }
406}
407
408fn create_terminal(
409    workspace: WeakEntity<Workspace>,
410    project: WeakEntity<Project>,
411    terminal: Entity<acp_thread::Terminal>,
412    window: &mut Window,
413    cx: &mut App,
414) -> Entity<TerminalView> {
415    cx.new(|cx| {
416        let mut view = TerminalView::new(
417            terminal.read(cx).inner().clone(),
418            workspace,
419            None,
420            project,
421            window,
422            cx,
423        );
424        view.set_embedded_mode(Some(1000), cx);
425        view
426    })
427}
428
429fn create_editor_diff(
430    diff: Entity<acp_thread::Diff>,
431    window: &mut Window,
432    cx: &mut App,
433) -> Entity<Editor> {
434    cx.new(|cx| {
435        let mut editor = Editor::new(
436            EditorMode::Full {
437                scale_ui_elements_with_buffer_font_size: false,
438                show_active_line_background: false,
439                sizing_behavior: SizingBehavior::SizeByContent,
440            },
441            diff.read(cx).multibuffer().clone(),
442            None,
443            window,
444            cx,
445        );
446        editor.set_show_gutter(false, cx);
447        editor.disable_inline_diagnostics();
448        editor.disable_expand_excerpt_buttons(cx);
449        editor.set_show_vertical_scrollbar(false, cx);
450        editor.set_minimap_visibility(MinimapVisibility::Disabled, window, cx);
451        editor.set_soft_wrap_mode(SoftWrap::None, cx);
452        editor.scroll_manager.set_forbid_vertical_scroll(true);
453        editor.set_show_indent_guides(false, cx);
454        editor.set_read_only(true);
455        editor.set_delegate_open_excerpts(true);
456        editor.set_show_bookmarks(false, cx);
457        editor.set_show_breakpoints(false, cx);
458        editor.set_show_code_actions(false, cx);
459        editor.set_show_git_diff_gutter(false, cx);
460        editor.set_expand_all_diff_hunks(cx);
461        editor.set_text_style_refinement(diff_editor_text_style_refinement(cx));
462        editor
463    })
464}
465
466fn diff_editor_text_style_refinement(cx: &mut App) -> TextStyleRefinement {
467    TextStyleRefinement {
468        font_size: Some(
469            TextSize::Small
470                .rems(cx)
471                .to_pixels(ThemeSettings::get_global(cx).agent_ui_font_size(cx))
472                .into(),
473        ),
474        ..Default::default()
475    }
476}
477
478#[cfg(test)]
479mod tests {
480    use std::path::Path;
481    use std::rc::Rc;
482    use std::sync::Arc;
483
484    use acp_thread::{AgentConnection, StubAgentConnection};
485    use agent_client_protocol as acp;
486    use buffer_diff::{DiffHunkStatus, DiffHunkStatusKind};
487    use editor::RowInfo;
488    use fs::FakeFs;
489    use gpui::{AppContext as _, TestAppContext};
490    use parking_lot::RwLock;
491
492    use crate::entry_view_state::EntryViewState;
493    use crate::message_editor::SessionCapabilities;
494    use multi_buffer::MultiBufferRow;
495    use pretty_assertions::assert_matches;
496    use project::Project;
497    use serde_json::json;
498    use settings::SettingsStore;
499    use util::path;
500    use workspace::{MultiWorkspace, PathList};
501
502    #[gpui::test]
503    async fn test_diff_sync(cx: &mut TestAppContext) {
504        init_test(cx);
505        let fs = FakeFs::new(cx.executor());
506        fs.insert_tree(
507            "/project",
508            json!({
509                "hello.txt": "hi world"
510            }),
511        )
512        .await;
513        let project = Project::test(fs, [Path::new(path!("/project"))], cx).await;
514
515        let (multi_workspace, cx) =
516            cx.add_window_view(|window, cx| MultiWorkspace::test_new(project.clone(), window, cx));
517        let workspace = multi_workspace.read_with(cx, |mw, _| mw.workspace().clone());
518
519        let tool_call = acp::ToolCall::new("tool", "Tool call")
520            .status(acp::ToolCallStatus::InProgress)
521            .content(vec![acp::ToolCallContent::Diff(
522                acp::Diff::new("/project/hello.txt", "hello world").old_text("hi world"),
523            )]);
524        let connection = Rc::new(StubAgentConnection::new());
525        let thread = cx
526            .update(|_, cx| {
527                connection.clone().new_session(
528                    project.clone(),
529                    PathList::new(&[Path::new(path!("/project"))]),
530                    cx,
531                )
532            })
533            .await
534            .unwrap();
535        let session_id = thread.update(cx, |thread, _| thread.session_id().clone());
536
537        cx.update(|_, cx| {
538            connection.send_update(session_id, acp::SessionUpdate::ToolCall(tool_call), cx)
539        });
540
541        let thread_store = None;
542
543        let view_state = cx.new(|_cx| {
544            EntryViewState::new(
545                workspace.downgrade(),
546                project.downgrade(),
547                thread_store,
548                None,
549                Arc::new(RwLock::new(SessionCapabilities::default())),
550                "Test Agent".into(),
551            )
552        });
553
554        view_state.update_in(cx, |view_state, window, cx| {
555            view_state.sync_entry(0, &thread, window, cx)
556        });
557
558        let diff = thread.read_with(cx, |thread, _| {
559            thread
560                .entries()
561                .get(0)
562                .unwrap()
563                .diffs()
564                .next()
565                .unwrap()
566                .clone()
567        });
568
569        cx.run_until_parked();
570
571        let diff_editor = view_state.read_with(cx, |view_state, _cx| {
572            view_state.entry(0).unwrap().editor_for_diff(&diff).unwrap()
573        });
574        assert_eq!(
575            diff_editor.read_with(cx, |editor, cx| editor.text(cx)),
576            "hi world\nhello world"
577        );
578        let row_infos = diff_editor.read_with(cx, |editor, cx| {
579            let multibuffer = editor.buffer().read(cx);
580            multibuffer
581                .snapshot(cx)
582                .row_infos(MultiBufferRow(0))
583                .collect::<Vec<_>>()
584        });
585        assert_matches!(
586            row_infos.as_slice(),
587            [
588                RowInfo {
589                    multibuffer_row: Some(MultiBufferRow(0)),
590                    diff_status: Some(DiffHunkStatus {
591                        kind: DiffHunkStatusKind::Deleted,
592                        ..
593                    }),
594                    ..
595                },
596                RowInfo {
597                    multibuffer_row: Some(MultiBufferRow(1)),
598                    diff_status: Some(DiffHunkStatus {
599                        kind: DiffHunkStatusKind::Added,
600                        ..
601                    }),
602                    ..
603                }
604            ]
605        );
606    }
607
608    fn init_test(cx: &mut TestAppContext) {
609        cx.update(|cx| {
610            let settings_store = SettingsStore::test(cx);
611            cx.set_global(settings_store);
612            theme_settings::init(theme::LoadThemes::JustBase, cx);
613            release_channel::init(semver::Version::new(0, 0, 0), cx);
614        });
615    }
616}