entry_view_state.rs

  1use std::{cell::Cell, ops::Range, rc::Rc};
  2
  3use acp_thread::{AcpThread, AgentThreadEntry};
  4use agent_client_protocol::{PromptCapabilities, ToolCallId};
  5use agent2::HistoryStore;
  6use collections::HashMap;
  7use editor::{Editor, EditorMode, MinimapVisibility};
  8use gpui::{
  9    AnyEntity, App, AppContext as _, Entity, EntityId, EventEmitter, FocusHandle, Focusable,
 10    ScrollHandle, TextStyleRefinement, WeakEntity, Window,
 11};
 12use language::language_settings::SoftWrap;
 13use project::Project;
 14use prompt_store::PromptStore;
 15use settings::Settings as _;
 16use terminal_view::TerminalView;
 17use theme::ThemeSettings;
 18use ui::{Context, TextSize};
 19use workspace::Workspace;
 20
 21use crate::acp::message_editor::{MessageEditor, MessageEditorEvent};
 22
 23pub struct EntryViewState {
 24    workspace: WeakEntity<Workspace>,
 25    project: Entity<Project>,
 26    history_store: Entity<HistoryStore>,
 27    prompt_store: Option<Entity<PromptStore>>,
 28    entries: Vec<Entry>,
 29    prevent_slash_commands: bool,
 30    prompt_capabilities: Rc<Cell<PromptCapabilities>>,
 31}
 32
 33impl EntryViewState {
 34    pub fn new(
 35        workspace: WeakEntity<Workspace>,
 36        project: Entity<Project>,
 37        history_store: Entity<HistoryStore>,
 38        prompt_store: Option<Entity<PromptStore>>,
 39        prompt_capabilities: Rc<Cell<PromptCapabilities>>,
 40        prevent_slash_commands: bool,
 41    ) -> Self {
 42        Self {
 43            workspace,
 44            project,
 45            history_store,
 46            prompt_store,
 47            entries: Vec::new(),
 48            prevent_slash_commands,
 49            prompt_capabilities,
 50        }
 51    }
 52
 53    pub fn entry(&self, index: usize) -> Option<&Entry> {
 54        self.entries.get(index)
 55    }
 56
 57    pub fn sync_entry(
 58        &mut self,
 59        index: usize,
 60        thread: &Entity<AcpThread>,
 61        window: &mut Window,
 62        cx: &mut Context<Self>,
 63    ) {
 64        let Some(thread_entry) = thread.read(cx).entries().get(index) else {
 65            return;
 66        };
 67
 68        match thread_entry {
 69            AgentThreadEntry::UserMessage(message) => {
 70                let has_id = message.id.is_some();
 71                let chunks = message.chunks.clone();
 72                if let Some(Entry::UserMessage(editor)) = self.entries.get_mut(index) {
 73                    if !editor.focus_handle(cx).is_focused(window) {
 74                        // Only update if we are not editing.
 75                        // If we are, cancelling the edit will set the message to the newest content.
 76                        editor.update(cx, |editor, cx| {
 77                            editor.set_message(chunks, window, cx);
 78                        });
 79                    }
 80                } else {
 81                    let message_editor = cx.new(|cx| {
 82                        let mut editor = MessageEditor::new(
 83                            self.workspace.clone(),
 84                            self.project.clone(),
 85                            self.history_store.clone(),
 86                            self.prompt_store.clone(),
 87                            self.prompt_capabilities.clone(),
 88                            "Edit message - @ to include context",
 89                            self.prevent_slash_commands,
 90                            editor::EditorMode::AutoHeight {
 91                                min_lines: 1,
 92                                max_lines: None,
 93                            },
 94                            window,
 95                            cx,
 96                        );
 97                        if !has_id {
 98                            editor.set_read_only(true, cx);
 99                        }
100                        editor.set_message(chunks, window, cx);
101                        editor
102                    });
103                    cx.subscribe(&message_editor, move |_, editor, event, cx| {
104                        cx.emit(EntryViewEvent {
105                            entry_index: index,
106                            view_event: ViewEvent::MessageEditorEvent(editor, *event),
107                        })
108                    })
109                    .detach();
110                    self.set_entry(index, Entry::UserMessage(message_editor));
111                }
112            }
113            AgentThreadEntry::ToolCall(tool_call) => {
114                let id = tool_call.id.clone();
115                let terminals = tool_call.terminals().cloned().collect::<Vec<_>>();
116                let diffs = tool_call.diffs().cloned().collect::<Vec<_>>();
117
118                let views = if let Some(Entry::Content(views)) = self.entries.get_mut(index) {
119                    views
120                } else {
121                    self.set_entry(index, Entry::empty());
122                    let Some(Entry::Content(views)) = self.entries.get_mut(index) else {
123                        unreachable!()
124                    };
125                    views
126                };
127
128                let is_tool_call_completed =
129                    matches!(tool_call.status, acp_thread::ToolCallStatus::Completed);
130
131                for terminal in terminals {
132                    match views.entry(terminal.entity_id()) {
133                        collections::hash_map::Entry::Vacant(entry) => {
134                            let element = create_terminal(
135                                self.workspace.clone(),
136                                self.project.clone(),
137                                terminal.clone(),
138                                window,
139                                cx,
140                            )
141                            .into_any();
142                            cx.emit(EntryViewEvent {
143                                entry_index: index,
144                                view_event: ViewEvent::NewTerminal(id.clone()),
145                            });
146                            entry.insert(element);
147                        }
148                        collections::hash_map::Entry::Occupied(_entry) => {
149                            if is_tool_call_completed && terminal.read(cx).output().is_none() {
150                                cx.emit(EntryViewEvent {
151                                    entry_index: index,
152                                    view_event: ViewEvent::TerminalMovedToBackground(id.clone()),
153                                });
154                            }
155                        }
156                    }
157                }
158
159                for diff in diffs {
160                    views.entry(diff.entity_id()).or_insert_with(|| {
161                        let element = create_editor_diff(diff.clone(), window, cx).into_any();
162                        cx.emit(EntryViewEvent {
163                            entry_index: index,
164                            view_event: ViewEvent::NewDiff(id.clone()),
165                        });
166                        element
167                    });
168                }
169            }
170            AgentThreadEntry::AssistantMessage(message) => {
171                let entry = if let Some(Entry::AssistantMessage(entry)) =
172                    self.entries.get_mut(index)
173                {
174                    entry
175                } else {
176                    self.set_entry(
177                        index,
178                        Entry::AssistantMessage(AssistantMessageEntry::default()),
179                    );
180                    let Some(Entry::AssistantMessage(entry)) = self.entries.get_mut(index) else {
181                        unreachable!()
182                    };
183                    entry
184                };
185                entry.sync(message);
186            }
187        };
188    }
189
190    fn set_entry(&mut self, index: usize, entry: Entry) {
191        if index == self.entries.len() {
192            self.entries.push(entry);
193        } else {
194            self.entries[index] = entry;
195        }
196    }
197
198    pub fn remove(&mut self, range: Range<usize>) {
199        self.entries.drain(range);
200    }
201
202    pub fn settings_changed(&mut self, cx: &mut App) {
203        for entry in self.entries.iter() {
204            match entry {
205                Entry::UserMessage { .. } | Entry::AssistantMessage { .. } => {}
206                Entry::Content(response_views) => {
207                    for view in response_views.values() {
208                        if let Ok(diff_editor) = view.clone().downcast::<Editor>() {
209                            diff_editor.update(cx, |diff_editor, cx| {
210                                diff_editor.set_text_style_refinement(
211                                    diff_editor_text_style_refinement(cx),
212                                );
213                                cx.notify();
214                            })
215                        }
216                    }
217                }
218            }
219        }
220    }
221}
222
223impl EventEmitter<EntryViewEvent> for EntryViewState {}
224
225pub struct EntryViewEvent {
226    pub entry_index: usize,
227    pub view_event: ViewEvent,
228}
229
230pub enum ViewEvent {
231    NewDiff(ToolCallId),
232    NewTerminal(ToolCallId),
233    TerminalMovedToBackground(ToolCallId),
234    MessageEditorEvent(Entity<MessageEditor>, MessageEditorEvent),
235}
236
237#[derive(Default, Debug)]
238pub struct AssistantMessageEntry {
239    scroll_handles_by_chunk_index: HashMap<usize, ScrollHandle>,
240}
241
242impl AssistantMessageEntry {
243    pub fn scroll_handle_for_chunk(&self, ix: usize) -> Option<ScrollHandle> {
244        self.scroll_handles_by_chunk_index.get(&ix).cloned()
245    }
246
247    pub fn sync(&mut self, message: &acp_thread::AssistantMessage) {
248        if let Some(acp_thread::AssistantMessageChunk::Thought { .. }) = message.chunks.last() {
249            let ix = message.chunks.len() - 1;
250            let handle = self.scroll_handles_by_chunk_index.entry(ix).or_default();
251            handle.scroll_to_bottom();
252        }
253    }
254}
255
256#[derive(Debug)]
257pub enum Entry {
258    UserMessage(Entity<MessageEditor>),
259    AssistantMessage(AssistantMessageEntry),
260    Content(HashMap<EntityId, AnyEntity>),
261}
262
263impl Entry {
264    pub fn focus_handle(&self, cx: &App) -> Option<FocusHandle> {
265        match self {
266            Self::UserMessage(editor) => Some(editor.read(cx).focus_handle(cx)),
267            Self::AssistantMessage(_) | Self::Content(_) => None,
268        }
269    }
270
271    pub fn message_editor(&self) -> Option<&Entity<MessageEditor>> {
272        match self {
273            Self::UserMessage(editor) => Some(editor),
274            Self::AssistantMessage(_) | Self::Content(_) => None,
275        }
276    }
277
278    pub fn editor_for_diff(&self, diff: &Entity<acp_thread::Diff>) -> Option<Entity<Editor>> {
279        self.content_map()?
280            .get(&diff.entity_id())
281            .cloned()
282            .map(|entity| entity.downcast::<Editor>().unwrap())
283    }
284
285    pub fn terminal(
286        &self,
287        terminal: &Entity<acp_thread::Terminal>,
288    ) -> Option<Entity<TerminalView>> {
289        self.content_map()?
290            .get(&terminal.entity_id())
291            .cloned()
292            .map(|entity| entity.downcast::<TerminalView>().unwrap())
293    }
294
295    pub fn scroll_handle_for_assistant_message_chunk(
296        &self,
297        chunk_ix: usize,
298    ) -> Option<ScrollHandle> {
299        match self {
300            Self::AssistantMessage(message) => message.scroll_handle_for_chunk(chunk_ix),
301            Self::UserMessage(_) | Self::Content(_) => None,
302        }
303    }
304
305    fn content_map(&self) -> Option<&HashMap<EntityId, AnyEntity>> {
306        match self {
307            Self::Content(map) => Some(map),
308            _ => None,
309        }
310    }
311
312    fn empty() -> Self {
313        Self::Content(HashMap::default())
314    }
315
316    #[cfg(test)]
317    pub fn has_content(&self) -> bool {
318        match self {
319            Self::Content(map) => !map.is_empty(),
320            Self::UserMessage(_) | Self::AssistantMessage(_) => false,
321        }
322    }
323}
324
325fn create_terminal(
326    workspace: WeakEntity<Workspace>,
327    project: Entity<Project>,
328    terminal: Entity<acp_thread::Terminal>,
329    window: &mut Window,
330    cx: &mut App,
331) -> Entity<TerminalView> {
332    cx.new(|cx| {
333        let mut view = TerminalView::new(
334            terminal.read(cx).inner().clone(),
335            workspace.clone(),
336            None,
337            project.downgrade(),
338            window,
339            cx,
340        );
341        view.set_embedded_mode(Some(1000), cx);
342        view
343    })
344}
345
346fn create_editor_diff(
347    diff: Entity<acp_thread::Diff>,
348    window: &mut Window,
349    cx: &mut App,
350) -> Entity<Editor> {
351    cx.new(|cx| {
352        let mut editor = Editor::new(
353            EditorMode::Full {
354                scale_ui_elements_with_buffer_font_size: false,
355                show_active_line_background: false,
356                sized_by_content: true,
357            },
358            diff.read(cx).multibuffer().clone(),
359            None,
360            window,
361            cx,
362        );
363        editor.set_show_gutter(false, cx);
364        editor.disable_inline_diagnostics();
365        editor.disable_expand_excerpt_buttons(cx);
366        editor.set_show_vertical_scrollbar(false, cx);
367        editor.set_minimap_visibility(MinimapVisibility::Disabled, window, cx);
368        editor.set_soft_wrap_mode(SoftWrap::None, cx);
369        editor.scroll_manager.set_forbid_vertical_scroll(true);
370        editor.set_show_indent_guides(false, cx);
371        editor.set_read_only(true);
372        editor.set_show_breakpoints(false, cx);
373        editor.set_show_code_actions(false, cx);
374        editor.set_show_git_diff_gutter(false, cx);
375        editor.set_expand_all_diff_hunks(cx);
376        editor.set_text_style_refinement(diff_editor_text_style_refinement(cx));
377        editor
378    })
379}
380
381fn diff_editor_text_style_refinement(cx: &mut App) -> TextStyleRefinement {
382    TextStyleRefinement {
383        font_size: Some(
384            TextSize::Small
385                .rems(cx)
386                .to_pixels(ThemeSettings::get_global(cx).agent_font_size(cx))
387                .into(),
388        ),
389        ..Default::default()
390    }
391}
392
393#[cfg(test)]
394mod tests {
395    use std::{path::Path, rc::Rc};
396
397    use acp_thread::{AgentConnection, StubAgentConnection};
398    use agent_client_protocol as acp;
399    use agent_settings::AgentSettings;
400    use agent2::HistoryStore;
401    use assistant_context::ContextStore;
402    use buffer_diff::{DiffHunkStatus, DiffHunkStatusKind};
403    use editor::{EditorSettings, RowInfo};
404    use fs::FakeFs;
405    use gpui::{AppContext as _, SemanticVersion, TestAppContext};
406
407    use crate::acp::entry_view_state::EntryViewState;
408    use multi_buffer::MultiBufferRow;
409    use pretty_assertions::assert_matches;
410    use project::Project;
411    use serde_json::json;
412    use settings::{Settings as _, SettingsStore};
413    use theme::ThemeSettings;
414    use util::path;
415    use workspace::Workspace;
416
417    #[gpui::test]
418    async fn test_diff_sync(cx: &mut TestAppContext) {
419        init_test(cx);
420        let fs = FakeFs::new(cx.executor());
421        fs.insert_tree(
422            "/project",
423            json!({
424                "hello.txt": "hi world"
425            }),
426        )
427        .await;
428        let project = Project::test(fs, [Path::new(path!("/project"))], cx).await;
429
430        let (workspace, cx) =
431            cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
432
433        let tool_call = acp::ToolCall {
434            id: acp::ToolCallId("tool".into()),
435            title: "Tool call".into(),
436            kind: acp::ToolKind::Other,
437            status: acp::ToolCallStatus::InProgress,
438            content: vec![acp::ToolCallContent::Diff {
439                diff: acp::Diff {
440                    path: "/project/hello.txt".into(),
441                    old_text: Some("hi world".into()),
442                    new_text: "hello world".into(),
443                },
444            }],
445            locations: vec![],
446            raw_input: None,
447            raw_output: None,
448        };
449        let connection = Rc::new(StubAgentConnection::new());
450        let thread = cx
451            .update(|_, cx| {
452                connection
453                    .clone()
454                    .new_thread(project.clone(), Path::new(path!("/project")), cx)
455            })
456            .await
457            .unwrap();
458        let session_id = thread.update(cx, |thread, _| thread.session_id().clone());
459
460        cx.update(|_, cx| {
461            connection.send_update(session_id, acp::SessionUpdate::ToolCall(tool_call), cx)
462        });
463
464        let context_store = cx.new(|cx| ContextStore::fake(project.clone(), cx));
465        let history_store = cx.new(|cx| HistoryStore::new(context_store, cx));
466
467        let view_state = cx.new(|_cx| {
468            EntryViewState::new(
469                workspace.downgrade(),
470                project.clone(),
471                history_store,
472                None,
473                Default::default(),
474                false,
475            )
476        });
477
478        view_state.update_in(cx, |view_state, window, cx| {
479            view_state.sync_entry(0, &thread, window, cx)
480        });
481
482        let diff = thread.read_with(cx, |thread, _cx| {
483            thread
484                .entries()
485                .get(0)
486                .unwrap()
487                .diffs()
488                .next()
489                .unwrap()
490                .clone()
491        });
492
493        cx.run_until_parked();
494
495        let diff_editor = view_state.read_with(cx, |view_state, _cx| {
496            view_state.entry(0).unwrap().editor_for_diff(&diff).unwrap()
497        });
498        assert_eq!(
499            diff_editor.read_with(cx, |editor, cx| editor.text(cx)),
500            "hi world\nhello world"
501        );
502        let row_infos = diff_editor.read_with(cx, |editor, cx| {
503            let multibuffer = editor.buffer().read(cx);
504            multibuffer
505                .snapshot(cx)
506                .row_infos(MultiBufferRow(0))
507                .collect::<Vec<_>>()
508        });
509        assert_matches!(
510            row_infos.as_slice(),
511            [
512                RowInfo {
513                    multibuffer_row: Some(MultiBufferRow(0)),
514                    diff_status: Some(DiffHunkStatus {
515                        kind: DiffHunkStatusKind::Deleted,
516                        ..
517                    }),
518                    ..
519                },
520                RowInfo {
521                    multibuffer_row: Some(MultiBufferRow(1)),
522                    diff_status: Some(DiffHunkStatus {
523                        kind: DiffHunkStatusKind::Added,
524                        ..
525                    }),
526                    ..
527                }
528            ]
529        );
530    }
531
532    fn init_test(cx: &mut TestAppContext) {
533        cx.update(|cx| {
534            let settings_store = SettingsStore::test(cx);
535            cx.set_global(settings_store);
536            language::init(cx);
537            Project::init_settings(cx);
538            AgentSettings::register(cx);
539            workspace::init_settings(cx);
540            ThemeSettings::register(cx);
541            release_channel::init(SemanticVersion::default(), cx);
542            EditorSettings::register(cx);
543        });
544    }
545}