entry_view_state.rs

  1use std::{cell::RefCell, ops::Range, rc::Rc};
  2
  3use super::thread_history::ThreadHistory;
  4use acp_thread::{AcpThread, AgentThreadEntry};
  5use agent::ThreadStore;
  6use agent_client_protocol::{self as acp, ToolCallId};
  7use collections::HashMap;
  8use editor::{Editor, EditorMode, MinimapVisibility, SizingBehavior};
  9use gpui::{
 10    AnyEntity, App, AppContext as _, Entity, EntityId, EventEmitter, FocusHandle, Focusable,
 11    ScrollHandle, SharedString, TextStyleRefinement, WeakEntity, Window,
 12};
 13use language::language_settings::SoftWrap;
 14use project::Project;
 15use prompt_store::PromptStore;
 16use settings::Settings as _;
 17use terminal_view::TerminalView;
 18use theme::ThemeSettings;
 19use ui::{Context, TextSize};
 20use workspace::Workspace;
 21
 22use crate::message_editor::{MessageEditor, MessageEditorEvent};
 23
 24pub struct EntryViewState {
 25    workspace: WeakEntity<Workspace>,
 26    project: WeakEntity<Project>,
 27    thread_store: Option<Entity<ThreadStore>>,
 28    history: WeakEntity<ThreadHistory>,
 29    prompt_store: Option<Entity<PromptStore>>,
 30    entries: Vec<Entry>,
 31    prompt_capabilities: Rc<RefCell<acp::PromptCapabilities>>,
 32    available_commands: Rc<RefCell<Vec<acp::AvailableCommand>>>,
 33    agent_name: SharedString,
 34}
 35
 36impl EntryViewState {
 37    pub fn new(
 38        workspace: WeakEntity<Workspace>,
 39        project: WeakEntity<Project>,
 40        thread_store: Option<Entity<ThreadStore>>,
 41        history: WeakEntity<ThreadHistory>,
 42        prompt_store: Option<Entity<PromptStore>>,
 43        prompt_capabilities: Rc<RefCell<acp::PromptCapabilities>>,
 44        available_commands: Rc<RefCell<Vec<acp::AvailableCommand>>>,
 45        agent_name: SharedString,
 46    ) -> Self {
 47        Self {
 48            workspace,
 49            project,
 50            thread_store,
 51            history,
 52            prompt_store,
 53            entries: Vec::new(),
 54            prompt_capabilities,
 55            available_commands,
 56            agent_name,
 57        }
 58    }
 59
 60    pub fn entry(&self, index: usize) -> Option<&Entry> {
 61        self.entries.get(index)
 62    }
 63
 64    pub fn sync_entry(
 65        &mut self,
 66        index: usize,
 67        thread: &Entity<AcpThread>,
 68        window: &mut Window,
 69        cx: &mut Context<Self>,
 70    ) {
 71        let Some(thread_entry) = thread.read(cx).entries().get(index) else {
 72            return;
 73        };
 74
 75        match thread_entry {
 76            AgentThreadEntry::UserMessage(message) => {
 77                let has_id = message.id.is_some();
 78                let is_subagent = thread.read(cx).parent_session_id().is_some();
 79                let chunks = message.chunks.clone();
 80                if let Some(Entry::UserMessage(editor)) = self.entries.get_mut(index) {
 81                    if !editor.focus_handle(cx).is_focused(window) {
 82                        // Only update if we are not editing.
 83                        // If we are, cancelling the edit will set the message to the newest content.
 84                        editor.update(cx, |editor, cx| {
 85                            editor.set_message(chunks, window, cx);
 86                        });
 87                    }
 88                } else {
 89                    let message_editor = cx.new(|cx| {
 90                        let mut editor = MessageEditor::new(
 91                            self.workspace.clone(),
 92                            self.project.clone(),
 93                            self.thread_store.clone(),
 94                            self.history.clone(),
 95                            self.prompt_store.clone(),
 96                            self.prompt_capabilities.clone(),
 97                            self.available_commands.clone(),
 98                            self.agent_name.clone(),
 99                            "Edit message - @ to include context",
100                            editor::EditorMode::AutoHeight {
101                                min_lines: 1,
102                                max_lines: None,
103                            },
104                            window,
105                            cx,
106                        );
107                        if !has_id || is_subagent {
108                            editor.set_read_only(true, cx);
109                        }
110                        editor.set_message(chunks, window, cx);
111                        editor
112                    });
113                    cx.subscribe(&message_editor, move |_, editor, event, cx| {
114                        cx.emit(EntryViewEvent {
115                            entry_index: index,
116                            view_event: ViewEvent::MessageEditorEvent(editor, *event),
117                        })
118                    })
119                    .detach();
120                    self.set_entry(index, Entry::UserMessage(message_editor));
121                }
122            }
123            AgentThreadEntry::ToolCall(tool_call) => {
124                let id = tool_call.id.clone();
125                let terminals = tool_call.terminals().cloned().collect::<Vec<_>>();
126                let diffs = tool_call.diffs().cloned().collect::<Vec<_>>();
127
128                let views = if let Some(Entry::Content(views)) = self.entries.get_mut(index) {
129                    views
130                } else {
131                    self.set_entry(index, Entry::empty());
132                    let Some(Entry::Content(views)) = self.entries.get_mut(index) else {
133                        unreachable!()
134                    };
135                    views
136                };
137
138                let is_tool_call_completed =
139                    matches!(tool_call.status, acp_thread::ToolCallStatus::Completed);
140
141                for terminal in terminals {
142                    match views.entry(terminal.entity_id()) {
143                        collections::hash_map::Entry::Vacant(entry) => {
144                            let element = create_terminal(
145                                self.workspace.clone(),
146                                self.project.clone(),
147                                terminal.clone(),
148                                window,
149                                cx,
150                            )
151                            .into_any();
152                            cx.emit(EntryViewEvent {
153                                entry_index: index,
154                                view_event: ViewEvent::NewTerminal(id.clone()),
155                            });
156                            entry.insert(element);
157                        }
158                        collections::hash_map::Entry::Occupied(_entry) => {
159                            if is_tool_call_completed && terminal.read(cx).output().is_none() {
160                                cx.emit(EntryViewEvent {
161                                    entry_index: index,
162                                    view_event: ViewEvent::TerminalMovedToBackground(id.clone()),
163                                });
164                            }
165                        }
166                    }
167                }
168
169                for diff in diffs {
170                    views.entry(diff.entity_id()).or_insert_with(|| {
171                        let element = create_editor_diff(diff.clone(), window, cx).into_any();
172                        cx.emit(EntryViewEvent {
173                            entry_index: index,
174                            view_event: ViewEvent::NewDiff(id.clone()),
175                        });
176                        element
177                    });
178                }
179            }
180            AgentThreadEntry::AssistantMessage(message) => {
181                let entry = if let Some(Entry::AssistantMessage(entry)) =
182                    self.entries.get_mut(index)
183                {
184                    entry
185                } else {
186                    self.set_entry(
187                        index,
188                        Entry::AssistantMessage(AssistantMessageEntry::default()),
189                    );
190                    let Some(Entry::AssistantMessage(entry)) = self.entries.get_mut(index) else {
191                        unreachable!()
192                    };
193                    entry
194                };
195                entry.sync(message);
196            }
197        };
198    }
199
200    fn set_entry(&mut self, index: usize, entry: Entry) {
201        if index == self.entries.len() {
202            self.entries.push(entry);
203        } else {
204            self.entries[index] = entry;
205        }
206    }
207
208    pub fn remove(&mut self, range: Range<usize>) {
209        self.entries.drain(range);
210    }
211
212    pub fn agent_ui_font_size_changed(&mut self, cx: &mut App) {
213        for entry in self.entries.iter() {
214            match entry {
215                Entry::UserMessage { .. } | Entry::AssistantMessage { .. } => {}
216                Entry::Content(response_views) => {
217                    for view in response_views.values() {
218                        if let Ok(diff_editor) = view.clone().downcast::<Editor>() {
219                            diff_editor.update(cx, |diff_editor, cx| {
220                                diff_editor.set_text_style_refinement(
221                                    diff_editor_text_style_refinement(cx),
222                                );
223                                cx.notify();
224                            })
225                        }
226                    }
227                }
228            }
229        }
230    }
231}
232
233impl EventEmitter<EntryViewEvent> for EntryViewState {}
234
235pub struct EntryViewEvent {
236    pub entry_index: usize,
237    pub view_event: ViewEvent,
238}
239
240pub enum ViewEvent {
241    NewDiff(ToolCallId),
242    NewTerminal(ToolCallId),
243    TerminalMovedToBackground(ToolCallId),
244    MessageEditorEvent(Entity<MessageEditor>, MessageEditorEvent),
245}
246
247#[derive(Default, Debug)]
248pub struct AssistantMessageEntry {
249    scroll_handles_by_chunk_index: HashMap<usize, ScrollHandle>,
250}
251
252impl AssistantMessageEntry {
253    pub fn scroll_handle_for_chunk(&self, ix: usize) -> Option<ScrollHandle> {
254        self.scroll_handles_by_chunk_index.get(&ix).cloned()
255    }
256
257    pub fn sync(&mut self, message: &acp_thread::AssistantMessage) {
258        if let Some(acp_thread::AssistantMessageChunk::Thought { .. }) = message.chunks.last() {
259            let ix = message.chunks.len() - 1;
260            let handle = self.scroll_handles_by_chunk_index.entry(ix).or_default();
261            handle.scroll_to_bottom();
262        }
263    }
264}
265
266#[derive(Debug)]
267pub enum Entry {
268    UserMessage(Entity<MessageEditor>),
269    AssistantMessage(AssistantMessageEntry),
270    Content(HashMap<EntityId, AnyEntity>),
271}
272
273impl Entry {
274    pub fn focus_handle(&self, cx: &App) -> Option<FocusHandle> {
275        match self {
276            Self::UserMessage(editor) => Some(editor.read(cx).focus_handle(cx)),
277            Self::AssistantMessage(_) | Self::Content(_) => None,
278        }
279    }
280
281    pub fn message_editor(&self) -> Option<&Entity<MessageEditor>> {
282        match self {
283            Self::UserMessage(editor) => Some(editor),
284            Self::AssistantMessage(_) | Self::Content(_) => None,
285        }
286    }
287
288    pub fn editor_for_diff(&self, diff: &Entity<acp_thread::Diff>) -> Option<Entity<Editor>> {
289        self.content_map()?
290            .get(&diff.entity_id())
291            .cloned()
292            .map(|entity| entity.downcast::<Editor>().unwrap())
293    }
294
295    pub fn terminal(
296        &self,
297        terminal: &Entity<acp_thread::Terminal>,
298    ) -> Option<Entity<TerminalView>> {
299        self.content_map()?
300            .get(&terminal.entity_id())
301            .cloned()
302            .map(|entity| entity.downcast::<TerminalView>().unwrap())
303    }
304
305    pub fn scroll_handle_for_assistant_message_chunk(
306        &self,
307        chunk_ix: usize,
308    ) -> Option<ScrollHandle> {
309        match self {
310            Self::AssistantMessage(message) => message.scroll_handle_for_chunk(chunk_ix),
311            Self::UserMessage(_) | Self::Content(_) => None,
312        }
313    }
314
315    fn content_map(&self) -> Option<&HashMap<EntityId, AnyEntity>> {
316        match self {
317            Self::Content(map) => Some(map),
318            _ => None,
319        }
320    }
321
322    fn empty() -> Self {
323        Self::Content(HashMap::default())
324    }
325
326    #[cfg(test)]
327    pub fn has_content(&self) -> bool {
328        match self {
329            Self::Content(map) => !map.is_empty(),
330            Self::UserMessage(_) | Self::AssistantMessage(_) => false,
331        }
332    }
333}
334
335fn create_terminal(
336    workspace: WeakEntity<Workspace>,
337    project: WeakEntity<Project>,
338    terminal: Entity<acp_thread::Terminal>,
339    window: &mut Window,
340    cx: &mut App,
341) -> Entity<TerminalView> {
342    cx.new(|cx| {
343        let mut view = TerminalView::new(
344            terminal.read(cx).inner().clone(),
345            workspace,
346            None,
347            project,
348            window,
349            cx,
350        );
351        view.set_embedded_mode(Some(1000), cx);
352        view
353    })
354}
355
356fn create_editor_diff(
357    diff: Entity<acp_thread::Diff>,
358    window: &mut Window,
359    cx: &mut App,
360) -> Entity<Editor> {
361    cx.new(|cx| {
362        let mut editor = Editor::new(
363            EditorMode::Full {
364                scale_ui_elements_with_buffer_font_size: false,
365                show_active_line_background: false,
366                sizing_behavior: SizingBehavior::SizeByContent,
367            },
368            diff.read(cx).multibuffer().clone(),
369            None,
370            window,
371            cx,
372        );
373        editor.set_show_gutter(false, cx);
374        editor.disable_inline_diagnostics();
375        editor.disable_expand_excerpt_buttons(cx);
376        editor.set_show_vertical_scrollbar(false, cx);
377        editor.set_minimap_visibility(MinimapVisibility::Disabled, window, cx);
378        editor.set_soft_wrap_mode(SoftWrap::None, cx);
379        editor.scroll_manager.set_forbid_vertical_scroll(true);
380        editor.set_show_indent_guides(false, cx);
381        editor.set_read_only(true);
382        editor.set_show_breakpoints(false, cx);
383        editor.set_show_code_actions(false, cx);
384        editor.set_show_git_diff_gutter(false, cx);
385        editor.set_expand_all_diff_hunks(cx);
386        editor.set_text_style_refinement(diff_editor_text_style_refinement(cx));
387        editor
388    })
389}
390
391fn diff_editor_text_style_refinement(cx: &mut App) -> TextStyleRefinement {
392    TextStyleRefinement {
393        font_size: Some(
394            TextSize::Small
395                .rems(cx)
396                .to_pixels(ThemeSettings::get_global(cx).agent_ui_font_size(cx))
397                .into(),
398        ),
399        ..Default::default()
400    }
401}
402
403#[cfg(test)]
404mod tests {
405    use std::path::Path;
406    use std::rc::Rc;
407
408    use acp_thread::{AgentConnection, StubAgentConnection};
409    use agent_client_protocol as acp;
410    use buffer_diff::{DiffHunkStatus, DiffHunkStatusKind};
411    use editor::RowInfo;
412    use fs::FakeFs;
413    use gpui::{AppContext as _, TestAppContext};
414
415    use crate::entry_view_state::EntryViewState;
416    use multi_buffer::MultiBufferRow;
417    use pretty_assertions::assert_matches;
418    use project::Project;
419    use serde_json::json;
420    use settings::SettingsStore;
421    use util::path;
422    use workspace::MultiWorkspace;
423
424    #[gpui::test]
425    async fn test_diff_sync(cx: &mut TestAppContext) {
426        init_test(cx);
427        let fs = FakeFs::new(cx.executor());
428        fs.insert_tree(
429            "/project",
430            json!({
431                "hello.txt": "hi world"
432            }),
433        )
434        .await;
435        let project = Project::test(fs, [Path::new(path!("/project"))], cx).await;
436
437        let (multi_workspace, cx) =
438            cx.add_window_view(|window, cx| MultiWorkspace::test_new(project.clone(), window, cx));
439        let workspace = multi_workspace.read_with(cx, |mw, _| mw.workspace().clone());
440
441        let tool_call = acp::ToolCall::new("tool", "Tool call")
442            .status(acp::ToolCallStatus::InProgress)
443            .content(vec![acp::ToolCallContent::Diff(
444                acp::Diff::new("/project/hello.txt", "hello world").old_text("hi world"),
445            )]);
446        let connection = Rc::new(StubAgentConnection::new());
447        let thread = cx
448            .update(|_, cx| {
449                connection
450                    .clone()
451                    .new_session(project.clone(), Path::new(path!("/project")), cx)
452            })
453            .await
454            .unwrap();
455        let session_id = thread.update(cx, |thread, _| thread.session_id().clone());
456
457        cx.update(|_, cx| {
458            connection.send_update(session_id, acp::SessionUpdate::ToolCall(tool_call), cx)
459        });
460
461        let thread_store = None;
462        let history =
463            cx.update(|window, cx| cx.new(|cx| crate::ThreadHistory::new(None, window, cx)));
464
465        let view_state = cx.new(|_cx| {
466            EntryViewState::new(
467                workspace.downgrade(),
468                project.downgrade(),
469                thread_store,
470                history.downgrade(),
471                None,
472                Default::default(),
473                Default::default(),
474                "Test Agent".into(),
475            )
476        });
477
478        view_state.update_in(cx, |view_state, window, cx| {
479            view_state.sync_entry(0, &thread, window, cx)
480        });
481
482        let diff = thread.read_with(cx, |thread, _| {
483            thread
484                .entries()
485                .get(0)
486                .unwrap()
487                .diffs()
488                .next()
489                .unwrap()
490                .clone()
491        });
492
493        cx.run_until_parked();
494
495        let diff_editor = view_state.read_with(cx, |view_state, _cx| {
496            view_state.entry(0).unwrap().editor_for_diff(&diff).unwrap()
497        });
498        assert_eq!(
499            diff_editor.read_with(cx, |editor, cx| editor.text(cx)),
500            "hi world\nhello world"
501        );
502        let row_infos = diff_editor.read_with(cx, |editor, cx| {
503            let multibuffer = editor.buffer().read(cx);
504            multibuffer
505                .snapshot(cx)
506                .row_infos(MultiBufferRow(0))
507                .collect::<Vec<_>>()
508        });
509        assert_matches!(
510            row_infos.as_slice(),
511            [
512                RowInfo {
513                    multibuffer_row: Some(MultiBufferRow(0)),
514                    diff_status: Some(DiffHunkStatus {
515                        kind: DiffHunkStatusKind::Deleted,
516                        ..
517                    }),
518                    ..
519                },
520                RowInfo {
521                    multibuffer_row: Some(MultiBufferRow(1)),
522                    diff_status: Some(DiffHunkStatus {
523                        kind: DiffHunkStatusKind::Added,
524                        ..
525                    }),
526                    ..
527                }
528            ]
529        );
530    }
531
532    fn init_test(cx: &mut TestAppContext) {
533        cx.update(|cx| {
534            let settings_store = SettingsStore::test(cx);
535            cx.set_global(settings_store);
536            theme::init(theme::LoadThemes::JustBase, cx);
537            release_channel::init(semver::Version::new(0, 0, 0), cx);
538        });
539    }
540}