Cycle message roles on click

Antonio Scandurra and Nathan Sobo created

Co-Authored-By: Nathan Sobo <nathan@zed.dev>

Change summary

crates/ai/src/ai.rs               |  10 ++
crates/ai/src/assistant.rs        | 139 +++++++++++++++++++++-----------
crates/theme/src/theme.rs         |   5 
styles/src/styleTree/assistant.ts |   3 
4 files changed, 106 insertions(+), 51 deletions(-)

Detailed changes

crates/ai/src/ai.rs 🔗

@@ -34,6 +34,16 @@ enum Role {
     System,
 }
 
+impl Role {
+    pub fn cycle(&mut self) {
+        *self = match self {
+            Role::User => Role::Assistant,
+            Role::Assistant => Role::System,
+            Role::System => Role::User,
+        }
+    }
+}
+
 impl Display for Role {
     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
         match self {

crates/ai/src/assistant.rs 🔗

@@ -485,14 +485,16 @@ impl Assistant {
         let messages = self
             .messages
             .iter()
-            .map(|message| tiktoken_rs::ChatCompletionRequestMessage {
-                role: match message.role {
-                    Role::User => "user".into(),
-                    Role::Assistant => "assistant".into(),
-                    Role::System => "system".into(),
-                },
-                content: message.content.read(cx).text(),
-                name: None,
+            .filter_map(|message| {
+                Some(tiktoken_rs::ChatCompletionRequestMessage {
+                    role: match self.messages_metadata.get(&message.excerpt_id)?.role {
+                        Role::User => "user".into(),
+                        Role::Assistant => "assistant".into(),
+                        Role::System => "system".into(),
+                    },
+                    content: message.content.read(cx).text(),
+                    name: None,
+                })
             })
             .collect::<Vec<_>>();
         let model = self.model.clone();
@@ -529,9 +531,11 @@ impl Assistant {
         let messages = self
             .messages
             .iter()
-            .map(|message| RequestMessage {
-                role: message.role,
-                content: message.content.read(cx).text(),
+            .filter_map(|message| {
+                Some(RequestMessage {
+                    role: self.messages_metadata.get(&message.excerpt_id)?.role,
+                    content: message.content.read(cx).text(),
+                })
             })
             .collect();
         let request = OpenAIRequest {
@@ -621,6 +625,13 @@ impl Assistant {
         }
     }
 
+    fn cycle_message_role(&mut self, excerpt_id: ExcerptId, cx: &mut ModelContext<Self>) {
+        if let Some(metadata) = self.messages_metadata.get_mut(&excerpt_id) {
+            metadata.role.cycle();
+            cx.notify();
+        }
+    }
+
     fn push_message(
         &mut self,
         role: Role,
@@ -659,7 +670,6 @@ impl Assistant {
 
         self.messages.push(Message {
             excerpt_id,
-            role,
             content: content.clone(),
         });
         self.messages_metadata.insert(
@@ -681,9 +691,11 @@ impl Assistant {
                     .messages
                     .iter()
                     .take(2)
-                    .map(|message| RequestMessage {
-                        role: message.role,
-                        content: message.content.read(cx).text(),
+                    .filter_map(|message| {
+                        Some(RequestMessage {
+                            role: self.messages_metadata.get(&message.excerpt_id)?.role,
+                            content: message.content.read(cx).text(),
+                        })
                     })
                     .chain(Some(RequestMessage {
                         role: Role::User,
@@ -753,27 +765,51 @@ impl AssistantEditor {
                 {
                     let assistant = assistant.clone();
                     move |_editor, params: editor::RenderExcerptHeaderParams, cx| {
+                        enum Sender {}
                         enum ErrorTooltip {}
 
                         let theme = theme::current(cx);
                         let style = &theme.assistant;
-                        if let Some(metadata) = assistant.read(cx).messages_metadata.get(&params.id)
+                        let excerpt_id = params.id;
+                        if let Some(metadata) = assistant
+                            .read(cx)
+                            .messages_metadata
+                            .get(&excerpt_id)
+                            .cloned()
                         {
-                            let sender = match metadata.role {
-                                Role::User => Label::new("You", style.user_sender.text.clone())
-                                    .contained()
-                                    .with_style(style.user_sender.container),
-                                Role::Assistant => {
-                                    Label::new("Assistant", style.assistant_sender.text.clone())
-                                        .contained()
-                                        .with_style(style.assistant_sender.container)
+                            let sender = MouseEventHandler::<Sender, _>::new(
+                                params.id.into(),
+                                cx,
+                                |state, _| match metadata.role {
+                                    Role::User => {
+                                        let style = style.user_sender.style_for(state, false);
+                                        Label::new("You", style.text.clone())
+                                            .contained()
+                                            .with_style(style.container)
+                                    }
+                                    Role::Assistant => {
+                                        let style = style.assistant_sender.style_for(state, false);
+                                        Label::new("Assistant", style.text.clone())
+                                            .contained()
+                                            .with_style(style.container)
+                                    }
+                                    Role::System => {
+                                        let style = style.system_sender.style_for(state, false);
+                                        Label::new("System", style.text.clone())
+                                            .contained()
+                                            .with_style(style.container)
+                                    }
+                                },
+                            )
+                            .with_cursor_style(CursorStyle::PointingHand)
+                            .on_down(MouseButton::Left, {
+                                let assistant = assistant.clone();
+                                move |_, _, cx| {
+                                    assistant.update(cx, |assistant, cx| {
+                                        assistant.cycle_message_role(excerpt_id, cx)
+                                    })
                                 }
-                                Role::System => {
-                                    Label::new("System", style.assistant_sender.text.clone())
-                                        .contained()
-                                        .with_style(style.assistant_sender.container)
-                                }
-                            };
+                            });
 
                             Flex::row()
                                 .with_child(sender.aligned())
@@ -786,7 +822,7 @@ impl AssistantEditor {
                                     .with_style(style.sent_at.container)
                                     .aligned(),
                                 )
-                                .with_children(metadata.error.clone().map(|error| {
+                                .with_children(metadata.error.map(|error| {
                                     Svg::new("icons/circle_x_mark_12.svg")
                                         .with_color(style.error_icon.color)
                                         .constrained()
@@ -833,21 +869,22 @@ impl AssistantEditor {
         self.assistant.update(cx, |assistant, cx| {
             let editor = self.editor.read(cx);
             let newest_selection = editor.selections.newest_anchor();
-            let role = if newest_selection.head() == Anchor::min() {
-                assistant.messages.first().map(|message| message.role)
+            let excerpt_id = if newest_selection.head() == Anchor::min() {
+                assistant.messages.first().map(|message| message.excerpt_id)
             } else if newest_selection.head() == Anchor::max() {
-                assistant.messages.last().map(|message| message.role)
+                assistant.messages.last().map(|message| message.excerpt_id)
             } else {
-                assistant
-                    .messages_metadata
-                    .get(&newest_selection.head().excerpt_id())
-                    .map(|message| message.role)
+                Some(newest_selection.head().excerpt_id())
             };
 
-            if role.map_or(false, |role| role == Role::Assistant) {
-                assistant.push_message(Role::User, cx);
-            } else {
-                assistant.assist(cx);
+            if let Some(excerpt_id) = excerpt_id {
+                if let Some(metadata) = assistant.messages_metadata.get(&excerpt_id) {
+                    if metadata.role == Role::User {
+                        assistant.assist(cx);
+                    } else {
+                        assistant.push_message(Role::User, cx);
+                    }
+                }
             }
         });
     }
@@ -967,12 +1004,17 @@ impl AssistantEditor {
                     let range = cmp::max(message_range.start, selection.range().start)
                         ..cmp::min(message_range.end, selection.range().end);
                     if !range.is_empty() {
-                        spanned_messages += 1;
-                        write!(&mut copied_text, "## {}\n\n", message.role).unwrap();
-                        for chunk in assistant.buffer.read(cx).snapshot(cx).text_for_range(range) {
-                            copied_text.push_str(&chunk);
+                        if let Some(metadata) = assistant.messages_metadata.get(&message.excerpt_id)
+                        {
+                            spanned_messages += 1;
+                            write!(&mut copied_text, "## {}\n\n", metadata.role).unwrap();
+                            for chunk in
+                                assistant.buffer.read(cx).snapshot(cx).text_for_range(range)
+                            {
+                                copied_text.push_str(&chunk);
+                            }
+                            copied_text.push('\n');
                         }
-                        copied_text.push('\n');
                     }
                 }
 
@@ -1090,11 +1132,10 @@ impl Item for AssistantEditor {
 #[derive(Debug)]
 struct Message {
     excerpt_id: ExcerptId,
-    role: Role,
     content: ModelHandle<Buffer>,
 }
 
-#[derive(Debug)]
+#[derive(Clone, Debug)]
 struct MessageMetadata {
     role: Role,
     sent_at: DateTime<Local>,

crates/theme/src/theme.rs 🔗

@@ -974,8 +974,9 @@ pub struct AssistantStyle {
     pub container: ContainerStyle,
     pub header: ContainerStyle,
     pub sent_at: ContainedText,
-    pub user_sender: ContainedText,
-    pub assistant_sender: ContainedText,
+    pub user_sender: Interactive<ContainedText>,
+    pub assistant_sender: Interactive<ContainedText>,
+    pub system_sender: Interactive<ContainedText>,
     pub model_info_container: ContainerStyle,
     pub model: Interactive<ContainedText>,
     pub remaining_tokens: ContainedText,

styles/src/styleTree/assistant.ts 🔗

@@ -20,6 +20,9 @@ export default function assistant(colorScheme: ColorScheme) {
       assistantSender: {
         ...text(layer, "sans", "accent", { size: "sm", weight: "bold" }),
       },
+      systemSender: {
+        ...text(layer, "sans", "variant", { size: "sm", weight: "bold" }),
+      },
       sentAt: {
         margin: { top: 2, left: 8 },
         ...text(layer, "sans", "default", { size: "2xs" }),