@@ -2,17 +2,18 @@ use std::sync::Arc;
use assistant_tool::ToolWorkingSet;
use collections::HashMap;
+use editor::{Editor, MultiBuffer};
use gpui::{
- list, AbsoluteLength, AnyElement, App, DefiniteLength, EdgesRefinement, Empty, Entity, Length,
- ListAlignment, ListOffset, ListState, StyleRefinement, Subscription, TextStyleRefinement,
- UnderlineStyle, WeakEntity,
+ list, AbsoluteLength, AnyElement, App, DefiniteLength, EdgesRefinement, Empty, Entity,
+ Focusable, Length, ListAlignment, ListOffset, ListState, StyleRefinement, Subscription,
+ TextStyleRefinement, UnderlineStyle, WeakEntity,
};
-use language::LanguageRegistry;
+use language::{Buffer, LanguageRegistry};
use language_model::{LanguageModelRegistry, LanguageModelToolUseId, Role};
use markdown::{Markdown, MarkdownStyle};
use settings::Settings as _;
use theme::ThemeSettings;
-use ui::{prelude::*, Disclosure};
+use ui::{prelude::*, Disclosure, KeyBinding};
use workspace::Workspace;
use crate::thread::{MessageId, RequestKind, Thread, ThreadError, ThreadEvent};
@@ -29,11 +30,16 @@ pub struct ActiveThread {
messages: Vec<MessageId>,
list_state: ListState,
rendered_messages_by_id: HashMap<MessageId, Entity<Markdown>>,
+ editing_message: Option<(MessageId, EditMessageState)>,
expanded_tool_uses: HashMap<LanguageModelToolUseId, bool>,
last_error: Option<ThreadError>,
_subscriptions: Vec<Subscription>,
}
+struct EditMessageState {
+ editor: Entity<Editor>,
+}
+
impl ActiveThread {
pub fn new(
thread: Entity<Thread>,
@@ -60,11 +66,12 @@ impl ActiveThread {
expanded_tool_uses: HashMap::default(),
list_state: ListState::new(0, ListAlignment::Bottom, px(1024.), {
let this = cx.entity().downgrade();
- move |ix, _: &mut Window, cx: &mut App| {
- this.update(cx, |this, cx| this.render_message(ix, cx))
+ move |ix, window: &mut Window, cx: &mut App| {
+ this.update(cx, |this, cx| this.render_message(ix, window, cx))
.unwrap()
}
}),
+ editing_message: None,
last_error: None,
_subscriptions: subscriptions,
};
@@ -117,6 +124,44 @@ impl ActiveThread {
self.messages.push(*id);
self.list_state.splice(old_len..old_len, 1);
+ let markdown = self.render_markdown(text.into(), window, cx);
+ self.rendered_messages_by_id.insert(*id, markdown);
+ self.list_state.scroll_to(ListOffset {
+ item_ix: old_len,
+ offset_in_item: Pixels(0.0),
+ });
+ }
+
+ fn edited_message(
+ &mut self,
+ id: &MessageId,
+ text: String,
+ window: &mut Window,
+ cx: &mut Context<Self>,
+ ) {
+ let Some(index) = self.messages.iter().position(|message_id| message_id == id) else {
+ return;
+ };
+ self.list_state.splice(index..index + 1, 1);
+ let markdown = self.render_markdown(text.into(), window, cx);
+ self.rendered_messages_by_id.insert(*id, markdown);
+ }
+
+ fn deleted_message(&mut self, id: &MessageId) {
+ let Some(index) = self.messages.iter().position(|message_id| message_id == id) else {
+ return;
+ };
+ self.messages.remove(index);
+ self.list_state.splice(index..index + 1, 0);
+ self.rendered_messages_by_id.remove(id);
+ }
+
+ fn render_markdown(
+ &self,
+ text: SharedString,
+ window: &Window,
+ cx: &mut Context<Self>,
+ ) -> Entity<Markdown> {
let theme_settings = ThemeSettings::get_global(cx);
let colors = cx.theme().colors();
let ui_font_size = TextSize::Default.rems(cx);
@@ -182,20 +227,15 @@ impl ActiveThread {
..Default::default()
};
- let markdown = cx.new(|cx| {
+ cx.new(|cx| {
Markdown::new(
- text.into(),
+ text,
markdown_style,
Some(self.language_registry.clone()),
None,
cx,
)
- });
- self.rendered_messages_by_id.insert(*id, markdown);
- self.list_state.scroll_to(ListOffset {
- item_ix: old_len,
- offset_in_item: Pixels(0.0),
- });
+ })
}
fn handle_thread_event(
@@ -241,6 +281,35 @@ impl ActiveThread {
cx.notify();
}
+ ThreadEvent::MessageEdited(message_id) => {
+ if let Some(message_text) = self
+ .thread
+ .read(cx)
+ .message(*message_id)
+ .map(|message| message.text.clone())
+ {
+ self.edited_message(message_id, message_text, window, cx);
+ }
+
+ self.thread_store
+ .update(cx, |thread_store, cx| {
+ thread_store.save_thread(&self.thread, cx)
+ })
+ .detach_and_log_err(cx);
+
+ cx.notify();
+ }
+ ThreadEvent::MessageDeleted(message_id) => {
+ self.deleted_message(message_id);
+
+ self.thread_store
+ .update(cx, |thread_store, cx| {
+ thread_store.save_thread(&self.thread, cx)
+ })
+ .detach_and_log_err(cx);
+
+ cx.notify();
+ }
ThreadEvent::UsePendingTools => {
let pending_tool_uses = self
.thread
@@ -289,7 +358,101 @@ impl ActiveThread {
}
}
- fn render_message(&self, ix: usize, cx: &mut Context<Self>) -> AnyElement {
+ fn start_editing_message(
+ &mut self,
+ message_id: MessageId,
+ message_text: String,
+ window: &mut Window,
+ cx: &mut Context<Self>,
+ ) {
+ let buffer = cx.new(|cx| {
+ MultiBuffer::singleton(cx.new(|cx| Buffer::local(message_text.clone(), cx)), cx)
+ });
+ let editor = cx.new(|cx| {
+ let mut editor = Editor::new(
+ editor::EditorMode::AutoHeight { max_lines: 8 },
+ buffer,
+ None,
+ false,
+ window,
+ cx,
+ );
+ editor.focus_handle(cx).focus(window);
+ editor.move_to_end(&editor::actions::MoveToEnd, window, cx);
+ editor
+ });
+ self.editing_message = Some((
+ message_id,
+ EditMessageState {
+ editor: editor.clone(),
+ },
+ ));
+ cx.notify();
+ }
+
+ fn cancel_editing_message(&mut self, _: &menu::Cancel, _: &mut Window, cx: &mut Context<Self>) {
+ self.editing_message.take();
+ cx.notify();
+ }
+
+ fn confirm_editing_message(
+ &mut self,
+ _: &menu::Confirm,
+ _: &mut Window,
+ cx: &mut Context<Self>,
+ ) {
+ let Some((message_id, state)) = self.editing_message.take() else {
+ return;
+ };
+ let edited_text = state.editor.read(cx).text(cx);
+ self.thread.update(cx, |thread, cx| {
+ thread.edit_message(message_id, Role::User, edited_text, cx);
+ for message_id in self.messages_after(message_id) {
+ thread.delete_message(*message_id, cx);
+ }
+ });
+
+ let provider = LanguageModelRegistry::read_global(cx).active_provider();
+ if provider
+ .as_ref()
+ .map_or(false, |provider| provider.must_accept_terms(cx))
+ {
+ cx.notify();
+ return;
+ }
+ let model_registry = LanguageModelRegistry::read_global(cx);
+ let Some(model) = model_registry.active_model() else {
+ return;
+ };
+
+ self.thread.update(cx, |thread, cx| {
+ thread.send_to_model(model, RequestKind::Chat, false, cx)
+ });
+ cx.notify();
+ }
+
+ fn last_user_message(&self, cx: &Context<Self>) -> Option<MessageId> {
+ self.messages
+ .iter()
+ .rev()
+ .find(|message_id| {
+ self.thread
+ .read(cx)
+ .message(**message_id)
+ .map_or(false, |message| message.role == Role::User)
+ })
+ .cloned()
+ }
+
+ fn messages_after(&self, message_id: MessageId) -> &[MessageId] {
+ self.messages
+ .iter()
+ .position(|id| *id == message_id)
+ .map(|index| &self.messages[index + 1..])
+ .unwrap_or(&[])
+ }
+
+ fn render_message(&self, ix: usize, window: &mut Window, cx: &mut Context<Self>) -> AnyElement {
let message_id = self.messages[ix];
let Some(message) = self.thread.read(cx).message(message_id) else {
return Empty.into_any();
@@ -308,8 +471,28 @@ impl ActiveThread {
return Empty.into_any();
}
+ let allow_editing_message =
+ message.role == Role::User && self.last_user_message(cx) == Some(message_id);
+
+ let edit_message_editor = self
+ .editing_message
+ .as_ref()
+ .filter(|(id, _)| *id == message_id)
+ .map(|(_, state)| state.editor.clone());
+
let message_content = v_flex()
- .child(div().p_2p5().text_ui(cx).child(markdown.clone()))
+ .child(
+ if let Some(edit_message_editor) = edit_message_editor.clone() {
+ div()
+ .key_context("EditMessageEditor")
+ .on_action(cx.listener(Self::cancel_editing_message))
+ .on_action(cx.listener(Self::confirm_editing_message))
+ .p_2p5()
+ .child(edit_message_editor)
+ } else {
+ div().p_2p5().text_ui(cx).child(markdown.clone())
+ },
+ )
.when_some(context, |parent, context| {
if !context.is_empty() {
parent.child(
@@ -358,6 +541,55 @@ impl ActiveThread {
.size(LabelSize::Small)
.color(Color::Muted),
),
+ )
+ .when_some(
+ edit_message_editor.clone(),
+ |this, edit_message_editor| {
+ let focus_handle = edit_message_editor.focus_handle(cx);
+ this.child(
+ h_flex()
+ .gap_1()
+ .child(
+ Button::new("cancel-edit-message", "Cancel")
+ .key_binding(KeyBinding::for_action_in(
+ &menu::Cancel,
+ &focus_handle,
+ window,
+ cx,
+ )),
+ )
+ .child(
+ Button::new(
+ "confirm-edit-message",
+ "Regenerate",
+ )
+ .key_binding(KeyBinding::for_action_in(
+ &menu::Confirm,
+ &focus_handle,
+ window,
+ cx,
+ )),
+ ),
+ )
+ },
+ )
+ .when(
+ edit_message_editor.is_none() && allow_editing_message,
+ |this| {
+ this.child(Button::new("edit-message", "Edit").on_click(
+ cx.listener({
+ let message_text = message.text.clone();
+ move |this, _, window, cx| {
+ this.start_editing_message(
+ message_id,
+ message_text.clone(),
+ window,
+ cx,
+ );
+ }
+ }),
+ ))
+ },
),
)
.child(message_content),