Cargo.lock 🔗
@@ -464,10 +464,12 @@ dependencies = [
"feature_flags",
"futures 0.3.31",
"gpui",
+ "language",
"language_model",
"language_model_selector",
"language_models",
"log",
+ "markdown",
"project",
"proto",
"serde",
Marshall Bowers created
This PR updates Assistant 2 to render the messages in the thread as
Markdown:
<img width="1138" alt="Screenshot 2024-12-03 at 6 09 27 PM"
src="https://github.com/user-attachments/assets/c1c44fde-1efb-43cf-b9c9-768e6974c753">
Release Notes:
- N/A
Cargo.lock | 2
crates/assistant2/Cargo.toml | 4 +
crates/assistant2/src/assistant_panel.rs | 74 ++++++++++++++++++++++++-
crates/assistant2/src/thread.rs | 5 +
4 files changed, 81 insertions(+), 4 deletions(-)
@@ -464,10 +464,12 @@ dependencies = [
"feature_flags",
"futures 0.3.31",
"gpui",
+ "language",
"language_model",
"language_model_selector",
"language_models",
"log",
+ "markdown",
"project",
"proto",
"serde",
@@ -23,15 +23,17 @@ editor.workspace = true
feature_flags.workspace = true
futures.workspace = true
gpui.workspace = true
+language.workspace = true
language_model.workspace = true
language_model_selector.workspace = true
language_models.workspace = true
log.workspace = true
+markdown.workspace = true
project.workspace = true
proto.workspace = true
-settings.workspace = true
serde.workspace = true
serde_json.workspace = true
+settings.workspace = true
smol.workspace = true
theme.workspace = true
ui.workspace = true
@@ -3,13 +3,19 @@ use std::sync::Arc;
use anyhow::Result;
use assistant_tool::ToolWorkingSet;
use client::zed_urls;
+use collections::HashMap;
use gpui::{
list, prelude::*, px, Action, AnyElement, AppContext, AsyncWindowContext, Empty, EventEmitter,
- FocusHandle, FocusableView, FontWeight, ListAlignment, ListState, Model, Pixels, Subscription,
- Task, View, ViewContext, WeakView, WindowContext,
+ FocusHandle, FocusableView, FontWeight, ListAlignment, ListState, Model, Pixels,
+ StyleRefinement, Subscription, Task, TextStyleRefinement, View, ViewContext, WeakView,
+ WindowContext,
};
+use language::LanguageRegistry;
use language_model::{LanguageModelRegistry, Role};
use language_model_selector::LanguageModelSelector;
+use markdown::{Markdown, MarkdownStyle};
+use settings::Settings;
+use theme::ThemeSettings;
use ui::{prelude::*, ButtonLike, Divider, IconButtonShape, Tab, Tooltip};
use workspace::dock::{DockPosition, Panel, PanelEvent};
use workspace::Workspace;
@@ -32,10 +38,12 @@ pub fn init(cx: &mut AppContext) {
pub struct AssistantPanel {
workspace: WeakView<Workspace>,
+ language_registry: Arc<LanguageRegistry>,
#[allow(unused)]
thread_store: Model<ThreadStore>,
thread: Model<Thread>,
thread_messages: Vec<MessageId>,
+ rendered_messages_by_id: HashMap<MessageId, View<Markdown>>,
thread_list_state: ListState,
message_editor: View<MessageEditor>,
tools: Arc<ToolWorkingSet>,
@@ -77,9 +85,11 @@ impl AssistantPanel {
Self {
workspace: workspace.weak_handle(),
+ language_registry: workspace.project().read(cx).languages().clone(),
thread_store,
thread: thread.clone(),
thread_messages: Vec::new(),
+ rendered_messages_by_id: HashMap::default(),
thread_list_state: ListState::new(0, ListAlignment::Bottom, px(1024.), {
let this = cx.view().downgrade();
move |ix, cx: &mut WindowContext| {
@@ -104,6 +114,9 @@ impl AssistantPanel {
self.message_editor = cx.new_view(|cx| MessageEditor::new(thread.clone(), cx));
self.thread = thread;
+ self.thread_messages.clear();
+ self.thread_list_state.reset(0);
+ self.rendered_messages_by_id.clear();
self._subscriptions = subscriptions;
self.message_editor.focus_handle(cx).focus(cx);
@@ -120,10 +133,61 @@ impl AssistantPanel {
self.last_error = Some(error.clone());
}
ThreadEvent::StreamedCompletion => {}
+ ThreadEvent::StreamedAssistantText(message_id, text) => {
+ if let Some(markdown) = self.rendered_messages_by_id.get_mut(&message_id) {
+ markdown.update(cx, |markdown, cx| {
+ markdown.append(text, cx);
+ });
+ }
+ }
ThreadEvent::MessageAdded(message_id) => {
let old_len = self.thread_messages.len();
self.thread_messages.push(*message_id);
self.thread_list_state.splice(old_len..old_len, 1);
+
+ if let Some(message_text) = self
+ .thread
+ .read(cx)
+ .message(*message_id)
+ .map(|message| message.text.clone())
+ {
+ let theme_settings = ThemeSettings::get_global(cx);
+
+ let mut text_style = cx.text_style();
+ text_style.refine(&TextStyleRefinement {
+ font_family: Some(theme_settings.ui_font.family.clone()),
+ font_size: Some(TextSize::Default.rems(cx).into()),
+ color: Some(cx.theme().colors().text),
+ ..Default::default()
+ });
+
+ let markdown_style = MarkdownStyle {
+ base_text_style: text_style,
+ syntax: cx.theme().syntax().clone(),
+ selection_background_color: cx.theme().players().local().selection,
+ code_block: StyleRefinement {
+ text: Some(TextStyleRefinement {
+ font_family: Some(theme_settings.buffer_font.family.clone()),
+ font_size: Some(theme_settings.buffer_font_size.into()),
+ ..Default::default()
+ }),
+ ..Default::default()
+ },
+ ..Default::default()
+ };
+
+ let markdown = cx.new_view(|cx| {
+ Markdown::new(
+ message_text,
+ markdown_style,
+ Some(self.language_registry.clone()),
+ None,
+ cx,
+ )
+ });
+ self.rendered_messages_by_id.insert(*message_id, markdown);
+ }
+
cx.notify();
}
ThreadEvent::UsePendingTools => {
@@ -323,6 +387,10 @@ impl AssistantPanel {
return Empty.into_any();
};
+ let Some(markdown) = self.rendered_messages_by_id.get(&message_id) else {
+ return Empty.into_any();
+ };
+
let (role_icon, role_name) = match message.role {
Role::User => (IconName::Person, "You"),
Role::Assistant => (IconName::ZedAssistant, "Assistant"),
@@ -350,7 +418,7 @@ impl AssistantPanel {
.child(Label::new(role_name).size(LabelSize::Small)),
),
)
- .child(v_flex().p_1p5().child(Label::new(message.text.clone()))),
+ .child(v_flex().p_1p5().text_ui(cx).child(markdown.clone())),
)
.into_any()
}
@@ -167,6 +167,10 @@ impl Thread {
if let Some(last_message) = thread.messages.last_mut() {
if last_message.role == Role::Assistant {
last_message.text.push_str(&chunk);
+ cx.emit(ThreadEvent::StreamedAssistantText(
+ last_message.id,
+ chunk,
+ ));
}
}
}
@@ -320,6 +324,7 @@ pub enum ThreadError {
pub enum ThreadEvent {
ShowError(ThreadError),
StreamedCompletion,
+ StreamedAssistantText(MessageId, String),
MessageAdded(MessageId),
UsePendingTools,
ToolFinished {