entry_view_state.rs

  1use std::ops::Range;
  2
  3use acp_thread::{AcpThread, AgentThreadEntry};
  4use agent::{TextThreadStore, ThreadStore};
  5use collections::HashMap;
  6use editor::{Editor, EditorMode, MinimapVisibility};
  7use gpui::{
  8    AnyEntity, App, AppContext as _, Entity, EntityId, EventEmitter, TextStyleRefinement,
  9    WeakEntity, Window,
 10};
 11use language::language_settings::SoftWrap;
 12use project::Project;
 13use settings::Settings as _;
 14use terminal_view::TerminalView;
 15use theme::ThemeSettings;
 16use ui::{Context, TextSize};
 17use workspace::Workspace;
 18
 19use crate::acp::message_editor::{MessageEditor, MessageEditorEvent};
 20
 21pub struct EntryViewState {
 22    workspace: WeakEntity<Workspace>,
 23    project: Entity<Project>,
 24    thread_store: Entity<ThreadStore>,
 25    text_thread_store: Entity<TextThreadStore>,
 26    entries: Vec<Entry>,
 27}
 28
 29impl EntryViewState {
 30    pub fn new(
 31        workspace: WeakEntity<Workspace>,
 32        project: Entity<Project>,
 33        thread_store: Entity<ThreadStore>,
 34        text_thread_store: Entity<TextThreadStore>,
 35    ) -> Self {
 36        Self {
 37            workspace,
 38            project,
 39            thread_store,
 40            text_thread_store,
 41            entries: Vec::new(),
 42        }
 43    }
 44
 45    pub fn entry(&self, index: usize) -> Option<&Entry> {
 46        self.entries.get(index)
 47    }
 48
 49    pub fn sync_entry(
 50        &mut self,
 51        index: usize,
 52        thread: &Entity<AcpThread>,
 53        window: &mut Window,
 54        cx: &mut Context<Self>,
 55    ) {
 56        let Some(thread_entry) = thread.read(cx).entries().get(index) else {
 57            return;
 58        };
 59
 60        match thread_entry {
 61            AgentThreadEntry::UserMessage(message) => {
 62                let has_id = message.id.is_some();
 63                let chunks = message.chunks.clone();
 64                let message_editor = cx.new(|cx| {
 65                    let mut editor = MessageEditor::new(
 66                        self.workspace.clone(),
 67                        self.project.clone(),
 68                        self.thread_store.clone(),
 69                        self.text_thread_store.clone(),
 70                        editor::EditorMode::AutoHeight {
 71                            min_lines: 1,
 72                            max_lines: None,
 73                        },
 74                        window,
 75                        cx,
 76                    );
 77                    if !has_id {
 78                        editor.set_read_only(true, cx);
 79                    }
 80                    editor.set_message(chunks, window, cx);
 81                    editor
 82                });
 83                cx.subscribe(&message_editor, move |_, editor, event, cx| {
 84                    cx.emit(EntryViewEvent {
 85                        entry_index: index,
 86                        view_event: ViewEvent::MessageEditorEvent(editor, *event),
 87                    })
 88                })
 89                .detach();
 90                self.set_entry(index, Entry::UserMessage(message_editor));
 91            }
 92            AgentThreadEntry::ToolCall(tool_call) => {
 93                let terminals = tool_call.terminals().cloned().collect::<Vec<_>>();
 94                let diffs = tool_call.diffs().cloned().collect::<Vec<_>>();
 95
 96                let views = if let Some(Entry::Content(views)) = self.entries.get_mut(index) {
 97                    views
 98                } else {
 99                    self.set_entry(index, Entry::empty());
100                    let Some(Entry::Content(views)) = self.entries.get_mut(index) else {
101                        unreachable!()
102                    };
103                    views
104                };
105
106                for terminal in terminals {
107                    views.entry(terminal.entity_id()).or_insert_with(|| {
108                        create_terminal(
109                            self.workspace.clone(),
110                            self.project.clone(),
111                            terminal.clone(),
112                            window,
113                            cx,
114                        )
115                        .into_any()
116                    });
117                }
118
119                for diff in diffs {
120                    views
121                        .entry(diff.entity_id())
122                        .or_insert_with(|| create_editor_diff(diff.clone(), window, cx).into_any());
123                }
124            }
125            AgentThreadEntry::AssistantMessage(_) => {
126                if index == self.entries.len() {
127                    self.entries.push(Entry::empty())
128                }
129            }
130        };
131    }
132
133    fn set_entry(&mut self, index: usize, entry: Entry) {
134        if index == self.entries.len() {
135            self.entries.push(entry);
136        } else {
137            self.entries[index] = entry;
138        }
139    }
140
141    pub fn remove(&mut self, range: Range<usize>) {
142        self.entries.drain(range);
143    }
144
145    pub fn settings_changed(&mut self, cx: &mut App) {
146        for entry in self.entries.iter() {
147            match entry {
148                Entry::UserMessage { .. } => {}
149                Entry::Content(response_views) => {
150                    for view in response_views.values() {
151                        if let Ok(diff_editor) = view.clone().downcast::<Editor>() {
152                            diff_editor.update(cx, |diff_editor, cx| {
153                                diff_editor.set_text_style_refinement(
154                                    diff_editor_text_style_refinement(cx),
155                                );
156                                cx.notify();
157                            })
158                        }
159                    }
160                }
161            }
162        }
163    }
164}
165
166impl EventEmitter<EntryViewEvent> for EntryViewState {}
167
168pub struct EntryViewEvent {
169    pub entry_index: usize,
170    pub view_event: ViewEvent,
171}
172
173pub enum ViewEvent {
174    MessageEditorEvent(Entity<MessageEditor>, MessageEditorEvent),
175}
176
177pub enum Entry {
178    UserMessage(Entity<MessageEditor>),
179    Content(HashMap<EntityId, AnyEntity>),
180}
181
182impl Entry {
183    pub fn message_editor(&self) -> Option<&Entity<MessageEditor>> {
184        match self {
185            Self::UserMessage(editor) => Some(editor),
186            Entry::Content(_) => None,
187        }
188    }
189
190    pub fn editor_for_diff(&self, diff: &Entity<acp_thread::Diff>) -> Option<Entity<Editor>> {
191        self.content_map()?
192            .get(&diff.entity_id())
193            .cloned()
194            .map(|entity| entity.downcast::<Editor>().unwrap())
195    }
196
197    pub fn terminal(
198        &self,
199        terminal: &Entity<acp_thread::Terminal>,
200    ) -> Option<Entity<TerminalView>> {
201        self.content_map()?
202            .get(&terminal.entity_id())
203            .cloned()
204            .map(|entity| entity.downcast::<TerminalView>().unwrap())
205    }
206
207    fn content_map(&self) -> Option<&HashMap<EntityId, AnyEntity>> {
208        match self {
209            Self::Content(map) => Some(map),
210            _ => None,
211        }
212    }
213
214    fn empty() -> Self {
215        Self::Content(HashMap::default())
216    }
217
218    #[cfg(test)]
219    pub fn has_content(&self) -> bool {
220        match self {
221            Self::Content(map) => !map.is_empty(),
222            Self::UserMessage(_) => false,
223        }
224    }
225}
226
227fn create_terminal(
228    workspace: WeakEntity<Workspace>,
229    project: Entity<Project>,
230    terminal: Entity<acp_thread::Terminal>,
231    window: &mut Window,
232    cx: &mut App,
233) -> Entity<TerminalView> {
234    cx.new(|cx| {
235        let mut view = TerminalView::new(
236            terminal.read(cx).inner().clone(),
237            workspace.clone(),
238            None,
239            project.downgrade(),
240            window,
241            cx,
242        );
243        view.set_embedded_mode(Some(1000), cx);
244        view
245    })
246}
247
248fn create_editor_diff(
249    diff: Entity<acp_thread::Diff>,
250    window: &mut Window,
251    cx: &mut App,
252) -> Entity<Editor> {
253    cx.new(|cx| {
254        let mut editor = Editor::new(
255            EditorMode::Full {
256                scale_ui_elements_with_buffer_font_size: false,
257                show_active_line_background: false,
258                sized_by_content: true,
259            },
260            diff.read(cx).multibuffer().clone(),
261            None,
262            window,
263            cx,
264        );
265        editor.set_show_gutter(false, cx);
266        editor.disable_inline_diagnostics();
267        editor.disable_expand_excerpt_buttons(cx);
268        editor.set_show_vertical_scrollbar(false, cx);
269        editor.set_minimap_visibility(MinimapVisibility::Disabled, window, cx);
270        editor.set_soft_wrap_mode(SoftWrap::None, cx);
271        editor.scroll_manager.set_forbid_vertical_scroll(true);
272        editor.set_show_indent_guides(false, cx);
273        editor.set_read_only(true);
274        editor.set_show_breakpoints(false, cx);
275        editor.set_show_code_actions(false, cx);
276        editor.set_show_git_diff_gutter(false, cx);
277        editor.set_expand_all_diff_hunks(cx);
278        editor.set_text_style_refinement(diff_editor_text_style_refinement(cx));
279        editor
280    })
281}
282
283fn diff_editor_text_style_refinement(cx: &mut App) -> TextStyleRefinement {
284    TextStyleRefinement {
285        font_size: Some(
286            TextSize::Small
287                .rems(cx)
288                .to_pixels(ThemeSettings::get_global(cx).agent_font_size(cx))
289                .into(),
290        ),
291        ..Default::default()
292    }
293}
294
295#[cfg(test)]
296mod tests {
297    use std::{path::Path, rc::Rc};
298
299    use acp_thread::{AgentConnection, StubAgentConnection};
300    use agent::{TextThreadStore, ThreadStore};
301    use agent_client_protocol as acp;
302    use agent_settings::AgentSettings;
303    use buffer_diff::{DiffHunkStatus, DiffHunkStatusKind};
304    use editor::{EditorSettings, RowInfo};
305    use fs::FakeFs;
306    use gpui::{AppContext as _, SemanticVersion, TestAppContext};
307
308    use crate::acp::entry_view_state::EntryViewState;
309    use multi_buffer::MultiBufferRow;
310    use pretty_assertions::assert_matches;
311    use project::Project;
312    use serde_json::json;
313    use settings::{Settings as _, SettingsStore};
314    use theme::ThemeSettings;
315    use util::path;
316    use workspace::Workspace;
317
318    #[gpui::test]
319    async fn test_diff_sync(cx: &mut TestAppContext) {
320        init_test(cx);
321        let fs = FakeFs::new(cx.executor());
322        fs.insert_tree(
323            "/project",
324            json!({
325                "hello.txt": "hi world"
326            }),
327        )
328        .await;
329        let project = Project::test(fs, [Path::new(path!("/project"))], cx).await;
330
331        let (workspace, cx) =
332            cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
333
334        let tool_call = acp::ToolCall {
335            id: acp::ToolCallId("tool".into()),
336            title: "Tool call".into(),
337            kind: acp::ToolKind::Other,
338            status: acp::ToolCallStatus::InProgress,
339            content: vec![acp::ToolCallContent::Diff {
340                diff: acp::Diff {
341                    path: "/project/hello.txt".into(),
342                    old_text: Some("hi world".into()),
343                    new_text: "hello world".into(),
344                },
345            }],
346            locations: vec![],
347            raw_input: None,
348            raw_output: None,
349        };
350        let connection = Rc::new(StubAgentConnection::new());
351        let thread = cx
352            .update(|_, cx| {
353                connection
354                    .clone()
355                    .new_thread(project.clone(), Path::new(path!("/project")), cx)
356            })
357            .await
358            .unwrap();
359        let session_id = thread.update(cx, |thread, _| thread.session_id().clone());
360
361        cx.update(|_, cx| {
362            connection.send_update(session_id, acp::SessionUpdate::ToolCall(tool_call), cx)
363        });
364
365        let thread_store = cx.new(|cx| ThreadStore::fake(project.clone(), cx));
366        let text_thread_store = cx.new(|cx| TextThreadStore::fake(project.clone(), cx));
367
368        let view_state = cx.new(|_cx| {
369            EntryViewState::new(
370                workspace.downgrade(),
371                project.clone(),
372                thread_store,
373                text_thread_store,
374            )
375        });
376
377        view_state.update_in(cx, |view_state, window, cx| {
378            view_state.sync_entry(0, &thread, window, cx)
379        });
380
381        let diff = thread.read_with(cx, |thread, _cx| {
382            thread
383                .entries()
384                .get(0)
385                .unwrap()
386                .diffs()
387                .next()
388                .unwrap()
389                .clone()
390        });
391
392        cx.run_until_parked();
393
394        let diff_editor = view_state.read_with(cx, |view_state, _cx| {
395            view_state.entry(0).unwrap().editor_for_diff(&diff).unwrap()
396        });
397        assert_eq!(
398            diff_editor.read_with(cx, |editor, cx| editor.text(cx)),
399            "hi world\nhello world"
400        );
401        let row_infos = diff_editor.read_with(cx, |editor, cx| {
402            let multibuffer = editor.buffer().read(cx);
403            multibuffer
404                .snapshot(cx)
405                .row_infos(MultiBufferRow(0))
406                .collect::<Vec<_>>()
407        });
408        assert_matches!(
409            row_infos.as_slice(),
410            [
411                RowInfo {
412                    multibuffer_row: Some(MultiBufferRow(0)),
413                    diff_status: Some(DiffHunkStatus {
414                        kind: DiffHunkStatusKind::Deleted,
415                        ..
416                    }),
417                    ..
418                },
419                RowInfo {
420                    multibuffer_row: Some(MultiBufferRow(1)),
421                    diff_status: Some(DiffHunkStatus {
422                        kind: DiffHunkStatusKind::Added,
423                        ..
424                    }),
425                    ..
426                }
427            ]
428        );
429    }
430
431    fn init_test(cx: &mut TestAppContext) {
432        cx.update(|cx| {
433            let settings_store = SettingsStore::test(cx);
434            cx.set_global(settings_store);
435            language::init(cx);
436            Project::init_settings(cx);
437            AgentSettings::register(cx);
438            workspace::init_settings(cx);
439            ThemeSettings::register(cx);
440            release_channel::init(SemanticVersion::default(), cx);
441            EditorSettings::register(cx);
442        });
443    }
444}