entry_view_state.rs

  1use std::{cell::RefCell, ops::Range, rc::Rc};
  2
  3use super::thread_history::ThreadHistory;
  4use acp_thread::{AcpThread, AgentThreadEntry};
  5use agent::ThreadStore;
  6use agent_client_protocol::{self as acp, ToolCallId};
  7use collections::HashMap;
  8use editor::{Editor, EditorEvent, EditorMode, MinimapVisibility, SizingBehavior};
  9use gpui::{
 10    AnyEntity, App, AppContext as _, Entity, EntityId, EventEmitter, FocusHandle, Focusable,
 11    ScrollHandle, TextStyleRefinement, WeakEntity, Window,
 12};
 13use language::language_settings::SoftWrap;
 14use project::{AgentId, Project};
 15use prompt_store::PromptStore;
 16use rope::Point;
 17use settings::Settings as _;
 18use terminal_view::TerminalView;
 19use theme::ThemeSettings;
 20use ui::{Context, TextSize};
 21use workspace::Workspace;
 22
 23use crate::message_editor::{MessageEditor, MessageEditorEvent};
 24
 25pub struct EntryViewState {
 26    workspace: WeakEntity<Workspace>,
 27    project: WeakEntity<Project>,
 28    thread_store: Option<Entity<ThreadStore>>,
 29    history: Option<WeakEntity<ThreadHistory>>,
 30    prompt_store: Option<Entity<PromptStore>>,
 31    entries: Vec<Entry>,
 32    prompt_capabilities: Rc<RefCell<acp::PromptCapabilities>>,
 33    available_commands: Rc<RefCell<Vec<acp::AvailableCommand>>>,
 34    agent_id: AgentId,
 35}
 36
 37impl EntryViewState {
 38    pub fn new(
 39        workspace: WeakEntity<Workspace>,
 40        project: WeakEntity<Project>,
 41        thread_store: Option<Entity<ThreadStore>>,
 42        history: Option<WeakEntity<ThreadHistory>>,
 43        prompt_store: Option<Entity<PromptStore>>,
 44        prompt_capabilities: Rc<RefCell<acp::PromptCapabilities>>,
 45        available_commands: Rc<RefCell<Vec<acp::AvailableCommand>>>,
 46        agent_id: AgentId,
 47    ) -> Self {
 48        Self {
 49            workspace,
 50            project,
 51            thread_store,
 52            history,
 53            prompt_store,
 54            entries: Vec::new(),
 55            prompt_capabilities,
 56            available_commands,
 57            agent_id,
 58        }
 59    }
 60
 61    pub fn entry(&self, index: usize) -> Option<&Entry> {
 62        self.entries.get(index)
 63    }
 64
 65    pub fn sync_entry(
 66        &mut self,
 67        index: usize,
 68        thread: &Entity<AcpThread>,
 69        window: &mut Window,
 70        cx: &mut Context<Self>,
 71    ) {
 72        let Some(thread_entry) = thread.read(cx).entries().get(index) else {
 73            return;
 74        };
 75
 76        match thread_entry {
 77            AgentThreadEntry::UserMessage(message) => {
 78                let has_id = message.id.is_some();
 79                let is_subagent = thread.read(cx).parent_session_id().is_some();
 80                let chunks = message.chunks.clone();
 81                if let Some(Entry::UserMessage(editor)) = self.entries.get_mut(index) {
 82                    if !editor.focus_handle(cx).is_focused(window) {
 83                        // Only update if we are not editing.
 84                        // If we are, cancelling the edit will set the message to the newest content.
 85                        editor.update(cx, |editor, cx| {
 86                            editor.set_message(chunks, window, cx);
 87                        });
 88                    }
 89                } else {
 90                    let message_editor = cx.new(|cx| {
 91                        let mut editor = MessageEditor::new(
 92                            self.workspace.clone(),
 93                            self.project.clone(),
 94                            self.thread_store.clone(),
 95                            self.history.clone(),
 96                            self.prompt_store.clone(),
 97                            self.prompt_capabilities.clone(),
 98                            self.available_commands.clone(),
 99                            self.agent_id.clone(),
100                            "Edit message - @ to include context",
101                            editor::EditorMode::AutoHeight {
102                                min_lines: 1,
103                                max_lines: None,
104                            },
105                            window,
106                            cx,
107                        );
108                        if !has_id || is_subagent {
109                            editor.set_read_only(true, cx);
110                        }
111                        editor.set_message(chunks, window, cx);
112                        editor
113                    });
114                    cx.subscribe(&message_editor, move |_, editor, event, cx| {
115                        cx.emit(EntryViewEvent {
116                            entry_index: index,
117                            view_event: ViewEvent::MessageEditorEvent(editor, event.clone()),
118                        })
119                    })
120                    .detach();
121                    self.set_entry(index, Entry::UserMessage(message_editor));
122                }
123            }
124            AgentThreadEntry::ToolCall(tool_call) => {
125                let id = tool_call.id.clone();
126                let terminals = tool_call.terminals().cloned().collect::<Vec<_>>();
127                let diffs = tool_call.diffs().cloned().collect::<Vec<_>>();
128
129                let views = if let Some(Entry::ToolCall(tool_call)) = self.entries.get_mut(index) {
130                    &mut tool_call.content
131                } else {
132                    self.set_entry(
133                        index,
134                        Entry::ToolCall(ToolCallEntry {
135                            content: HashMap::default(),
136                        }),
137                    );
138                    let Some(Entry::ToolCall(tool_call)) = self.entries.get_mut(index) else {
139                        unreachable!()
140                    };
141                    &mut tool_call.content
142                };
143
144                let is_tool_call_completed =
145                    matches!(tool_call.status, acp_thread::ToolCallStatus::Completed);
146
147                for terminal in terminals {
148                    match views.entry(terminal.entity_id()) {
149                        collections::hash_map::Entry::Vacant(entry) => {
150                            let element = create_terminal(
151                                self.workspace.clone(),
152                                self.project.clone(),
153                                terminal.clone(),
154                                window,
155                                cx,
156                            )
157                            .into_any();
158                            cx.emit(EntryViewEvent {
159                                entry_index: index,
160                                view_event: ViewEvent::NewTerminal(id.clone()),
161                            });
162                            entry.insert(element);
163                        }
164                        collections::hash_map::Entry::Occupied(_entry) => {
165                            if is_tool_call_completed && terminal.read(cx).output().is_none() {
166                                cx.emit(EntryViewEvent {
167                                    entry_index: index,
168                                    view_event: ViewEvent::TerminalMovedToBackground(id.clone()),
169                                });
170                            }
171                        }
172                    }
173                }
174
175                for diff in diffs {
176                    views.entry(diff.entity_id()).or_insert_with(|| {
177                        let editor = create_editor_diff(diff.clone(), window, cx);
178                        cx.subscribe(&editor, {
179                            let diff = diff.clone();
180                            let entry_index = index;
181                            move |_this, _editor, event: &EditorEvent, cx| {
182                                if let EditorEvent::OpenExcerptsRequested {
183                                    selections_by_buffer,
184                                    split,
185                                } = event
186                                {
187                                    let multibuffer = diff.read(cx).multibuffer();
188                                    if let Some((buffer_id, (ranges, _))) =
189                                        selections_by_buffer.iter().next()
190                                    {
191                                        if let Some(buffer) =
192                                            multibuffer.read(cx).buffer(*buffer_id)
193                                        {
194                                            if let Some(range) = ranges.first() {
195                                                let point =
196                                                    buffer.read(cx).offset_to_point(range.start.0);
197                                                if let Some(path) = diff.read(cx).file_path(cx) {
198                                                    cx.emit(EntryViewEvent {
199                                                        entry_index,
200                                                        view_event: ViewEvent::OpenDiffLocation {
201                                                            path,
202                                                            position: point,
203                                                            split: *split,
204                                                        },
205                                                    });
206                                                }
207                                            }
208                                        }
209                                    }
210                                }
211                            }
212                        })
213                        .detach();
214                        cx.emit(EntryViewEvent {
215                            entry_index: index,
216                            view_event: ViewEvent::NewDiff(id.clone()),
217                        });
218                        editor.into_any()
219                    });
220                }
221            }
222            AgentThreadEntry::AssistantMessage(message) => {
223                let entry = if let Some(Entry::AssistantMessage(entry)) =
224                    self.entries.get_mut(index)
225                {
226                    entry
227                } else {
228                    self.set_entry(
229                        index,
230                        Entry::AssistantMessage(AssistantMessageEntry {
231                            scroll_handles_by_chunk_index: HashMap::default(),
232                            focus_handle: cx.focus_handle(),
233                        }),
234                    );
235                    let Some(Entry::AssistantMessage(entry)) = self.entries.get_mut(index) else {
236                        unreachable!()
237                    };
238                    entry
239                };
240                entry.sync(message);
241            }
242        };
243    }
244
245    fn set_entry(&mut self, index: usize, entry: Entry) {
246        if index == self.entries.len() {
247            self.entries.push(entry);
248        } else {
249            self.entries[index] = entry;
250        }
251    }
252
253    pub fn remove(&mut self, range: Range<usize>) {
254        self.entries.drain(range);
255    }
256
257    pub fn agent_ui_font_size_changed(&mut self, cx: &mut App) {
258        for entry in self.entries.iter() {
259            match entry {
260                Entry::UserMessage { .. } | Entry::AssistantMessage { .. } => {}
261                Entry::ToolCall(ToolCallEntry { content }) => {
262                    for view in content.values() {
263                        if let Ok(diff_editor) = view.clone().downcast::<Editor>() {
264                            diff_editor.update(cx, |diff_editor, cx| {
265                                diff_editor.set_text_style_refinement(
266                                    diff_editor_text_style_refinement(cx),
267                                );
268                                cx.notify();
269                            })
270                        }
271                    }
272                }
273            }
274        }
275    }
276}
277
278impl EventEmitter<EntryViewEvent> for EntryViewState {}
279
280pub struct EntryViewEvent {
281    pub entry_index: usize,
282    pub view_event: ViewEvent,
283}
284
285pub enum ViewEvent {
286    NewDiff(ToolCallId),
287    NewTerminal(ToolCallId),
288    TerminalMovedToBackground(ToolCallId),
289    MessageEditorEvent(Entity<MessageEditor>, MessageEditorEvent),
290    OpenDiffLocation {
291        path: String,
292        position: Point,
293        split: bool,
294    },
295}
296
297#[derive(Debug)]
298pub struct AssistantMessageEntry {
299    scroll_handles_by_chunk_index: HashMap<usize, ScrollHandle>,
300    focus_handle: FocusHandle,
301}
302
303impl AssistantMessageEntry {
304    pub fn scroll_handle_for_chunk(&self, ix: usize) -> Option<ScrollHandle> {
305        self.scroll_handles_by_chunk_index.get(&ix).cloned()
306    }
307
308    pub fn sync(&mut self, message: &acp_thread::AssistantMessage) {
309        if let Some(acp_thread::AssistantMessageChunk::Thought { .. }) = message.chunks.last() {
310            let ix = message.chunks.len() - 1;
311            let handle = self.scroll_handles_by_chunk_index.entry(ix).or_default();
312            handle.scroll_to_bottom();
313        }
314    }
315}
316
317#[derive(Debug)]
318pub struct ToolCallEntry {
319    content: HashMap<EntityId, AnyEntity>,
320}
321
322#[derive(Debug)]
323pub enum Entry {
324    UserMessage(Entity<MessageEditor>),
325    AssistantMessage(AssistantMessageEntry),
326    ToolCall(ToolCallEntry),
327}
328
329impl Entry {
330    pub fn focus_handle(&self, cx: &App) -> Option<FocusHandle> {
331        match self {
332            Self::UserMessage(editor) => Some(editor.read(cx).focus_handle(cx)),
333            Self::AssistantMessage(message) => Some(message.focus_handle.clone()),
334            Self::ToolCall(_) => None,
335        }
336    }
337
338    pub fn message_editor(&self) -> Option<&Entity<MessageEditor>> {
339        match self {
340            Self::UserMessage(editor) => Some(editor),
341            Self::AssistantMessage(_) | Self::ToolCall(_) => None,
342        }
343    }
344
345    pub fn editor_for_diff(&self, diff: &Entity<acp_thread::Diff>) -> Option<Entity<Editor>> {
346        self.content_map()?
347            .get(&diff.entity_id())
348            .cloned()
349            .map(|entity| entity.downcast::<Editor>().unwrap())
350    }
351
352    pub fn terminal(
353        &self,
354        terminal: &Entity<acp_thread::Terminal>,
355    ) -> Option<Entity<TerminalView>> {
356        self.content_map()?
357            .get(&terminal.entity_id())
358            .cloned()
359            .map(|entity| entity.downcast::<TerminalView>().unwrap())
360    }
361
362    pub fn scroll_handle_for_assistant_message_chunk(
363        &self,
364        chunk_ix: usize,
365    ) -> Option<ScrollHandle> {
366        match self {
367            Self::AssistantMessage(message) => message.scroll_handle_for_chunk(chunk_ix),
368            Self::UserMessage(_) | Self::ToolCall(_) => None,
369        }
370    }
371
372    fn content_map(&self) -> Option<&HashMap<EntityId, AnyEntity>> {
373        match self {
374            Self::ToolCall(ToolCallEntry { content }) => Some(content),
375            _ => None,
376        }
377    }
378
379    #[cfg(test)]
380    pub fn has_content(&self) -> bool {
381        match self {
382            Self::ToolCall(ToolCallEntry { content }) => !content.is_empty(),
383            Self::UserMessage(_) | Self::AssistantMessage(_) => false,
384        }
385    }
386}
387
388fn create_terminal(
389    workspace: WeakEntity<Workspace>,
390    project: WeakEntity<Project>,
391    terminal: Entity<acp_thread::Terminal>,
392    window: &mut Window,
393    cx: &mut App,
394) -> Entity<TerminalView> {
395    cx.new(|cx| {
396        let mut view = TerminalView::new(
397            terminal.read(cx).inner().clone(),
398            workspace,
399            None,
400            project,
401            window,
402            cx,
403        );
404        view.set_embedded_mode(Some(1000), cx);
405        view
406    })
407}
408
409fn create_editor_diff(
410    diff: Entity<acp_thread::Diff>,
411    window: &mut Window,
412    cx: &mut App,
413) -> Entity<Editor> {
414    cx.new(|cx| {
415        let mut editor = Editor::new(
416            EditorMode::Full {
417                scale_ui_elements_with_buffer_font_size: false,
418                show_active_line_background: false,
419                sizing_behavior: SizingBehavior::SizeByContent,
420            },
421            diff.read(cx).multibuffer().clone(),
422            None,
423            window,
424            cx,
425        );
426        editor.set_show_gutter(false, cx);
427        editor.disable_inline_diagnostics();
428        editor.disable_expand_excerpt_buttons(cx);
429        editor.set_show_vertical_scrollbar(false, cx);
430        editor.set_minimap_visibility(MinimapVisibility::Disabled, window, cx);
431        editor.set_soft_wrap_mode(SoftWrap::None, cx);
432        editor.scroll_manager.set_forbid_vertical_scroll(true);
433        editor.set_show_indent_guides(false, cx);
434        editor.set_read_only(true);
435        editor.set_delegate_open_excerpts(true);
436        editor.set_show_breakpoints(false, cx);
437        editor.set_show_code_actions(false, cx);
438        editor.set_show_git_diff_gutter(false, cx);
439        editor.set_expand_all_diff_hunks(cx);
440        editor.set_text_style_refinement(diff_editor_text_style_refinement(cx));
441        editor
442    })
443}
444
445fn diff_editor_text_style_refinement(cx: &mut App) -> TextStyleRefinement {
446    TextStyleRefinement {
447        font_size: Some(
448            TextSize::Small
449                .rems(cx)
450                .to_pixels(ThemeSettings::get_global(cx).agent_ui_font_size(cx))
451                .into(),
452        ),
453        ..Default::default()
454    }
455}
456
457#[cfg(test)]
458mod tests {
459    use std::path::Path;
460    use std::rc::Rc;
461
462    use acp_thread::{AgentConnection, StubAgentConnection};
463    use agent_client_protocol as acp;
464    use buffer_diff::{DiffHunkStatus, DiffHunkStatusKind};
465    use editor::RowInfo;
466    use fs::FakeFs;
467    use gpui::{AppContext as _, TestAppContext};
468
469    use crate::entry_view_state::EntryViewState;
470    use multi_buffer::MultiBufferRow;
471    use pretty_assertions::assert_matches;
472    use project::Project;
473    use serde_json::json;
474    use settings::SettingsStore;
475    use util::path;
476    use workspace::{MultiWorkspace, PathList};
477
478    #[gpui::test]
479    async fn test_diff_sync(cx: &mut TestAppContext) {
480        init_test(cx);
481        let fs = FakeFs::new(cx.executor());
482        fs.insert_tree(
483            "/project",
484            json!({
485                "hello.txt": "hi world"
486            }),
487        )
488        .await;
489        let project = Project::test(fs, [Path::new(path!("/project"))], cx).await;
490
491        let (multi_workspace, cx) =
492            cx.add_window_view(|window, cx| MultiWorkspace::test_new(project.clone(), window, cx));
493        let workspace = multi_workspace.read_with(cx, |mw, _| mw.workspace().clone());
494
495        let tool_call = acp::ToolCall::new("tool", "Tool call")
496            .status(acp::ToolCallStatus::InProgress)
497            .content(vec![acp::ToolCallContent::Diff(
498                acp::Diff::new("/project/hello.txt", "hello world").old_text("hi world"),
499            )]);
500        let connection = Rc::new(StubAgentConnection::new());
501        let thread = cx
502            .update(|_, cx| {
503                connection.clone().new_session(
504                    project.clone(),
505                    PathList::new(&[Path::new(path!("/project"))]),
506                    cx,
507                )
508            })
509            .await
510            .unwrap();
511        let session_id = thread.update(cx, |thread, _| thread.session_id().clone());
512
513        cx.update(|_, cx| {
514            connection.send_update(session_id, acp::SessionUpdate::ToolCall(tool_call), cx)
515        });
516
517        let thread_store = None;
518        let history: Option<gpui::WeakEntity<crate::ThreadHistory>> = None;
519
520        let view_state = cx.new(|_cx| {
521            EntryViewState::new(
522                workspace.downgrade(),
523                project.downgrade(),
524                thread_store,
525                history,
526                None,
527                Default::default(),
528                Default::default(),
529                "Test Agent".into(),
530            )
531        });
532
533        view_state.update_in(cx, |view_state, window, cx| {
534            view_state.sync_entry(0, &thread, window, cx)
535        });
536
537        let diff = thread.read_with(cx, |thread, _| {
538            thread
539                .entries()
540                .get(0)
541                .unwrap()
542                .diffs()
543                .next()
544                .unwrap()
545                .clone()
546        });
547
548        cx.run_until_parked();
549
550        let diff_editor = view_state.read_with(cx, |view_state, _cx| {
551            view_state.entry(0).unwrap().editor_for_diff(&diff).unwrap()
552        });
553        assert_eq!(
554            diff_editor.read_with(cx, |editor, cx| editor.text(cx)),
555            "hi world\nhello world"
556        );
557        let row_infos = diff_editor.read_with(cx, |editor, cx| {
558            let multibuffer = editor.buffer().read(cx);
559            multibuffer
560                .snapshot(cx)
561                .row_infos(MultiBufferRow(0))
562                .collect::<Vec<_>>()
563        });
564        assert_matches!(
565            row_infos.as_slice(),
566            [
567                RowInfo {
568                    multibuffer_row: Some(MultiBufferRow(0)),
569                    diff_status: Some(DiffHunkStatus {
570                        kind: DiffHunkStatusKind::Deleted,
571                        ..
572                    }),
573                    ..
574                },
575                RowInfo {
576                    multibuffer_row: Some(MultiBufferRow(1)),
577                    diff_status: Some(DiffHunkStatus {
578                        kind: DiffHunkStatusKind::Added,
579                        ..
580                    }),
581                    ..
582                }
583            ]
584        );
585    }
586
587    fn init_test(cx: &mut TestAppContext) {
588        cx.update(|cx| {
589            let settings_store = SettingsStore::test(cx);
590            cx.set_global(settings_store);
591            theme::init(theme::LoadThemes::JustBase, cx);
592            release_channel::init(semver::Version::new(0, 0, 0), cx);
593        });
594    }
595}