@@ -1,45 +1,141 @@
-use std::{collections::HashMap, ops::Range};
+use std::ops::Range;
-use acp_thread::AcpThread;
-use editor::{Editor, EditorMode, MinimapVisibility, MultiBuffer};
+use acp_thread::{AcpThread, AgentThreadEntry};
+use agent::{TextThreadStore, ThreadStore};
+use collections::HashMap;
+use editor::{Editor, EditorMode, MinimapVisibility};
use gpui::{
- AnyEntity, App, AppContext as _, Entity, EntityId, TextStyleRefinement, WeakEntity, Window,
+ AnyEntity, App, AppContext as _, Entity, EntityId, EventEmitter, TextStyleRefinement,
+ WeakEntity, Window,
};
use language::language_settings::SoftWrap;
+use project::Project;
use settings::Settings as _;
use terminal_view::TerminalView;
use theme::ThemeSettings;
-use ui::TextSize;
+use ui::{Context, TextSize};
use workspace::Workspace;
-#[derive(Default)]
+use crate::acp::message_editor::{MessageEditor, MessageEditorEvent};
+
pub struct EntryViewState {
+ workspace: WeakEntity<Workspace>,
+ project: Entity<Project>,
+ thread_store: Entity<ThreadStore>,
+ text_thread_store: Entity<TextThreadStore>,
entries: Vec<Entry>,
}
impl EntryViewState {
+ pub fn new(
+ workspace: WeakEntity<Workspace>,
+ project: Entity<Project>,
+ thread_store: Entity<ThreadStore>,
+ text_thread_store: Entity<TextThreadStore>,
+ ) -> Self {
+ Self {
+ workspace,
+ project,
+ thread_store,
+ text_thread_store,
+ entries: Vec::new(),
+ }
+ }
+
pub fn entry(&self, index: usize) -> Option<&Entry> {
self.entries.get(index)
}
pub fn sync_entry(
&mut self,
- workspace: WeakEntity<Workspace>,
- thread: Entity<AcpThread>,
index: usize,
+ thread: &Entity<AcpThread>,
window: &mut Window,
- cx: &mut App,
+ cx: &mut Context<Self>,
) {
- debug_assert!(index <= self.entries.len());
- let entry = if let Some(entry) = self.entries.get_mut(index) {
- entry
- } else {
- self.entries.push(Entry::default());
- self.entries.last_mut().unwrap()
+ let Some(thread_entry) = thread.read(cx).entries().get(index) else {
+ return;
+ };
+
+ match thread_entry {
+ AgentThreadEntry::UserMessage(message) => {
+ let has_id = message.id.is_some();
+ let chunks = message.chunks.clone();
+ let message_editor = cx.new(|cx| {
+ let mut editor = MessageEditor::new(
+ self.workspace.clone(),
+ self.project.clone(),
+ self.thread_store.clone(),
+ self.text_thread_store.clone(),
+ editor::EditorMode::AutoHeight {
+ min_lines: 1,
+ max_lines: None,
+ },
+ window,
+ cx,
+ );
+ if !has_id {
+ editor.set_read_only(true, cx);
+ }
+ editor.set_message(chunks, window, cx);
+ editor
+ });
+ cx.subscribe(&message_editor, move |_, editor, event, cx| {
+ cx.emit(EntryViewEvent {
+ entry_index: index,
+ view_event: ViewEvent::MessageEditorEvent(editor, *event),
+ })
+ })
+ .detach();
+ self.set_entry(index, Entry::UserMessage(message_editor));
+ }
+ AgentThreadEntry::ToolCall(tool_call) => {
+ let terminals = tool_call.terminals().cloned().collect::<Vec<_>>();
+ let diffs = tool_call.diffs().cloned().collect::<Vec<_>>();
+
+ let views = if let Some(Entry::Content(views)) = self.entries.get_mut(index) {
+ views
+ } else {
+ self.set_entry(index, Entry::empty());
+ let Some(Entry::Content(views)) = self.entries.get_mut(index) else {
+ unreachable!()
+ };
+ views
+ };
+
+ for terminal in terminals {
+ views.entry(terminal.entity_id()).or_insert_with(|| {
+ create_terminal(
+ self.workspace.clone(),
+ self.project.clone(),
+ terminal.clone(),
+ window,
+ cx,
+ )
+ .into_any()
+ });
+ }
+
+ for diff in diffs {
+ views
+ .entry(diff.entity_id())
+ .or_insert_with(|| create_editor_diff(diff.clone(), window, cx).into_any());
+ }
+ }
+ AgentThreadEntry::AssistantMessage(_) => {
+ if index == self.entries.len() {
+ self.entries.push(Entry::empty())
+ }
+ }
};
+ }
- entry.sync_diff_multibuffers(&thread, index, window, cx);
- entry.sync_terminals(&workspace, &thread, index, window, cx);
+ fn set_entry(&mut self, index: usize, entry: Entry) {
+ if index == self.entries.len() {
+ self.entries.push(entry);
+ } else {
+ self.entries[index] = entry;
+ }
}
pub fn remove(&mut self, range: Range<usize>) {
@@ -48,26 +144,51 @@ impl EntryViewState {
pub fn settings_changed(&mut self, cx: &mut App) {
for entry in self.entries.iter() {
- for view in entry.views.values() {
- if let Ok(diff_editor) = view.clone().downcast::<Editor>() {
- diff_editor.update(cx, |diff_editor, cx| {
- diff_editor
- .set_text_style_refinement(diff_editor_text_style_refinement(cx));
- cx.notify();
- })
+ match entry {
+ Entry::UserMessage { .. } => {}
+ Entry::Content(response_views) => {
+ for view in response_views.values() {
+ if let Ok(diff_editor) = view.clone().downcast::<Editor>() {
+ diff_editor.update(cx, |diff_editor, cx| {
+ diff_editor.set_text_style_refinement(
+ diff_editor_text_style_refinement(cx),
+ );
+ cx.notify();
+ })
+ }
+ }
}
}
}
}
}
-pub struct Entry {
- views: HashMap<EntityId, AnyEntity>,
+impl EventEmitter<EntryViewEvent> for EntryViewState {}
+
+pub struct EntryViewEvent {
+ pub entry_index: usize,
+ pub view_event: ViewEvent,
+}
+
+pub enum ViewEvent {
+ MessageEditorEvent(Entity<MessageEditor>, MessageEditorEvent),
+}
+
+pub enum Entry {
+ UserMessage(Entity<MessageEditor>),
+ Content(HashMap<EntityId, AnyEntity>),
}
impl Entry {
- pub fn editor_for_diff(&self, diff: &Entity<MultiBuffer>) -> Option<Entity<Editor>> {
- self.views
+ pub fn message_editor(&self) -> Option<&Entity<MessageEditor>> {
+ match self {
+ Self::UserMessage(editor) => Some(editor),
+ Entry::Content(_) => None,
+ }
+ }
+
+ pub fn editor_for_diff(&self, diff: &Entity<acp_thread::Diff>) -> Option<Entity<Editor>> {
+ self.content_map()?
.get(&diff.entity_id())
.cloned()
.map(|entity| entity.downcast::<Editor>().unwrap())
@@ -77,118 +198,88 @@ impl Entry {
&self,
terminal: &Entity<acp_thread::Terminal>,
) -> Option<Entity<TerminalView>> {
- self.views
+ self.content_map()?
.get(&terminal.entity_id())
.cloned()
.map(|entity| entity.downcast::<TerminalView>().unwrap())
}
- fn sync_diff_multibuffers(
- &mut self,
- thread: &Entity<AcpThread>,
- index: usize,
- window: &mut Window,
- cx: &mut App,
- ) {
- let Some(entry) = thread.read(cx).entries().get(index) else {
- return;
- };
-
- let multibuffers = entry
- .diffs()
- .map(|diff| diff.read(cx).multibuffer().clone());
-
- let multibuffers = multibuffers.collect::<Vec<_>>();
-
- for multibuffer in multibuffers {
- if self.views.contains_key(&multibuffer.entity_id()) {
- return;
- }
-
- let editor = cx.new(|cx| {
- let mut editor = Editor::new(
- EditorMode::Full {
- scale_ui_elements_with_buffer_font_size: false,
- show_active_line_background: false,
- sized_by_content: true,
- },
- multibuffer.clone(),
- None,
- window,
- cx,
- );
- editor.set_show_gutter(false, cx);
- editor.disable_inline_diagnostics();
- editor.disable_expand_excerpt_buttons(cx);
- editor.set_show_vertical_scrollbar(false, cx);
- editor.set_minimap_visibility(MinimapVisibility::Disabled, window, cx);
- editor.set_soft_wrap_mode(SoftWrap::None, cx);
- editor.scroll_manager.set_forbid_vertical_scroll(true);
- editor.set_show_indent_guides(false, cx);
- editor.set_read_only(true);
- editor.set_show_breakpoints(false, cx);
- editor.set_show_code_actions(false, cx);
- editor.set_show_git_diff_gutter(false, cx);
- editor.set_expand_all_diff_hunks(cx);
- editor.set_text_style_refinement(diff_editor_text_style_refinement(cx));
- editor
- });
-
- let entity_id = multibuffer.entity_id();
- self.views.insert(entity_id, editor.into_any());
+ fn content_map(&self) -> Option<&HashMap<EntityId, AnyEntity>> {
+ match self {
+ Self::Content(map) => Some(map),
+ _ => None,
}
}
- fn sync_terminals(
- &mut self,
- workspace: &WeakEntity<Workspace>,
- thread: &Entity<AcpThread>,
- index: usize,
- window: &mut Window,
- cx: &mut App,
- ) {
- let Some(entry) = thread.read(cx).entries().get(index) else {
- return;
- };
-
- let terminals = entry
- .terminals()
- .map(|terminal| terminal.clone())
- .collect::<Vec<_>>();
-
- for terminal in terminals {
- if self.views.contains_key(&terminal.entity_id()) {
- return;
- }
-
- let Some(strong_workspace) = workspace.upgrade() else {
- return;
- };
-
- let terminal_view = cx.new(|cx| {
- let mut view = TerminalView::new(
- terminal.read(cx).inner().clone(),
- workspace.clone(),
- None,
- strong_workspace.read(cx).project().downgrade(),
- window,
- cx,
- );
- view.set_embedded_mode(Some(1000), cx);
- view
- });
-
- let entity_id = terminal.entity_id();
- self.views.insert(entity_id, terminal_view.into_any());
- }
+ fn empty() -> Self {
+ Self::Content(HashMap::default())
}
#[cfg(test)]
- pub fn len(&self) -> usize {
- self.views.len()
+ pub fn has_content(&self) -> bool {
+ match self {
+ Self::Content(map) => !map.is_empty(),
+ Self::UserMessage(_) => false,
+ }
}
}
+fn create_terminal(
+ workspace: WeakEntity<Workspace>,
+ project: Entity<Project>,
+ terminal: Entity<acp_thread::Terminal>,
+ window: &mut Window,
+ cx: &mut App,
+) -> Entity<TerminalView> {
+ cx.new(|cx| {
+ let mut view = TerminalView::new(
+ terminal.read(cx).inner().clone(),
+ workspace.clone(),
+ None,
+ project.downgrade(),
+ window,
+ cx,
+ );
+ view.set_embedded_mode(Some(1000), cx);
+ view
+ })
+}
+
+fn create_editor_diff(
+ diff: Entity<acp_thread::Diff>,
+ window: &mut Window,
+ cx: &mut App,
+) -> Entity<Editor> {
+ cx.new(|cx| {
+ let mut editor = Editor::new(
+ EditorMode::Full {
+ scale_ui_elements_with_buffer_font_size: false,
+ show_active_line_background: false,
+ sized_by_content: true,
+ },
+ diff.read(cx).multibuffer().clone(),
+ None,
+ window,
+ cx,
+ );
+ editor.set_show_gutter(false, cx);
+ editor.disable_inline_diagnostics();
+ editor.disable_expand_excerpt_buttons(cx);
+ editor.set_show_vertical_scrollbar(false, cx);
+ editor.set_minimap_visibility(MinimapVisibility::Disabled, window, cx);
+ editor.set_soft_wrap_mode(SoftWrap::None, cx);
+ editor.scroll_manager.set_forbid_vertical_scroll(true);
+ editor.set_show_indent_guides(false, cx);
+ editor.set_read_only(true);
+ editor.set_show_breakpoints(false, cx);
+ editor.set_show_code_actions(false, cx);
+ editor.set_show_git_diff_gutter(false, cx);
+ editor.set_expand_all_diff_hunks(cx);
+ editor.set_text_style_refinement(diff_editor_text_style_refinement(cx));
+ editor
+ })
+}
+
fn diff_editor_text_style_refinement(cx: &mut App) -> TextStyleRefinement {
TextStyleRefinement {
font_size: Some(
@@ -201,26 +292,20 @@ fn diff_editor_text_style_refinement(cx: &mut App) -> TextStyleRefinement {
}
}
-impl Default for Entry {
- fn default() -> Self {
- Self {
- // Avoid allocating in the heap by default
- views: HashMap::with_capacity(0),
- }
- }
-}
-
#[cfg(test)]
mod tests {
use std::{path::Path, rc::Rc};
use acp_thread::{AgentConnection, StubAgentConnection};
+ use agent::{TextThreadStore, ThreadStore};
use agent_client_protocol as acp;
use agent_settings::AgentSettings;
use buffer_diff::{DiffHunkStatus, DiffHunkStatusKind};
use editor::{EditorSettings, RowInfo};
use fs::FakeFs;
- use gpui::{SemanticVersion, TestAppContext};
+ use gpui::{AppContext as _, SemanticVersion, TestAppContext};
+
+ use crate::acp::entry_view_state::EntryViewState;
use multi_buffer::MultiBufferRow;
use pretty_assertions::assert_matches;
use project::Project;
@@ -230,8 +315,6 @@ mod tests {
use util::path;
use workspace::Workspace;
- use crate::acp::entry_view_state::EntryViewState;
-
#[gpui::test]
async fn test_diff_sync(cx: &mut TestAppContext) {
init_test(cx);
@@ -269,7 +352,7 @@ mod tests {
.update(|_, cx| {
connection
.clone()
- .new_thread(project, Path::new(path!("/project")), cx)
+ .new_thread(project.clone(), Path::new(path!("/project")), cx)
})
.await
.unwrap();
@@ -279,12 +362,23 @@ mod tests {
connection.send_update(session_id, acp::SessionUpdate::ToolCall(tool_call), cx)
});
- let mut view_state = EntryViewState::default();
- cx.update(|window, cx| {
- view_state.sync_entry(workspace.downgrade(), thread.clone(), 0, window, cx);
+ let thread_store = cx.new(|cx| ThreadStore::fake(project.clone(), cx));
+ let text_thread_store = cx.new(|cx| TextThreadStore::fake(project.clone(), cx));
+
+ let view_state = cx.new(|_cx| {
+ EntryViewState::new(
+ workspace.downgrade(),
+ project.clone(),
+ thread_store,
+ text_thread_store,
+ )
+ });
+
+ view_state.update_in(cx, |view_state, window, cx| {
+ view_state.sync_entry(0, &thread, window, cx)
});
- let multibuffer = thread.read_with(cx, |thread, cx| {
+ let diff = thread.read_with(cx, |thread, _cx| {
thread
.entries()
.get(0)
@@ -292,15 +386,14 @@ mod tests {
.diffs()
.next()
.unwrap()
- .read(cx)
- .multibuffer()
.clone()
});
cx.run_until_parked();
- let entry = view_state.entry(0).unwrap();
- let diff_editor = entry.editor_for_diff(&multibuffer).unwrap();
+ let diff_editor = view_state.read_with(cx, |view_state, _cx| {
+ view_state.entry(0).unwrap().editor_for_diff(&diff).unwrap()
+ });
assert_eq!(
diff_editor.read_with(cx, |editor, cx| editor.text(cx)),
"hi world\nhello world"
@@ -45,6 +45,7 @@ use zed_actions::assistant::OpenRulesLibrary;
use super::entry_view_state::EntryViewState;
use crate::acp::AcpModelSelectorPopover;
+use crate::acp::entry_view_state::{EntryViewEvent, ViewEvent};
use crate::acp::message_editor::{MessageEditor, MessageEditorEvent};
use crate::agent_diff::AgentDiff;
use crate::profile_selector::{ProfileProvider, ProfileSelector};
@@ -101,10 +102,8 @@ pub struct AcpThreadView {
agent: Rc<dyn AgentServer>,
workspace: WeakEntity<Workspace>,
project: Entity<Project>,
- thread_store: Entity<ThreadStore>,
- text_thread_store: Entity<TextThreadStore>,
thread_state: ThreadState,
- entry_view_state: EntryViewState,
+ entry_view_state: Entity<EntryViewState>,
message_editor: Entity<MessageEditor>,
model_selector: Option<Entity<AcpModelSelectorPopover>>,
profile_selector: Option<Entity<ProfileSelector>>,
@@ -120,16 +119,9 @@ pub struct AcpThreadView {
plan_expanded: bool,
editor_expanded: bool,
terminal_expanded: bool,
- editing_message: Option<EditingMessage>,
+ editing_message: Option<usize>,
_cancel_task: Option<Task<()>>,
- _subscriptions: [Subscription; 2],
-}
-
-struct EditingMessage {
- index: usize,
- message_id: UserMessageId,
- editor: Entity<MessageEditor>,
- _subscription: Subscription,
+ _subscriptions: [Subscription; 3],
}
enum ThreadState {
@@ -176,24 +168,32 @@ impl AcpThreadView {
let list_state = ListState::new(0, gpui::ListAlignment::Bottom, px(2048.0));
+ let entry_view_state = cx.new(|_| {
+ EntryViewState::new(
+ workspace.clone(),
+ project.clone(),
+ thread_store.clone(),
+ text_thread_store.clone(),
+ )
+ });
+
let subscriptions = [
cx.observe_global_in::<SettingsStore>(window, Self::settings_changed),
- cx.subscribe_in(&message_editor, window, Self::on_message_editor_event),
+ cx.subscribe_in(&message_editor, window, Self::handle_message_editor_event),
+ cx.subscribe_in(&entry_view_state, window, Self::handle_entry_view_event),
];
Self {
agent: agent.clone(),
workspace: workspace.clone(),
project: project.clone(),
- thread_store,
- text_thread_store,
+ entry_view_state,
thread_state: Self::initial_state(agent, workspace, project, window, cx),
message_editor,
model_selector: None,
profile_selector: None,
notifications: Vec::new(),
notification_subscriptions: HashMap::default(),
- entry_view_state: EntryViewState::default(),
list_state: list_state.clone(),
scrollbar_state: ScrollbarState::new(list_state).parent_entity(&cx.entity()),
thread_error: None,
@@ -414,7 +414,7 @@ impl AcpThreadView {
cx.notify();
}
- pub fn on_message_editor_event(
+ pub fn handle_message_editor_event(
&mut self,
_: &Entity<MessageEditor>,
event: &MessageEditorEvent,
@@ -424,6 +424,28 @@ impl AcpThreadView {
match event {
MessageEditorEvent::Send => self.send(window, cx),
MessageEditorEvent::Cancel => self.cancel_generation(cx),
+ MessageEditorEvent::Focus => {}
+ }
+ }
+
+ pub fn handle_entry_view_event(
+ &mut self,
+ _: &Entity<EntryViewState>,
+ event: &EntryViewEvent,
+ window: &mut Window,
+ cx: &mut Context<Self>,
+ ) {
+ match &event.view_event {
+ ViewEvent::MessageEditorEvent(_editor, MessageEditorEvent::Focus) => {
+ self.editing_message = Some(event.entry_index);
+ cx.notify();
+ }
+ ViewEvent::MessageEditorEvent(editor, MessageEditorEvent::Send) => {
+ self.regenerate(event.entry_index, editor, window, cx);
+ }
+ ViewEvent::MessageEditorEvent(_editor, MessageEditorEvent::Cancel) => {
+ self.cancel_editing(&Default::default(), window, cx);
+ }
}
}
@@ -494,27 +516,56 @@ impl AcpThreadView {
.detach();
}
- fn cancel_editing(&mut self, _: &ClickEvent, _window: &mut Window, cx: &mut Context<Self>) {
- self.editing_message.take();
+ fn cancel_editing(&mut self, _: &ClickEvent, window: &mut Window, cx: &mut Context<Self>) {
+ let Some(thread) = self.thread().cloned() else {
+ return;
+ };
+
+ if let Some(index) = self.editing_message.take() {
+ if let Some(editor) = self
+ .entry_view_state
+ .read(cx)
+ .entry(index)
+ .and_then(|e| e.message_editor())
+ .cloned()
+ {
+ editor.update(cx, |editor, cx| {
+ if let Some(user_message) = thread
+ .read(cx)
+ .entries()
+ .get(index)
+ .and_then(|e| e.user_message())
+ {
+ editor.set_message(user_message.chunks.clone(), window, cx);
+ }
+ })
+ }
+ };
+ self.focus_handle(cx).focus(window);
cx.notify();
}
- fn regenerate(&mut self, _: &ClickEvent, window: &mut Window, cx: &mut Context<Self>) {
- let Some(editing_message) = self.editing_message.take() else {
+ fn regenerate(
+ &mut self,
+ entry_ix: usize,
+ message_editor: &Entity<MessageEditor>,
+ window: &mut Window,
+ cx: &mut Context<Self>,
+ ) {
+ let Some(thread) = self.thread().cloned() else {
return;
};
- let Some(thread) = self.thread().cloned() else {
+ let Some(rewind) = thread.update(cx, |thread, cx| {
+ let user_message_id = thread.entries().get(entry_ix)?.user_message()?.id.clone()?;
+ Some(thread.rewind(user_message_id, cx))
+ }) else {
return;
};
- let rewind = thread.update(cx, |thread, cx| {
- thread.rewind(editing_message.message_id, cx)
- });
+ let contents =
+ message_editor.update(cx, |message_editor, cx| message_editor.contents(window, cx));
- let contents = editing_message
- .editor
- .update(cx, |message_editor, cx| message_editor.contents(window, cx));
let task = cx.foreground_executor().spawn(async move {
rewind.await?;
contents.await
@@ -570,27 +621,20 @@ impl AcpThreadView {
AcpThreadEvent::NewEntry => {
let len = thread.read(cx).entries().len();
let index = len - 1;
- self.entry_view_state.sync_entry(
- self.workspace.clone(),
- thread.clone(),
- index,
- window,
- cx,
- );
+ self.entry_view_state.update(cx, |view_state, cx| {
+ view_state.sync_entry(index, &thread, window, cx)
+ });
self.list_state.splice(index..index, 1);
}
AcpThreadEvent::EntryUpdated(index) => {
- self.entry_view_state.sync_entry(
- self.workspace.clone(),
- thread.clone(),
- *index,
- window,
- cx,
- );
+ self.entry_view_state.update(cx, |view_state, cx| {
+ view_state.sync_entry(*index, &thread, window, cx)
+ });
self.list_state.splice(*index..index + 1, 1);
}
AcpThreadEvent::EntriesRemoved(range) => {
- self.entry_view_state.remove(range.clone());
+ self.entry_view_state
+ .update(cx, |view_state, _cx| view_state.remove(range.clone()));
self.list_state.splice(range.clone(), 0);
}
AcpThreadEvent::ToolAuthorizationRequired => {
@@ -722,29 +766,15 @@ impl AcpThreadView {
.border_1()
.border_color(cx.theme().colors().border)
.text_xs()
- .id("message")
- .on_click(cx.listener({
- move |this, _, window, cx| {
- this.start_editing_message(entry_ix, window, cx)
- }
- }))
.children(
- if let Some(editing) = self.editing_message.as_ref()
- && Some(&editing.message_id) == message.id.as_ref()
- {
- Some(
- self.render_edit_message_editor(editing, cx)
- .into_any_element(),
- )
- } else {
- message.content.markdown().map(|md| {
- self.render_markdown(
- md.clone(),
- user_message_markdown_style(window, cx),
- )
- .into_any_element()
- })
- },
+ self.entry_view_state
+ .read(cx)
+ .entry(entry_ix)
+ .and_then(|entry| entry.message_editor())
+ .map(|editor| {
+ self.render_sent_message_editor(entry_ix, editor, cx)
+ .into_any_element()
+ }),
),
)
.into_any(),
@@ -819,8 +849,8 @@ impl AcpThreadView {
primary
};
- if let Some(editing) = self.editing_message.as_ref()
- && editing.index < entry_ix
+ if let Some(editing_index) = self.editing_message.as_ref()
+ && *editing_index < entry_ix
{
let backdrop = div()
.id(("backdrop", entry_ix))
@@ -834,8 +864,8 @@ impl AcpThreadView {
div()
.relative()
- .child(backdrop)
.child(primary)
+ .child(backdrop)
.into_any_element()
} else {
primary
@@ -1256,9 +1286,7 @@ impl AcpThreadView {
Empty.into_any_element()
}
}
- ToolCallContent::Diff(diff) => {
- self.render_diff_editor(entry_ix, &diff.read(cx).multibuffer(), cx)
- }
+ ToolCallContent::Diff(diff) => self.render_diff_editor(entry_ix, &diff, cx),
ToolCallContent::Terminal(terminal) => {
self.render_terminal_tool_call(entry_ix, terminal, tool_call, window, cx)
}
@@ -1405,7 +1433,7 @@ impl AcpThreadView {
fn render_diff_editor(
&self,
entry_ix: usize,
- multibuffer: &Entity<MultiBuffer>,
+ diff: &Entity<acp_thread::Diff>,
cx: &Context<Self>,
) -> AnyElement {
v_flex()
@@ -1413,8 +1441,8 @@ impl AcpThreadView {
.border_t_1()
.border_color(self.tool_card_border_color(cx))
.child(
- if let Some(entry) = self.entry_view_state.entry(entry_ix)
- && let Some(editor) = entry.editor_for_diff(&multibuffer)
+ if let Some(entry) = self.entry_view_state.read(cx).entry(entry_ix)
+ && let Some(editor) = entry.editor_for_diff(&diff)
{
editor.clone().into_any_element()
} else {
@@ -1617,6 +1645,7 @@ impl AcpThreadView {
let terminal_view = self
.entry_view_state
+ .read(cx)
.entry(entry_ix)
.and_then(|entry| entry.terminal(&terminal));
let show_output = self.terminal_expanded && terminal_view.is_some();
@@ -2485,82 +2514,38 @@ impl AcpThreadView {
)
}
- fn start_editing_message(&mut self, index: usize, window: &mut Window, cx: &mut Context<Self>) {
- let Some(thread) = self.thread() else {
- return;
- };
- let Some(AgentThreadEntry::UserMessage(message)) = thread.read(cx).entries().get(index)
- else {
- return;
- };
- let Some(message_id) = message.id.clone() else {
- return;
- };
-
- self.list_state.scroll_to_reveal_item(index);
-
- let chunks = message.chunks.clone();
- let editor = cx.new(|cx| {
- let mut editor = MessageEditor::new(
- self.workspace.clone(),
- self.project.clone(),
- self.thread_store.clone(),
- self.text_thread_store.clone(),
- editor::EditorMode::AutoHeight {
- min_lines: 1,
- max_lines: None,
- },
- window,
- cx,
- );
- editor.set_message(chunks, window, cx);
- editor
- });
- let subscription =
- cx.subscribe_in(&editor, window, |this, _, event, window, cx| match event {
- MessageEditorEvent::Send => {
- this.regenerate(&Default::default(), window, cx);
- }
- MessageEditorEvent::Cancel => {
- this.cancel_editing(&Default::default(), window, cx);
- }
- });
- editor.focus_handle(cx).focus(window);
-
- self.editing_message.replace(EditingMessage {
- index: index,
- message_id: message_id.clone(),
- editor,
- _subscription: subscription,
- });
- cx.notify();
- }
-
- fn render_edit_message_editor(&self, editing: &EditingMessage, cx: &Context<Self>) -> Div {
- v_flex()
- .w_full()
- .gap_2()
- .child(editing.editor.clone())
- .child(
- h_flex()
- .gap_1()
- .child(
- Icon::new(IconName::Warning)
- .color(Color::Warning)
- .size(IconSize::XSmall),
- )
- .child(
- Label::new("Editing will restart the thread from this point.")
- .color(Color::Muted)
- .size(LabelSize::XSmall),
- )
- .child(self.render_editing_message_editor_buttons(editing, cx)),
- )
+ fn render_sent_message_editor(
+ &self,
+ entry_ix: usize,
+ editor: &Entity<MessageEditor>,
+ cx: &Context<Self>,
+ ) -> Div {
+ v_flex().w_full().gap_2().child(editor.clone()).when(
+ self.editing_message == Some(entry_ix),
+ |el| {
+ el.child(
+ h_flex()
+ .gap_1()
+ .child(
+ Icon::new(IconName::Warning)
+ .color(Color::Warning)
+ .size(IconSize::XSmall),
+ )
+ .child(
+ Label::new("Editing will restart the thread from this point.")
+ .color(Color::Muted)
+ .size(LabelSize::XSmall),
+ )
+ .child(self.render_sent_message_editor_buttons(entry_ix, editor, cx)),
+ )
+ },
+ )
}
- fn render_editing_message_editor_buttons(
+ fn render_sent_message_editor_buttons(
&self,
- editing: &EditingMessage,
+ entry_ix: usize,
+ editor: &Entity<MessageEditor>,
cx: &Context<Self>,
) -> Div {
h_flex()
@@ -2573,7 +2558,7 @@ impl AcpThreadView {
.icon_color(Color::Error)
.icon_size(IconSize::Small)
.tooltip({
- let focus_handle = editing.editor.focus_handle(cx);
+ let focus_handle = editor.focus_handle(cx);
move |window, cx| {
Tooltip::for_action_in(
"Cancel Edit",
@@ -2588,12 +2573,12 @@ impl AcpThreadView {
)
.child(
IconButton::new("confirm-edit-message", IconName::Return)
- .disabled(editing.editor.read(cx).is_empty(cx))
+ .disabled(editor.read(cx).is_empty(cx))
.shape(ui::IconButtonShape::Square)
.icon_color(Color::Muted)
.icon_size(IconSize::Small)
.tooltip({
- let focus_handle = editing.editor.focus_handle(cx);
+ let focus_handle = editor.focus_handle(cx);
move |window, cx| {
Tooltip::for_action_in(
"Regenerate",
@@ -2604,7 +2589,12 @@ impl AcpThreadView {
)
}
})
- .on_click(cx.listener(Self::regenerate)),
+ .on_click(cx.listener({
+ let editor = editor.clone();
+ move |this, _, window, cx| {
+ this.regenerate(entry_ix, &editor, window, cx);
+ }
+ })),
)
}
@@ -3137,7 +3127,9 @@ impl AcpThreadView {
}
fn settings_changed(&mut self, _window: &mut Window, cx: &mut Context<Self>) {
- self.entry_view_state.settings_changed(cx);
+ self.entry_view_state.update(cx, |entry_view_state, cx| {
+ entry_view_state.settings_changed(cx);
+ });
}
pub(crate) fn insert_dragged_files(
@@ -3152,9 +3144,7 @@ impl AcpThreadView {
drop(added_worktrees);
})
}
-}
-impl AcpThreadView {
fn render_thread_error(&self, window: &mut Window, cx: &mut Context<'_, Self>) -> Option<Div> {
let content = match self.thread_error.as_ref()? {
ThreadError::Other(error) => self.render_any_thread_error(error.clone(), cx),
@@ -3439,35 +3429,6 @@ impl Render for AcpThreadView {
}
}
-fn user_message_markdown_style(window: &Window, cx: &App) -> MarkdownStyle {
- let mut style = default_markdown_style(false, window, cx);
- let mut text_style = window.text_style();
- let theme_settings = ThemeSettings::get_global(cx);
-
- let buffer_font = theme_settings.buffer_font.family.clone();
- let buffer_font_size = TextSize::Small.rems(cx);
-
- text_style.refine(&TextStyleRefinement {
- font_family: Some(buffer_font),
- font_size: Some(buffer_font_size.into()),
- ..Default::default()
- });
-
- style.base_text_style = text_style;
- style.link_callback = Some(Rc::new(move |url, cx| {
- if MentionUri::parse(url).is_ok() {
- let colors = cx.theme().colors();
- Some(TextStyleRefinement {
- background_color: Some(colors.element_background),
- ..Default::default()
- })
- } else {
- None
- }
- }));
- style
-}
-
fn default_markdown_style(buffer_font: bool, window: &Window, cx: &App) -> MarkdownStyle {
let theme_settings = ThemeSettings::get_global(cx);
let colors = cx.theme().colors();
@@ -3626,12 +3587,13 @@ pub(crate) mod tests {
use agent_client_protocol::SessionId;
use editor::EditorSettings;
use fs::FakeFs;
- use gpui::{SemanticVersion, TestAppContext, VisualTestContext};
+ use gpui::{EventEmitter, SemanticVersion, TestAppContext, VisualTestContext};
use project::Project;
use serde_json::json;
use settings::SettingsStore;
use std::any::Any;
use std::path::Path;
+ use workspace::Item;
use super::*;
@@ -3778,6 +3740,50 @@ pub(crate) mod tests {
(thread_view, cx)
}
+ fn add_to_workspace(thread_view: Entity<AcpThreadView>, cx: &mut VisualTestContext) {
+ let workspace = thread_view.read_with(cx, |thread_view, _cx| thread_view.workspace.clone());
+
+ workspace
+ .update_in(cx, |workspace, window, cx| {
+ workspace.add_item_to_active_pane(
+ Box::new(cx.new(|_| ThreadViewItem(thread_view.clone()))),
+ None,
+ true,
+ window,
+ cx,
+ );
+ })
+ .unwrap();
+ }
+
+ struct ThreadViewItem(Entity<AcpThreadView>);
+
+ impl Item for ThreadViewItem {
+ type Event = ();
+
+ fn include_in_nav_history() -> bool {
+ false
+ }
+
+ fn tab_content_text(&self, _detail: usize, _cx: &App) -> SharedString {
+ "Test".into()
+ }
+ }
+
+ impl EventEmitter<()> for ThreadViewItem {}
+
+ impl Focusable for ThreadViewItem {
+ fn focus_handle(&self, cx: &App) -> FocusHandle {
+ self.0.read(cx).focus_handle(cx).clone()
+ }
+ }
+
+ impl Render for ThreadViewItem {
+ fn render(&mut self, _window: &mut Window, _cx: &mut Context<Self>) -> impl IntoElement {
+ self.0.clone().into_any_element()
+ }
+ }
+
struct StubAgentServer<C> {
connection: C,
}
@@ -3799,19 +3805,19 @@ pub(crate) mod tests {
C: 'static + AgentConnection + Send + Clone,
{
fn logo(&self) -> ui::IconName {
- unimplemented!()
+ ui::IconName::Ai
}
fn name(&self) -> &'static str {
- unimplemented!()
+ "Test"
}
fn empty_state_headline(&self) -> &'static str {
- unimplemented!()
+ "Test"
}
fn empty_state_message(&self) -> &'static str {
- unimplemented!()
+ "Test"
}
fn connect(
@@ -3960,9 +3966,17 @@ pub(crate) mod tests {
assert_eq!(thread.entries().len(), 2);
});
- thread_view.read_with(cx, |view, _| {
- assert_eq!(view.entry_view_state.entry(0).unwrap().len(), 0);
- assert_eq!(view.entry_view_state.entry(1).unwrap().len(), 1);
+ thread_view.read_with(cx, |view, cx| {
+ view.entry_view_state.read_with(cx, |entry_view_state, _| {
+ assert!(
+ entry_view_state
+ .entry(0)
+ .unwrap()
+ .message_editor()
+ .is_some()
+ );
+ assert!(entry_view_state.entry(1).unwrap().has_content());
+ });
});
// Second user message
@@ -3991,18 +4005,31 @@ pub(crate) mod tests {
let second_user_message_id = thread.read_with(cx, |thread, _| {
assert_eq!(thread.entries().len(), 4);
- let AgentThreadEntry::UserMessage(user_message) = thread.entries().get(2).unwrap()
- else {
+ let AgentThreadEntry::UserMessage(user_message) = &thread.entries()[2] else {
panic!();
};
user_message.id.clone().unwrap()
});
- thread_view.read_with(cx, |view, _| {
- assert_eq!(view.entry_view_state.entry(0).unwrap().len(), 0);
- assert_eq!(view.entry_view_state.entry(1).unwrap().len(), 1);
- assert_eq!(view.entry_view_state.entry(2).unwrap().len(), 0);
- assert_eq!(view.entry_view_state.entry(3).unwrap().len(), 1);
+ thread_view.read_with(cx, |view, cx| {
+ view.entry_view_state.read_with(cx, |entry_view_state, _| {
+ assert!(
+ entry_view_state
+ .entry(0)
+ .unwrap()
+ .message_editor()
+ .is_some()
+ );
+ assert!(entry_view_state.entry(1).unwrap().has_content());
+ assert!(
+ entry_view_state
+ .entry(2)
+ .unwrap()
+ .message_editor()
+ .is_some()
+ );
+ assert!(entry_view_state.entry(3).unwrap().has_content());
+ });
});
// Rewind to first message
@@ -4017,13 +4044,169 @@ pub(crate) mod tests {
assert_eq!(thread.entries().len(), 2);
});
- thread_view.read_with(cx, |view, _| {
- assert_eq!(view.entry_view_state.entry(0).unwrap().len(), 0);
- assert_eq!(view.entry_view_state.entry(1).unwrap().len(), 1);
+ thread_view.read_with(cx, |view, cx| {
+ view.entry_view_state.read_with(cx, |entry_view_state, _| {
+ assert!(
+ entry_view_state
+ .entry(0)
+ .unwrap()
+ .message_editor()
+ .is_some()
+ );
+ assert!(entry_view_state.entry(1).unwrap().has_content());
- // Old views should be dropped
- assert!(view.entry_view_state.entry(2).is_none());
- assert!(view.entry_view_state.entry(3).is_none());
+ // Old views should be dropped
+ assert!(entry_view_state.entry(2).is_none());
+ assert!(entry_view_state.entry(3).is_none());
+ });
});
}
+
+ #[gpui::test]
+ async fn test_message_editing_cancel(cx: &mut TestAppContext) {
+ init_test(cx);
+
+ let connection = StubAgentConnection::new();
+
+ connection.set_next_prompt_updates(vec![acp::SessionUpdate::AgentMessageChunk {
+ content: acp::ContentBlock::Text(acp::TextContent {
+ text: "Response".into(),
+ annotations: None,
+ }),
+ }]);
+
+ let (thread_view, cx) = setup_thread_view(StubAgentServer::new(connection), cx).await;
+ add_to_workspace(thread_view.clone(), cx);
+
+ let message_editor = cx.read(|cx| thread_view.read(cx).message_editor.clone());
+ message_editor.update_in(cx, |editor, window, cx| {
+ editor.set_text("Original message to edit", window, cx);
+ });
+ thread_view.update_in(cx, |thread_view, window, cx| {
+ thread_view.send(window, cx);
+ });
+
+ cx.run_until_parked();
+
+ let user_message_editor = thread_view.read_with(cx, |view, cx| {
+ assert_eq!(view.editing_message, None);
+
+ view.entry_view_state
+ .read(cx)
+ .entry(0)
+ .unwrap()
+ .message_editor()
+ .unwrap()
+ .clone()
+ });
+
+ // Focus
+ cx.focus(&user_message_editor);
+ thread_view.read_with(cx, |view, _cx| {
+ assert_eq!(view.editing_message, Some(0));
+ });
+
+ // Edit
+ user_message_editor.update_in(cx, |editor, window, cx| {
+ editor.set_text("Edited message content", window, cx);
+ });
+
+ // Cancel
+ user_message_editor.update_in(cx, |_editor, window, cx| {
+ window.dispatch_action(Box::new(editor::actions::Cancel), cx);
+ });
+
+ thread_view.read_with(cx, |view, _cx| {
+ assert_eq!(view.editing_message, None);
+ });
+
+ user_message_editor.read_with(cx, |editor, cx| {
+ assert_eq!(editor.text(cx), "Original message to edit");
+ });
+ }
+
+ #[gpui::test]
+ async fn test_message_editing_regenerate(cx: &mut TestAppContext) {
+ init_test(cx);
+
+ let connection = StubAgentConnection::new();
+
+ connection.set_next_prompt_updates(vec![acp::SessionUpdate::AgentMessageChunk {
+ content: acp::ContentBlock::Text(acp::TextContent {
+ text: "Response".into(),
+ annotations: None,
+ }),
+ }]);
+
+ let (thread_view, cx) =
+ setup_thread_view(StubAgentServer::new(connection.clone()), cx).await;
+ add_to_workspace(thread_view.clone(), cx);
+
+ let message_editor = cx.read(|cx| thread_view.read(cx).message_editor.clone());
+ message_editor.update_in(cx, |editor, window, cx| {
+ editor.set_text("Original message to edit", window, cx);
+ });
+ thread_view.update_in(cx, |thread_view, window, cx| {
+ thread_view.send(window, cx);
+ });
+
+ cx.run_until_parked();
+
+ let user_message_editor = thread_view.read_with(cx, |view, cx| {
+ assert_eq!(view.editing_message, None);
+ assert_eq!(view.thread().unwrap().read(cx).entries().len(), 2);
+
+ view.entry_view_state
+ .read(cx)
+ .entry(0)
+ .unwrap()
+ .message_editor()
+ .unwrap()
+ .clone()
+ });
+
+ // Focus
+ cx.focus(&user_message_editor);
+
+ // Edit
+ user_message_editor.update_in(cx, |editor, window, cx| {
+ editor.set_text("Edited message content", window, cx);
+ });
+
+ // Send
+ connection.set_next_prompt_updates(vec![acp::SessionUpdate::AgentMessageChunk {
+ content: acp::ContentBlock::Text(acp::TextContent {
+ text: "New Response".into(),
+ annotations: None,
+ }),
+ }]);
+
+ user_message_editor.update_in(cx, |_editor, window, cx| {
+ window.dispatch_action(Box::new(Chat), cx);
+ });
+
+ cx.run_until_parked();
+
+ thread_view.read_with(cx, |view, cx| {
+ assert_eq!(view.editing_message, None);
+
+ let entries = view.thread().unwrap().read(cx).entries();
+ assert_eq!(entries.len(), 2);
+ assert_eq!(
+ entries[0].to_markdown(cx),
+ "## User\n\nEdited message content\n\n"
+ );
+ assert_eq!(
+ entries[1].to_markdown(cx),
+ "## Assistant\n\nNew Response\n\n"
+ );
+
+ let new_editor = view.entry_view_state.read_with(cx, |state, _cx| {
+ assert!(!state.entry(1).unwrap().has_content());
+ state.entry(0).unwrap().message_editor().unwrap().clone()
+ });
+
+ assert_eq!(new_editor.read(cx).text(cx), "Edited message content");
+ })
+ }
}