@@ -142,7 +142,9 @@ impl MessageEditor {
editor.clear(cx);
text
});
- let context = self.context_store.update(cx, |this, _cx| this.drain());
+ let context = self
+ .context_store
+ .update(cx, |this, _cx| this.context().clone());
self.thread.update(cx, |thread, cx| {
thread.insert_user_message(user_message, context, cx);
@@ -3,7 +3,7 @@ use std::sync::Arc;
use anyhow::Result;
use assistant_tool::ToolWorkingSet;
use chrono::{DateTime, Utc};
-use collections::HashMap;
+use collections::{HashMap, HashSet};
use futures::future::Shared;
use futures::{FutureExt as _, StreamExt as _};
use gpui::{AppContext, EventEmitter, ModelContext, SharedString, Task};
@@ -209,7 +209,13 @@ impl Thread {
temperature: None,
};
+ let mut referenced_context_ids = HashSet::default();
+
for message in &self.messages {
+ if let Some(context_ids) = self.context_by_message.get(&message.id) {
+ referenced_context_ids.extend(context_ids);
+ }
+
let mut request_message = LanguageModelRequestMessage {
role: message.role,
content: Vec::new(),
@@ -224,10 +230,6 @@ impl Thread {
}
}
- if let Some(context) = self.context_for_message(message.id) {
- attach_context_to_message(&mut request_message, context.clone());
- }
-
if !message.text.is_empty() {
request_message
.content
@@ -245,6 +247,22 @@ impl Thread {
request.messages.push(request_message);
}
+ if !referenced_context_ids.is_empty() {
+ let mut context_message = LanguageModelRequestMessage {
+ role: Role::User,
+ content: Vec::new(),
+ cache: false,
+ };
+
+ let referenced_context = referenced_context_ids
+ .into_iter()
+ .filter_map(|context_id| self.context.get(context_id))
+ .cloned();
+ attach_context_to_message(&mut context_message, referenced_context);
+
+ request.messages.push(context_message);
+ }
+
request
}