assistant2: Add support for editing the last message sent by the user (#26037)

Bennet Bo Fenner created

https://github.com/user-attachments/assets/df46632b-dfeb-4991-ab2e-86829b72be9b

Closes #ISSUE

Release Notes:

- N/A

Change summary

assets/keymaps/default-linux.json      |   9 
assets/keymaps/default-macos.json      |   9 
crates/assistant2/src/active_thread.rs | 266 ++++++++++++++++++++++++++-
crates/assistant2/src/thread.rs        |  38 +++
4 files changed, 304 insertions(+), 18 deletions(-)

Detailed changes

assets/keymaps/default-linux.json 🔗

@@ -626,6 +626,15 @@
       "enter": "assistant2::Chat"
     }
   },
+  {
+    "context": "EditMessageEditor > Editor",
+    "use_key_equivalents": true,
+    "bindings": {
+      "escape": "menu::Cancel",
+      "enter": "menu::Confirm",
+      "alt-enter": "editor::Newline"
+    }
+  },
   {
     "context": "ContextStrip",
     "bindings": {

assets/keymaps/default-macos.json 🔗

@@ -271,6 +271,15 @@
       "enter": "assistant2::Chat"
     }
   },
+  {
+    "context": "EditMessageEditor > Editor",
+    "use_key_equivalents": true,
+    "bindings": {
+      "escape": "menu::Cancel",
+      "enter": "menu::Confirm",
+      "alt-enter": "editor::Newline"
+    }
+  },
   {
     "context": "ContextStrip",
     "use_key_equivalents": true,

crates/assistant2/src/active_thread.rs 🔗

@@ -2,17 +2,18 @@ use std::sync::Arc;
 
 use assistant_tool::ToolWorkingSet;
 use collections::HashMap;
+use editor::{Editor, MultiBuffer};
 use gpui::{
-    list, AbsoluteLength, AnyElement, App, DefiniteLength, EdgesRefinement, Empty, Entity, Length,
-    ListAlignment, ListOffset, ListState, StyleRefinement, Subscription, TextStyleRefinement,
-    UnderlineStyle, WeakEntity,
+    list, AbsoluteLength, AnyElement, App, DefiniteLength, EdgesRefinement, Empty, Entity,
+    Focusable, Length, ListAlignment, ListOffset, ListState, StyleRefinement, Subscription,
+    TextStyleRefinement, UnderlineStyle, WeakEntity,
 };
-use language::LanguageRegistry;
+use language::{Buffer, LanguageRegistry};
 use language_model::{LanguageModelRegistry, LanguageModelToolUseId, Role};
 use markdown::{Markdown, MarkdownStyle};
 use settings::Settings as _;
 use theme::ThemeSettings;
-use ui::{prelude::*, Disclosure};
+use ui::{prelude::*, Disclosure, KeyBinding};
 use workspace::Workspace;
 
 use crate::thread::{MessageId, RequestKind, Thread, ThreadError, ThreadEvent};
@@ -29,11 +30,16 @@ pub struct ActiveThread {
     messages: Vec<MessageId>,
     list_state: ListState,
     rendered_messages_by_id: HashMap<MessageId, Entity<Markdown>>,
+    editing_message: Option<(MessageId, EditMessageState)>,
     expanded_tool_uses: HashMap<LanguageModelToolUseId, bool>,
     last_error: Option<ThreadError>,
     _subscriptions: Vec<Subscription>,
 }
 
+struct EditMessageState {
+    editor: Entity<Editor>,
+}
+
 impl ActiveThread {
     pub fn new(
         thread: Entity<Thread>,
@@ -60,11 +66,12 @@ impl ActiveThread {
             expanded_tool_uses: HashMap::default(),
             list_state: ListState::new(0, ListAlignment::Bottom, px(1024.), {
                 let this = cx.entity().downgrade();
-                move |ix, _: &mut Window, cx: &mut App| {
-                    this.update(cx, |this, cx| this.render_message(ix, cx))
+                move |ix, window: &mut Window, cx: &mut App| {
+                    this.update(cx, |this, cx| this.render_message(ix, window, cx))
                         .unwrap()
                 }
             }),
+            editing_message: None,
             last_error: None,
             _subscriptions: subscriptions,
         };
@@ -117,6 +124,44 @@ impl ActiveThread {
         self.messages.push(*id);
         self.list_state.splice(old_len..old_len, 1);
 
+        let markdown = self.render_markdown(text.into(), window, cx);
+        self.rendered_messages_by_id.insert(*id, markdown);
+        self.list_state.scroll_to(ListOffset {
+            item_ix: old_len,
+            offset_in_item: Pixels(0.0),
+        });
+    }
+
+    fn edited_message(
+        &mut self,
+        id: &MessageId,
+        text: String,
+        window: &mut Window,
+        cx: &mut Context<Self>,
+    ) {
+        let Some(index) = self.messages.iter().position(|message_id| message_id == id) else {
+            return;
+        };
+        self.list_state.splice(index..index + 1, 1);
+        let markdown = self.render_markdown(text.into(), window, cx);
+        self.rendered_messages_by_id.insert(*id, markdown);
+    }
+
+    fn deleted_message(&mut self, id: &MessageId) {
+        let Some(index) = self.messages.iter().position(|message_id| message_id == id) else {
+            return;
+        };
+        self.messages.remove(index);
+        self.list_state.splice(index..index + 1, 0);
+        self.rendered_messages_by_id.remove(id);
+    }
+
+    fn render_markdown(
+        &self,
+        text: SharedString,
+        window: &Window,
+        cx: &mut Context<Self>,
+    ) -> Entity<Markdown> {
         let theme_settings = ThemeSettings::get_global(cx);
         let colors = cx.theme().colors();
         let ui_font_size = TextSize::Default.rems(cx);
@@ -182,20 +227,15 @@ impl ActiveThread {
             ..Default::default()
         };
 
-        let markdown = cx.new(|cx| {
+        cx.new(|cx| {
             Markdown::new(
-                text.into(),
+                text,
                 markdown_style,
                 Some(self.language_registry.clone()),
                 None,
                 cx,
             )
-        });
-        self.rendered_messages_by_id.insert(*id, markdown);
-        self.list_state.scroll_to(ListOffset {
-            item_ix: old_len,
-            offset_in_item: Pixels(0.0),
-        });
+        })
     }
 
     fn handle_thread_event(
@@ -241,6 +281,35 @@ impl ActiveThread {
 
                 cx.notify();
             }
+            ThreadEvent::MessageEdited(message_id) => {
+                if let Some(message_text) = self
+                    .thread
+                    .read(cx)
+                    .message(*message_id)
+                    .map(|message| message.text.clone())
+                {
+                    self.edited_message(message_id, message_text, window, cx);
+                }
+
+                self.thread_store
+                    .update(cx, |thread_store, cx| {
+                        thread_store.save_thread(&self.thread, cx)
+                    })
+                    .detach_and_log_err(cx);
+
+                cx.notify();
+            }
+            ThreadEvent::MessageDeleted(message_id) => {
+                self.deleted_message(message_id);
+
+                self.thread_store
+                    .update(cx, |thread_store, cx| {
+                        thread_store.save_thread(&self.thread, cx)
+                    })
+                    .detach_and_log_err(cx);
+
+                cx.notify();
+            }
             ThreadEvent::UsePendingTools => {
                 let pending_tool_uses = self
                     .thread
@@ -289,7 +358,101 @@ impl ActiveThread {
         }
     }
 
-    fn render_message(&self, ix: usize, cx: &mut Context<Self>) -> AnyElement {
+    fn start_editing_message(
+        &mut self,
+        message_id: MessageId,
+        message_text: String,
+        window: &mut Window,
+        cx: &mut Context<Self>,
+    ) {
+        let buffer = cx.new(|cx| {
+            MultiBuffer::singleton(cx.new(|cx| Buffer::local(message_text.clone(), cx)), cx)
+        });
+        let editor = cx.new(|cx| {
+            let mut editor = Editor::new(
+                editor::EditorMode::AutoHeight { max_lines: 8 },
+                buffer,
+                None,
+                false,
+                window,
+                cx,
+            );
+            editor.focus_handle(cx).focus(window);
+            editor.move_to_end(&editor::actions::MoveToEnd, window, cx);
+            editor
+        });
+        self.editing_message = Some((
+            message_id,
+            EditMessageState {
+                editor: editor.clone(),
+            },
+        ));
+        cx.notify();
+    }
+
+    fn cancel_editing_message(&mut self, _: &menu::Cancel, _: &mut Window, cx: &mut Context<Self>) {
+        self.editing_message.take();
+        cx.notify();
+    }
+
+    fn confirm_editing_message(
+        &mut self,
+        _: &menu::Confirm,
+        _: &mut Window,
+        cx: &mut Context<Self>,
+    ) {
+        let Some((message_id, state)) = self.editing_message.take() else {
+            return;
+        };
+        let edited_text = state.editor.read(cx).text(cx);
+        self.thread.update(cx, |thread, cx| {
+            thread.edit_message(message_id, Role::User, edited_text, cx);
+            for message_id in self.messages_after(message_id) {
+                thread.delete_message(*message_id, cx);
+            }
+        });
+
+        let provider = LanguageModelRegistry::read_global(cx).active_provider();
+        if provider
+            .as_ref()
+            .map_or(false, |provider| provider.must_accept_terms(cx))
+        {
+            cx.notify();
+            return;
+        }
+        let model_registry = LanguageModelRegistry::read_global(cx);
+        let Some(model) = model_registry.active_model() else {
+            return;
+        };
+
+        self.thread.update(cx, |thread, cx| {
+            thread.send_to_model(model, RequestKind::Chat, false, cx)
+        });
+        cx.notify();
+    }
+
+    fn last_user_message(&self, cx: &Context<Self>) -> Option<MessageId> {
+        self.messages
+            .iter()
+            .rev()
+            .find(|message_id| {
+                self.thread
+                    .read(cx)
+                    .message(**message_id)
+                    .map_or(false, |message| message.role == Role::User)
+            })
+            .cloned()
+    }
+
+    fn messages_after(&self, message_id: MessageId) -> &[MessageId] {
+        self.messages
+            .iter()
+            .position(|id| *id == message_id)
+            .map(|index| &self.messages[index + 1..])
+            .unwrap_or(&[])
+    }
+
+    fn render_message(&self, ix: usize, window: &mut Window, cx: &mut Context<Self>) -> AnyElement {
         let message_id = self.messages[ix];
         let Some(message) = self.thread.read(cx).message(message_id) else {
             return Empty.into_any();
@@ -308,8 +471,28 @@ impl ActiveThread {
             return Empty.into_any();
         }
 
+        let allow_editing_message =
+            message.role == Role::User && self.last_user_message(cx) == Some(message_id);
+
+        let edit_message_editor = self
+            .editing_message
+            .as_ref()
+            .filter(|(id, _)| *id == message_id)
+            .map(|(_, state)| state.editor.clone());
+
         let message_content = v_flex()
-            .child(div().p_2p5().text_ui(cx).child(markdown.clone()))
+            .child(
+                if let Some(edit_message_editor) = edit_message_editor.clone() {
+                    div()
+                        .key_context("EditMessageEditor")
+                        .on_action(cx.listener(Self::cancel_editing_message))
+                        .on_action(cx.listener(Self::confirm_editing_message))
+                        .p_2p5()
+                        .child(edit_message_editor)
+                } else {
+                    div().p_2p5().text_ui(cx).child(markdown.clone())
+                },
+            )
             .when_some(context, |parent, context| {
                 if !context.is_empty() {
                     parent.child(
@@ -358,6 +541,55 @@ impl ActiveThread {
                                                 .size(LabelSize::Small)
                                                 .color(Color::Muted),
                                         ),
+                                )
+                                .when_some(
+                                    edit_message_editor.clone(),
+                                    |this, edit_message_editor| {
+                                        let focus_handle = edit_message_editor.focus_handle(cx);
+                                        this.child(
+                                            h_flex()
+                                                .gap_1()
+                                                .child(
+                                                    Button::new("cancel-edit-message", "Cancel")
+                                                        .key_binding(KeyBinding::for_action_in(
+                                                            &menu::Cancel,
+                                                            &focus_handle,
+                                                            window,
+                                                            cx,
+                                                        )),
+                                                )
+                                                .child(
+                                                    Button::new(
+                                                        "confirm-edit-message",
+                                                        "Regenerate",
+                                                    )
+                                                    .key_binding(KeyBinding::for_action_in(
+                                                        &menu::Confirm,
+                                                        &focus_handle,
+                                                        window,
+                                                        cx,
+                                                    )),
+                                                ),
+                                        )
+                                    },
+                                )
+                                .when(
+                                    edit_message_editor.is_none() && allow_editing_message,
+                                    |this| {
+                                        this.child(Button::new("edit-message", "Edit").on_click(
+                                            cx.listener({
+                                                let message_text = message.text.clone();
+                                                move |this, _, window, cx| {
+                                                    this.start_editing_message(
+                                                        message_id,
+                                                        message_text.clone(),
+                                                        window,
+                                                        cx,
+                                                    );
+                                                }
+                                            }),
+                                        ))
+                                    },
                                 ),
                         )
                         .child(message_content),

crates/assistant2/src/thread.rs 🔗

@@ -99,7 +99,13 @@ impl Thread {
         tools: Arc<ToolWorkingSet>,
         _cx: &mut Context<Self>,
     ) -> Self {
-        let next_message_id = MessageId(saved.messages.len());
+        let next_message_id = MessageId(
+            saved
+                .messages
+                .last()
+                .map(|message| message.id.0 + 1)
+                .unwrap_or(0),
+        );
         let tool_use = ToolUseState::from_saved_messages(&saved.messages);
 
         Self {
@@ -229,6 +235,34 @@ impl Thread {
         id
     }
 
+    pub fn edit_message(
+        &mut self,
+        id: MessageId,
+        new_role: Role,
+        new_text: String,
+        cx: &mut Context<Self>,
+    ) -> bool {
+        let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
+            return false;
+        };
+        message.role = new_role;
+        message.text = new_text;
+        self.touch_updated_at();
+        cx.emit(ThreadEvent::MessageEdited(id));
+        true
+    }
+
+    pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
+        let Some(index) = self.messages.iter().position(|message| message.id == id) else {
+            return false;
+        };
+        self.messages.remove(index);
+        self.context_by_message.remove(&id);
+        self.touch_updated_at();
+        cx.emit(ThreadEvent::MessageDeleted(id));
+        true
+    }
+
     /// Returns the representation of this [`Thread`] in a textual form.
     ///
     /// This is the representation we use when attaching a thread as context to another thread.
@@ -567,6 +601,8 @@ pub enum ThreadEvent {
     StreamedCompletion,
     StreamedAssistantText(MessageId, String),
     MessageAdded(MessageId),
+    MessageEdited(MessageId),
+    MessageDeleted(MessageId),
     SummaryChanged,
     UsePendingTools,
     ToolFinished {