Cycle message roles on ctrl-r

Nathan Sobo created

Change summary

assets/keymaps/default.json |   3 
crates/ai/src/assistant.rs  | 168 ++++++++++++++++++++++----------------
2 files changed, 98 insertions(+), 73 deletions(-)

Detailed changes

assets/keymaps/default.json 🔗

@@ -201,7 +201,8 @@
     "bindings": {
       "cmd-enter": "assistant::Assist",
       "cmd->": "assistant::QuoteSelection",
-      "shift-enter": "assistant::Split"
+      "shift-enter": "assistant::Split",
+      "ctrl-r": "assistant::CycleMessageRole"
     }
   },
   {

crates/ai/src/assistant.rs 🔗

@@ -44,6 +44,7 @@ actions!(
         NewContext,
         Assist,
         Split,
+        CycleMessageRole,
         QuoteSelection,
         ToggleFocus,
         ResetKey
@@ -72,6 +73,7 @@ pub fn init(cx: &mut AppContext) {
     cx.add_action(AssistantEditor::quote_selection);
     cx.capture_action(AssistantEditor::copy);
     cx.capture_action(AssistantEditor::split);
+    cx.capture_action(AssistantEditor::cycle_message_role);
     cx.add_action(AssistantPanel::save_api_key);
     cx.add_action(AssistantPanel::reset_api_key);
     cx.add_action(
@@ -446,7 +448,7 @@ enum AssistantEvent {
 
 struct Assistant {
     buffer: ModelHandle<Buffer>,
-    messages: Vec<Message>,
+    message_anchors: Vec<MessageAnchor>,
     messages_metadata: HashMap<MessageId, MessageMetadata>,
     next_message_id: MessageId,
     summary: Option<String>,
@@ -491,7 +493,7 @@ impl Assistant {
         });
 
         let mut this = Self {
-            messages: Default::default(),
+            message_anchors: Default::default(),
             messages_metadata: Default::default(),
             next_message_id: Default::default(),
             summary: None,
@@ -506,11 +508,11 @@ impl Assistant {
             api_key,
             buffer,
         };
-        let message = Message {
+        let message = MessageAnchor {
             id: MessageId(post_inc(&mut this.next_message_id.0)),
             start: language::Anchor::MIN,
         };
-        this.messages.push(message.clone());
+        this.message_anchors.push(message.clone());
         this.messages_metadata.insert(
             message.id,
             MessageMetadata {
@@ -587,7 +589,7 @@ impl Assistant {
         cx.notify();
     }
 
-    fn assist(&mut self, cx: &mut ModelContext<Self>) -> Option<(Message, Message)> {
+    fn assist(&mut self, cx: &mut ModelContext<Self>) -> Option<(MessageAnchor, MessageAnchor)> {
         let request = OpenAIRequest {
             model: self.model.clone(),
             messages: self.open_ai_request_messages(cx),
@@ -597,7 +599,7 @@ impl Assistant {
         let api_key = self.api_key.borrow().clone()?;
         let stream = stream_completion(api_key, cx.background().clone(), request);
         let assistant_message =
-            self.insert_message_after(self.messages.last()?.id, Role::Assistant, cx)?;
+            self.insert_message_after(self.message_anchors.last()?.id, Role::Assistant, cx)?;
         let user_message = self.insert_message_after(assistant_message.id, Role::User, cx)?;
         let task = cx.spawn_weak({
             |this, mut cx| async move {
@@ -613,14 +615,15 @@ impl Assistant {
                                 .update(&mut cx, |this, cx| {
                                     let text: Arc<str> = choice.delta.content?.into();
                                     let message_ix = this
-                                        .messages
+                                        .message_anchors
                                         .iter()
                                         .position(|message| message.id == assistant_message_id)?;
                                     this.buffer.update(cx, |buffer, cx| {
-                                        let offset = if message_ix + 1 == this.messages.len() {
+                                        let offset = if message_ix + 1 == this.message_anchors.len()
+                                        {
                                             buffer.len()
                                         } else {
-                                            this.messages[message_ix + 1]
+                                            this.message_anchors[message_ix + 1]
                                                 .start
                                                 .to_offset(buffer)
                                                 .saturating_sub(1)
@@ -685,25 +688,26 @@ impl Assistant {
         message_id: MessageId,
         role: Role,
         cx: &mut ModelContext<Self>,
-    ) -> Option<Message> {
+    ) -> Option<MessageAnchor> {
         if let Some(prev_message_ix) = self
-            .messages
+            .message_anchors
             .iter()
             .position(|message| message.id == message_id)
         {
             let start = self.buffer.update(cx, |buffer, cx| {
-                let offset = self.messages[prev_message_ix + 1..]
+                let offset = self.message_anchors[prev_message_ix + 1..]
                     .iter()
                     .find(|message| message.start.is_valid(buffer))
                     .map_or(buffer.len(), |message| message.start.to_offset(buffer) - 1);
                 buffer.edit([(offset..offset, "\n")], None, cx);
                 buffer.anchor_before(offset + 1)
             });
-            let message = Message {
+            let message = MessageAnchor {
                 id: MessageId(post_inc(&mut self.next_message_id.0)),
                 start,
             };
-            self.messages.insert(prev_message_ix + 1, message.clone());
+            self.message_anchors
+                .insert(prev_message_ix + 1, message.clone());
             self.messages_metadata.insert(
                 message.id,
                 MessageMetadata {
@@ -723,23 +727,21 @@ impl Assistant {
         &mut self,
         range: Range<usize>,
         cx: &mut ModelContext<Self>,
-    ) -> (Option<Message>, Option<Message>) {
+    ) -> (Option<MessageAnchor>, Option<MessageAnchor>) {
         let start_message = self.message_for_offset(range.start, cx);
         let end_message = self.message_for_offset(range.end, cx);
         if let Some((start_message, end_message)) = start_message.zip(end_message) {
-            let (start_message_ix, _, metadata, message_range) = start_message;
-            let (end_message_ix, _, _, _) = end_message;
-
             // Prevent splitting when range spans multiple messages.
-            if start_message_ix != end_message_ix {
+            if start_message.index != end_message.index {
                 return (None, None);
             }
 
-            let role = metadata.role;
+            let message = start_message;
+            let role = message.role;
             let mut edited_buffer = false;
 
             let mut suffix_start = None;
-            if range.start > message_range.start && range.end < message_range.end - 1 {
+            if range.start > message.range.start && range.end < message.range.end - 1 {
                 if self.buffer.read(cx).chars_at(range.end).next() == Some('\n') {
                     suffix_start = Some(range.end + 1);
                 } else if self.buffer.read(cx).reversed_chars_at(range.end).next() == Some('\n') {
@@ -748,7 +750,7 @@ impl Assistant {
             }
 
             let suffix = if let Some(suffix_start) = suffix_start {
-                Message {
+                MessageAnchor {
                     id: MessageId(post_inc(&mut self.next_message_id.0)),
                     start: self.buffer.read(cx).anchor_before(suffix_start),
                 }
@@ -757,13 +759,14 @@ impl Assistant {
                     buffer.edit([(range.end..range.end, "\n")], None, cx);
                 });
                 edited_buffer = true;
-                Message {
+                MessageAnchor {
                     id: MessageId(post_inc(&mut self.next_message_id.0)),
                     start: self.buffer.read(cx).anchor_before(range.end + 1),
                 }
             };
 
-            self.messages.insert(start_message_ix + 1, suffix.clone());
+            self.message_anchors
+                .insert(message.index + 1, suffix.clone());
             self.messages_metadata.insert(
                 suffix.id,
                 MessageMetadata {
@@ -773,11 +776,11 @@ impl Assistant {
                 },
             );
 
-            let new_messages = if range.start == range.end || range.start == message_range.start {
+            let new_messages = if range.start == range.end || range.start == message.range.start {
                 (None, Some(suffix))
             } else {
                 let mut prefix_end = None;
-                if range.start > message_range.start && range.end < message_range.end - 1 {
+                if range.start > message.range.start && range.end < message.range.end - 1 {
                     if self.buffer.read(cx).chars_at(range.start).next() == Some('\n') {
                         prefix_end = Some(range.start + 1);
                     } else if self.buffer.read(cx).reversed_chars_at(range.start).next()
@@ -789,7 +792,7 @@ impl Assistant {
 
                 let selection = if let Some(prefix_end) = prefix_end {
                     cx.emit(AssistantEvent::MessagesEdited);
-                    Message {
+                    MessageAnchor {
                         id: MessageId(post_inc(&mut self.next_message_id.0)),
                         start: self.buffer.read(cx).anchor_before(prefix_end),
                     }
@@ -798,14 +801,14 @@ impl Assistant {
                         buffer.edit([(range.start..range.start, "\n")], None, cx)
                     });
                     edited_buffer = true;
-                    Message {
+                    MessageAnchor {
                         id: MessageId(post_inc(&mut self.next_message_id.0)),
                         start: self.buffer.read(cx).anchor_before(range.end + 1),
                     }
                 };
 
-                self.messages
-                    .insert(start_message_ix + 1, selection.clone());
+                self.message_anchors
+                    .insert(message.index + 1, selection.clone());
                 self.messages_metadata.insert(
                     selection.id,
                     MessageMetadata {
@@ -827,7 +830,7 @@ impl Assistant {
     }
 
     fn summarize(&mut self, cx: &mut ModelContext<Self>) {
-        if self.messages.len() >= 2 && self.summary.is_none() {
+        if self.message_anchors.len() >= 2 && self.summary.is_none() {
             let api_key = self.api_key.borrow().clone();
             if let Some(api_key) = api_key {
                 let mut messages = self.open_ai_request_messages(cx);
@@ -870,50 +873,51 @@ impl Assistant {
     fn open_ai_request_messages(&self, cx: &AppContext) -> Vec<RequestMessage> {
         let buffer = self.buffer.read(cx);
         self.messages(cx)
-            .map(|(_ix, _message, metadata, range)| RequestMessage {
-                role: metadata.role,
-                content: buffer.text_for_range(range).collect(),
+            .map(|message| RequestMessage {
+                role: message.role,
+                content: buffer.text_for_range(message.range).collect(),
             })
             .collect()
     }
 
-    fn message_for_offset<'a>(
-        &'a self,
-        offset: usize,
-        cx: &'a AppContext,
-    ) -> Option<(usize, &Message, &MessageMetadata, Range<usize>)> {
+    fn message_for_offset<'a>(&'a self, offset: usize, cx: &'a AppContext) -> Option<Message> {
         let mut messages = self.messages(cx).peekable();
-        while let Some((ix, message, metadata, range)) = messages.next() {
-            if range.contains(&offset) || messages.peek().is_none() {
-                return Some((ix, message, metadata, range));
+        while let Some(message) = messages.next() {
+            if message.range.contains(&offset) || messages.peek().is_none() {
+                return Some(message);
             }
         }
         None
     }
 
-    fn messages<'a>(
-        &'a self,
-        cx: &'a AppContext,
-    ) -> impl 'a + Iterator<Item = (usize, &Message, &MessageMetadata, Range<usize>)> {
+    fn messages<'a>(&'a self, cx: &'a AppContext) -> impl 'a + Iterator<Item = Message> {
         let buffer = self.buffer.read(cx);
-        let mut messages = self.messages.iter().enumerate().peekable();
+        let mut message_anchors = self.message_anchors.iter().enumerate().peekable();
         iter::from_fn(move || {
-            while let Some((ix, message)) = messages.next() {
-                let metadata = self.messages_metadata.get(&message.id)?;
-                let message_start = message.start.to_offset(buffer);
+            while let Some((ix, message_anchor)) = message_anchors.next() {
+                let metadata = self.messages_metadata.get(&message_anchor.id)?;
+                let message_start = message_anchor.start.to_offset(buffer);
                 let mut message_end = None;
-                while let Some((_, next_message)) = messages.peek() {
+                while let Some((_, next_message)) = message_anchors.peek() {
                     if next_message.start.is_valid(buffer) {
                         message_end = Some(next_message.start);
                         break;
                     } else {
-                        messages.next();
+                        message_anchors.next();
                     }
                 }
                 let message_end = message_end
                     .unwrap_or(language::Anchor::MAX)
                     .to_offset(buffer);
-                return Some((ix, message, metadata, message_start..message_end));
+                return Some(Message {
+                    index: ix,
+                    range: message_start..message_end,
+                    id: message_anchor.id,
+                    anchor: message_anchor.start,
+                    role: metadata.role,
+                    sent_at: metadata.sent_at,
+                    error: metadata.error.clone(),
+                });
             }
             None
         })
@@ -1003,6 +1007,15 @@ impl AssistantEditor {
         }
     }
 
+    fn cycle_message_role(&mut self, _: &CycleMessageRole, cx: &mut ViewContext<Self>) {
+        let cursor_offset = self.editor.read(cx).selections.newest(cx).head();
+        self.assistant.update(cx, |assistant, cx| {
+            if let Some(message) = assistant.message_for_offset(cursor_offset, cx) {
+                assistant.cycle_message_role(message.id, cx);
+            }
+        });
+    }
+
     fn handle_assistant_event(
         &mut self,
         _: ModelHandle<Assistant>,
@@ -1087,14 +1100,14 @@ impl AssistantEditor {
                 .assistant
                 .read(cx)
                 .messages(cx)
-                .map(|(_, message, metadata, _)| BlockProperties {
-                    position: buffer.anchor_in_excerpt(excerpt_id, message.start),
+                .map(|message| BlockProperties {
+                    position: buffer.anchor_in_excerpt(excerpt_id, message.anchor),
                     height: 2,
                     style: BlockStyle::Sticky,
                     render: Arc::new({
                         let assistant = self.assistant.clone();
-                        let metadata = metadata.clone();
-                        let message = message.clone();
+                        // let metadata = message.metadata.clone();
+                        // let message = message.clone();
                         move |cx| {
                             enum Sender {}
                             enum ErrorTooltip {}
@@ -1105,7 +1118,7 @@ impl AssistantEditor {
                             let sender = MouseEventHandler::<Sender, _>::new(
                                 message_id.0,
                                 cx,
-                                |state, _| match metadata.role {
+                                |state, _| match message.role {
                                     Role::User => {
                                         let style = style.user_sender.style_for(state, false);
                                         Label::new("You", style.text.clone())
@@ -1140,14 +1153,14 @@ impl AssistantEditor {
                                 .with_child(sender.aligned())
                                 .with_child(
                                     Label::new(
-                                        metadata.sent_at.format("%I:%M%P").to_string(),
+                                        message.sent_at.format("%I:%M%P").to_string(),
                                         style.sent_at.text.clone(),
                                     )
                                     .contained()
                                     .with_style(style.sent_at.container)
                                     .aligned(),
                                 )
-                                .with_children(metadata.error.clone().map(|error| {
+                                .with_children(message.error.as_ref().map(|error| {
                                     Svg::new("icons/circle_x_mark_12.svg")
                                         .with_color(style.error_icon.color)
                                         .constrained()
@@ -1156,7 +1169,7 @@ impl AssistantEditor {
                                         .with_style(style.error_icon.container)
                                         .with_tooltip::<ErrorTooltip>(
                                             message_id.0,
-                                            error,
+                                            error.to_string(),
                                             None,
                                             theme.tooltip.clone(),
                                             cx,
@@ -1252,15 +1265,15 @@ impl AssistantEditor {
             let selection = editor.selections.newest::<usize>(cx);
             let mut copied_text = String::new();
             let mut spanned_messages = 0;
-            for (_ix, _message, metadata, message_range) in assistant.messages(cx) {
-                if message_range.start >= selection.range().end {
+            for message in assistant.messages(cx) {
+                if message.range.start >= selection.range().end {
                     break;
-                } else if message_range.end >= selection.range().start {
-                    let range = cmp::max(message_range.start, selection.range().start)
-                        ..cmp::min(message_range.end, selection.range().end);
+                } else if message.range.end >= selection.range().start {
+                    let range = cmp::max(message.range.start, selection.range().start)
+                        ..cmp::min(message.range.end, selection.range().end);
                     if !range.is_empty() {
                         spanned_messages += 1;
-                        write!(&mut copied_text, "## {}\n\n", metadata.role).unwrap();
+                        write!(&mut copied_text, "## {}\n\n", message.role).unwrap();
                         for chunk in assistant.buffer.read(cx).text_for_range(range) {
                             copied_text.push_str(&chunk);
                         }
@@ -1395,7 +1408,7 @@ impl Item for AssistantEditor {
 struct MessageId(usize);
 
 #[derive(Clone, Debug)]
-struct Message {
+struct MessageAnchor {
     id: MessageId,
     start: language::Anchor,
 }
@@ -1404,7 +1417,18 @@ struct Message {
 struct MessageMetadata {
     role: Role,
     sent_at: DateTime<Local>,
-    error: Option<String>,
+    error: Option<Arc<str>>,
+}
+
+#[derive(Clone, Debug)]
+pub struct Message {
+    range: Range<usize>,
+    index: usize,
+    id: MessageId,
+    anchor: language::Anchor,
+    role: Role,
+    sent_at: DateTime<Local>,
+    error: Option<Arc<str>>,
 }
 
 async fn stream_completion(
@@ -1504,7 +1528,7 @@ mod tests {
         let assistant = cx.add_model(|cx| Assistant::new(Default::default(), registry, cx));
         let buffer = assistant.read(cx).buffer.clone();
 
-        let message_1 = assistant.read(cx).messages[0].clone();
+        let message_1 = assistant.read(cx).message_anchors[0].clone();
         assert_eq!(
             messages(&assistant, cx),
             vec![(message_1.id, Role::User, 0..0)]
@@ -1630,7 +1654,7 @@ mod tests {
         let assistant = cx.add_model(|cx| Assistant::new(Default::default(), registry, cx));
         let buffer = assistant.read(cx).buffer.clone();
 
-        let message_1 = assistant.read(cx).messages[0].clone();
+        let message_1 = assistant.read(cx).message_anchors[0].clone();
         assert_eq!(
             messages(&assistant, cx),
             vec![(message_1.id, Role::User, 0..0)]
@@ -1724,7 +1748,7 @@ mod tests {
         assistant
             .read(cx)
             .messages(cx)
-            .map(|(_, message, metadata, range)| (message.id, metadata.role, range))
+            .map(|message| (message.id, message.role, message.range))
             .collect()
     }
 }