assistant2: Render messages in the thread using a `list` (#21491)

Marshall Bowers created

This PR updates the rendering of the messages in the current thread to
use a `gpui::list`.

Release Notes:

- N/A

Change summary

crates/assistant2/src/assistant_panel.rs | 80 ++++++++++++++-----------
crates/assistant2/src/message_editor.rs  |  2 
crates/assistant2/src/thread.rs          | 15 +++-
3 files changed, 57 insertions(+), 40 deletions(-)

Detailed changes

crates/assistant2/src/assistant_panel.rs 🔗

@@ -4,9 +4,9 @@ use anyhow::Result;
 use assistant_tool::ToolWorkingSet;
 use client::zed_urls;
 use gpui::{
-    prelude::*, px, Action, AnyElement, AppContext, AsyncWindowContext, EventEmitter, FocusHandle,
-    FocusableView, FontWeight, Model, Pixels, Subscription, Task, View, ViewContext, WeakView,
-    WindowContext,
+    list, prelude::*, px, Action, AnyElement, AppContext, AsyncWindowContext, Empty, EventEmitter,
+    FocusHandle, FocusableView, FontWeight, ListAlignment, ListState, Model, Pixels, Subscription,
+    Task, View, ViewContext, WeakView, WindowContext,
 };
 use language_model::{LanguageModelRegistry, Role};
 use language_model_selector::LanguageModelSelector;
@@ -15,7 +15,7 @@ use workspace::dock::{DockPosition, Panel, PanelEvent};
 use workspace::Workspace;
 
 use crate::message_editor::MessageEditor;
-use crate::thread::{Message, Thread, ThreadError, ThreadEvent};
+use crate::thread::{MessageId, Thread, ThreadError, ThreadEvent};
 use crate::thread_store::ThreadStore;
 use crate::{NewThread, ToggleFocus, ToggleModelSelector};
 
@@ -35,6 +35,8 @@ pub struct AssistantPanel {
     #[allow(unused)]
     thread_store: Model<ThreadStore>,
     thread: Model<Thread>,
+    thread_messages: Vec<MessageId>,
+    thread_list_state: ListState,
     message_editor: View<MessageEditor>,
     tools: Arc<ToolWorkingSet>,
     last_error: Option<ThreadError>,
@@ -77,6 +79,14 @@ impl AssistantPanel {
             workspace: workspace.weak_handle(),
             thread_store,
             thread: thread.clone(),
+            thread_messages: Vec::new(),
+            thread_list_state: ListState::new(0, ListAlignment::Bottom, px(1024.), {
+                let this = cx.view().downgrade();
+                move |ix, cx: &mut WindowContext| {
+                    this.update(cx, |this, cx| this.render_message(ix, cx))
+                        .unwrap()
+                }
+            }),
             message_editor: cx.new_view(|cx| MessageEditor::new(thread, cx)),
             tools,
             last_error: None,
@@ -110,6 +120,12 @@ impl AssistantPanel {
                 self.last_error = Some(error.clone());
             }
             ThreadEvent::StreamedCompletion => {}
+            ThreadEvent::MessageAdded(message_id) => {
+                let old_len = self.thread_messages.len();
+                self.thread_messages.push(*message_id);
+                self.thread_list_state.splice(old_len..old_len, 1);
+                cx.notify();
+            }
             ThreadEvent::UsePendingTools => {
                 let pending_tool_uses = self
                     .thread
@@ -301,31 +317,42 @@ impl AssistantPanel {
         )
     }
 
-    fn render_message(&self, message: Message, cx: &mut ViewContext<Self>) -> impl IntoElement {
+    fn render_message(&self, ix: usize, cx: &mut ViewContext<Self>) -> AnyElement {
+        let message_id = self.thread_messages[ix];
+        let Some(message) = self.thread.read(cx).message(message_id) else {
+            return Empty.into_any();
+        };
+
         let (role_icon, role_name) = match message.role {
             Role::User => (IconName::Person, "You"),
             Role::Assistant => (IconName::ZedAssistant, "Assistant"),
             Role::System => (IconName::Settings, "System"),
         };
 
-        v_flex()
-            .border_1()
-            .border_color(cx.theme().colors().border_variant)
-            .rounded_md()
+        div()
+            .id(("message-container", ix))
+            .p_2()
             .child(
-                h_flex()
-                    .justify_between()
-                    .p_1p5()
-                    .border_b_1()
+                v_flex()
+                    .border_1()
                     .border_color(cx.theme().colors().border_variant)
+                    .rounded_md()
                     .child(
                         h_flex()
-                            .gap_2()
-                            .child(Icon::new(role_icon).size(IconSize::Small))
-                            .child(Label::new(role_name).size(LabelSize::Small)),
-                    ),
+                            .justify_between()
+                            .p_1p5()
+                            .border_b_1()
+                            .border_color(cx.theme().colors().border_variant)
+                            .child(
+                                h_flex()
+                                    .gap_2()
+                                    .child(Icon::new(role_icon).size(IconSize::Small))
+                                    .child(Label::new(role_name).size(LabelSize::Small)),
+                            ),
+                    )
+                    .child(v_flex().p_1p5().child(Label::new(message.text.clone()))),
             )
-            .child(v_flex().p_1p5().child(Label::new(message.text.clone())))
+            .into_any()
     }
 
     fn render_last_error(&self, cx: &mut ViewContext<Self>) -> Option<AnyElement> {
@@ -477,8 +504,6 @@ impl AssistantPanel {
 
 impl Render for AssistantPanel {
     fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
-        let messages = self.thread.read(cx).messages().cloned().collect::<Vec<_>>();
-
         v_flex()
             .key_context("AssistantPanel2")
             .justify_between()
@@ -487,20 +512,7 @@ impl Render for AssistantPanel {
                 this.new_thread(cx);
             }))
             .child(self.render_toolbar(cx))
-            .child(
-                v_flex()
-                    .id("message-list")
-                    .gap_2()
-                    .size_full()
-                    .p_2()
-                    .overflow_y_scroll()
-                    .bg(cx.theme().colors().panel_background)
-                    .children(
-                        messages
-                            .into_iter()
-                            .map(|message| self.render_message(message, cx)),
-                    ),
-            )
+            .child(list(self.thread_list_state.clone()).flex_1())
             .child(
                 h_flex()
                     .border_t_1()

crates/assistant2/src/message_editor.rs 🔗

@@ -56,7 +56,7 @@ impl MessageEditor {
         });
 
         self.thread.update(cx, |thread, cx| {
-            thread.insert_user_message(user_message);
+            thread.insert_user_message(user_message, cx);
             let mut request = thread.to_completion_request(request_kind, cx);
 
             if self.use_tools {

crates/assistant2/src/thread.rs 🔗

@@ -63,8 +63,8 @@ impl Thread {
         }
     }
 
-    pub fn messages(&self) -> impl Iterator<Item = &Message> {
-        self.messages.iter()
+    pub fn message(&self, id: MessageId) -> Option<&Message> {
+        self.messages.iter().find(|message| message.id == id)
     }
 
     pub fn tools(&self) -> &Arc<ToolWorkingSet> {
@@ -75,12 +75,14 @@ impl Thread {
         self.pending_tool_uses_by_id.values().collect()
     }
 
-    pub fn insert_user_message(&mut self, text: impl Into<String>) {
+    pub fn insert_user_message(&mut self, text: impl Into<String>, cx: &mut ModelContext<Self>) {
+        let id = self.next_message_id.post_inc();
         self.messages.push(Message {
-            id: self.next_message_id.post_inc(),
+            id,
             role: Role::User,
             text: text.into(),
         });
+        cx.emit(ThreadEvent::MessageAdded(id));
     }
 
     pub fn to_completion_request(
@@ -150,11 +152,13 @@ impl Thread {
                     thread.update(&mut cx, |thread, cx| {
                         match event {
                             LanguageModelCompletionEvent::StartMessage { .. } => {
+                                let id = thread.next_message_id.post_inc();
                                 thread.messages.push(Message {
-                                    id: thread.next_message_id.post_inc(),
+                                    id,
                                     role: Role::Assistant,
                                     text: String::new(),
                                 });
+                                cx.emit(ThreadEvent::MessageAdded(id));
                             }
                             LanguageModelCompletionEvent::Stop(reason) => {
                                 stop_reason = reason;
@@ -316,6 +320,7 @@ pub enum ThreadError {
 pub enum ThreadEvent {
     ShowError(ThreadError),
     StreamedCompletion,
+    MessageAdded(MessageId),
     UsePendingTools,
     ToolFinished {
         #[allow(unused)]