Assistant grouping (#11479)

Kyle Kelley created

Groups collections of assistant messages with their tool calls as
children of the assistant message container.


![image](https://github.com/zed-industries/zed/assets/836375/b26b7c90-4c8d-4bbd-972a-1e769d78a455)

Release Notes:

- N/A

Change summary

crates/assistant2/src/assistant2.rs              | 199 ++++++++++-------
crates/assistant2/src/ui/chat_message.rs         |  25 -
crates/assistant2/src/ui/stories/chat_message.rs |  21 -
3 files changed, 128 insertions(+), 117 deletions(-)

Detailed changes

crates/assistant2/src/assistant2.rs 🔗

@@ -342,8 +342,8 @@ impl AssistantChat {
         }
 
         if self.pending_completion.take().is_some() {
-            if let Some(ChatMessage::Assistant(message)) = self.messages.last() {
-                if message.body.text.is_empty() {
+            if let Some(ChatMessage::Assistant(grouping)) = self.messages.last() {
+                if grouping.messages.is_empty() {
                     self.pop_message(cx);
                 }
             }
@@ -478,22 +478,30 @@ impl AssistantChat {
                 while let Some(delta) = stream.next().await {
                     let delta = delta?;
                     this.update(cx, |this, cx| {
-                        if let Some(ChatMessage::Assistant(AssistantMessage {
-                            body: message_body,
-                            tool_calls: message_tool_calls,
+                        if let Some(ChatMessage::Assistant(GroupedAssistantMessage {
+                            messages,
                             ..
                         })) = this.messages.last_mut()
                         {
+                            if messages.is_empty() {
+                                messages.push(AssistantMessage {
+                                    body: RichText::default(),
+                                    tool_calls: Vec::new(),
+                                })
+                            }
+
+                            let message = messages.last_mut().unwrap();
+
                             if let Some(content) = &delta.content {
                                 body.push_str(content);
                             }
 
                             for tool_call in delta.tool_calls {
                                 let index = tool_call.index as usize;
-                                if index >= message_tool_calls.len() {
-                                    message_tool_calls.resize_with(index + 1, Default::default);
+                                if index >= message.tool_calls.len() {
+                                    message.tool_calls.resize_with(index + 1, Default::default);
                                 }
-                                let call = &mut message_tool_calls[index];
+                                let call = &mut message.tool_calls[index];
 
                                 if let Some(id) = &tool_call.id {
                                     call.id.push_str(id);
@@ -512,7 +520,7 @@ impl AssistantChat {
                                 }
                             }
 
-                            *message_body =
+                            message.body =
                                 RichText::new(body.clone(), &[], &this.language_registry);
                             cx.notify();
                         } else {
@@ -527,9 +535,9 @@ impl AssistantChat {
 
             let mut tool_tasks = Vec::new();
             this.update(cx, |this, cx| {
-                if let Some(ChatMessage::Assistant(AssistantMessage {
+                if let Some(ChatMessage::Assistant(GroupedAssistantMessage {
                     error: message_error,
-                    tool_calls,
+                    messages,
                     ..
                 })) = this.messages.last_mut()
                 {
@@ -537,8 +545,10 @@ impl AssistantChat {
                         message_error.replace(SharedString::from(error.to_string()));
                         cx.notify();
                     } else {
-                        for tool_call in tool_calls.iter() {
-                            tool_tasks.push(this.tool_registry.call(tool_call, cx));
+                        if let Some(current_message) = messages.last_mut() {
+                            for tool_call in current_message.tool_calls.iter() {
+                                tool_tasks.push(this.tool_registry.call(tool_call, cx));
+                            }
                         }
                     }
                 }
@@ -554,21 +564,38 @@ impl AssistantChat {
             let tools = tools.into_iter().filter_map(|tool| tool.ok()).collect();
 
             this.update(cx, |this, cx| {
-                if let Some(ChatMessage::Assistant(AssistantMessage { tool_calls, .. })) =
+                if let Some(ChatMessage::Assistant(GroupedAssistantMessage { messages, .. })) =
                     this.messages.last_mut()
                 {
-                    *tool_calls = tools;
-                    cx.notify();
+                    if let Some(current_message) = messages.last_mut() {
+                        current_message.tool_calls = tools;
+                        cx.notify();
+                    } else {
+                        unreachable!()
+                    }
                 }
             })?;
         }
     }
 
     fn push_new_assistant_message(&mut self, cx: &mut ViewContext<Self>) {
-        let message = ChatMessage::Assistant(AssistantMessage {
+        // If the last message is a grouped assistant message, add to the grouped message
+        if let Some(ChatMessage::Assistant(GroupedAssistantMessage { messages, .. })) =
+            self.messages.last_mut()
+        {
+            messages.push(AssistantMessage {
+                body: RichText::default(),
+                tool_calls: Vec::new(),
+            });
+            return;
+        }
+
+        let message = ChatMessage::Assistant(GroupedAssistantMessage {
             id: self.next_message_id.post_inc(),
-            body: RichText::default(),
-            tool_calls: Vec::new(),
+            messages: vec![AssistantMessage {
+                body: RichText::default(),
+                tool_calls: Vec::new(),
+            }],
             error: None,
         });
         self.push_message(message, cx);
@@ -687,15 +714,14 @@ impl AssistantChat {
                                 crate::ui::ChatMessage::new(
                                     *id,
                                     UserOrAssistant::User(self.user_store.read(cx).current_user()),
-                                    Some(
+                                    // todo!(): clean up the vec usage
+                                    vec![
                                         RichText::new(
                                             body.read(cx).text(cx),
                                             &[],
                                             &self.language_registry,
                                         )
                                         .element(ElementId::from(id.0), cx),
-                                    ),
-                                    Some(
                                         h_flex()
                                             .gap_2()
                                             .children(
@@ -704,7 +730,7 @@ impl AssistantChat {
                                                     .map(|attachment| attachment.view.clone()),
                                             )
                                             .into_any_element(),
-                                    ),
+                                    ],
                                     self.is_message_collapsed(id),
                                     Box::new(cx.listener({
                                         let id = *id;
@@ -719,33 +745,34 @@ impl AssistantChat {
                     }
                 })
                 .into_any(),
-            ChatMessage::Assistant(AssistantMessage {
+            ChatMessage::Assistant(GroupedAssistantMessage {
                 id,
-                body,
+                messages,
                 error,
-                tool_calls,
                 ..
             }) => {
-                let assistant_body = if body.text.is_empty() {
-                    None
-                } else {
-                    Some(
-                        div()
-                            .child(body.element(ElementId::from(id.0), cx))
-                            .into_any_element(),
-                    )
-                };
+                let mut message_elements = Vec::new();
+
+                for message in messages {
+                    if !message.body.text.is_empty() {
+                        message_elements.push(
+                            div()
+                                // todo!(): The element Id will need to be a combo of the base ID and the index within the grouping
+                                .child(message.body.element(ElementId::from(id.0), cx))
+                                .into_any_element(),
+                        )
+                    }
 
-                let tools = tool_calls
-                    .iter()
-                    .map(|tool_call| self.tool_registry.render_tool_call(tool_call, cx))
-                    .collect::<Vec<AnyElement>>();
+                    let tools = message
+                        .tool_calls
+                        .iter()
+                        .map(|tool_call| self.tool_registry.render_tool_call(tool_call, cx))
+                        .collect::<Vec<AnyElement>>();
 
-                let tools_body = if tools.is_empty() {
-                    None
-                } else {
-                    Some(div().children(tools).into_any_element())
-                };
+                    if !tools.is_empty() {
+                        message_elements.push(div().children(tools).into_any_element())
+                    }
+                }
 
                 div()
                     .when(is_first, |this| this.pt(padding))
@@ -753,8 +780,7 @@ impl AssistantChat {
                         crate::ui::ChatMessage::new(
                             *id,
                             UserOrAssistant::Assistant,
-                            assistant_body,
-                            tools_body,
+                            message_elements,
                             self.is_message_collapsed(id),
                             Box::new(cx.listener({
                                 let id = *id;
@@ -796,46 +822,47 @@ impl AssistantChat {
                         content: body.read(cx).text(cx),
                     });
                 }
-                ChatMessage::Assistant(AssistantMessage {
-                    body, tool_calls, ..
-                }) => {
-                    // In no case do we want to send an empty message. This shouldn't happen, but we might as well
-                    // not break the Chat API if it does.
-                    if body.text.is_empty() && tool_calls.is_empty() {
-                        continue;
-                    }
+                ChatMessage::Assistant(GroupedAssistantMessage { messages, .. }) => {
+                    for message in messages {
+                        let body = message.body.clone();
 
-                    let tool_calls_from_assistant = tool_calls
-                        .iter()
-                        .map(|tool_call| ToolCall {
-                            content: ToolCallContent::Function {
-                                function: FunctionContent {
-                                    name: tool_call.name.clone(),
-                                    arguments: tool_call.arguments.clone(),
-                                },
-                            },
-                            id: tool_call.id.clone(),
-                        })
-                        .collect();
-
-                    completion_messages.push(CompletionMessage::Assistant {
-                        content: Some(body.text.to_string()),
-                        tool_calls: tool_calls_from_assistant,
-                    });
+                        if body.text.is_empty() && message.tool_calls.is_empty() {
+                            continue;
+                        }
 
-                    for tool_call in tool_calls {
-                        // Every tool call _must_ have a result by ID, otherwise OpenAI will error.
-                        let content = match &tool_call.result {
-                            Some(result) => {
-                                result.generate(&tool_call.name, &mut project_context, cx)
-                            }
-                            None => "".to_string(),
-                        };
+                        let tool_calls_from_assistant = message
+                            .tool_calls
+                            .iter()
+                            .map(|tool_call| ToolCall {
+                                content: ToolCallContent::Function {
+                                    function: FunctionContent {
+                                        name: tool_call.name.clone(),
+                                        arguments: tool_call.arguments.clone(),
+                                    },
+                                },
+                                id: tool_call.id.clone(),
+                            })
+                            .collect();
 
-                        completion_messages.push(CompletionMessage::Tool {
-                            content,
-                            tool_call_id: tool_call.id.clone(),
+                        completion_messages.push(CompletionMessage::Assistant {
+                            content: Some(body.text.to_string()),
+                            tool_calls: tool_calls_from_assistant,
                         });
+
+                        for tool_call in &message.tool_calls {
+                            // Every tool call _must_ have a result by ID, otherwise OpenAI will error.
+                            let content = match &tool_call.result {
+                                Some(result) => {
+                                    result.generate(&tool_call.name, &mut project_context, cx)
+                                }
+                                None => "".to_string(),
+                            };
+
+                            completion_messages.push(CompletionMessage::Tool {
+                                content,
+                                tool_call_id: tool_call.id.clone(),
+                            });
+                        }
                     }
                 }
             }
@@ -885,7 +912,7 @@ impl MessageId {
 
 enum ChatMessage {
     User(UserMessage),
-    Assistant(AssistantMessage),
+    Assistant(GroupedAssistantMessage),
 }
 
 impl ChatMessage {
@@ -904,8 +931,12 @@ struct UserMessage {
 }
 
 struct AssistantMessage {
-    id: MessageId,
     body: RichText,
     tool_calls: Vec<ToolFunctionCall>,
+}
+
+struct GroupedAssistantMessage {
+    id: MessageId,
+    messages: Vec<AssistantMessage>,
     error: Option<SharedString>,
 }

crates/assistant2/src/ui/chat_message.rs 🔗

@@ -15,8 +15,7 @@ pub enum UserOrAssistant {
 pub struct ChatMessage {
     id: MessageId,
     player: UserOrAssistant,
-    message: Option<AnyElement>,
-    tools_used: Option<AnyElement>,
+    messages: Vec<AnyElement>,
     selected: bool,
     collapsed: bool,
     on_collapse_handle_click: Box<dyn Fn(&ClickEvent, &mut WindowContext) + 'static>,
@@ -26,16 +25,14 @@ impl ChatMessage {
     pub fn new(
         id: MessageId,
         player: UserOrAssistant,
-        message: Option<AnyElement>,
-        tools_used: Option<AnyElement>,
+        messages: Vec<AnyElement>,
         collapsed: bool,
         on_collapse_handle_click: Box<dyn Fn(&ClickEvent, &mut WindowContext) + 'static>,
     ) -> Self {
         Self {
             id,
             player,
-            message,
-            tools_used,
+            messages,
             selected: false,
             collapsed,
             on_collapse_handle_click,
@@ -117,19 +114,10 @@ impl RenderOnce for ChatMessage {
                             .icon_color(Color::Muted)
                             .on_click(self.on_collapse_handle_click)
                             .tooltip(|cx| Tooltip::text("Collapse Message", cx)),
-                        ), // .child(
-                           //     IconButton::new("copy-message", IconName::Copy)
-                           //         .icon_color(Color::Muted)
-                           //         .icon_size(IconSize::XSmall),
-                           // )
-                           // .child(
-                           //     IconButton::new("menu", IconName::Ellipsis)
-                           //         .icon_color(Color::Muted)
-                           //         .icon_size(IconSize::XSmall),
-                           // ),
+                        ),
                     ),
             )
-            .when(self.message.is_some() || self.tools_used.is_some(), |el| {
+            .when(self.messages.len() > 0, |el| {
                 el.child(
                     h_flex().child(
                         v_flex()
@@ -144,8 +132,7 @@ impl RenderOnce for ChatMessage {
                                 this.bg(background_color)
                             })
                             .when(self.collapsed, |this| this.h(collapsed_height))
-                            .children(self.message)
-                            .when_some(self.tools_used, |this, tools_used| this.child(tools_used)),
+                            .children(self.messages),
                     ),
                 )
             })

crates/assistant2/src/ui/stories/chat_message.rs 🔗

@@ -28,8 +28,7 @@ impl Render for ChatMessageStory {
                     ChatMessage::new(
                         MessageId(0),
                         UserOrAssistant::User(Some(user_1.clone())),
-                        Some(div().child("What can I do here?").into_any_element()),
-                        None,
+                        vec![div().child("What can I do here?").into_any_element()],
                         false,
                         Box::new(|_, _| {}),
                     ),
@@ -39,8 +38,7 @@ impl Render for ChatMessageStory {
                     ChatMessage::new(
                         MessageId(0),
                         UserOrAssistant::User(Some(user_1.clone())),
-                        Some(div().child("What can I do here?").into_any_element()),
-                        None,
+                        vec![div().child("What can I do here?").into_any_element()],
                         true,
                         Box::new(|_, _| {}),
                     ),
@@ -53,8 +51,7 @@ impl Render for ChatMessageStory {
                     ChatMessage::new(
                         MessageId(0),
                         UserOrAssistant::Assistant,
-                        Some(div().child("You can talk to me!").into_any_element()),
-                        None,
+                        vec![div().child("You can talk to me!").into_any_element()],
                         false,
                         Box::new(|_, _| {}),
                     ),
@@ -64,8 +61,7 @@ impl Render for ChatMessageStory {
                     ChatMessage::new(
                         MessageId(0),
                         UserOrAssistant::Assistant,
-                        Some(div().child(MULTI_LINE_MESSAGE).into_any_element()),
-                        None,
+                        vec![div().child(MULTI_LINE_MESSAGE).into_any_element()],
                         true,
                         Box::new(|_, _| {}),
                     ),
@@ -79,24 +75,21 @@ impl Render for ChatMessageStory {
                     .child(ChatMessage::new(
                         MessageId(0),
                         UserOrAssistant::User(Some(user_1.clone())),
-                        Some(div().child("What is Rust??").into_any_element()),
-                        None,
+                        vec![div().child("What is Rust??").into_any_element()],
                         false,
                         Box::new(|_, _| {}),
                     ))
                     .child(ChatMessage::new(
                         MessageId(0),
                         UserOrAssistant::Assistant,
-                        Some(div().child("Rust is a multi-paradigm programming language focused on performance and safety").into_any_element()),
-                        None,
+                        vec![div().child("Rust is a multi-paradigm programming language focused on performance and safety").into_any_element()],
                         false,
                         Box::new(|_, _| {}),
                     ))
                     .child(ChatMessage::new(
                         MessageId(0),
                         UserOrAssistant::User(Some(user_1)),
-                        Some(div().child("Sounds pretty cool!").into_any_element()),
-                        None,
+                        vec![div().child("Sounds pretty cool!").into_any_element()],
                         false,
                         Box::new(|_, _| {}),
                     )),