entry_view_state.rs

  1use std::ops::Range;
  2
  3use super::thread_history::ThreadHistory;
  4use acp_thread::{AcpThread, AgentThreadEntry};
  5use agent::ThreadStore;
  6use agent_client_protocol::ToolCallId;
  7use collections::HashMap;
  8use editor::{Editor, EditorEvent, EditorMode, MinimapVisibility, SizingBehavior};
  9use gpui::{
 10    AnyEntity, App, AppContext as _, Entity, EntityId, EventEmitter, FocusHandle, Focusable,
 11    ScrollHandle, TextStyleRefinement, WeakEntity, Window,
 12};
 13use language::language_settings::SoftWrap;
 14use project::{AgentId, Project};
 15use prompt_store::PromptStore;
 16use rope::Point;
 17use settings::Settings as _;
 18use terminal_view::TerminalView;
 19use theme_settings::ThemeSettings;
 20use ui::{Context, TextSize};
 21use workspace::Workspace;
 22
 23use crate::message_editor::{MessageEditor, MessageEditorEvent, SharedSessionCapabilities};
 24
 25pub struct EntryViewState {
 26    workspace: WeakEntity<Workspace>,
 27    project: WeakEntity<Project>,
 28    thread_store: Option<Entity<ThreadStore>>,
 29    history: Option<WeakEntity<ThreadHistory>>,
 30    prompt_store: Option<Entity<PromptStore>>,
 31    entries: Vec<Entry>,
 32    session_capabilities: SharedSessionCapabilities,
 33    agent_id: AgentId,
 34}
 35
 36impl EntryViewState {
 37    pub fn new(
 38        workspace: WeakEntity<Workspace>,
 39        project: WeakEntity<Project>,
 40        thread_store: Option<Entity<ThreadStore>>,
 41        history: Option<WeakEntity<ThreadHistory>>,
 42        prompt_store: Option<Entity<PromptStore>>,
 43        session_capabilities: SharedSessionCapabilities,
 44        agent_id: AgentId,
 45    ) -> Self {
 46        Self {
 47            workspace,
 48            project,
 49            thread_store,
 50            history,
 51            prompt_store,
 52            entries: Vec::new(),
 53            session_capabilities,
 54            agent_id,
 55        }
 56    }
 57
 58    pub fn entry(&self, index: usize) -> Option<&Entry> {
 59        self.entries.get(index)
 60    }
 61
 62    pub fn sync_entry(
 63        &mut self,
 64        index: usize,
 65        thread: &Entity<AcpThread>,
 66        window: &mut Window,
 67        cx: &mut Context<Self>,
 68    ) {
 69        let Some(thread_entry) = thread.read(cx).entries().get(index) else {
 70            return;
 71        };
 72
 73        match thread_entry {
 74            AgentThreadEntry::UserMessage(message) => {
 75                let can_rewind = thread.read(cx).supports_truncate(cx);
 76                let has_id = message.id.is_some();
 77                let is_subagent = thread.read(cx).parent_session_id().is_some();
 78                let chunks = message.chunks.clone();
 79                if let Some(Entry::UserMessage(editor)) = self.entries.get_mut(index) {
 80                    if !editor.focus_handle(cx).is_focused(window) {
 81                        // Only update if we are not editing.
 82                        // If we are, cancelling the edit will set the message to the newest content.
 83                        editor.update(cx, |editor, cx| {
 84                            editor.set_message(chunks, window, cx);
 85                        });
 86                    }
 87                } else {
 88                    let message_editor = cx.new(|cx| {
 89                        let mut editor = MessageEditor::new(
 90                            self.workspace.clone(),
 91                            self.project.clone(),
 92                            self.thread_store.clone(),
 93                            self.history.clone(),
 94                            self.prompt_store.clone(),
 95                            self.session_capabilities.clone(),
 96                            self.agent_id.clone(),
 97                            "Edit message - @ to include context",
 98                            editor::EditorMode::AutoHeight {
 99                                min_lines: 1,
100                                max_lines: None,
101                            },
102                            window,
103                            cx,
104                        );
105                        if !can_rewind || !has_id || is_subagent {
106                            editor.set_read_only(true, cx);
107                        }
108                        editor.set_message(chunks, window, cx);
109                        editor
110                    });
111                    cx.subscribe(&message_editor, move |_, editor, event, cx| {
112                        cx.emit(EntryViewEvent {
113                            entry_index: index,
114                            view_event: ViewEvent::MessageEditorEvent(editor, event.clone()),
115                        })
116                    })
117                    .detach();
118                    self.set_entry(index, Entry::UserMessage(message_editor));
119                }
120            }
121            AgentThreadEntry::ToolCall(tool_call) => {
122                let id = tool_call.id.clone();
123                let terminals = tool_call.terminals().cloned().collect::<Vec<_>>();
124                let diffs = tool_call.diffs().cloned().collect::<Vec<_>>();
125
126                let views = if let Some(Entry::ToolCall(tool_call)) = self.entries.get_mut(index) {
127                    &mut tool_call.content
128                } else {
129                    self.set_entry(
130                        index,
131                        Entry::ToolCall(ToolCallEntry {
132                            content: HashMap::default(),
133                        }),
134                    );
135                    let Some(Entry::ToolCall(tool_call)) = self.entries.get_mut(index) else {
136                        unreachable!()
137                    };
138                    &mut tool_call.content
139                };
140
141                let is_tool_call_completed =
142                    matches!(tool_call.status, acp_thread::ToolCallStatus::Completed);
143
144                for terminal in terminals {
145                    match views.entry(terminal.entity_id()) {
146                        collections::hash_map::Entry::Vacant(entry) => {
147                            let element = create_terminal(
148                                self.workspace.clone(),
149                                self.project.clone(),
150                                terminal.clone(),
151                                window,
152                                cx,
153                            )
154                            .into_any();
155                            cx.emit(EntryViewEvent {
156                                entry_index: index,
157                                view_event: ViewEvent::NewTerminal(id.clone()),
158                            });
159                            entry.insert(element);
160                        }
161                        collections::hash_map::Entry::Occupied(_entry) => {
162                            if is_tool_call_completed && terminal.read(cx).output().is_none() {
163                                cx.emit(EntryViewEvent {
164                                    entry_index: index,
165                                    view_event: ViewEvent::TerminalMovedToBackground(id.clone()),
166                                });
167                            }
168                        }
169                    }
170                }
171
172                for diff in diffs {
173                    views.entry(diff.entity_id()).or_insert_with(|| {
174                        let editor = create_editor_diff(diff.clone(), window, cx);
175                        cx.subscribe(&editor, {
176                            let diff = diff.clone();
177                            let entry_index = index;
178                            move |_this, _editor, event: &EditorEvent, cx| {
179                                if let EditorEvent::OpenExcerptsRequested {
180                                    selections_by_buffer,
181                                    split,
182                                } = event
183                                {
184                                    let multibuffer = diff.read(cx).multibuffer();
185                                    if let Some((buffer_id, (ranges, _))) =
186                                        selections_by_buffer.iter().next()
187                                    {
188                                        if let Some(buffer) =
189                                            multibuffer.read(cx).buffer(*buffer_id)
190                                        {
191                                            if let Some(range) = ranges.first() {
192                                                let point =
193                                                    buffer.read(cx).offset_to_point(range.start.0);
194                                                if let Some(path) = diff.read(cx).file_path(cx) {
195                                                    cx.emit(EntryViewEvent {
196                                                        entry_index,
197                                                        view_event: ViewEvent::OpenDiffLocation {
198                                                            path,
199                                                            position: point,
200                                                            split: *split,
201                                                        },
202                                                    });
203                                                }
204                                            }
205                                        }
206                                    }
207                                }
208                            }
209                        })
210                        .detach();
211                        cx.emit(EntryViewEvent {
212                            entry_index: index,
213                            view_event: ViewEvent::NewDiff(id.clone()),
214                        });
215                        editor.into_any()
216                    });
217                }
218            }
219            AgentThreadEntry::AssistantMessage(message) => {
220                let entry = if let Some(Entry::AssistantMessage(entry)) =
221                    self.entries.get_mut(index)
222                {
223                    entry
224                } else {
225                    self.set_entry(
226                        index,
227                        Entry::AssistantMessage(AssistantMessageEntry {
228                            scroll_handles_by_chunk_index: HashMap::default(),
229                            focus_handle: cx.focus_handle(),
230                        }),
231                    );
232                    let Some(Entry::AssistantMessage(entry)) = self.entries.get_mut(index) else {
233                        unreachable!()
234                    };
235                    entry
236                };
237                entry.sync(message);
238            }
239            AgentThreadEntry::CompletedPlan(_) => {
240                if !matches!(self.entries.get(index), Some(Entry::CompletedPlan)) {
241                    self.set_entry(index, Entry::CompletedPlan);
242                }
243            }
244        };
245    }
246
247    fn set_entry(&mut self, index: usize, entry: Entry) {
248        if index == self.entries.len() {
249            self.entries.push(entry);
250        } else {
251            self.entries[index] = entry;
252        }
253    }
254
255    pub fn remove(&mut self, range: Range<usize>) {
256        self.entries.drain(range);
257    }
258
259    pub fn agent_ui_font_size_changed(&mut self, cx: &mut App) {
260        for entry in self.entries.iter() {
261            match entry {
262                Entry::UserMessage { .. }
263                | Entry::AssistantMessage { .. }
264                | Entry::CompletedPlan => {}
265                Entry::ToolCall(ToolCallEntry { content }) => {
266                    for view in content.values() {
267                        if let Ok(diff_editor) = view.clone().downcast::<Editor>() {
268                            diff_editor.update(cx, |diff_editor, cx| {
269                                diff_editor.set_text_style_refinement(
270                                    diff_editor_text_style_refinement(cx),
271                                );
272                                cx.notify();
273                            })
274                        }
275                    }
276                }
277            }
278        }
279    }
280}
281
282impl EventEmitter<EntryViewEvent> for EntryViewState {}
283
284pub struct EntryViewEvent {
285    pub entry_index: usize,
286    pub view_event: ViewEvent,
287}
288
289pub enum ViewEvent {
290    NewDiff(ToolCallId),
291    NewTerminal(ToolCallId),
292    TerminalMovedToBackground(ToolCallId),
293    MessageEditorEvent(Entity<MessageEditor>, MessageEditorEvent),
294    OpenDiffLocation {
295        path: String,
296        position: Point,
297        split: bool,
298    },
299}
300
301#[derive(Debug)]
302pub struct AssistantMessageEntry {
303    scroll_handles_by_chunk_index: HashMap<usize, ScrollHandle>,
304    focus_handle: FocusHandle,
305}
306
307impl AssistantMessageEntry {
308    pub fn scroll_handle_for_chunk(&self, ix: usize) -> Option<ScrollHandle> {
309        self.scroll_handles_by_chunk_index.get(&ix).cloned()
310    }
311
312    pub fn sync(&mut self, message: &acp_thread::AssistantMessage) {
313        if let Some(acp_thread::AssistantMessageChunk::Thought { .. }) = message.chunks.last() {
314            let ix = message.chunks.len() - 1;
315            let handle = self.scroll_handles_by_chunk_index.entry(ix).or_default();
316            handle.scroll_to_bottom();
317        }
318    }
319}
320
321#[derive(Debug)]
322pub struct ToolCallEntry {
323    content: HashMap<EntityId, AnyEntity>,
324}
325
326#[derive(Debug)]
327pub enum Entry {
328    UserMessage(Entity<MessageEditor>),
329    AssistantMessage(AssistantMessageEntry),
330    ToolCall(ToolCallEntry),
331    CompletedPlan,
332}
333
334impl Entry {
335    pub fn focus_handle(&self, cx: &App) -> Option<FocusHandle> {
336        match self {
337            Self::UserMessage(editor) => Some(editor.read(cx).focus_handle(cx)),
338            Self::AssistantMessage(message) => Some(message.focus_handle.clone()),
339            Self::ToolCall(_) | Self::CompletedPlan => None,
340        }
341    }
342
343    pub fn message_editor(&self) -> Option<&Entity<MessageEditor>> {
344        match self {
345            Self::UserMessage(editor) => Some(editor),
346            Self::AssistantMessage(_) | Self::ToolCall(_) | Self::CompletedPlan => None,
347        }
348    }
349
350    pub fn editor_for_diff(&self, diff: &Entity<acp_thread::Diff>) -> Option<Entity<Editor>> {
351        self.content_map()?
352            .get(&diff.entity_id())
353            .cloned()
354            .map(|entity| entity.downcast::<Editor>().unwrap())
355    }
356
357    pub fn terminal(
358        &self,
359        terminal: &Entity<acp_thread::Terminal>,
360    ) -> Option<Entity<TerminalView>> {
361        self.content_map()?
362            .get(&terminal.entity_id())
363            .cloned()
364            .map(|entity| entity.downcast::<TerminalView>().unwrap())
365    }
366
367    pub fn scroll_handle_for_assistant_message_chunk(
368        &self,
369        chunk_ix: usize,
370    ) -> Option<ScrollHandle> {
371        match self {
372            Self::AssistantMessage(message) => message.scroll_handle_for_chunk(chunk_ix),
373            Self::UserMessage(_) | Self::ToolCall(_) | Self::CompletedPlan => None,
374        }
375    }
376
377    fn content_map(&self) -> Option<&HashMap<EntityId, AnyEntity>> {
378        match self {
379            Self::ToolCall(ToolCallEntry { content }) => Some(content),
380            _ => None,
381        }
382    }
383
384    #[cfg(test)]
385    pub fn has_content(&self) -> bool {
386        match self {
387            Self::ToolCall(ToolCallEntry { content }) => !content.is_empty(),
388            Self::UserMessage(_) | Self::AssistantMessage(_) | Self::CompletedPlan => false,
389        }
390    }
391}
392
393fn create_terminal(
394    workspace: WeakEntity<Workspace>,
395    project: WeakEntity<Project>,
396    terminal: Entity<acp_thread::Terminal>,
397    window: &mut Window,
398    cx: &mut App,
399) -> Entity<TerminalView> {
400    cx.new(|cx| {
401        let mut view = TerminalView::new(
402            terminal.read(cx).inner().clone(),
403            workspace,
404            None,
405            project,
406            window,
407            cx,
408        );
409        view.set_embedded_mode(Some(1000), cx);
410        view
411    })
412}
413
414fn create_editor_diff(
415    diff: Entity<acp_thread::Diff>,
416    window: &mut Window,
417    cx: &mut App,
418) -> Entity<Editor> {
419    cx.new(|cx| {
420        let mut editor = Editor::new(
421            EditorMode::Full {
422                scale_ui_elements_with_buffer_font_size: false,
423                show_active_line_background: false,
424                sizing_behavior: SizingBehavior::SizeByContent,
425            },
426            diff.read(cx).multibuffer().clone(),
427            None,
428            window,
429            cx,
430        );
431        editor.set_show_gutter(false, cx);
432        editor.disable_inline_diagnostics();
433        editor.disable_expand_excerpt_buttons(cx);
434        editor.set_show_vertical_scrollbar(false, cx);
435        editor.set_minimap_visibility(MinimapVisibility::Disabled, window, cx);
436        editor.set_soft_wrap_mode(SoftWrap::None, cx);
437        editor.scroll_manager.set_forbid_vertical_scroll(true);
438        editor.set_show_indent_guides(false, cx);
439        editor.set_read_only(true);
440        editor.set_delegate_open_excerpts(true);
441        editor.set_show_breakpoints(false, cx);
442        editor.set_show_code_actions(false, cx);
443        editor.set_show_git_diff_gutter(false, cx);
444        editor.set_expand_all_diff_hunks(cx);
445        editor.set_text_style_refinement(diff_editor_text_style_refinement(cx));
446        editor
447    })
448}
449
450fn diff_editor_text_style_refinement(cx: &mut App) -> TextStyleRefinement {
451    TextStyleRefinement {
452        font_size: Some(
453            TextSize::Small
454                .rems(cx)
455                .to_pixels(ThemeSettings::get_global(cx).agent_ui_font_size(cx))
456                .into(),
457        ),
458        ..Default::default()
459    }
460}
461
462#[cfg(test)]
463mod tests {
464    use std::path::Path;
465    use std::rc::Rc;
466    use std::sync::Arc;
467
468    use acp_thread::{AgentConnection, StubAgentConnection};
469    use agent_client_protocol as acp;
470    use buffer_diff::{DiffHunkStatus, DiffHunkStatusKind};
471    use editor::RowInfo;
472    use fs::FakeFs;
473    use gpui::{AppContext as _, TestAppContext};
474    use parking_lot::RwLock;
475
476    use crate::entry_view_state::EntryViewState;
477    use crate::message_editor::SessionCapabilities;
478    use multi_buffer::MultiBufferRow;
479    use pretty_assertions::assert_matches;
480    use project::Project;
481    use serde_json::json;
482    use settings::SettingsStore;
483    use util::path;
484    use workspace::{MultiWorkspace, PathList};
485
486    #[gpui::test]
487    async fn test_diff_sync(cx: &mut TestAppContext) {
488        init_test(cx);
489        let fs = FakeFs::new(cx.executor());
490        fs.insert_tree(
491            "/project",
492            json!({
493                "hello.txt": "hi world"
494            }),
495        )
496        .await;
497        let project = Project::test(fs, [Path::new(path!("/project"))], cx).await;
498
499        let (multi_workspace, cx) =
500            cx.add_window_view(|window, cx| MultiWorkspace::test_new(project.clone(), window, cx));
501        let workspace = multi_workspace.read_with(cx, |mw, _| mw.workspace().clone());
502
503        let tool_call = acp::ToolCall::new("tool", "Tool call")
504            .status(acp::ToolCallStatus::InProgress)
505            .content(vec![acp::ToolCallContent::Diff(
506                acp::Diff::new("/project/hello.txt", "hello world").old_text("hi world"),
507            )]);
508        let connection = Rc::new(StubAgentConnection::new());
509        let thread = cx
510            .update(|_, cx| {
511                connection.clone().new_session(
512                    project.clone(),
513                    PathList::new(&[Path::new(path!("/project"))]),
514                    cx,
515                )
516            })
517            .await
518            .unwrap();
519        let session_id = thread.update(cx, |thread, _| thread.session_id().clone());
520
521        cx.update(|_, cx| {
522            connection.send_update(session_id, acp::SessionUpdate::ToolCall(tool_call), cx)
523        });
524
525        let thread_store = None;
526        let history: Option<gpui::WeakEntity<crate::ThreadHistory>> = None;
527
528        let view_state = cx.new(|_cx| {
529            EntryViewState::new(
530                workspace.downgrade(),
531                project.downgrade(),
532                thread_store,
533                history,
534                None,
535                Arc::new(RwLock::new(SessionCapabilities::default())),
536                "Test Agent".into(),
537            )
538        });
539
540        view_state.update_in(cx, |view_state, window, cx| {
541            view_state.sync_entry(0, &thread, window, cx)
542        });
543
544        let diff = thread.read_with(cx, |thread, _| {
545            thread
546                .entries()
547                .get(0)
548                .unwrap()
549                .diffs()
550                .next()
551                .unwrap()
552                .clone()
553        });
554
555        cx.run_until_parked();
556
557        let diff_editor = view_state.read_with(cx, |view_state, _cx| {
558            view_state.entry(0).unwrap().editor_for_diff(&diff).unwrap()
559        });
560        assert_eq!(
561            diff_editor.read_with(cx, |editor, cx| editor.text(cx)),
562            "hi world\nhello world"
563        );
564        let row_infos = diff_editor.read_with(cx, |editor, cx| {
565            let multibuffer = editor.buffer().read(cx);
566            multibuffer
567                .snapshot(cx)
568                .row_infos(MultiBufferRow(0))
569                .collect::<Vec<_>>()
570        });
571        assert_matches!(
572            row_infos.as_slice(),
573            [
574                RowInfo {
575                    multibuffer_row: Some(MultiBufferRow(0)),
576                    diff_status: Some(DiffHunkStatus {
577                        kind: DiffHunkStatusKind::Deleted,
578                        ..
579                    }),
580                    ..
581                },
582                RowInfo {
583                    multibuffer_row: Some(MultiBufferRow(1)),
584                    diff_status: Some(DiffHunkStatus {
585                        kind: DiffHunkStatusKind::Added,
586                        ..
587                    }),
588                    ..
589                }
590            ]
591        );
592    }
593
594    fn init_test(cx: &mut TestAppContext) {
595        cx.update(|cx| {
596            let settings_store = SettingsStore::test(cx);
597            cx.set_global(settings_store);
598            theme_settings::init(theme::LoadThemes::JustBase, cx);
599            release_channel::init(semver::Version::new(0, 0, 0), cx);
600        });
601    }
602}