entry_view_state.rs

  1use std::ops::Range;
  2
  3use acp_thread::{AcpThread, AgentThreadEntry};
  4use agent_client_protocol::ToolCallId;
  5use agent2::HistoryStore;
  6use collections::HashMap;
  7use editor::{Editor, EditorMode, MinimapVisibility};
  8use gpui::{
  9    AnyEntity, App, AppContext as _, Entity, EntityId, EventEmitter, Focusable,
 10    TextStyleRefinement, WeakEntity, Window,
 11};
 12use language::language_settings::SoftWrap;
 13use project::Project;
 14use prompt_store::PromptStore;
 15use settings::Settings as _;
 16use terminal_view::TerminalView;
 17use theme::ThemeSettings;
 18use ui::{Context, TextSize};
 19use workspace::Workspace;
 20
 21use crate::acp::message_editor::{MessageEditor, MessageEditorEvent};
 22
 23pub struct EntryViewState {
 24    workspace: WeakEntity<Workspace>,
 25    project: Entity<Project>,
 26    history_store: Entity<HistoryStore>,
 27    prompt_store: Option<Entity<PromptStore>>,
 28    entries: Vec<Entry>,
 29    prevent_slash_commands: bool,
 30}
 31
 32impl EntryViewState {
 33    pub fn new(
 34        workspace: WeakEntity<Workspace>,
 35        project: Entity<Project>,
 36        history_store: Entity<HistoryStore>,
 37        prompt_store: Option<Entity<PromptStore>>,
 38        prevent_slash_commands: bool,
 39    ) -> Self {
 40        Self {
 41            workspace,
 42            project,
 43            history_store,
 44            prompt_store,
 45            entries: Vec::new(),
 46            prevent_slash_commands,
 47        }
 48    }
 49
 50    pub fn entry(&self, index: usize) -> Option<&Entry> {
 51        self.entries.get(index)
 52    }
 53
 54    pub fn sync_entry(
 55        &mut self,
 56        index: usize,
 57        thread: &Entity<AcpThread>,
 58        window: &mut Window,
 59        cx: &mut Context<Self>,
 60    ) {
 61        let Some(thread_entry) = thread.read(cx).entries().get(index) else {
 62            return;
 63        };
 64
 65        match thread_entry {
 66            AgentThreadEntry::UserMessage(message) => {
 67                let has_id = message.id.is_some();
 68                let chunks = message.chunks.clone();
 69                if let Some(Entry::UserMessage(editor)) = self.entries.get_mut(index) {
 70                    if !editor.focus_handle(cx).is_focused(window) {
 71                        // Only update if we are not editing.
 72                        // If we are, cancelling the edit will set the message to the newest content.
 73                        editor.update(cx, |editor, cx| {
 74                            editor.set_message(chunks, window, cx);
 75                        });
 76                    }
 77                } else {
 78                    let message_editor = cx.new(|cx| {
 79                        let mut editor = MessageEditor::new(
 80                            self.workspace.clone(),
 81                            self.project.clone(),
 82                            self.history_store.clone(),
 83                            self.prompt_store.clone(),
 84                            "Edit message - @ to include context",
 85                            self.prevent_slash_commands,
 86                            editor::EditorMode::AutoHeight {
 87                                min_lines: 1,
 88                                max_lines: None,
 89                            },
 90                            window,
 91                            cx,
 92                        );
 93                        if !has_id {
 94                            editor.set_read_only(true, cx);
 95                        }
 96                        editor.set_message(chunks, window, cx);
 97                        editor
 98                    });
 99                    cx.subscribe(&message_editor, move |_, editor, event, cx| {
100                        cx.emit(EntryViewEvent {
101                            entry_index: index,
102                            view_event: ViewEvent::MessageEditorEvent(editor, *event),
103                        })
104                    })
105                    .detach();
106                    self.set_entry(index, Entry::UserMessage(message_editor));
107                }
108            }
109            AgentThreadEntry::ToolCall(tool_call) => {
110                let id = tool_call.id.clone();
111                let terminals = tool_call.terminals().cloned().collect::<Vec<_>>();
112                let diffs = tool_call.diffs().cloned().collect::<Vec<_>>();
113
114                let views = if let Some(Entry::Content(views)) = self.entries.get_mut(index) {
115                    views
116                } else {
117                    self.set_entry(index, Entry::empty());
118                    let Some(Entry::Content(views)) = self.entries.get_mut(index) else {
119                        unreachable!()
120                    };
121                    views
122                };
123
124                for terminal in terminals {
125                    views.entry(terminal.entity_id()).or_insert_with(|| {
126                        let element = create_terminal(
127                            self.workspace.clone(),
128                            self.project.clone(),
129                            terminal.clone(),
130                            window,
131                            cx,
132                        )
133                        .into_any();
134                        cx.emit(EntryViewEvent {
135                            entry_index: index,
136                            view_event: ViewEvent::NewTerminal(id.clone()),
137                        });
138                        element
139                    });
140                }
141
142                for diff in diffs {
143                    views.entry(diff.entity_id()).or_insert_with(|| {
144                        let element = create_editor_diff(diff.clone(), window, cx).into_any();
145                        cx.emit(EntryViewEvent {
146                            entry_index: index,
147                            view_event: ViewEvent::NewDiff(id.clone()),
148                        });
149                        element
150                    });
151                }
152            }
153            AgentThreadEntry::AssistantMessage(_) => {
154                if index == self.entries.len() {
155                    self.entries.push(Entry::empty())
156                }
157            }
158        };
159    }
160
161    fn set_entry(&mut self, index: usize, entry: Entry) {
162        if index == self.entries.len() {
163            self.entries.push(entry);
164        } else {
165            self.entries[index] = entry;
166        }
167    }
168
169    pub fn remove(&mut self, range: Range<usize>) {
170        self.entries.drain(range);
171    }
172
173    pub fn settings_changed(&mut self, cx: &mut App) {
174        for entry in self.entries.iter() {
175            match entry {
176                Entry::UserMessage { .. } => {}
177                Entry::Content(response_views) => {
178                    for view in response_views.values() {
179                        if let Ok(diff_editor) = view.clone().downcast::<Editor>() {
180                            diff_editor.update(cx, |diff_editor, cx| {
181                                diff_editor.set_text_style_refinement(
182                                    diff_editor_text_style_refinement(cx),
183                                );
184                                cx.notify();
185                            })
186                        }
187                    }
188                }
189            }
190        }
191    }
192}
193
194impl EventEmitter<EntryViewEvent> for EntryViewState {}
195
196pub struct EntryViewEvent {
197    pub entry_index: usize,
198    pub view_event: ViewEvent,
199}
200
201pub enum ViewEvent {
202    NewDiff(ToolCallId),
203    NewTerminal(ToolCallId),
204    MessageEditorEvent(Entity<MessageEditor>, MessageEditorEvent),
205}
206
207#[derive(Debug)]
208pub enum Entry {
209    UserMessage(Entity<MessageEditor>),
210    Content(HashMap<EntityId, AnyEntity>),
211}
212
213impl Entry {
214    pub fn message_editor(&self) -> Option<&Entity<MessageEditor>> {
215        match self {
216            Self::UserMessage(editor) => Some(editor),
217            Entry::Content(_) => None,
218        }
219    }
220
221    pub fn editor_for_diff(&self, diff: &Entity<acp_thread::Diff>) -> Option<Entity<Editor>> {
222        self.content_map()?
223            .get(&diff.entity_id())
224            .cloned()
225            .map(|entity| entity.downcast::<Editor>().unwrap())
226    }
227
228    pub fn terminal(
229        &self,
230        terminal: &Entity<acp_thread::Terminal>,
231    ) -> Option<Entity<TerminalView>> {
232        self.content_map()?
233            .get(&terminal.entity_id())
234            .cloned()
235            .map(|entity| entity.downcast::<TerminalView>().unwrap())
236    }
237
238    fn content_map(&self) -> Option<&HashMap<EntityId, AnyEntity>> {
239        match self {
240            Self::Content(map) => Some(map),
241            _ => None,
242        }
243    }
244
245    fn empty() -> Self {
246        Self::Content(HashMap::default())
247    }
248
249    #[cfg(test)]
250    pub fn has_content(&self) -> bool {
251        match self {
252            Self::Content(map) => !map.is_empty(),
253            Self::UserMessage(_) => false,
254        }
255    }
256}
257
258fn create_terminal(
259    workspace: WeakEntity<Workspace>,
260    project: Entity<Project>,
261    terminal: Entity<acp_thread::Terminal>,
262    window: &mut Window,
263    cx: &mut App,
264) -> Entity<TerminalView> {
265    cx.new(|cx| {
266        let mut view = TerminalView::new(
267            terminal.read(cx).inner().clone(),
268            workspace.clone(),
269            None,
270            project.downgrade(),
271            window,
272            cx,
273        );
274        view.set_embedded_mode(Some(1000), cx);
275        view
276    })
277}
278
279fn create_editor_diff(
280    diff: Entity<acp_thread::Diff>,
281    window: &mut Window,
282    cx: &mut App,
283) -> Entity<Editor> {
284    cx.new(|cx| {
285        let mut editor = Editor::new(
286            EditorMode::Full {
287                scale_ui_elements_with_buffer_font_size: false,
288                show_active_line_background: false,
289                sized_by_content: true,
290            },
291            diff.read(cx).multibuffer().clone(),
292            None,
293            window,
294            cx,
295        );
296        editor.set_show_gutter(false, cx);
297        editor.disable_inline_diagnostics();
298        editor.disable_expand_excerpt_buttons(cx);
299        editor.set_show_vertical_scrollbar(false, cx);
300        editor.set_minimap_visibility(MinimapVisibility::Disabled, window, cx);
301        editor.set_soft_wrap_mode(SoftWrap::None, cx);
302        editor.scroll_manager.set_forbid_vertical_scroll(true);
303        editor.set_show_indent_guides(false, cx);
304        editor.set_read_only(true);
305        editor.set_show_breakpoints(false, cx);
306        editor.set_show_code_actions(false, cx);
307        editor.set_show_git_diff_gutter(false, cx);
308        editor.set_expand_all_diff_hunks(cx);
309        editor.set_text_style_refinement(diff_editor_text_style_refinement(cx));
310        editor
311    })
312}
313
314fn diff_editor_text_style_refinement(cx: &mut App) -> TextStyleRefinement {
315    TextStyleRefinement {
316        font_size: Some(
317            TextSize::Small
318                .rems(cx)
319                .to_pixels(ThemeSettings::get_global(cx).agent_font_size(cx))
320                .into(),
321        ),
322        ..Default::default()
323    }
324}
325
326#[cfg(test)]
327mod tests {
328    use std::{path::Path, rc::Rc};
329
330    use acp_thread::{AgentConnection, StubAgentConnection};
331    use agent_client_protocol as acp;
332    use agent_settings::AgentSettings;
333    use agent2::HistoryStore;
334    use assistant_context::ContextStore;
335    use buffer_diff::{DiffHunkStatus, DiffHunkStatusKind};
336    use editor::{EditorSettings, RowInfo};
337    use fs::FakeFs;
338    use gpui::{AppContext as _, SemanticVersion, TestAppContext};
339
340    use crate::acp::entry_view_state::EntryViewState;
341    use multi_buffer::MultiBufferRow;
342    use pretty_assertions::assert_matches;
343    use project::Project;
344    use serde_json::json;
345    use settings::{Settings as _, SettingsStore};
346    use theme::ThemeSettings;
347    use util::path;
348    use workspace::Workspace;
349
350    #[gpui::test]
351    async fn test_diff_sync(cx: &mut TestAppContext) {
352        init_test(cx);
353        let fs = FakeFs::new(cx.executor());
354        fs.insert_tree(
355            "/project",
356            json!({
357                "hello.txt": "hi world"
358            }),
359        )
360        .await;
361        let project = Project::test(fs, [Path::new(path!("/project"))], cx).await;
362
363        let (workspace, cx) =
364            cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
365
366        let tool_call = acp::ToolCall {
367            id: acp::ToolCallId("tool".into()),
368            title: "Tool call".into(),
369            kind: acp::ToolKind::Other,
370            status: acp::ToolCallStatus::InProgress,
371            content: vec![acp::ToolCallContent::Diff {
372                diff: acp::Diff {
373                    path: "/project/hello.txt".into(),
374                    old_text: Some("hi world".into()),
375                    new_text: "hello world".into(),
376                },
377            }],
378            locations: vec![],
379            raw_input: None,
380            raw_output: None,
381        };
382        let connection = Rc::new(StubAgentConnection::new());
383        let thread = cx
384            .update(|_, cx| {
385                connection
386                    .clone()
387                    .new_thread(project.clone(), Path::new(path!("/project")), cx)
388            })
389            .await
390            .unwrap();
391        let session_id = thread.update(cx, |thread, _| thread.session_id().clone());
392
393        cx.update(|_, cx| {
394            connection.send_update(session_id, acp::SessionUpdate::ToolCall(tool_call), cx)
395        });
396
397        let context_store = cx.new(|cx| ContextStore::fake(project.clone(), cx));
398        let history_store = cx.new(|cx| HistoryStore::new(context_store, cx));
399
400        let view_state = cx.new(|_cx| {
401            EntryViewState::new(
402                workspace.downgrade(),
403                project.clone(),
404                history_store,
405                None,
406                false,
407            )
408        });
409
410        view_state.update_in(cx, |view_state, window, cx| {
411            view_state.sync_entry(0, &thread, window, cx)
412        });
413
414        let diff = thread.read_with(cx, |thread, _cx| {
415            thread
416                .entries()
417                .get(0)
418                .unwrap()
419                .diffs()
420                .next()
421                .unwrap()
422                .clone()
423        });
424
425        cx.run_until_parked();
426
427        let diff_editor = view_state.read_with(cx, |view_state, _cx| {
428            view_state.entry(0).unwrap().editor_for_diff(&diff).unwrap()
429        });
430        assert_eq!(
431            diff_editor.read_with(cx, |editor, cx| editor.text(cx)),
432            "hi world\nhello world"
433        );
434        let row_infos = diff_editor.read_with(cx, |editor, cx| {
435            let multibuffer = editor.buffer().read(cx);
436            multibuffer
437                .snapshot(cx)
438                .row_infos(MultiBufferRow(0))
439                .collect::<Vec<_>>()
440        });
441        assert_matches!(
442            row_infos.as_slice(),
443            [
444                RowInfo {
445                    multibuffer_row: Some(MultiBufferRow(0)),
446                    diff_status: Some(DiffHunkStatus {
447                        kind: DiffHunkStatusKind::Deleted,
448                        ..
449                    }),
450                    ..
451                },
452                RowInfo {
453                    multibuffer_row: Some(MultiBufferRow(1)),
454                    diff_status: Some(DiffHunkStatus {
455                        kind: DiffHunkStatusKind::Added,
456                        ..
457                    }),
458                    ..
459                }
460            ]
461        );
462    }
463
464    fn init_test(cx: &mut TestAppContext) {
465        cx.update(|cx| {
466            let settings_store = SettingsStore::test(cx);
467            cx.set_global(settings_store);
468            language::init(cx);
469            Project::init_settings(cx);
470            AgentSettings::register(cx);
471            workspace::init_settings(cx);
472            ThemeSettings::register(cx);
473            release_channel::init(SemanticVersion::default(), cx);
474            EditorSettings::register(cx);
475        });
476    }
477}