assistant2: Make context persistent in the thread (#22789)

Marshall Bowers created

This PR makes it so the context is persistent in the thread, rather than
having to reattach it for each message.

This PR intentionally does not make an attempt to refresh the attached
context if it changes. That will come in a follow-up.

Release Notes:

- N/A

Change summary

crates/assistant2/src/context_store.rs  |  6 -----
crates/assistant2/src/message_editor.rs |  4 ++
crates/assistant2/src/thread.rs         | 28 ++++++++++++++++++++++----
3 files changed, 26 insertions(+), 12 deletions(-)

Detailed changes

crates/assistant2/src/context_store.rs 🔗

@@ -36,12 +36,6 @@ impl ContextStore {
         &self.context
     }
 
-    pub fn drain(&mut self) -> Vec<Context> {
-        let context = self.context.drain(..).collect();
-        self.clear();
-        context
-    }
-
     pub fn clear(&mut self) {
         self.context.clear();
         self.files.clear();

crates/assistant2/src/message_editor.rs 🔗

@@ -142,7 +142,9 @@ impl MessageEditor {
             editor.clear(cx);
             text
         });
-        let context = self.context_store.update(cx, |this, _cx| this.drain());
+        let context = self
+            .context_store
+            .update(cx, |this, _cx| this.context().clone());
 
         self.thread.update(cx, |thread, cx| {
             thread.insert_user_message(user_message, context, cx);

crates/assistant2/src/thread.rs 🔗

@@ -3,7 +3,7 @@ use std::sync::Arc;
 use anyhow::Result;
 use assistant_tool::ToolWorkingSet;
 use chrono::{DateTime, Utc};
-use collections::HashMap;
+use collections::{HashMap, HashSet};
 use futures::future::Shared;
 use futures::{FutureExt as _, StreamExt as _};
 use gpui::{AppContext, EventEmitter, ModelContext, SharedString, Task};
@@ -209,7 +209,13 @@ impl Thread {
             temperature: None,
         };
 
+        let mut referenced_context_ids = HashSet::default();
+
         for message in &self.messages {
+            if let Some(context_ids) = self.context_by_message.get(&message.id) {
+                referenced_context_ids.extend(context_ids);
+            }
+
             let mut request_message = LanguageModelRequestMessage {
                 role: message.role,
                 content: Vec::new(),
@@ -224,10 +230,6 @@ impl Thread {
                 }
             }
 
-            if let Some(context) = self.context_for_message(message.id) {
-                attach_context_to_message(&mut request_message, context.clone());
-            }
-
             if !message.text.is_empty() {
                 request_message
                     .content
@@ -245,6 +247,22 @@ impl Thread {
             request.messages.push(request_message);
         }
 
+        if !referenced_context_ids.is_empty() {
+            let mut context_message = LanguageModelRequestMessage {
+                role: Role::User,
+                content: Vec::new(),
+                cache: false,
+            };
+
+            let referenced_context = referenced_context_ids
+                .into_iter()
+                .filter_map(|context_id| self.context.get(context_id))
+                .cloned();
+            attach_context_to_message(&mut context_message, referenced_context);
+
+            request.messages.push(context_message);
+        }
+
         request
     }