entry_view_state.rs

  1use std::{collections::HashMap, ops::Range};
  2
  3use acp_thread::AcpThread;
  4use editor::{Editor, EditorMode, MinimapVisibility, MultiBuffer};
  5use gpui::{
  6    AnyEntity, App, AppContext as _, Entity, EntityId, TextStyleRefinement, WeakEntity, Window,
  7};
  8use language::language_settings::SoftWrap;
  9use settings::Settings as _;
 10use terminal_view::TerminalView;
 11use theme::ThemeSettings;
 12use ui::TextSize;
 13use workspace::Workspace;
 14
 15#[derive(Default)]
 16pub struct EntryViewState {
 17    entries: Vec<Entry>,
 18}
 19
 20impl EntryViewState {
 21    pub fn entry(&self, index: usize) -> Option<&Entry> {
 22        self.entries.get(index)
 23    }
 24
 25    pub fn sync_entry(
 26        &mut self,
 27        workspace: WeakEntity<Workspace>,
 28        thread: Entity<AcpThread>,
 29        index: usize,
 30        window: &mut Window,
 31        cx: &mut App,
 32    ) {
 33        debug_assert!(index <= self.entries.len());
 34        let entry = if let Some(entry) = self.entries.get_mut(index) {
 35            entry
 36        } else {
 37            self.entries.push(Entry::default());
 38            self.entries.last_mut().unwrap()
 39        };
 40
 41        entry.sync_diff_multibuffers(&thread, index, window, cx);
 42        entry.sync_terminals(&workspace, &thread, index, window, cx);
 43    }
 44
 45    pub fn remove(&mut self, range: Range<usize>) {
 46        self.entries.drain(range);
 47    }
 48
 49    pub fn settings_changed(&mut self, cx: &mut App) {
 50        for entry in self.entries.iter() {
 51            for view in entry.views.values() {
 52                if let Ok(diff_editor) = view.clone().downcast::<Editor>() {
 53                    diff_editor.update(cx, |diff_editor, cx| {
 54                        diff_editor
 55                            .set_text_style_refinement(diff_editor_text_style_refinement(cx));
 56                        cx.notify();
 57                    })
 58                }
 59            }
 60        }
 61    }
 62}
 63
 64pub struct Entry {
 65    views: HashMap<EntityId, AnyEntity>,
 66}
 67
 68impl Entry {
 69    pub fn editor_for_diff(&self, diff: &Entity<MultiBuffer>) -> Option<Entity<Editor>> {
 70        self.views
 71            .get(&diff.entity_id())
 72            .cloned()
 73            .map(|entity| entity.downcast::<Editor>().unwrap())
 74    }
 75
 76    pub fn terminal(
 77        &self,
 78        terminal: &Entity<acp_thread::Terminal>,
 79    ) -> Option<Entity<TerminalView>> {
 80        self.views
 81            .get(&terminal.entity_id())
 82            .cloned()
 83            .map(|entity| entity.downcast::<TerminalView>().unwrap())
 84    }
 85
 86    fn sync_diff_multibuffers(
 87        &mut self,
 88        thread: &Entity<AcpThread>,
 89        index: usize,
 90        window: &mut Window,
 91        cx: &mut App,
 92    ) {
 93        let Some(entry) = thread.read(cx).entries().get(index) else {
 94            return;
 95        };
 96
 97        let multibuffers = entry
 98            .diffs()
 99            .map(|diff| diff.read(cx).multibuffer().clone());
100
101        let multibuffers = multibuffers.collect::<Vec<_>>();
102
103        for multibuffer in multibuffers {
104            if self.views.contains_key(&multibuffer.entity_id()) {
105                return;
106            }
107
108            let editor = cx.new(|cx| {
109                let mut editor = Editor::new(
110                    EditorMode::Full {
111                        scale_ui_elements_with_buffer_font_size: false,
112                        show_active_line_background: false,
113                        sized_by_content: true,
114                    },
115                    multibuffer.clone(),
116                    None,
117                    window,
118                    cx,
119                );
120                editor.set_show_gutter(false, cx);
121                editor.disable_inline_diagnostics();
122                editor.disable_expand_excerpt_buttons(cx);
123                editor.set_show_vertical_scrollbar(false, cx);
124                editor.set_minimap_visibility(MinimapVisibility::Disabled, window, cx);
125                editor.set_soft_wrap_mode(SoftWrap::None, cx);
126                editor.scroll_manager.set_forbid_vertical_scroll(true);
127                editor.set_show_indent_guides(false, cx);
128                editor.set_read_only(true);
129                editor.set_show_breakpoints(false, cx);
130                editor.set_show_code_actions(false, cx);
131                editor.set_show_git_diff_gutter(false, cx);
132                editor.set_expand_all_diff_hunks(cx);
133                editor.set_text_style_refinement(diff_editor_text_style_refinement(cx));
134                editor
135            });
136
137            let entity_id = multibuffer.entity_id();
138            self.views.insert(entity_id, editor.into_any());
139        }
140    }
141
142    fn sync_terminals(
143        &mut self,
144        workspace: &WeakEntity<Workspace>,
145        thread: &Entity<AcpThread>,
146        index: usize,
147        window: &mut Window,
148        cx: &mut App,
149    ) {
150        let Some(entry) = thread.read(cx).entries().get(index) else {
151            return;
152        };
153
154        let terminals = entry
155            .terminals()
156            .map(|terminal| terminal.clone())
157            .collect::<Vec<_>>();
158
159        for terminal in terminals {
160            if self.views.contains_key(&terminal.entity_id()) {
161                return;
162            }
163
164            let Some(strong_workspace) = workspace.upgrade() else {
165                return;
166            };
167
168            let terminal_view = cx.new(|cx| {
169                let mut view = TerminalView::new(
170                    terminal.read(cx).inner().clone(),
171                    workspace.clone(),
172                    None,
173                    strong_workspace.read(cx).project().downgrade(),
174                    window,
175                    cx,
176                );
177                view.set_embedded_mode(Some(1000), cx);
178                view
179            });
180
181            let entity_id = terminal.entity_id();
182            self.views.insert(entity_id, terminal_view.into_any());
183        }
184    }
185
186    #[cfg(test)]
187    pub fn len(&self) -> usize {
188        self.views.len()
189    }
190}
191
192fn diff_editor_text_style_refinement(cx: &mut App) -> TextStyleRefinement {
193    TextStyleRefinement {
194        font_size: Some(
195            TextSize::Small
196                .rems(cx)
197                .to_pixels(ThemeSettings::get_global(cx).agent_font_size(cx))
198                .into(),
199        ),
200        ..Default::default()
201    }
202}
203
204impl Default for Entry {
205    fn default() -> Self {
206        Self {
207            // Avoid allocating in the heap by default
208            views: HashMap::with_capacity(0),
209        }
210    }
211}
212
213#[cfg(test)]
214mod tests {
215    use std::{path::Path, rc::Rc};
216
217    use acp_thread::{AgentConnection, StubAgentConnection};
218    use agent_client_protocol as acp;
219    use agent_settings::AgentSettings;
220    use buffer_diff::{DiffHunkStatus, DiffHunkStatusKind};
221    use editor::{EditorSettings, RowInfo};
222    use fs::FakeFs;
223    use gpui::{SemanticVersion, TestAppContext};
224    use multi_buffer::MultiBufferRow;
225    use pretty_assertions::assert_matches;
226    use project::Project;
227    use serde_json::json;
228    use settings::{Settings as _, SettingsStore};
229    use theme::ThemeSettings;
230    use util::path;
231    use workspace::Workspace;
232
233    use crate::acp::entry_view_state::EntryViewState;
234
235    #[gpui::test]
236    async fn test_diff_sync(cx: &mut TestAppContext) {
237        init_test(cx);
238        let fs = FakeFs::new(cx.executor());
239        fs.insert_tree(
240            "/project",
241            json!({
242                "hello.txt": "hi world"
243            }),
244        )
245        .await;
246        let project = Project::test(fs, [Path::new(path!("/project"))], cx).await;
247
248        let (workspace, cx) =
249            cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
250
251        let tool_call = acp::ToolCall {
252            id: acp::ToolCallId("tool".into()),
253            title: "Tool call".into(),
254            kind: acp::ToolKind::Other,
255            status: acp::ToolCallStatus::InProgress,
256            content: vec![acp::ToolCallContent::Diff {
257                diff: acp::Diff {
258                    path: "/project/hello.txt".into(),
259                    old_text: Some("hi world".into()),
260                    new_text: "hello world".into(),
261                },
262            }],
263            locations: vec![],
264            raw_input: None,
265            raw_output: None,
266        };
267        let connection = Rc::new(StubAgentConnection::new());
268        let thread = cx
269            .update(|_, cx| {
270                connection
271                    .clone()
272                    .new_thread(project, Path::new(path!("/project")), cx)
273            })
274            .await
275            .unwrap();
276        let session_id = thread.update(cx, |thread, _| thread.session_id().clone());
277
278        cx.update(|_, cx| {
279            connection.send_update(session_id, acp::SessionUpdate::ToolCall(tool_call), cx)
280        });
281
282        let mut view_state = EntryViewState::default();
283        cx.update(|window, cx| {
284            view_state.sync_entry(workspace.downgrade(), thread.clone(), 0, window, cx);
285        });
286
287        let multibuffer = thread.read_with(cx, |thread, cx| {
288            thread
289                .entries()
290                .get(0)
291                .unwrap()
292                .diffs()
293                .next()
294                .unwrap()
295                .read(cx)
296                .multibuffer()
297                .clone()
298        });
299
300        cx.run_until_parked();
301
302        let entry = view_state.entry(0).unwrap();
303        let diff_editor = entry.editor_for_diff(&multibuffer).unwrap();
304        assert_eq!(
305            diff_editor.read_with(cx, |editor, cx| editor.text(cx)),
306            "hi world\nhello world"
307        );
308        let row_infos = diff_editor.read_with(cx, |editor, cx| {
309            let multibuffer = editor.buffer().read(cx);
310            multibuffer
311                .snapshot(cx)
312                .row_infos(MultiBufferRow(0))
313                .collect::<Vec<_>>()
314        });
315        assert_matches!(
316            row_infos.as_slice(),
317            [
318                RowInfo {
319                    multibuffer_row: Some(MultiBufferRow(0)),
320                    diff_status: Some(DiffHunkStatus {
321                        kind: DiffHunkStatusKind::Deleted,
322                        ..
323                    }),
324                    ..
325                },
326                RowInfo {
327                    multibuffer_row: Some(MultiBufferRow(1)),
328                    diff_status: Some(DiffHunkStatus {
329                        kind: DiffHunkStatusKind::Added,
330                        ..
331                    }),
332                    ..
333                }
334            ]
335        );
336    }
337
338    fn init_test(cx: &mut TestAppContext) {
339        cx.update(|cx| {
340            let settings_store = SettingsStore::test(cx);
341            cx.set_global(settings_store);
342            language::init(cx);
343            Project::init_settings(cx);
344            AgentSettings::register(cx);
345            workspace::init_settings(cx);
346            ThemeSettings::register(cx);
347            release_channel::init(SemanticVersion::default(), cx);
348            EditorSettings::register(cx);
349        });
350    }
351}