entry_view_state.rs

  1use std::ops::Range;
  2
  3use acp_thread::{AcpThread, AgentThreadEntry};
  4use agent::{TextThreadStore, ThreadStore};
  5use collections::HashMap;
  6use editor::{Editor, EditorMode, MinimapVisibility};
  7use gpui::{
  8    AnyEntity, App, AppContext as _, Entity, EntityId, EventEmitter, Focusable,
  9    TextStyleRefinement, WeakEntity, Window,
 10};
 11use language::language_settings::SoftWrap;
 12use project::Project;
 13use settings::Settings as _;
 14use terminal_view::TerminalView;
 15use theme::ThemeSettings;
 16use ui::{Context, TextSize};
 17use workspace::Workspace;
 18
 19use crate::acp::message_editor::{MessageEditor, MessageEditorEvent};
 20
 21pub struct EntryViewState {
 22    workspace: WeakEntity<Workspace>,
 23    project: Entity<Project>,
 24    thread_store: Entity<ThreadStore>,
 25    text_thread_store: Entity<TextThreadStore>,
 26    entries: Vec<Entry>,
 27    prevent_slash_commands: bool,
 28}
 29
 30impl EntryViewState {
 31    pub fn new(
 32        workspace: WeakEntity<Workspace>,
 33        project: Entity<Project>,
 34        thread_store: Entity<ThreadStore>,
 35        text_thread_store: Entity<TextThreadStore>,
 36        prevent_slash_commands: bool,
 37    ) -> Self {
 38        Self {
 39            workspace,
 40            project,
 41            thread_store,
 42            text_thread_store,
 43            entries: Vec::new(),
 44            prevent_slash_commands,
 45        }
 46    }
 47
 48    pub fn entry(&self, index: usize) -> Option<&Entry> {
 49        self.entries.get(index)
 50    }
 51
 52    pub fn sync_entry(
 53        &mut self,
 54        index: usize,
 55        thread: &Entity<AcpThread>,
 56        window: &mut Window,
 57        cx: &mut Context<Self>,
 58    ) {
 59        let Some(thread_entry) = thread.read(cx).entries().get(index) else {
 60            return;
 61        };
 62
 63        match thread_entry {
 64            AgentThreadEntry::UserMessage(message) => {
 65                let has_id = message.id.is_some();
 66                let chunks = message.chunks.clone();
 67                if let Some(Entry::UserMessage(editor)) = self.entries.get_mut(index) {
 68                    if !editor.focus_handle(cx).is_focused(window) {
 69                        // Only update if we are not editing.
 70                        // If we are, cancelling the edit will set the message to the newest content.
 71                        editor.update(cx, |editor, cx| {
 72                            editor.set_message(chunks, window, cx);
 73                        });
 74                    }
 75                } else {
 76                    let message_editor = cx.new(|cx| {
 77                        let mut editor = MessageEditor::new(
 78                            self.workspace.clone(),
 79                            self.project.clone(),
 80                            self.thread_store.clone(),
 81                            self.text_thread_store.clone(),
 82                            "Edit message - @ to include context",
 83                            self.prevent_slash_commands,
 84                            editor::EditorMode::AutoHeight {
 85                                min_lines: 1,
 86                                max_lines: None,
 87                            },
 88                            window,
 89                            cx,
 90                        );
 91                        if !has_id {
 92                            editor.set_read_only(true, cx);
 93                        }
 94                        editor.set_message(chunks, window, cx);
 95                        editor
 96                    });
 97                    cx.subscribe(&message_editor, move |_, editor, event, cx| {
 98                        cx.emit(EntryViewEvent {
 99                            entry_index: index,
100                            view_event: ViewEvent::MessageEditorEvent(editor, *event),
101                        })
102                    })
103                    .detach();
104                    self.set_entry(index, Entry::UserMessage(message_editor));
105                }
106            }
107            AgentThreadEntry::ToolCall(tool_call) => {
108                let terminals = tool_call.terminals().cloned().collect::<Vec<_>>();
109                let diffs = tool_call.diffs().cloned().collect::<Vec<_>>();
110
111                let views = if let Some(Entry::Content(views)) = self.entries.get_mut(index) {
112                    views
113                } else {
114                    self.set_entry(index, Entry::empty());
115                    let Some(Entry::Content(views)) = self.entries.get_mut(index) else {
116                        unreachable!()
117                    };
118                    views
119                };
120
121                for terminal in terminals {
122                    views.entry(terminal.entity_id()).or_insert_with(|| {
123                        create_terminal(
124                            self.workspace.clone(),
125                            self.project.clone(),
126                            terminal.clone(),
127                            window,
128                            cx,
129                        )
130                        .into_any()
131                    });
132                }
133
134                for diff in diffs {
135                    views
136                        .entry(diff.entity_id())
137                        .or_insert_with(|| create_editor_diff(diff.clone(), window, cx).into_any());
138                }
139            }
140            AgentThreadEntry::AssistantMessage(_) => {
141                if index == self.entries.len() {
142                    self.entries.push(Entry::empty())
143                }
144            }
145        };
146    }
147
148    fn set_entry(&mut self, index: usize, entry: Entry) {
149        if index == self.entries.len() {
150            self.entries.push(entry);
151        } else {
152            self.entries[index] = entry;
153        }
154    }
155
156    pub fn remove(&mut self, range: Range<usize>) {
157        self.entries.drain(range);
158    }
159
160    pub fn settings_changed(&mut self, cx: &mut App) {
161        for entry in self.entries.iter() {
162            match entry {
163                Entry::UserMessage { .. } => {}
164                Entry::Content(response_views) => {
165                    for view in response_views.values() {
166                        if let Ok(diff_editor) = view.clone().downcast::<Editor>() {
167                            diff_editor.update(cx, |diff_editor, cx| {
168                                diff_editor.set_text_style_refinement(
169                                    diff_editor_text_style_refinement(cx),
170                                );
171                                cx.notify();
172                            })
173                        }
174                    }
175                }
176            }
177        }
178    }
179}
180
181impl EventEmitter<EntryViewEvent> for EntryViewState {}
182
183pub struct EntryViewEvent {
184    pub entry_index: usize,
185    pub view_event: ViewEvent,
186}
187
188pub enum ViewEvent {
189    MessageEditorEvent(Entity<MessageEditor>, MessageEditorEvent),
190}
191
192#[derive(Debug)]
193pub enum Entry {
194    UserMessage(Entity<MessageEditor>),
195    Content(HashMap<EntityId, AnyEntity>),
196}
197
198impl Entry {
199    pub fn message_editor(&self) -> Option<&Entity<MessageEditor>> {
200        match self {
201            Self::UserMessage(editor) => Some(editor),
202            Entry::Content(_) => None,
203        }
204    }
205
206    pub fn editor_for_diff(&self, diff: &Entity<acp_thread::Diff>) -> Option<Entity<Editor>> {
207        self.content_map()?
208            .get(&diff.entity_id())
209            .cloned()
210            .map(|entity| entity.downcast::<Editor>().unwrap())
211    }
212
213    pub fn terminal(
214        &self,
215        terminal: &Entity<acp_thread::Terminal>,
216    ) -> Option<Entity<TerminalView>> {
217        self.content_map()?
218            .get(&terminal.entity_id())
219            .cloned()
220            .map(|entity| entity.downcast::<TerminalView>().unwrap())
221    }
222
223    fn content_map(&self) -> Option<&HashMap<EntityId, AnyEntity>> {
224        match self {
225            Self::Content(map) => Some(map),
226            _ => None,
227        }
228    }
229
230    fn empty() -> Self {
231        Self::Content(HashMap::default())
232    }
233
234    #[cfg(test)]
235    pub fn has_content(&self) -> bool {
236        match self {
237            Self::Content(map) => !map.is_empty(),
238            Self::UserMessage(_) => false,
239        }
240    }
241}
242
243fn create_terminal(
244    workspace: WeakEntity<Workspace>,
245    project: Entity<Project>,
246    terminal: Entity<acp_thread::Terminal>,
247    window: &mut Window,
248    cx: &mut App,
249) -> Entity<TerminalView> {
250    cx.new(|cx| {
251        let mut view = TerminalView::new(
252            terminal.read(cx).inner().clone(),
253            workspace.clone(),
254            None,
255            project.downgrade(),
256            window,
257            cx,
258        );
259        view.set_embedded_mode(Some(1000), cx);
260        view
261    })
262}
263
264fn create_editor_diff(
265    diff: Entity<acp_thread::Diff>,
266    window: &mut Window,
267    cx: &mut App,
268) -> Entity<Editor> {
269    cx.new(|cx| {
270        let mut editor = Editor::new(
271            EditorMode::Full {
272                scale_ui_elements_with_buffer_font_size: false,
273                show_active_line_background: false,
274                sized_by_content: true,
275            },
276            diff.read(cx).multibuffer().clone(),
277            None,
278            window,
279            cx,
280        );
281        editor.set_show_gutter(false, cx);
282        editor.disable_inline_diagnostics();
283        editor.disable_expand_excerpt_buttons(cx);
284        editor.set_show_vertical_scrollbar(false, cx);
285        editor.set_minimap_visibility(MinimapVisibility::Disabled, window, cx);
286        editor.set_soft_wrap_mode(SoftWrap::None, cx);
287        editor.scroll_manager.set_forbid_vertical_scroll(true);
288        editor.set_show_indent_guides(false, cx);
289        editor.set_read_only(true);
290        editor.set_show_breakpoints(false, cx);
291        editor.set_show_code_actions(false, cx);
292        editor.set_show_git_diff_gutter(false, cx);
293        editor.set_expand_all_diff_hunks(cx);
294        editor.set_text_style_refinement(diff_editor_text_style_refinement(cx));
295        editor
296    })
297}
298
299fn diff_editor_text_style_refinement(cx: &mut App) -> TextStyleRefinement {
300    TextStyleRefinement {
301        font_size: Some(
302            TextSize::Small
303                .rems(cx)
304                .to_pixels(ThemeSettings::get_global(cx).agent_font_size(cx))
305                .into(),
306        ),
307        ..Default::default()
308    }
309}
310
311#[cfg(test)]
312mod tests {
313    use std::{path::Path, rc::Rc};
314
315    use acp_thread::{AgentConnection, StubAgentConnection};
316    use agent::{TextThreadStore, ThreadStore};
317    use agent_client_protocol as acp;
318    use agent_settings::AgentSettings;
319    use buffer_diff::{DiffHunkStatus, DiffHunkStatusKind};
320    use editor::{EditorSettings, RowInfo};
321    use fs::FakeFs;
322    use gpui::{AppContext as _, SemanticVersion, TestAppContext};
323
324    use crate::acp::entry_view_state::EntryViewState;
325    use multi_buffer::MultiBufferRow;
326    use pretty_assertions::assert_matches;
327    use project::Project;
328    use serde_json::json;
329    use settings::{Settings as _, SettingsStore};
330    use theme::ThemeSettings;
331    use util::path;
332    use workspace::Workspace;
333
334    #[gpui::test]
335    async fn test_diff_sync(cx: &mut TestAppContext) {
336        init_test(cx);
337        let fs = FakeFs::new(cx.executor());
338        fs.insert_tree(
339            "/project",
340            json!({
341                "hello.txt": "hi world"
342            }),
343        )
344        .await;
345        let project = Project::test(fs, [Path::new(path!("/project"))], cx).await;
346
347        let (workspace, cx) =
348            cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
349
350        let tool_call = acp::ToolCall {
351            id: acp::ToolCallId("tool".into()),
352            title: "Tool call".into(),
353            kind: acp::ToolKind::Other,
354            status: acp::ToolCallStatus::InProgress,
355            content: vec![acp::ToolCallContent::Diff {
356                diff: acp::Diff {
357                    path: "/project/hello.txt".into(),
358                    old_text: Some("hi world".into()),
359                    new_text: "hello world".into(),
360                },
361            }],
362            locations: vec![],
363            raw_input: None,
364            raw_output: None,
365        };
366        let connection = Rc::new(StubAgentConnection::new());
367        let thread = cx
368            .update(|_, cx| {
369                connection
370                    .clone()
371                    .new_thread(project.clone(), Path::new(path!("/project")), cx)
372            })
373            .await
374            .unwrap();
375        let session_id = thread.update(cx, |thread, _| thread.session_id().clone());
376
377        cx.update(|_, cx| {
378            connection.send_update(session_id, acp::SessionUpdate::ToolCall(tool_call), cx)
379        });
380
381        let thread_store = cx.new(|cx| ThreadStore::fake(project.clone(), cx));
382        let text_thread_store = cx.new(|cx| TextThreadStore::fake(project.clone(), cx));
383
384        let view_state = cx.new(|_cx| {
385            EntryViewState::new(
386                workspace.downgrade(),
387                project.clone(),
388                thread_store,
389                text_thread_store,
390                false,
391            )
392        });
393
394        view_state.update_in(cx, |view_state, window, cx| {
395            view_state.sync_entry(0, &thread, window, cx)
396        });
397
398        let diff = thread.read_with(cx, |thread, _cx| {
399            thread
400                .entries()
401                .get(0)
402                .unwrap()
403                .diffs()
404                .next()
405                .unwrap()
406                .clone()
407        });
408
409        cx.run_until_parked();
410
411        let diff_editor = view_state.read_with(cx, |view_state, _cx| {
412            view_state.entry(0).unwrap().editor_for_diff(&diff).unwrap()
413        });
414        assert_eq!(
415            diff_editor.read_with(cx, |editor, cx| editor.text(cx)),
416            "hi world\nhello world"
417        );
418        let row_infos = diff_editor.read_with(cx, |editor, cx| {
419            let multibuffer = editor.buffer().read(cx);
420            multibuffer
421                .snapshot(cx)
422                .row_infos(MultiBufferRow(0))
423                .collect::<Vec<_>>()
424        });
425        assert_matches!(
426            row_infos.as_slice(),
427            [
428                RowInfo {
429                    multibuffer_row: Some(MultiBufferRow(0)),
430                    diff_status: Some(DiffHunkStatus {
431                        kind: DiffHunkStatusKind::Deleted,
432                        ..
433                    }),
434                    ..
435                },
436                RowInfo {
437                    multibuffer_row: Some(MultiBufferRow(1)),
438                    diff_status: Some(DiffHunkStatus {
439                        kind: DiffHunkStatusKind::Added,
440                        ..
441                    }),
442                    ..
443                }
444            ]
445        );
446    }
447
448    fn init_test(cx: &mut TestAppContext) {
449        cx.update(|cx| {
450            let settings_store = SettingsStore::test(cx);
451            cx.set_global(settings_store);
452            language::init(cx);
453            Project::init_settings(cx);
454            AgentSettings::register(cx);
455            workspace::init_settings(cx);
456            ThemeSettings::register(cx);
457            release_channel::init(SemanticVersion::default(), cx);
458            EditorSettings::register(cx);
459        });
460    }
461}