diff --git a/Cargo.lock b/Cargo.lock index 6266df7d8667f8dffcfb41f74bda1974c9bc04ca..e1ff5d6dae07a0ba9c04f2b6c8ff1f7e1590d40f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -464,10 +464,12 @@ dependencies = [ "feature_flags", "futures 0.3.31", "gpui", + "language", "language_model", "language_model_selector", "language_models", "log", + "markdown", "project", "proto", "serde", diff --git a/crates/assistant2/Cargo.toml b/crates/assistant2/Cargo.toml index 20e8dfbc9a9891d6778f8ceec600001636b1d5ff..257183a4ac1843599bd8b84bdf55cb82cb19dd78 100644 --- a/crates/assistant2/Cargo.toml +++ b/crates/assistant2/Cargo.toml @@ -23,15 +23,17 @@ editor.workspace = true feature_flags.workspace = true futures.workspace = true gpui.workspace = true +language.workspace = true language_model.workspace = true language_model_selector.workspace = true language_models.workspace = true log.workspace = true +markdown.workspace = true project.workspace = true proto.workspace = true -settings.workspace = true serde.workspace = true serde_json.workspace = true +settings.workspace = true smol.workspace = true theme.workspace = true ui.workspace = true diff --git a/crates/assistant2/src/assistant_panel.rs b/crates/assistant2/src/assistant_panel.rs index b4ac2731e0bdc072babc4267690bb4cdda6778ea..b8ce5b1a36d9600ae369202764867459f610bb01 100644 --- a/crates/assistant2/src/assistant_panel.rs +++ b/crates/assistant2/src/assistant_panel.rs @@ -3,13 +3,19 @@ use std::sync::Arc; use anyhow::Result; use assistant_tool::ToolWorkingSet; use client::zed_urls; +use collections::HashMap; use gpui::{ list, prelude::*, px, Action, AnyElement, AppContext, AsyncWindowContext, Empty, EventEmitter, - FocusHandle, FocusableView, FontWeight, ListAlignment, ListState, Model, Pixels, Subscription, - Task, View, ViewContext, WeakView, WindowContext, + FocusHandle, FocusableView, FontWeight, ListAlignment, ListState, Model, Pixels, + StyleRefinement, Subscription, Task, TextStyleRefinement, View, ViewContext, WeakView, + WindowContext, }; +use language::LanguageRegistry; use language_model::{LanguageModelRegistry, Role}; use language_model_selector::LanguageModelSelector; +use markdown::{Markdown, MarkdownStyle}; +use settings::Settings; +use theme::ThemeSettings; use ui::{prelude::*, ButtonLike, Divider, IconButtonShape, Tab, Tooltip}; use workspace::dock::{DockPosition, Panel, PanelEvent}; use workspace::Workspace; @@ -32,10 +38,12 @@ pub fn init(cx: &mut AppContext) { pub struct AssistantPanel { workspace: WeakView, + language_registry: Arc, #[allow(unused)] thread_store: Model, thread: Model, thread_messages: Vec, + rendered_messages_by_id: HashMap>, thread_list_state: ListState, message_editor: View, tools: Arc, @@ -77,9 +85,11 @@ impl AssistantPanel { Self { workspace: workspace.weak_handle(), + language_registry: workspace.project().read(cx).languages().clone(), thread_store, thread: thread.clone(), thread_messages: Vec::new(), + rendered_messages_by_id: HashMap::default(), thread_list_state: ListState::new(0, ListAlignment::Bottom, px(1024.), { let this = cx.view().downgrade(); move |ix, cx: &mut WindowContext| { @@ -104,6 +114,9 @@ impl AssistantPanel { self.message_editor = cx.new_view(|cx| MessageEditor::new(thread.clone(), cx)); self.thread = thread; + self.thread_messages.clear(); + self.thread_list_state.reset(0); + self.rendered_messages_by_id.clear(); self._subscriptions = subscriptions; self.message_editor.focus_handle(cx).focus(cx); @@ -120,10 +133,61 @@ impl AssistantPanel { self.last_error = Some(error.clone()); } ThreadEvent::StreamedCompletion => {} + ThreadEvent::StreamedAssistantText(message_id, text) => { + if let Some(markdown) = self.rendered_messages_by_id.get_mut(&message_id) { + markdown.update(cx, |markdown, cx| { + markdown.append(text, cx); + }); + } + } ThreadEvent::MessageAdded(message_id) => { let old_len = self.thread_messages.len(); self.thread_messages.push(*message_id); self.thread_list_state.splice(old_len..old_len, 1); + + if let Some(message_text) = self + .thread + .read(cx) + .message(*message_id) + .map(|message| message.text.clone()) + { + let theme_settings = ThemeSettings::get_global(cx); + + let mut text_style = cx.text_style(); + text_style.refine(&TextStyleRefinement { + font_family: Some(theme_settings.ui_font.family.clone()), + font_size: Some(TextSize::Default.rems(cx).into()), + color: Some(cx.theme().colors().text), + ..Default::default() + }); + + let markdown_style = MarkdownStyle { + base_text_style: text_style, + syntax: cx.theme().syntax().clone(), + selection_background_color: cx.theme().players().local().selection, + code_block: StyleRefinement { + text: Some(TextStyleRefinement { + font_family: Some(theme_settings.buffer_font.family.clone()), + font_size: Some(theme_settings.buffer_font_size.into()), + ..Default::default() + }), + ..Default::default() + }, + ..Default::default() + }; + + let markdown = cx.new_view(|cx| { + Markdown::new( + message_text, + markdown_style, + Some(self.language_registry.clone()), + None, + cx, + ) + }); + self.rendered_messages_by_id.insert(*message_id, markdown); + } + cx.notify(); } ThreadEvent::UsePendingTools => { @@ -323,6 +387,10 @@ impl AssistantPanel { return Empty.into_any(); }; + let Some(markdown) = self.rendered_messages_by_id.get(&message_id) else { + return Empty.into_any(); + }; + let (role_icon, role_name) = match message.role { Role::User => (IconName::Person, "You"), Role::Assistant => (IconName::ZedAssistant, "Assistant"), @@ -350,7 +418,7 @@ impl AssistantPanel { .child(Label::new(role_name).size(LabelSize::Small)), ), ) - .child(v_flex().p_1p5().child(Label::new(message.text.clone()))), + .child(v_flex().p_1p5().text_ui(cx).child(markdown.clone())), ) .into_any() } diff --git a/crates/assistant2/src/thread.rs b/crates/assistant2/src/thread.rs index 43868fffffcb9c9a286a640e7074bdcff435df08..a84132588430226672393dcca28a77b7932a370b 100644 --- a/crates/assistant2/src/thread.rs +++ b/crates/assistant2/src/thread.rs @@ -167,6 +167,10 @@ impl Thread { if let Some(last_message) = thread.messages.last_mut() { if last_message.role == Role::Assistant { last_message.text.push_str(&chunk); + cx.emit(ThreadEvent::StreamedAssistantText( + last_message.id, + chunk, + )); } } } @@ -320,6 +324,7 @@ pub enum ThreadError { pub enum ThreadEvent { ShowError(ThreadError), StreamedCompletion, + StreamedAssistantText(MessageId, String), MessageAdded(MessageId), UsePendingTools, ToolFinished {