entry_view_state.rs

  1use std::{cell::RefCell, ops::Range, rc::Rc};
  2
  3use super::thread_history::AcpThreadHistory;
  4use crate::user_slash_command::{CommandLoadError, UserSlashCommand};
  5use acp_thread::{AcpThread, AgentThreadEntry};
  6use agent::ThreadStore;
  7use agent_client_protocol::{self as acp, ToolCallId};
  8use collections::HashMap;
  9use editor::{Editor, EditorMode, MinimapVisibility, SizingBehavior};
 10use gpui::{
 11    AnyEntity, App, AppContext as _, Entity, EntityId, EventEmitter, FocusHandle, Focusable,
 12    ScrollHandle, SharedString, TextStyleRefinement, WeakEntity, Window,
 13};
 14use language::language_settings::SoftWrap;
 15use project::Project;
 16use prompt_store::PromptStore;
 17use settings::Settings as _;
 18use terminal_view::TerminalView;
 19use theme::ThemeSettings;
 20use ui::{Context, TextSize};
 21use workspace::Workspace;
 22
 23use crate::acp::message_editor::{MessageEditor, MessageEditorEvent};
 24
 25pub struct EntryViewState {
 26    workspace: WeakEntity<Workspace>,
 27    project: WeakEntity<Project>,
 28    thread_store: Option<Entity<ThreadStore>>,
 29    history: WeakEntity<AcpThreadHistory>,
 30    prompt_store: Option<Entity<PromptStore>>,
 31    entries: Vec<Entry>,
 32    prompt_capabilities: Rc<RefCell<acp::PromptCapabilities>>,
 33    available_commands: Rc<RefCell<Vec<acp::AvailableCommand>>>,
 34    cached_user_commands: Rc<RefCell<HashMap<String, UserSlashCommand>>>,
 35    cached_user_command_errors: Rc<RefCell<Vec<CommandLoadError>>>,
 36    agent_name: SharedString,
 37}
 38
 39impl EntryViewState {
 40    pub fn new(
 41        workspace: WeakEntity<Workspace>,
 42        project: WeakEntity<Project>,
 43        thread_store: Option<Entity<ThreadStore>>,
 44        history: WeakEntity<AcpThreadHistory>,
 45        prompt_store: Option<Entity<PromptStore>>,
 46        prompt_capabilities: Rc<RefCell<acp::PromptCapabilities>>,
 47        available_commands: Rc<RefCell<Vec<acp::AvailableCommand>>>,
 48        cached_user_commands: Rc<RefCell<HashMap<String, UserSlashCommand>>>,
 49        cached_user_command_errors: Rc<RefCell<Vec<CommandLoadError>>>,
 50        agent_name: SharedString,
 51    ) -> Self {
 52        Self {
 53            workspace,
 54            project,
 55            thread_store,
 56            history,
 57            prompt_store,
 58            entries: Vec::new(),
 59            prompt_capabilities,
 60            available_commands,
 61            cached_user_commands,
 62            cached_user_command_errors,
 63            agent_name,
 64        }
 65    }
 66
 67    pub fn entry(&self, index: usize) -> Option<&Entry> {
 68        self.entries.get(index)
 69    }
 70
 71    pub fn sync_entry(
 72        &mut self,
 73        index: usize,
 74        thread: &Entity<AcpThread>,
 75        window: &mut Window,
 76        cx: &mut Context<Self>,
 77    ) {
 78        let Some(thread_entry) = thread.read(cx).entries().get(index) else {
 79            return;
 80        };
 81
 82        match thread_entry {
 83            AgentThreadEntry::UserMessage(message) => {
 84                let has_id = message.id.is_some();
 85                let chunks = message.chunks.clone();
 86                if let Some(Entry::UserMessage(editor)) = self.entries.get_mut(index) {
 87                    if !editor.focus_handle(cx).is_focused(window) {
 88                        // Only update if we are not editing.
 89                        // If we are, cancelling the edit will set the message to the newest content.
 90                        editor.update(cx, |editor, cx| {
 91                            editor.set_message(chunks, window, cx);
 92                        });
 93                    }
 94                } else {
 95                    let message_editor = cx.new(|cx| {
 96                        let mut editor = MessageEditor::new_with_cache(
 97                            self.workspace.clone(),
 98                            self.project.clone(),
 99                            self.thread_store.clone(),
100                            self.history.clone(),
101                            self.prompt_store.clone(),
102                            self.prompt_capabilities.clone(),
103                            self.available_commands.clone(),
104                            self.cached_user_commands.clone(),
105                            self.cached_user_command_errors.clone(),
106                            self.agent_name.clone(),
107                            "Edit message - @ to include context",
108                            editor::EditorMode::AutoHeight {
109                                min_lines: 1,
110                                max_lines: None,
111                            },
112                            window,
113                            cx,
114                        );
115                        if !has_id {
116                            editor.set_read_only(true, cx);
117                        }
118                        editor.set_message(chunks, window, cx);
119                        editor
120                    });
121                    cx.subscribe(&message_editor, move |_, editor, event, cx| {
122                        cx.emit(EntryViewEvent {
123                            entry_index: index,
124                            view_event: ViewEvent::MessageEditorEvent(editor, *event),
125                        })
126                    })
127                    .detach();
128                    self.set_entry(index, Entry::UserMessage(message_editor));
129                }
130            }
131            AgentThreadEntry::ToolCall(tool_call) => {
132                let id = tool_call.id.clone();
133                let terminals = tool_call.terminals().cloned().collect::<Vec<_>>();
134                let diffs = tool_call.diffs().cloned().collect::<Vec<_>>();
135
136                let views = if let Some(Entry::Content(views)) = self.entries.get_mut(index) {
137                    views
138                } else {
139                    self.set_entry(index, Entry::empty());
140                    let Some(Entry::Content(views)) = self.entries.get_mut(index) else {
141                        unreachable!()
142                    };
143                    views
144                };
145
146                let is_tool_call_completed =
147                    matches!(tool_call.status, acp_thread::ToolCallStatus::Completed);
148
149                for terminal in terminals {
150                    match views.entry(terminal.entity_id()) {
151                        collections::hash_map::Entry::Vacant(entry) => {
152                            let element = create_terminal(
153                                self.workspace.clone(),
154                                self.project.clone(),
155                                terminal.clone(),
156                                window,
157                                cx,
158                            )
159                            .into_any();
160                            cx.emit(EntryViewEvent {
161                                entry_index: index,
162                                view_event: ViewEvent::NewTerminal(id.clone()),
163                            });
164                            entry.insert(element);
165                        }
166                        collections::hash_map::Entry::Occupied(_entry) => {
167                            if is_tool_call_completed && terminal.read(cx).output().is_none() {
168                                cx.emit(EntryViewEvent {
169                                    entry_index: index,
170                                    view_event: ViewEvent::TerminalMovedToBackground(id.clone()),
171                                });
172                            }
173                        }
174                    }
175                }
176
177                for diff in diffs {
178                    views.entry(diff.entity_id()).or_insert_with(|| {
179                        let element = create_editor_diff(diff.clone(), window, cx).into_any();
180                        cx.emit(EntryViewEvent {
181                            entry_index: index,
182                            view_event: ViewEvent::NewDiff(id.clone()),
183                        });
184                        element
185                    });
186                }
187            }
188            AgentThreadEntry::AssistantMessage(message) => {
189                let entry = if let Some(Entry::AssistantMessage(entry)) =
190                    self.entries.get_mut(index)
191                {
192                    entry
193                } else {
194                    self.set_entry(
195                        index,
196                        Entry::AssistantMessage(AssistantMessageEntry::default()),
197                    );
198                    let Some(Entry::AssistantMessage(entry)) = self.entries.get_mut(index) else {
199                        unreachable!()
200                    };
201                    entry
202                };
203                entry.sync(message);
204            }
205        };
206    }
207
208    fn set_entry(&mut self, index: usize, entry: Entry) {
209        if index == self.entries.len() {
210            self.entries.push(entry);
211        } else {
212            self.entries[index] = entry;
213        }
214    }
215
216    pub fn remove(&mut self, range: Range<usize>) {
217        self.entries.drain(range);
218    }
219
220    pub fn agent_ui_font_size_changed(&mut self, cx: &mut App) {
221        for entry in self.entries.iter() {
222            match entry {
223                Entry::UserMessage { .. } | Entry::AssistantMessage { .. } => {}
224                Entry::Content(response_views) => {
225                    for view in response_views.values() {
226                        if let Ok(diff_editor) = view.clone().downcast::<Editor>() {
227                            diff_editor.update(cx, |diff_editor, cx| {
228                                diff_editor.set_text_style_refinement(
229                                    diff_editor_text_style_refinement(cx),
230                                );
231                                cx.notify();
232                            })
233                        }
234                    }
235                }
236            }
237        }
238    }
239}
240
241impl EventEmitter<EntryViewEvent> for EntryViewState {}
242
243pub struct EntryViewEvent {
244    pub entry_index: usize,
245    pub view_event: ViewEvent,
246}
247
248pub enum ViewEvent {
249    NewDiff(ToolCallId),
250    NewTerminal(ToolCallId),
251    TerminalMovedToBackground(ToolCallId),
252    MessageEditorEvent(Entity<MessageEditor>, MessageEditorEvent),
253}
254
255#[derive(Default, Debug)]
256pub struct AssistantMessageEntry {
257    scroll_handles_by_chunk_index: HashMap<usize, ScrollHandle>,
258}
259
260impl AssistantMessageEntry {
261    pub fn scroll_handle_for_chunk(&self, ix: usize) -> Option<ScrollHandle> {
262        self.scroll_handles_by_chunk_index.get(&ix).cloned()
263    }
264
265    pub fn sync(&mut self, message: &acp_thread::AssistantMessage) {
266        if let Some(acp_thread::AssistantMessageChunk::Thought { .. }) = message.chunks.last() {
267            let ix = message.chunks.len() - 1;
268            let handle = self.scroll_handles_by_chunk_index.entry(ix).or_default();
269            handle.scroll_to_bottom();
270        }
271    }
272}
273
274#[derive(Debug)]
275pub enum Entry {
276    UserMessage(Entity<MessageEditor>),
277    AssistantMessage(AssistantMessageEntry),
278    Content(HashMap<EntityId, AnyEntity>),
279}
280
281impl Entry {
282    pub fn focus_handle(&self, cx: &App) -> Option<FocusHandle> {
283        match self {
284            Self::UserMessage(editor) => Some(editor.read(cx).focus_handle(cx)),
285            Self::AssistantMessage(_) | Self::Content(_) => None,
286        }
287    }
288
289    pub fn message_editor(&self) -> Option<&Entity<MessageEditor>> {
290        match self {
291            Self::UserMessage(editor) => Some(editor),
292            Self::AssistantMessage(_) | Self::Content(_) => None,
293        }
294    }
295
296    pub fn editor_for_diff(&self, diff: &Entity<acp_thread::Diff>) -> Option<Entity<Editor>> {
297        self.content_map()?
298            .get(&diff.entity_id())
299            .cloned()
300            .map(|entity| entity.downcast::<Editor>().unwrap())
301    }
302
303    pub fn terminal(
304        &self,
305        terminal: &Entity<acp_thread::Terminal>,
306    ) -> Option<Entity<TerminalView>> {
307        self.content_map()?
308            .get(&terminal.entity_id())
309            .cloned()
310            .map(|entity| entity.downcast::<TerminalView>().unwrap())
311    }
312
313    pub fn scroll_handle_for_assistant_message_chunk(
314        &self,
315        chunk_ix: usize,
316    ) -> Option<ScrollHandle> {
317        match self {
318            Self::AssistantMessage(message) => message.scroll_handle_for_chunk(chunk_ix),
319            Self::UserMessage(_) | Self::Content(_) => None,
320        }
321    }
322
323    fn content_map(&self) -> Option<&HashMap<EntityId, AnyEntity>> {
324        match self {
325            Self::Content(map) => Some(map),
326            _ => None,
327        }
328    }
329
330    fn empty() -> Self {
331        Self::Content(HashMap::default())
332    }
333
334    #[cfg(test)]
335    pub fn has_content(&self) -> bool {
336        match self {
337            Self::Content(map) => !map.is_empty(),
338            Self::UserMessage(_) | Self::AssistantMessage(_) => false,
339        }
340    }
341}
342
343fn create_terminal(
344    workspace: WeakEntity<Workspace>,
345    project: WeakEntity<Project>,
346    terminal: Entity<acp_thread::Terminal>,
347    window: &mut Window,
348    cx: &mut App,
349) -> Entity<TerminalView> {
350    cx.new(|cx| {
351        let mut view = TerminalView::new(
352            terminal.read(cx).inner().clone(),
353            workspace,
354            None,
355            project,
356            window,
357            cx,
358        );
359        view.set_embedded_mode(Some(1000), cx);
360        view
361    })
362}
363
364fn create_editor_diff(
365    diff: Entity<acp_thread::Diff>,
366    window: &mut Window,
367    cx: &mut App,
368) -> Entity<Editor> {
369    cx.new(|cx| {
370        let mut editor = Editor::new(
371            EditorMode::Full {
372                scale_ui_elements_with_buffer_font_size: false,
373                show_active_line_background: false,
374                sizing_behavior: SizingBehavior::SizeByContent,
375            },
376            diff.read(cx).multibuffer().clone(),
377            None,
378            window,
379            cx,
380        );
381        editor.set_show_gutter(false, cx);
382        editor.disable_inline_diagnostics();
383        editor.disable_expand_excerpt_buttons(cx);
384        editor.set_show_vertical_scrollbar(false, cx);
385        editor.set_minimap_visibility(MinimapVisibility::Disabled, window, cx);
386        editor.set_soft_wrap_mode(SoftWrap::None, cx);
387        editor.scroll_manager.set_forbid_vertical_scroll(true);
388        editor.set_show_indent_guides(false, cx);
389        editor.set_read_only(true);
390        editor.set_show_breakpoints(false, cx);
391        editor.set_show_code_actions(false, cx);
392        editor.set_show_git_diff_gutter(false, cx);
393        editor.set_expand_all_diff_hunks(cx);
394        editor.set_text_style_refinement(diff_editor_text_style_refinement(cx));
395        editor
396    })
397}
398
399fn diff_editor_text_style_refinement(cx: &mut App) -> TextStyleRefinement {
400    TextStyleRefinement {
401        font_size: Some(
402            TextSize::Small
403                .rems(cx)
404                .to_pixels(ThemeSettings::get_global(cx).agent_ui_font_size(cx))
405                .into(),
406        ),
407        ..Default::default()
408    }
409}
410
411#[cfg(test)]
412mod tests {
413    use std::path::Path;
414    use std::rc::Rc;
415
416    use acp_thread::{AgentConnection, StubAgentConnection};
417    use agent_client_protocol as acp;
418    use buffer_diff::{DiffHunkStatus, DiffHunkStatusKind};
419    use editor::RowInfo;
420    use fs::FakeFs;
421    use gpui::{AppContext as _, TestAppContext};
422
423    use crate::acp::entry_view_state::EntryViewState;
424    use multi_buffer::MultiBufferRow;
425    use pretty_assertions::assert_matches;
426    use project::Project;
427    use serde_json::json;
428    use settings::SettingsStore;
429    use util::path;
430    use workspace::Workspace;
431
432    #[gpui::test]
433    async fn test_diff_sync(cx: &mut TestAppContext) {
434        init_test(cx);
435        let fs = FakeFs::new(cx.executor());
436        fs.insert_tree(
437            "/project",
438            json!({
439                "hello.txt": "hi world"
440            }),
441        )
442        .await;
443        let project = Project::test(fs, [Path::new(path!("/project"))], cx).await;
444
445        let (workspace, cx) =
446            cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
447
448        let tool_call = acp::ToolCall::new("tool", "Tool call")
449            .status(acp::ToolCallStatus::InProgress)
450            .content(vec![acp::ToolCallContent::Diff(
451                acp::Diff::new("/project/hello.txt", "hello world").old_text("hi world"),
452            )]);
453        let connection = Rc::new(StubAgentConnection::new());
454        let thread = cx
455            .update(|_, cx| {
456                connection
457                    .clone()
458                    .new_thread(project.clone(), Path::new(path!("/project")), cx)
459            })
460            .await
461            .unwrap();
462        let session_id = thread.update(cx, |thread, _| thread.session_id().clone());
463
464        cx.update(|_, cx| {
465            connection.send_update(session_id, acp::SessionUpdate::ToolCall(tool_call), cx)
466        });
467
468        let thread_store = None;
469        let history = cx
470            .update(|window, cx| cx.new(|cx| crate::acp::AcpThreadHistory::new(None, window, cx)));
471
472        let view_state = cx.new(|_cx| {
473            EntryViewState::new(
474                workspace.downgrade(),
475                project.downgrade(),
476                thread_store,
477                history.downgrade(),
478                None,
479                Default::default(),
480                Default::default(),
481                Default::default(),
482                Default::default(),
483                "Test Agent".into(),
484            )
485        });
486
487        view_state.update_in(cx, |view_state, window, cx| {
488            view_state.sync_entry(0, &thread, window, cx)
489        });
490
491        let diff = thread.read_with(cx, |thread, _| {
492            thread
493                .entries()
494                .get(0)
495                .unwrap()
496                .diffs()
497                .next()
498                .unwrap()
499                .clone()
500        });
501
502        cx.run_until_parked();
503
504        let diff_editor = view_state.read_with(cx, |view_state, _cx| {
505            view_state.entry(0).unwrap().editor_for_diff(&diff).unwrap()
506        });
507        assert_eq!(
508            diff_editor.read_with(cx, |editor, cx| editor.text(cx)),
509            "hi world\nhello world"
510        );
511        let row_infos = diff_editor.read_with(cx, |editor, cx| {
512            let multibuffer = editor.buffer().read(cx);
513            multibuffer
514                .snapshot(cx)
515                .row_infos(MultiBufferRow(0))
516                .collect::<Vec<_>>()
517        });
518        assert_matches!(
519            row_infos.as_slice(),
520            [
521                RowInfo {
522                    multibuffer_row: Some(MultiBufferRow(0)),
523                    diff_status: Some(DiffHunkStatus {
524                        kind: DiffHunkStatusKind::Deleted,
525                        ..
526                    }),
527                    ..
528                },
529                RowInfo {
530                    multibuffer_row: Some(MultiBufferRow(1)),
531                    diff_status: Some(DiffHunkStatus {
532                        kind: DiffHunkStatusKind::Added,
533                        ..
534                    }),
535                    ..
536                }
537            ]
538        );
539    }
540
541    fn init_test(cx: &mut TestAppContext) {
542        cx.update(|cx| {
543            let settings_store = SettingsStore::test(cx);
544            cx.set_global(settings_store);
545            theme::init(theme::LoadThemes::JustBase, cx);
546            release_channel::init(semver::Version::new(0, 0, 0), cx);
547        });
548    }
549}