From 2b9250843c110b13644c81b7e3abd17a92edc567 Mon Sep 17 00:00:00 2001 From: Marshall Bowers Date: Mon, 25 Nov 2024 16:51:32 -0500 Subject: [PATCH] assistant2: Include previous messages in the thread in the completion request (#21184) This PR makes it so previous messages in the thread are included when constructing the completion request, instead of only sending up the most recent user message. Release Notes: - N/A --- crates/assistant2/src/assistant_panel.rs | 2 +- crates/assistant2/src/message_editor.rs | 48 +++------------------ crates/assistant2/src/thread.rs | 54 ++++++++++++++++++++++-- 3 files changed, 58 insertions(+), 46 deletions(-) diff --git a/crates/assistant2/src/assistant_panel.rs b/crates/assistant2/src/assistant_panel.rs index 20ce26fc5909f703934c83528e2d8f62e30948fe..f3dd42e4d6b63da6f4d903088f788cc05477a845 100644 --- a/crates/assistant2/src/assistant_panel.rs +++ b/crates/assistant2/src/assistant_panel.rs @@ -234,7 +234,7 @@ impl Render for AssistantPanel { .p_2() .overflow_y_scroll() .bg(cx.theme().colors().panel_background) - .children(self.thread.read(cx).messages.iter().map(|message| { + .children(self.thread.read(cx).messages().map(|message| { v_flex() .p_2() .border_1() diff --git a/crates/assistant2/src/message_editor.rs b/crates/assistant2/src/message_editor.rs index e1606ff27af3759ff67733b84af39dde989ced6d..f0a8e260bc5fd816da493fadd2362b10e341709c 100644 --- a/crates/assistant2/src/message_editor.rs +++ b/crates/assistant2/src/message_editor.rs @@ -1,20 +1,13 @@ use editor::{Editor, EditorElement, EditorStyle}; use gpui::{AppContext, FocusableView, Model, TextStyle, View}; -use language_model::{ - LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, MessageContent, Role, -}; +use language_model::LanguageModelRegistry; use settings::Settings; use theme::ThemeSettings; use ui::{prelude::*, ButtonLike, ElevationIndex, KeyBinding}; -use crate::thread::{self, Thread}; +use crate::thread::{RequestKind, Thread}; use crate::Chat; -#[derive(Debug, Clone, Copy)] -pub enum RequestKind { - Chat, -} - pub struct MessageEditor { thread: Model, editor: View, @@ -54,47 +47,20 @@ impl MessageEditor { let model_registry = LanguageModelRegistry::read_global(cx); let model = model_registry.active_model()?; - let request = self.build_completion_request(request_kind, cx); - - let user_message = self.editor.read(cx).text(cx); - self.thread.update(cx, |thread, _cx| { - thread.messages.push(thread::Message { - role: Role::User, - text: user_message, - }); - }); - - self.editor.update(cx, |editor, cx| { + let user_message = self.editor.update(cx, |editor, cx| { + let text = editor.text(cx); editor.clear(cx); + text }); self.thread.update(cx, |thread, cx| { + thread.insert_user_message(user_message); + let request = thread.to_completion_request(request_kind, cx); thread.stream_completion(request, model, cx) }); None } - - fn build_completion_request( - &self, - _request_kind: RequestKind, - cx: &AppContext, - ) -> LanguageModelRequest { - let text = self.editor.read(cx).text(cx); - - let request = LanguageModelRequest { - messages: vec![LanguageModelRequestMessage { - role: Role::User, - content: vec![MessageContent::Text(text)], - cache: false, - }], - tools: Vec::new(), - stop: Vec::new(), - temperature: None, - }; - - request - } } impl FocusableView for MessageEditor { diff --git a/crates/assistant2/src/thread.rs b/crates/assistant2/src/thread.rs index a6c870b4563e218d92ffdb75182fb75079f04164..a433c10267ffe5532e68ba1bad10de2951bc65e1 100644 --- a/crates/assistant2/src/thread.rs +++ b/crates/assistant2/src/thread.rs @@ -1,12 +1,18 @@ use std::sync::Arc; use futures::StreamExt as _; -use gpui::{EventEmitter, ModelContext, Task}; +use gpui::{AppContext, EventEmitter, ModelContext, Task}; use language_model::{ - LanguageModel, LanguageModelCompletionEvent, LanguageModelRequest, Role, StopReason, + LanguageModel, LanguageModelCompletionEvent, LanguageModelRequest, LanguageModelRequestMessage, + MessageContent, Role, StopReason, }; use util::ResultExt as _; +#[derive(Debug, Clone, Copy)] +pub enum RequestKind { + Chat, +} + /// A message in a [`Thread`]. pub struct Message { pub role: Role, @@ -15,8 +21,8 @@ pub struct Message { /// A thread of conversation with the LLM. pub struct Thread { - pub messages: Vec, - pub pending_completion_tasks: Vec>, + messages: Vec, + pending_completion_tasks: Vec>, } impl Thread { @@ -27,6 +33,46 @@ impl Thread { } } + pub fn messages(&self) -> impl Iterator { + self.messages.iter() + } + + pub fn insert_user_message(&mut self, text: impl Into) { + self.messages.push(Message { + role: Role::User, + text: text.into(), + }); + } + + pub fn to_completion_request( + &self, + _request_kind: RequestKind, + _cx: &AppContext, + ) -> LanguageModelRequest { + let mut request = LanguageModelRequest { + messages: vec![], + tools: Vec::new(), + stop: Vec::new(), + temperature: None, + }; + + for message in &self.messages { + let mut request_message = LanguageModelRequestMessage { + role: message.role, + content: Vec::new(), + cache: false, + }; + + request_message + .content + .push(MessageContent::Text(message.text.clone())); + + request.messages.push(request_message); + } + + request + } + pub fn stream_completion( &mut self, request: LanguageModelRequest,