entry_view_state.rs

  1use std::{cell::RefCell, ops::Range, rc::Rc};
  2
  3use super::thread_history::ThreadHistory;
  4use acp_thread::{AcpThread, AgentThreadEntry};
  5use agent::ThreadStore;
  6use agent_client_protocol::{self as acp, ToolCallId};
  7use collections::HashMap;
  8use editor::{Editor, EditorEvent, EditorMode, MinimapVisibility, SizingBehavior};
  9use gpui::{
 10    AnyEntity, App, AppContext as _, Entity, EntityId, EventEmitter, FocusHandle, Focusable,
 11    ScrollHandle, SharedString, TextStyleRefinement, WeakEntity, Window,
 12};
 13use language::language_settings::SoftWrap;
 14use project::Project;
 15use prompt_store::PromptStore;
 16use rope::Point;
 17use settings::Settings as _;
 18use terminal_view::TerminalView;
 19use theme::ThemeSettings;
 20use ui::{Context, TextSize};
 21use workspace::Workspace;
 22
 23use crate::message_editor::{MessageEditor, MessageEditorEvent};
 24
 25pub struct EntryViewState {
 26    workspace: WeakEntity<Workspace>,
 27    project: WeakEntity<Project>,
 28    thread_store: Option<Entity<ThreadStore>>,
 29    history: WeakEntity<ThreadHistory>,
 30    prompt_store: Option<Entity<PromptStore>>,
 31    entries: Vec<Entry>,
 32    prompt_capabilities: Rc<RefCell<acp::PromptCapabilities>>,
 33    available_commands: Rc<RefCell<Vec<acp::AvailableCommand>>>,
 34    agent_name: SharedString,
 35}
 36
 37impl EntryViewState {
 38    pub fn new(
 39        workspace: WeakEntity<Workspace>,
 40        project: WeakEntity<Project>,
 41        thread_store: Option<Entity<ThreadStore>>,
 42        history: WeakEntity<ThreadHistory>,
 43        prompt_store: Option<Entity<PromptStore>>,
 44        prompt_capabilities: Rc<RefCell<acp::PromptCapabilities>>,
 45        available_commands: Rc<RefCell<Vec<acp::AvailableCommand>>>,
 46        agent_name: SharedString,
 47    ) -> Self {
 48        Self {
 49            workspace,
 50            project,
 51            thread_store,
 52            history,
 53            prompt_store,
 54            entries: Vec::new(),
 55            prompt_capabilities,
 56            available_commands,
 57            agent_name,
 58        }
 59    }
 60
 61    pub fn entry(&self, index: usize) -> Option<&Entry> {
 62        self.entries.get(index)
 63    }
 64
 65    pub fn sync_entry(
 66        &mut self,
 67        index: usize,
 68        thread: &Entity<AcpThread>,
 69        window: &mut Window,
 70        cx: &mut Context<Self>,
 71    ) {
 72        let Some(thread_entry) = thread.read(cx).entries().get(index) else {
 73            return;
 74        };
 75
 76        match thread_entry {
 77            AgentThreadEntry::UserMessage(message) => {
 78                let has_id = message.id.is_some();
 79                let is_subagent = thread.read(cx).parent_session_id().is_some();
 80                let chunks = message.chunks.clone();
 81                if let Some(Entry::UserMessage(editor)) = self.entries.get_mut(index) {
 82                    if !editor.focus_handle(cx).is_focused(window) {
 83                        // Only update if we are not editing.
 84                        // If we are, cancelling the edit will set the message to the newest content.
 85                        editor.update(cx, |editor, cx| {
 86                            editor.set_message(chunks, window, cx);
 87                        });
 88                    }
 89                } else {
 90                    let message_editor = cx.new(|cx| {
 91                        let mut editor = MessageEditor::new(
 92                            self.workspace.clone(),
 93                            self.project.clone(),
 94                            self.thread_store.clone(),
 95                            self.history.clone(),
 96                            self.prompt_store.clone(),
 97                            self.prompt_capabilities.clone(),
 98                            self.available_commands.clone(),
 99                            self.agent_name.clone(),
100                            "Edit message - @ to include context",
101                            editor::EditorMode::AutoHeight {
102                                min_lines: 1,
103                                max_lines: None,
104                            },
105                            window,
106                            cx,
107                        );
108                        if !has_id || is_subagent {
109                            editor.set_read_only(true, cx);
110                        }
111                        editor.set_message(chunks, window, cx);
112                        editor
113                    });
114                    cx.subscribe(&message_editor, move |_, editor, event, cx| {
115                        cx.emit(EntryViewEvent {
116                            entry_index: index,
117                            view_event: ViewEvent::MessageEditorEvent(editor, event.clone()),
118                        })
119                    })
120                    .detach();
121                    self.set_entry(index, Entry::UserMessage(message_editor));
122                }
123            }
124            AgentThreadEntry::ToolCall(tool_call) => {
125                let id = tool_call.id.clone();
126                let terminals = tool_call.terminals().cloned().collect::<Vec<_>>();
127                let diffs = tool_call.diffs().cloned().collect::<Vec<_>>();
128
129                let views = if let Some(Entry::ToolCall(tool_call)) = self.entries.get_mut(index) {
130                    &mut tool_call.content
131                } else {
132                    self.set_entry(
133                        index,
134                        Entry::ToolCall(ToolCallEntry {
135                            content: HashMap::default(),
136                        }),
137                    );
138                    let Some(Entry::ToolCall(tool_call)) = self.entries.get_mut(index) else {
139                        unreachable!()
140                    };
141                    &mut tool_call.content
142                };
143
144                let is_tool_call_completed =
145                    matches!(tool_call.status, acp_thread::ToolCallStatus::Completed);
146
147                for terminal in terminals {
148                    match views.entry(terminal.entity_id()) {
149                        collections::hash_map::Entry::Vacant(entry) => {
150                            let element = create_terminal(
151                                self.workspace.clone(),
152                                self.project.clone(),
153                                terminal.clone(),
154                                window,
155                                cx,
156                            )
157                            .into_any();
158                            cx.emit(EntryViewEvent {
159                                entry_index: index,
160                                view_event: ViewEvent::NewTerminal(id.clone()),
161                            });
162                            entry.insert(element);
163                        }
164                        collections::hash_map::Entry::Occupied(_entry) => {
165                            if is_tool_call_completed && terminal.read(cx).output().is_none() {
166                                cx.emit(EntryViewEvent {
167                                    entry_index: index,
168                                    view_event: ViewEvent::TerminalMovedToBackground(id.clone()),
169                                });
170                            }
171                        }
172                    }
173                }
174
175                for diff in diffs {
176                    views.entry(diff.entity_id()).or_insert_with(|| {
177                        let editor = create_editor_diff(diff.clone(), window, cx);
178                        cx.subscribe(&editor, {
179                            let diff = diff.clone();
180                            let entry_index = index;
181                            move |_this, _editor, event: &EditorEvent, cx| {
182                                if let EditorEvent::OpenExcerptsRequested {
183                                    selections_by_buffer,
184                                    split,
185                                } = event
186                                {
187                                    let multibuffer = diff.read(cx).multibuffer();
188                                    if let Some((buffer_id, (ranges, _))) =
189                                        selections_by_buffer.iter().next()
190                                    {
191                                        if let Some(buffer) =
192                                            multibuffer.read(cx).buffer(*buffer_id)
193                                        {
194                                            if let Some(range) = ranges.first() {
195                                                let point =
196                                                    buffer.read(cx).offset_to_point(range.start.0);
197                                                if let Some(path) = diff.read(cx).file_path(cx) {
198                                                    cx.emit(EntryViewEvent {
199                                                        entry_index,
200                                                        view_event: ViewEvent::OpenDiffLocation {
201                                                            path,
202                                                            position: point,
203                                                            split: *split,
204                                                        },
205                                                    });
206                                                }
207                                            }
208                                        }
209                                    }
210                                }
211                            }
212                        })
213                        .detach();
214                        cx.emit(EntryViewEvent {
215                            entry_index: index,
216                            view_event: ViewEvent::NewDiff(id.clone()),
217                        });
218                        editor.into_any()
219                    });
220                }
221            }
222            AgentThreadEntry::AssistantMessage(message) => {
223                let entry = if let Some(Entry::AssistantMessage(entry)) =
224                    self.entries.get_mut(index)
225                {
226                    entry
227                } else {
228                    self.set_entry(
229                        index,
230                        Entry::AssistantMessage(AssistantMessageEntry::default()),
231                    );
232                    let Some(Entry::AssistantMessage(entry)) = self.entries.get_mut(index) else {
233                        unreachable!()
234                    };
235                    entry
236                };
237                entry.sync(message);
238            }
239        };
240    }
241
242    fn set_entry(&mut self, index: usize, entry: Entry) {
243        if index == self.entries.len() {
244            self.entries.push(entry);
245        } else {
246            self.entries[index] = entry;
247        }
248    }
249
250    pub fn remove(&mut self, range: Range<usize>) {
251        self.entries.drain(range);
252    }
253
254    pub fn agent_ui_font_size_changed(&mut self, cx: &mut App) {
255        for entry in self.entries.iter() {
256            match entry {
257                Entry::UserMessage { .. } | Entry::AssistantMessage { .. } => {}
258                Entry::ToolCall(ToolCallEntry { content }) => {
259                    for view in content.values() {
260                        if let Ok(diff_editor) = view.clone().downcast::<Editor>() {
261                            diff_editor.update(cx, |diff_editor, cx| {
262                                diff_editor.set_text_style_refinement(
263                                    diff_editor_text_style_refinement(cx),
264                                );
265                                cx.notify();
266                            })
267                        }
268                    }
269                }
270            }
271        }
272    }
273}
274
275impl EventEmitter<EntryViewEvent> for EntryViewState {}
276
277pub struct EntryViewEvent {
278    pub entry_index: usize,
279    pub view_event: ViewEvent,
280}
281
282pub enum ViewEvent {
283    NewDiff(ToolCallId),
284    NewTerminal(ToolCallId),
285    TerminalMovedToBackground(ToolCallId),
286    MessageEditorEvent(Entity<MessageEditor>, MessageEditorEvent),
287    OpenDiffLocation {
288        path: String,
289        position: Point,
290        split: bool,
291    },
292}
293
294#[derive(Default, Debug)]
295pub struct AssistantMessageEntry {
296    scroll_handles_by_chunk_index: HashMap<usize, ScrollHandle>,
297}
298
299impl AssistantMessageEntry {
300    pub fn scroll_handle_for_chunk(&self, ix: usize) -> Option<ScrollHandle> {
301        self.scroll_handles_by_chunk_index.get(&ix).cloned()
302    }
303
304    pub fn sync(&mut self, message: &acp_thread::AssistantMessage) {
305        if let Some(acp_thread::AssistantMessageChunk::Thought { .. }) = message.chunks.last() {
306            let ix = message.chunks.len() - 1;
307            let handle = self.scroll_handles_by_chunk_index.entry(ix).or_default();
308            handle.scroll_to_bottom();
309        }
310    }
311}
312
313#[derive(Debug)]
314pub struct ToolCallEntry {
315    content: HashMap<EntityId, AnyEntity>,
316}
317
318#[derive(Debug)]
319pub enum Entry {
320    UserMessage(Entity<MessageEditor>),
321    AssistantMessage(AssistantMessageEntry),
322    ToolCall(ToolCallEntry),
323}
324
325impl Entry {
326    pub fn focus_handle(&self, cx: &App) -> Option<FocusHandle> {
327        match self {
328            Self::UserMessage(editor) => Some(editor.read(cx).focus_handle(cx)),
329            Self::AssistantMessage(_) | Self::ToolCall(_) => None,
330        }
331    }
332
333    pub fn message_editor(&self) -> Option<&Entity<MessageEditor>> {
334        match self {
335            Self::UserMessage(editor) => Some(editor),
336            Self::AssistantMessage(_) | Self::ToolCall(_) => None,
337        }
338    }
339
340    pub fn editor_for_diff(&self, diff: &Entity<acp_thread::Diff>) -> Option<Entity<Editor>> {
341        self.content_map()?
342            .get(&diff.entity_id())
343            .cloned()
344            .map(|entity| entity.downcast::<Editor>().unwrap())
345    }
346
347    pub fn terminal(
348        &self,
349        terminal: &Entity<acp_thread::Terminal>,
350    ) -> Option<Entity<TerminalView>> {
351        self.content_map()?
352            .get(&terminal.entity_id())
353            .cloned()
354            .map(|entity| entity.downcast::<TerminalView>().unwrap())
355    }
356
357    pub fn scroll_handle_for_assistant_message_chunk(
358        &self,
359        chunk_ix: usize,
360    ) -> Option<ScrollHandle> {
361        match self {
362            Self::AssistantMessage(message) => message.scroll_handle_for_chunk(chunk_ix),
363            Self::UserMessage(_) | Self::ToolCall(_) => None,
364        }
365    }
366
367    fn content_map(&self) -> Option<&HashMap<EntityId, AnyEntity>> {
368        match self {
369            Self::ToolCall(ToolCallEntry { content }) => Some(content),
370            _ => None,
371        }
372    }
373
374    #[cfg(test)]
375    pub fn has_content(&self) -> bool {
376        match self {
377            Self::ToolCall(ToolCallEntry { content }) => !content.is_empty(),
378            Self::UserMessage(_) | Self::AssistantMessage(_) => false,
379        }
380    }
381}
382
383fn create_terminal(
384    workspace: WeakEntity<Workspace>,
385    project: WeakEntity<Project>,
386    terminal: Entity<acp_thread::Terminal>,
387    window: &mut Window,
388    cx: &mut App,
389) -> Entity<TerminalView> {
390    cx.new(|cx| {
391        let mut view = TerminalView::new(
392            terminal.read(cx).inner().clone(),
393            workspace,
394            None,
395            project,
396            window,
397            cx,
398        );
399        view.set_embedded_mode(Some(1000), cx);
400        view
401    })
402}
403
404fn create_editor_diff(
405    diff: Entity<acp_thread::Diff>,
406    window: &mut Window,
407    cx: &mut App,
408) -> Entity<Editor> {
409    cx.new(|cx| {
410        let mut editor = Editor::new(
411            EditorMode::Full {
412                scale_ui_elements_with_buffer_font_size: false,
413                show_active_line_background: false,
414                sizing_behavior: SizingBehavior::SizeByContent,
415            },
416            diff.read(cx).multibuffer().clone(),
417            None,
418            window,
419            cx,
420        );
421        editor.set_show_gutter(false, cx);
422        editor.disable_inline_diagnostics();
423        editor.disable_expand_excerpt_buttons(cx);
424        editor.set_show_vertical_scrollbar(false, cx);
425        editor.set_minimap_visibility(MinimapVisibility::Disabled, window, cx);
426        editor.set_soft_wrap_mode(SoftWrap::None, cx);
427        editor.scroll_manager.set_forbid_vertical_scroll(true);
428        editor.set_show_indent_guides(false, cx);
429        editor.set_read_only(true);
430        editor.set_delegate_open_excerpts(true);
431        editor.set_show_breakpoints(false, cx);
432        editor.set_show_code_actions(false, cx);
433        editor.set_show_git_diff_gutter(false, cx);
434        editor.set_expand_all_diff_hunks(cx);
435        editor.set_text_style_refinement(diff_editor_text_style_refinement(cx));
436        editor
437    })
438}
439
440fn diff_editor_text_style_refinement(cx: &mut App) -> TextStyleRefinement {
441    TextStyleRefinement {
442        font_size: Some(
443            TextSize::Small
444                .rems(cx)
445                .to_pixels(ThemeSettings::get_global(cx).agent_ui_font_size(cx))
446                .into(),
447        ),
448        ..Default::default()
449    }
450}
451
452#[cfg(test)]
453mod tests {
454    use std::path::Path;
455    use std::rc::Rc;
456
457    use acp_thread::{AgentConnection, StubAgentConnection};
458    use agent_client_protocol as acp;
459    use buffer_diff::{DiffHunkStatus, DiffHunkStatusKind};
460    use editor::RowInfo;
461    use fs::FakeFs;
462    use gpui::{AppContext as _, TestAppContext};
463
464    use crate::entry_view_state::EntryViewState;
465    use multi_buffer::MultiBufferRow;
466    use pretty_assertions::assert_matches;
467    use project::Project;
468    use serde_json::json;
469    use settings::SettingsStore;
470    use util::path;
471    use workspace::MultiWorkspace;
472
473    #[gpui::test]
474    async fn test_diff_sync(cx: &mut TestAppContext) {
475        init_test(cx);
476        let fs = FakeFs::new(cx.executor());
477        fs.insert_tree(
478            "/project",
479            json!({
480                "hello.txt": "hi world"
481            }),
482        )
483        .await;
484        let project = Project::test(fs, [Path::new(path!("/project"))], cx).await;
485
486        let (multi_workspace, cx) =
487            cx.add_window_view(|window, cx| MultiWorkspace::test_new(project.clone(), window, cx));
488        let workspace = multi_workspace.read_with(cx, |mw, _| mw.workspace().clone());
489
490        let tool_call = acp::ToolCall::new("tool", "Tool call")
491            .status(acp::ToolCallStatus::InProgress)
492            .content(vec![acp::ToolCallContent::Diff(
493                acp::Diff::new("/project/hello.txt", "hello world").old_text("hi world"),
494            )]);
495        let connection = Rc::new(StubAgentConnection::new());
496        let thread = cx
497            .update(|_, cx| {
498                connection
499                    .clone()
500                    .new_session(project.clone(), Path::new(path!("/project")), cx)
501            })
502            .await
503            .unwrap();
504        let session_id = thread.update(cx, |thread, _| thread.session_id().clone());
505
506        cx.update(|_, cx| {
507            connection.send_update(session_id, acp::SessionUpdate::ToolCall(tool_call), cx)
508        });
509
510        let thread_store = None;
511        let history =
512            cx.update(|window, cx| cx.new(|cx| crate::ThreadHistory::new(None, window, cx)));
513
514        let view_state = cx.new(|_cx| {
515            EntryViewState::new(
516                workspace.downgrade(),
517                project.downgrade(),
518                thread_store,
519                history.downgrade(),
520                None,
521                Default::default(),
522                Default::default(),
523                "Test Agent".into(),
524            )
525        });
526
527        view_state.update_in(cx, |view_state, window, cx| {
528            view_state.sync_entry(0, &thread, window, cx)
529        });
530
531        let diff = thread.read_with(cx, |thread, _| {
532            thread
533                .entries()
534                .get(0)
535                .unwrap()
536                .diffs()
537                .next()
538                .unwrap()
539                .clone()
540        });
541
542        cx.run_until_parked();
543
544        let diff_editor = view_state.read_with(cx, |view_state, _cx| {
545            view_state.entry(0).unwrap().editor_for_diff(&diff).unwrap()
546        });
547        assert_eq!(
548            diff_editor.read_with(cx, |editor, cx| editor.text(cx)),
549            "hi world\nhello world"
550        );
551        let row_infos = diff_editor.read_with(cx, |editor, cx| {
552            let multibuffer = editor.buffer().read(cx);
553            multibuffer
554                .snapshot(cx)
555                .row_infos(MultiBufferRow(0))
556                .collect::<Vec<_>>()
557        });
558        assert_matches!(
559            row_infos.as_slice(),
560            [
561                RowInfo {
562                    multibuffer_row: Some(MultiBufferRow(0)),
563                    diff_status: Some(DiffHunkStatus {
564                        kind: DiffHunkStatusKind::Deleted,
565                        ..
566                    }),
567                    ..
568                },
569                RowInfo {
570                    multibuffer_row: Some(MultiBufferRow(1)),
571                    diff_status: Some(DiffHunkStatus {
572                        kind: DiffHunkStatusKind::Added,
573                        ..
574                    }),
575                    ..
576                }
577            ]
578        );
579    }
580
581    fn init_test(cx: &mut TestAppContext) {
582        cx.update(|cx| {
583            let settings_store = SettingsStore::test(cx);
584            cx.set_global(settings_store);
585            theme::init(theme::LoadThemes::JustBase, cx);
586            release_channel::init(semver::Version::new(0, 0, 0), cx);
587        });
588    }
589}