assistant2: Render messages as Markdown (#21496)

Marshall Bowers created

This PR updates Assistant 2 to render the messages in the thread as
Markdown:

<img width="1138" alt="Screenshot 2024-12-03 at 6 09 27 PM"
src="https://github.com/user-attachments/assets/c1c44fde-1efb-43cf-b9c9-768e6974c753">

Release Notes:

- N/A

Change summary

Cargo.lock                               |  2 
crates/assistant2/Cargo.toml             |  4 +
crates/assistant2/src/assistant_panel.rs | 74 ++++++++++++++++++++++++-
crates/assistant2/src/thread.rs          |  5 +
4 files changed, 81 insertions(+), 4 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -464,10 +464,12 @@ dependencies = [
  "feature_flags",
  "futures 0.3.31",
  "gpui",
+ "language",
  "language_model",
  "language_model_selector",
  "language_models",
  "log",
+ "markdown",
  "project",
  "proto",
  "serde",

crates/assistant2/Cargo.toml 🔗

@@ -23,15 +23,17 @@ editor.workspace = true
 feature_flags.workspace = true
 futures.workspace = true
 gpui.workspace = true
+language.workspace = true
 language_model.workspace = true
 language_model_selector.workspace = true
 language_models.workspace = true
 log.workspace = true
+markdown.workspace = true
 project.workspace = true
 proto.workspace = true
-settings.workspace = true
 serde.workspace = true
 serde_json.workspace = true
+settings.workspace = true
 smol.workspace = true
 theme.workspace = true
 ui.workspace = true

crates/assistant2/src/assistant_panel.rs 🔗

@@ -3,13 +3,19 @@ use std::sync::Arc;
 use anyhow::Result;
 use assistant_tool::ToolWorkingSet;
 use client::zed_urls;
+use collections::HashMap;
 use gpui::{
     list, prelude::*, px, Action, AnyElement, AppContext, AsyncWindowContext, Empty, EventEmitter,
-    FocusHandle, FocusableView, FontWeight, ListAlignment, ListState, Model, Pixels, Subscription,
-    Task, View, ViewContext, WeakView, WindowContext,
+    FocusHandle, FocusableView, FontWeight, ListAlignment, ListState, Model, Pixels,
+    StyleRefinement, Subscription, Task, TextStyleRefinement, View, ViewContext, WeakView,
+    WindowContext,
 };
+use language::LanguageRegistry;
 use language_model::{LanguageModelRegistry, Role};
 use language_model_selector::LanguageModelSelector;
+use markdown::{Markdown, MarkdownStyle};
+use settings::Settings;
+use theme::ThemeSettings;
 use ui::{prelude::*, ButtonLike, Divider, IconButtonShape, Tab, Tooltip};
 use workspace::dock::{DockPosition, Panel, PanelEvent};
 use workspace::Workspace;
@@ -32,10 +38,12 @@ pub fn init(cx: &mut AppContext) {
 
 pub struct AssistantPanel {
     workspace: WeakView<Workspace>,
+    language_registry: Arc<LanguageRegistry>,
     #[allow(unused)]
     thread_store: Model<ThreadStore>,
     thread: Model<Thread>,
     thread_messages: Vec<MessageId>,
+    rendered_messages_by_id: HashMap<MessageId, View<Markdown>>,
     thread_list_state: ListState,
     message_editor: View<MessageEditor>,
     tools: Arc<ToolWorkingSet>,
@@ -77,9 +85,11 @@ impl AssistantPanel {
 
         Self {
             workspace: workspace.weak_handle(),
+            language_registry: workspace.project().read(cx).languages().clone(),
             thread_store,
             thread: thread.clone(),
             thread_messages: Vec::new(),
+            rendered_messages_by_id: HashMap::default(),
             thread_list_state: ListState::new(0, ListAlignment::Bottom, px(1024.), {
                 let this = cx.view().downgrade();
                 move |ix, cx: &mut WindowContext| {
@@ -104,6 +114,9 @@ impl AssistantPanel {
 
         self.message_editor = cx.new_view(|cx| MessageEditor::new(thread.clone(), cx));
         self.thread = thread;
+        self.thread_messages.clear();
+        self.thread_list_state.reset(0);
+        self.rendered_messages_by_id.clear();
         self._subscriptions = subscriptions;
 
         self.message_editor.focus_handle(cx).focus(cx);
@@ -120,10 +133,61 @@ impl AssistantPanel {
                 self.last_error = Some(error.clone());
             }
             ThreadEvent::StreamedCompletion => {}
+            ThreadEvent::StreamedAssistantText(message_id, text) => {
+                if let Some(markdown) = self.rendered_messages_by_id.get_mut(&message_id) {
+                    markdown.update(cx, |markdown, cx| {
+                        markdown.append(text, cx);
+                    });
+                }
+            }
             ThreadEvent::MessageAdded(message_id) => {
                 let old_len = self.thread_messages.len();
                 self.thread_messages.push(*message_id);
                 self.thread_list_state.splice(old_len..old_len, 1);
+
+                if let Some(message_text) = self
+                    .thread
+                    .read(cx)
+                    .message(*message_id)
+                    .map(|message| message.text.clone())
+                {
+                    let theme_settings = ThemeSettings::get_global(cx);
+
+                    let mut text_style = cx.text_style();
+                    text_style.refine(&TextStyleRefinement {
+                        font_family: Some(theme_settings.ui_font.family.clone()),
+                        font_size: Some(TextSize::Default.rems(cx).into()),
+                        color: Some(cx.theme().colors().text),
+                        ..Default::default()
+                    });
+
+                    let markdown_style = MarkdownStyle {
+                        base_text_style: text_style,
+                        syntax: cx.theme().syntax().clone(),
+                        selection_background_color: cx.theme().players().local().selection,
+                        code_block: StyleRefinement {
+                            text: Some(TextStyleRefinement {
+                                font_family: Some(theme_settings.buffer_font.family.clone()),
+                                font_size: Some(theme_settings.buffer_font_size.into()),
+                                ..Default::default()
+                            }),
+                            ..Default::default()
+                        },
+                        ..Default::default()
+                    };
+
+                    let markdown = cx.new_view(|cx| {
+                        Markdown::new(
+                            message_text,
+                            markdown_style,
+                            Some(self.language_registry.clone()),
+                            None,
+                            cx,
+                        )
+                    });
+                    self.rendered_messages_by_id.insert(*message_id, markdown);
+                }
+
                 cx.notify();
             }
             ThreadEvent::UsePendingTools => {
@@ -323,6 +387,10 @@ impl AssistantPanel {
             return Empty.into_any();
         };
 
+        let Some(markdown) = self.rendered_messages_by_id.get(&message_id) else {
+            return Empty.into_any();
+        };
+
         let (role_icon, role_name) = match message.role {
             Role::User => (IconName::Person, "You"),
             Role::Assistant => (IconName::ZedAssistant, "Assistant"),
@@ -350,7 +418,7 @@ impl AssistantPanel {
                                     .child(Label::new(role_name).size(LabelSize::Small)),
                             ),
                     )
-                    .child(v_flex().p_1p5().child(Label::new(message.text.clone()))),
+                    .child(v_flex().p_1p5().text_ui(cx).child(markdown.clone())),
             )
             .into_any()
     }

crates/assistant2/src/thread.rs 🔗

@@ -167,6 +167,10 @@ impl Thread {
                                 if let Some(last_message) = thread.messages.last_mut() {
                                     if last_message.role == Role::Assistant {
                                         last_message.text.push_str(&chunk);
+                                        cx.emit(ThreadEvent::StreamedAssistantText(
+                                            last_message.id,
+                                            chunk,
+                                        ));
                                     }
                                 }
                             }
@@ -320,6 +324,7 @@ pub enum ThreadError {
 pub enum ThreadEvent {
     ShowError(ThreadError),
     StreamedCompletion,
+    StreamedAssistantText(MessageId, String),
     MessageAdded(MessageId),
     UsePendingTools,
     ToolFinished {