entry_view_state.rs

  1use std::{cell::Cell, ops::Range, rc::Rc};
  2
  3use acp_thread::{AcpThread, AgentThreadEntry};
  4use agent_client_protocol::{PromptCapabilities, ToolCallId};
  5use agent2::HistoryStore;
  6use collections::HashMap;
  7use editor::{Editor, EditorMode, MinimapVisibility};
  8use gpui::{
  9    AnyEntity, App, AppContext as _, Entity, EntityId, EventEmitter, Focusable,
 10    TextStyleRefinement, WeakEntity, Window,
 11};
 12use language::language_settings::SoftWrap;
 13use project::Project;
 14use prompt_store::PromptStore;
 15use settings::Settings as _;
 16use terminal_view::TerminalView;
 17use theme::ThemeSettings;
 18use ui::{Context, TextSize};
 19use workspace::Workspace;
 20
 21use crate::acp::message_editor::{MessageEditor, MessageEditorEvent};
 22
 23pub struct EntryViewState {
 24    workspace: WeakEntity<Workspace>,
 25    project: Entity<Project>,
 26    history_store: Entity<HistoryStore>,
 27    prompt_store: Option<Entity<PromptStore>>,
 28    entries: Vec<Entry>,
 29    prevent_slash_commands: bool,
 30    prompt_capabilities: Rc<Cell<PromptCapabilities>>,
 31}
 32
 33impl EntryViewState {
 34    pub fn new(
 35        workspace: WeakEntity<Workspace>,
 36        project: Entity<Project>,
 37        history_store: Entity<HistoryStore>,
 38        prompt_store: Option<Entity<PromptStore>>,
 39        prompt_capabilities: Rc<Cell<PromptCapabilities>>,
 40        prevent_slash_commands: bool,
 41    ) -> Self {
 42        Self {
 43            workspace,
 44            project,
 45            history_store,
 46            prompt_store,
 47            entries: Vec::new(),
 48            prevent_slash_commands,
 49            prompt_capabilities,
 50        }
 51    }
 52
 53    pub fn entry(&self, index: usize) -> Option<&Entry> {
 54        self.entries.get(index)
 55    }
 56
 57    pub fn sync_entry(
 58        &mut self,
 59        index: usize,
 60        thread: &Entity<AcpThread>,
 61        window: &mut Window,
 62        cx: &mut Context<Self>,
 63    ) {
 64        let Some(thread_entry) = thread.read(cx).entries().get(index) else {
 65            return;
 66        };
 67
 68        match thread_entry {
 69            AgentThreadEntry::UserMessage(message) => {
 70                let has_id = message.id.is_some();
 71                let chunks = message.chunks.clone();
 72                if let Some(Entry::UserMessage(editor)) = self.entries.get_mut(index) {
 73                    if !editor.focus_handle(cx).is_focused(window) {
 74                        // Only update if we are not editing.
 75                        // If we are, cancelling the edit will set the message to the newest content.
 76                        editor.update(cx, |editor, cx| {
 77                            editor.set_message(chunks, window, cx);
 78                        });
 79                    }
 80                } else {
 81                    let message_editor = cx.new(|cx| {
 82                        let mut editor = MessageEditor::new(
 83                            self.workspace.clone(),
 84                            self.project.clone(),
 85                            self.history_store.clone(),
 86                            self.prompt_store.clone(),
 87                            self.prompt_capabilities.clone(),
 88                            "Edit message - @ to include context",
 89                            self.prevent_slash_commands,
 90                            editor::EditorMode::AutoHeight {
 91                                min_lines: 1,
 92                                max_lines: None,
 93                            },
 94                            window,
 95                            cx,
 96                        );
 97                        if !has_id {
 98                            editor.set_read_only(true, cx);
 99                        }
100                        editor.set_message(chunks, window, cx);
101                        editor
102                    });
103                    cx.subscribe(&message_editor, move |_, editor, event, cx| {
104                        cx.emit(EntryViewEvent {
105                            entry_index: index,
106                            view_event: ViewEvent::MessageEditorEvent(editor, *event),
107                        })
108                    })
109                    .detach();
110                    self.set_entry(index, Entry::UserMessage(message_editor));
111                }
112            }
113            AgentThreadEntry::ToolCall(tool_call) => {
114                let id = tool_call.id.clone();
115                let terminals = tool_call.terminals().cloned().collect::<Vec<_>>();
116                let diffs = tool_call.diffs().cloned().collect::<Vec<_>>();
117
118                let views = if let Some(Entry::Content(views)) = self.entries.get_mut(index) {
119                    views
120                } else {
121                    self.set_entry(index, Entry::empty());
122                    let Some(Entry::Content(views)) = self.entries.get_mut(index) else {
123                        unreachable!()
124                    };
125                    views
126                };
127
128                for terminal in terminals {
129                    views.entry(terminal.entity_id()).or_insert_with(|| {
130                        let element = create_terminal(
131                            self.workspace.clone(),
132                            self.project.clone(),
133                            terminal.clone(),
134                            window,
135                            cx,
136                        )
137                        .into_any();
138                        cx.emit(EntryViewEvent {
139                            entry_index: index,
140                            view_event: ViewEvent::NewTerminal(id.clone()),
141                        });
142                        element
143                    });
144                }
145
146                for diff in diffs {
147                    views.entry(diff.entity_id()).or_insert_with(|| {
148                        let element = create_editor_diff(diff.clone(), window, cx).into_any();
149                        cx.emit(EntryViewEvent {
150                            entry_index: index,
151                            view_event: ViewEvent::NewDiff(id.clone()),
152                        });
153                        element
154                    });
155                }
156            }
157            AgentThreadEntry::AssistantMessage(_) => {
158                if index == self.entries.len() {
159                    self.entries.push(Entry::empty())
160                }
161            }
162        };
163    }
164
165    fn set_entry(&mut self, index: usize, entry: Entry) {
166        if index == self.entries.len() {
167            self.entries.push(entry);
168        } else {
169            self.entries[index] = entry;
170        }
171    }
172
173    pub fn remove(&mut self, range: Range<usize>) {
174        self.entries.drain(range);
175    }
176
177    pub fn settings_changed(&mut self, cx: &mut App) {
178        for entry in self.entries.iter() {
179            match entry {
180                Entry::UserMessage { .. } => {}
181                Entry::Content(response_views) => {
182                    for view in response_views.values() {
183                        if let Ok(diff_editor) = view.clone().downcast::<Editor>() {
184                            diff_editor.update(cx, |diff_editor, cx| {
185                                diff_editor.set_text_style_refinement(
186                                    diff_editor_text_style_refinement(cx),
187                                );
188                                cx.notify();
189                            })
190                        }
191                    }
192                }
193            }
194        }
195    }
196}
197
198impl EventEmitter<EntryViewEvent> for EntryViewState {}
199
200pub struct EntryViewEvent {
201    pub entry_index: usize,
202    pub view_event: ViewEvent,
203}
204
205pub enum ViewEvent {
206    NewDiff(ToolCallId),
207    NewTerminal(ToolCallId),
208    MessageEditorEvent(Entity<MessageEditor>, MessageEditorEvent),
209}
210
211#[derive(Debug)]
212pub enum Entry {
213    UserMessage(Entity<MessageEditor>),
214    Content(HashMap<EntityId, AnyEntity>),
215}
216
217impl Entry {
218    pub fn message_editor(&self) -> Option<&Entity<MessageEditor>> {
219        match self {
220            Self::UserMessage(editor) => Some(editor),
221            Entry::Content(_) => None,
222        }
223    }
224
225    pub fn editor_for_diff(&self, diff: &Entity<acp_thread::Diff>) -> Option<Entity<Editor>> {
226        self.content_map()?
227            .get(&diff.entity_id())
228            .cloned()
229            .map(|entity| entity.downcast::<Editor>().unwrap())
230    }
231
232    pub fn terminal(
233        &self,
234        terminal: &Entity<acp_thread::Terminal>,
235    ) -> Option<Entity<TerminalView>> {
236        self.content_map()?
237            .get(&terminal.entity_id())
238            .cloned()
239            .map(|entity| entity.downcast::<TerminalView>().unwrap())
240    }
241
242    fn content_map(&self) -> Option<&HashMap<EntityId, AnyEntity>> {
243        match self {
244            Self::Content(map) => Some(map),
245            _ => None,
246        }
247    }
248
249    fn empty() -> Self {
250        Self::Content(HashMap::default())
251    }
252
253    #[cfg(test)]
254    pub fn has_content(&self) -> bool {
255        match self {
256            Self::Content(map) => !map.is_empty(),
257            Self::UserMessage(_) => false,
258        }
259    }
260}
261
262fn create_terminal(
263    workspace: WeakEntity<Workspace>,
264    project: Entity<Project>,
265    terminal: Entity<acp_thread::Terminal>,
266    window: &mut Window,
267    cx: &mut App,
268) -> Entity<TerminalView> {
269    cx.new(|cx| {
270        let mut view = TerminalView::new(
271            terminal.read(cx).inner().clone(),
272            workspace.clone(),
273            None,
274            project.downgrade(),
275            window,
276            cx,
277        );
278        view.set_embedded_mode(Some(1000), cx);
279        view
280    })
281}
282
283fn create_editor_diff(
284    diff: Entity<acp_thread::Diff>,
285    window: &mut Window,
286    cx: &mut App,
287) -> Entity<Editor> {
288    cx.new(|cx| {
289        let mut editor = Editor::new(
290            EditorMode::Full {
291                scale_ui_elements_with_buffer_font_size: false,
292                show_active_line_background: false,
293                sized_by_content: true,
294            },
295            diff.read(cx).multibuffer().clone(),
296            None,
297            window,
298            cx,
299        );
300        editor.set_show_gutter(false, cx);
301        editor.disable_inline_diagnostics();
302        editor.disable_expand_excerpt_buttons(cx);
303        editor.set_show_vertical_scrollbar(false, cx);
304        editor.set_minimap_visibility(MinimapVisibility::Disabled, window, cx);
305        editor.set_soft_wrap_mode(SoftWrap::None, cx);
306        editor.scroll_manager.set_forbid_vertical_scroll(true);
307        editor.set_show_indent_guides(false, cx);
308        editor.set_read_only(true);
309        editor.set_show_breakpoints(false, cx);
310        editor.set_show_code_actions(false, cx);
311        editor.set_show_git_diff_gutter(false, cx);
312        editor.set_expand_all_diff_hunks(cx);
313        editor.set_text_style_refinement(diff_editor_text_style_refinement(cx));
314        editor
315    })
316}
317
318fn diff_editor_text_style_refinement(cx: &mut App) -> TextStyleRefinement {
319    TextStyleRefinement {
320        font_size: Some(
321            TextSize::Small
322                .rems(cx)
323                .to_pixels(ThemeSettings::get_global(cx).agent_font_size(cx))
324                .into(),
325        ),
326        ..Default::default()
327    }
328}
329
330#[cfg(test)]
331mod tests {
332    use std::{path::Path, rc::Rc};
333
334    use acp_thread::{AgentConnection, StubAgentConnection};
335    use agent_client_protocol as acp;
336    use agent_settings::AgentSettings;
337    use agent2::HistoryStore;
338    use assistant_context::ContextStore;
339    use buffer_diff::{DiffHunkStatus, DiffHunkStatusKind};
340    use editor::{EditorSettings, RowInfo};
341    use fs::FakeFs;
342    use gpui::{AppContext as _, SemanticVersion, TestAppContext};
343
344    use crate::acp::entry_view_state::EntryViewState;
345    use multi_buffer::MultiBufferRow;
346    use pretty_assertions::assert_matches;
347    use project::Project;
348    use serde_json::json;
349    use settings::{Settings as _, SettingsStore};
350    use theme::ThemeSettings;
351    use util::path;
352    use workspace::Workspace;
353
354    #[gpui::test]
355    async fn test_diff_sync(cx: &mut TestAppContext) {
356        init_test(cx);
357        let fs = FakeFs::new(cx.executor());
358        fs.insert_tree(
359            "/project",
360            json!({
361                "hello.txt": "hi world"
362            }),
363        )
364        .await;
365        let project = Project::test(fs, [Path::new(path!("/project"))], cx).await;
366
367        let (workspace, cx) =
368            cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
369
370        let tool_call = acp::ToolCall {
371            id: acp::ToolCallId("tool".into()),
372            title: "Tool call".into(),
373            kind: acp::ToolKind::Other,
374            status: acp::ToolCallStatus::InProgress,
375            content: vec![acp::ToolCallContent::Diff {
376                diff: acp::Diff {
377                    path: "/project/hello.txt".into(),
378                    old_text: Some("hi world".into()),
379                    new_text: "hello world".into(),
380                },
381            }],
382            locations: vec![],
383            raw_input: None,
384            raw_output: None,
385        };
386        let connection = Rc::new(StubAgentConnection::new());
387        let thread = cx
388            .update(|_, cx| {
389                connection
390                    .clone()
391                    .new_thread(project.clone(), Path::new(path!("/project")), cx)
392            })
393            .await
394            .unwrap();
395        let session_id = thread.update(cx, |thread, _| thread.session_id().clone());
396
397        cx.update(|_, cx| {
398            connection.send_update(session_id, acp::SessionUpdate::ToolCall(tool_call), cx)
399        });
400
401        let context_store = cx.new(|cx| ContextStore::fake(project.clone(), cx));
402        let history_store = cx.new(|cx| HistoryStore::new(context_store, cx));
403
404        let view_state = cx.new(|_cx| {
405            EntryViewState::new(
406                workspace.downgrade(),
407                project.clone(),
408                history_store,
409                None,
410                Default::default(),
411                false,
412            )
413        });
414
415        view_state.update_in(cx, |view_state, window, cx| {
416            view_state.sync_entry(0, &thread, window, cx)
417        });
418
419        let diff = thread.read_with(cx, |thread, _cx| {
420            thread
421                .entries()
422                .get(0)
423                .unwrap()
424                .diffs()
425                .next()
426                .unwrap()
427                .clone()
428        });
429
430        cx.run_until_parked();
431
432        let diff_editor = view_state.read_with(cx, |view_state, _cx| {
433            view_state.entry(0).unwrap().editor_for_diff(&diff).unwrap()
434        });
435        assert_eq!(
436            diff_editor.read_with(cx, |editor, cx| editor.text(cx)),
437            "hi world\nhello world"
438        );
439        let row_infos = diff_editor.read_with(cx, |editor, cx| {
440            let multibuffer = editor.buffer().read(cx);
441            multibuffer
442                .snapshot(cx)
443                .row_infos(MultiBufferRow(0))
444                .collect::<Vec<_>>()
445        });
446        assert_matches!(
447            row_infos.as_slice(),
448            [
449                RowInfo {
450                    multibuffer_row: Some(MultiBufferRow(0)),
451                    diff_status: Some(DiffHunkStatus {
452                        kind: DiffHunkStatusKind::Deleted,
453                        ..
454                    }),
455                    ..
456                },
457                RowInfo {
458                    multibuffer_row: Some(MultiBufferRow(1)),
459                    diff_status: Some(DiffHunkStatus {
460                        kind: DiffHunkStatusKind::Added,
461                        ..
462                    }),
463                    ..
464                }
465            ]
466        );
467    }
468
469    fn init_test(cx: &mut TestAppContext) {
470        cx.update(|cx| {
471            let settings_store = SettingsStore::test(cx);
472            cx.set_global(settings_store);
473            language::init(cx);
474            Project::init_settings(cx);
475            AgentSettings::register(cx);
476            workspace::init_settings(cx);
477            ThemeSettings::register(cx);
478            release_channel::init(SemanticVersion::default(), cx);
479            EditorSettings::register(cx);
480        });
481    }
482}