entry_view_state.rs

  1use std::ops::Range;
  2
  3use acp_thread::{AcpThread, AgentThreadEntry};
  4use agent2::HistoryStore;
  5use collections::HashMap;
  6use editor::{Editor, EditorMode, MinimapVisibility};
  7use gpui::{
  8    AnyEntity, App, AppContext as _, Entity, EntityId, EventEmitter, Focusable,
  9    TextStyleRefinement, WeakEntity, Window,
 10};
 11use language::language_settings::SoftWrap;
 12use project::Project;
 13use prompt_store::PromptStore;
 14use settings::Settings as _;
 15use terminal_view::TerminalView;
 16use theme::ThemeSettings;
 17use ui::{Context, TextSize};
 18use workspace::Workspace;
 19
 20use crate::acp::message_editor::{MessageEditor, MessageEditorEvent};
 21
 22pub struct EntryViewState {
 23    workspace: WeakEntity<Workspace>,
 24    project: Entity<Project>,
 25    history_store: Entity<HistoryStore>,
 26    prompt_store: Option<Entity<PromptStore>>,
 27    entries: Vec<Entry>,
 28    prevent_slash_commands: bool,
 29}
 30
 31impl EntryViewState {
 32    pub fn new(
 33        workspace: WeakEntity<Workspace>,
 34        project: Entity<Project>,
 35        history_store: Entity<HistoryStore>,
 36        prompt_store: Option<Entity<PromptStore>>,
 37        prevent_slash_commands: bool,
 38    ) -> Self {
 39        Self {
 40            workspace,
 41            project,
 42            history_store,
 43            prompt_store,
 44            entries: Vec::new(),
 45            prevent_slash_commands,
 46        }
 47    }
 48
 49    pub fn entry(&self, index: usize) -> Option<&Entry> {
 50        self.entries.get(index)
 51    }
 52
 53    pub fn sync_entry(
 54        &mut self,
 55        index: usize,
 56        thread: &Entity<AcpThread>,
 57        window: &mut Window,
 58        cx: &mut Context<Self>,
 59    ) {
 60        let Some(thread_entry) = thread.read(cx).entries().get(index) else {
 61            return;
 62        };
 63
 64        match thread_entry {
 65            AgentThreadEntry::UserMessage(message) => {
 66                let has_id = message.id.is_some();
 67                let chunks = message.chunks.clone();
 68                if let Some(Entry::UserMessage(editor)) = self.entries.get_mut(index) {
 69                    if !editor.focus_handle(cx).is_focused(window) {
 70                        // Only update if we are not editing.
 71                        // If we are, cancelling the edit will set the message to the newest content.
 72                        editor.update(cx, |editor, cx| {
 73                            editor.set_message(chunks, window, cx);
 74                        });
 75                    }
 76                } else {
 77                    let message_editor = cx.new(|cx| {
 78                        let mut editor = MessageEditor::new(
 79                            self.workspace.clone(),
 80                            self.project.clone(),
 81                            self.history_store.clone(),
 82                            self.prompt_store.clone(),
 83                            "Edit message - @ to include context",
 84                            self.prevent_slash_commands,
 85                            editor::EditorMode::AutoHeight {
 86                                min_lines: 1,
 87                                max_lines: None,
 88                            },
 89                            window,
 90                            cx,
 91                        );
 92                        if !has_id {
 93                            editor.set_read_only(true, cx);
 94                        }
 95                        editor.set_message(chunks, window, cx);
 96                        editor
 97                    });
 98                    cx.subscribe(&message_editor, move |_, editor, event, cx| {
 99                        cx.emit(EntryViewEvent {
100                            entry_index: index,
101                            view_event: ViewEvent::MessageEditorEvent(editor, *event),
102                        })
103                    })
104                    .detach();
105                    self.set_entry(index, Entry::UserMessage(message_editor));
106                }
107            }
108            AgentThreadEntry::ToolCall(tool_call) => {
109                let terminals = tool_call.terminals().cloned().collect::<Vec<_>>();
110                let diffs = tool_call.diffs().cloned().collect::<Vec<_>>();
111
112                let views = if let Some(Entry::Content(views)) = self.entries.get_mut(index) {
113                    views
114                } else {
115                    self.set_entry(index, Entry::empty());
116                    let Some(Entry::Content(views)) = self.entries.get_mut(index) else {
117                        unreachable!()
118                    };
119                    views
120                };
121
122                for terminal in terminals {
123                    views.entry(terminal.entity_id()).or_insert_with(|| {
124                        create_terminal(
125                            self.workspace.clone(),
126                            self.project.clone(),
127                            terminal.clone(),
128                            window,
129                            cx,
130                        )
131                        .into_any()
132                    });
133                }
134
135                for diff in diffs {
136                    views
137                        .entry(diff.entity_id())
138                        .or_insert_with(|| create_editor_diff(diff.clone(), window, cx).into_any());
139                }
140            }
141            AgentThreadEntry::AssistantMessage(_) => {
142                if index == self.entries.len() {
143                    self.entries.push(Entry::empty())
144                }
145            }
146        };
147    }
148
149    fn set_entry(&mut self, index: usize, entry: Entry) {
150        if index == self.entries.len() {
151            self.entries.push(entry);
152        } else {
153            self.entries[index] = entry;
154        }
155    }
156
157    pub fn remove(&mut self, range: Range<usize>) {
158        self.entries.drain(range);
159    }
160
161    pub fn settings_changed(&mut self, cx: &mut App) {
162        for entry in self.entries.iter() {
163            match entry {
164                Entry::UserMessage { .. } => {}
165                Entry::Content(response_views) => {
166                    for view in response_views.values() {
167                        if let Ok(diff_editor) = view.clone().downcast::<Editor>() {
168                            diff_editor.update(cx, |diff_editor, cx| {
169                                diff_editor.set_text_style_refinement(
170                                    diff_editor_text_style_refinement(cx),
171                                );
172                                cx.notify();
173                            })
174                        }
175                    }
176                }
177            }
178        }
179    }
180}
181
182impl EventEmitter<EntryViewEvent> for EntryViewState {}
183
184pub struct EntryViewEvent {
185    pub entry_index: usize,
186    pub view_event: ViewEvent,
187}
188
189pub enum ViewEvent {
190    MessageEditorEvent(Entity<MessageEditor>, MessageEditorEvent),
191}
192
193#[derive(Debug)]
194pub enum Entry {
195    UserMessage(Entity<MessageEditor>),
196    Content(HashMap<EntityId, AnyEntity>),
197}
198
199impl Entry {
200    pub fn message_editor(&self) -> Option<&Entity<MessageEditor>> {
201        match self {
202            Self::UserMessage(editor) => Some(editor),
203            Entry::Content(_) => None,
204        }
205    }
206
207    pub fn editor_for_diff(&self, diff: &Entity<acp_thread::Diff>) -> Option<Entity<Editor>> {
208        self.content_map()?
209            .get(&diff.entity_id())
210            .cloned()
211            .map(|entity| entity.downcast::<Editor>().unwrap())
212    }
213
214    pub fn terminal(
215        &self,
216        terminal: &Entity<acp_thread::Terminal>,
217    ) -> Option<Entity<TerminalView>> {
218        self.content_map()?
219            .get(&terminal.entity_id())
220            .cloned()
221            .map(|entity| entity.downcast::<TerminalView>().unwrap())
222    }
223
224    fn content_map(&self) -> Option<&HashMap<EntityId, AnyEntity>> {
225        match self {
226            Self::Content(map) => Some(map),
227            _ => None,
228        }
229    }
230
231    fn empty() -> Self {
232        Self::Content(HashMap::default())
233    }
234
235    #[cfg(test)]
236    pub fn has_content(&self) -> bool {
237        match self {
238            Self::Content(map) => !map.is_empty(),
239            Self::UserMessage(_) => false,
240        }
241    }
242}
243
244fn create_terminal(
245    workspace: WeakEntity<Workspace>,
246    project: Entity<Project>,
247    terminal: Entity<acp_thread::Terminal>,
248    window: &mut Window,
249    cx: &mut App,
250) -> Entity<TerminalView> {
251    cx.new(|cx| {
252        let mut view = TerminalView::new(
253            terminal.read(cx).inner().clone(),
254            workspace.clone(),
255            None,
256            project.downgrade(),
257            window,
258            cx,
259        );
260        view.set_embedded_mode(Some(1000), cx);
261        view
262    })
263}
264
265fn create_editor_diff(
266    diff: Entity<acp_thread::Diff>,
267    window: &mut Window,
268    cx: &mut App,
269) -> Entity<Editor> {
270    cx.new(|cx| {
271        let mut editor = Editor::new(
272            EditorMode::Full {
273                scale_ui_elements_with_buffer_font_size: false,
274                show_active_line_background: false,
275                sized_by_content: true,
276            },
277            diff.read(cx).multibuffer().clone(),
278            None,
279            window,
280            cx,
281        );
282        editor.set_show_gutter(false, cx);
283        editor.disable_inline_diagnostics();
284        editor.disable_expand_excerpt_buttons(cx);
285        editor.set_show_vertical_scrollbar(false, cx);
286        editor.set_minimap_visibility(MinimapVisibility::Disabled, window, cx);
287        editor.set_soft_wrap_mode(SoftWrap::None, cx);
288        editor.scroll_manager.set_forbid_vertical_scroll(true);
289        editor.set_show_indent_guides(false, cx);
290        editor.set_read_only(true);
291        editor.set_show_breakpoints(false, cx);
292        editor.set_show_code_actions(false, cx);
293        editor.set_show_git_diff_gutter(false, cx);
294        editor.set_expand_all_diff_hunks(cx);
295        editor.set_text_style_refinement(diff_editor_text_style_refinement(cx));
296        editor
297    })
298}
299
300fn diff_editor_text_style_refinement(cx: &mut App) -> TextStyleRefinement {
301    TextStyleRefinement {
302        font_size: Some(
303            TextSize::Small
304                .rems(cx)
305                .to_pixels(ThemeSettings::get_global(cx).agent_font_size(cx))
306                .into(),
307        ),
308        ..Default::default()
309    }
310}
311
312#[cfg(test)]
313mod tests {
314    use std::{path::Path, rc::Rc};
315
316    use acp_thread::{AgentConnection, StubAgentConnection};
317    use agent_client_protocol as acp;
318    use agent_settings::AgentSettings;
319    use agent2::HistoryStore;
320    use assistant_context::ContextStore;
321    use buffer_diff::{DiffHunkStatus, DiffHunkStatusKind};
322    use editor::{EditorSettings, RowInfo};
323    use fs::FakeFs;
324    use gpui::{AppContext as _, SemanticVersion, TestAppContext};
325
326    use crate::acp::entry_view_state::EntryViewState;
327    use multi_buffer::MultiBufferRow;
328    use pretty_assertions::assert_matches;
329    use project::Project;
330    use serde_json::json;
331    use settings::{Settings as _, SettingsStore};
332    use theme::ThemeSettings;
333    use util::path;
334    use workspace::Workspace;
335
336    #[gpui::test]
337    async fn test_diff_sync(cx: &mut TestAppContext) {
338        init_test(cx);
339        let fs = FakeFs::new(cx.executor());
340        fs.insert_tree(
341            "/project",
342            json!({
343                "hello.txt": "hi world"
344            }),
345        )
346        .await;
347        let project = Project::test(fs, [Path::new(path!("/project"))], cx).await;
348
349        let (workspace, cx) =
350            cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
351
352        let tool_call = acp::ToolCall {
353            id: acp::ToolCallId("tool".into()),
354            title: "Tool call".into(),
355            kind: acp::ToolKind::Other,
356            status: acp::ToolCallStatus::InProgress,
357            content: vec![acp::ToolCallContent::Diff {
358                diff: acp::Diff {
359                    path: "/project/hello.txt".into(),
360                    old_text: Some("hi world".into()),
361                    new_text: "hello world".into(),
362                },
363            }],
364            locations: vec![],
365            raw_input: None,
366            raw_output: None,
367        };
368        let connection = Rc::new(StubAgentConnection::new());
369        let thread = cx
370            .update(|_, cx| {
371                connection
372                    .clone()
373                    .new_thread(project.clone(), Path::new(path!("/project")), cx)
374            })
375            .await
376            .unwrap();
377        let session_id = thread.update(cx, |thread, _| thread.session_id().clone());
378
379        cx.update(|_, cx| {
380            connection.send_update(session_id, acp::SessionUpdate::ToolCall(tool_call), cx)
381        });
382
383        let context_store = cx.new(|cx| ContextStore::fake(project.clone(), cx));
384        let history_store = cx.new(|cx| HistoryStore::new(context_store, cx));
385
386        let view_state = cx.new(|_cx| {
387            EntryViewState::new(
388                workspace.downgrade(),
389                project.clone(),
390                history_store,
391                None,
392                false,
393            )
394        });
395
396        view_state.update_in(cx, |view_state, window, cx| {
397            view_state.sync_entry(0, &thread, window, cx)
398        });
399
400        let diff = thread.read_with(cx, |thread, _cx| {
401            thread
402                .entries()
403                .get(0)
404                .unwrap()
405                .diffs()
406                .next()
407                .unwrap()
408                .clone()
409        });
410
411        cx.run_until_parked();
412
413        let diff_editor = view_state.read_with(cx, |view_state, _cx| {
414            view_state.entry(0).unwrap().editor_for_diff(&diff).unwrap()
415        });
416        assert_eq!(
417            diff_editor.read_with(cx, |editor, cx| editor.text(cx)),
418            "hi world\nhello world"
419        );
420        let row_infos = diff_editor.read_with(cx, |editor, cx| {
421            let multibuffer = editor.buffer().read(cx);
422            multibuffer
423                .snapshot(cx)
424                .row_infos(MultiBufferRow(0))
425                .collect::<Vec<_>>()
426        });
427        assert_matches!(
428            row_infos.as_slice(),
429            [
430                RowInfo {
431                    multibuffer_row: Some(MultiBufferRow(0)),
432                    diff_status: Some(DiffHunkStatus {
433                        kind: DiffHunkStatusKind::Deleted,
434                        ..
435                    }),
436                    ..
437                },
438                RowInfo {
439                    multibuffer_row: Some(MultiBufferRow(1)),
440                    diff_status: Some(DiffHunkStatus {
441                        kind: DiffHunkStatusKind::Added,
442                        ..
443                    }),
444                    ..
445                }
446            ]
447        );
448    }
449
450    fn init_test(cx: &mut TestAppContext) {
451        cx.update(|cx| {
452            let settings_store = SettingsStore::test(cx);
453            cx.set_global(settings_store);
454            language::init(cx);
455            Project::init_settings(cx);
456            AgentSettings::register(cx);
457            workspace::init_settings(cx);
458            ThemeSettings::register(cx);
459            release_channel::init(SemanticVersion::default(), cx);
460            EditorSettings::register(cx);
461        });
462    }
463}