entry_view_state.rs

  1use std::ops::Range;
  2
  3use acp_thread::{AcpThread, AgentThreadEntry};
  4use agent::{TextThreadStore, ThreadStore};
  5use collections::HashMap;
  6use editor::{Editor, EditorMode, MinimapVisibility};
  7use gpui::{
  8    AnyEntity, App, AppContext as _, Entity, EntityId, EventEmitter, TextStyleRefinement,
  9    WeakEntity, Window,
 10};
 11use language::language_settings::SoftWrap;
 12use project::Project;
 13use settings::Settings as _;
 14use terminal_view::TerminalView;
 15use theme::ThemeSettings;
 16use ui::{Context, TextSize};
 17use workspace::Workspace;
 18
 19use crate::acp::message_editor::{MessageEditor, MessageEditorEvent};
 20
 21pub struct EntryViewState {
 22    workspace: WeakEntity<Workspace>,
 23    project: Entity<Project>,
 24    thread_store: Entity<ThreadStore>,
 25    text_thread_store: Entity<TextThreadStore>,
 26    entries: Vec<Entry>,
 27}
 28
 29impl EntryViewState {
 30    pub fn new(
 31        workspace: WeakEntity<Workspace>,
 32        project: Entity<Project>,
 33        thread_store: Entity<ThreadStore>,
 34        text_thread_store: Entity<TextThreadStore>,
 35    ) -> Self {
 36        Self {
 37            workspace,
 38            project,
 39            thread_store,
 40            text_thread_store,
 41            entries: Vec::new(),
 42        }
 43    }
 44
 45    pub fn entry(&self, index: usize) -> Option<&Entry> {
 46        self.entries.get(index)
 47    }
 48
 49    pub fn sync_entry(
 50        &mut self,
 51        index: usize,
 52        thread: &Entity<AcpThread>,
 53        window: &mut Window,
 54        cx: &mut Context<Self>,
 55    ) {
 56        let Some(thread_entry) = thread.read(cx).entries().get(index) else {
 57            return;
 58        };
 59
 60        match thread_entry {
 61            AgentThreadEntry::UserMessage(message) => {
 62                let has_id = message.id.is_some();
 63                let chunks = message.chunks.clone();
 64                let message_editor = cx.new(|cx| {
 65                    let mut editor = MessageEditor::new(
 66                        self.workspace.clone(),
 67                        self.project.clone(),
 68                        self.thread_store.clone(),
 69                        self.text_thread_store.clone(),
 70                        "Edit message - @ to include context",
 71                        editor::EditorMode::AutoHeight {
 72                            min_lines: 1,
 73                            max_lines: None,
 74                        },
 75                        window,
 76                        cx,
 77                    );
 78                    if !has_id {
 79                        editor.set_read_only(true, cx);
 80                    }
 81                    editor.set_message(chunks, window, cx);
 82                    editor
 83                });
 84                cx.subscribe(&message_editor, move |_, editor, event, cx| {
 85                    cx.emit(EntryViewEvent {
 86                        entry_index: index,
 87                        view_event: ViewEvent::MessageEditorEvent(editor, *event),
 88                    })
 89                })
 90                .detach();
 91                self.set_entry(index, Entry::UserMessage(message_editor));
 92            }
 93            AgentThreadEntry::ToolCall(tool_call) => {
 94                let terminals = tool_call.terminals().cloned().collect::<Vec<_>>();
 95                let diffs = tool_call.diffs().cloned().collect::<Vec<_>>();
 96
 97                let views = if let Some(Entry::Content(views)) = self.entries.get_mut(index) {
 98                    views
 99                } else {
100                    self.set_entry(index, Entry::empty());
101                    let Some(Entry::Content(views)) = self.entries.get_mut(index) else {
102                        unreachable!()
103                    };
104                    views
105                };
106
107                for terminal in terminals {
108                    views.entry(terminal.entity_id()).or_insert_with(|| {
109                        create_terminal(
110                            self.workspace.clone(),
111                            self.project.clone(),
112                            terminal.clone(),
113                            window,
114                            cx,
115                        )
116                        .into_any()
117                    });
118                }
119
120                for diff in diffs {
121                    views
122                        .entry(diff.entity_id())
123                        .or_insert_with(|| create_editor_diff(diff.clone(), window, cx).into_any());
124                }
125            }
126            AgentThreadEntry::AssistantMessage(_) => {
127                if index == self.entries.len() {
128                    self.entries.push(Entry::empty())
129                }
130            }
131        };
132    }
133
134    fn set_entry(&mut self, index: usize, entry: Entry) {
135        if index == self.entries.len() {
136            self.entries.push(entry);
137        } else {
138            self.entries[index] = entry;
139        }
140    }
141
142    pub fn remove(&mut self, range: Range<usize>) {
143        self.entries.drain(range);
144    }
145
146    pub fn settings_changed(&mut self, cx: &mut App) {
147        for entry in self.entries.iter() {
148            match entry {
149                Entry::UserMessage { .. } => {}
150                Entry::Content(response_views) => {
151                    for view in response_views.values() {
152                        if let Ok(diff_editor) = view.clone().downcast::<Editor>() {
153                            diff_editor.update(cx, |diff_editor, cx| {
154                                diff_editor.set_text_style_refinement(
155                                    diff_editor_text_style_refinement(cx),
156                                );
157                                cx.notify();
158                            })
159                        }
160                    }
161                }
162            }
163        }
164    }
165}
166
167impl EventEmitter<EntryViewEvent> for EntryViewState {}
168
169pub struct EntryViewEvent {
170    pub entry_index: usize,
171    pub view_event: ViewEvent,
172}
173
174pub enum ViewEvent {
175    MessageEditorEvent(Entity<MessageEditor>, MessageEditorEvent),
176}
177
178pub enum Entry {
179    UserMessage(Entity<MessageEditor>),
180    Content(HashMap<EntityId, AnyEntity>),
181}
182
183impl Entry {
184    pub fn message_editor(&self) -> Option<&Entity<MessageEditor>> {
185        match self {
186            Self::UserMessage(editor) => Some(editor),
187            Entry::Content(_) => None,
188        }
189    }
190
191    pub fn editor_for_diff(&self, diff: &Entity<acp_thread::Diff>) -> Option<Entity<Editor>> {
192        self.content_map()?
193            .get(&diff.entity_id())
194            .cloned()
195            .map(|entity| entity.downcast::<Editor>().unwrap())
196    }
197
198    pub fn terminal(
199        &self,
200        terminal: &Entity<acp_thread::Terminal>,
201    ) -> Option<Entity<TerminalView>> {
202        self.content_map()?
203            .get(&terminal.entity_id())
204            .cloned()
205            .map(|entity| entity.downcast::<TerminalView>().unwrap())
206    }
207
208    fn content_map(&self) -> Option<&HashMap<EntityId, AnyEntity>> {
209        match self {
210            Self::Content(map) => Some(map),
211            _ => None,
212        }
213    }
214
215    fn empty() -> Self {
216        Self::Content(HashMap::default())
217    }
218
219    #[cfg(test)]
220    pub fn has_content(&self) -> bool {
221        match self {
222            Self::Content(map) => !map.is_empty(),
223            Self::UserMessage(_) => false,
224        }
225    }
226}
227
228fn create_terminal(
229    workspace: WeakEntity<Workspace>,
230    project: Entity<Project>,
231    terminal: Entity<acp_thread::Terminal>,
232    window: &mut Window,
233    cx: &mut App,
234) -> Entity<TerminalView> {
235    cx.new(|cx| {
236        let mut view = TerminalView::new(
237            terminal.read(cx).inner().clone(),
238            workspace.clone(),
239            None,
240            project.downgrade(),
241            window,
242            cx,
243        );
244        view.set_embedded_mode(Some(1000), cx);
245        view
246    })
247}
248
249fn create_editor_diff(
250    diff: Entity<acp_thread::Diff>,
251    window: &mut Window,
252    cx: &mut App,
253) -> Entity<Editor> {
254    cx.new(|cx| {
255        let mut editor = Editor::new(
256            EditorMode::Full {
257                scale_ui_elements_with_buffer_font_size: false,
258                show_active_line_background: false,
259                sized_by_content: true,
260            },
261            diff.read(cx).multibuffer().clone(),
262            None,
263            window,
264            cx,
265        );
266        editor.set_show_gutter(false, cx);
267        editor.disable_inline_diagnostics();
268        editor.disable_expand_excerpt_buttons(cx);
269        editor.set_show_vertical_scrollbar(false, cx);
270        editor.set_minimap_visibility(MinimapVisibility::Disabled, window, cx);
271        editor.set_soft_wrap_mode(SoftWrap::None, cx);
272        editor.scroll_manager.set_forbid_vertical_scroll(true);
273        editor.set_show_indent_guides(false, cx);
274        editor.set_read_only(true);
275        editor.set_show_breakpoints(false, cx);
276        editor.set_show_code_actions(false, cx);
277        editor.set_show_git_diff_gutter(false, cx);
278        editor.set_expand_all_diff_hunks(cx);
279        editor.set_text_style_refinement(diff_editor_text_style_refinement(cx));
280        editor
281    })
282}
283
284fn diff_editor_text_style_refinement(cx: &mut App) -> TextStyleRefinement {
285    TextStyleRefinement {
286        font_size: Some(
287            TextSize::Small
288                .rems(cx)
289                .to_pixels(ThemeSettings::get_global(cx).agent_font_size(cx))
290                .into(),
291        ),
292        ..Default::default()
293    }
294}
295
296#[cfg(test)]
297mod tests {
298    use std::{path::Path, rc::Rc};
299
300    use acp_thread::{AgentConnection, StubAgentConnection};
301    use agent::{TextThreadStore, ThreadStore};
302    use agent_client_protocol as acp;
303    use agent_settings::AgentSettings;
304    use buffer_diff::{DiffHunkStatus, DiffHunkStatusKind};
305    use editor::{EditorSettings, RowInfo};
306    use fs::FakeFs;
307    use gpui::{AppContext as _, SemanticVersion, TestAppContext};
308
309    use crate::acp::entry_view_state::EntryViewState;
310    use multi_buffer::MultiBufferRow;
311    use pretty_assertions::assert_matches;
312    use project::Project;
313    use serde_json::json;
314    use settings::{Settings as _, SettingsStore};
315    use theme::ThemeSettings;
316    use util::path;
317    use workspace::Workspace;
318
319    #[gpui::test]
320    async fn test_diff_sync(cx: &mut TestAppContext) {
321        init_test(cx);
322        let fs = FakeFs::new(cx.executor());
323        fs.insert_tree(
324            "/project",
325            json!({
326                "hello.txt": "hi world"
327            }),
328        )
329        .await;
330        let project = Project::test(fs, [Path::new(path!("/project"))], cx).await;
331
332        let (workspace, cx) =
333            cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
334
335        let tool_call = acp::ToolCall {
336            id: acp::ToolCallId("tool".into()),
337            title: "Tool call".into(),
338            kind: acp::ToolKind::Other,
339            status: acp::ToolCallStatus::InProgress,
340            content: vec![acp::ToolCallContent::Diff {
341                diff: acp::Diff {
342                    path: "/project/hello.txt".into(),
343                    old_text: Some("hi world".into()),
344                    new_text: "hello world".into(),
345                },
346            }],
347            locations: vec![],
348            raw_input: None,
349            raw_output: None,
350        };
351        let connection = Rc::new(StubAgentConnection::new());
352        let thread = cx
353            .update(|_, cx| {
354                connection
355                    .clone()
356                    .new_thread(project.clone(), Path::new(path!("/project")), cx)
357            })
358            .await
359            .unwrap();
360        let session_id = thread.update(cx, |thread, _| thread.session_id().clone());
361
362        cx.update(|_, cx| {
363            connection.send_update(session_id, acp::SessionUpdate::ToolCall(tool_call), cx)
364        });
365
366        let thread_store = cx.new(|cx| ThreadStore::fake(project.clone(), cx));
367        let text_thread_store = cx.new(|cx| TextThreadStore::fake(project.clone(), cx));
368
369        let view_state = cx.new(|_cx| {
370            EntryViewState::new(
371                workspace.downgrade(),
372                project.clone(),
373                thread_store,
374                text_thread_store,
375            )
376        });
377
378        view_state.update_in(cx, |view_state, window, cx| {
379            view_state.sync_entry(0, &thread, window, cx)
380        });
381
382        let diff = thread.read_with(cx, |thread, _cx| {
383            thread
384                .entries()
385                .get(0)
386                .unwrap()
387                .diffs()
388                .next()
389                .unwrap()
390                .clone()
391        });
392
393        cx.run_until_parked();
394
395        let diff_editor = view_state.read_with(cx, |view_state, _cx| {
396            view_state.entry(0).unwrap().editor_for_diff(&diff).unwrap()
397        });
398        assert_eq!(
399            diff_editor.read_with(cx, |editor, cx| editor.text(cx)),
400            "hi world\nhello world"
401        );
402        let row_infos = diff_editor.read_with(cx, |editor, cx| {
403            let multibuffer = editor.buffer().read(cx);
404            multibuffer
405                .snapshot(cx)
406                .row_infos(MultiBufferRow(0))
407                .collect::<Vec<_>>()
408        });
409        assert_matches!(
410            row_infos.as_slice(),
411            [
412                RowInfo {
413                    multibuffer_row: Some(MultiBufferRow(0)),
414                    diff_status: Some(DiffHunkStatus {
415                        kind: DiffHunkStatusKind::Deleted,
416                        ..
417                    }),
418                    ..
419                },
420                RowInfo {
421                    multibuffer_row: Some(MultiBufferRow(1)),
422                    diff_status: Some(DiffHunkStatus {
423                        kind: DiffHunkStatusKind::Added,
424                        ..
425                    }),
426                    ..
427                }
428            ]
429        );
430    }
431
432    fn init_test(cx: &mut TestAppContext) {
433        cx.update(|cx| {
434            let settings_store = SettingsStore::test(cx);
435            cx.set_global(settings_store);
436            language::init(cx);
437            Project::init_settings(cx);
438            AgentSettings::register(cx);
439            workspace::init_settings(cx);
440            ThemeSettings::register(cx);
441            release_channel::init(SemanticVersion::default(), cx);
442            EditorSettings::register(cx);
443        });
444    }
445}