entry_view_state.rs

  1use std::ops::Range;
  2
  3use super::thread_history::ThreadHistory;
  4use acp_thread::{AcpThread, AgentThreadEntry};
  5use agent::ThreadStore;
  6use agent_client_protocol::ToolCallId;
  7use collections::HashMap;
  8use editor::{Editor, EditorEvent, EditorMode, MinimapVisibility, SizingBehavior};
  9use gpui::{
 10    AnyEntity, App, AppContext as _, Entity, EntityId, EventEmitter, FocusHandle, Focusable,
 11    ScrollHandle, TextStyleRefinement, WeakEntity, Window,
 12};
 13use language::language_settings::SoftWrap;
 14use project::{AgentId, Project};
 15use prompt_store::PromptStore;
 16use rope::Point;
 17use settings::Settings as _;
 18use terminal_view::TerminalView;
 19use theme_settings::ThemeSettings;
 20use ui::{Context, TextSize};
 21use workspace::Workspace;
 22
 23use crate::message_editor::{MessageEditor, MessageEditorEvent, SharedSessionCapabilities};
 24
 25pub struct EntryViewState {
 26    workspace: WeakEntity<Workspace>,
 27    project: WeakEntity<Project>,
 28    thread_store: Option<Entity<ThreadStore>>,
 29    history: Option<WeakEntity<ThreadHistory>>,
 30    prompt_store: Option<Entity<PromptStore>>,
 31    entries: Vec<Entry>,
 32    session_capabilities: SharedSessionCapabilities,
 33    agent_id: AgentId,
 34}
 35
 36impl EntryViewState {
 37    pub fn new(
 38        workspace: WeakEntity<Workspace>,
 39        project: WeakEntity<Project>,
 40        thread_store: Option<Entity<ThreadStore>>,
 41        history: Option<WeakEntity<ThreadHistory>>,
 42        prompt_store: Option<Entity<PromptStore>>,
 43        session_capabilities: SharedSessionCapabilities,
 44        agent_id: AgentId,
 45    ) -> Self {
 46        Self {
 47            workspace,
 48            project,
 49            thread_store,
 50            history,
 51            prompt_store,
 52            entries: Vec::new(),
 53            session_capabilities,
 54            agent_id,
 55        }
 56    }
 57
 58    pub fn entry(&self, index: usize) -> Option<&Entry> {
 59        self.entries.get(index)
 60    }
 61
 62    pub fn sync_entry(
 63        &mut self,
 64        index: usize,
 65        thread: &Entity<AcpThread>,
 66        window: &mut Window,
 67        cx: &mut Context<Self>,
 68    ) {
 69        let Some(thread_entry) = thread.read(cx).entries().get(index) else {
 70            return;
 71        };
 72
 73        match thread_entry {
 74            AgentThreadEntry::UserMessage(message) => {
 75                let can_rewind = thread.read(cx).supports_truncate(cx);
 76                let has_id = message.id.is_some();
 77                let is_subagent = thread.read(cx).parent_session_id().is_some();
 78                let chunks = message.chunks.clone();
 79                if let Some(Entry::UserMessage(editor)) = self.entries.get_mut(index) {
 80                    if !editor.focus_handle(cx).is_focused(window) {
 81                        // Only update if we are not editing.
 82                        // If we are, cancelling the edit will set the message to the newest content.
 83                        editor.update(cx, |editor, cx| {
 84                            editor.set_message(chunks, window, cx);
 85                        });
 86                    }
 87                } else {
 88                    let message_editor = cx.new(|cx| {
 89                        let mut editor = MessageEditor::new(
 90                            self.workspace.clone(),
 91                            self.project.clone(),
 92                            self.thread_store.clone(),
 93                            self.history.clone(),
 94                            self.prompt_store.clone(),
 95                            self.session_capabilities.clone(),
 96                            self.agent_id.clone(),
 97                            "Edit message - @ to include context",
 98                            editor::EditorMode::AutoHeight {
 99                                min_lines: 1,
100                                max_lines: None,
101                            },
102                            window,
103                            cx,
104                        );
105                        if !can_rewind || !has_id || is_subagent {
106                            editor.set_read_only(true, cx);
107                        }
108                        editor.set_message(chunks, window, cx);
109                        editor
110                    });
111                    cx.subscribe(&message_editor, move |_, editor, event, cx| {
112                        cx.emit(EntryViewEvent {
113                            entry_index: index,
114                            view_event: ViewEvent::MessageEditorEvent(editor, event.clone()),
115                        })
116                    })
117                    .detach();
118                    self.set_entry(index, Entry::UserMessage(message_editor));
119                }
120            }
121            AgentThreadEntry::ToolCall(tool_call) => {
122                let id = tool_call.id.clone();
123                let terminals = tool_call.terminals().cloned().collect::<Vec<_>>();
124                let diffs = tool_call.diffs().cloned().collect::<Vec<_>>();
125
126                let views = if let Some(Entry::ToolCall(tool_call)) = self.entries.get_mut(index) {
127                    &mut tool_call.content
128                } else {
129                    self.set_entry(
130                        index,
131                        Entry::ToolCall(ToolCallEntry {
132                            content: HashMap::default(),
133                            focus_handle: cx.focus_handle(),
134                        }),
135                    );
136                    let Some(Entry::ToolCall(tool_call)) = self.entries.get_mut(index) else {
137                        unreachable!()
138                    };
139                    &mut tool_call.content
140                };
141
142                let is_tool_call_completed =
143                    matches!(tool_call.status, acp_thread::ToolCallStatus::Completed);
144
145                for terminal in terminals {
146                    match views.entry(terminal.entity_id()) {
147                        collections::hash_map::Entry::Vacant(entry) => {
148                            let element = create_terminal(
149                                self.workspace.clone(),
150                                self.project.clone(),
151                                terminal.clone(),
152                                window,
153                                cx,
154                            )
155                            .into_any();
156                            cx.emit(EntryViewEvent {
157                                entry_index: index,
158                                view_event: ViewEvent::NewTerminal(id.clone()),
159                            });
160                            entry.insert(element);
161                        }
162                        collections::hash_map::Entry::Occupied(_entry) => {
163                            if is_tool_call_completed && terminal.read(cx).output().is_none() {
164                                cx.emit(EntryViewEvent {
165                                    entry_index: index,
166                                    view_event: ViewEvent::TerminalMovedToBackground(id.clone()),
167                                });
168                            }
169                        }
170                    }
171                }
172
173                for diff in diffs {
174                    views.entry(diff.entity_id()).or_insert_with(|| {
175                        let editor = create_editor_diff(diff.clone(), window, cx);
176                        cx.subscribe(&editor, {
177                            let diff = diff.clone();
178                            let entry_index = index;
179                            move |_this, _editor, event: &EditorEvent, cx| {
180                                if let EditorEvent::OpenExcerptsRequested {
181                                    selections_by_buffer,
182                                    split,
183                                } = event
184                                {
185                                    let multibuffer = diff.read(cx).multibuffer();
186                                    if let Some((buffer_id, (ranges, _))) =
187                                        selections_by_buffer.iter().next()
188                                    {
189                                        if let Some(buffer) =
190                                            multibuffer.read(cx).buffer(*buffer_id)
191                                        {
192                                            if let Some(range) = ranges.first() {
193                                                let point =
194                                                    buffer.read(cx).offset_to_point(range.start.0);
195                                                if let Some(path) = diff.read(cx).file_path(cx) {
196                                                    cx.emit(EntryViewEvent {
197                                                        entry_index,
198                                                        view_event: ViewEvent::OpenDiffLocation {
199                                                            path,
200                                                            position: point,
201                                                            split: *split,
202                                                        },
203                                                    });
204                                                }
205                                            }
206                                        }
207                                    }
208                                }
209                            }
210                        })
211                        .detach();
212                        cx.emit(EntryViewEvent {
213                            entry_index: index,
214                            view_event: ViewEvent::NewDiff(id.clone()),
215                        });
216                        editor.into_any()
217                    });
218                }
219            }
220            AgentThreadEntry::AssistantMessage(message) => {
221                let entry = if let Some(Entry::AssistantMessage(entry)) =
222                    self.entries.get_mut(index)
223                {
224                    entry
225                } else {
226                    self.set_entry(
227                        index,
228                        Entry::AssistantMessage(AssistantMessageEntry {
229                            scroll_handles_by_chunk_index: HashMap::default(),
230                            focus_handle: cx.focus_handle(),
231                        }),
232                    );
233                    let Some(Entry::AssistantMessage(entry)) = self.entries.get_mut(index) else {
234                        unreachable!()
235                    };
236                    entry
237                };
238                entry.sync(message);
239            }
240            AgentThreadEntry::CompletedPlan(_) => {
241                if !matches!(self.entries.get(index), Some(Entry::CompletedPlan)) {
242                    self.set_entry(index, Entry::CompletedPlan);
243                }
244            }
245        };
246    }
247
248    fn set_entry(&mut self, index: usize, entry: Entry) {
249        if index == self.entries.len() {
250            self.entries.push(entry);
251        } else {
252            self.entries[index] = entry;
253        }
254    }
255
256    pub fn remove(&mut self, range: Range<usize>) {
257        self.entries.drain(range);
258    }
259
260    pub fn agent_ui_font_size_changed(&mut self, cx: &mut App) {
261        for entry in self.entries.iter() {
262            match entry {
263                Entry::UserMessage { .. }
264                | Entry::AssistantMessage { .. }
265                | Entry::CompletedPlan => {}
266                Entry::ToolCall(ToolCallEntry { content, .. }) => {
267                    for view in content.values() {
268                        if let Ok(diff_editor) = view.clone().downcast::<Editor>() {
269                            diff_editor.update(cx, |diff_editor, cx| {
270                                diff_editor.set_text_style_refinement(
271                                    diff_editor_text_style_refinement(cx),
272                                );
273                                cx.notify();
274                            })
275                        }
276                    }
277                }
278            }
279        }
280    }
281}
282
283impl EventEmitter<EntryViewEvent> for EntryViewState {}
284
285pub struct EntryViewEvent {
286    pub entry_index: usize,
287    pub view_event: ViewEvent,
288}
289
290pub enum ViewEvent {
291    NewDiff(ToolCallId),
292    NewTerminal(ToolCallId),
293    TerminalMovedToBackground(ToolCallId),
294    MessageEditorEvent(Entity<MessageEditor>, MessageEditorEvent),
295    OpenDiffLocation {
296        path: String,
297        position: Point,
298        split: bool,
299    },
300}
301
302#[derive(Debug)]
303pub struct AssistantMessageEntry {
304    scroll_handles_by_chunk_index: HashMap<usize, ScrollHandle>,
305    focus_handle: FocusHandle,
306}
307
308impl AssistantMessageEntry {
309    pub fn scroll_handle_for_chunk(&self, ix: usize) -> Option<ScrollHandle> {
310        self.scroll_handles_by_chunk_index.get(&ix).cloned()
311    }
312
313    pub fn sync(&mut self, message: &acp_thread::AssistantMessage) {
314        if let Some(acp_thread::AssistantMessageChunk::Thought { .. }) = message.chunks.last() {
315            let ix = message.chunks.len() - 1;
316            let handle = self.scroll_handles_by_chunk_index.entry(ix).or_default();
317            handle.scroll_to_bottom();
318        }
319    }
320}
321
322#[derive(Debug)]
323pub struct ToolCallEntry {
324    content: HashMap<EntityId, AnyEntity>,
325    focus_handle: FocusHandle,
326}
327
328#[derive(Debug)]
329pub enum Entry {
330    UserMessage(Entity<MessageEditor>),
331    AssistantMessage(AssistantMessageEntry),
332    ToolCall(ToolCallEntry),
333    CompletedPlan,
334}
335
336impl Entry {
337    pub fn focus_handle(&self, cx: &App) -> Option<FocusHandle> {
338        match self {
339            Self::UserMessage(editor) => Some(editor.read(cx).focus_handle(cx)),
340            Self::AssistantMessage(message) => Some(message.focus_handle.clone()),
341            Self::ToolCall(tool_call) => Some(tool_call.focus_handle.clone()),
342            Self::CompletedPlan => None,
343        }
344    }
345
346    pub fn message_editor(&self) -> Option<&Entity<MessageEditor>> {
347        match self {
348            Self::UserMessage(editor) => Some(editor),
349            Self::AssistantMessage(_) | Self::ToolCall(_) | Self::CompletedPlan => None,
350        }
351    }
352
353    pub fn editor_for_diff(&self, diff: &Entity<acp_thread::Diff>) -> Option<Entity<Editor>> {
354        self.content_map()?
355            .get(&diff.entity_id())
356            .cloned()
357            .map(|entity| entity.downcast::<Editor>().unwrap())
358    }
359
360    pub fn terminal(
361        &self,
362        terminal: &Entity<acp_thread::Terminal>,
363    ) -> Option<Entity<TerminalView>> {
364        self.content_map()?
365            .get(&terminal.entity_id())
366            .cloned()
367            .map(|entity| entity.downcast::<TerminalView>().unwrap())
368    }
369
370    pub fn scroll_handle_for_assistant_message_chunk(
371        &self,
372        chunk_ix: usize,
373    ) -> Option<ScrollHandle> {
374        match self {
375            Self::AssistantMessage(message) => message.scroll_handle_for_chunk(chunk_ix),
376            Self::UserMessage(_) | Self::ToolCall(_) | Self::CompletedPlan => None,
377        }
378    }
379
380    fn content_map(&self) -> Option<&HashMap<EntityId, AnyEntity>> {
381        match self {
382            Self::ToolCall(ToolCallEntry { content, .. }) => Some(content),
383            _ => None,
384        }
385    }
386
387    #[cfg(test)]
388    pub fn has_content(&self) -> bool {
389        match self {
390            Self::ToolCall(ToolCallEntry { content, .. }) => !content.is_empty(),
391            Self::UserMessage(_) | Self::AssistantMessage(_) | Self::CompletedPlan => false,
392        }
393    }
394}
395
396impl Focusable for ToolCallEntry {
397    fn focus_handle(&self, _cx: &App) -> FocusHandle {
398        self.focus_handle.clone()
399    }
400}
401
402impl Focusable for Entry {
403    fn focus_handle(&self, cx: &App) -> FocusHandle {
404        match self {
405            Self::UserMessage(editor) => editor.read(cx).focus_handle(cx),
406            Self::AssistantMessage(message) => message.focus_handle.clone(),
407            Self::ToolCall(tool_call) => tool_call.focus_handle.clone(),
408            Self::CompletedPlan => cx.focus_handle(),
409        }
410    }
411}
412
413fn create_terminal(
414    workspace: WeakEntity<Workspace>,
415    project: WeakEntity<Project>,
416    terminal: Entity<acp_thread::Terminal>,
417    window: &mut Window,
418    cx: &mut App,
419) -> Entity<TerminalView> {
420    cx.new(|cx| {
421        let mut view = TerminalView::new(
422            terminal.read(cx).inner().clone(),
423            workspace,
424            None,
425            project,
426            window,
427            cx,
428        );
429        view.set_embedded_mode(Some(1000), cx);
430        view
431    })
432}
433
434fn create_editor_diff(
435    diff: Entity<acp_thread::Diff>,
436    window: &mut Window,
437    cx: &mut App,
438) -> Entity<Editor> {
439    cx.new(|cx| {
440        let mut editor = Editor::new(
441            EditorMode::Full {
442                scale_ui_elements_with_buffer_font_size: false,
443                show_active_line_background: false,
444                sizing_behavior: SizingBehavior::SizeByContent,
445            },
446            diff.read(cx).multibuffer().clone(),
447            None,
448            window,
449            cx,
450        );
451        editor.set_show_gutter(false, cx);
452        editor.disable_inline_diagnostics();
453        editor.disable_expand_excerpt_buttons(cx);
454        editor.set_show_vertical_scrollbar(false, cx);
455        editor.set_minimap_visibility(MinimapVisibility::Disabled, window, cx);
456        editor.set_soft_wrap_mode(SoftWrap::None, cx);
457        editor.scroll_manager.set_forbid_vertical_scroll(true);
458        editor.set_show_indent_guides(false, cx);
459        editor.set_read_only(true);
460        editor.set_delegate_open_excerpts(true);
461        editor.set_show_bookmarks(false, cx);
462        editor.set_show_breakpoints(false, cx);
463        editor.set_show_code_actions(false, cx);
464        editor.set_show_git_diff_gutter(false, cx);
465        editor.set_expand_all_diff_hunks(cx);
466        editor.set_text_style_refinement(diff_editor_text_style_refinement(cx));
467        editor
468    })
469}
470
471fn diff_editor_text_style_refinement(cx: &mut App) -> TextStyleRefinement {
472    TextStyleRefinement {
473        font_size: Some(
474            TextSize::Small
475                .rems(cx)
476                .to_pixels(ThemeSettings::get_global(cx).agent_ui_font_size(cx))
477                .into(),
478        ),
479        ..Default::default()
480    }
481}
482
483#[cfg(test)]
484mod tests {
485    use std::path::Path;
486    use std::rc::Rc;
487    use std::sync::Arc;
488
489    use acp_thread::{AgentConnection, StubAgentConnection};
490    use agent_client_protocol as acp;
491    use buffer_diff::{DiffHunkStatus, DiffHunkStatusKind};
492    use editor::RowInfo;
493    use fs::FakeFs;
494    use gpui::{AppContext as _, TestAppContext};
495    use parking_lot::RwLock;
496
497    use crate::entry_view_state::EntryViewState;
498    use crate::message_editor::SessionCapabilities;
499    use multi_buffer::MultiBufferRow;
500    use pretty_assertions::assert_matches;
501    use project::Project;
502    use serde_json::json;
503    use settings::SettingsStore;
504    use util::path;
505    use workspace::{MultiWorkspace, PathList};
506
507    #[gpui::test]
508    async fn test_diff_sync(cx: &mut TestAppContext) {
509        init_test(cx);
510        let fs = FakeFs::new(cx.executor());
511        fs.insert_tree(
512            "/project",
513            json!({
514                "hello.txt": "hi world"
515            }),
516        )
517        .await;
518        let project = Project::test(fs, [Path::new(path!("/project"))], cx).await;
519
520        let (multi_workspace, cx) =
521            cx.add_window_view(|window, cx| MultiWorkspace::test_new(project.clone(), window, cx));
522        let workspace = multi_workspace.read_with(cx, |mw, _| mw.workspace().clone());
523
524        let tool_call = acp::ToolCall::new("tool", "Tool call")
525            .status(acp::ToolCallStatus::InProgress)
526            .content(vec![acp::ToolCallContent::Diff(
527                acp::Diff::new("/project/hello.txt", "hello world").old_text("hi world"),
528            )]);
529        let connection = Rc::new(StubAgentConnection::new());
530        let thread = cx
531            .update(|_, cx| {
532                connection.clone().new_session(
533                    project.clone(),
534                    PathList::new(&[Path::new(path!("/project"))]),
535                    cx,
536                )
537            })
538            .await
539            .unwrap();
540        let session_id = thread.update(cx, |thread, _| thread.session_id().clone());
541
542        cx.update(|_, cx| {
543            connection.send_update(session_id, acp::SessionUpdate::ToolCall(tool_call), cx)
544        });
545
546        let thread_store = None;
547        let history: Option<gpui::WeakEntity<crate::ThreadHistory>> = None;
548
549        let view_state = cx.new(|_cx| {
550            EntryViewState::new(
551                workspace.downgrade(),
552                project.downgrade(),
553                thread_store,
554                history,
555                None,
556                Arc::new(RwLock::new(SessionCapabilities::default())),
557                "Test Agent".into(),
558            )
559        });
560
561        view_state.update_in(cx, |view_state, window, cx| {
562            view_state.sync_entry(0, &thread, window, cx)
563        });
564
565        let diff = thread.read_with(cx, |thread, _| {
566            thread
567                .entries()
568                .get(0)
569                .unwrap()
570                .diffs()
571                .next()
572                .unwrap()
573                .clone()
574        });
575
576        cx.run_until_parked();
577
578        let diff_editor = view_state.read_with(cx, |view_state, _cx| {
579            view_state.entry(0).unwrap().editor_for_diff(&diff).unwrap()
580        });
581        assert_eq!(
582            diff_editor.read_with(cx, |editor, cx| editor.text(cx)),
583            "hi world\nhello world"
584        );
585        let row_infos = diff_editor.read_with(cx, |editor, cx| {
586            let multibuffer = editor.buffer().read(cx);
587            multibuffer
588                .snapshot(cx)
589                .row_infos(MultiBufferRow(0))
590                .collect::<Vec<_>>()
591        });
592        assert_matches!(
593            row_infos.as_slice(),
594            [
595                RowInfo {
596                    multibuffer_row: Some(MultiBufferRow(0)),
597                    diff_status: Some(DiffHunkStatus {
598                        kind: DiffHunkStatusKind::Deleted,
599                        ..
600                    }),
601                    ..
602                },
603                RowInfo {
604                    multibuffer_row: Some(MultiBufferRow(1)),
605                    diff_status: Some(DiffHunkStatus {
606                        kind: DiffHunkStatusKind::Added,
607                        ..
608                    }),
609                    ..
610                }
611            ]
612        );
613    }
614
615    fn init_test(cx: &mut TestAppContext) {
616        cx.update(|cx| {
617            let settings_store = SettingsStore::test(cx);
618            cx.set_global(settings_store);
619            theme_settings::init(theme::LoadThemes::JustBase, cx);
620            release_channel::init(semver::Version::new(0, 0, 0), cx);
621        });
622    }
623}