1use std::{cell::RefCell, ops::Range, rc::Rc};
  2
  3use acp_thread::{AcpThread, AgentThreadEntry};
  4use agent::HistoryStore;
  5use agent_client_protocol::{self as acp, ToolCallId};
  6use collections::HashMap;
  7use editor::{Editor, EditorMode, MinimapVisibility};
  8use gpui::{
  9    AnyEntity, App, AppContext as _, Entity, EntityId, EventEmitter, FocusHandle, Focusable,
 10    ScrollHandle, SharedString, TextStyleRefinement, WeakEntity, Window,
 11};
 12use language::language_settings::SoftWrap;
 13use project::Project;
 14use prompt_store::PromptStore;
 15use settings::Settings as _;
 16use terminal_view::TerminalView;
 17use theme::ThemeSettings;
 18use ui::{Context, TextSize};
 19use workspace::Workspace;
 20
 21use crate::acp::message_editor::{MessageEditor, MessageEditorEvent};
 22
 23pub struct EntryViewState {
 24    workspace: WeakEntity<Workspace>,
 25    project: Entity<Project>,
 26    history_store: Entity<HistoryStore>,
 27    prompt_store: Option<Entity<PromptStore>>,
 28    entries: Vec<Entry>,
 29    prompt_capabilities: Rc<RefCell<acp::PromptCapabilities>>,
 30    available_commands: Rc<RefCell<Vec<acp::AvailableCommand>>>,
 31    agent_name: SharedString,
 32}
 33
 34impl EntryViewState {
 35    pub const fn new(
 36        workspace: WeakEntity<Workspace>,
 37        project: Entity<Project>,
 38        history_store: Entity<HistoryStore>,
 39        prompt_store: Option<Entity<PromptStore>>,
 40        prompt_capabilities: Rc<RefCell<acp::PromptCapabilities>>,
 41        available_commands: Rc<RefCell<Vec<acp::AvailableCommand>>>,
 42        agent_name: SharedString,
 43    ) -> Self {
 44        Self {
 45            workspace,
 46            project,
 47            history_store,
 48            prompt_store,
 49            entries: Vec::new(),
 50            prompt_capabilities,
 51            available_commands,
 52            agent_name,
 53        }
 54    }
 55
 56    pub fn entry(&self, index: usize) -> Option<&Entry> {
 57        self.entries.get(index)
 58    }
 59
 60    pub fn sync_entry(
 61        &mut self,
 62        index: usize,
 63        thread: &Entity<AcpThread>,
 64        window: &mut Window,
 65        cx: &mut Context<Self>,
 66    ) {
 67        let Some(thread_entry) = thread.read(cx).entries().get(index) else {
 68            return;
 69        };
 70
 71        match thread_entry {
 72            AgentThreadEntry::UserMessage(message) => {
 73                let has_id = message.id.is_some();
 74                let chunks = message.chunks.clone();
 75                if let Some(Entry::UserMessage(editor)) = self.entries.get_mut(index) {
 76                    if !editor.focus_handle(cx).is_focused(window) {
 77                        // Only update if we are not editing.
 78                        // If we are, cancelling the edit will set the message to the newest content.
 79                        editor.update(cx, |editor, cx| {
 80                            editor.set_message(chunks, window, cx);
 81                        });
 82                    }
 83                } else {
 84                    let message_editor = cx.new(|cx| {
 85                        let mut editor = MessageEditor::new(
 86                            self.workspace.clone(),
 87                            self.project.clone(),
 88                            self.history_store.clone(),
 89                            self.prompt_store.clone(),
 90                            self.prompt_capabilities.clone(),
 91                            self.available_commands.clone(),
 92                            self.agent_name.clone(),
 93                            "Edit message - @ to include context",
 94                            editor::EditorMode::AutoHeight {
 95                                min_lines: 1,
 96                                max_lines: None,
 97                            },
 98                            window,
 99                            cx,
100                        );
101                        if !has_id {
102                            editor.set_read_only(true, cx);
103                        }
104                        editor.set_message(chunks, window, cx);
105                        editor
106                    });
107                    cx.subscribe(&message_editor, move |_, editor, event, cx| {
108                        cx.emit(EntryViewEvent {
109                            entry_index: index,
110                            view_event: ViewEvent::MessageEditorEvent(editor, *event),
111                        })
112                    })
113                    .detach();
114                    self.set_entry(index, Entry::UserMessage(message_editor));
115                }
116            }
117            AgentThreadEntry::ToolCall(tool_call) => {
118                let id = tool_call.id.clone();
119                let terminals = tool_call.terminals().cloned().collect::<Vec<_>>();
120                let diffs = tool_call.diffs().cloned().collect::<Vec<_>>();
121
122                let views = if let Some(Entry::Content(views)) = self.entries.get_mut(index) {
123                    views
124                } else {
125                    self.set_entry(index, Entry::empty());
126                    let Some(Entry::Content(views)) = self.entries.get_mut(index) else {
127                        unreachable!()
128                    };
129                    views
130                };
131
132                let is_tool_call_completed =
133                    matches!(tool_call.status, acp_thread::ToolCallStatus::Completed);
134
135                for terminal in terminals {
136                    match views.entry(terminal.entity_id()) {
137                        collections::hash_map::Entry::Vacant(entry) => {
138                            let element = create_terminal(
139                                self.workspace.clone(),
140                                self.project.clone(),
141                                terminal.clone(),
142                                window,
143                                cx,
144                            )
145                            .into_any();
146                            cx.emit(EntryViewEvent {
147                                entry_index: index,
148                                view_event: ViewEvent::NewTerminal(id.clone()),
149                            });
150                            entry.insert(element);
151                        }
152                        collections::hash_map::Entry::Occupied(_entry) => {
153                            if is_tool_call_completed && terminal.read(cx).output().is_none() {
154                                cx.emit(EntryViewEvent {
155                                    entry_index: index,
156                                    view_event: ViewEvent::TerminalMovedToBackground(id.clone()),
157                                });
158                            }
159                        }
160                    }
161                }
162
163                for diff in diffs {
164                    views.entry(diff.entity_id()).or_insert_with(|| {
165                        let element = create_editor_diff(diff.clone(), window, cx).into_any();
166                        cx.emit(EntryViewEvent {
167                            entry_index: index,
168                            view_event: ViewEvent::NewDiff(id.clone()),
169                        });
170                        element
171                    });
172                }
173            }
174            AgentThreadEntry::AssistantMessage(message) => {
175                let entry = if let Some(Entry::AssistantMessage(entry)) =
176                    self.entries.get_mut(index)
177                {
178                    entry
179                } else {
180                    self.set_entry(
181                        index,
182                        Entry::AssistantMessage(AssistantMessageEntry::default()),
183                    );
184                    let Some(Entry::AssistantMessage(entry)) = self.entries.get_mut(index) else {
185                        unreachable!()
186                    };
187                    entry
188                };
189                entry.sync(message);
190            }
191        };
192    }
193
194    fn set_entry(&mut self, index: usize, entry: Entry) {
195        if index == self.entries.len() {
196            self.entries.push(entry);
197        } else {
198            self.entries[index] = entry;
199        }
200    }
201
202    pub fn remove(&mut self, range: Range<usize>) {
203        self.entries.drain(range);
204    }
205
206    pub fn agent_ui_font_size_changed(&mut self, cx: &mut App) {
207        for entry in self.entries.iter() {
208            match entry {
209                Entry::UserMessage { .. } | Entry::AssistantMessage { .. } => {}
210                Entry::Content(response_views) => {
211                    for view in response_views.values() {
212                        if let Ok(diff_editor) = view.clone().downcast::<Editor>() {
213                            diff_editor.update(cx, |diff_editor, cx| {
214                                diff_editor.set_text_style_refinement(
215                                    diff_editor_text_style_refinement(cx),
216                                );
217                                cx.notify();
218                            })
219                        }
220                    }
221                }
222            }
223        }
224    }
225}
226
227impl EventEmitter<EntryViewEvent> for EntryViewState {}
228
229pub struct EntryViewEvent {
230    pub entry_index: usize,
231    pub view_event: ViewEvent,
232}
233
234pub enum ViewEvent {
235    NewDiff(ToolCallId),
236    NewTerminal(ToolCallId),
237    TerminalMovedToBackground(ToolCallId),
238    MessageEditorEvent(Entity<MessageEditor>, MessageEditorEvent),
239}
240
241#[derive(Default, Debug)]
242pub struct AssistantMessageEntry {
243    scroll_handles_by_chunk_index: HashMap<usize, ScrollHandle>,
244}
245
246impl AssistantMessageEntry {
247    pub fn scroll_handle_for_chunk(&self, ix: usize) -> Option<ScrollHandle> {
248        self.scroll_handles_by_chunk_index.get(&ix).cloned()
249    }
250
251    pub fn sync(&mut self, message: &acp_thread::AssistantMessage) {
252        if let Some(acp_thread::AssistantMessageChunk::Thought { .. }) = message.chunks.last() {
253            let ix = message.chunks.len() - 1;
254            let handle = self.scroll_handles_by_chunk_index.entry(ix).or_default();
255            handle.scroll_to_bottom();
256        }
257    }
258}
259
260#[derive(Debug)]
261pub enum Entry {
262    UserMessage(Entity<MessageEditor>),
263    AssistantMessage(AssistantMessageEntry),
264    Content(HashMap<EntityId, AnyEntity>),
265}
266
267impl Entry {
268    pub fn focus_handle(&self, cx: &App) -> Option<FocusHandle> {
269        match self {
270            Self::UserMessage(editor) => Some(editor.read(cx).focus_handle(cx)),
271            Self::AssistantMessage(_) | Self::Content(_) => None,
272        }
273    }
274
275    pub const fn message_editor(&self) -> Option<&Entity<MessageEditor>> {
276        match self {
277            Self::UserMessage(editor) => Some(editor),
278            Self::AssistantMessage(_) | Self::Content(_) => None,
279        }
280    }
281
282    pub fn editor_for_diff(&self, diff: &Entity<acp_thread::Diff>) -> Option<Entity<Editor>> {
283        self.content_map()?
284            .get(&diff.entity_id())
285            .cloned()
286            .map(|entity| entity.downcast::<Editor>().unwrap())
287    }
288
289    pub fn terminal(
290        &self,
291        terminal: &Entity<acp_thread::Terminal>,
292    ) -> Option<Entity<TerminalView>> {
293        self.content_map()?
294            .get(&terminal.entity_id())
295            .cloned()
296            .map(|entity| entity.downcast::<TerminalView>().unwrap())
297    }
298
299    pub fn scroll_handle_for_assistant_message_chunk(
300        &self,
301        chunk_ix: usize,
302    ) -> Option<ScrollHandle> {
303        match self {
304            Self::AssistantMessage(message) => message.scroll_handle_for_chunk(chunk_ix),
305            Self::UserMessage(_) | Self::Content(_) => None,
306        }
307    }
308
309    const fn content_map(&self) -> Option<&HashMap<EntityId, AnyEntity>> {
310        match self {
311            Self::Content(map) => Some(map),
312            _ => None,
313        }
314    }
315
316    fn empty() -> Self {
317        Self::Content(HashMap::default())
318    }
319
320    #[cfg(test)]
321    pub fn has_content(&self) -> bool {
322        match self {
323            Self::Content(map) => !map.is_empty(),
324            Self::UserMessage(_) | Self::AssistantMessage(_) => false,
325        }
326    }
327}
328
329fn create_terminal(
330    workspace: WeakEntity<Workspace>,
331    project: Entity<Project>,
332    terminal: Entity<acp_thread::Terminal>,
333    window: &mut Window,
334    cx: &mut App,
335) -> Entity<TerminalView> {
336    cx.new(|cx| {
337        let mut view = TerminalView::new(
338            terminal.read(cx).inner().clone(),
339            workspace.clone(),
340            None,
341            project.downgrade(),
342            window,
343            cx,
344        );
345        view.set_embedded_mode(Some(1000), cx);
346        view
347    })
348}
349
350fn create_editor_diff(
351    diff: Entity<acp_thread::Diff>,
352    window: &mut Window,
353    cx: &mut App,
354) -> Entity<Editor> {
355    cx.new(|cx| {
356        let mut editor = Editor::new(
357            EditorMode::Full {
358                scale_ui_elements_with_buffer_font_size: false,
359                show_active_line_background: false,
360                sized_by_content: true,
361            },
362            diff.read(cx).multibuffer().clone(),
363            None,
364            window,
365            cx,
366        );
367        editor.set_show_gutter(false, cx);
368        editor.disable_inline_diagnostics();
369        editor.disable_expand_excerpt_buttons(cx);
370        editor.set_show_vertical_scrollbar(false, cx);
371        editor.set_minimap_visibility(MinimapVisibility::Disabled, window, cx);
372        editor.set_soft_wrap_mode(SoftWrap::None, cx);
373        editor.scroll_manager.set_forbid_vertical_scroll(true);
374        editor.set_show_indent_guides(false, cx);
375        editor.set_read_only(true);
376        editor.set_show_breakpoints(false, cx);
377        editor.set_show_code_actions(false, cx);
378        editor.set_show_git_diff_gutter(false, cx);
379        editor.set_expand_all_diff_hunks(cx);
380        editor.set_text_style_refinement(diff_editor_text_style_refinement(cx));
381        editor
382    })
383}
384
385fn diff_editor_text_style_refinement(cx: &mut App) -> TextStyleRefinement {
386    TextStyleRefinement {
387        font_size: Some(
388            TextSize::Small
389                .rems(cx)
390                .to_pixels(ThemeSettings::get_global(cx).agent_ui_font_size(cx))
391                .into(),
392        ),
393        ..Default::default()
394    }
395}
396
397#[cfg(test)]
398mod tests {
399    use std::{path::Path, rc::Rc};
400
401    use acp_thread::{AgentConnection, StubAgentConnection};
402    use agent::HistoryStore;
403    use agent_client_protocol as acp;
404    use agent_settings::AgentSettings;
405    use assistant_context::ContextStore;
406    use buffer_diff::{DiffHunkStatus, DiffHunkStatusKind};
407    use editor::{EditorSettings, RowInfo};
408    use fs::FakeFs;
409    use gpui::{AppContext as _, SemanticVersion, TestAppContext};
410
411    use crate::acp::entry_view_state::EntryViewState;
412    use multi_buffer::MultiBufferRow;
413    use pretty_assertions::assert_matches;
414    use project::Project;
415    use serde_json::json;
416    use settings::{Settings as _, SettingsStore};
417    use util::path;
418    use workspace::Workspace;
419
420    #[gpui::test]
421    async fn test_diff_sync(cx: &mut TestAppContext) {
422        init_test(cx);
423        let fs = FakeFs::new(cx.executor());
424        fs.insert_tree(
425            "/project",
426            json!({
427                "hello.txt": "hi world"
428            }),
429        )
430        .await;
431        let project = Project::test(fs, [Path::new(path!("/project"))], cx).await;
432
433        let (workspace, cx) =
434            cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
435
436        let tool_call = acp::ToolCall {
437            id: acp::ToolCallId("tool".into()),
438            title: "Tool call".into(),
439            kind: acp::ToolKind::Other,
440            status: acp::ToolCallStatus::InProgress,
441            content: vec![acp::ToolCallContent::Diff {
442                diff: acp::Diff {
443                    path: "/project/hello.txt".into(),
444                    old_text: Some("hi world".into()),
445                    new_text: "hello world".into(),
446                    meta: None,
447                },
448            }],
449            locations: vec![],
450            raw_input: None,
451            raw_output: None,
452            meta: None,
453        };
454        let connection = Rc::new(StubAgentConnection::new());
455        let thread = cx
456            .update(|_, cx| {
457                connection
458                    .clone()
459                    .new_thread(project.clone(), Path::new(path!("/project")), cx)
460            })
461            .await
462            .unwrap();
463        let session_id = thread.update(cx, |thread, _| thread.session_id().clone());
464
465        cx.update(|_, cx| {
466            connection.send_update(session_id, acp::SessionUpdate::ToolCall(tool_call), cx)
467        });
468
469        let context_store = cx.new(|cx| ContextStore::fake(project.clone(), cx));
470        let history_store = cx.new(|cx| HistoryStore::new(context_store, cx));
471
472        let view_state = cx.new(|_cx| {
473            EntryViewState::new(
474                workspace.downgrade(),
475                project.clone(),
476                history_store,
477                None,
478                Default::default(),
479                Default::default(),
480                "Test Agent".into(),
481            )
482        });
483
484        view_state.update_in(cx, |view_state, window, cx| {
485            view_state.sync_entry(0, &thread, window, cx)
486        });
487
488        let diff = thread.read_with(cx, |thread, _cx| {
489            thread
490                .entries()
491                .get(0)
492                .unwrap()
493                .diffs()
494                .next()
495                .unwrap()
496                .clone()
497        });
498
499        cx.run_until_parked();
500
501        let diff_editor = view_state.read_with(cx, |view_state, _cx| {
502            view_state.entry(0).unwrap().editor_for_diff(&diff).unwrap()
503        });
504        assert_eq!(
505            diff_editor.read_with(cx, |editor, cx| editor.text(cx)),
506            "hi world\nhello world"
507        );
508        let row_infos = diff_editor.read_with(cx, |editor, cx| {
509            let multibuffer = editor.buffer().read(cx);
510            multibuffer
511                .snapshot(cx)
512                .row_infos(MultiBufferRow(0))
513                .collect::<Vec<_>>()
514        });
515        assert_matches!(
516            row_infos.as_slice(),
517            [
518                RowInfo {
519                    multibuffer_row: Some(MultiBufferRow(0)),
520                    diff_status: Some(DiffHunkStatus {
521                        kind: DiffHunkStatusKind::Deleted,
522                        ..
523                    }),
524                    ..
525                },
526                RowInfo {
527                    multibuffer_row: Some(MultiBufferRow(1)),
528                    diff_status: Some(DiffHunkStatus {
529                        kind: DiffHunkStatusKind::Added,
530                        ..
531                    }),
532                    ..
533                }
534            ]
535        );
536    }
537
538    fn init_test(cx: &mut TestAppContext) {
539        cx.update(|cx| {
540            let settings_store = SettingsStore::test(cx);
541            cx.set_global(settings_store);
542            language::init(cx);
543            Project::init_settings(cx);
544            AgentSettings::register(cx);
545            workspace::init_settings(cx);
546            theme::init(theme::LoadThemes::JustBase, cx);
547            release_channel::init(SemanticVersion::default(), cx);
548            EditorSettings::register(cx);
549        });
550    }
551}