Render thinking blocks

Nathan Sobo and Conrad Irwin created

For now, we style them the same as messages. Next up we'll improve the
styling.

Co-authored-by: Conrad Irwin <conrad.irwin@gmail.com>

Change summary

crates/acp/src/acp.rs         | 133 ++++++++++++++++++++----------------
crates/acp/src/server.rs      |  12 +-
crates/acp/src/thread_view.rs |  75 +++++++++++---------
3 files changed, 122 insertions(+), 98 deletions(-)

Detailed changes

crates/acp/src/acp.rs 🔗

@@ -1,7 +1,7 @@
 mod server;
 mod thread_view;
 
-use agentic_coding_protocol::{self as acp, Role};
+use agentic_coding_protocol::{self as acp};
 use anyhow::{Context as _, Result};
 use buffer_diff::BufferDiff;
 use chrono::{DateTime, Utc};
@@ -39,15 +39,13 @@ pub struct FileContent {
 }
 
 #[derive(Clone, Debug, Eq, PartialEq)]
-pub struct Message {
-    pub role: acp::Role,
-    pub chunks: Vec<MessageChunk>,
+pub struct UserMessage {
+    pub chunks: Vec<UserMessageChunk>,
 }
 
-impl Message {
-    fn into_acp(self, cx: &App) -> acp::Message {
-        acp::Message {
-            role: self.role,
+impl UserMessage {
+    fn into_acp(self, cx: &App) -> acp::UserMessage {
+        acp::UserMessage {
             chunks: self
                 .chunks
                 .into_iter()
@@ -58,7 +56,7 @@ impl Message {
 }
 
 #[derive(Clone, Debug, Eq, PartialEq)]
-pub enum MessageChunk {
+pub enum UserMessageChunk {
     Text {
         chunk: Entity<Markdown>,
     },
@@ -82,33 +80,57 @@ pub enum MessageChunk {
     },
 }
 
-impl MessageChunk {
+impl UserMessageChunk {
+    pub fn into_acp(self, cx: &App) -> acp::UserMessageChunk {
+        match self {
+            Self::Text { chunk } => acp::UserMessageChunk::Text {
+                chunk: chunk.read(cx).source().to_string(),
+            },
+            Self::File { .. } => todo!(),
+            Self::Directory { .. } => todo!(),
+            Self::Symbol { .. } => todo!(),
+            Self::Fetch { .. } => todo!(),
+        }
+    }
+
+    pub fn from_str(chunk: &str, language_registry: Arc<LanguageRegistry>, cx: &mut App) -> Self {
+        Self::Text {
+            chunk: cx.new(|cx| {
+                Markdown::new(chunk.to_owned().into(), Some(language_registry), None, cx)
+            }),
+        }
+    }
+}
+
+#[derive(Clone, Debug, Eq, PartialEq)]
+pub struct AssistantMessage {
+    pub chunks: Vec<AssistantMessageChunk>,
+}
+
+#[derive(Clone, Debug, Eq, PartialEq)]
+pub enum AssistantMessageChunk {
+    Text { chunk: Entity<Markdown> },
+    Thought { chunk: Entity<Markdown> },
+}
+
+impl AssistantMessageChunk {
     pub fn from_acp(
-        chunk: acp::MessageChunk,
+        chunk: acp::AssistantMessageChunk,
         language_registry: Arc<LanguageRegistry>,
         cx: &mut App,
     ) -> Self {
         match chunk {
-            acp::MessageChunk::Text { chunk } => MessageChunk::Text {
+            acp::AssistantMessageChunk::Text { chunk } => Self::Text {
                 chunk: cx.new(|cx| Markdown::new(chunk.into(), Some(language_registry), None, cx)),
             },
-        }
-    }
-
-    pub fn into_acp(self, cx: &App) -> acp::MessageChunk {
-        match self {
-            MessageChunk::Text { chunk } => acp::MessageChunk::Text {
-                chunk: chunk.read(cx).source().to_string(),
+            acp::AssistantMessageChunk::Thought { chunk } => Self::Thought {
+                chunk: cx.new(|cx| Markdown::new(chunk.into(), Some(language_registry), None, cx)),
             },
-            MessageChunk::File { .. } => todo!(),
-            MessageChunk::Directory { .. } => todo!(),
-            MessageChunk::Symbol { .. } => todo!(),
-            MessageChunk::Fetch { .. } => todo!(),
         }
     }
 
     pub fn from_str(chunk: &str, language_registry: Arc<LanguageRegistry>, cx: &mut App) -> Self {
-        MessageChunk::Text {
+        Self::Text {
             chunk: cx.new(|cx| {
                 Markdown::new(chunk.to_owned().into(), Some(language_registry), None, cx)
             }),
@@ -118,7 +140,8 @@ impl MessageChunk {
 
 #[derive(Debug)]
 pub enum AgentThreadEntryContent {
-    Message(Message),
+    UserMessage(UserMessage),
+    AssistantMessage(AssistantMessage),
     ToolCall(ToolCall),
 }
 
@@ -412,26 +435,28 @@ impl AcpThread {
         id
     }
 
-    pub fn push_assistant_chunk(&mut self, chunk: acp::MessageChunk, cx: &mut Context<Self>) {
+    pub fn push_assistant_chunk(
+        &mut self,
+        chunk: acp::AssistantMessageChunk,
+        cx: &mut Context<Self>,
+    ) {
         let entries_len = self.entries.len();
         if let Some(last_entry) = self.entries.last_mut()
-            && let AgentThreadEntryContent::Message(Message {
-                ref mut chunks,
-                role: Role::Assistant,
-            }) = last_entry.content
+            && let AgentThreadEntryContent::AssistantMessage(AssistantMessage { ref mut chunks }) =
+                last_entry.content
         {
             cx.emit(AcpThreadEvent::EntryUpdated(entries_len - 1));
 
             if let (
-                Some(MessageChunk::Text { chunk: old_chunk }),
-                acp::MessageChunk::Text { chunk: new_chunk },
+                Some(AssistantMessageChunk::Text { chunk: old_chunk }),
+                acp::AssistantMessageChunk::Text { chunk: new_chunk },
             ) = (chunks.last_mut(), &chunk)
             {
                 old_chunk.update(cx, |old_chunk, cx| {
                     old_chunk.append(&new_chunk, cx);
                 });
             } else {
-                chunks.push(MessageChunk::from_acp(
+                chunks.push(AssistantMessageChunk::from_acp(
                     chunk,
                     self.project.read(cx).languages().clone(),
                     cx,
@@ -441,11 +466,11 @@ impl AcpThread {
             return;
         }
 
-        let chunk = MessageChunk::from_acp(chunk, self.project.read(cx).languages().clone(), cx);
+        let chunk =
+            AssistantMessageChunk::from_acp(chunk, self.project.read(cx).languages().clone(), cx);
 
         self.push_entry(
-            AgentThreadEntryContent::Message(Message {
-                role: Role::Assistant,
+            AgentThreadEntryContent::AssistantMessage(AssistantMessage {
                 chunks: vec![chunk],
             }),
             cx,
@@ -603,7 +628,8 @@ impl AcpThread {
                     ToolCallStatus::WaitingForConfirmation { .. } => return true,
                     ToolCallStatus::Allowed { .. } | ToolCallStatus::Rejected => continue,
                 },
-                AgentThreadEntryContent::Message(_) => {
+                AgentThreadEntryContent::UserMessage(_)
+                | AgentThreadEntryContent::AssistantMessage(_) => {
                     // Reached the beginning of the turn
                     return false;
                 }
@@ -615,12 +641,12 @@ impl AcpThread {
     pub fn send(&mut self, message: &str, cx: &mut Context<Self>) -> Task<Result<()>> {
         let agent = self.server.clone();
         let id = self.id.clone();
-        let chunk = MessageChunk::from_str(message, self.project.read(cx).languages().clone(), cx);
-        let message = Message {
-            role: Role::User,
+        let chunk =
+            UserMessageChunk::from_str(message, self.project.read(cx).languages().clone(), cx);
+        let message = UserMessage {
             chunks: vec![chunk],
         };
-        self.push_entry(AgentThreadEntryContent::Message(message.clone()), cx);
+        self.push_entry(AgentThreadEntryContent::UserMessage(message.clone()), cx);
         let acp_message = message.into_acp(cx);
         cx.spawn(async move |_, cx| {
             agent.send_message(id, acp_message, cx).await?;
@@ -688,17 +714,11 @@ mod tests {
             assert_eq!(thread.entries.len(), 2);
             assert!(matches!(
                 thread.entries[0].content,
-                AgentThreadEntryContent::Message(Message {
-                    role: Role::User,
-                    ..
-                })
+                AgentThreadEntryContent::UserMessage(_)
             ));
             assert!(matches!(
                 thread.entries[1].content,
-                AgentThreadEntryContent::Message(Message {
-                    role: Role::Assistant,
-                    ..
-                })
+                AgentThreadEntryContent::AssistantMessage(_)
             ));
         });
     }
@@ -729,7 +749,7 @@ mod tests {
             .unwrap();
         thread.read_with(cx, |thread, _cx| {
             assert!(matches!(
-                &thread.entries()[1].content,
+                &thread.entries()[2].content,
                 AgentThreadEntryContent::ToolCall(ToolCall {
                     status: ToolCallStatus::Allowed { .. },
                     ..
@@ -737,11 +757,8 @@ mod tests {
             ));
 
             assert!(matches!(
-                thread.entries[2].content,
-                AgentThreadEntryContent::Message(Message {
-                    role: Role::Assistant,
-                    ..
-                })
+                thread.entries[3].content,
+                AgentThreadEntryContent::AssistantMessage(_)
             ));
         });
     }
@@ -771,7 +788,7 @@ mod tests {
                         ..
                     },
                 ..
-            }) = &thread.entries()[1].content
+            }) = &thread.entries()[2].content
             else {
                 panic!();
             };
@@ -785,7 +802,7 @@ mod tests {
             thread.authorize_tool_call(tool_call_id, acp::ToolCallConfirmationOutcome::Allow, cx);
 
             assert!(matches!(
-                &thread.entries()[1].content,
+                &thread.entries()[2].content,
                 AgentThreadEntryContent::ToolCall(ToolCall {
                     status: ToolCallStatus::Allowed { .. },
                     ..
@@ -800,7 +817,7 @@ mod tests {
                 content: Some(ToolCallContent::Markdown { markdown }),
                 status: ToolCallStatus::Allowed { .. },
                 ..
-            }) = &thread.entries()[1].content
+            }) = &thread.entries()[2].content
             else {
                 panic!();
             };

crates/acp/src/server.rs 🔗

@@ -56,10 +56,10 @@ impl AcpClientDelegate {
 
 #[async_trait(?Send)]
 impl acp::Client for AcpClientDelegate {
-    async fn stream_message_chunk(
+    async fn stream_assistant_message_chunk(
         &self,
-        params: acp::StreamMessageChunkParams,
-    ) -> Result<acp::StreamMessageChunkResponse> {
+        params: acp::StreamAssistantMessageChunkParams,
+    ) -> Result<acp::StreamAssistantMessageChunkResponse> {
         let cx = &mut self.cx.clone();
 
         cx.update(|cx| {
@@ -68,7 +68,7 @@ impl acp::Client for AcpClientDelegate {
             });
         })?;
 
-        Ok(acp::StreamMessageChunkResponse)
+        Ok(acp::StreamAssistantMessageChunkResponse)
     }
 
     async fn request_tool_call_confirmation(
@@ -209,11 +209,11 @@ impl AcpServer {
     pub async fn send_message(
         &self,
         thread_id: ThreadId,
-        message: acp::Message,
+        message: acp::UserMessage,
         _cx: &mut AsyncApp,
     ) -> Result<()> {
         self.connection
-            .request(acp::SendMessageParams {
+            .request(acp::SendUserMessageParams {
                 thread_id: thread_id.clone().into(),
                 message,
             })

crates/acp/src/thread_view.rs 🔗

@@ -24,8 +24,9 @@ use util::{ResultExt, paths};
 use zed_actions::agent::Chat;
 
 use crate::{
-    AcpServer, AcpThread, AcpThreadEvent, AgentThreadEntryContent, Diff, MessageChunk, Role,
-    ThreadEntry, ToolCall, ToolCallConfirmation, ToolCallContent, ToolCallId, ToolCallStatus,
+    AcpServer, AcpThread, AcpThreadEvent, AgentThreadEntryContent, AssistantMessage,
+    AssistantMessageChunk, Diff, ThreadEntry, ToolCall, ToolCallConfirmation, ToolCallContent,
+    ToolCallId, ToolCallStatus, UserMessageChunk,
 };
 
 pub struct AcpThreadView {
@@ -390,45 +391,51 @@ impl AcpThreadView {
         cx: &Context<Self>,
     ) -> AnyElement {
         match &entry.content {
-            AgentThreadEntryContent::Message(message) => {
-                let style = if message.role == Role::User {
-                    user_message_markdown_style(window, cx)
-                } else {
-                    default_markdown_style(window, cx)
-                };
+            AgentThreadEntryContent::UserMessage(message) => {
+                let style = user_message_markdown_style(window, cx);
+                let message_body = div().children(message.chunks.iter().map(|chunk| match chunk {
+                    UserMessageChunk::Text { chunk } => {
+                        // todo!() open link
+                        MarkdownElement::new(chunk.clone(), style.clone())
+                    }
+                    _ => todo!(),
+                }));
+                div()
+                    .p_2()
+                    .pt_5()
+                    .child(
+                        div()
+                            .text_xs()
+                            .p_3()
+                            .bg(cx.theme().colors().editor_background)
+                            .rounded_lg()
+                            .shadow_md()
+                            .border_1()
+                            .border_color(cx.theme().colors().border)
+                            .child(message_body),
+                    )
+                    .into_any()
+            }
+            AgentThreadEntryContent::AssistantMessage(AssistantMessage { chunks }) => {
+                let style = default_markdown_style(window, cx);
                 let message_body = div()
-                    .children(message.chunks.iter().map(|chunk| match chunk {
-                        MessageChunk::Text { chunk } => {
+                    .children(chunks.iter().map(|chunk| match chunk {
+                        AssistantMessageChunk::Text { chunk } => {
                             // todo!() open link
                             MarkdownElement::new(chunk.clone(), style.clone())
                         }
-                        _ => todo!(),
+                        AssistantMessageChunk::Thought { chunk } => {
+                            MarkdownElement::new(chunk.clone(), style.clone())
+                        }
                     }))
                     .into_any();
 
-                match message.role {
-                    Role::User => div()
-                        .p_2()
-                        .pt_5()
-                        .child(
-                            div()
-                                .text_xs()
-                                .p_3()
-                                .bg(cx.theme().colors().editor_background)
-                                .rounded_lg()
-                                .shadow_md()
-                                .border_1()
-                                .border_color(cx.theme().colors().border)
-                                .child(message_body),
-                        )
-                        .into_any(),
-                    Role::Assistant => div()
-                        .text_ui(cx)
-                        .p_5()
-                        .pt_2()
-                        .child(message_body)
-                        .into_any(),
-                }
+                div()
+                    .text_ui(cx)
+                    .p_5()
+                    .pt_2()
+                    .child(message_body)
+                    .into_any()
             }
             AgentThreadEntryContent::ToolCall(tool_call) => div()
                 .px_2()