assistant2: Stream in completion text (#21182)

Marshall Bowers created

This PR makes it so that the completion text streams into the message
list rather than being buffered until the end.

Release Notes:

- N/A

Change summary

crates/assistant2/src/assistant_panel.rs |  5 +
crates/assistant2/src/message_editor.rs  | 51 +-----------------
crates/assistant2/src/thread.rs          | 70 +++++++++++++++++++++++++
3 files changed, 75 insertions(+), 51 deletions(-)

Detailed changes

crates/assistant2/src/assistant_panel.rs 🔗

@@ -1,7 +1,7 @@
 use anyhow::Result;
 use gpui::{
     prelude::*, px, Action, AppContext, AsyncWindowContext, EventEmitter, FocusHandle,
-    FocusableView, Model, Pixels, Task, View, ViewContext, WeakView, WindowContext,
+    FocusableView, Model, Pixels, Subscription, Task, View, ViewContext, WeakView, WindowContext,
 };
 use language_model::LanguageModelRegistry;
 use language_model_selector::LanguageModelSelector;
@@ -28,6 +28,7 @@ pub struct AssistantPanel {
     pane: View<Pane>,
     thread: Model<Thread>,
     message_editor: View<MessageEditor>,
+    _subscriptions: Vec<Subscription>,
 }
 
 impl AssistantPanel {
@@ -59,11 +60,13 @@ impl AssistantPanel {
         });
 
         let thread = cx.new_model(Thread::new);
+        let subscriptions = vec![cx.observe(&thread, |_, _, cx| cx.notify())];
 
         Self {
             pane,
             thread: thread.clone(),
             message_editor: cx.new_view(|cx| MessageEditor::new(thread, cx)),
+            _subscriptions: subscriptions,
         }
     }
 }

crates/assistant2/src/message_editor.rs 🔗

@@ -1,14 +1,11 @@
 use editor::{Editor, EditorElement, EditorStyle};
-use futures::StreamExt;
 use gpui::{AppContext, Model, TextStyle, View};
 use language_model::{
-    LanguageModelCompletionEvent, LanguageModelRegistry, LanguageModelRequest,
-    LanguageModelRequestMessage, MessageContent, Role, StopReason,
+    LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, MessageContent, Role,
 };
 use settings::Settings;
 use theme::ThemeSettings;
 use ui::{prelude::*, ButtonLike, ElevationIndex, KeyBinding};
-use util::ResultExt;
 
 use crate::thread::{self, Thread};
 use crate::Chat;
@@ -71,50 +68,8 @@ impl MessageEditor {
             editor.clear(cx);
         });
 
-        let task = cx.spawn(|this, mut cx| async move {
-            let stream = model.stream_completion(request, &cx);
-            let stream_completion = async {
-                let mut events = stream.await?;
-                let mut stop_reason = StopReason::EndTurn;
-
-                let mut text = String::new();
-
-                while let Some(event) = events.next().await {
-                    let event = event?;
-                    match event {
-                        LanguageModelCompletionEvent::StartMessage { .. } => {}
-                        LanguageModelCompletionEvent::Stop(reason) => {
-                            stop_reason = reason;
-                        }
-                        LanguageModelCompletionEvent::Text(chunk) => {
-                            text.push_str(&chunk);
-                        }
-                        LanguageModelCompletionEvent::ToolUse(_tool_use) => {}
-                    }
-
-                    smol::future::yield_now().await;
-                }
-
-                anyhow::Ok((stop_reason, text))
-            };
-
-            let result = stream_completion.await;
-
-            this.update(&mut cx, |this, cx| {
-                if let Some((_stop_reason, text)) = result.log_err() {
-                    this.thread.update(cx, |thread, _cx| {
-                        thread.messages.push(thread::Message {
-                            role: Role::Assistant,
-                            text,
-                        });
-                    });
-                }
-            })
-            .ok();
-        });
-
-        self.thread.update(cx, |thread, _cx| {
-            thread.pending_completion_tasks.push(task);
+        self.thread.update(cx, |thread, cx| {
+            thread.stream_completion(request, model, cx)
         });
 
         None

crates/assistant2/src/thread.rs 🔗

@@ -1,5 +1,11 @@
-use gpui::{ModelContext, Task};
-use language_model::Role;
+use std::sync::Arc;
+
+use futures::StreamExt as _;
+use gpui::{EventEmitter, ModelContext, Task};
+use language_model::{
+    LanguageModel, LanguageModelCompletionEvent, LanguageModelRequest, Role, StopReason,
+};
+use util::ResultExt as _;
 
 /// A message in a [`Thread`].
 pub struct Message {
@@ -20,4 +26,64 @@ impl Thread {
             pending_completion_tasks: Vec::new(),
         }
     }
+
+    pub fn stream_completion(
+        &mut self,
+        request: LanguageModelRequest,
+        model: Arc<dyn LanguageModel>,
+        cx: &mut ModelContext<Self>,
+    ) {
+        let task = cx.spawn(|this, mut cx| async move {
+            let stream = model.stream_completion(request, &cx);
+            let stream_completion = async {
+                let mut events = stream.await?;
+                let mut stop_reason = StopReason::EndTurn;
+
+                while let Some(event) = events.next().await {
+                    let event = event?;
+
+                    this.update(&mut cx, |thread, cx| {
+                        match event {
+                            LanguageModelCompletionEvent::StartMessage { .. } => {
+                                thread.messages.push(Message {
+                                    role: Role::Assistant,
+                                    text: String::new(),
+                                });
+                            }
+                            LanguageModelCompletionEvent::Stop(reason) => {
+                                stop_reason = reason;
+                            }
+                            LanguageModelCompletionEvent::Text(chunk) => {
+                                if let Some(last_message) = thread.messages.last_mut() {
+                                    if last_message.role == Role::Assistant {
+                                        last_message.text.push_str(&chunk);
+                                    }
+                                }
+                            }
+                            LanguageModelCompletionEvent::ToolUse(_tool_use) => {}
+                        }
+
+                        cx.emit(ThreadEvent::StreamedCompletion);
+                        cx.notify();
+                    })?;
+
+                    smol::future::yield_now().await;
+                }
+
+                anyhow::Ok(stop_reason)
+            };
+
+            let result = stream_completion.await;
+            let _ = result.log_err();
+        });
+
+        self.pending_completion_tasks.push(task);
+    }
 }
+
+#[derive(Debug, Clone)]
+pub enum ThreadEvent {
+    StreamedCompletion,
+}
+
+impl EventEmitter<ThreadEvent> for Thread {}