entry_view_state.rs

  1use std::{cell::RefCell, ops::Range, rc::Rc};
  2
  3use super::thread_history::ThreadHistory;
  4use acp_thread::{AcpThread, AgentThreadEntry};
  5use agent::ThreadStore;
  6use agent_client_protocol::{self as acp, ToolCallId};
  7use collections::HashMap;
  8use editor::{Editor, EditorEvent, EditorMode, MinimapVisibility, SizingBehavior};
  9use gpui::{
 10    AnyEntity, App, AppContext as _, Entity, EntityId, EventEmitter, FocusHandle, Focusable,
 11    ScrollHandle, SharedString, TextStyleRefinement, WeakEntity, Window,
 12};
 13use language::language_settings::SoftWrap;
 14use project::Project;
 15use prompt_store::PromptStore;
 16use rope::Point;
 17use settings::Settings as _;
 18use terminal_view::TerminalView;
 19use theme::ThemeSettings;
 20use ui::{Context, TextSize};
 21use workspace::Workspace;
 22
 23use crate::message_editor::{MessageEditor, MessageEditorEvent};
 24
 25pub struct EntryViewState {
 26    workspace: WeakEntity<Workspace>,
 27    project: WeakEntity<Project>,
 28    thread_store: Option<Entity<ThreadStore>>,
 29    history: WeakEntity<ThreadHistory>,
 30    prompt_store: Option<Entity<PromptStore>>,
 31    entries: Vec<Entry>,
 32    prompt_capabilities: Rc<RefCell<acp::PromptCapabilities>>,
 33    available_commands: Rc<RefCell<Vec<acp::AvailableCommand>>>,
 34    agent_name: SharedString,
 35}
 36
 37impl EntryViewState {
 38    pub fn new(
 39        workspace: WeakEntity<Workspace>,
 40        project: WeakEntity<Project>,
 41        thread_store: Option<Entity<ThreadStore>>,
 42        history: WeakEntity<ThreadHistory>,
 43        prompt_store: Option<Entity<PromptStore>>,
 44        prompt_capabilities: Rc<RefCell<acp::PromptCapabilities>>,
 45        available_commands: Rc<RefCell<Vec<acp::AvailableCommand>>>,
 46        agent_name: SharedString,
 47    ) -> Self {
 48        Self {
 49            workspace,
 50            project,
 51            thread_store,
 52            history,
 53            prompt_store,
 54            entries: Vec::new(),
 55            prompt_capabilities,
 56            available_commands,
 57            agent_name,
 58        }
 59    }
 60
 61    pub fn entry(&self, index: usize) -> Option<&Entry> {
 62        self.entries.get(index)
 63    }
 64
 65    pub fn sync_entry(
 66        &mut self,
 67        index: usize,
 68        thread: &Entity<AcpThread>,
 69        window: &mut Window,
 70        cx: &mut Context<Self>,
 71    ) {
 72        let Some(thread_entry) = thread.read(cx).entries().get(index) else {
 73            return;
 74        };
 75
 76        match thread_entry {
 77            AgentThreadEntry::UserMessage(message) => {
 78                let has_id = message.id.is_some();
 79                let is_subagent = thread.read(cx).parent_session_id().is_some();
 80                let chunks = message.chunks.clone();
 81                if let Some(Entry::UserMessage(editor)) = self.entries.get_mut(index) {
 82                    if !editor.focus_handle(cx).is_focused(window) {
 83                        // Only update if we are not editing.
 84                        // If we are, cancelling the edit will set the message to the newest content.
 85                        editor.update(cx, |editor, cx| {
 86                            editor.set_message(chunks, window, cx);
 87                        });
 88                    }
 89                } else {
 90                    let message_editor = cx.new(|cx| {
 91                        let mut editor = MessageEditor::new(
 92                            self.workspace.clone(),
 93                            self.project.clone(),
 94                            self.thread_store.clone(),
 95                            self.history.clone(),
 96                            self.prompt_store.clone(),
 97                            self.prompt_capabilities.clone(),
 98                            self.available_commands.clone(),
 99                            self.agent_name.clone(),
100                            "Edit message - @ to include context",
101                            editor::EditorMode::AutoHeight {
102                                min_lines: 1,
103                                max_lines: None,
104                            },
105                            window,
106                            cx,
107                        );
108                        if !has_id || is_subagent {
109                            editor.set_read_only(true, cx);
110                        }
111                        editor.set_message(chunks, window, cx);
112                        editor
113                    });
114                    cx.subscribe(&message_editor, move |_, editor, event, cx| {
115                        cx.emit(EntryViewEvent {
116                            entry_index: index,
117                            view_event: ViewEvent::MessageEditorEvent(editor, *event),
118                        })
119                    })
120                    .detach();
121                    self.set_entry(index, Entry::UserMessage(message_editor));
122                }
123            }
124            AgentThreadEntry::ToolCall(tool_call) => {
125                let id = tool_call.id.clone();
126                let terminals = tool_call.terminals().cloned().collect::<Vec<_>>();
127                let diffs = tool_call.diffs().cloned().collect::<Vec<_>>();
128
129                let views = if let Some(Entry::Content(views)) = self.entries.get_mut(index) {
130                    views
131                } else {
132                    self.set_entry(index, Entry::empty());
133                    let Some(Entry::Content(views)) = self.entries.get_mut(index) else {
134                        unreachable!()
135                    };
136                    views
137                };
138
139                let is_tool_call_completed =
140                    matches!(tool_call.status, acp_thread::ToolCallStatus::Completed);
141
142                for terminal in terminals {
143                    match views.entry(terminal.entity_id()) {
144                        collections::hash_map::Entry::Vacant(entry) => {
145                            let element = create_terminal(
146                                self.workspace.clone(),
147                                self.project.clone(),
148                                terminal.clone(),
149                                window,
150                                cx,
151                            )
152                            .into_any();
153                            cx.emit(EntryViewEvent {
154                                entry_index: index,
155                                view_event: ViewEvent::NewTerminal(id.clone()),
156                            });
157                            entry.insert(element);
158                        }
159                        collections::hash_map::Entry::Occupied(_entry) => {
160                            if is_tool_call_completed && terminal.read(cx).output().is_none() {
161                                cx.emit(EntryViewEvent {
162                                    entry_index: index,
163                                    view_event: ViewEvent::TerminalMovedToBackground(id.clone()),
164                                });
165                            }
166                        }
167                    }
168                }
169
170                for diff in diffs {
171                    views.entry(diff.entity_id()).or_insert_with(|| {
172                        let editor = create_editor_diff(diff.clone(), window, cx);
173                        cx.subscribe(&editor, {
174                            let diff = diff.clone();
175                            let entry_index = index;
176                            move |_this, _editor, event: &EditorEvent, cx| {
177                                if let EditorEvent::OpenExcerptsRequested {
178                                    selections_by_buffer,
179                                    split,
180                                } = event
181                                {
182                                    let multibuffer = diff.read(cx).multibuffer();
183                                    if let Some((buffer_id, (ranges, _))) =
184                                        selections_by_buffer.iter().next()
185                                    {
186                                        if let Some(buffer) =
187                                            multibuffer.read(cx).buffer(*buffer_id)
188                                        {
189                                            if let Some(range) = ranges.first() {
190                                                let point =
191                                                    buffer.read(cx).offset_to_point(range.start.0);
192                                                if let Some(path) = diff.read(cx).file_path(cx) {
193                                                    cx.emit(EntryViewEvent {
194                                                        entry_index,
195                                                        view_event: ViewEvent::OpenDiffLocation {
196                                                            path,
197                                                            position: point,
198                                                            split: *split,
199                                                        },
200                                                    });
201                                                }
202                                            }
203                                        }
204                                    }
205                                }
206                            }
207                        })
208                        .detach();
209                        cx.emit(EntryViewEvent {
210                            entry_index: index,
211                            view_event: ViewEvent::NewDiff(id.clone()),
212                        });
213                        editor.into_any()
214                    });
215                }
216            }
217            AgentThreadEntry::AssistantMessage(message) => {
218                let entry = if let Some(Entry::AssistantMessage(entry)) =
219                    self.entries.get_mut(index)
220                {
221                    entry
222                } else {
223                    self.set_entry(
224                        index,
225                        Entry::AssistantMessage(AssistantMessageEntry::default()),
226                    );
227                    let Some(Entry::AssistantMessage(entry)) = self.entries.get_mut(index) else {
228                        unreachable!()
229                    };
230                    entry
231                };
232                entry.sync(message);
233            }
234        };
235    }
236
237    fn set_entry(&mut self, index: usize, entry: Entry) {
238        if index == self.entries.len() {
239            self.entries.push(entry);
240        } else {
241            self.entries[index] = entry;
242        }
243    }
244
245    pub fn remove(&mut self, range: Range<usize>) {
246        self.entries.drain(range);
247    }
248
249    pub fn agent_ui_font_size_changed(&mut self, cx: &mut App) {
250        for entry in self.entries.iter() {
251            match entry {
252                Entry::UserMessage { .. } | Entry::AssistantMessage { .. } => {}
253                Entry::Content(response_views) => {
254                    for view in response_views.values() {
255                        if let Ok(diff_editor) = view.clone().downcast::<Editor>() {
256                            diff_editor.update(cx, |diff_editor, cx| {
257                                diff_editor.set_text_style_refinement(
258                                    diff_editor_text_style_refinement(cx),
259                                );
260                                cx.notify();
261                            })
262                        }
263                    }
264                }
265            }
266        }
267    }
268}
269
270impl EventEmitter<EntryViewEvent> for EntryViewState {}
271
272pub struct EntryViewEvent {
273    pub entry_index: usize,
274    pub view_event: ViewEvent,
275}
276
277pub enum ViewEvent {
278    NewDiff(ToolCallId),
279    NewTerminal(ToolCallId),
280    TerminalMovedToBackground(ToolCallId),
281    MessageEditorEvent(Entity<MessageEditor>, MessageEditorEvent),
282    OpenDiffLocation {
283        path: String,
284        position: Point,
285        split: bool,
286    },
287}
288
289#[derive(Default, Debug)]
290pub struct AssistantMessageEntry {
291    scroll_handles_by_chunk_index: HashMap<usize, ScrollHandle>,
292}
293
294impl AssistantMessageEntry {
295    pub fn scroll_handle_for_chunk(&self, ix: usize) -> Option<ScrollHandle> {
296        self.scroll_handles_by_chunk_index.get(&ix).cloned()
297    }
298
299    pub fn sync(&mut self, message: &acp_thread::AssistantMessage) {
300        if let Some(acp_thread::AssistantMessageChunk::Thought { .. }) = message.chunks.last() {
301            let ix = message.chunks.len() - 1;
302            let handle = self.scroll_handles_by_chunk_index.entry(ix).or_default();
303            handle.scroll_to_bottom();
304        }
305    }
306}
307
308#[derive(Debug)]
309pub enum Entry {
310    UserMessage(Entity<MessageEditor>),
311    AssistantMessage(AssistantMessageEntry),
312    Content(HashMap<EntityId, AnyEntity>),
313}
314
315impl Entry {
316    pub fn focus_handle(&self, cx: &App) -> Option<FocusHandle> {
317        match self {
318            Self::UserMessage(editor) => Some(editor.read(cx).focus_handle(cx)),
319            Self::AssistantMessage(_) | Self::Content(_) => None,
320        }
321    }
322
323    pub fn message_editor(&self) -> Option<&Entity<MessageEditor>> {
324        match self {
325            Self::UserMessage(editor) => Some(editor),
326            Self::AssistantMessage(_) | Self::Content(_) => None,
327        }
328    }
329
330    pub fn editor_for_diff(&self, diff: &Entity<acp_thread::Diff>) -> Option<Entity<Editor>> {
331        self.content_map()?
332            .get(&diff.entity_id())
333            .cloned()
334            .map(|entity| entity.downcast::<Editor>().unwrap())
335    }
336
337    pub fn terminal(
338        &self,
339        terminal: &Entity<acp_thread::Terminal>,
340    ) -> Option<Entity<TerminalView>> {
341        self.content_map()?
342            .get(&terminal.entity_id())
343            .cloned()
344            .map(|entity| entity.downcast::<TerminalView>().unwrap())
345    }
346
347    pub fn scroll_handle_for_assistant_message_chunk(
348        &self,
349        chunk_ix: usize,
350    ) -> Option<ScrollHandle> {
351        match self {
352            Self::AssistantMessage(message) => message.scroll_handle_for_chunk(chunk_ix),
353            Self::UserMessage(_) | Self::Content(_) => None,
354        }
355    }
356
357    fn content_map(&self) -> Option<&HashMap<EntityId, AnyEntity>> {
358        match self {
359            Self::Content(map) => Some(map),
360            _ => None,
361        }
362    }
363
364    fn empty() -> Self {
365        Self::Content(HashMap::default())
366    }
367
368    #[cfg(test)]
369    pub fn has_content(&self) -> bool {
370        match self {
371            Self::Content(map) => !map.is_empty(),
372            Self::UserMessage(_) | Self::AssistantMessage(_) => false,
373        }
374    }
375}
376
377fn create_terminal(
378    workspace: WeakEntity<Workspace>,
379    project: WeakEntity<Project>,
380    terminal: Entity<acp_thread::Terminal>,
381    window: &mut Window,
382    cx: &mut App,
383) -> Entity<TerminalView> {
384    cx.new(|cx| {
385        let mut view = TerminalView::new(
386            terminal.read(cx).inner().clone(),
387            workspace,
388            None,
389            project,
390            window,
391            cx,
392        );
393        view.set_embedded_mode(Some(1000), cx);
394        view
395    })
396}
397
398fn create_editor_diff(
399    diff: Entity<acp_thread::Diff>,
400    window: &mut Window,
401    cx: &mut App,
402) -> Entity<Editor> {
403    cx.new(|cx| {
404        let mut editor = Editor::new(
405            EditorMode::Full {
406                scale_ui_elements_with_buffer_font_size: false,
407                show_active_line_background: false,
408                sizing_behavior: SizingBehavior::SizeByContent,
409            },
410            diff.read(cx).multibuffer().clone(),
411            None,
412            window,
413            cx,
414        );
415        editor.set_show_gutter(false, cx);
416        editor.disable_inline_diagnostics();
417        editor.disable_expand_excerpt_buttons(cx);
418        editor.set_show_vertical_scrollbar(false, cx);
419        editor.set_minimap_visibility(MinimapVisibility::Disabled, window, cx);
420        editor.set_soft_wrap_mode(SoftWrap::None, cx);
421        editor.scroll_manager.set_forbid_vertical_scroll(true);
422        editor.set_show_indent_guides(false, cx);
423        editor.set_read_only(true);
424        editor.set_delegate_open_excerpts(true);
425        editor.set_show_breakpoints(false, cx);
426        editor.set_show_code_actions(false, cx);
427        editor.set_show_git_diff_gutter(false, cx);
428        editor.set_expand_all_diff_hunks(cx);
429        editor.set_text_style_refinement(diff_editor_text_style_refinement(cx));
430        editor
431    })
432}
433
434fn diff_editor_text_style_refinement(cx: &mut App) -> TextStyleRefinement {
435    TextStyleRefinement {
436        font_size: Some(
437            TextSize::Small
438                .rems(cx)
439                .to_pixels(ThemeSettings::get_global(cx).agent_ui_font_size(cx))
440                .into(),
441        ),
442        ..Default::default()
443    }
444}
445
446#[cfg(test)]
447mod tests {
448    use std::path::Path;
449    use std::rc::Rc;
450
451    use acp_thread::{AgentConnection, StubAgentConnection};
452    use agent_client_protocol as acp;
453    use buffer_diff::{DiffHunkStatus, DiffHunkStatusKind};
454    use editor::RowInfo;
455    use fs::FakeFs;
456    use gpui::{AppContext as _, TestAppContext};
457
458    use crate::entry_view_state::EntryViewState;
459    use multi_buffer::MultiBufferRow;
460    use pretty_assertions::assert_matches;
461    use project::Project;
462    use serde_json::json;
463    use settings::SettingsStore;
464    use util::path;
465    use workspace::MultiWorkspace;
466
467    #[gpui::test]
468    async fn test_diff_sync(cx: &mut TestAppContext) {
469        init_test(cx);
470        let fs = FakeFs::new(cx.executor());
471        fs.insert_tree(
472            "/project",
473            json!({
474                "hello.txt": "hi world"
475            }),
476        )
477        .await;
478        let project = Project::test(fs, [Path::new(path!("/project"))], cx).await;
479
480        let (multi_workspace, cx) =
481            cx.add_window_view(|window, cx| MultiWorkspace::test_new(project.clone(), window, cx));
482        let workspace = multi_workspace.read_with(cx, |mw, _| mw.workspace().clone());
483
484        let tool_call = acp::ToolCall::new("tool", "Tool call")
485            .status(acp::ToolCallStatus::InProgress)
486            .content(vec![acp::ToolCallContent::Diff(
487                acp::Diff::new("/project/hello.txt", "hello world").old_text("hi world"),
488            )]);
489        let connection = Rc::new(StubAgentConnection::new());
490        let thread = cx
491            .update(|_, cx| {
492                connection
493                    .clone()
494                    .new_session(project.clone(), Path::new(path!("/project")), cx)
495            })
496            .await
497            .unwrap();
498        let session_id = thread.update(cx, |thread, _| thread.session_id().clone());
499
500        cx.update(|_, cx| {
501            connection.send_update(session_id, acp::SessionUpdate::ToolCall(tool_call), cx)
502        });
503
504        let thread_store = None;
505        let history =
506            cx.update(|window, cx| cx.new(|cx| crate::ThreadHistory::new(None, window, cx)));
507
508        let view_state = cx.new(|_cx| {
509            EntryViewState::new(
510                workspace.downgrade(),
511                project.downgrade(),
512                thread_store,
513                history.downgrade(),
514                None,
515                Default::default(),
516                Default::default(),
517                "Test Agent".into(),
518            )
519        });
520
521        view_state.update_in(cx, |view_state, window, cx| {
522            view_state.sync_entry(0, &thread, window, cx)
523        });
524
525        let diff = thread.read_with(cx, |thread, _| {
526            thread
527                .entries()
528                .get(0)
529                .unwrap()
530                .diffs()
531                .next()
532                .unwrap()
533                .clone()
534        });
535
536        cx.run_until_parked();
537
538        let diff_editor = view_state.read_with(cx, |view_state, _cx| {
539            view_state.entry(0).unwrap().editor_for_diff(&diff).unwrap()
540        });
541        assert_eq!(
542            diff_editor.read_with(cx, |editor, cx| editor.text(cx)),
543            "hi world\nhello world"
544        );
545        let row_infos = diff_editor.read_with(cx, |editor, cx| {
546            let multibuffer = editor.buffer().read(cx);
547            multibuffer
548                .snapshot(cx)
549                .row_infos(MultiBufferRow(0))
550                .collect::<Vec<_>>()
551        });
552        assert_matches!(
553            row_infos.as_slice(),
554            [
555                RowInfo {
556                    multibuffer_row: Some(MultiBufferRow(0)),
557                    diff_status: Some(DiffHunkStatus {
558                        kind: DiffHunkStatusKind::Deleted,
559                        ..
560                    }),
561                    ..
562                },
563                RowInfo {
564                    multibuffer_row: Some(MultiBufferRow(1)),
565                    diff_status: Some(DiffHunkStatus {
566                        kind: DiffHunkStatusKind::Added,
567                        ..
568                    }),
569                    ..
570                }
571            ]
572        );
573    }
574
575    fn init_test(cx: &mut TestAppContext) {
576        cx.update(|cx| {
577            let settings_store = SettingsStore::test(cx);
578            cx.set_global(settings_store);
579            theme::init(theme::LoadThemes::JustBase, cx);
580            release_channel::init(semver::Version::new(0, 0, 0), cx);
581        });
582    }
583}