entry_view_state.rs

  1use std::ops::Range;
  2
  3use acp_thread::{AcpThread, AgentThreadEntry};
  4use agent2::HistoryStore;
  5use collections::HashMap;
  6use editor::{Editor, EditorMode, MinimapVisibility};
  7use gpui::{
  8    AnyEntity, App, AppContext as _, Entity, EntityId, EventEmitter, Focusable,
  9    TextStyleRefinement, WeakEntity, Window,
 10};
 11use language::language_settings::SoftWrap;
 12use project::Project;
 13use prompt_store::PromptStore;
 14use settings::Settings as _;
 15use terminal_view::TerminalView;
 16use theme::ThemeSettings;
 17use ui::{Context, TextSize};
 18use workspace::Workspace;
 19
 20use crate::acp::message_editor::{MessageEditor, MessageEditorEvent};
 21
 22pub struct EntryViewState {
 23    workspace: WeakEntity<Workspace>,
 24    project: Entity<Project>,
 25    history_store: Entity<HistoryStore>,
 26    prompt_store: Option<Entity<PromptStore>>,
 27    entries: Vec<Entry>,
 28    prevent_slash_commands: bool,
 29}
 30
 31impl EntryViewState {
 32    pub fn new(
 33        workspace: WeakEntity<Workspace>,
 34        project: Entity<Project>,
 35        history_store: Entity<HistoryStore>,
 36        prompt_store: Option<Entity<PromptStore>>,
 37        prevent_slash_commands: bool,
 38    ) -> Self {
 39        Self {
 40            workspace,
 41            project,
 42            history_store,
 43            prompt_store,
 44            entries: Vec::new(),
 45            prevent_slash_commands,
 46        }
 47    }
 48
 49    pub fn entry(&self, index: usize) -> Option<&Entry> {
 50        self.entries.get(index)
 51    }
 52
 53    pub fn sync_entry(
 54        &mut self,
 55        index: usize,
 56        thread: &Entity<AcpThread>,
 57        window: &mut Window,
 58        cx: &mut Context<Self>,
 59    ) {
 60        let Some(thread_entry) = thread.read(cx).entries().get(index) else {
 61            return;
 62        };
 63
 64        match thread_entry {
 65            AgentThreadEntry::UserMessage(message) => {
 66                let has_id = message.id.is_some();
 67                let chunks = message.chunks.clone();
 68                if let Some(Entry::UserMessage(editor)) = self.entries.get_mut(index) {
 69                    if !editor.focus_handle(cx).is_focused(window) {
 70                        // Only update if we are not editing.
 71                        // If we are, cancelling the edit will set the message to the newest content.
 72                        editor.update(cx, |editor, cx| {
 73                            editor.set_message(chunks, window, cx);
 74                        });
 75                    }
 76                } else {
 77                    let message_editor = cx.new(|cx| {
 78                        let mut editor = MessageEditor::new(
 79                            self.workspace.clone(),
 80                            self.project.clone(),
 81                            self.history_store.clone(),
 82                            self.prompt_store.clone(),
 83                            "Edit message - @ to include context",
 84                            self.prevent_slash_commands,
 85                            editor::EditorMode::AutoHeight {
 86                                min_lines: 1,
 87                                max_lines: None,
 88                            },
 89                            window,
 90                            cx,
 91                        );
 92                        if !has_id {
 93                            editor.set_read_only(true, cx);
 94                        }
 95                        editor.set_message(chunks, window, cx);
 96                        editor
 97                    });
 98                    cx.subscribe(&message_editor, move |_, editor, event, cx| {
 99                        cx.emit(EntryViewEvent {
100                            entry_index: index,
101                            view_event: ViewEvent::MessageEditorEvent(editor, *event),
102                        })
103                    })
104                    .detach();
105                    self.set_entry(index, Entry::UserMessage(message_editor));
106                }
107            }
108            AgentThreadEntry::ToolCall(tool_call) => {
109                let terminals = tool_call.terminals().cloned().collect::<Vec<_>>();
110                let diffs = tool_call.diffs().cloned().collect::<Vec<_>>();
111
112                let views = if let Some(Entry::Content(views)) = self.entries.get_mut(index) {
113                    views
114                } else {
115                    self.set_entry(index, Entry::empty());
116                    let Some(Entry::Content(views)) = self.entries.get_mut(index) else {
117                        unreachable!()
118                    };
119                    views
120                };
121
122                for terminal in terminals {
123                    views.entry(terminal.entity_id()).or_insert_with(|| {
124                        let element = create_terminal(
125                            self.workspace.clone(),
126                            self.project.clone(),
127                            terminal.clone(),
128                            window,
129                            cx,
130                        )
131                        .into_any();
132                        cx.emit(EntryViewEvent {
133                            entry_index: index,
134                            view_event: ViewEvent::NewTerminal(terminal.entity_id()),
135                        });
136                        element
137                    });
138                }
139
140                for diff in diffs {
141                    views
142                        .entry(diff.entity_id())
143                        .or_insert_with(|| create_editor_diff(diff.clone(), window, cx).into_any());
144                }
145            }
146            AgentThreadEntry::AssistantMessage(_) => {
147                if index == self.entries.len() {
148                    self.entries.push(Entry::empty())
149                }
150            }
151        };
152    }
153
154    fn set_entry(&mut self, index: usize, entry: Entry) {
155        if index == self.entries.len() {
156            self.entries.push(entry);
157        } else {
158            self.entries[index] = entry;
159        }
160    }
161
162    pub fn remove(&mut self, range: Range<usize>) {
163        self.entries.drain(range);
164    }
165
166    pub fn settings_changed(&mut self, cx: &mut App) {
167        for entry in self.entries.iter() {
168            match entry {
169                Entry::UserMessage { .. } => {}
170                Entry::Content(response_views) => {
171                    for view in response_views.values() {
172                        if let Ok(diff_editor) = view.clone().downcast::<Editor>() {
173                            diff_editor.update(cx, |diff_editor, cx| {
174                                diff_editor.set_text_style_refinement(
175                                    diff_editor_text_style_refinement(cx),
176                                );
177                                cx.notify();
178                            })
179                        }
180                    }
181                }
182            }
183        }
184    }
185}
186
187impl EventEmitter<EntryViewEvent> for EntryViewState {}
188
189pub struct EntryViewEvent {
190    pub entry_index: usize,
191    pub view_event: ViewEvent,
192}
193
194pub enum ViewEvent {
195    NewTerminal(EntityId),
196    MessageEditorEvent(Entity<MessageEditor>, MessageEditorEvent),
197}
198
199#[derive(Debug)]
200pub enum Entry {
201    UserMessage(Entity<MessageEditor>),
202    Content(HashMap<EntityId, AnyEntity>),
203}
204
205impl Entry {
206    pub fn message_editor(&self) -> Option<&Entity<MessageEditor>> {
207        match self {
208            Self::UserMessage(editor) => Some(editor),
209            Entry::Content(_) => None,
210        }
211    }
212
213    pub fn editor_for_diff(&self, diff: &Entity<acp_thread::Diff>) -> Option<Entity<Editor>> {
214        self.content_map()?
215            .get(&diff.entity_id())
216            .cloned()
217            .map(|entity| entity.downcast::<Editor>().unwrap())
218    }
219
220    pub fn terminal(
221        &self,
222        terminal: &Entity<acp_thread::Terminal>,
223    ) -> Option<Entity<TerminalView>> {
224        self.content_map()?
225            .get(&terminal.entity_id())
226            .cloned()
227            .map(|entity| entity.downcast::<TerminalView>().unwrap())
228    }
229
230    fn content_map(&self) -> Option<&HashMap<EntityId, AnyEntity>> {
231        match self {
232            Self::Content(map) => Some(map),
233            _ => None,
234        }
235    }
236
237    fn empty() -> Self {
238        Self::Content(HashMap::default())
239    }
240
241    #[cfg(test)]
242    pub fn has_content(&self) -> bool {
243        match self {
244            Self::Content(map) => !map.is_empty(),
245            Self::UserMessage(_) => false,
246        }
247    }
248}
249
250fn create_terminal(
251    workspace: WeakEntity<Workspace>,
252    project: Entity<Project>,
253    terminal: Entity<acp_thread::Terminal>,
254    window: &mut Window,
255    cx: &mut App,
256) -> Entity<TerminalView> {
257    cx.new(|cx| {
258        let mut view = TerminalView::new(
259            terminal.read(cx).inner().clone(),
260            workspace.clone(),
261            None,
262            project.downgrade(),
263            window,
264            cx,
265        );
266        view.set_embedded_mode(Some(1000), cx);
267        view
268    })
269}
270
271fn create_editor_diff(
272    diff: Entity<acp_thread::Diff>,
273    window: &mut Window,
274    cx: &mut App,
275) -> Entity<Editor> {
276    cx.new(|cx| {
277        let mut editor = Editor::new(
278            EditorMode::Full {
279                scale_ui_elements_with_buffer_font_size: false,
280                show_active_line_background: false,
281                sized_by_content: true,
282            },
283            diff.read(cx).multibuffer().clone(),
284            None,
285            window,
286            cx,
287        );
288        editor.set_show_gutter(false, cx);
289        editor.disable_inline_diagnostics();
290        editor.disable_expand_excerpt_buttons(cx);
291        editor.set_show_vertical_scrollbar(false, cx);
292        editor.set_minimap_visibility(MinimapVisibility::Disabled, window, cx);
293        editor.set_soft_wrap_mode(SoftWrap::None, cx);
294        editor.scroll_manager.set_forbid_vertical_scroll(true);
295        editor.set_show_indent_guides(false, cx);
296        editor.set_read_only(true);
297        editor.set_show_breakpoints(false, cx);
298        editor.set_show_code_actions(false, cx);
299        editor.set_show_git_diff_gutter(false, cx);
300        editor.set_expand_all_diff_hunks(cx);
301        editor.set_text_style_refinement(diff_editor_text_style_refinement(cx));
302        editor
303    })
304}
305
306fn diff_editor_text_style_refinement(cx: &mut App) -> TextStyleRefinement {
307    TextStyleRefinement {
308        font_size: Some(
309            TextSize::Small
310                .rems(cx)
311                .to_pixels(ThemeSettings::get_global(cx).agent_font_size(cx))
312                .into(),
313        ),
314        ..Default::default()
315    }
316}
317
318#[cfg(test)]
319mod tests {
320    use std::{path::Path, rc::Rc};
321
322    use acp_thread::{AgentConnection, StubAgentConnection};
323    use agent_client_protocol as acp;
324    use agent_settings::AgentSettings;
325    use agent2::HistoryStore;
326    use assistant_context::ContextStore;
327    use buffer_diff::{DiffHunkStatus, DiffHunkStatusKind};
328    use editor::{EditorSettings, RowInfo};
329    use fs::FakeFs;
330    use gpui::{AppContext as _, SemanticVersion, TestAppContext};
331
332    use crate::acp::entry_view_state::EntryViewState;
333    use multi_buffer::MultiBufferRow;
334    use pretty_assertions::assert_matches;
335    use project::Project;
336    use serde_json::json;
337    use settings::{Settings as _, SettingsStore};
338    use theme::ThemeSettings;
339    use util::path;
340    use workspace::Workspace;
341
342    #[gpui::test]
343    async fn test_diff_sync(cx: &mut TestAppContext) {
344        init_test(cx);
345        let fs = FakeFs::new(cx.executor());
346        fs.insert_tree(
347            "/project",
348            json!({
349                "hello.txt": "hi world"
350            }),
351        )
352        .await;
353        let project = Project::test(fs, [Path::new(path!("/project"))], cx).await;
354
355        let (workspace, cx) =
356            cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
357
358        let tool_call = acp::ToolCall {
359            id: acp::ToolCallId("tool".into()),
360            title: "Tool call".into(),
361            kind: acp::ToolKind::Other,
362            status: acp::ToolCallStatus::InProgress,
363            content: vec![acp::ToolCallContent::Diff {
364                diff: acp::Diff {
365                    path: "/project/hello.txt".into(),
366                    old_text: Some("hi world".into()),
367                    new_text: "hello world".into(),
368                },
369            }],
370            locations: vec![],
371            raw_input: None,
372            raw_output: None,
373        };
374        let connection = Rc::new(StubAgentConnection::new());
375        let thread = cx
376            .update(|_, cx| {
377                connection
378                    .clone()
379                    .new_thread(project.clone(), Path::new(path!("/project")), cx)
380            })
381            .await
382            .unwrap();
383        let session_id = thread.update(cx, |thread, _| thread.session_id().clone());
384
385        cx.update(|_, cx| {
386            connection.send_update(session_id, acp::SessionUpdate::ToolCall(tool_call), cx)
387        });
388
389        let context_store = cx.new(|cx| ContextStore::fake(project.clone(), cx));
390        let history_store = cx.new(|cx| HistoryStore::new(context_store, cx));
391
392        let view_state = cx.new(|_cx| {
393            EntryViewState::new(
394                workspace.downgrade(),
395                project.clone(),
396                history_store,
397                None,
398                false,
399            )
400        });
401
402        view_state.update_in(cx, |view_state, window, cx| {
403            view_state.sync_entry(0, &thread, window, cx)
404        });
405
406        let diff = thread.read_with(cx, |thread, _cx| {
407            thread
408                .entries()
409                .get(0)
410                .unwrap()
411                .diffs()
412                .next()
413                .unwrap()
414                .clone()
415        });
416
417        cx.run_until_parked();
418
419        let diff_editor = view_state.read_with(cx, |view_state, _cx| {
420            view_state.entry(0).unwrap().editor_for_diff(&diff).unwrap()
421        });
422        assert_eq!(
423            diff_editor.read_with(cx, |editor, cx| editor.text(cx)),
424            "hi world\nhello world"
425        );
426        let row_infos = diff_editor.read_with(cx, |editor, cx| {
427            let multibuffer = editor.buffer().read(cx);
428            multibuffer
429                .snapshot(cx)
430                .row_infos(MultiBufferRow(0))
431                .collect::<Vec<_>>()
432        });
433        assert_matches!(
434            row_infos.as_slice(),
435            [
436                RowInfo {
437                    multibuffer_row: Some(MultiBufferRow(0)),
438                    diff_status: Some(DiffHunkStatus {
439                        kind: DiffHunkStatusKind::Deleted,
440                        ..
441                    }),
442                    ..
443                },
444                RowInfo {
445                    multibuffer_row: Some(MultiBufferRow(1)),
446                    diff_status: Some(DiffHunkStatus {
447                        kind: DiffHunkStatusKind::Added,
448                        ..
449                    }),
450                    ..
451                }
452            ]
453        );
454    }
455
456    fn init_test(cx: &mut TestAppContext) {
457        cx.update(|cx| {
458            let settings_store = SettingsStore::test(cx);
459            cx.set_global(settings_store);
460            language::init(cx);
461            Project::init_settings(cx);
462            AgentSettings::register(cx);
463            workspace::init_settings(cx);
464            ThemeSettings::register(cx);
465            release_channel::init(SemanticVersion::default(), cx);
466            EditorSettings::register(cx);
467        });
468    }
469}