entry_view_state.rs

  1use std::{cell::RefCell, ops::Range, rc::Rc};
  2
  3use super::thread_history::AcpThreadHistory;
  4use acp_thread::{AcpThread, AgentThreadEntry};
  5use agent::ThreadStore;
  6use agent_client_protocol::{self as acp, ToolCallId};
  7use collections::HashMap;
  8use editor::{Editor, EditorMode, MinimapVisibility, SizingBehavior};
  9use gpui::{
 10    AnyEntity, App, AppContext as _, Entity, EntityId, EventEmitter, FocusHandle, Focusable,
 11    ScrollHandle, SharedString, TextStyleRefinement, WeakEntity, Window,
 12};
 13use language::language_settings::SoftWrap;
 14use project::Project;
 15use prompt_store::PromptStore;
 16use settings::Settings as _;
 17use terminal_view::TerminalView;
 18use theme::ThemeSettings;
 19use ui::{Context, TextSize};
 20use workspace::Workspace;
 21
 22use crate::acp::message_editor::{MessageEditor, MessageEditorEvent};
 23
 24pub struct EntryViewState {
 25    workspace: WeakEntity<Workspace>,
 26    project: WeakEntity<Project>,
 27    thread_store: Option<Entity<ThreadStore>>,
 28    history: WeakEntity<AcpThreadHistory>,
 29    prompt_store: Option<Entity<PromptStore>>,
 30    entries: Vec<Entry>,
 31    prompt_capabilities: Rc<RefCell<acp::PromptCapabilities>>,
 32    available_commands: Rc<RefCell<Vec<acp::AvailableCommand>>>,
 33    agent_name: SharedString,
 34}
 35
 36impl EntryViewState {
 37    pub fn new(
 38        workspace: WeakEntity<Workspace>,
 39        project: WeakEntity<Project>,
 40        thread_store: Option<Entity<ThreadStore>>,
 41        history: WeakEntity<AcpThreadHistory>,
 42        prompt_store: Option<Entity<PromptStore>>,
 43        prompt_capabilities: Rc<RefCell<acp::PromptCapabilities>>,
 44        available_commands: Rc<RefCell<Vec<acp::AvailableCommand>>>,
 45        agent_name: SharedString,
 46    ) -> Self {
 47        Self {
 48            workspace,
 49            project,
 50            thread_store,
 51            history,
 52            prompt_store,
 53            entries: Vec::new(),
 54            prompt_capabilities,
 55            available_commands,
 56            agent_name,
 57        }
 58    }
 59
 60    pub fn entry(&self, index: usize) -> Option<&Entry> {
 61        self.entries.get(index)
 62    }
 63
 64    pub fn sync_entry(
 65        &mut self,
 66        index: usize,
 67        thread: &Entity<AcpThread>,
 68        window: &mut Window,
 69        cx: &mut Context<Self>,
 70    ) {
 71        let Some(thread_entry) = thread.read(cx).entries().get(index) else {
 72            return;
 73        };
 74
 75        match thread_entry {
 76            AgentThreadEntry::UserMessage(message) => {
 77                let has_id = message.id.is_some();
 78                let chunks = message.chunks.clone();
 79                if let Some(Entry::UserMessage(editor)) = self.entries.get_mut(index) {
 80                    if !editor.focus_handle(cx).is_focused(window) {
 81                        // Only update if we are not editing.
 82                        // If we are, cancelling the edit will set the message to the newest content.
 83                        editor.update(cx, |editor, cx| {
 84                            editor.set_message(chunks, window, cx);
 85                        });
 86                    }
 87                } else {
 88                    let message_editor = cx.new(|cx| {
 89                        let mut editor = MessageEditor::new(
 90                            self.workspace.clone(),
 91                            self.project.clone(),
 92                            self.thread_store.clone(),
 93                            self.history.clone(),
 94                            self.prompt_store.clone(),
 95                            self.prompt_capabilities.clone(),
 96                            self.available_commands.clone(),
 97                            self.agent_name.clone(),
 98                            "Edit message - @ to include context",
 99                            editor::EditorMode::AutoHeight {
100                                min_lines: 1,
101                                max_lines: None,
102                            },
103                            window,
104                            cx,
105                        );
106                        if !has_id {
107                            editor.set_read_only(true, cx);
108                        }
109                        editor.set_message(chunks, window, cx);
110                        editor
111                    });
112                    cx.subscribe(&message_editor, move |_, editor, event, cx| {
113                        cx.emit(EntryViewEvent {
114                            entry_index: index,
115                            view_event: ViewEvent::MessageEditorEvent(editor, *event),
116                        })
117                    })
118                    .detach();
119                    self.set_entry(index, Entry::UserMessage(message_editor));
120                }
121            }
122            AgentThreadEntry::ToolCall(tool_call) => {
123                let id = tool_call.id.clone();
124                let terminals = tool_call.terminals().cloned().collect::<Vec<_>>();
125                let diffs = tool_call.diffs().cloned().collect::<Vec<_>>();
126
127                let views = if let Some(Entry::Content(views)) = self.entries.get_mut(index) {
128                    views
129                } else {
130                    self.set_entry(index, Entry::empty());
131                    let Some(Entry::Content(views)) = self.entries.get_mut(index) else {
132                        unreachable!()
133                    };
134                    views
135                };
136
137                let is_tool_call_completed =
138                    matches!(tool_call.status, acp_thread::ToolCallStatus::Completed);
139
140                for terminal in terminals {
141                    match views.entry(terminal.entity_id()) {
142                        collections::hash_map::Entry::Vacant(entry) => {
143                            let element = create_terminal(
144                                self.workspace.clone(),
145                                self.project.clone(),
146                                terminal.clone(),
147                                window,
148                                cx,
149                            )
150                            .into_any();
151                            cx.emit(EntryViewEvent {
152                                entry_index: index,
153                                view_event: ViewEvent::NewTerminal(id.clone()),
154                            });
155                            entry.insert(element);
156                        }
157                        collections::hash_map::Entry::Occupied(_entry) => {
158                            if is_tool_call_completed && terminal.read(cx).output().is_none() {
159                                cx.emit(EntryViewEvent {
160                                    entry_index: index,
161                                    view_event: ViewEvent::TerminalMovedToBackground(id.clone()),
162                                });
163                            }
164                        }
165                    }
166                }
167
168                for diff in diffs {
169                    views.entry(diff.entity_id()).or_insert_with(|| {
170                        let element = create_editor_diff(diff.clone(), window, cx).into_any();
171                        cx.emit(EntryViewEvent {
172                            entry_index: index,
173                            view_event: ViewEvent::NewDiff(id.clone()),
174                        });
175                        element
176                    });
177                }
178            }
179            AgentThreadEntry::AssistantMessage(message) => {
180                let entry = if let Some(Entry::AssistantMessage(entry)) =
181                    self.entries.get_mut(index)
182                {
183                    entry
184                } else {
185                    self.set_entry(
186                        index,
187                        Entry::AssistantMessage(AssistantMessageEntry::default()),
188                    );
189                    let Some(Entry::AssistantMessage(entry)) = self.entries.get_mut(index) else {
190                        unreachable!()
191                    };
192                    entry
193                };
194                entry.sync(message);
195            }
196        };
197    }
198
199    fn set_entry(&mut self, index: usize, entry: Entry) {
200        if index == self.entries.len() {
201            self.entries.push(entry);
202        } else {
203            self.entries[index] = entry;
204        }
205    }
206
207    pub fn remove(&mut self, range: Range<usize>) {
208        self.entries.drain(range);
209    }
210
211    pub fn agent_ui_font_size_changed(&mut self, cx: &mut App) {
212        for entry in self.entries.iter() {
213            match entry {
214                Entry::UserMessage { .. } | Entry::AssistantMessage { .. } => {}
215                Entry::Content(response_views) => {
216                    for view in response_views.values() {
217                        if let Ok(diff_editor) = view.clone().downcast::<Editor>() {
218                            diff_editor.update(cx, |diff_editor, cx| {
219                                diff_editor.set_text_style_refinement(
220                                    diff_editor_text_style_refinement(cx),
221                                );
222                                cx.notify();
223                            })
224                        }
225                    }
226                }
227            }
228        }
229    }
230}
231
232impl EventEmitter<EntryViewEvent> for EntryViewState {}
233
234pub struct EntryViewEvent {
235    pub entry_index: usize,
236    pub view_event: ViewEvent,
237}
238
239pub enum ViewEvent {
240    NewDiff(ToolCallId),
241    NewTerminal(ToolCallId),
242    TerminalMovedToBackground(ToolCallId),
243    MessageEditorEvent(Entity<MessageEditor>, MessageEditorEvent),
244}
245
246#[derive(Default, Debug)]
247pub struct AssistantMessageEntry {
248    scroll_handles_by_chunk_index: HashMap<usize, ScrollHandle>,
249}
250
251impl AssistantMessageEntry {
252    pub fn scroll_handle_for_chunk(&self, ix: usize) -> Option<ScrollHandle> {
253        self.scroll_handles_by_chunk_index.get(&ix).cloned()
254    }
255
256    pub fn sync(&mut self, message: &acp_thread::AssistantMessage) {
257        if let Some(acp_thread::AssistantMessageChunk::Thought { .. }) = message.chunks.last() {
258            let ix = message.chunks.len() - 1;
259            let handle = self.scroll_handles_by_chunk_index.entry(ix).or_default();
260            handle.scroll_to_bottom();
261        }
262    }
263}
264
265#[derive(Debug)]
266pub enum Entry {
267    UserMessage(Entity<MessageEditor>),
268    AssistantMessage(AssistantMessageEntry),
269    Content(HashMap<EntityId, AnyEntity>),
270}
271
272impl Entry {
273    pub fn focus_handle(&self, cx: &App) -> Option<FocusHandle> {
274        match self {
275            Self::UserMessage(editor) => Some(editor.read(cx).focus_handle(cx)),
276            Self::AssistantMessage(_) | Self::Content(_) => None,
277        }
278    }
279
280    pub fn message_editor(&self) -> Option<&Entity<MessageEditor>> {
281        match self {
282            Self::UserMessage(editor) => Some(editor),
283            Self::AssistantMessage(_) | Self::Content(_) => None,
284        }
285    }
286
287    pub fn editor_for_diff(&self, diff: &Entity<acp_thread::Diff>) -> Option<Entity<Editor>> {
288        self.content_map()?
289            .get(&diff.entity_id())
290            .cloned()
291            .map(|entity| entity.downcast::<Editor>().unwrap())
292    }
293
294    pub fn terminal(
295        &self,
296        terminal: &Entity<acp_thread::Terminal>,
297    ) -> Option<Entity<TerminalView>> {
298        self.content_map()?
299            .get(&terminal.entity_id())
300            .cloned()
301            .map(|entity| entity.downcast::<TerminalView>().unwrap())
302    }
303
304    pub fn scroll_handle_for_assistant_message_chunk(
305        &self,
306        chunk_ix: usize,
307    ) -> Option<ScrollHandle> {
308        match self {
309            Self::AssistantMessage(message) => message.scroll_handle_for_chunk(chunk_ix),
310            Self::UserMessage(_) | Self::Content(_) => None,
311        }
312    }
313
314    fn content_map(&self) -> Option<&HashMap<EntityId, AnyEntity>> {
315        match self {
316            Self::Content(map) => Some(map),
317            _ => None,
318        }
319    }
320
321    fn empty() -> Self {
322        Self::Content(HashMap::default())
323    }
324
325    #[cfg(test)]
326    pub fn has_content(&self) -> bool {
327        match self {
328            Self::Content(map) => !map.is_empty(),
329            Self::UserMessage(_) | Self::AssistantMessage(_) => false,
330        }
331    }
332}
333
334fn create_terminal(
335    workspace: WeakEntity<Workspace>,
336    project: WeakEntity<Project>,
337    terminal: Entity<acp_thread::Terminal>,
338    window: &mut Window,
339    cx: &mut App,
340) -> Entity<TerminalView> {
341    cx.new(|cx| {
342        let mut view = TerminalView::new(
343            terminal.read(cx).inner().clone(),
344            workspace,
345            None,
346            project,
347            window,
348            cx,
349        );
350        view.set_embedded_mode(Some(1000), cx);
351        view
352    })
353}
354
355fn create_editor_diff(
356    diff: Entity<acp_thread::Diff>,
357    window: &mut Window,
358    cx: &mut App,
359) -> Entity<Editor> {
360    cx.new(|cx| {
361        let mut editor = Editor::new(
362            EditorMode::Full {
363                scale_ui_elements_with_buffer_font_size: false,
364                show_active_line_background: false,
365                sizing_behavior: SizingBehavior::SizeByContent,
366            },
367            diff.read(cx).multibuffer().clone(),
368            None,
369            window,
370            cx,
371        );
372        editor.set_show_gutter(false, cx);
373        editor.disable_inline_diagnostics();
374        editor.disable_expand_excerpt_buttons(cx);
375        editor.set_show_vertical_scrollbar(false, cx);
376        editor.set_minimap_visibility(MinimapVisibility::Disabled, window, cx);
377        editor.set_soft_wrap_mode(SoftWrap::None, cx);
378        editor.scroll_manager.set_forbid_vertical_scroll(true);
379        editor.set_show_indent_guides(false, cx);
380        editor.set_read_only(true);
381        editor.set_show_breakpoints(false, cx);
382        editor.set_show_code_actions(false, cx);
383        editor.set_show_git_diff_gutter(false, cx);
384        editor.set_expand_all_diff_hunks(cx);
385        editor.set_text_style_refinement(diff_editor_text_style_refinement(cx));
386        editor
387    })
388}
389
390fn diff_editor_text_style_refinement(cx: &mut App) -> TextStyleRefinement {
391    TextStyleRefinement {
392        font_size: Some(
393            TextSize::Small
394                .rems(cx)
395                .to_pixels(ThemeSettings::get_global(cx).agent_ui_font_size(cx))
396                .into(),
397        ),
398        ..Default::default()
399    }
400}
401
402#[cfg(test)]
403mod tests {
404    use std::path::Path;
405    use std::rc::Rc;
406
407    use acp_thread::{AgentConnection, StubAgentConnection};
408    use agent_client_protocol as acp;
409    use buffer_diff::{DiffHunkStatus, DiffHunkStatusKind};
410    use editor::RowInfo;
411    use fs::FakeFs;
412    use gpui::{AppContext as _, TestAppContext};
413
414    use crate::acp::entry_view_state::EntryViewState;
415    use multi_buffer::MultiBufferRow;
416    use pretty_assertions::assert_matches;
417    use project::Project;
418    use serde_json::json;
419    use settings::SettingsStore;
420    use util::path;
421    use workspace::Workspace;
422
423    #[gpui::test]
424    async fn test_diff_sync(cx: &mut TestAppContext) {
425        init_test(cx);
426        let fs = FakeFs::new(cx.executor());
427        fs.insert_tree(
428            "/project",
429            json!({
430                "hello.txt": "hi world"
431            }),
432        )
433        .await;
434        let project = Project::test(fs, [Path::new(path!("/project"))], cx).await;
435
436        let (workspace, cx) =
437            cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
438
439        let tool_call = acp::ToolCall::new("tool", "Tool call")
440            .status(acp::ToolCallStatus::InProgress)
441            .content(vec![acp::ToolCallContent::Diff(
442                acp::Diff::new("/project/hello.txt", "hello world").old_text("hi world"),
443            )]);
444        let connection = Rc::new(StubAgentConnection::new());
445        let thread = cx
446            .update(|_, cx| {
447                connection
448                    .clone()
449                    .new_thread(project.clone(), Path::new(path!("/project")), cx)
450            })
451            .await
452            .unwrap();
453        let session_id = thread.update(cx, |thread, _| thread.session_id().clone());
454
455        cx.update(|_, cx| {
456            connection.send_update(session_id, acp::SessionUpdate::ToolCall(tool_call), cx)
457        });
458
459        let thread_store = None;
460        let history = cx
461            .update(|window, cx| cx.new(|cx| crate::acp::AcpThreadHistory::new(None, window, cx)));
462
463        let view_state = cx.new(|_cx| {
464            EntryViewState::new(
465                workspace.downgrade(),
466                project.downgrade(),
467                thread_store,
468                history.downgrade(),
469                None,
470                Default::default(),
471                Default::default(),
472                "Test Agent".into(),
473            )
474        });
475
476        view_state.update_in(cx, |view_state, window, cx| {
477            view_state.sync_entry(0, &thread, window, cx)
478        });
479
480        let diff = thread.read_with(cx, |thread, _| {
481            thread
482                .entries()
483                .get(0)
484                .unwrap()
485                .diffs()
486                .next()
487                .unwrap()
488                .clone()
489        });
490
491        cx.run_until_parked();
492
493        let diff_editor = view_state.read_with(cx, |view_state, _cx| {
494            view_state.entry(0).unwrap().editor_for_diff(&diff).unwrap()
495        });
496        assert_eq!(
497            diff_editor.read_with(cx, |editor, cx| editor.text(cx)),
498            "hi world\nhello world"
499        );
500        let row_infos = diff_editor.read_with(cx, |editor, cx| {
501            let multibuffer = editor.buffer().read(cx);
502            multibuffer
503                .snapshot(cx)
504                .row_infos(MultiBufferRow(0))
505                .collect::<Vec<_>>()
506        });
507        assert_matches!(
508            row_infos.as_slice(),
509            [
510                RowInfo {
511                    multibuffer_row: Some(MultiBufferRow(0)),
512                    diff_status: Some(DiffHunkStatus {
513                        kind: DiffHunkStatusKind::Deleted,
514                        ..
515                    }),
516                    ..
517                },
518                RowInfo {
519                    multibuffer_row: Some(MultiBufferRow(1)),
520                    diff_status: Some(DiffHunkStatus {
521                        kind: DiffHunkStatusKind::Added,
522                        ..
523                    }),
524                    ..
525                }
526            ]
527        );
528    }
529
530    fn init_test(cx: &mut TestAppContext) {
531        cx.update(|cx| {
532            let settings_store = SettingsStore::test(cx);
533            cx.set_global(settings_store);
534            theme::init(theme::LoadThemes::JustBase, cx);
535            release_channel::init(semver::Version::new(0, 0, 0), cx);
536        });
537    }
538}