entry_view_state.rs

  1use std::{cell::RefCell, ops::Range, rc::Rc};
  2
  3use acp_thread::{AcpThread, AgentSessionList, AgentThreadEntry};
  4use agent::ThreadStore;
  5use agent_client_protocol::{self as acp, ToolCallId};
  6use collections::HashMap;
  7use editor::{Editor, EditorMode, MinimapVisibility, SizingBehavior};
  8use gpui::{
  9    AnyEntity, App, AppContext as _, Entity, EntityId, EventEmitter, FocusHandle, Focusable,
 10    ScrollHandle, SharedString, TextStyleRefinement, WeakEntity, Window,
 11};
 12use language::language_settings::SoftWrap;
 13use project::Project;
 14use prompt_store::PromptStore;
 15use settings::Settings as _;
 16use terminal_view::TerminalView;
 17use theme::ThemeSettings;
 18use ui::{Context, TextSize};
 19use workspace::Workspace;
 20
 21use crate::acp::message_editor::{MessageEditor, MessageEditorEvent};
 22
 23pub struct EntryViewState {
 24    workspace: WeakEntity<Workspace>,
 25    project: WeakEntity<Project>,
 26    thread_store: Option<Entity<ThreadStore>>,
 27    session_list: Rc<RefCell<Option<Rc<dyn AgentSessionList>>>>,
 28    prompt_store: Option<Entity<PromptStore>>,
 29    entries: Vec<Entry>,
 30    prompt_capabilities: Rc<RefCell<acp::PromptCapabilities>>,
 31    available_commands: Rc<RefCell<Vec<acp::AvailableCommand>>>,
 32    agent_name: SharedString,
 33}
 34
 35impl EntryViewState {
 36    pub fn new(
 37        workspace: WeakEntity<Workspace>,
 38        project: WeakEntity<Project>,
 39        thread_store: Option<Entity<ThreadStore>>,
 40        session_list: Rc<RefCell<Option<Rc<dyn AgentSessionList>>>>,
 41        prompt_store: Option<Entity<PromptStore>>,
 42        prompt_capabilities: Rc<RefCell<acp::PromptCapabilities>>,
 43        available_commands: Rc<RefCell<Vec<acp::AvailableCommand>>>,
 44        agent_name: SharedString,
 45    ) -> Self {
 46        Self {
 47            workspace,
 48            project,
 49            thread_store,
 50            session_list,
 51            prompt_store,
 52            entries: Vec::new(),
 53            prompt_capabilities,
 54            available_commands,
 55            agent_name,
 56        }
 57    }
 58
 59    pub fn entry(&self, index: usize) -> Option<&Entry> {
 60        self.entries.get(index)
 61    }
 62
 63    pub fn sync_entry(
 64        &mut self,
 65        index: usize,
 66        thread: &Entity<AcpThread>,
 67        window: &mut Window,
 68        cx: &mut Context<Self>,
 69    ) {
 70        let Some(thread_entry) = thread.read(cx).entries().get(index) else {
 71            return;
 72        };
 73
 74        match thread_entry {
 75            AgentThreadEntry::UserMessage(message) => {
 76                let has_id = message.id.is_some();
 77                let chunks = message.chunks.clone();
 78                if let Some(Entry::UserMessage(editor)) = self.entries.get_mut(index) {
 79                    if !editor.focus_handle(cx).is_focused(window) {
 80                        // Only update if we are not editing.
 81                        // If we are, cancelling the edit will set the message to the newest content.
 82                        editor.update(cx, |editor, cx| {
 83                            editor.set_message(chunks, window, cx);
 84                        });
 85                    }
 86                } else {
 87                    let message_editor = cx.new(|cx| {
 88                        let mut editor = MessageEditor::new(
 89                            self.workspace.clone(),
 90                            self.project.clone(),
 91                            self.thread_store.clone(),
 92                            self.session_list.clone(),
 93                            self.prompt_store.clone(),
 94                            self.prompt_capabilities.clone(),
 95                            self.available_commands.clone(),
 96                            self.agent_name.clone(),
 97                            "Edit message - @ to include context",
 98                            editor::EditorMode::AutoHeight {
 99                                min_lines: 1,
100                                max_lines: None,
101                            },
102                            window,
103                            cx,
104                        );
105                        if !has_id {
106                            editor.set_read_only(true, cx);
107                        }
108                        editor.set_message(chunks, window, cx);
109                        editor
110                    });
111                    cx.subscribe(&message_editor, move |_, editor, event, cx| {
112                        cx.emit(EntryViewEvent {
113                            entry_index: index,
114                            view_event: ViewEvent::MessageEditorEvent(editor, *event),
115                        })
116                    })
117                    .detach();
118                    self.set_entry(index, Entry::UserMessage(message_editor));
119                }
120            }
121            AgentThreadEntry::ToolCall(tool_call) => {
122                let id = tool_call.id.clone();
123                let terminals = tool_call.terminals().cloned().collect::<Vec<_>>();
124                let diffs = tool_call.diffs().cloned().collect::<Vec<_>>();
125
126                let views = if let Some(Entry::Content(views)) = self.entries.get_mut(index) {
127                    views
128                } else {
129                    self.set_entry(index, Entry::empty());
130                    let Some(Entry::Content(views)) = self.entries.get_mut(index) else {
131                        unreachable!()
132                    };
133                    views
134                };
135
136                let is_tool_call_completed =
137                    matches!(tool_call.status, acp_thread::ToolCallStatus::Completed);
138
139                for terminal in terminals {
140                    match views.entry(terminal.entity_id()) {
141                        collections::hash_map::Entry::Vacant(entry) => {
142                            let element = create_terminal(
143                                self.workspace.clone(),
144                                self.project.clone(),
145                                terminal.clone(),
146                                window,
147                                cx,
148                            )
149                            .into_any();
150                            cx.emit(EntryViewEvent {
151                                entry_index: index,
152                                view_event: ViewEvent::NewTerminal(id.clone()),
153                            });
154                            entry.insert(element);
155                        }
156                        collections::hash_map::Entry::Occupied(_entry) => {
157                            if is_tool_call_completed && terminal.read(cx).output().is_none() {
158                                cx.emit(EntryViewEvent {
159                                    entry_index: index,
160                                    view_event: ViewEvent::TerminalMovedToBackground(id.clone()),
161                                });
162                            }
163                        }
164                    }
165                }
166
167                for diff in diffs {
168                    views.entry(diff.entity_id()).or_insert_with(|| {
169                        let element = create_editor_diff(diff.clone(), window, cx).into_any();
170                        cx.emit(EntryViewEvent {
171                            entry_index: index,
172                            view_event: ViewEvent::NewDiff(id.clone()),
173                        });
174                        element
175                    });
176                }
177            }
178            AgentThreadEntry::AssistantMessage(message) => {
179                let entry = if let Some(Entry::AssistantMessage(entry)) =
180                    self.entries.get_mut(index)
181                {
182                    entry
183                } else {
184                    self.set_entry(
185                        index,
186                        Entry::AssistantMessage(AssistantMessageEntry::default()),
187                    );
188                    let Some(Entry::AssistantMessage(entry)) = self.entries.get_mut(index) else {
189                        unreachable!()
190                    };
191                    entry
192                };
193                entry.sync(message);
194            }
195        };
196    }
197
198    fn set_entry(&mut self, index: usize, entry: Entry) {
199        if index == self.entries.len() {
200            self.entries.push(entry);
201        } else {
202            self.entries[index] = entry;
203        }
204    }
205
206    pub fn remove(&mut self, range: Range<usize>) {
207        self.entries.drain(range);
208    }
209
210    pub fn agent_ui_font_size_changed(&mut self, cx: &mut App) {
211        for entry in self.entries.iter() {
212            match entry {
213                Entry::UserMessage { .. } | Entry::AssistantMessage { .. } => {}
214                Entry::Content(response_views) => {
215                    for view in response_views.values() {
216                        if let Ok(diff_editor) = view.clone().downcast::<Editor>() {
217                            diff_editor.update(cx, |diff_editor, cx| {
218                                diff_editor.set_text_style_refinement(
219                                    diff_editor_text_style_refinement(cx),
220                                );
221                                cx.notify();
222                            })
223                        }
224                    }
225                }
226            }
227        }
228    }
229}
230
231impl EventEmitter<EntryViewEvent> for EntryViewState {}
232
233pub struct EntryViewEvent {
234    pub entry_index: usize,
235    pub view_event: ViewEvent,
236}
237
238pub enum ViewEvent {
239    NewDiff(ToolCallId),
240    NewTerminal(ToolCallId),
241    TerminalMovedToBackground(ToolCallId),
242    MessageEditorEvent(Entity<MessageEditor>, MessageEditorEvent),
243}
244
245#[derive(Default, Debug)]
246pub struct AssistantMessageEntry {
247    scroll_handles_by_chunk_index: HashMap<usize, ScrollHandle>,
248}
249
250impl AssistantMessageEntry {
251    pub fn scroll_handle_for_chunk(&self, ix: usize) -> Option<ScrollHandle> {
252        self.scroll_handles_by_chunk_index.get(&ix).cloned()
253    }
254
255    pub fn sync(&mut self, message: &acp_thread::AssistantMessage) {
256        if let Some(acp_thread::AssistantMessageChunk::Thought { .. }) = message.chunks.last() {
257            let ix = message.chunks.len() - 1;
258            let handle = self.scroll_handles_by_chunk_index.entry(ix).or_default();
259            handle.scroll_to_bottom();
260        }
261    }
262}
263
264#[derive(Debug)]
265pub enum Entry {
266    UserMessage(Entity<MessageEditor>),
267    AssistantMessage(AssistantMessageEntry),
268    Content(HashMap<EntityId, AnyEntity>),
269}
270
271impl Entry {
272    pub fn focus_handle(&self, cx: &App) -> Option<FocusHandle> {
273        match self {
274            Self::UserMessage(editor) => Some(editor.read(cx).focus_handle(cx)),
275            Self::AssistantMessage(_) | Self::Content(_) => None,
276        }
277    }
278
279    pub fn message_editor(&self) -> Option<&Entity<MessageEditor>> {
280        match self {
281            Self::UserMessage(editor) => Some(editor),
282            Self::AssistantMessage(_) | Self::Content(_) => None,
283        }
284    }
285
286    pub fn editor_for_diff(&self, diff: &Entity<acp_thread::Diff>) -> Option<Entity<Editor>> {
287        self.content_map()?
288            .get(&diff.entity_id())
289            .cloned()
290            .map(|entity| entity.downcast::<Editor>().unwrap())
291    }
292
293    pub fn terminal(
294        &self,
295        terminal: &Entity<acp_thread::Terminal>,
296    ) -> Option<Entity<TerminalView>> {
297        self.content_map()?
298            .get(&terminal.entity_id())
299            .cloned()
300            .map(|entity| entity.downcast::<TerminalView>().unwrap())
301    }
302
303    pub fn scroll_handle_for_assistant_message_chunk(
304        &self,
305        chunk_ix: usize,
306    ) -> Option<ScrollHandle> {
307        match self {
308            Self::AssistantMessage(message) => message.scroll_handle_for_chunk(chunk_ix),
309            Self::UserMessage(_) | Self::Content(_) => None,
310        }
311    }
312
313    fn content_map(&self) -> Option<&HashMap<EntityId, AnyEntity>> {
314        match self {
315            Self::Content(map) => Some(map),
316            _ => None,
317        }
318    }
319
320    fn empty() -> Self {
321        Self::Content(HashMap::default())
322    }
323
324    #[cfg(test)]
325    pub fn has_content(&self) -> bool {
326        match self {
327            Self::Content(map) => !map.is_empty(),
328            Self::UserMessage(_) | Self::AssistantMessage(_) => false,
329        }
330    }
331}
332
333fn create_terminal(
334    workspace: WeakEntity<Workspace>,
335    project: WeakEntity<Project>,
336    terminal: Entity<acp_thread::Terminal>,
337    window: &mut Window,
338    cx: &mut App,
339) -> Entity<TerminalView> {
340    cx.new(|cx| {
341        let mut view = TerminalView::new(
342            terminal.read(cx).inner().clone(),
343            workspace,
344            None,
345            project,
346            window,
347            cx,
348        );
349        view.set_embedded_mode(Some(1000), cx);
350        view
351    })
352}
353
354fn create_editor_diff(
355    diff: Entity<acp_thread::Diff>,
356    window: &mut Window,
357    cx: &mut App,
358) -> Entity<Editor> {
359    cx.new(|cx| {
360        let mut editor = Editor::new(
361            EditorMode::Full {
362                scale_ui_elements_with_buffer_font_size: false,
363                show_active_line_background: false,
364                sizing_behavior: SizingBehavior::SizeByContent,
365            },
366            diff.read(cx).multibuffer().clone(),
367            None,
368            window,
369            cx,
370        );
371        editor.set_show_gutter(false, cx);
372        editor.disable_inline_diagnostics();
373        editor.disable_expand_excerpt_buttons(cx);
374        editor.set_show_vertical_scrollbar(false, cx);
375        editor.set_minimap_visibility(MinimapVisibility::Disabled, window, cx);
376        editor.set_soft_wrap_mode(SoftWrap::None, cx);
377        editor.scroll_manager.set_forbid_vertical_scroll(true);
378        editor.set_show_indent_guides(false, cx);
379        editor.set_read_only(true);
380        editor.set_show_breakpoints(false, cx);
381        editor.set_show_code_actions(false, cx);
382        editor.set_show_git_diff_gutter(false, cx);
383        editor.set_expand_all_diff_hunks(cx);
384        editor.set_text_style_refinement(diff_editor_text_style_refinement(cx));
385        editor
386    })
387}
388
389fn diff_editor_text_style_refinement(cx: &mut App) -> TextStyleRefinement {
390    TextStyleRefinement {
391        font_size: Some(
392            TextSize::Small
393                .rems(cx)
394                .to_pixels(ThemeSettings::get_global(cx).agent_ui_font_size(cx))
395                .into(),
396        ),
397        ..Default::default()
398    }
399}
400
401#[cfg(test)]
402mod tests {
403    use std::{cell::RefCell, path::Path, rc::Rc};
404
405    use acp_thread::{AgentConnection, StubAgentConnection};
406    use agent_client_protocol as acp;
407    use buffer_diff::{DiffHunkStatus, DiffHunkStatusKind};
408    use editor::RowInfo;
409    use fs::FakeFs;
410    use gpui::{AppContext as _, TestAppContext};
411
412    use crate::acp::entry_view_state::EntryViewState;
413    use multi_buffer::MultiBufferRow;
414    use pretty_assertions::assert_matches;
415    use project::Project;
416    use serde_json::json;
417    use settings::SettingsStore;
418    use util::path;
419    use workspace::Workspace;
420
421    #[gpui::test]
422    async fn test_diff_sync(cx: &mut TestAppContext) {
423        init_test(cx);
424        let fs = FakeFs::new(cx.executor());
425        fs.insert_tree(
426            "/project",
427            json!({
428                "hello.txt": "hi world"
429            }),
430        )
431        .await;
432        let project = Project::test(fs, [Path::new(path!("/project"))], cx).await;
433
434        let (workspace, cx) =
435            cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
436
437        let tool_call = acp::ToolCall::new("tool", "Tool call")
438            .status(acp::ToolCallStatus::InProgress)
439            .content(vec![acp::ToolCallContent::Diff(
440                acp::Diff::new("/project/hello.txt", "hello world").old_text("hi world"),
441            )]);
442        let connection = Rc::new(StubAgentConnection::new());
443        let thread = cx
444            .update(|_, cx| {
445                connection
446                    .clone()
447                    .new_thread(project.clone(), Path::new(path!("/project")), cx)
448            })
449            .await
450            .unwrap();
451        let session_id = thread.update(cx, |thread, _| thread.session_id().clone());
452
453        cx.update(|_, cx| {
454            connection.send_update(session_id, acp::SessionUpdate::ToolCall(tool_call), cx)
455        });
456
457        let thread_store = None;
458        let session_list = Rc::new(RefCell::new(None));
459
460        let view_state = cx.new(|_cx| {
461            EntryViewState::new(
462                workspace.downgrade(),
463                project.downgrade(),
464                thread_store,
465                session_list,
466                None,
467                Default::default(),
468                Default::default(),
469                "Test Agent".into(),
470            )
471        });
472
473        view_state.update_in(cx, |view_state, window, cx| {
474            view_state.sync_entry(0, &thread, window, cx)
475        });
476
477        let diff = thread.read_with(cx, |thread, _cx| {
478            thread
479                .entries()
480                .get(0)
481                .unwrap()
482                .diffs()
483                .next()
484                .unwrap()
485                .clone()
486        });
487
488        cx.run_until_parked();
489
490        let diff_editor = view_state.read_with(cx, |view_state, _cx| {
491            view_state.entry(0).unwrap().editor_for_diff(&diff).unwrap()
492        });
493        assert_eq!(
494            diff_editor.read_with(cx, |editor, cx| editor.text(cx)),
495            "hi world\nhello world"
496        );
497        let row_infos = diff_editor.read_with(cx, |editor, cx| {
498            let multibuffer = editor.buffer().read(cx);
499            multibuffer
500                .snapshot(cx)
501                .row_infos(MultiBufferRow(0))
502                .collect::<Vec<_>>()
503        });
504        assert_matches!(
505            row_infos.as_slice(),
506            [
507                RowInfo {
508                    multibuffer_row: Some(MultiBufferRow(0)),
509                    diff_status: Some(DiffHunkStatus {
510                        kind: DiffHunkStatusKind::Deleted,
511                        ..
512                    }),
513                    ..
514                },
515                RowInfo {
516                    multibuffer_row: Some(MultiBufferRow(1)),
517                    diff_status: Some(DiffHunkStatus {
518                        kind: DiffHunkStatusKind::Added,
519                        ..
520                    }),
521                    ..
522                }
523            ]
524        );
525    }
526
527    fn init_test(cx: &mut TestAppContext) {
528        cx.update(|cx| {
529            let settings_store = SettingsStore::test(cx);
530            cx.set_global(settings_store);
531            theme::init(theme::LoadThemes::JustBase, cx);
532            release_channel::init(semver::Version::new(0, 0, 0), cx);
533        });
534    }
535}