entry_view_state.rs

  1use std::ops::Range;
  2
  3use acp_thread::{AcpThread, AgentThreadEntry};
  4use agent::{TextThreadStore, ThreadStore};
  5use collections::HashMap;
  6use editor::{Editor, EditorMode, MinimapVisibility};
  7use gpui::{
  8    AnyEntity, App, AppContext as _, Entity, EntityId, EventEmitter, Focusable,
  9    TextStyleRefinement, WeakEntity, Window,
 10};
 11use language::language_settings::SoftWrap;
 12use project::Project;
 13use settings::Settings as _;
 14use terminal_view::TerminalView;
 15use theme::ThemeSettings;
 16use ui::{Context, TextSize};
 17use workspace::Workspace;
 18
 19use crate::acp::message_editor::{MessageEditor, MessageEditorEvent};
 20
 21pub struct EntryViewState {
 22    workspace: WeakEntity<Workspace>,
 23    project: Entity<Project>,
 24    thread_store: Entity<ThreadStore>,
 25    text_thread_store: Entity<TextThreadStore>,
 26    entries: Vec<Entry>,
 27}
 28
 29impl EntryViewState {
 30    pub fn new(
 31        workspace: WeakEntity<Workspace>,
 32        project: Entity<Project>,
 33        thread_store: Entity<ThreadStore>,
 34        text_thread_store: Entity<TextThreadStore>,
 35    ) -> Self {
 36        Self {
 37            workspace,
 38            project,
 39            thread_store,
 40            text_thread_store,
 41            entries: Vec::new(),
 42        }
 43    }
 44
 45    pub fn entry(&self, index: usize) -> Option<&Entry> {
 46        self.entries.get(index)
 47    }
 48
 49    pub fn sync_entry(
 50        &mut self,
 51        index: usize,
 52        thread: &Entity<AcpThread>,
 53        window: &mut Window,
 54        cx: &mut Context<Self>,
 55    ) {
 56        let Some(thread_entry) = thread.read(cx).entries().get(index) else {
 57            return;
 58        };
 59
 60        match thread_entry {
 61            AgentThreadEntry::UserMessage(message) => {
 62                let has_id = message.id.is_some();
 63                let chunks = message.chunks.clone();
 64                if let Some(Entry::UserMessage(editor)) = self.entries.get_mut(index) {
 65                    if !editor.focus_handle(cx).is_focused(window) {
 66                        // Only update if we are not editing.
 67                        // If we are, cancelling the edit will set the message to the newest content.
 68                        editor.update(cx, |editor, cx| {
 69                            editor.set_message(chunks, window, cx);
 70                        });
 71                    }
 72                } else {
 73                    let message_editor = cx.new(|cx| {
 74                        let mut editor = MessageEditor::new(
 75                            self.workspace.clone(),
 76                            self.project.clone(),
 77                            self.thread_store.clone(),
 78                            self.text_thread_store.clone(),
 79                            "Edit message - @ to include context",
 80                            editor::EditorMode::AutoHeight {
 81                                min_lines: 1,
 82                                max_lines: None,
 83                            },
 84                            window,
 85                            cx,
 86                        );
 87                        if !has_id {
 88                            editor.set_read_only(true, cx);
 89                        }
 90                        editor.set_message(chunks, window, cx);
 91                        editor
 92                    });
 93                    cx.subscribe(&message_editor, move |_, editor, event, cx| {
 94                        cx.emit(EntryViewEvent {
 95                            entry_index: index,
 96                            view_event: ViewEvent::MessageEditorEvent(editor, *event),
 97                        })
 98                    })
 99                    .detach();
100                    self.set_entry(index, Entry::UserMessage(message_editor));
101                }
102            }
103            AgentThreadEntry::ToolCall(tool_call) => {
104                let terminals = tool_call.terminals().cloned().collect::<Vec<_>>();
105                let diffs = tool_call.diffs().cloned().collect::<Vec<_>>();
106
107                let views = if let Some(Entry::Content(views)) = self.entries.get_mut(index) {
108                    views
109                } else {
110                    self.set_entry(index, Entry::empty());
111                    let Some(Entry::Content(views)) = self.entries.get_mut(index) else {
112                        unreachable!()
113                    };
114                    views
115                };
116
117                for terminal in terminals {
118                    views.entry(terminal.entity_id()).or_insert_with(|| {
119                        create_terminal(
120                            self.workspace.clone(),
121                            self.project.clone(),
122                            terminal.clone(),
123                            window,
124                            cx,
125                        )
126                        .into_any()
127                    });
128                }
129
130                for diff in diffs {
131                    views
132                        .entry(diff.entity_id())
133                        .or_insert_with(|| create_editor_diff(diff.clone(), window, cx).into_any());
134                }
135            }
136            AgentThreadEntry::AssistantMessage(_) => {
137                if index == self.entries.len() {
138                    self.entries.push(Entry::empty())
139                }
140            }
141        };
142    }
143
144    fn set_entry(&mut self, index: usize, entry: Entry) {
145        if index == self.entries.len() {
146            self.entries.push(entry);
147        } else {
148            self.entries[index] = entry;
149        }
150    }
151
152    pub fn remove(&mut self, range: Range<usize>) {
153        self.entries.drain(range);
154    }
155
156    pub fn settings_changed(&mut self, cx: &mut App) {
157        for entry in self.entries.iter() {
158            match entry {
159                Entry::UserMessage { .. } => {}
160                Entry::Content(response_views) => {
161                    for view in response_views.values() {
162                        if let Ok(diff_editor) = view.clone().downcast::<Editor>() {
163                            diff_editor.update(cx, |diff_editor, cx| {
164                                diff_editor.set_text_style_refinement(
165                                    diff_editor_text_style_refinement(cx),
166                                );
167                                cx.notify();
168                            })
169                        }
170                    }
171                }
172            }
173        }
174    }
175}
176
177impl EventEmitter<EntryViewEvent> for EntryViewState {}
178
179pub struct EntryViewEvent {
180    pub entry_index: usize,
181    pub view_event: ViewEvent,
182}
183
184pub enum ViewEvent {
185    MessageEditorEvent(Entity<MessageEditor>, MessageEditorEvent),
186}
187
188pub enum Entry {
189    UserMessage(Entity<MessageEditor>),
190    Content(HashMap<EntityId, AnyEntity>),
191}
192
193impl Entry {
194    pub fn message_editor(&self) -> Option<&Entity<MessageEditor>> {
195        match self {
196            Self::UserMessage(editor) => Some(editor),
197            Entry::Content(_) => None,
198        }
199    }
200
201    pub fn editor_for_diff(&self, diff: &Entity<acp_thread::Diff>) -> Option<Entity<Editor>> {
202        self.content_map()?
203            .get(&diff.entity_id())
204            .cloned()
205            .map(|entity| entity.downcast::<Editor>().unwrap())
206    }
207
208    pub fn terminal(
209        &self,
210        terminal: &Entity<acp_thread::Terminal>,
211    ) -> Option<Entity<TerminalView>> {
212        self.content_map()?
213            .get(&terminal.entity_id())
214            .cloned()
215            .map(|entity| entity.downcast::<TerminalView>().unwrap())
216    }
217
218    fn content_map(&self) -> Option<&HashMap<EntityId, AnyEntity>> {
219        match self {
220            Self::Content(map) => Some(map),
221            _ => None,
222        }
223    }
224
225    fn empty() -> Self {
226        Self::Content(HashMap::default())
227    }
228
229    #[cfg(test)]
230    pub fn has_content(&self) -> bool {
231        match self {
232            Self::Content(map) => !map.is_empty(),
233            Self::UserMessage(_) => false,
234        }
235    }
236}
237
238fn create_terminal(
239    workspace: WeakEntity<Workspace>,
240    project: Entity<Project>,
241    terminal: Entity<acp_thread::Terminal>,
242    window: &mut Window,
243    cx: &mut App,
244) -> Entity<TerminalView> {
245    cx.new(|cx| {
246        let mut view = TerminalView::new(
247            terminal.read(cx).inner().clone(),
248            workspace.clone(),
249            None,
250            project.downgrade(),
251            window,
252            cx,
253        );
254        view.set_embedded_mode(Some(1000), cx);
255        view
256    })
257}
258
259fn create_editor_diff(
260    diff: Entity<acp_thread::Diff>,
261    window: &mut Window,
262    cx: &mut App,
263) -> Entity<Editor> {
264    cx.new(|cx| {
265        let mut editor = Editor::new(
266            EditorMode::Full {
267                scale_ui_elements_with_buffer_font_size: false,
268                show_active_line_background: false,
269                sized_by_content: true,
270            },
271            diff.read(cx).multibuffer().clone(),
272            None,
273            window,
274            cx,
275        );
276        editor.set_show_gutter(false, cx);
277        editor.disable_inline_diagnostics();
278        editor.disable_expand_excerpt_buttons(cx);
279        editor.set_show_vertical_scrollbar(false, cx);
280        editor.set_minimap_visibility(MinimapVisibility::Disabled, window, cx);
281        editor.set_soft_wrap_mode(SoftWrap::None, cx);
282        editor.scroll_manager.set_forbid_vertical_scroll(true);
283        editor.set_show_indent_guides(false, cx);
284        editor.set_read_only(true);
285        editor.set_show_breakpoints(false, cx);
286        editor.set_show_code_actions(false, cx);
287        editor.set_show_git_diff_gutter(false, cx);
288        editor.set_expand_all_diff_hunks(cx);
289        editor.set_text_style_refinement(diff_editor_text_style_refinement(cx));
290        editor
291    })
292}
293
294fn diff_editor_text_style_refinement(cx: &mut App) -> TextStyleRefinement {
295    TextStyleRefinement {
296        font_size: Some(
297            TextSize::Small
298                .rems(cx)
299                .to_pixels(ThemeSettings::get_global(cx).agent_font_size(cx))
300                .into(),
301        ),
302        ..Default::default()
303    }
304}
305
306#[cfg(test)]
307mod tests {
308    use std::{path::Path, rc::Rc};
309
310    use acp_thread::{AgentConnection, StubAgentConnection};
311    use agent::{TextThreadStore, ThreadStore};
312    use agent_client_protocol as acp;
313    use agent_settings::AgentSettings;
314    use buffer_diff::{DiffHunkStatus, DiffHunkStatusKind};
315    use editor::{EditorSettings, RowInfo};
316    use fs::FakeFs;
317    use gpui::{AppContext as _, SemanticVersion, TestAppContext};
318
319    use crate::acp::entry_view_state::EntryViewState;
320    use multi_buffer::MultiBufferRow;
321    use pretty_assertions::assert_matches;
322    use project::Project;
323    use serde_json::json;
324    use settings::{Settings as _, SettingsStore};
325    use theme::ThemeSettings;
326    use util::path;
327    use workspace::Workspace;
328
329    #[gpui::test]
330    async fn test_diff_sync(cx: &mut TestAppContext) {
331        init_test(cx);
332        let fs = FakeFs::new(cx.executor());
333        fs.insert_tree(
334            "/project",
335            json!({
336                "hello.txt": "hi world"
337            }),
338        )
339        .await;
340        let project = Project::test(fs, [Path::new(path!("/project"))], cx).await;
341
342        let (workspace, cx) =
343            cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
344
345        let tool_call = acp::ToolCall {
346            id: acp::ToolCallId("tool".into()),
347            title: "Tool call".into(),
348            kind: acp::ToolKind::Other,
349            status: acp::ToolCallStatus::InProgress,
350            content: vec![acp::ToolCallContent::Diff {
351                diff: acp::Diff {
352                    path: "/project/hello.txt".into(),
353                    old_text: Some("hi world".into()),
354                    new_text: "hello world".into(),
355                },
356            }],
357            locations: vec![],
358            raw_input: None,
359            raw_output: None,
360        };
361        let connection = Rc::new(StubAgentConnection::new());
362        let thread = cx
363            .update(|_, cx| {
364                connection
365                    .clone()
366                    .new_thread(project.clone(), Path::new(path!("/project")), cx)
367            })
368            .await
369            .unwrap();
370        let session_id = thread.update(cx, |thread, _| thread.session_id().clone());
371
372        cx.update(|_, cx| {
373            connection.send_update(session_id, acp::SessionUpdate::ToolCall(tool_call), cx)
374        });
375
376        let thread_store = cx.new(|cx| ThreadStore::fake(project.clone(), cx));
377        let text_thread_store = cx.new(|cx| TextThreadStore::fake(project.clone(), cx));
378
379        let view_state = cx.new(|_cx| {
380            EntryViewState::new(
381                workspace.downgrade(),
382                project.clone(),
383                thread_store,
384                text_thread_store,
385            )
386        });
387
388        view_state.update_in(cx, |view_state, window, cx| {
389            view_state.sync_entry(0, &thread, window, cx)
390        });
391
392        let diff = thread.read_with(cx, |thread, _cx| {
393            thread
394                .entries()
395                .get(0)
396                .unwrap()
397                .diffs()
398                .next()
399                .unwrap()
400                .clone()
401        });
402
403        cx.run_until_parked();
404
405        let diff_editor = view_state.read_with(cx, |view_state, _cx| {
406            view_state.entry(0).unwrap().editor_for_diff(&diff).unwrap()
407        });
408        assert_eq!(
409            diff_editor.read_with(cx, |editor, cx| editor.text(cx)),
410            "hi world\nhello world"
411        );
412        let row_infos = diff_editor.read_with(cx, |editor, cx| {
413            let multibuffer = editor.buffer().read(cx);
414            multibuffer
415                .snapshot(cx)
416                .row_infos(MultiBufferRow(0))
417                .collect::<Vec<_>>()
418        });
419        assert_matches!(
420            row_infos.as_slice(),
421            [
422                RowInfo {
423                    multibuffer_row: Some(MultiBufferRow(0)),
424                    diff_status: Some(DiffHunkStatus {
425                        kind: DiffHunkStatusKind::Deleted,
426                        ..
427                    }),
428                    ..
429                },
430                RowInfo {
431                    multibuffer_row: Some(MultiBufferRow(1)),
432                    diff_status: Some(DiffHunkStatus {
433                        kind: DiffHunkStatusKind::Added,
434                        ..
435                    }),
436                    ..
437                }
438            ]
439        );
440    }
441
442    fn init_test(cx: &mut TestAppContext) {
443        cx.update(|cx| {
444            let settings_store = SettingsStore::test(cx);
445            cx.set_global(settings_store);
446            language::init(cx);
447            Project::init_settings(cx);
448            AgentSettings::register(cx);
449            workspace::init_settings(cx);
450            ThemeSettings::register(cx);
451            release_channel::init(SemanticVersion::default(), cx);
452            EditorSettings::register(cx);
453        });
454    }
455}