entry_view_state.rs

  1use std::{
  2    cell::{Cell, RefCell},
  3    ops::Range,
  4    rc::Rc,
  5};
  6
  7use acp_thread::{AcpThread, AgentThreadEntry};
  8use agent_client_protocol::{self as acp, ToolCallId};
  9use agent2::HistoryStore;
 10use collections::HashMap;
 11use editor::{Editor, EditorMode, MinimapVisibility};
 12use gpui::{
 13    AnyEntity, App, AppContext as _, Entity, EntityId, EventEmitter, FocusHandle, Focusable,
 14    ScrollHandle, SharedString, TextStyleRefinement, WeakEntity, Window,
 15};
 16use language::language_settings::SoftWrap;
 17use project::Project;
 18use prompt_store::PromptStore;
 19use settings::Settings as _;
 20use terminal_view::TerminalView;
 21use theme::ThemeSettings;
 22use ui::{Context, TextSize};
 23use workspace::Workspace;
 24
 25use crate::acp::message_editor::{MessageEditor, MessageEditorEvent};
 26
 27pub struct EntryViewState {
 28    workspace: WeakEntity<Workspace>,
 29    project: Entity<Project>,
 30    history_store: Entity<HistoryStore>,
 31    prompt_store: Option<Entity<PromptStore>>,
 32    entries: Vec<Entry>,
 33    prompt_capabilities: Rc<Cell<acp::PromptCapabilities>>,
 34    available_commands: Rc<RefCell<Vec<acp::AvailableCommand>>>,
 35    agent_name: SharedString,
 36}
 37
 38impl EntryViewState {
 39    pub fn new(
 40        workspace: WeakEntity<Workspace>,
 41        project: Entity<Project>,
 42        history_store: Entity<HistoryStore>,
 43        prompt_store: Option<Entity<PromptStore>>,
 44        prompt_capabilities: Rc<Cell<acp::PromptCapabilities>>,
 45        available_commands: Rc<RefCell<Vec<acp::AvailableCommand>>>,
 46        agent_name: SharedString,
 47    ) -> Self {
 48        Self {
 49            workspace,
 50            project,
 51            history_store,
 52            prompt_store,
 53            entries: Vec::new(),
 54            prompt_capabilities,
 55            available_commands,
 56            agent_name,
 57        }
 58    }
 59
 60    pub fn entry(&self, index: usize) -> Option<&Entry> {
 61        self.entries.get(index)
 62    }
 63
 64    pub fn sync_entry(
 65        &mut self,
 66        index: usize,
 67        thread: &Entity<AcpThread>,
 68        window: &mut Window,
 69        cx: &mut Context<Self>,
 70    ) {
 71        let Some(thread_entry) = thread.read(cx).entries().get(index) else {
 72            return;
 73        };
 74
 75        match thread_entry {
 76            AgentThreadEntry::UserMessage(message) => {
 77                let has_id = message.id.is_some();
 78                let chunks = message.chunks.clone();
 79                if let Some(Entry::UserMessage(editor)) = self.entries.get_mut(index) {
 80                    if !editor.focus_handle(cx).is_focused(window) {
 81                        // Only update if we are not editing.
 82                        // If we are, cancelling the edit will set the message to the newest content.
 83                        editor.update(cx, |editor, cx| {
 84                            editor.set_message(chunks, window, cx);
 85                        });
 86                    }
 87                } else {
 88                    let message_editor = cx.new(|cx| {
 89                        let mut editor = MessageEditor::new(
 90                            self.workspace.clone(),
 91                            self.project.clone(),
 92                            self.history_store.clone(),
 93                            self.prompt_store.clone(),
 94                            self.prompt_capabilities.clone(),
 95                            self.available_commands.clone(),
 96                            self.agent_name.clone(),
 97                            "Edit message - @ to include context",
 98                            editor::EditorMode::AutoHeight {
 99                                min_lines: 1,
100                                max_lines: None,
101                            },
102                            window,
103                            cx,
104                        );
105                        if !has_id {
106                            editor.set_read_only(true, cx);
107                        }
108                        editor.set_message(chunks, window, cx);
109                        editor
110                    });
111                    cx.subscribe(&message_editor, move |_, editor, event, cx| {
112                        cx.emit(EntryViewEvent {
113                            entry_index: index,
114                            view_event: ViewEvent::MessageEditorEvent(editor, *event),
115                        })
116                    })
117                    .detach();
118                    self.set_entry(index, Entry::UserMessage(message_editor));
119                }
120            }
121            AgentThreadEntry::ToolCall(tool_call) => {
122                let id = tool_call.id.clone();
123                let terminals = tool_call.terminals().cloned().collect::<Vec<_>>();
124                let diffs = tool_call.diffs().cloned().collect::<Vec<_>>();
125
126                let views = if let Some(Entry::Content(views)) = self.entries.get_mut(index) {
127                    views
128                } else {
129                    self.set_entry(index, Entry::empty());
130                    let Some(Entry::Content(views)) = self.entries.get_mut(index) else {
131                        unreachable!()
132                    };
133                    views
134                };
135
136                let is_tool_call_completed =
137                    matches!(tool_call.status, acp_thread::ToolCallStatus::Completed);
138
139                for terminal in terminals {
140                    match views.entry(terminal.entity_id()) {
141                        collections::hash_map::Entry::Vacant(entry) => {
142                            let element = create_terminal(
143                                self.workspace.clone(),
144                                self.project.clone(),
145                                terminal.clone(),
146                                window,
147                                cx,
148                            )
149                            .into_any();
150                            cx.emit(EntryViewEvent {
151                                entry_index: index,
152                                view_event: ViewEvent::NewTerminal(id.clone()),
153                            });
154                            entry.insert(element);
155                        }
156                        collections::hash_map::Entry::Occupied(_entry) => {
157                            if is_tool_call_completed && terminal.read(cx).output().is_none() {
158                                cx.emit(EntryViewEvent {
159                                    entry_index: index,
160                                    view_event: ViewEvent::TerminalMovedToBackground(id.clone()),
161                                });
162                            }
163                        }
164                    }
165                }
166
167                for diff in diffs {
168                    views.entry(diff.entity_id()).or_insert_with(|| {
169                        let element = create_editor_diff(diff.clone(), window, cx).into_any();
170                        cx.emit(EntryViewEvent {
171                            entry_index: index,
172                            view_event: ViewEvent::NewDiff(id.clone()),
173                        });
174                        element
175                    });
176                }
177            }
178            AgentThreadEntry::AssistantMessage(message) => {
179                let entry = if let Some(Entry::AssistantMessage(entry)) =
180                    self.entries.get_mut(index)
181                {
182                    entry
183                } else {
184                    self.set_entry(
185                        index,
186                        Entry::AssistantMessage(AssistantMessageEntry::default()),
187                    );
188                    let Some(Entry::AssistantMessage(entry)) = self.entries.get_mut(index) else {
189                        unreachable!()
190                    };
191                    entry
192                };
193                entry.sync(message);
194            }
195        };
196    }
197
198    fn set_entry(&mut self, index: usize, entry: Entry) {
199        if index == self.entries.len() {
200            self.entries.push(entry);
201        } else {
202            self.entries[index] = entry;
203        }
204    }
205
206    pub fn remove(&mut self, range: Range<usize>) {
207        self.entries.drain(range);
208    }
209
210    pub fn agent_font_size_changed(&mut self, cx: &mut App) {
211        for entry in self.entries.iter() {
212            match entry {
213                Entry::UserMessage { .. } | Entry::AssistantMessage { .. } => {}
214                Entry::Content(response_views) => {
215                    for view in response_views.values() {
216                        if let Ok(diff_editor) = view.clone().downcast::<Editor>() {
217                            diff_editor.update(cx, |diff_editor, cx| {
218                                diff_editor.set_text_style_refinement(
219                                    diff_editor_text_style_refinement(cx),
220                                );
221                                cx.notify();
222                            })
223                        }
224                    }
225                }
226            }
227        }
228    }
229}
230
231impl EventEmitter<EntryViewEvent> for EntryViewState {}
232
233pub struct EntryViewEvent {
234    pub entry_index: usize,
235    pub view_event: ViewEvent,
236}
237
238pub enum ViewEvent {
239    NewDiff(ToolCallId),
240    NewTerminal(ToolCallId),
241    TerminalMovedToBackground(ToolCallId),
242    MessageEditorEvent(Entity<MessageEditor>, MessageEditorEvent),
243}
244
245#[derive(Default, Debug)]
246pub struct AssistantMessageEntry {
247    scroll_handles_by_chunk_index: HashMap<usize, ScrollHandle>,
248}
249
250impl AssistantMessageEntry {
251    pub fn scroll_handle_for_chunk(&self, ix: usize) -> Option<ScrollHandle> {
252        self.scroll_handles_by_chunk_index.get(&ix).cloned()
253    }
254
255    pub fn sync(&mut self, message: &acp_thread::AssistantMessage) {
256        if let Some(acp_thread::AssistantMessageChunk::Thought { .. }) = message.chunks.last() {
257            let ix = message.chunks.len() - 1;
258            let handle = self.scroll_handles_by_chunk_index.entry(ix).or_default();
259            handle.scroll_to_bottom();
260        }
261    }
262}
263
264#[derive(Debug)]
265pub enum Entry {
266    UserMessage(Entity<MessageEditor>),
267    AssistantMessage(AssistantMessageEntry),
268    Content(HashMap<EntityId, AnyEntity>),
269}
270
271impl Entry {
272    pub fn focus_handle(&self, cx: &App) -> Option<FocusHandle> {
273        match self {
274            Self::UserMessage(editor) => Some(editor.read(cx).focus_handle(cx)),
275            Self::AssistantMessage(_) | Self::Content(_) => None,
276        }
277    }
278
279    pub fn message_editor(&self) -> Option<&Entity<MessageEditor>> {
280        match self {
281            Self::UserMessage(editor) => Some(editor),
282            Self::AssistantMessage(_) | Self::Content(_) => None,
283        }
284    }
285
286    pub fn editor_for_diff(&self, diff: &Entity<acp_thread::Diff>) -> Option<Entity<Editor>> {
287        self.content_map()?
288            .get(&diff.entity_id())
289            .cloned()
290            .map(|entity| entity.downcast::<Editor>().unwrap())
291    }
292
293    pub fn terminal(
294        &self,
295        terminal: &Entity<acp_thread::Terminal>,
296    ) -> Option<Entity<TerminalView>> {
297        self.content_map()?
298            .get(&terminal.entity_id())
299            .cloned()
300            .map(|entity| entity.downcast::<TerminalView>().unwrap())
301    }
302
303    pub fn scroll_handle_for_assistant_message_chunk(
304        &self,
305        chunk_ix: usize,
306    ) -> Option<ScrollHandle> {
307        match self {
308            Self::AssistantMessage(message) => message.scroll_handle_for_chunk(chunk_ix),
309            Self::UserMessage(_) | Self::Content(_) => None,
310        }
311    }
312
313    fn content_map(&self) -> Option<&HashMap<EntityId, AnyEntity>> {
314        match self {
315            Self::Content(map) => Some(map),
316            _ => None,
317        }
318    }
319
320    fn empty() -> Self {
321        Self::Content(HashMap::default())
322    }
323
324    #[cfg(test)]
325    pub fn has_content(&self) -> bool {
326        match self {
327            Self::Content(map) => !map.is_empty(),
328            Self::UserMessage(_) | Self::AssistantMessage(_) => false,
329        }
330    }
331}
332
333fn create_terminal(
334    workspace: WeakEntity<Workspace>,
335    project: Entity<Project>,
336    terminal: Entity<acp_thread::Terminal>,
337    window: &mut Window,
338    cx: &mut App,
339) -> Entity<TerminalView> {
340    cx.new(|cx| {
341        let mut view = TerminalView::new(
342            terminal.read(cx).inner().clone(),
343            workspace.clone(),
344            None,
345            project.downgrade(),
346            window,
347            cx,
348        );
349        view.set_embedded_mode(Some(1000), cx);
350        view
351    })
352}
353
354fn create_editor_diff(
355    diff: Entity<acp_thread::Diff>,
356    window: &mut Window,
357    cx: &mut App,
358) -> Entity<Editor> {
359    cx.new(|cx| {
360        let mut editor = Editor::new(
361            EditorMode::Full {
362                scale_ui_elements_with_buffer_font_size: false,
363                show_active_line_background: false,
364                sized_by_content: true,
365            },
366            diff.read(cx).multibuffer().clone(),
367            None,
368            window,
369            cx,
370        );
371        editor.set_show_gutter(false, cx);
372        editor.disable_inline_diagnostics();
373        editor.disable_expand_excerpt_buttons(cx);
374        editor.set_show_vertical_scrollbar(false, cx);
375        editor.set_minimap_visibility(MinimapVisibility::Disabled, window, cx);
376        editor.set_soft_wrap_mode(SoftWrap::None, cx);
377        editor.scroll_manager.set_forbid_vertical_scroll(true);
378        editor.set_show_indent_guides(false, cx);
379        editor.set_read_only(true);
380        editor.set_show_breakpoints(false, cx);
381        editor.set_show_code_actions(false, cx);
382        editor.set_show_git_diff_gutter(false, cx);
383        editor.set_expand_all_diff_hunks(cx);
384        editor.set_text_style_refinement(diff_editor_text_style_refinement(cx));
385        editor
386    })
387}
388
389fn diff_editor_text_style_refinement(cx: &mut App) -> TextStyleRefinement {
390    TextStyleRefinement {
391        font_size: Some(
392            TextSize::Small
393                .rems(cx)
394                .to_pixels(ThemeSettings::get_global(cx).agent_font_size(cx))
395                .into(),
396        ),
397        ..Default::default()
398    }
399}
400
401#[cfg(test)]
402mod tests {
403    use std::{path::Path, rc::Rc};
404
405    use acp_thread::{AgentConnection, StubAgentConnection};
406    use agent_client_protocol as acp;
407    use agent_settings::AgentSettings;
408    use agent2::HistoryStore;
409    use assistant_context::ContextStore;
410    use buffer_diff::{DiffHunkStatus, DiffHunkStatusKind};
411    use editor::{EditorSettings, RowInfo};
412    use fs::FakeFs;
413    use gpui::{AppContext as _, SemanticVersion, TestAppContext};
414
415    use crate::acp::entry_view_state::EntryViewState;
416    use multi_buffer::MultiBufferRow;
417    use pretty_assertions::assert_matches;
418    use project::Project;
419    use serde_json::json;
420    use settings::{Settings as _, SettingsStore};
421    use theme::ThemeSettings;
422    use util::path;
423    use workspace::Workspace;
424
425    #[gpui::test]
426    async fn test_diff_sync(cx: &mut TestAppContext) {
427        init_test(cx);
428        let fs = FakeFs::new(cx.executor());
429        fs.insert_tree(
430            "/project",
431            json!({
432                "hello.txt": "hi world"
433            }),
434        )
435        .await;
436        let project = Project::test(fs, [Path::new(path!("/project"))], cx).await;
437
438        let (workspace, cx) =
439            cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
440
441        let tool_call = acp::ToolCall {
442            id: acp::ToolCallId("tool".into()),
443            title: "Tool call".into(),
444            kind: acp::ToolKind::Other,
445            status: acp::ToolCallStatus::InProgress,
446            content: vec![acp::ToolCallContent::Diff {
447                diff: acp::Diff {
448                    path: "/project/hello.txt".into(),
449                    old_text: Some("hi world".into()),
450                    new_text: "hello world".into(),
451                },
452            }],
453            locations: vec![],
454            raw_input: None,
455            raw_output: None,
456        };
457        let connection = Rc::new(StubAgentConnection::new());
458        let thread = cx
459            .update(|_, cx| {
460                connection
461                    .clone()
462                    .new_thread(project.clone(), Path::new(path!("/project")), cx)
463            })
464            .await
465            .unwrap();
466        let session_id = thread.update(cx, |thread, _| thread.session_id().clone());
467
468        cx.update(|_, cx| {
469            connection.send_update(session_id, acp::SessionUpdate::ToolCall(tool_call), cx)
470        });
471
472        let context_store = cx.new(|cx| ContextStore::fake(project.clone(), cx));
473        let history_store = cx.new(|cx| HistoryStore::new(context_store, cx));
474
475        let view_state = cx.new(|_cx| {
476            EntryViewState::new(
477                workspace.downgrade(),
478                project.clone(),
479                history_store,
480                None,
481                Default::default(),
482                Default::default(),
483                "Test Agent".into(),
484            )
485        });
486
487        view_state.update_in(cx, |view_state, window, cx| {
488            view_state.sync_entry(0, &thread, window, cx)
489        });
490
491        let diff = thread.read_with(cx, |thread, _cx| {
492            thread
493                .entries()
494                .get(0)
495                .unwrap()
496                .diffs()
497                .next()
498                .unwrap()
499                .clone()
500        });
501
502        cx.run_until_parked();
503
504        let diff_editor = view_state.read_with(cx, |view_state, _cx| {
505            view_state.entry(0).unwrap().editor_for_diff(&diff).unwrap()
506        });
507        assert_eq!(
508            diff_editor.read_with(cx, |editor, cx| editor.text(cx)),
509            "hi world\nhello world"
510        );
511        let row_infos = diff_editor.read_with(cx, |editor, cx| {
512            let multibuffer = editor.buffer().read(cx);
513            multibuffer
514                .snapshot(cx)
515                .row_infos(MultiBufferRow(0))
516                .collect::<Vec<_>>()
517        });
518        assert_matches!(
519            row_infos.as_slice(),
520            [
521                RowInfo {
522                    multibuffer_row: Some(MultiBufferRow(0)),
523                    diff_status: Some(DiffHunkStatus {
524                        kind: DiffHunkStatusKind::Deleted,
525                        ..
526                    }),
527                    ..
528                },
529                RowInfo {
530                    multibuffer_row: Some(MultiBufferRow(1)),
531                    diff_status: Some(DiffHunkStatus {
532                        kind: DiffHunkStatusKind::Added,
533                        ..
534                    }),
535                    ..
536                }
537            ]
538        );
539    }
540
541    fn init_test(cx: &mut TestAppContext) {
542        cx.update(|cx| {
543            let settings_store = SettingsStore::test(cx);
544            cx.set_global(settings_store);
545            language::init(cx);
546            Project::init_settings(cx);
547            AgentSettings::register(cx);
548            workspace::init_settings(cx);
549            ThemeSettings::register(cx);
550            release_channel::init(SemanticVersion::default(), cx);
551            EditorSettings::register(cx);
552        });
553    }
554}