diff --git a/crates/assistant2/src/context_store.rs b/crates/assistant2/src/context_store.rs index 6ce35f8971254db3f40f85e0adcd03c3fda104b3..f2659f53a86365e8bf9078e1c485f3dd24211631 100644 --- a/crates/assistant2/src/context_store.rs +++ b/crates/assistant2/src/context_store.rs @@ -36,12 +36,6 @@ impl ContextStore { &self.context } - pub fn drain(&mut self) -> Vec { - let context = self.context.drain(..).collect(); - self.clear(); - context - } - pub fn clear(&mut self) { self.context.clear(); self.files.clear(); diff --git a/crates/assistant2/src/message_editor.rs b/crates/assistant2/src/message_editor.rs index bad17d7f10b1a88ce84ea6ef6266bc337a81ea49..684b5c5df0c098ced990d5f9cba1fe24f39292b2 100644 --- a/crates/assistant2/src/message_editor.rs +++ b/crates/assistant2/src/message_editor.rs @@ -142,7 +142,9 @@ impl MessageEditor { editor.clear(cx); text }); - let context = self.context_store.update(cx, |this, _cx| this.drain()); + let context = self + .context_store + .update(cx, |this, _cx| this.context().clone()); self.thread.update(cx, |thread, cx| { thread.insert_user_message(user_message, context, cx); diff --git a/crates/assistant2/src/thread.rs b/crates/assistant2/src/thread.rs index 0816e541b4adc50d8b15b7a385d29a3e3a77ff6e..f3de36f8a7a8fe2c634175eef6f269e719e7167b 100644 --- a/crates/assistant2/src/thread.rs +++ b/crates/assistant2/src/thread.rs @@ -3,7 +3,7 @@ use std::sync::Arc; use anyhow::Result; use assistant_tool::ToolWorkingSet; use chrono::{DateTime, Utc}; -use collections::HashMap; +use collections::{HashMap, HashSet}; use futures::future::Shared; use futures::{FutureExt as _, StreamExt as _}; use gpui::{AppContext, EventEmitter, ModelContext, SharedString, Task}; @@ -209,7 +209,13 @@ impl Thread { temperature: None, }; + let mut referenced_context_ids = HashSet::default(); + for message in &self.messages { + if let Some(context_ids) = self.context_by_message.get(&message.id) { + referenced_context_ids.extend(context_ids); + } + let mut request_message = LanguageModelRequestMessage { role: message.role, content: Vec::new(), @@ -224,10 +230,6 @@ impl Thread { } } - if let Some(context) = self.context_for_message(message.id) { - attach_context_to_message(&mut request_message, context.clone()); - } - if !message.text.is_empty() { request_message .content @@ -245,6 +247,22 @@ impl Thread { request.messages.push(request_message); } + if !referenced_context_ids.is_empty() { + let mut context_message = LanguageModelRequestMessage { + role: Role::User, + content: Vec::new(), + cache: false, + }; + + let referenced_context = referenced_context_ids + .into_iter() + .filter_map(|context_id| self.context.get(context_id)) + .cloned(); + attach_context_to_message(&mut context_message, referenced_context); + + request.messages.push(context_message); + } + request }