entry_view_state.rs

  1use std::ops::Range;
  2
  3use super::thread_history::ThreadHistory;
  4use acp_thread::{AcpThread, AgentThreadEntry};
  5use agent::ThreadStore;
  6use agent_client_protocol::ToolCallId;
  7use collections::HashMap;
  8use editor::{Editor, EditorEvent, EditorMode, MinimapVisibility, SizingBehavior};
  9use gpui::{
 10    AnyEntity, App, AppContext as _, Entity, EntityId, EventEmitter, FocusHandle, Focusable,
 11    ScrollHandle, TextStyleRefinement, WeakEntity, Window,
 12};
 13use language::language_settings::SoftWrap;
 14use project::{AgentId, Project};
 15use prompt_store::PromptStore;
 16use rope::Point;
 17use settings::Settings as _;
 18use terminal_view::TerminalView;
 19use theme_settings::ThemeSettings;
 20use ui::{Context, TextSize};
 21use workspace::Workspace;
 22
 23use crate::message_editor::{MessageEditor, MessageEditorEvent, SharedSessionCapabilities};
 24
 25pub struct EntryViewState {
 26    workspace: WeakEntity<Workspace>,
 27    project: WeakEntity<Project>,
 28    thread_store: Option<Entity<ThreadStore>>,
 29    history: Option<WeakEntity<ThreadHistory>>,
 30    prompt_store: Option<Entity<PromptStore>>,
 31    entries: Vec<Entry>,
 32    session_capabilities: SharedSessionCapabilities,
 33    agent_id: AgentId,
 34}
 35
 36impl EntryViewState {
 37    pub fn new(
 38        workspace: WeakEntity<Workspace>,
 39        project: WeakEntity<Project>,
 40        thread_store: Option<Entity<ThreadStore>>,
 41        history: Option<WeakEntity<ThreadHistory>>,
 42        prompt_store: Option<Entity<PromptStore>>,
 43        session_capabilities: SharedSessionCapabilities,
 44        agent_id: AgentId,
 45    ) -> Self {
 46        Self {
 47            workspace,
 48            project,
 49            thread_store,
 50            history,
 51            prompt_store,
 52            entries: Vec::new(),
 53            session_capabilities,
 54            agent_id,
 55        }
 56    }
 57
 58    pub fn entry(&self, index: usize) -> Option<&Entry> {
 59        self.entries.get(index)
 60    }
 61
 62    pub fn sync_entry(
 63        &mut self,
 64        index: usize,
 65        thread: &Entity<AcpThread>,
 66        window: &mut Window,
 67        cx: &mut Context<Self>,
 68    ) {
 69        let Some(thread_entry) = thread.read(cx).entries().get(index) else {
 70            return;
 71        };
 72
 73        match thread_entry {
 74            AgentThreadEntry::UserMessage(message) => {
 75                let has_id = message.id.is_some();
 76                let is_subagent = thread.read(cx).parent_session_id().is_some();
 77                let chunks = message.chunks.clone();
 78                if let Some(Entry::UserMessage(editor)) = self.entries.get_mut(index) {
 79                    if !editor.focus_handle(cx).is_focused(window) {
 80                        // Only update if we are not editing.
 81                        // If we are, cancelling the edit will set the message to the newest content.
 82                        editor.update(cx, |editor, cx| {
 83                            editor.set_message(chunks, window, cx);
 84                        });
 85                    }
 86                } else {
 87                    let message_editor = cx.new(|cx| {
 88                        let mut editor = MessageEditor::new(
 89                            self.workspace.clone(),
 90                            self.project.clone(),
 91                            self.thread_store.clone(),
 92                            self.history.clone(),
 93                            self.prompt_store.clone(),
 94                            self.session_capabilities.clone(),
 95                            self.agent_id.clone(),
 96                            "Edit message - @ to include context",
 97                            editor::EditorMode::AutoHeight {
 98                                min_lines: 1,
 99                                max_lines: None,
100                            },
101                            window,
102                            cx,
103                        );
104                        if !has_id || is_subagent {
105                            editor.set_read_only(true, cx);
106                        }
107                        editor.set_message(chunks, window, cx);
108                        editor
109                    });
110                    cx.subscribe(&message_editor, move |_, editor, event, cx| {
111                        cx.emit(EntryViewEvent {
112                            entry_index: index,
113                            view_event: ViewEvent::MessageEditorEvent(editor, event.clone()),
114                        })
115                    })
116                    .detach();
117                    self.set_entry(index, Entry::UserMessage(message_editor));
118                }
119            }
120            AgentThreadEntry::ToolCall(tool_call) => {
121                let id = tool_call.id.clone();
122                let terminals = tool_call.terminals().cloned().collect::<Vec<_>>();
123                let diffs = tool_call.diffs().cloned().collect::<Vec<_>>();
124
125                let views = if let Some(Entry::ToolCall(tool_call)) = self.entries.get_mut(index) {
126                    &mut tool_call.content
127                } else {
128                    self.set_entry(
129                        index,
130                        Entry::ToolCall(ToolCallEntry {
131                            content: HashMap::default(),
132                        }),
133                    );
134                    let Some(Entry::ToolCall(tool_call)) = self.entries.get_mut(index) else {
135                        unreachable!()
136                    };
137                    &mut tool_call.content
138                };
139
140                let is_tool_call_completed =
141                    matches!(tool_call.status, acp_thread::ToolCallStatus::Completed);
142
143                for terminal in terminals {
144                    match views.entry(terminal.entity_id()) {
145                        collections::hash_map::Entry::Vacant(entry) => {
146                            let element = create_terminal(
147                                self.workspace.clone(),
148                                self.project.clone(),
149                                terminal.clone(),
150                                window,
151                                cx,
152                            )
153                            .into_any();
154                            cx.emit(EntryViewEvent {
155                                entry_index: index,
156                                view_event: ViewEvent::NewTerminal(id.clone()),
157                            });
158                            entry.insert(element);
159                        }
160                        collections::hash_map::Entry::Occupied(_entry) => {
161                            if is_tool_call_completed && terminal.read(cx).output().is_none() {
162                                cx.emit(EntryViewEvent {
163                                    entry_index: index,
164                                    view_event: ViewEvent::TerminalMovedToBackground(id.clone()),
165                                });
166                            }
167                        }
168                    }
169                }
170
171                for diff in diffs {
172                    views.entry(diff.entity_id()).or_insert_with(|| {
173                        let editor = create_editor_diff(diff.clone(), window, cx);
174                        cx.subscribe(&editor, {
175                            let diff = diff.clone();
176                            let entry_index = index;
177                            move |_this, _editor, event: &EditorEvent, cx| {
178                                if let EditorEvent::OpenExcerptsRequested {
179                                    selections_by_buffer,
180                                    split,
181                                } = event
182                                {
183                                    let multibuffer = diff.read(cx).multibuffer();
184                                    if let Some((buffer_id, (ranges, _))) =
185                                        selections_by_buffer.iter().next()
186                                    {
187                                        if let Some(buffer) =
188                                            multibuffer.read(cx).buffer(*buffer_id)
189                                        {
190                                            if let Some(range) = ranges.first() {
191                                                let point =
192                                                    buffer.read(cx).offset_to_point(range.start.0);
193                                                if let Some(path) = diff.read(cx).file_path(cx) {
194                                                    cx.emit(EntryViewEvent {
195                                                        entry_index,
196                                                        view_event: ViewEvent::OpenDiffLocation {
197                                                            path,
198                                                            position: point,
199                                                            split: *split,
200                                                        },
201                                                    });
202                                                }
203                                            }
204                                        }
205                                    }
206                                }
207                            }
208                        })
209                        .detach();
210                        cx.emit(EntryViewEvent {
211                            entry_index: index,
212                            view_event: ViewEvent::NewDiff(id.clone()),
213                        });
214                        editor.into_any()
215                    });
216                }
217            }
218            AgentThreadEntry::AssistantMessage(message) => {
219                let entry = if let Some(Entry::AssistantMessage(entry)) =
220                    self.entries.get_mut(index)
221                {
222                    entry
223                } else {
224                    self.set_entry(
225                        index,
226                        Entry::AssistantMessage(AssistantMessageEntry {
227                            scroll_handles_by_chunk_index: HashMap::default(),
228                            focus_handle: cx.focus_handle(),
229                        }),
230                    );
231                    let Some(Entry::AssistantMessage(entry)) = self.entries.get_mut(index) else {
232                        unreachable!()
233                    };
234                    entry
235                };
236                entry.sync(message);
237            }
238            AgentThreadEntry::CompletedPlan(_) => {
239                if !matches!(self.entries.get(index), Some(Entry::CompletedPlan)) {
240                    self.set_entry(index, Entry::CompletedPlan);
241                }
242            }
243        };
244    }
245
246    fn set_entry(&mut self, index: usize, entry: Entry) {
247        if index == self.entries.len() {
248            self.entries.push(entry);
249        } else {
250            self.entries[index] = entry;
251        }
252    }
253
254    pub fn remove(&mut self, range: Range<usize>) {
255        self.entries.drain(range);
256    }
257
258    pub fn agent_ui_font_size_changed(&mut self, cx: &mut App) {
259        for entry in self.entries.iter() {
260            match entry {
261                Entry::UserMessage { .. }
262                | Entry::AssistantMessage { .. }
263                | Entry::CompletedPlan => {}
264                Entry::ToolCall(ToolCallEntry { content }) => {
265                    for view in content.values() {
266                        if let Ok(diff_editor) = view.clone().downcast::<Editor>() {
267                            diff_editor.update(cx, |diff_editor, cx| {
268                                diff_editor.set_text_style_refinement(
269                                    diff_editor_text_style_refinement(cx),
270                                );
271                                cx.notify();
272                            })
273                        }
274                    }
275                }
276            }
277        }
278    }
279}
280
281impl EventEmitter<EntryViewEvent> for EntryViewState {}
282
283pub struct EntryViewEvent {
284    pub entry_index: usize,
285    pub view_event: ViewEvent,
286}
287
288pub enum ViewEvent {
289    NewDiff(ToolCallId),
290    NewTerminal(ToolCallId),
291    TerminalMovedToBackground(ToolCallId),
292    MessageEditorEvent(Entity<MessageEditor>, MessageEditorEvent),
293    OpenDiffLocation {
294        path: String,
295        position: Point,
296        split: bool,
297    },
298}
299
300#[derive(Debug)]
301pub struct AssistantMessageEntry {
302    scroll_handles_by_chunk_index: HashMap<usize, ScrollHandle>,
303    focus_handle: FocusHandle,
304}
305
306impl AssistantMessageEntry {
307    pub fn scroll_handle_for_chunk(&self, ix: usize) -> Option<ScrollHandle> {
308        self.scroll_handles_by_chunk_index.get(&ix).cloned()
309    }
310
311    pub fn sync(&mut self, message: &acp_thread::AssistantMessage) {
312        if let Some(acp_thread::AssistantMessageChunk::Thought { .. }) = message.chunks.last() {
313            let ix = message.chunks.len() - 1;
314            let handle = self.scroll_handles_by_chunk_index.entry(ix).or_default();
315            handle.scroll_to_bottom();
316        }
317    }
318}
319
320#[derive(Debug)]
321pub struct ToolCallEntry {
322    content: HashMap<EntityId, AnyEntity>,
323}
324
325#[derive(Debug)]
326pub enum Entry {
327    UserMessage(Entity<MessageEditor>),
328    AssistantMessage(AssistantMessageEntry),
329    ToolCall(ToolCallEntry),
330    CompletedPlan,
331}
332
333impl Entry {
334    pub fn focus_handle(&self, cx: &App) -> Option<FocusHandle> {
335        match self {
336            Self::UserMessage(editor) => Some(editor.read(cx).focus_handle(cx)),
337            Self::AssistantMessage(message) => Some(message.focus_handle.clone()),
338            Self::ToolCall(_) | Self::CompletedPlan => None,
339        }
340    }
341
342    pub fn message_editor(&self) -> Option<&Entity<MessageEditor>> {
343        match self {
344            Self::UserMessage(editor) => Some(editor),
345            Self::AssistantMessage(_) | Self::ToolCall(_) | Self::CompletedPlan => None,
346        }
347    }
348
349    pub fn editor_for_diff(&self, diff: &Entity<acp_thread::Diff>) -> Option<Entity<Editor>> {
350        self.content_map()?
351            .get(&diff.entity_id())
352            .cloned()
353            .map(|entity| entity.downcast::<Editor>().unwrap())
354    }
355
356    pub fn terminal(
357        &self,
358        terminal: &Entity<acp_thread::Terminal>,
359    ) -> Option<Entity<TerminalView>> {
360        self.content_map()?
361            .get(&terminal.entity_id())
362            .cloned()
363            .map(|entity| entity.downcast::<TerminalView>().unwrap())
364    }
365
366    pub fn scroll_handle_for_assistant_message_chunk(
367        &self,
368        chunk_ix: usize,
369    ) -> Option<ScrollHandle> {
370        match self {
371            Self::AssistantMessage(message) => message.scroll_handle_for_chunk(chunk_ix),
372            Self::UserMessage(_) | Self::ToolCall(_) | Self::CompletedPlan => None,
373        }
374    }
375
376    fn content_map(&self) -> Option<&HashMap<EntityId, AnyEntity>> {
377        match self {
378            Self::ToolCall(ToolCallEntry { content }) => Some(content),
379            _ => None,
380        }
381    }
382
383    #[cfg(test)]
384    pub fn has_content(&self) -> bool {
385        match self {
386            Self::ToolCall(ToolCallEntry { content }) => !content.is_empty(),
387            Self::UserMessage(_) | Self::AssistantMessage(_) | Self::CompletedPlan => false,
388        }
389    }
390}
391
392fn create_terminal(
393    workspace: WeakEntity<Workspace>,
394    project: WeakEntity<Project>,
395    terminal: Entity<acp_thread::Terminal>,
396    window: &mut Window,
397    cx: &mut App,
398) -> Entity<TerminalView> {
399    cx.new(|cx| {
400        let mut view = TerminalView::new(
401            terminal.read(cx).inner().clone(),
402            workspace,
403            None,
404            project,
405            window,
406            cx,
407        );
408        view.set_embedded_mode(Some(1000), cx);
409        view
410    })
411}
412
413fn create_editor_diff(
414    diff: Entity<acp_thread::Diff>,
415    window: &mut Window,
416    cx: &mut App,
417) -> Entity<Editor> {
418    cx.new(|cx| {
419        let mut editor = Editor::new(
420            EditorMode::Full {
421                scale_ui_elements_with_buffer_font_size: false,
422                show_active_line_background: false,
423                sizing_behavior: SizingBehavior::SizeByContent,
424            },
425            diff.read(cx).multibuffer().clone(),
426            None,
427            window,
428            cx,
429        );
430        editor.set_show_gutter(false, cx);
431        editor.disable_inline_diagnostics();
432        editor.disable_expand_excerpt_buttons(cx);
433        editor.set_show_vertical_scrollbar(false, cx);
434        editor.set_minimap_visibility(MinimapVisibility::Disabled, window, cx);
435        editor.set_soft_wrap_mode(SoftWrap::None, cx);
436        editor.scroll_manager.set_forbid_vertical_scroll(true);
437        editor.set_show_indent_guides(false, cx);
438        editor.set_read_only(true);
439        editor.set_delegate_open_excerpts(true);
440        editor.set_show_breakpoints(false, cx);
441        editor.set_show_code_actions(false, cx);
442        editor.set_show_git_diff_gutter(false, cx);
443        editor.set_expand_all_diff_hunks(cx);
444        editor.set_text_style_refinement(diff_editor_text_style_refinement(cx));
445        editor
446    })
447}
448
449fn diff_editor_text_style_refinement(cx: &mut App) -> TextStyleRefinement {
450    TextStyleRefinement {
451        font_size: Some(
452            TextSize::Small
453                .rems(cx)
454                .to_pixels(ThemeSettings::get_global(cx).agent_ui_font_size(cx))
455                .into(),
456        ),
457        ..Default::default()
458    }
459}
460
461#[cfg(test)]
462mod tests {
463    use std::path::Path;
464    use std::rc::Rc;
465    use std::sync::Arc;
466
467    use acp_thread::{AgentConnection, StubAgentConnection};
468    use agent_client_protocol as acp;
469    use buffer_diff::{DiffHunkStatus, DiffHunkStatusKind};
470    use editor::RowInfo;
471    use fs::FakeFs;
472    use gpui::{AppContext as _, TestAppContext};
473    use parking_lot::RwLock;
474
475    use crate::entry_view_state::EntryViewState;
476    use crate::message_editor::SessionCapabilities;
477    use multi_buffer::MultiBufferRow;
478    use pretty_assertions::assert_matches;
479    use project::Project;
480    use serde_json::json;
481    use settings::SettingsStore;
482    use util::path;
483    use workspace::{MultiWorkspace, PathList};
484
485    #[gpui::test]
486    async fn test_diff_sync(cx: &mut TestAppContext) {
487        init_test(cx);
488        let fs = FakeFs::new(cx.executor());
489        fs.insert_tree(
490            "/project",
491            json!({
492                "hello.txt": "hi world"
493            }),
494        )
495        .await;
496        let project = Project::test(fs, [Path::new(path!("/project"))], cx).await;
497
498        let (multi_workspace, cx) =
499            cx.add_window_view(|window, cx| MultiWorkspace::test_new(project.clone(), window, cx));
500        let workspace = multi_workspace.read_with(cx, |mw, _| mw.workspace().clone());
501
502        let tool_call = acp::ToolCall::new("tool", "Tool call")
503            .status(acp::ToolCallStatus::InProgress)
504            .content(vec![acp::ToolCallContent::Diff(
505                acp::Diff::new("/project/hello.txt", "hello world").old_text("hi world"),
506            )]);
507        let connection = Rc::new(StubAgentConnection::new());
508        let thread = cx
509            .update(|_, cx| {
510                connection.clone().new_session(
511                    project.clone(),
512                    PathList::new(&[Path::new(path!("/project"))]),
513                    cx,
514                )
515            })
516            .await
517            .unwrap();
518        let session_id = thread.update(cx, |thread, _| thread.session_id().clone());
519
520        cx.update(|_, cx| {
521            connection.send_update(session_id, acp::SessionUpdate::ToolCall(tool_call), cx)
522        });
523
524        let thread_store = None;
525        let history: Option<gpui::WeakEntity<crate::ThreadHistory>> = None;
526
527        let view_state = cx.new(|_cx| {
528            EntryViewState::new(
529                workspace.downgrade(),
530                project.downgrade(),
531                thread_store,
532                history,
533                None,
534                Arc::new(RwLock::new(SessionCapabilities::default())),
535                "Test Agent".into(),
536            )
537        });
538
539        view_state.update_in(cx, |view_state, window, cx| {
540            view_state.sync_entry(0, &thread, window, cx)
541        });
542
543        let diff = thread.read_with(cx, |thread, _| {
544            thread
545                .entries()
546                .get(0)
547                .unwrap()
548                .diffs()
549                .next()
550                .unwrap()
551                .clone()
552        });
553
554        cx.run_until_parked();
555
556        let diff_editor = view_state.read_with(cx, |view_state, _cx| {
557            view_state.entry(0).unwrap().editor_for_diff(&diff).unwrap()
558        });
559        assert_eq!(
560            diff_editor.read_with(cx, |editor, cx| editor.text(cx)),
561            "hi world\nhello world"
562        );
563        let row_infos = diff_editor.read_with(cx, |editor, cx| {
564            let multibuffer = editor.buffer().read(cx);
565            multibuffer
566                .snapshot(cx)
567                .row_infos(MultiBufferRow(0))
568                .collect::<Vec<_>>()
569        });
570        assert_matches!(
571            row_infos.as_slice(),
572            [
573                RowInfo {
574                    multibuffer_row: Some(MultiBufferRow(0)),
575                    diff_status: Some(DiffHunkStatus {
576                        kind: DiffHunkStatusKind::Deleted,
577                        ..
578                    }),
579                    ..
580                },
581                RowInfo {
582                    multibuffer_row: Some(MultiBufferRow(1)),
583                    diff_status: Some(DiffHunkStatus {
584                        kind: DiffHunkStatusKind::Added,
585                        ..
586                    }),
587                    ..
588                }
589            ]
590        );
591    }
592
593    fn init_test(cx: &mut TestAppContext) {
594        cx.update(|cx| {
595            let settings_store = SettingsStore::test(cx);
596            cx.set_global(settings_store);
597            theme_settings::init(theme::LoadThemes::JustBase, cx);
598            release_channel::init(semver::Version::new(0, 0, 0), cx);
599        });
600    }
601}