@@ -0,0 +1,351 @@
+use std::{collections::HashMap, ops::Range};
+
+use acp_thread::AcpThread;
+use editor::{Editor, EditorMode, MinimapVisibility, MultiBuffer};
+use gpui::{
+ AnyEntity, App, AppContext as _, Entity, EntityId, TextStyleRefinement, WeakEntity, Window,
+};
+use language::language_settings::SoftWrap;
+use settings::Settings as _;
+use terminal_view::TerminalView;
+use theme::ThemeSettings;
+use ui::TextSize;
+use workspace::Workspace;
+
+#[derive(Default)]
+pub struct EntryViewState {
+ entries: Vec<Entry>,
+}
+
+impl EntryViewState {
+ pub fn entry(&self, index: usize) -> Option<&Entry> {
+ self.entries.get(index)
+ }
+
+ pub fn sync_entry(
+ &mut self,
+ workspace: WeakEntity<Workspace>,
+ thread: Entity<AcpThread>,
+ index: usize,
+ window: &mut Window,
+ cx: &mut App,
+ ) {
+ debug_assert!(index <= self.entries.len());
+ let entry = if let Some(entry) = self.entries.get_mut(index) {
+ entry
+ } else {
+ self.entries.push(Entry::default());
+ self.entries.last_mut().unwrap()
+ };
+
+ entry.sync_diff_multibuffers(&thread, index, window, cx);
+ entry.sync_terminals(&workspace, &thread, index, window, cx);
+ }
+
+ pub fn remove(&mut self, range: Range<usize>) {
+ self.entries.drain(range);
+ }
+
+ pub fn settings_changed(&mut self, cx: &mut App) {
+ for entry in self.entries.iter() {
+ for view in entry.views.values() {
+ if let Ok(diff_editor) = view.clone().downcast::<Editor>() {
+ diff_editor.update(cx, |diff_editor, cx| {
+ diff_editor
+ .set_text_style_refinement(diff_editor_text_style_refinement(cx));
+ cx.notify();
+ })
+ }
+ }
+ }
+ }
+}
+
+pub struct Entry {
+ views: HashMap<EntityId, AnyEntity>,
+}
+
+impl Entry {
+ pub fn editor_for_diff(&self, diff: &Entity<MultiBuffer>) -> Option<Entity<Editor>> {
+ self.views
+ .get(&diff.entity_id())
+ .cloned()
+ .map(|entity| entity.downcast::<Editor>().unwrap())
+ }
+
+ pub fn terminal(
+ &self,
+ terminal: &Entity<acp_thread::Terminal>,
+ ) -> Option<Entity<TerminalView>> {
+ self.views
+ .get(&terminal.entity_id())
+ .cloned()
+ .map(|entity| entity.downcast::<TerminalView>().unwrap())
+ }
+
+ fn sync_diff_multibuffers(
+ &mut self,
+ thread: &Entity<AcpThread>,
+ index: usize,
+ window: &mut Window,
+ cx: &mut App,
+ ) {
+ let Some(entry) = thread.read(cx).entries().get(index) else {
+ return;
+ };
+
+ let multibuffers = entry
+ .diffs()
+ .map(|diff| diff.read(cx).multibuffer().clone());
+
+ let multibuffers = multibuffers.collect::<Vec<_>>();
+
+ for multibuffer in multibuffers {
+ if self.views.contains_key(&multibuffer.entity_id()) {
+ return;
+ }
+
+ let editor = cx.new(|cx| {
+ let mut editor = Editor::new(
+ EditorMode::Full {
+ scale_ui_elements_with_buffer_font_size: false,
+ show_active_line_background: false,
+ sized_by_content: true,
+ },
+ multibuffer.clone(),
+ None,
+ window,
+ cx,
+ );
+ editor.set_show_gutter(false, cx);
+ editor.disable_inline_diagnostics();
+ editor.disable_expand_excerpt_buttons(cx);
+ editor.set_show_vertical_scrollbar(false, cx);
+ editor.set_minimap_visibility(MinimapVisibility::Disabled, window, cx);
+ editor.set_soft_wrap_mode(SoftWrap::None, cx);
+ editor.scroll_manager.set_forbid_vertical_scroll(true);
+ editor.set_show_indent_guides(false, cx);
+ editor.set_read_only(true);
+ editor.set_show_breakpoints(false, cx);
+ editor.set_show_code_actions(false, cx);
+ editor.set_show_git_diff_gutter(false, cx);
+ editor.set_expand_all_diff_hunks(cx);
+ editor.set_text_style_refinement(diff_editor_text_style_refinement(cx));
+ editor
+ });
+
+ let entity_id = multibuffer.entity_id();
+ self.views.insert(entity_id, editor.into_any());
+ }
+ }
+
+ fn sync_terminals(
+ &mut self,
+ workspace: &WeakEntity<Workspace>,
+ thread: &Entity<AcpThread>,
+ index: usize,
+ window: &mut Window,
+ cx: &mut App,
+ ) {
+ let Some(entry) = thread.read(cx).entries().get(index) else {
+ return;
+ };
+
+ let terminals = entry
+ .terminals()
+ .map(|terminal| terminal.clone())
+ .collect::<Vec<_>>();
+
+ for terminal in terminals {
+ if self.views.contains_key(&terminal.entity_id()) {
+ return;
+ }
+
+ let Some(strong_workspace) = workspace.upgrade() else {
+ return;
+ };
+
+ let terminal_view = cx.new(|cx| {
+ let mut view = TerminalView::new(
+ terminal.read(cx).inner().clone(),
+ workspace.clone(),
+ None,
+ strong_workspace.read(cx).project().downgrade(),
+ window,
+ cx,
+ );
+ view.set_embedded_mode(Some(1000), cx);
+ view
+ });
+
+ let entity_id = terminal.entity_id();
+ self.views.insert(entity_id, terminal_view.into_any());
+ }
+ }
+
+ #[cfg(test)]
+ pub fn len(&self) -> usize {
+ self.views.len()
+ }
+}
+
+fn diff_editor_text_style_refinement(cx: &mut App) -> TextStyleRefinement {
+ TextStyleRefinement {
+ font_size: Some(
+ TextSize::Small
+ .rems(cx)
+ .to_pixels(ThemeSettings::get_global(cx).agent_font_size(cx))
+ .into(),
+ ),
+ ..Default::default()
+ }
+}
+
+impl Default for Entry {
+ fn default() -> Self {
+ Self {
+ // Avoid allocating in the heap by default
+ views: HashMap::with_capacity(0),
+ }
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use std::{path::Path, rc::Rc};
+
+ use acp_thread::{AgentConnection, StubAgentConnection};
+ use agent_client_protocol as acp;
+ use agent_settings::AgentSettings;
+ use buffer_diff::{DiffHunkStatus, DiffHunkStatusKind};
+ use editor::{EditorSettings, RowInfo};
+ use fs::FakeFs;
+ use gpui::{SemanticVersion, TestAppContext};
+ use multi_buffer::MultiBufferRow;
+ use pretty_assertions::assert_matches;
+ use project::Project;
+ use serde_json::json;
+ use settings::{Settings as _, SettingsStore};
+ use theme::ThemeSettings;
+ use util::path;
+ use workspace::Workspace;
+
+ use crate::acp::entry_view_state::EntryViewState;
+
+ #[gpui::test]
+ async fn test_diff_sync(cx: &mut TestAppContext) {
+ init_test(cx);
+ let fs = FakeFs::new(cx.executor());
+ fs.insert_tree(
+ "/project",
+ json!({
+ "hello.txt": "hi world"
+ }),
+ )
+ .await;
+ let project = Project::test(fs, [Path::new(path!("/project"))], cx).await;
+
+ let (workspace, cx) =
+ cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
+
+ let tool_call = acp::ToolCall {
+ id: acp::ToolCallId("tool".into()),
+ title: "Tool call".into(),
+ kind: acp::ToolKind::Other,
+ status: acp::ToolCallStatus::InProgress,
+ content: vec![acp::ToolCallContent::Diff {
+ diff: acp::Diff {
+ path: "/project/hello.txt".into(),
+ old_text: Some("hi world".into()),
+ new_text: "hello world".into(),
+ },
+ }],
+ locations: vec![],
+ raw_input: None,
+ raw_output: None,
+ };
+ let connection = Rc::new(StubAgentConnection::new());
+ let thread = cx
+ .update(|_, cx| {
+ connection
+ .clone()
+ .new_thread(project, Path::new(path!("/project")), cx)
+ })
+ .await
+ .unwrap();
+ let session_id = thread.update(cx, |thread, _| thread.session_id().clone());
+
+ cx.update(|_, cx| {
+ connection.send_update(session_id, acp::SessionUpdate::ToolCall(tool_call), cx)
+ });
+
+ let mut view_state = EntryViewState::default();
+ cx.update(|window, cx| {
+ view_state.sync_entry(workspace.downgrade(), thread.clone(), 0, window, cx);
+ });
+
+ let multibuffer = thread.read_with(cx, |thread, cx| {
+ thread
+ .entries()
+ .get(0)
+ .unwrap()
+ .diffs()
+ .next()
+ .unwrap()
+ .read(cx)
+ .multibuffer()
+ .clone()
+ });
+
+ cx.run_until_parked();
+
+ let entry = view_state.entry(0).unwrap();
+ let diff_editor = entry.editor_for_diff(&multibuffer).unwrap();
+ assert_eq!(
+ diff_editor.read_with(cx, |editor, cx| editor.text(cx)),
+ "hi world\nhello world"
+ );
+ let row_infos = diff_editor.read_with(cx, |editor, cx| {
+ let multibuffer = editor.buffer().read(cx);
+ multibuffer
+ .snapshot(cx)
+ .row_infos(MultiBufferRow(0))
+ .collect::<Vec<_>>()
+ });
+ assert_matches!(
+ row_infos.as_slice(),
+ [
+ RowInfo {
+ multibuffer_row: Some(MultiBufferRow(0)),
+ diff_status: Some(DiffHunkStatus {
+ kind: DiffHunkStatusKind::Deleted,
+ ..
+ }),
+ ..
+ },
+ RowInfo {
+ multibuffer_row: Some(MultiBufferRow(1)),
+ diff_status: Some(DiffHunkStatus {
+ kind: DiffHunkStatusKind::Added,
+ ..
+ }),
+ ..
+ }
+ ]
+ );
+ }
+
+ fn init_test(cx: &mut TestAppContext) {
+ cx.update(|cx| {
+ let settings_store = SettingsStore::test(cx);
+ cx.set_global(settings_store);
+ language::init(cx);
+ Project::init_settings(cx);
+ AgentSettings::register(cx);
+ workspace::init_settings(cx);
+ ThemeSettings::register(cx);
+ release_channel::init(SemanticVersion::default(), cx);
+ EditorSettings::register(cx);
+ });
+ }
+}
@@ -12,24 +12,22 @@ use audio::{Audio, Sound};
use buffer_diff::BufferDiff;
use collections::{HashMap, HashSet};
use editor::scroll::Autoscroll;
-use editor::{Editor, EditorMode, MinimapVisibility, MultiBuffer, PathKey, SelectionEffects};
+use editor::{Editor, EditorMode, MultiBuffer, PathKey, SelectionEffects};
use file_icons::FileIcons;
use gpui::{
Action, Animation, AnimationExt, App, BorderStyle, ClickEvent, EdgesRefinement, Empty, Entity,
- EntityId, FocusHandle, Focusable, Hsla, Length, ListOffset, ListState, MouseButton,
- PlatformDisplay, SharedString, Stateful, StyleRefinement, Subscription, Task, TextStyle,
- TextStyleRefinement, Transformation, UnderlineStyle, WeakEntity, Window, WindowHandle, div,
- linear_color_stop, linear_gradient, list, percentage, point, prelude::*, pulsating_between,
+ FocusHandle, Focusable, Hsla, Length, ListOffset, ListState, MouseButton, PlatformDisplay,
+ SharedString, Stateful, StyleRefinement, Subscription, Task, TextStyle, TextStyleRefinement,
+ Transformation, UnderlineStyle, WeakEntity, Window, WindowHandle, div, linear_color_stop,
+ linear_gradient, list, percentage, point, prelude::*, pulsating_between,
};
use language::Buffer;
-use language::language_settings::SoftWrap;
use markdown::{HeadingLevelStyles, Markdown, MarkdownElement, MarkdownStyle};
use project::Project;
use prompt_store::PromptId;
use rope::Point;
use settings::{Settings as _, SettingsStore};
use std::{collections::BTreeMap, process::ExitStatus, rc::Rc, time::Duration};
-use terminal_view::TerminalView;
use text::Anchor;
use theme::ThemeSettings;
use ui::{
@@ -41,6 +39,7 @@ use workspace::{CollaboratorId, Workspace};
use zed_actions::agent::{Chat, ToggleModelSelector};
use zed_actions::assistant::OpenRulesLibrary;
+use super::entry_view_state::EntryViewState;
use crate::acp::AcpModelSelectorPopover;
use crate::acp::message_editor::{MessageEditor, MessageEditorEvent};
use crate::agent_diff::AgentDiff;
@@ -61,8 +60,7 @@ pub struct AcpThreadView {
thread_store: Entity<ThreadStore>,
text_thread_store: Entity<TextThreadStore>,
thread_state: ThreadState,
- diff_editors: HashMap<EntityId, Entity<Editor>>,
- terminal_views: HashMap<EntityId, Entity<TerminalView>>,
+ entry_view_state: EntryViewState,
message_editor: Entity<MessageEditor>,
model_selector: Option<Entity<AcpModelSelectorPopover>>,
notifications: Vec<WindowHandle<AgentNotification>>,
@@ -149,8 +147,7 @@ impl AcpThreadView {
model_selector: None,
notifications: Vec::new(),
notification_subscriptions: HashMap::default(),
- diff_editors: Default::default(),
- terminal_views: Default::default(),
+ entry_view_state: EntryViewState::default(),
list_state: list_state.clone(),
scrollbar_state: ScrollbarState::new(list_state).parent_entity(&cx.entity()),
last_error: None,
@@ -209,11 +206,18 @@ impl AcpThreadView {
// })
// .ok();
- let result = match connection
- .clone()
- .new_thread(project.clone(), &root_dir, cx)
- .await
- {
+ let Some(result) = cx
+ .update(|_, cx| {
+ connection
+ .clone()
+ .new_thread(project.clone(), &root_dir, cx)
+ })
+ .log_err()
+ else {
+ return;
+ };
+
+ let result = match result.await {
Err(e) => {
let mut cx = cx.clone();
if e.is::<acp_thread::AuthRequired>() {
@@ -480,16 +484,29 @@ impl AcpThreadView {
) {
match event {
AcpThreadEvent::NewEntry => {
- let index = thread.read(cx).entries().len() - 1;
- self.sync_thread_entry_view(index, window, cx);
+ let len = thread.read(cx).entries().len();
+ let index = len - 1;
+ self.entry_view_state.sync_entry(
+ self.workspace.clone(),
+ thread.clone(),
+ index,
+ window,
+ cx,
+ );
self.list_state.splice(index..index, 1);
}
AcpThreadEvent::EntryUpdated(index) => {
- self.sync_thread_entry_view(*index, window, cx);
+ self.entry_view_state.sync_entry(
+ self.workspace.clone(),
+ thread.clone(),
+ *index,
+ window,
+ cx,
+ );
self.list_state.splice(*index..index + 1, 1);
}
AcpThreadEvent::EntriesRemoved(range) => {
- // TODO: Clean up unused diff editors and terminal views
+ self.entry_view_state.remove(range.clone());
self.list_state.splice(range.clone(), 0);
}
AcpThreadEvent::ToolAuthorizationRequired => {
@@ -523,128 +540,6 @@ impl AcpThreadView {
cx.notify();
}
- fn sync_thread_entry_view(
- &mut self,
- entry_ix: usize,
- window: &mut Window,
- cx: &mut Context<Self>,
- ) {
- self.sync_diff_multibuffers(entry_ix, window, cx);
- self.sync_terminals(entry_ix, window, cx);
- }
-
- fn sync_diff_multibuffers(
- &mut self,
- entry_ix: usize,
- window: &mut Window,
- cx: &mut Context<Self>,
- ) {
- let Some(multibuffers) = self.entry_diff_multibuffers(entry_ix, cx) else {
- return;
- };
-
- let multibuffers = multibuffers.collect::<Vec<_>>();
-
- for multibuffer in multibuffers {
- if self.diff_editors.contains_key(&multibuffer.entity_id()) {
- return;
- }
-
- let editor = cx.new(|cx| {
- let mut editor = Editor::new(
- EditorMode::Full {
- scale_ui_elements_with_buffer_font_size: false,
- show_active_line_background: false,
- sized_by_content: true,
- },
- multibuffer.clone(),
- None,
- window,
- cx,
- );
- editor.set_show_gutter(false, cx);
- editor.disable_inline_diagnostics();
- editor.disable_expand_excerpt_buttons(cx);
- editor.set_show_vertical_scrollbar(false, cx);
- editor.set_minimap_visibility(MinimapVisibility::Disabled, window, cx);
- editor.set_soft_wrap_mode(SoftWrap::None, cx);
- editor.scroll_manager.set_forbid_vertical_scroll(true);
- editor.set_show_indent_guides(false, cx);
- editor.set_read_only(true);
- editor.set_show_breakpoints(false, cx);
- editor.set_show_code_actions(false, cx);
- editor.set_show_git_diff_gutter(false, cx);
- editor.set_expand_all_diff_hunks(cx);
- editor.set_text_style_refinement(diff_editor_text_style_refinement(cx));
- editor
- });
- let entity_id = multibuffer.entity_id();
- cx.observe_release(&multibuffer, move |this, _, _| {
- this.diff_editors.remove(&entity_id);
- })
- .detach();
-
- self.diff_editors.insert(entity_id, editor);
- }
- }
-
- fn entry_diff_multibuffers(
- &self,
- entry_ix: usize,
- cx: &App,
- ) -> Option<impl Iterator<Item = Entity<MultiBuffer>>> {
- let entry = self.thread()?.read(cx).entries().get(entry_ix)?;
- Some(
- entry
- .diffs()
- .map(|diff| diff.read(cx).multibuffer().clone()),
- )
- }
-
- fn sync_terminals(&mut self, entry_ix: usize, window: &mut Window, cx: &mut Context<Self>) {
- let Some(terminals) = self.entry_terminals(entry_ix, cx) else {
- return;
- };
-
- let terminals = terminals.collect::<Vec<_>>();
-
- for terminal in terminals {
- if self.terminal_views.contains_key(&terminal.entity_id()) {
- return;
- }
-
- let terminal_view = cx.new(|cx| {
- let mut view = TerminalView::new(
- terminal.read(cx).inner().clone(),
- self.workspace.clone(),
- None,
- self.project.downgrade(),
- window,
- cx,
- );
- view.set_embedded_mode(Some(1000), cx);
- view
- });
-
- let entity_id = terminal.entity_id();
- cx.observe_release(&terminal, move |this, _, _| {
- this.terminal_views.remove(&entity_id);
- })
- .detach();
-
- self.terminal_views.insert(entity_id, terminal_view);
- }
- }
-
- fn entry_terminals(
- &self,
- entry_ix: usize,
- cx: &App,
- ) -> Option<impl Iterator<Item = Entity<acp_thread::Terminal>>> {
- let entry = self.thread()?.read(cx).entries().get(entry_ix)?;
- Some(entry.terminals().map(|terminal| terminal.clone()))
- }
-
fn authenticate(
&mut self,
method: acp::AuthMethodId,
@@ -712,7 +607,7 @@ impl AcpThreadView {
fn render_entry(
&self,
- index: usize,
+ entry_ix: usize,
total_entries: usize,
entry: &AgentThreadEntry,
window: &mut Window,
@@ -720,7 +615,7 @@ impl AcpThreadView {
) -> AnyElement {
let primary = match &entry {
AgentThreadEntry::UserMessage(message) => div()
- .id(("user_message", index))
+ .id(("user_message", entry_ix))
.py_4()
.px_2()
.children(message.id.clone().and_then(|message_id| {
@@ -749,7 +644,9 @@ impl AcpThreadView {
.text_xs()
.id("message")
.on_click(cx.listener({
- move |this, _, window, cx| this.start_editing_message(index, window, cx)
+ move |this, _, window, cx| {
+ this.start_editing_message(entry_ix, window, cx)
+ }
}))
.children(
if let Some(editing) = self.editing_message.as_ref()
@@ -787,7 +684,7 @@ impl AcpThreadView {
AssistantMessageChunk::Thought { block } => {
block.markdown().map(|md| {
self.render_thinking_block(
- index,
+ entry_ix,
chunk_ix,
md.clone(),
window,
@@ -803,7 +700,7 @@ impl AcpThreadView {
v_flex()
.px_5()
.py_1()
- .when(index + 1 == total_entries, |this| this.pb_4())
+ .when(entry_ix + 1 == total_entries, |this| this.pb_4())
.w_full()
.text_ui(cx)
.child(message_body)
@@ -815,10 +712,12 @@ impl AcpThreadView {
div().w_full().py_1p5().px_5().map(|this| {
if has_terminals {
this.children(tool_call.terminals().map(|terminal| {
- self.render_terminal_tool_call(terminal, tool_call, window, cx)
+ self.render_terminal_tool_call(
+ entry_ix, terminal, tool_call, window, cx,
+ )
}))
} else {
- this.child(self.render_tool_call(index, tool_call, window, cx))
+ this.child(self.render_tool_call(entry_ix, tool_call, window, cx))
}
})
}
@@ -830,7 +729,7 @@ impl AcpThreadView {
};
let is_generating = matches!(thread.read(cx).status(), ThreadStatus::Generating);
- let primary = if index == total_entries - 1 && !is_generating {
+ let primary = if entry_ix == total_entries - 1 && !is_generating {
v_flex()
.w_full()
.child(primary)
@@ -841,10 +740,10 @@ impl AcpThreadView {
};
if let Some(editing) = self.editing_message.as_ref()
- && editing.index < index
+ && editing.index < entry_ix
{
let backdrop = div()
- .id(("backdrop", index))
+ .id(("backdrop", entry_ix))
.size_full()
.absolute()
.inset_0()
@@ -1125,7 +1024,9 @@ impl AcpThreadView {
.w_full()
.children(tool_call.content.iter().map(|content| {
div()
- .child(self.render_tool_call_content(content, tool_call, window, cx))
+ .child(
+ self.render_tool_call_content(entry_ix, content, tool_call, window, cx),
+ )
.into_any_element()
}))
.child(self.render_permission_buttons(
@@ -1139,7 +1040,9 @@ impl AcpThreadView {
.w_full()
.children(tool_call.content.iter().map(|content| {
div()
- .child(self.render_tool_call_content(content, tool_call, window, cx))
+ .child(
+ self.render_tool_call_content(entry_ix, content, tool_call, window, cx),
+ )
.into_any_element()
})),
ToolCallStatus::Rejected => v_flex().size_0(),
@@ -1257,6 +1160,7 @@ impl AcpThreadView {
fn render_tool_call_content(
&self,
+ entry_ix: usize,
content: &ToolCallContent,
tool_call: &ToolCall,
window: &Window,
@@ -1273,10 +1177,10 @@ impl AcpThreadView {
}
}
ToolCallContent::Diff(diff) => {
- self.render_diff_editor(&diff.read(cx).multibuffer(), cx)
+ self.render_diff_editor(entry_ix, &diff.read(cx).multibuffer(), cx)
}
ToolCallContent::Terminal(terminal) => {
- self.render_terminal_tool_call(terminal, tool_call, window, cx)
+ self.render_terminal_tool_call(entry_ix, terminal, tool_call, window, cx)
}
}
}
@@ -1420,6 +1324,7 @@ impl AcpThreadView {
fn render_diff_editor(
&self,
+ entry_ix: usize,
multibuffer: &Entity<MultiBuffer>,
cx: &Context<Self>,
) -> AnyElement {
@@ -1428,7 +1333,9 @@ impl AcpThreadView {
.border_t_1()
.border_color(self.tool_card_border_color(cx))
.child(
- if let Some(editor) = self.diff_editors.get(&multibuffer.entity_id()) {
+ if let Some(entry) = self.entry_view_state.entry(entry_ix)
+ && let Some(editor) = entry.editor_for_diff(&multibuffer)
+ {
editor.clone().into_any_element()
} else {
Empty.into_any()
@@ -1439,6 +1346,7 @@ impl AcpThreadView {
fn render_terminal_tool_call(
&self,
+ entry_ix: usize,
terminal: &Entity<acp_thread::Terminal>,
tool_call: &ToolCall,
window: &Window,
@@ -1627,8 +1535,11 @@ impl AcpThreadView {
})),
);
- let show_output =
- self.terminal_expanded && self.terminal_views.contains_key(&terminal.entity_id());
+ let terminal_view = self
+ .entry_view_state
+ .entry(entry_ix)
+ .and_then(|entry| entry.terminal(&terminal));
+ let show_output = self.terminal_expanded && terminal_view.is_some();
v_flex()
.mb_2()
@@ -1661,8 +1572,6 @@ impl AcpThreadView {
),
)
.when(show_output, |this| {
- let terminal_view = self.terminal_views.get(&terminal.entity_id()).unwrap();
-
this.child(
div()
.pt_2()
@@ -1672,7 +1581,7 @@ impl AcpThreadView {
.bg(cx.theme().colors().editor_background)
.rounded_b_md()
.text_ui_sm(cx)
- .child(terminal_view.clone()),
+ .children(terminal_view.clone()),
)
})
.into_any()
@@ -3075,12 +2984,7 @@ impl AcpThreadView {
}
fn settings_changed(&mut self, _window: &mut Window, cx: &mut Context<Self>) {
- for diff_editor in self.diff_editors.values() {
- diff_editor.update(cx, |diff_editor, cx| {
- diff_editor.set_text_style_refinement(diff_editor_text_style_refinement(cx));
- cx.notify();
- })
- }
+ self.entry_view_state.settings_changed(cx);
}
pub(crate) fn insert_dragged_files(
@@ -3379,18 +3283,6 @@ fn plan_label_markdown_style(
}
}
-fn diff_editor_text_style_refinement(cx: &mut App) -> TextStyleRefinement {
- TextStyleRefinement {
- font_size: Some(
- TextSize::Small
- .rems(cx)
- .to_pixels(ThemeSettings::get_global(cx).agent_font_size(cx))
- .into(),
- ),
- ..Default::default()
- }
-}
-
fn terminal_command_markdown_style(window: &Window, cx: &App) -> MarkdownStyle {
let default_md_style = default_markdown_style(true, window, cx);
@@ -3405,16 +3297,16 @@ fn terminal_command_markdown_style(window: &Window, cx: &App) -> MarkdownStyle {
#[cfg(test)]
pub(crate) mod tests {
- use std::{path::Path, sync::Arc};
+ use std::path::Path;
+ use acp_thread::StubAgentConnection;
use agent::{TextThreadStore, ThreadStore};
use agent_client_protocol::SessionId;
use editor::EditorSettings;
use fs::FakeFs;
- use futures::future::try_join_all;
use gpui::{SemanticVersion, TestAppContext, VisualTestContext};
- use parking_lot::Mutex;
- use rand::Rng;
+ use project::Project;
+ use serde_json::json;
use settings::SettingsStore;
use super::*;
@@ -3497,8 +3389,8 @@ pub(crate) mod tests {
raw_input: None,
raw_output: None,
};
- let connection = StubAgentConnection::new(vec![acp::SessionUpdate::ToolCall(tool_call)])
- .with_permission_requests(HashMap::from_iter([(
+ let connection =
+ StubAgentConnection::new().with_permission_requests(HashMap::from_iter([(
tool_call_id,
vec![acp::PermissionOption {
id: acp::PermissionOptionId("1".into()),
@@ -3506,6 +3398,9 @@ pub(crate) mod tests {
kind: acp::PermissionOptionKind::AllowOnce,
}],
)]));
+
+ connection.set_next_prompt_updates(vec![acp::SessionUpdate::ToolCall(tool_call)]);
+
let (thread_view, cx) = setup_thread_view(StubAgentServer::new(connection), cx).await;
let message_editor = cx.read(|cx| thread_view.read(cx).message_editor.clone());
@@ -3605,115 +3500,6 @@ pub(crate) mod tests {
}
}
- #[derive(Clone, Default)]
- struct StubAgentConnection {
- sessions: Arc<Mutex<HashMap<acp::SessionId, WeakEntity<AcpThread>>>>,
- permission_requests: HashMap<acp::ToolCallId, Vec<acp::PermissionOption>>,
- updates: Vec<acp::SessionUpdate>,
- }
-
- impl StubAgentConnection {
- fn new(updates: Vec<acp::SessionUpdate>) -> Self {
- Self {
- updates,
- permission_requests: HashMap::default(),
- sessions: Arc::default(),
- }
- }
-
- fn with_permission_requests(
- mut self,
- permission_requests: HashMap<acp::ToolCallId, Vec<acp::PermissionOption>>,
- ) -> Self {
- self.permission_requests = permission_requests;
- self
- }
- }
-
- impl AgentConnection for StubAgentConnection {
- fn auth_methods(&self) -> &[acp::AuthMethod] {
- &[]
- }
-
- fn new_thread(
- self: Rc<Self>,
- project: Entity<Project>,
- _cwd: &Path,
- cx: &mut gpui::AsyncApp,
- ) -> Task<gpui::Result<Entity<AcpThread>>> {
- let session_id = SessionId(
- rand::thread_rng()
- .sample_iter(&rand::distributions::Alphanumeric)
- .take(7)
- .map(char::from)
- .collect::<String>()
- .into(),
- );
- let thread = cx
- .new(|cx| AcpThread::new("Test", self.clone(), project, session_id.clone(), cx))
- .unwrap();
- self.sessions.lock().insert(session_id, thread.downgrade());
- Task::ready(Ok(thread))
- }
-
- fn authenticate(
- &self,
- _method_id: acp::AuthMethodId,
- _cx: &mut App,
- ) -> Task<gpui::Result<()>> {
- unimplemented!()
- }
-
- fn prompt(
- &self,
- _id: Option<acp_thread::UserMessageId>,
- params: acp::PromptRequest,
- cx: &mut App,
- ) -> Task<gpui::Result<acp::PromptResponse>> {
- let sessions = self.sessions.lock();
- let thread = sessions.get(¶ms.session_id).unwrap();
- let mut tasks = vec![];
- for update in &self.updates {
- let thread = thread.clone();
- let update = update.clone();
- let permission_request = if let acp::SessionUpdate::ToolCall(tool_call) = &update
- && let Some(options) = self.permission_requests.get(&tool_call.id)
- {
- Some((tool_call.clone(), options.clone()))
- } else {
- None
- };
- let task = cx.spawn(async move |cx| {
- if let Some((tool_call, options)) = permission_request {
- let permission = thread.update(cx, |thread, cx| {
- thread.request_tool_call_authorization(
- tool_call.clone(),
- options.clone(),
- cx,
- )
- })?;
- permission.await?;
- }
- thread.update(cx, |thread, cx| {
- thread.handle_session_update(update.clone(), cx).unwrap();
- })?;
- anyhow::Ok(())
- });
- tasks.push(task);
- }
- cx.spawn(async move |_| {
- try_join_all(tasks).await?;
- Ok(acp::PromptResponse {
- stop_reason: acp::StopReason::EndTurn,
- })
- })
- }
-
- fn cancel(&self, _session_id: &acp::SessionId, _cx: &mut App) {
- unimplemented!()
- }
- }
-
#[derive(Clone)]
struct SaboteurAgentConnection;
@@ -3722,19 +3508,17 @@ pub(crate) mod tests {
self: Rc<Self>,
project: Entity<Project>,
_cwd: &Path,
- cx: &mut gpui::AsyncApp,
+ cx: &mut gpui::App,
) -> Task<gpui::Result<Entity<AcpThread>>> {
- Task::ready(Ok(cx
- .new(|cx| {
- AcpThread::new(
- "SaboteurAgentConnection",
- self,
- project,
- SessionId("test".into()),
- cx,
- )
- })
- .unwrap()))
+ Task::ready(Ok(cx.new(|cx| {
+ AcpThread::new(
+ "SaboteurAgentConnection",
+ self,
+ project,
+ SessionId("test".into()),
+ cx,
+ )
+ })))
}
fn auth_methods(&self) -> &[acp::AuthMethod] {
@@ -3776,4 +3560,142 @@ pub(crate) mod tests {
EditorSettings::register(cx);
});
}
+
+ #[gpui::test]
+ async fn test_rewind_views(cx: &mut TestAppContext) {
+ init_test(cx);
+
+ let fs = FakeFs::new(cx.executor());
+ fs.insert_tree(
+ "/project",
+ json!({
+ "test1.txt": "old content 1",
+ "test2.txt": "old content 2"
+ }),
+ )
+ .await;
+ let project = Project::test(fs, [Path::new("/project")], cx).await;
+ let (workspace, cx) =
+ cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
+
+ let thread_store =
+ cx.update(|_window, cx| cx.new(|cx| ThreadStore::fake(project.clone(), cx)));
+ let text_thread_store =
+ cx.update(|_window, cx| cx.new(|cx| TextThreadStore::fake(project.clone(), cx)));
+
+ let connection = Rc::new(StubAgentConnection::new());
+ let thread_view = cx.update(|window, cx| {
+ cx.new(|cx| {
+ AcpThreadView::new(
+ Rc::new(StubAgentServer::new(connection.as_ref().clone())),
+ workspace.downgrade(),
+ project.clone(),
+ thread_store.clone(),
+ text_thread_store.clone(),
+ window,
+ cx,
+ )
+ })
+ });
+
+ cx.run_until_parked();
+
+ let thread = thread_view
+ .read_with(cx, |view, _| view.thread().cloned())
+ .unwrap();
+
+ // First user message
+ connection.set_next_prompt_updates(vec![acp::SessionUpdate::ToolCall(acp::ToolCall {
+ id: acp::ToolCallId("tool1".into()),
+ title: "Edit file 1".into(),
+ kind: acp::ToolKind::Edit,
+ status: acp::ToolCallStatus::Completed,
+ content: vec![acp::ToolCallContent::Diff {
+ diff: acp::Diff {
+ path: "/project/test1.txt".into(),
+ old_text: Some("old content 1".into()),
+ new_text: "new content 1".into(),
+ },
+ }],
+ locations: vec![],
+ raw_input: None,
+ raw_output: None,
+ })]);
+
+ thread
+ .update(cx, |thread, cx| thread.send_raw("Give me a diff", cx))
+ .await
+ .unwrap();
+ cx.run_until_parked();
+
+ thread.read_with(cx, |thread, _| {
+ assert_eq!(thread.entries().len(), 2);
+ });
+
+ thread_view.read_with(cx, |view, _| {
+ assert_eq!(view.entry_view_state.entry(0).unwrap().len(), 0);
+ assert_eq!(view.entry_view_state.entry(1).unwrap().len(), 1);
+ });
+
+ // Second user message
+ connection.set_next_prompt_updates(vec![acp::SessionUpdate::ToolCall(acp::ToolCall {
+ id: acp::ToolCallId("tool2".into()),
+ title: "Edit file 2".into(),
+ kind: acp::ToolKind::Edit,
+ status: acp::ToolCallStatus::Completed,
+ content: vec![acp::ToolCallContent::Diff {
+ diff: acp::Diff {
+ path: "/project/test2.txt".into(),
+ old_text: Some("old content 2".into()),
+ new_text: "new content 2".into(),
+ },
+ }],
+ locations: vec![],
+ raw_input: None,
+ raw_output: None,
+ })]);
+
+ thread
+ .update(cx, |thread, cx| thread.send_raw("Another one", cx))
+ .await
+ .unwrap();
+ cx.run_until_parked();
+
+ let second_user_message_id = thread.read_with(cx, |thread, _| {
+ assert_eq!(thread.entries().len(), 4);
+ let AgentThreadEntry::UserMessage(user_message) = thread.entries().get(2).unwrap()
+ else {
+ panic!();
+ };
+ user_message.id.clone().unwrap()
+ });
+
+ thread_view.read_with(cx, |view, _| {
+ assert_eq!(view.entry_view_state.entry(0).unwrap().len(), 0);
+ assert_eq!(view.entry_view_state.entry(1).unwrap().len(), 1);
+ assert_eq!(view.entry_view_state.entry(2).unwrap().len(), 0);
+ assert_eq!(view.entry_view_state.entry(3).unwrap().len(), 1);
+ });
+
+ // Rewind to first message
+ thread
+ .update(cx, |thread, cx| thread.rewind(second_user_message_id, cx))
+ .await
+ .unwrap();
+
+ cx.run_until_parked();
+
+ thread.read_with(cx, |thread, _| {
+ assert_eq!(thread.entries().len(), 2);
+ });
+
+ thread_view.read_with(cx, |view, _| {
+ assert_eq!(view.entry_view_state.entry(0).unwrap().len(), 0);
+ assert_eq!(view.entry_view_state.entry(1).unwrap().len(), 1);
+
+ // Old views should be dropped
+ assert!(view.entry_view_state.entry(2).is_none());
+ assert!(view.entry_view_state.entry(3).is_none());
+ });
+ }
}