entry_view_state.rs

  1use std::ops::Range;
  2
  3use acp_thread::{AcpThread, AgentThreadEntry};
  4use agent::{TextThreadStore, ThreadStore};
  5use collections::HashMap;
  6use editor::{Editor, EditorMode, MinimapVisibility};
  7use gpui::{
  8    AnyEntity, App, AppContext as _, Entity, EntityId, EventEmitter, Focusable,
  9    TextStyleRefinement, WeakEntity, Window,
 10};
 11use language::language_settings::SoftWrap;
 12use project::Project;
 13use settings::Settings as _;
 14use terminal_view::TerminalView;
 15use theme::ThemeSettings;
 16use ui::{Context, TextSize};
 17use workspace::Workspace;
 18
 19use crate::acp::message_editor::{MessageEditor, MessageEditorEvent};
 20
 21pub struct EntryViewState {
 22    workspace: WeakEntity<Workspace>,
 23    project: Entity<Project>,
 24    thread_store: Entity<ThreadStore>,
 25    text_thread_store: Entity<TextThreadStore>,
 26    entries: Vec<Entry>,
 27    prevent_slash_commands: bool,
 28}
 29
 30impl EntryViewState {
 31    pub fn new(
 32        workspace: WeakEntity<Workspace>,
 33        project: Entity<Project>,
 34        thread_store: Entity<ThreadStore>,
 35        text_thread_store: Entity<TextThreadStore>,
 36        prevent_slash_commands: bool,
 37    ) -> Self {
 38        Self {
 39            workspace,
 40            project,
 41            thread_store,
 42            text_thread_store,
 43            entries: Vec::new(),
 44            prevent_slash_commands,
 45        }
 46    }
 47
 48    pub fn entry(&self, index: usize) -> Option<&Entry> {
 49        self.entries.get(index)
 50    }
 51
 52    pub fn sync_entry(
 53        &mut self,
 54        index: usize,
 55        thread: &Entity<AcpThread>,
 56        window: &mut Window,
 57        cx: &mut Context<Self>,
 58    ) {
 59        let Some(thread_entry) = thread.read(cx).entries().get(index) else {
 60            return;
 61        };
 62
 63        match thread_entry {
 64            AgentThreadEntry::UserMessage(message) => {
 65                let has_id = message.id.is_some();
 66                let chunks = message.chunks.clone();
 67                if let Some(Entry::UserMessage(editor)) = self.entries.get_mut(index) {
 68                    if !editor.focus_handle(cx).is_focused(window) {
 69                        // Only update if we are not editing.
 70                        // If we are, cancelling the edit will set the message to the newest content.
 71                        editor.update(cx, |editor, cx| {
 72                            editor.set_message(chunks, window, cx);
 73                        });
 74                    }
 75                } else {
 76                    let message_editor = cx.new(|cx| {
 77                        let mut editor = MessageEditor::new(
 78                            self.workspace.clone(),
 79                            self.project.clone(),
 80                            self.thread_store.clone(),
 81                            self.text_thread_store.clone(),
 82                            "Edit message - @ to include context",
 83                            self.prevent_slash_commands,
 84                            editor::EditorMode::AutoHeight {
 85                                min_lines: 1,
 86                                max_lines: None,
 87                            },
 88                            window,
 89                            cx,
 90                        );
 91                        if !has_id {
 92                            editor.set_read_only(true, cx);
 93                        }
 94                        editor.set_message(chunks, window, cx);
 95                        editor
 96                    });
 97                    cx.subscribe(&message_editor, move |_, editor, event, cx| {
 98                        cx.emit(EntryViewEvent {
 99                            entry_index: index,
100                            view_event: ViewEvent::MessageEditorEvent(editor, *event),
101                        })
102                    })
103                    .detach();
104                    self.set_entry(index, Entry::UserMessage(message_editor));
105                }
106            }
107            AgentThreadEntry::ToolCall(tool_call) => {
108                let terminals = tool_call.terminals().cloned().collect::<Vec<_>>();
109                let diffs = tool_call.diffs().cloned().collect::<Vec<_>>();
110
111                let views = if let Some(Entry::Content(views)) = self.entries.get_mut(index) {
112                    views
113                } else {
114                    self.set_entry(index, Entry::empty());
115                    let Some(Entry::Content(views)) = self.entries.get_mut(index) else {
116                        unreachable!()
117                    };
118                    views
119                };
120
121                for terminal in terminals {
122                    views.entry(terminal.entity_id()).or_insert_with(|| {
123                        create_terminal(
124                            self.workspace.clone(),
125                            self.project.clone(),
126                            terminal.clone(),
127                            window,
128                            cx,
129                        )
130                        .into_any()
131                    });
132                }
133
134                for diff in diffs {
135                    views
136                        .entry(diff.entity_id())
137                        .or_insert_with(|| create_editor_diff(diff.clone(), window, cx).into_any());
138                }
139            }
140            AgentThreadEntry::AssistantMessage(_) => {
141                if index == self.entries.len() {
142                    self.entries.push(Entry::empty())
143                }
144            }
145        };
146    }
147
148    fn set_entry(&mut self, index: usize, entry: Entry) {
149        if index == self.entries.len() {
150            self.entries.push(entry);
151        } else {
152            self.entries[index] = entry;
153        }
154    }
155
156    pub fn remove(&mut self, range: Range<usize>) {
157        self.entries.drain(range);
158    }
159
160    pub fn settings_changed(&mut self, cx: &mut App) {
161        for entry in self.entries.iter() {
162            match entry {
163                Entry::UserMessage { .. } => {}
164                Entry::Content(response_views) => {
165                    for view in response_views.values() {
166                        if let Ok(diff_editor) = view.clone().downcast::<Editor>() {
167                            diff_editor.update(cx, |diff_editor, cx| {
168                                diff_editor.set_text_style_refinement(
169                                    diff_editor_text_style_refinement(cx),
170                                );
171                                cx.notify();
172                            })
173                        }
174                    }
175                }
176            }
177        }
178    }
179}
180
181impl EventEmitter<EntryViewEvent> for EntryViewState {}
182
183pub struct EntryViewEvent {
184    pub entry_index: usize,
185    pub view_event: ViewEvent,
186}
187
188pub enum ViewEvent {
189    MessageEditorEvent(Entity<MessageEditor>, MessageEditorEvent),
190}
191
192pub enum Entry {
193    UserMessage(Entity<MessageEditor>),
194    Content(HashMap<EntityId, AnyEntity>),
195}
196
197impl Entry {
198    pub fn message_editor(&self) -> Option<&Entity<MessageEditor>> {
199        match self {
200            Self::UserMessage(editor) => Some(editor),
201            Entry::Content(_) => None,
202        }
203    }
204
205    pub fn editor_for_diff(&self, diff: &Entity<acp_thread::Diff>) -> Option<Entity<Editor>> {
206        self.content_map()?
207            .get(&diff.entity_id())
208            .cloned()
209            .map(|entity| entity.downcast::<Editor>().unwrap())
210    }
211
212    pub fn terminal(
213        &self,
214        terminal: &Entity<acp_thread::Terminal>,
215    ) -> Option<Entity<TerminalView>> {
216        self.content_map()?
217            .get(&terminal.entity_id())
218            .cloned()
219            .map(|entity| entity.downcast::<TerminalView>().unwrap())
220    }
221
222    fn content_map(&self) -> Option<&HashMap<EntityId, AnyEntity>> {
223        match self {
224            Self::Content(map) => Some(map),
225            _ => None,
226        }
227    }
228
229    fn empty() -> Self {
230        Self::Content(HashMap::default())
231    }
232
233    #[cfg(test)]
234    pub fn has_content(&self) -> bool {
235        match self {
236            Self::Content(map) => !map.is_empty(),
237            Self::UserMessage(_) => false,
238        }
239    }
240}
241
242fn create_terminal(
243    workspace: WeakEntity<Workspace>,
244    project: Entity<Project>,
245    terminal: Entity<acp_thread::Terminal>,
246    window: &mut Window,
247    cx: &mut App,
248) -> Entity<TerminalView> {
249    cx.new(|cx| {
250        let mut view = TerminalView::new(
251            terminal.read(cx).inner().clone(),
252            workspace.clone(),
253            None,
254            project.downgrade(),
255            window,
256            cx,
257        );
258        view.set_embedded_mode(Some(1000), cx);
259        view
260    })
261}
262
263fn create_editor_diff(
264    diff: Entity<acp_thread::Diff>,
265    window: &mut Window,
266    cx: &mut App,
267) -> Entity<Editor> {
268    cx.new(|cx| {
269        let mut editor = Editor::new(
270            EditorMode::Full {
271                scale_ui_elements_with_buffer_font_size: false,
272                show_active_line_background: false,
273                sized_by_content: true,
274            },
275            diff.read(cx).multibuffer().clone(),
276            None,
277            window,
278            cx,
279        );
280        editor.set_show_gutter(false, cx);
281        editor.disable_inline_diagnostics();
282        editor.disable_expand_excerpt_buttons(cx);
283        editor.set_show_vertical_scrollbar(false, cx);
284        editor.set_minimap_visibility(MinimapVisibility::Disabled, window, cx);
285        editor.set_soft_wrap_mode(SoftWrap::None, cx);
286        editor.scroll_manager.set_forbid_vertical_scroll(true);
287        editor.set_show_indent_guides(false, cx);
288        editor.set_read_only(true);
289        editor.set_show_breakpoints(false, cx);
290        editor.set_show_code_actions(false, cx);
291        editor.set_show_git_diff_gutter(false, cx);
292        editor.set_expand_all_diff_hunks(cx);
293        editor.set_text_style_refinement(diff_editor_text_style_refinement(cx));
294        editor
295    })
296}
297
298fn diff_editor_text_style_refinement(cx: &mut App) -> TextStyleRefinement {
299    TextStyleRefinement {
300        font_size: Some(
301            TextSize::Small
302                .rems(cx)
303                .to_pixels(ThemeSettings::get_global(cx).agent_font_size(cx))
304                .into(),
305        ),
306        ..Default::default()
307    }
308}
309
310#[cfg(test)]
311mod tests {
312    use std::{path::Path, rc::Rc};
313
314    use acp_thread::{AgentConnection, StubAgentConnection};
315    use agent::{TextThreadStore, ThreadStore};
316    use agent_client_protocol as acp;
317    use agent_settings::AgentSettings;
318    use buffer_diff::{DiffHunkStatus, DiffHunkStatusKind};
319    use editor::{EditorSettings, RowInfo};
320    use fs::FakeFs;
321    use gpui::{AppContext as _, SemanticVersion, TestAppContext};
322
323    use crate::acp::entry_view_state::EntryViewState;
324    use multi_buffer::MultiBufferRow;
325    use pretty_assertions::assert_matches;
326    use project::Project;
327    use serde_json::json;
328    use settings::{Settings as _, SettingsStore};
329    use theme::ThemeSettings;
330    use util::path;
331    use workspace::Workspace;
332
333    #[gpui::test]
334    async fn test_diff_sync(cx: &mut TestAppContext) {
335        init_test(cx);
336        let fs = FakeFs::new(cx.executor());
337        fs.insert_tree(
338            "/project",
339            json!({
340                "hello.txt": "hi world"
341            }),
342        )
343        .await;
344        let project = Project::test(fs, [Path::new(path!("/project"))], cx).await;
345
346        let (workspace, cx) =
347            cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
348
349        let tool_call = acp::ToolCall {
350            id: acp::ToolCallId("tool".into()),
351            title: "Tool call".into(),
352            kind: acp::ToolKind::Other,
353            status: acp::ToolCallStatus::InProgress,
354            content: vec![acp::ToolCallContent::Diff {
355                diff: acp::Diff {
356                    path: "/project/hello.txt".into(),
357                    old_text: Some("hi world".into()),
358                    new_text: "hello world".into(),
359                },
360            }],
361            locations: vec![],
362            raw_input: None,
363            raw_output: None,
364        };
365        let connection = Rc::new(StubAgentConnection::new());
366        let thread = cx
367            .update(|_, cx| {
368                connection
369                    .clone()
370                    .new_thread(project.clone(), Path::new(path!("/project")), cx)
371            })
372            .await
373            .unwrap();
374        let session_id = thread.update(cx, |thread, _| thread.session_id().clone());
375
376        cx.update(|_, cx| {
377            connection.send_update(session_id, acp::SessionUpdate::ToolCall(tool_call), cx)
378        });
379
380        let thread_store = cx.new(|cx| ThreadStore::fake(project.clone(), cx));
381        let text_thread_store = cx.new(|cx| TextThreadStore::fake(project.clone(), cx));
382
383        let view_state = cx.new(|_cx| {
384            EntryViewState::new(
385                workspace.downgrade(),
386                project.clone(),
387                thread_store,
388                text_thread_store,
389                false,
390            )
391        });
392
393        view_state.update_in(cx, |view_state, window, cx| {
394            view_state.sync_entry(0, &thread, window, cx)
395        });
396
397        let diff = thread.read_with(cx, |thread, _cx| {
398            thread
399                .entries()
400                .get(0)
401                .unwrap()
402                .diffs()
403                .next()
404                .unwrap()
405                .clone()
406        });
407
408        cx.run_until_parked();
409
410        let diff_editor = view_state.read_with(cx, |view_state, _cx| {
411            view_state.entry(0).unwrap().editor_for_diff(&diff).unwrap()
412        });
413        assert_eq!(
414            diff_editor.read_with(cx, |editor, cx| editor.text(cx)),
415            "hi world\nhello world"
416        );
417        let row_infos = diff_editor.read_with(cx, |editor, cx| {
418            let multibuffer = editor.buffer().read(cx);
419            multibuffer
420                .snapshot(cx)
421                .row_infos(MultiBufferRow(0))
422                .collect::<Vec<_>>()
423        });
424        assert_matches!(
425            row_infos.as_slice(),
426            [
427                RowInfo {
428                    multibuffer_row: Some(MultiBufferRow(0)),
429                    diff_status: Some(DiffHunkStatus {
430                        kind: DiffHunkStatusKind::Deleted,
431                        ..
432                    }),
433                    ..
434                },
435                RowInfo {
436                    multibuffer_row: Some(MultiBufferRow(1)),
437                    diff_status: Some(DiffHunkStatus {
438                        kind: DiffHunkStatusKind::Added,
439                        ..
440                    }),
441                    ..
442                }
443            ]
444        );
445    }
446
447    fn init_test(cx: &mut TestAppContext) {
448        cx.update(|cx| {
449            let settings_store = SettingsStore::test(cx);
450            cx.set_global(settings_store);
451            language::init(cx);
452            Project::init_settings(cx);
453            AgentSettings::register(cx);
454            workspace::init_settings(cx);
455            ThemeSettings::register(cx);
456            release_channel::init(SemanticVersion::default(), cx);
457            EditorSettings::register(cx);
458        });
459    }
460}