From fffa40f97303d6c93b135bbc058a664af4093811 Mon Sep 17 00:00:00 2001 From: Marshall Bowers Date: Tue, 7 Jan 2025 14:16:30 -0500 Subject: [PATCH] assistant2: Make context persistent in the thread (#22789) This PR makes it so the context is persistent in the thread, rather than having to reattach it for each message. This PR intentionally does not make an attempt to refresh the attached context if it changes. That will come in a follow-up. Release Notes: - N/A --- crates/assistant2/src/context_store.rs | 6 ------ crates/assistant2/src/message_editor.rs | 4 +++- crates/assistant2/src/thread.rs | 28 ++++++++++++++++++++----- 3 files changed, 26 insertions(+), 12 deletions(-) diff --git a/crates/assistant2/src/context_store.rs b/crates/assistant2/src/context_store.rs index 6ce35f8971254db3f40f85e0adcd03c3fda104b3..f2659f53a86365e8bf9078e1c485f3dd24211631 100644 --- a/crates/assistant2/src/context_store.rs +++ b/crates/assistant2/src/context_store.rs @@ -36,12 +36,6 @@ impl ContextStore { &self.context } - pub fn drain(&mut self) -> Vec { - let context = self.context.drain(..).collect(); - self.clear(); - context - } - pub fn clear(&mut self) { self.context.clear(); self.files.clear(); diff --git a/crates/assistant2/src/message_editor.rs b/crates/assistant2/src/message_editor.rs index bad17d7f10b1a88ce84ea6ef6266bc337a81ea49..684b5c5df0c098ced990d5f9cba1fe24f39292b2 100644 --- a/crates/assistant2/src/message_editor.rs +++ b/crates/assistant2/src/message_editor.rs @@ -142,7 +142,9 @@ impl MessageEditor { editor.clear(cx); text }); - let context = self.context_store.update(cx, |this, _cx| this.drain()); + let context = self + .context_store + .update(cx, |this, _cx| this.context().clone()); self.thread.update(cx, |thread, cx| { thread.insert_user_message(user_message, context, cx); diff --git a/crates/assistant2/src/thread.rs b/crates/assistant2/src/thread.rs index 0816e541b4adc50d8b15b7a385d29a3e3a77ff6e..f3de36f8a7a8fe2c634175eef6f269e719e7167b 100644 --- a/crates/assistant2/src/thread.rs +++ b/crates/assistant2/src/thread.rs @@ -3,7 +3,7 @@ use std::sync::Arc; use anyhow::Result; use assistant_tool::ToolWorkingSet; use chrono::{DateTime, Utc}; -use collections::HashMap; +use collections::{HashMap, HashSet}; use futures::future::Shared; use futures::{FutureExt as _, StreamExt as _}; use gpui::{AppContext, EventEmitter, ModelContext, SharedString, Task}; @@ -209,7 +209,13 @@ impl Thread { temperature: None, }; + let mut referenced_context_ids = HashSet::default(); + for message in &self.messages { + if let Some(context_ids) = self.context_by_message.get(&message.id) { + referenced_context_ids.extend(context_ids); + } + let mut request_message = LanguageModelRequestMessage { role: message.role, content: Vec::new(), @@ -224,10 +230,6 @@ impl Thread { } } - if let Some(context) = self.context_for_message(message.id) { - attach_context_to_message(&mut request_message, context.clone()); - } - if !message.text.is_empty() { request_message .content @@ -245,6 +247,22 @@ impl Thread { request.messages.push(request_message); } + if !referenced_context_ids.is_empty() { + let mut context_message = LanguageModelRequestMessage { + role: Role::User, + content: Vec::new(), + cache: false, + }; + + let referenced_context = referenced_context_ids + .into_iter() + .filter_map(|context_id| self.context.get(context_id)) + .cloned(); + attach_context_to_message(&mut context_message, referenced_context); + + request.messages.push(context_message); + } + request }