Replace `rich_text` with `markdown` in `assistant2` (#11650)

Antonio Scandurra created

We don't implement copy yet but it should be pretty straightforward to
add.


https://github.com/zed-industries/zed/assets/482957/6b4d7c34-de6b-4b07-aed9-608c771bbbdb

/cc: @rgbkrk @maxbrunsfeld @maxdeviant 

Release Notes:

- N/A

Change summary

Cargo.lock                               |   2 
crates/assistant2/Cargo.toml             |   3 
crates/assistant2/src/assistant2.rs      | 279 +++++++++++++++----------
crates/assistant2/src/ui/chat_message.rs |   2 
crates/markdown/src/markdown.rs          |  35 ++
5 files changed, 204 insertions(+), 117 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -390,6 +390,7 @@ dependencies = [
  "language",
  "languages",
  "log",
+ "markdown",
  "node_runtime",
  "open_ai",
  "picker",
@@ -397,7 +398,6 @@ dependencies = [
  "rand 0.8.5",
  "regex",
  "release_channel",
- "rich_text",
  "schemars",
  "semantic_index",
  "serde",

crates/assistant2/Cargo.toml 🔗

@@ -29,11 +29,11 @@ fuzzy.workspace = true
 gpui.workspace = true
 language.workspace = true
 log.workspace = true
+markdown.workspace = true
 open_ai.workspace = true
 picker.workspace = true
 project.workspace = true
 regex.workspace = true
-rich_text.workspace = true
 schemars.workspace = true
 semantic_index.workspace = true
 serde.workspace = true
@@ -52,6 +52,7 @@ env_logger.workspace = true
 gpui = { workspace = true, features = ["test-support"] }
 language = { workspace = true, features = ["test-support"] }
 languages.workspace = true
+markdown = { workspace = true, features = ["test-support"] }
 node_runtime.workspace = true
 project = { workspace = true, features = ["test-support"] }
 rand.workspace = true

crates/assistant2/src/assistant2.rs 🔗

@@ -26,8 +26,8 @@ use gpui::{
     FocusableView, ListAlignment, ListState, Model, Render, Task, View, WeakView,
 };
 use language::{language_settings::SoftWrap, LanguageRegistry};
+use markdown::{Markdown, MarkdownStyle};
 use open_ai::{FunctionContent, ToolCall, ToolCallContent};
-use rich_text::RichText;
 use saved_conversation::{SavedAssistantMessagePart, SavedChatMessage, SavedConversation};
 use saved_conversations::SavedConversations;
 use semantic_index::{CloudEmbeddingProvider, ProjectIndex, ProjectIndexDebugView, SemanticIndex};
@@ -261,11 +261,11 @@ pub struct AssistantChat {
     tool_registry: Arc<ToolRegistry>,
     attachment_registry: Arc<AttachmentRegistry>,
     project_index: Model<ProjectIndex>,
+    markdown_style: MarkdownStyle,
 }
 
 struct EditingMessage {
     id: MessageId,
-    old_body: Arc<str>,
     body: View<Editor>,
 }
 
@@ -348,21 +348,49 @@ impl AssistantChat {
             pending_completion: None,
             attachment_registry,
             tool_registry,
+            markdown_style: MarkdownStyle {
+                code_block: gpui::TextStyleRefinement {
+                    font_family: Some("Zed Mono".into()),
+                    color: Some(cx.theme().colors().editor_foreground),
+                    background_color: Some(cx.theme().colors().editor_background),
+                    ..Default::default()
+                },
+                inline_code: gpui::TextStyleRefinement {
+                    font_family: Some("Zed Mono".into()),
+                    // @nate: Could we add inline-code specific styles to the theme?
+                    color: Some(cx.theme().colors().editor_foreground),
+                    background_color: Some(cx.theme().colors().editor_background),
+                    ..Default::default()
+                },
+                rule_color: Color::Muted.color(cx),
+                block_quote_border_color: Color::Muted.color(cx),
+                block_quote: gpui::TextStyleRefinement {
+                    color: Some(Color::Muted.color(cx)),
+                    ..Default::default()
+                },
+                link: gpui::TextStyleRefinement {
+                    color: Some(Color::Accent.color(cx)),
+                    underline: Some(gpui::UnderlineStyle {
+                        thickness: px(1.),
+                        color: Some(Color::Accent.color(cx)),
+                        wavy: false,
+                    }),
+                    ..Default::default()
+                },
+                syntax: cx.theme().syntax().clone(),
+                selection_background_color: {
+                    let mut selection = cx.theme().players().local().selection;
+                    selection.fade_out(0.7);
+                    selection
+                },
+            },
         }
     }
 
-    fn editing_message_id(&self) -> Option<MessageId> {
-        self.editing_message.as_ref().map(|message| message.id)
-    }
-
-    fn focused_message_id(&self, cx: &WindowContext) -> Option<MessageId> {
-        self.messages.iter().find_map(|message| match message {
-            ChatMessage::User(message) => message
-                .body
-                .focus_handle(cx)
-                .contains_focused(cx)
-                .then_some(message.id),
-            ChatMessage::Assistant(_) => None,
+    fn message_for_id(&self, id: MessageId) -> Option<&ChatMessage> {
+        self.messages.iter().find(|message| match message {
+            ChatMessage::User(message) => message.id == id,
+            ChatMessage::Assistant(message) => message.id == id,
         })
     }
 
@@ -372,10 +400,8 @@ impl AssistantChat {
 
     fn cancel(&mut self, _: &Cancel, cx: &mut ViewContext<Self>) {
         // If we're currently editing a message, cancel the edit.
-        if let Some(editing_message) = self.editing_message.take() {
-            editing_message
-                .body
-                .update(cx, |body, cx| body.set_text(editing_message.old_body, cx));
+        if self.editing_message.take().is_some() {
+            cx.notify();
             return;
         }
 
@@ -392,14 +418,7 @@ impl AssistantChat {
     }
 
     fn submit(&mut self, Submit(mode): &Submit, cx: &mut ViewContext<Self>) {
-        if let Some(focused_message_id) = self.focused_message_id(cx) {
-            self.truncate_messages(focused_message_id, cx);
-            self.pending_completion.take();
-            self.composer_editor.focus_handle(cx).focus(cx);
-            if self.editing_message_id() == Some(focused_message_id) {
-                self.editing_message.take();
-            }
-        } else if self.composer_editor.focus_handle(cx).is_focused(cx) {
+        if self.composer_editor.focus_handle(cx).is_focused(cx) {
             // Don't allow multiple concurrent completions.
             if self.pending_completion.is_some() {
                 cx.propagate();
@@ -410,10 +429,12 @@ impl AssistantChat {
                 let text = composer_editor.text(cx);
                 let id = self.next_message_id.post_inc();
                 let body = cx.new_view(|cx| {
-                    let mut editor = Editor::auto_height(80, cx);
-                    editor.set_text(text, cx);
-                    editor.set_soft_wrap_mode(SoftWrap::EditorWidth, cx);
-                    editor
+                    Markdown::new(
+                        text,
+                        self.markdown_style.clone(),
+                        self.language_registry.clone(),
+                        cx,
+                    )
                 });
                 composer_editor.clear(cx);
 
@@ -424,6 +445,26 @@ impl AssistantChat {
                 })
             });
             self.push_message(message, cx);
+        } else if let Some(editing_message) = self.editing_message.as_ref() {
+            let focus_handle = editing_message.body.focus_handle(cx);
+            if focus_handle.contains_focused(cx) {
+                if let Some(ChatMessage::User(user_message)) =
+                    self.message_for_id(editing_message.id)
+                {
+                    user_message.body.update(cx, |body, cx| {
+                        body.reset(editing_message.body.read(cx).text(cx), cx);
+                    });
+                }
+
+                self.truncate_messages(editing_message.id, cx);
+
+                self.pending_completion.take();
+                self.composer_editor.focus_handle(cx).focus(cx);
+                self.editing_message.take();
+            } else {
+                log::error!("unexpected state: no user message editor is focused.");
+                return;
+            }
         } else {
             log::error!("unexpected state: no user message editor is focused.");
             return;
@@ -512,7 +553,6 @@ impl AssistantChat {
                 });
 
                 let mut stream = completion?.await?;
-                let mut body = String::new();
                 while let Some(delta) = stream.next().await {
                     let delta = delta?;
                     this.update(cx, |this, cx| {
@@ -521,7 +561,14 @@ impl AssistantChat {
                         {
                             if messages.is_empty() {
                                 messages.push(AssistantMessagePart {
-                                    body: RichText::default(),
+                                    body: cx.new_view(|cx| {
+                                        Markdown::new(
+                                            "".into(),
+                                            this.markdown_style.clone(),
+                                            this.language_registry.clone(),
+                                            cx,
+                                        )
+                                    }),
                                     tool_calls: Vec::new(),
                                 })
                             }
@@ -529,7 +576,9 @@ impl AssistantChat {
                             let message = messages.last_mut().unwrap();
 
                             if let Some(content) = &delta.content {
-                                body.push_str(content);
+                                message
+                                    .body
+                                    .update(cx, |message, cx| message.append(&content, cx));
                             }
 
                             for tool_call_delta in delta.tool_calls {
@@ -558,8 +607,6 @@ impl AssistantChat {
                                 }
                             }
 
-                            message.body =
-                                RichText::new(body.clone(), &[], &this.language_registry);
                             cx.notify();
                         } else {
                             unreachable!()
@@ -608,7 +655,14 @@ impl AssistantChat {
             self.messages.last_mut()
         {
             messages.push(AssistantMessagePart {
-                body: RichText::default(),
+                body: cx.new_view(|cx| {
+                    Markdown::new(
+                        "".into(),
+                        self.markdown_style.clone(),
+                        self.language_registry.clone(),
+                        cx,
+                    )
+                }),
                 tool_calls: Vec::new(),
             });
             return;
@@ -617,7 +671,14 @@ impl AssistantChat {
         let message = ChatMessage::Assistant(AssistantMessage {
             id: self.next_message_id.post_inc(),
             messages: vec![AssistantMessagePart {
-                body: RichText::default(),
+                body: cx.new_view(|cx| {
+                    Markdown::new(
+                        "".into(),
+                        self.markdown_style.clone(),
+                        self.language_registry.clone(),
+                        cx,
+                    )
+                }),
                 tool_calls: Vec::new(),
             }],
             error: None,
@@ -760,66 +821,69 @@ impl AssistantChat {
                 .id(SharedString::from(format!("message-{}-container", id.0)))
                 .when(is_first, |this| this.pt(padding))
                 .map(|element| {
-                    if self.editing_message_id() == Some(*id) {
-                        element.child(Composer::new(
-                            body.clone(),
-                            self.project_index_button.clone(),
-                            self.active_file_button.clone(),
-                            crate::ui::ModelSelector::new(
-                                cx.view().downgrade(),
-                                self.model.clone(),
-                            )
-                            .into_any_element(),
-                        ))
-                    } else {
-                        element
-                            .on_click(cx.listener({
-                                let id = *id;
-                                let body = body.clone();
-                                move |assistant_chat, event: &ClickEvent, cx| {
-                                    if event.up.click_count == 2 {
-                                        assistant_chat.editing_message = Some(EditingMessage {
-                                            id,
-                                            body: body.clone(),
-                                            old_body: body.read(cx).text(cx).into(),
-                                        });
-                                        body.focus_handle(cx).focus(cx);
-                                    }
+                    if let Some(editing_message) = self.editing_message.as_ref() {
+                        if editing_message.id == *id {
+                            return element.child(Composer::new(
+                                editing_message.body.clone(),
+                                self.project_index_button.clone(),
+                                self.active_file_button.clone(),
+                                crate::ui::ModelSelector::new(
+                                    cx.view().downgrade(),
+                                    self.model.clone(),
+                                )
+                                .into_any_element(),
+                            ));
+                        }
+                    }
+
+                    element
+                        .on_click(cx.listener({
+                            let id = *id;
+                            let body = body.clone();
+                            move |assistant_chat, event: &ClickEvent, cx| {
+                                if event.up.click_count == 2 {
+                                    let body = cx.new_view(|cx| {
+                                        let mut editor = Editor::auto_height(80, cx);
+                                        let source = Arc::from(body.read(cx).source());
+                                        editor.set_text(source, cx);
+                                        editor.set_soft_wrap_mode(SoftWrap::EditorWidth, cx);
+                                        editor
+                                    });
+                                    assistant_chat.editing_message = Some(EditingMessage {
+                                        id,
+                                        body: body.clone(),
+                                    });
+                                    body.focus_handle(cx).focus(cx);
                                 }
-                            }))
-                            .child(
-                                crate::ui::ChatMessage::new(
-                                    *id,
-                                    UserOrAssistant::User(self.user_store.read(cx).current_user()),
-                                    // todo!(): clean up the vec usage
-                                    vec![
-                                        RichText::new(
-                                            body.read(cx).text(cx),
-                                            &[],
-                                            &self.language_registry,
+                            }
+                        }))
+                        .child(
+                            crate::ui::ChatMessage::new(
+                                *id,
+                                UserOrAssistant::User(self.user_store.read(cx).current_user()),
+                                // todo!(): clean up the vec usage
+                                vec![
+                                    body.clone().into_any_element(),
+                                    h_flex()
+                                        .gap_2()
+                                        .children(
+                                            attachments
+                                                .iter()
+                                                .map(|attachment| attachment.view.clone()),
                                         )
-                                        .element(ElementId::from(id.0), cx),
-                                        h_flex()
-                                            .gap_2()
-                                            .children(
-                                                attachments
-                                                    .iter()
-                                                    .map(|attachment| attachment.view.clone()),
-                                            )
-                                            .into_any_element(),
-                                    ],
-                                    self.is_message_collapsed(id),
-                                    Box::new(cx.listener({
-                                        let id = *id;
-                                        move |assistant_chat, _event, _cx| {
-                                            assistant_chat.toggle_message_collapsed(id)
-                                        }
-                                    })),
-                                )
-                                // TODO: Wire up selections.
-                                .selected(is_last),
+                                        .into_any_element(),
+                                ],
+                                self.is_message_collapsed(id),
+                                Box::new(cx.listener({
+                                    let id = *id;
+                                    move |assistant_chat, _event, _cx| {
+                                        assistant_chat.toggle_message_collapsed(id)
+                                    }
+                                })),
                             )
-                    }
+                            // TODO: Wire up selections.
+                            .selected(is_last),
+                        )
                 })
                 .into_any(),
             ChatMessage::Assistant(AssistantMessage {
@@ -831,13 +895,8 @@ impl AssistantChat {
                 let mut message_elements = Vec::new();
 
                 for message in messages {
-                    if !message.body.text.is_empty() {
-                        message_elements.push(
-                            div()
-                                // todo!(): The element Id will need to be a combo of the base ID and the index within the grouping
-                                .child(message.body.element(ElementId::from(id.0), cx))
-                                .into_any_element(),
-                        )
+                    if !message.body.read(cx).source().is_empty() {
+                        message_elements.push(div().child(message.body.clone()).into_any())
                     }
 
                     let tools = message
@@ -847,7 +906,7 @@ impl AssistantChat {
                         .collect::<Vec<AnyElement>>();
 
                     if !tools.is_empty() {
-                        message_elements.push(div().children(tools).into_any_element())
+                        message_elements.push(div().children(tools).into_any())
                     }
                 }
 
@@ -900,14 +959,14 @@ impl AssistantChat {
 
                     // Show user's message last so that the assistant is grounded in the user's request
                     completion_messages.push(CompletionMessage::User {
-                        content: body.read(cx).text(cx),
+                        content: body.read(cx).source().to_string(),
                     });
                 }
                 ChatMessage::Assistant(AssistantMessage { messages, .. }) => {
                     for message in messages {
                         let body = message.body.clone();
 
-                        if body.text.is_empty() && message.tool_calls.is_empty() {
+                        if body.read(cx).source().is_empty() && message.tool_calls.is_empty() {
                             continue;
                         }
 
@@ -926,7 +985,7 @@ impl AssistantChat {
                             .collect();
 
                         completion_messages.push(CompletionMessage::Assistant {
-                            content: Some(body.text.to_string()),
+                            content: Some(body.read(cx).source().to_string()),
                             tool_calls: tool_calls_from_assistant,
                         });
 
@@ -964,7 +1023,7 @@ impl AssistantChat {
         match message {
             ChatMessage::User(message) => SavedChatMessage::User {
                 id: message.id,
-                body: message.body.read(cx).text(cx),
+                body: message.body.read(cx).source().into(),
                 attachments: message
                     .attachments
                     .iter()
@@ -981,7 +1040,7 @@ impl AssistantChat {
                     .messages
                     .iter()
                     .map(|message| SavedAssistantMessagePart {
-                        body: message.body.text.clone(),
+                        body: message.body.read(cx).source().to_string().into(),
                         tool_calls: message
                             .tool_calls
                             .iter()
@@ -1093,7 +1152,7 @@ enum ChatMessage {
 impl ChatMessage {
     fn focus_handle(&self, cx: &AppContext) -> Option<FocusHandle> {
         match self {
-            ChatMessage::User(UserMessage { body, .. }) => Some(body.focus_handle(cx)),
+            ChatMessage::User(message) => Some(message.body.focus_handle(cx)),
             ChatMessage::Assistant(_) => None,
         }
     }
@@ -1101,12 +1160,12 @@ impl ChatMessage {
 
 struct UserMessage {
     pub id: MessageId,
-    pub body: View<Editor>,
+    pub body: View<Markdown>,
     pub attachments: Vec<UserAttachment>,
 }
 
 struct AssistantMessagePart {
-    pub body: RichText,
+    pub body: View<Markdown>,
     pub tool_calls: Vec<ToolFunctionCall>,
 }
 

crates/assistant2/src/ui/chat_message.rs 🔗

@@ -119,7 +119,7 @@ impl RenderOnce for ChatMessage {
             )
             .when(self.messages.len() > 0, |el| {
                 el.child(
-                    h_flex().child(
+                    h_flex().w_full().child(
                         v_flex()
                             .relative()
                             .overflow_hidden()

crates/markdown/src/markdown.rs 🔗

@@ -3,10 +3,10 @@ mod parser;
 use crate::parser::CodeBlockKind;
 use futures::FutureExt;
 use gpui::{
-    point, quad, AnyElement, Bounds, CursorStyle, DispatchPhase, Edges, FontStyle, FontWeight,
-    GlobalElementId, Hitbox, Hsla, MouseDownEvent, MouseEvent, MouseMoveEvent, MouseUpEvent, Point,
-    Render, StrikethroughStyle, Style, StyledText, Task, TextLayout, TextRun, TextStyle,
-    TextStyleRefinement, View,
+    point, quad, AnyElement, AppContext, Bounds, CursorStyle, DispatchPhase, Edges, FocusHandle,
+    FocusableView, FontStyle, FontWeight, GlobalElementId, Hitbox, Hsla, KeyContext,
+    MouseDownEvent, MouseEvent, MouseMoveEvent, MouseUpEvent, Point, Render, StrikethroughStyle,
+    Style, StyledText, Task, TextLayout, TextRun, TextStyle, TextStyleRefinement, View,
 };
 use language::{Language, LanguageRegistry, Rope};
 use parser::{parse_markdown, MarkdownEvent, MarkdownTag, MarkdownTagEnd};
@@ -36,6 +36,7 @@ pub struct Markdown {
     parsed_markdown: ParsedMarkdown,
     should_reparse: bool,
     pending_parse: Option<Task<Option<()>>>,
+    focus_handle: FocusHandle,
     language_registry: Arc<LanguageRegistry>,
 }
 
@@ -46,6 +47,7 @@ impl Markdown {
         language_registry: Arc<LanguageRegistry>,
         cx: &mut ViewContext<Self>,
     ) -> Self {
+        let focus_handle = cx.focus_handle();
         let mut this = Self {
             source,
             selection: Selection::default(),
@@ -55,6 +57,7 @@ impl Markdown {
             should_reparse: false,
             parsed_markdown: ParsedMarkdown::default(),
             pending_parse: None,
+            focus_handle,
             language_registry,
         };
         this.parse(cx);
@@ -66,6 +69,16 @@ impl Markdown {
         self.parse(cx);
     }
 
+    pub fn reset(&mut self, source: String, cx: &mut ViewContext<Self>) {
+        self.source = source;
+        self.selection = Selection::default();
+        self.autoscroll_request = None;
+        self.pending_parse = None;
+        self.should_reparse = false;
+        self.parsed_markdown = ParsedMarkdown::default();
+        self.parse(cx);
+    }
+
     pub fn source(&self) -> &str {
         &self.source
     }
@@ -120,6 +133,12 @@ impl Render for Markdown {
     }
 }
 
+impl FocusableView for Markdown {
+    fn focus_handle(&self, _cx: &AppContext) -> FocusHandle {
+        self.focus_handle.clone()
+    }
+}
+
 #[derive(Copy, Clone, Default, Debug)]
 struct Selection {
     start: usize,
@@ -309,6 +328,7 @@ impl MarkdownElement {
                                 reversed: false,
                                 pending: true,
                             };
+                            cx.focus(&markdown.focus_handle);
                         }
 
                         cx.notify();
@@ -593,6 +613,13 @@ impl Element for MarkdownElement {
         hitbox: &mut Self::PrepaintState,
         cx: &mut WindowContext,
     ) {
+        let focus_handle = self.markdown.read(cx).focus_handle.clone();
+        cx.set_focus_handle(&focus_handle);
+
+        let mut context = KeyContext::default();
+        context.add("Markdown");
+        cx.set_key_context(context);
+
         self.paint_mouse_listeners(hitbox, &rendered_markdown.text, cx);
         rendered_markdown.element.paint(cx);
         self.paint_selection(bounds, &rendered_markdown.text, cx);