From f4899d92a4f60bb40f139348ed732395033b3a7b Mon Sep 17 00:00:00 2001 From: Bennet Bo Fenner Date: Tue, 4 Mar 2025 17:57:42 +0100 Subject: [PATCH] assistant2: Add support for editing the last message sent by the user (#26037) https://github.com/user-attachments/assets/df46632b-dfeb-4991-ab2e-86829b72be9b Closes #ISSUE Release Notes: - N/A --- assets/keymaps/default-linux.json | 9 + assets/keymaps/default-macos.json | 9 + crates/assistant2/src/active_thread.rs | 266 +++++++++++++++++++++++-- crates/assistant2/src/thread.rs | 38 +++- 4 files changed, 304 insertions(+), 18 deletions(-) diff --git a/assets/keymaps/default-linux.json b/assets/keymaps/default-linux.json index 46da421588591e453e73b76c225e52e0c63a9730..655b77b86747c4b765a0c5cac73a5c4c76ff45ce 100644 --- a/assets/keymaps/default-linux.json +++ b/assets/keymaps/default-linux.json @@ -626,6 +626,15 @@ "enter": "assistant2::Chat" } }, + { + "context": "EditMessageEditor > Editor", + "use_key_equivalents": true, + "bindings": { + "escape": "menu::Cancel", + "enter": "menu::Confirm", + "alt-enter": "editor::Newline" + } + }, { "context": "ContextStrip", "bindings": { diff --git a/assets/keymaps/default-macos.json b/assets/keymaps/default-macos.json index b2bced5e7dc54ce58f4872677c22235c281822aa..6c266fbd38ce5845ae7dda97a721f79ebdc48c6a 100644 --- a/assets/keymaps/default-macos.json +++ b/assets/keymaps/default-macos.json @@ -271,6 +271,15 @@ "enter": "assistant2::Chat" } }, + { + "context": "EditMessageEditor > Editor", + "use_key_equivalents": true, + "bindings": { + "escape": "menu::Cancel", + "enter": "menu::Confirm", + "alt-enter": "editor::Newline" + } + }, { "context": "ContextStrip", "use_key_equivalents": true, diff --git a/crates/assistant2/src/active_thread.rs b/crates/assistant2/src/active_thread.rs index 6e8b73054f6ca24864ec07008da429f9db365bcf..227320a3c7058ee49af8168d979dfaa11dccf806 100644 --- a/crates/assistant2/src/active_thread.rs +++ b/crates/assistant2/src/active_thread.rs @@ -2,17 +2,18 @@ use std::sync::Arc; use assistant_tool::ToolWorkingSet; use collections::HashMap; +use editor::{Editor, MultiBuffer}; use gpui::{ - list, AbsoluteLength, AnyElement, App, DefiniteLength, EdgesRefinement, Empty, Entity, Length, - ListAlignment, ListOffset, ListState, StyleRefinement, Subscription, TextStyleRefinement, - UnderlineStyle, WeakEntity, + list, AbsoluteLength, AnyElement, App, DefiniteLength, EdgesRefinement, Empty, Entity, + Focusable, Length, ListAlignment, ListOffset, ListState, StyleRefinement, Subscription, + TextStyleRefinement, UnderlineStyle, WeakEntity, }; -use language::LanguageRegistry; +use language::{Buffer, LanguageRegistry}; use language_model::{LanguageModelRegistry, LanguageModelToolUseId, Role}; use markdown::{Markdown, MarkdownStyle}; use settings::Settings as _; use theme::ThemeSettings; -use ui::{prelude::*, Disclosure}; +use ui::{prelude::*, Disclosure, KeyBinding}; use workspace::Workspace; use crate::thread::{MessageId, RequestKind, Thread, ThreadError, ThreadEvent}; @@ -29,11 +30,16 @@ pub struct ActiveThread { messages: Vec, list_state: ListState, rendered_messages_by_id: HashMap>, + editing_message: Option<(MessageId, EditMessageState)>, expanded_tool_uses: HashMap, last_error: Option, _subscriptions: Vec, } +struct EditMessageState { + editor: Entity, +} + impl ActiveThread { pub fn new( thread: Entity, @@ -60,11 +66,12 @@ impl ActiveThread { expanded_tool_uses: HashMap::default(), list_state: ListState::new(0, ListAlignment::Bottom, px(1024.), { let this = cx.entity().downgrade(); - move |ix, _: &mut Window, cx: &mut App| { - this.update(cx, |this, cx| this.render_message(ix, cx)) + move |ix, window: &mut Window, cx: &mut App| { + this.update(cx, |this, cx| this.render_message(ix, window, cx)) .unwrap() } }), + editing_message: None, last_error: None, _subscriptions: subscriptions, }; @@ -117,6 +124,44 @@ impl ActiveThread { self.messages.push(*id); self.list_state.splice(old_len..old_len, 1); + let markdown = self.render_markdown(text.into(), window, cx); + self.rendered_messages_by_id.insert(*id, markdown); + self.list_state.scroll_to(ListOffset { + item_ix: old_len, + offset_in_item: Pixels(0.0), + }); + } + + fn edited_message( + &mut self, + id: &MessageId, + text: String, + window: &mut Window, + cx: &mut Context, + ) { + let Some(index) = self.messages.iter().position(|message_id| message_id == id) else { + return; + }; + self.list_state.splice(index..index + 1, 1); + let markdown = self.render_markdown(text.into(), window, cx); + self.rendered_messages_by_id.insert(*id, markdown); + } + + fn deleted_message(&mut self, id: &MessageId) { + let Some(index) = self.messages.iter().position(|message_id| message_id == id) else { + return; + }; + self.messages.remove(index); + self.list_state.splice(index..index + 1, 0); + self.rendered_messages_by_id.remove(id); + } + + fn render_markdown( + &self, + text: SharedString, + window: &Window, + cx: &mut Context, + ) -> Entity { let theme_settings = ThemeSettings::get_global(cx); let colors = cx.theme().colors(); let ui_font_size = TextSize::Default.rems(cx); @@ -182,20 +227,15 @@ impl ActiveThread { ..Default::default() }; - let markdown = cx.new(|cx| { + cx.new(|cx| { Markdown::new( - text.into(), + text, markdown_style, Some(self.language_registry.clone()), None, cx, ) - }); - self.rendered_messages_by_id.insert(*id, markdown); - self.list_state.scroll_to(ListOffset { - item_ix: old_len, - offset_in_item: Pixels(0.0), - }); + }) } fn handle_thread_event( @@ -241,6 +281,35 @@ impl ActiveThread { cx.notify(); } + ThreadEvent::MessageEdited(message_id) => { + if let Some(message_text) = self + .thread + .read(cx) + .message(*message_id) + .map(|message| message.text.clone()) + { + self.edited_message(message_id, message_text, window, cx); + } + + self.thread_store + .update(cx, |thread_store, cx| { + thread_store.save_thread(&self.thread, cx) + }) + .detach_and_log_err(cx); + + cx.notify(); + } + ThreadEvent::MessageDeleted(message_id) => { + self.deleted_message(message_id); + + self.thread_store + .update(cx, |thread_store, cx| { + thread_store.save_thread(&self.thread, cx) + }) + .detach_and_log_err(cx); + + cx.notify(); + } ThreadEvent::UsePendingTools => { let pending_tool_uses = self .thread @@ -289,7 +358,101 @@ impl ActiveThread { } } - fn render_message(&self, ix: usize, cx: &mut Context) -> AnyElement { + fn start_editing_message( + &mut self, + message_id: MessageId, + message_text: String, + window: &mut Window, + cx: &mut Context, + ) { + let buffer = cx.new(|cx| { + MultiBuffer::singleton(cx.new(|cx| Buffer::local(message_text.clone(), cx)), cx) + }); + let editor = cx.new(|cx| { + let mut editor = Editor::new( + editor::EditorMode::AutoHeight { max_lines: 8 }, + buffer, + None, + false, + window, + cx, + ); + editor.focus_handle(cx).focus(window); + editor.move_to_end(&editor::actions::MoveToEnd, window, cx); + editor + }); + self.editing_message = Some(( + message_id, + EditMessageState { + editor: editor.clone(), + }, + )); + cx.notify(); + } + + fn cancel_editing_message(&mut self, _: &menu::Cancel, _: &mut Window, cx: &mut Context) { + self.editing_message.take(); + cx.notify(); + } + + fn confirm_editing_message( + &mut self, + _: &menu::Confirm, + _: &mut Window, + cx: &mut Context, + ) { + let Some((message_id, state)) = self.editing_message.take() else { + return; + }; + let edited_text = state.editor.read(cx).text(cx); + self.thread.update(cx, |thread, cx| { + thread.edit_message(message_id, Role::User, edited_text, cx); + for message_id in self.messages_after(message_id) { + thread.delete_message(*message_id, cx); + } + }); + + let provider = LanguageModelRegistry::read_global(cx).active_provider(); + if provider + .as_ref() + .map_or(false, |provider| provider.must_accept_terms(cx)) + { + cx.notify(); + return; + } + let model_registry = LanguageModelRegistry::read_global(cx); + let Some(model) = model_registry.active_model() else { + return; + }; + + self.thread.update(cx, |thread, cx| { + thread.send_to_model(model, RequestKind::Chat, false, cx) + }); + cx.notify(); + } + + fn last_user_message(&self, cx: &Context) -> Option { + self.messages + .iter() + .rev() + .find(|message_id| { + self.thread + .read(cx) + .message(**message_id) + .map_or(false, |message| message.role == Role::User) + }) + .cloned() + } + + fn messages_after(&self, message_id: MessageId) -> &[MessageId] { + self.messages + .iter() + .position(|id| *id == message_id) + .map(|index| &self.messages[index + 1..]) + .unwrap_or(&[]) + } + + fn render_message(&self, ix: usize, window: &mut Window, cx: &mut Context) -> AnyElement { let message_id = self.messages[ix]; let Some(message) = self.thread.read(cx).message(message_id) else { return Empty.into_any(); @@ -308,8 +471,28 @@ impl ActiveThread { return Empty.into_any(); } + let allow_editing_message = + message.role == Role::User && self.last_user_message(cx) == Some(message_id); + + let edit_message_editor = self + .editing_message + .as_ref() + .filter(|(id, _)| *id == message_id) + .map(|(_, state)| state.editor.clone()); + let message_content = v_flex() - .child(div().p_2p5().text_ui(cx).child(markdown.clone())) + .child( + if let Some(edit_message_editor) = edit_message_editor.clone() { + div() + .key_context("EditMessageEditor") + .on_action(cx.listener(Self::cancel_editing_message)) + .on_action(cx.listener(Self::confirm_editing_message)) + .p_2p5() + .child(edit_message_editor) + } else { + div().p_2p5().text_ui(cx).child(markdown.clone()) + }, + ) .when_some(context, |parent, context| { if !context.is_empty() { parent.child( @@ -358,6 +541,55 @@ impl ActiveThread { .size(LabelSize::Small) .color(Color::Muted), ), + ) + .when_some( + edit_message_editor.clone(), + |this, edit_message_editor| { + let focus_handle = edit_message_editor.focus_handle(cx); + this.child( + h_flex() + .gap_1() + .child( + Button::new("cancel-edit-message", "Cancel") + .key_binding(KeyBinding::for_action_in( + &menu::Cancel, + &focus_handle, + window, + cx, + )), + ) + .child( + Button::new( + "confirm-edit-message", + "Regenerate", + ) + .key_binding(KeyBinding::for_action_in( + &menu::Confirm, + &focus_handle, + window, + cx, + )), + ), + ) + }, + ) + .when( + edit_message_editor.is_none() && allow_editing_message, + |this| { + this.child(Button::new("edit-message", "Edit").on_click( + cx.listener({ + let message_text = message.text.clone(); + move |this, _, window, cx| { + this.start_editing_message( + message_id, + message_text.clone(), + window, + cx, + ); + } + }), + )) + }, ), ) .child(message_content), diff --git a/crates/assistant2/src/thread.rs b/crates/assistant2/src/thread.rs index ee85f3cc94430867ce4405467d72540e21a78bd9..368902fadaa4d668b57dc679ed0d20d05a218348 100644 --- a/crates/assistant2/src/thread.rs +++ b/crates/assistant2/src/thread.rs @@ -99,7 +99,13 @@ impl Thread { tools: Arc, _cx: &mut Context, ) -> Self { - let next_message_id = MessageId(saved.messages.len()); + let next_message_id = MessageId( + saved + .messages + .last() + .map(|message| message.id.0 + 1) + .unwrap_or(0), + ); let tool_use = ToolUseState::from_saved_messages(&saved.messages); Self { @@ -229,6 +235,34 @@ impl Thread { id } + pub fn edit_message( + &mut self, + id: MessageId, + new_role: Role, + new_text: String, + cx: &mut Context, + ) -> bool { + let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else { + return false; + }; + message.role = new_role; + message.text = new_text; + self.touch_updated_at(); + cx.emit(ThreadEvent::MessageEdited(id)); + true + } + + pub fn delete_message(&mut self, id: MessageId, cx: &mut Context) -> bool { + let Some(index) = self.messages.iter().position(|message| message.id == id) else { + return false; + }; + self.messages.remove(index); + self.context_by_message.remove(&id); + self.touch_updated_at(); + cx.emit(ThreadEvent::MessageDeleted(id)); + true + } + /// Returns the representation of this [`Thread`] in a textual form. /// /// This is the representation we use when attaching a thread as context to another thread. @@ -567,6 +601,8 @@ pub enum ThreadEvent { StreamedCompletion, StreamedAssistantText(MessageId, String), MessageAdded(MessageId), + MessageEdited(MessageId), + MessageDeleted(MessageId), SummaryChanged, UsePendingTools, ToolFinished {