assistant2: Include previous messages in the thread in the completion request (#21184)

Marshall Bowers created

This PR makes it so previous messages in the thread are included when
constructing the completion request, instead of only sending up the most
recent user message.

Release Notes:

- N/A

Change summary

crates/assistant2/src/assistant_panel.rs |  2 
crates/assistant2/src/message_editor.rs  | 48 +++-------------------
crates/assistant2/src/thread.rs          | 54 ++++++++++++++++++++++++-
3 files changed, 58 insertions(+), 46 deletions(-)

Detailed changes

crates/assistant2/src/assistant_panel.rs 🔗

@@ -234,7 +234,7 @@ impl Render for AssistantPanel {
                     .p_2()
                     .overflow_y_scroll()
                     .bg(cx.theme().colors().panel_background)
-                    .children(self.thread.read(cx).messages.iter().map(|message| {
+                    .children(self.thread.read(cx).messages().map(|message| {
                         v_flex()
                             .p_2()
                             .border_1()

crates/assistant2/src/message_editor.rs 🔗

@@ -1,20 +1,13 @@
 use editor::{Editor, EditorElement, EditorStyle};
 use gpui::{AppContext, FocusableView, Model, TextStyle, View};
-use language_model::{
-    LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, MessageContent, Role,
-};
+use language_model::LanguageModelRegistry;
 use settings::Settings;
 use theme::ThemeSettings;
 use ui::{prelude::*, ButtonLike, ElevationIndex, KeyBinding};
 
-use crate::thread::{self, Thread};
+use crate::thread::{RequestKind, Thread};
 use crate::Chat;
 
-#[derive(Debug, Clone, Copy)]
-pub enum RequestKind {
-    Chat,
-}
-
 pub struct MessageEditor {
     thread: Model<Thread>,
     editor: View<Editor>,
@@ -54,47 +47,20 @@ impl MessageEditor {
         let model_registry = LanguageModelRegistry::read_global(cx);
         let model = model_registry.active_model()?;
 
-        let request = self.build_completion_request(request_kind, cx);
-
-        let user_message = self.editor.read(cx).text(cx);
-        self.thread.update(cx, |thread, _cx| {
-            thread.messages.push(thread::Message {
-                role: Role::User,
-                text: user_message,
-            });
-        });
-
-        self.editor.update(cx, |editor, cx| {
+        let user_message = self.editor.update(cx, |editor, cx| {
+            let text = editor.text(cx);
             editor.clear(cx);
+            text
         });
 
         self.thread.update(cx, |thread, cx| {
+            thread.insert_user_message(user_message);
+            let request = thread.to_completion_request(request_kind, cx);
             thread.stream_completion(request, model, cx)
         });
 
         None
     }
-
-    fn build_completion_request(
-        &self,
-        _request_kind: RequestKind,
-        cx: &AppContext,
-    ) -> LanguageModelRequest {
-        let text = self.editor.read(cx).text(cx);
-
-        let request = LanguageModelRequest {
-            messages: vec![LanguageModelRequestMessage {
-                role: Role::User,
-                content: vec![MessageContent::Text(text)],
-                cache: false,
-            }],
-            tools: Vec::new(),
-            stop: Vec::new(),
-            temperature: None,
-        };
-
-        request
-    }
 }
 
 impl FocusableView for MessageEditor {

crates/assistant2/src/thread.rs 🔗

@@ -1,12 +1,18 @@
 use std::sync::Arc;
 
 use futures::StreamExt as _;
-use gpui::{EventEmitter, ModelContext, Task};
+use gpui::{AppContext, EventEmitter, ModelContext, Task};
 use language_model::{
-    LanguageModel, LanguageModelCompletionEvent, LanguageModelRequest, Role, StopReason,
+    LanguageModel, LanguageModelCompletionEvent, LanguageModelRequest, LanguageModelRequestMessage,
+    MessageContent, Role, StopReason,
 };
 use util::ResultExt as _;
 
+#[derive(Debug, Clone, Copy)]
+pub enum RequestKind {
+    Chat,
+}
+
 /// A message in a [`Thread`].
 pub struct Message {
     pub role: Role,
@@ -15,8 +21,8 @@ pub struct Message {
 
 /// A thread of conversation with the LLM.
 pub struct Thread {
-    pub messages: Vec<Message>,
-    pub pending_completion_tasks: Vec<Task<()>>,
+    messages: Vec<Message>,
+    pending_completion_tasks: Vec<Task<()>>,
 }
 
 impl Thread {
@@ -27,6 +33,46 @@ impl Thread {
         }
     }
 
+    pub fn messages(&self) -> impl Iterator<Item = &Message> {
+        self.messages.iter()
+    }
+
+    pub fn insert_user_message(&mut self, text: impl Into<String>) {
+        self.messages.push(Message {
+            role: Role::User,
+            text: text.into(),
+        });
+    }
+
+    pub fn to_completion_request(
+        &self,
+        _request_kind: RequestKind,
+        _cx: &AppContext,
+    ) -> LanguageModelRequest {
+        let mut request = LanguageModelRequest {
+            messages: vec![],
+            tools: Vec::new(),
+            stop: Vec::new(),
+            temperature: None,
+        };
+
+        for message in &self.messages {
+            let mut request_message = LanguageModelRequestMessage {
+                role: message.role,
+                content: Vec::new(),
+                cache: false,
+            };
+
+            request_message
+                .content
+                .push(MessageContent::Text(message.text.clone()));
+
+            request.messages.push(request_message);
+        }
+
+        request
+    }
+
     pub fn stream_completion(
         &mut self,
         request: LanguageModelRequest,