agent: Snapshot context in user message instead of recreating it (#27967)

Agus Zubiaga and Antonio Scandurra created

This makes context essentially work the same way as `read-file`,
increasing the likelihood of cache hits.

Just like with `read-file`, we'll notify the model when the user makes
an edit to one of the tracked files. In the future, we want to send a
diff instead of just a list of files, but that's an orthogonal change.


Release Notes:
- agent: Improved caching of files in context

---------

Co-authored-by: Antonio Scandurra <me@as-cii.com>

Change summary

crates/agent/src/active_thread.rs               |  54 -
crates/agent/src/context.rs                     |  80 +-
crates/agent/src/thread.rs                      | 531 +++++++++++++++++-
crates/agent/src/thread_store.rs                |   3 
crates/assistant_eval/src/headless_assistant.rs |   2 
crates/assistant_tool/src/action_log.rs         |  17 
6 files changed, 551 insertions(+), 136 deletions(-)

Detailed changes

crates/agent/src/active_thread.rs 🔗

@@ -34,7 +34,7 @@ use ui::{Disclosure, IconButton, KeyBinding, Scrollbar, ScrollbarState, Tooltip,
 use util::ResultExt as _;
 use workspace::{OpenOptions, Workspace};
 
-use crate::context_store::{ContextStore, refresh_context_store_text};
+use crate::context_store::ContextStore;
 
 pub struct ActiveThread {
     language_registry: Arc<LanguageRegistry>,
@@ -593,54 +593,14 @@ impl ActiveThread {
                 }
 
                 if self.thread.read(cx).all_tools_finished() {
-                    let pending_refresh_buffers = self.thread.update(cx, |thread, cx| {
-                        thread.action_log().update(cx, |action_log, _cx| {
-                            action_log.take_stale_buffers_in_context()
-                        })
-                    });
-
-                    let context_update_task = if !pending_refresh_buffers.is_empty() {
-                        let refresh_task = refresh_context_store_text(
-                            self.context_store.clone(),
-                            &pending_refresh_buffers,
-                            cx,
-                        );
-
-                        cx.spawn(async move |this, cx| {
-                            let updated_context_ids = refresh_task.await;
-
-                            this.update(cx, |this, cx| {
-                                this.context_store.read_with(cx, |context_store, _cx| {
-                                    context_store
-                                        .context()
-                                        .iter()
-                                        .filter(|context| {
-                                            updated_context_ids.contains(&context.id())
-                                        })
-                                        .cloned()
-                                        .collect()
-                                })
-                            })
-                        })
-                    } else {
-                        Task::ready(anyhow::Ok(Vec::new()))
-                    };
-
                     let model_registry = LanguageModelRegistry::read_global(cx);
                     if let Some(model) = model_registry.active_model() {
-                        cx.spawn(async move |this, cx| {
-                            let updated_context = context_update_task.await?;
-
-                            this.update(cx, |this, cx| {
-                                this.thread.update(cx, |thread, cx| {
-                                    thread.attach_tool_results(updated_context, cx);
-                                    if !canceled {
-                                        thread.send_to_model(model, RequestKind::Chat, cx);
-                                    }
-                                });
-                            })
-                        })
-                        .detach();
+                        self.thread.update(cx, |thread, cx| {
+                            thread.attach_tool_results(cx);
+                            if !canceled {
+                                thread.send_to_model(model, RequestKind::Chat, cx);
+                            }
+                        });
                     }
                 }
             }

crates/agent/src/context.rs 🔗

@@ -146,11 +146,11 @@ pub struct ContextSymbolId {
     pub range: Range<Anchor>,
 }
 
-pub fn attach_context_to_message<'a>(
-    message: &mut LanguageModelRequestMessage,
+/// Formats a collection of contexts into a string representation
+pub fn format_context_as_string<'a>(
     contexts: impl Iterator<Item = &'a AssistantContext>,
     cx: &App,
-) {
+) -> Option<String> {
     let mut file_context = Vec::new();
     let mut directory_context = Vec::new();
     let mut symbol_context = Vec::new();
@@ -167,64 +167,78 @@ pub fn attach_context_to_message<'a>(
         }
     }
 
-    let mut context_chunks = Vec::new();
+    if file_context.is_empty()
+        && directory_context.is_empty()
+        && symbol_context.is_empty()
+        && fetch_context.is_empty()
+        && thread_context.is_empty()
+    {
+        return None;
+    }
+
+    let mut result = String::new();
+    result.push_str("\n<context>\n\
+        The following items were attached by the user. You don't need to use other tools to read them.\n\n");
 
     if !file_context.is_empty() {
-        context_chunks.push("<files>\n");
+        result.push_str("<files>\n");
         for context in file_context {
-            context_chunks.push(&context.context_buffer.text);
+            result.push_str(&context.context_buffer.text);
         }
-        context_chunks.push("\n</files>\n");
+        result.push_str("</files>\n");
     }
 
     if !directory_context.is_empty() {
-        context_chunks.push("<directories>\n");
+        result.push_str("<directories>\n");
         for context in directory_context {
             for context_buffer in &context.context_buffers {
-                context_chunks.push(&context_buffer.text);
+                result.push_str(&context_buffer.text);
             }
         }
-        context_chunks.push("\n</directories>\n");
+        result.push_str("</directories>\n");
     }
 
     if !symbol_context.is_empty() {
-        context_chunks.push("<symbols>\n");
+        result.push_str("<symbols>\n");
         for context in symbol_context {
-            context_chunks.push(&context.context_symbol.text);
+            result.push_str(&context.context_symbol.text);
+            result.push('\n');
         }
-        context_chunks.push("\n</symbols>\n");
+        result.push_str("</symbols>\n");
     }
 
     if !fetch_context.is_empty() {
-        context_chunks.push("<fetched_urls>\n");
+        result.push_str("<fetched_urls>\n");
         for context in &fetch_context {
-            context_chunks.push(&context.url);
-            context_chunks.push(&context.text);
+            result.push_str(&context.url);
+            result.push('\n');
+            result.push_str(&context.text);
+            result.push('\n');
         }
-        context_chunks.push("\n</fetched_urls>\n");
+        result.push_str("</fetched_urls>\n");
     }
 
-    // Need to own the SharedString for summary so that it can be referenced.
-    let mut thread_context_chunks = Vec::new();
     if !thread_context.is_empty() {
-        context_chunks.push("<conversation_threads>\n");
+        result.push_str("<conversation_threads>\n");
         for context in &thread_context {
-            thread_context_chunks.push(context.summary(cx));
-            thread_context_chunks.push(context.text.clone());
+            result.push_str(&context.summary(cx));
+            result.push('\n');
+            result.push_str(&context.text);
+            result.push('\n');
         }
-        context_chunks.push("\n</conversation_threads>\n");
+        result.push_str("</conversation_threads>\n");
     }
 
-    for chunk in &thread_context_chunks {
-        context_chunks.push(chunk);
-    }
+    result.push_str("</context>\n");
+    Some(result)
+}
 
-    if !context_chunks.is_empty() {
-        message.content.push(
-            "\n<context>\n\
-                The following items were attached by the user. You don't need to use other tools to read them.\n\n".into(),
-        );
-        message.content.push(context_chunks.join("\n").into());
-        message.content.push("\n</context>\n".into());
+pub fn attach_context_to_message<'a>(
+    message: &mut LanguageModelRequestMessage,
+    contexts: impl Iterator<Item = &'a AssistantContext>,
+    cx: &App,
+) {
+    if let Some(context_string) = format_context_as_string(contexts, cx) {
+        message.content.push(context_string.into());
     }
 }

crates/agent/src/thread.rs 🔗

@@ -7,7 +7,7 @@ use anyhow::{Context as _, Result, anyhow};
 use assistant_settings::AssistantSettings;
 use assistant_tool::{ActionLog, Tool, ToolWorkingSet};
 use chrono::{DateTime, Utc};
-use collections::{BTreeMap, HashMap, HashSet};
+use collections::{BTreeMap, HashMap};
 use fs::Fs;
 use futures::future::Shared;
 use futures::{FutureExt, StreamExt as _};
@@ -30,7 +30,7 @@ use settings::Settings;
 use util::{ResultExt as _, TryFutureExt as _, maybe, post_inc};
 use uuid::Uuid;
 
-use crate::context::{AssistantContext, ContextId, attach_context_to_message};
+use crate::context::{AssistantContext, ContextId, format_context_as_string};
 use crate::thread_store::{
     SerializedMessage, SerializedMessageSegment, SerializedThread, SerializedToolResult,
     SerializedToolUse,
@@ -82,6 +82,7 @@ pub struct Message {
     pub id: MessageId,
     pub role: Role,
     pub segments: Vec<MessageSegment>,
+    pub context: String,
 }
 
 impl Message {
@@ -110,6 +111,11 @@ impl Message {
 
     pub fn to_string(&self) -> String {
         let mut result = String::new();
+
+        if !self.context.is_empty() {
+            result.push_str(&self.context);
+        }
+
         for segment in &self.segments {
             match segment {
                 MessageSegment::Text(text) => result.push_str(text),
@@ -120,11 +126,12 @@ impl Message {
                 }
             }
         }
+
         result
     }
 }
 
-#[derive(Debug, Clone)]
+#[derive(Debug, Clone, PartialEq, Eq)]
 pub enum MessageSegment {
     Text(String),
     Thinking(String),
@@ -335,6 +342,7 @@ impl Thread {
                             }
                         })
                         .collect(),
+                    context: message.context,
                 })
                 .collect(),
             next_message_id,
@@ -595,15 +603,58 @@ impl Thread {
         git_checkpoint: Option<GitStoreCheckpoint>,
         cx: &mut Context<Self>,
     ) -> MessageId {
-        let message_id =
-            self.insert_message(Role::User, vec![MessageSegment::Text(text.into())], cx);
-        let context_ids = context
+        let text = text.into();
+
+        let message_id = self.insert_message(Role::User, vec![MessageSegment::Text(text)], cx);
+
+        // Filter out contexts that have already been included in previous messages
+        let new_context: Vec<_> = context
+            .into_iter()
+            .filter(|ctx| !self.context.contains_key(&ctx.id()))
+            .collect();
+
+        if !new_context.is_empty() {
+            if let Some(context_string) = format_context_as_string(new_context.iter(), cx) {
+                if let Some(message) = self.messages.iter_mut().find(|m| m.id == message_id) {
+                    message.context = context_string;
+                }
+            }
+
+            self.action_log.update(cx, |log, cx| {
+                // Track all buffers added as context
+                for ctx in &new_context {
+                    match ctx {
+                        AssistantContext::File(file_ctx) => {
+                            log.buffer_added_as_context(file_ctx.context_buffer.buffer.clone(), cx);
+                        }
+                        AssistantContext::Directory(dir_ctx) => {
+                            for context_buffer in &dir_ctx.context_buffers {
+                                log.buffer_added_as_context(context_buffer.buffer.clone(), cx);
+                            }
+                        }
+                        AssistantContext::Symbol(symbol_ctx) => {
+                            log.buffer_added_as_context(
+                                symbol_ctx.context_symbol.buffer.clone(),
+                                cx,
+                            );
+                        }
+                        AssistantContext::FetchedUrl(_) | AssistantContext::Thread(_) => {}
+                    }
+                }
+            });
+        }
+
+        let context_ids = new_context
             .iter()
             .map(|context| context.id())
             .collect::<Vec<_>>();
-        self.context
-            .extend(context.into_iter().map(|context| (context.id(), context)));
+        self.context.extend(
+            new_context
+                .into_iter()
+                .map(|context| (context.id(), context)),
+        );
         self.context_by_message.insert(message_id, context_ids);
+
         if let Some(git_checkpoint) = git_checkpoint {
             self.pending_checkpoint = Some(ThreadCheckpoint {
                 message_id,
@@ -620,7 +671,12 @@ impl Thread {
         cx: &mut Context<Self>,
     ) -> MessageId {
         let id = self.next_message_id.post_inc();
-        self.messages.push(Message { id, role, segments });
+        self.messages.push(Message {
+            id,
+            role,
+            segments,
+            context: String::new(),
+        });
         self.touch_updated_at();
         cx.emit(ThreadEvent::MessageAdded(id));
         id
@@ -726,6 +782,7 @@ impl Thread {
                                 content: tool_result.content.clone(),
                             })
                             .collect(),
+                        context: message.context.clone(),
                     })
                     .collect(),
                 initial_project_snapshot,
@@ -912,8 +969,6 @@ impl Thread {
             log::error!("system_prompt_context not set.")
         }
 
-        let mut added_context_ids = HashSet::<ContextId>::default();
-
         for message in &self.messages {
             let mut request_message = LanguageModelRequestMessage {
                 role: message.role,
@@ -934,23 +989,6 @@ impl Thread {
                 }
             }
 
-            // Attach context to this message if it's the first to reference it
-            if let Some(context_ids) = self.context_by_message.get(&message.id) {
-                let new_context_ids: Vec<_> = context_ids
-                    .iter()
-                    .filter(|id| !added_context_ids.contains(id))
-                    .collect();
-
-                if !new_context_ids.is_empty() {
-                    let referenced_context = new_context_ids
-                        .iter()
-                        .filter_map(|context_id| self.context.get(*context_id));
-
-                    attach_context_to_message(&mut request_message, referenced_context, cx);
-                    added_context_ids.extend(context_ids.iter());
-                }
-            }
-
             if !message.segments.is_empty() {
                 request_message
                     .content
@@ -970,11 +1008,9 @@ impl Thread {
             request.messages.push(request_message);
         }
 
-        // Set a cache breakpoint at the second-to-last message.
         // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
-        let breakpoint_index = request.messages.len() - 2;
-        for (index, message) in request.messages.iter_mut().enumerate() {
-            message.cache = index == breakpoint_index;
+        if let Some(last) = request.messages.last_mut() {
+            last.cache = true;
         }
 
         self.attached_tracked_files_state(&mut request.messages, cx);
@@ -999,7 +1035,7 @@ impl Thread {
             };
 
             if stale_message.is_empty() {
-                write!(&mut stale_message, "{}", STALE_FILES_HEADER).ok();
+                write!(&mut stale_message, "{}\n", STALE_FILES_HEADER).ok();
             }
 
             writeln!(&mut stale_message, "- {}", file.path().display()).ok();
@@ -1453,17 +1489,7 @@ impl Thread {
         })
     }
 
-    pub fn attach_tool_results(
-        &mut self,
-        updated_context: Vec<AssistantContext>,
-        cx: &mut Context<Self>,
-    ) {
-        self.context.extend(
-            updated_context
-                .into_iter()
-                .map(|context| (context.id(), context)),
-        );
-
+    pub fn attach_tool_results(&mut self, cx: &mut Context<Self>) {
         // Insert a user message to contain the tool results.
         self.insert_user_message(
             // TODO: Sending up a user message without any content results in the model sending back
@@ -1672,6 +1698,11 @@ impl Thread {
                     Role::System => "System",
                 }
             )?;
+
+            if !message.context.is_empty() {
+                writeln!(markdown, "{}", message.context)?;
+            }
+
             for segment in &message.segments {
                 match segment {
                     MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
@@ -1828,3 +1859,415 @@ struct PendingCompletion {
     id: usize,
     _task: Task<()>,
 }
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+    use crate::{ThreadStore, context_store::ContextStore, thread_store};
+    use assistant_settings::AssistantSettings;
+    use context_server::ContextServerSettings;
+    use editor::EditorSettings;
+    use gpui::TestAppContext;
+    use project::{FakeFs, Project};
+    use prompt_store::PromptBuilder;
+    use serde_json::json;
+    use settings::{Settings, SettingsStore};
+    use std::sync::Arc;
+    use theme::ThemeSettings;
+    use util::path;
+    use workspace::Workspace;
+
+    #[gpui::test]
+    async fn test_message_with_context(cx: &mut TestAppContext) {
+        init_test_settings(cx);
+
+        let project = create_test_project(
+            cx,
+            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
+        )
+        .await;
+
+        let (_workspace, _thread_store, thread, context_store) =
+            setup_test_environment(cx, project.clone()).await;
+
+        add_file_to_context(&project, &context_store, "test/code.rs", cx)
+            .await
+            .unwrap();
+
+        let context =
+            context_store.update(cx, |store, _| store.context().first().cloned().unwrap());
+
+        // Insert user message with context
+        let message_id = thread.update(cx, |thread, cx| {
+            thread.insert_user_message("Please explain this code", vec![context], None, cx)
+        });
+
+        // Check content and context in message object
+        let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
+
+        // Use different path format strings based on platform for the test
+        #[cfg(windows)]
+        let path_part = r"test\code.rs";
+        #[cfg(not(windows))]
+        let path_part = "test/code.rs";
+
+        let expected_context = format!(
+            r#"
+<context>
+The following items were attached by the user. You don't need to use other tools to read them.
+
+<files>
+```rs {path_part}
+fn main() {{
+    println!("Hello, world!");
+}}
+```
+</files>
+</context>
+"#
+        );
+
+        assert_eq!(message.role, Role::User);
+        assert_eq!(message.segments.len(), 1);
+        assert_eq!(
+            message.segments[0],
+            MessageSegment::Text("Please explain this code".to_string())
+        );
+        assert_eq!(message.context, expected_context);
+
+        // Check message in request
+        let request = thread.read_with(cx, |thread, cx| {
+            thread.to_completion_request(RequestKind::Chat, cx)
+        });
+
+        assert_eq!(request.messages.len(), 1);
+        let expected_full_message = format!("{}Please explain this code", expected_context);
+        assert_eq!(request.messages[0].string_contents(), expected_full_message);
+    }
+
+    #[gpui::test]
+    async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
+        init_test_settings(cx);
+
+        let project = create_test_project(
+            cx,
+            json!({
+                "file1.rs": "fn function1() {}\n",
+                "file2.rs": "fn function2() {}\n",
+                "file3.rs": "fn function3() {}\n",
+            }),
+        )
+        .await;
+
+        let (_, _thread_store, thread, context_store) =
+            setup_test_environment(cx, project.clone()).await;
+
+        // Open files individually
+        add_file_to_context(&project, &context_store, "test/file1.rs", cx)
+            .await
+            .unwrap();
+        add_file_to_context(&project, &context_store, "test/file2.rs", cx)
+            .await
+            .unwrap();
+        add_file_to_context(&project, &context_store, "test/file3.rs", cx)
+            .await
+            .unwrap();
+
+        // Get the context objects
+        let contexts = context_store.update(cx, |store, _| store.context().clone());
+        assert_eq!(contexts.len(), 3);
+
+        // First message with context 1
+        let message1_id = thread.update(cx, |thread, cx| {
+            thread.insert_user_message("Message 1", vec![contexts[0].clone()], None, cx)
+        });
+
+        // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
+        let message2_id = thread.update(cx, |thread, cx| {
+            thread.insert_user_message(
+                "Message 2",
+                vec![contexts[0].clone(), contexts[1].clone()],
+                None,
+                cx,
+            )
+        });
+
+        // Third message with all three contexts (contexts 1 and 2 should be skipped)
+        let message3_id = thread.update(cx, |thread, cx| {
+            thread.insert_user_message(
+                "Message 3",
+                vec![
+                    contexts[0].clone(),
+                    contexts[1].clone(),
+                    contexts[2].clone(),
+                ],
+                None,
+                cx,
+            )
+        });
+
+        // Check what contexts are included in each message
+        let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
+            (
+                thread.message(message1_id).unwrap().clone(),
+                thread.message(message2_id).unwrap().clone(),
+                thread.message(message3_id).unwrap().clone(),
+            )
+        });
+
+        // First message should include context 1
+        assert!(message1.context.contains("file1.rs"));
+
+        // Second message should include only context 2 (not 1)
+        assert!(!message2.context.contains("file1.rs"));
+        assert!(message2.context.contains("file2.rs"));
+
+        // Third message should include only context 3 (not 1 or 2)
+        assert!(!message3.context.contains("file1.rs"));
+        assert!(!message3.context.contains("file2.rs"));
+        assert!(message3.context.contains("file3.rs"));
+
+        // Check entire request to make sure all contexts are properly included
+        let request = thread.read_with(cx, |thread, cx| {
+            thread.to_completion_request(RequestKind::Chat, cx)
+        });
+
+        // The request should contain all 3 messages
+        assert_eq!(request.messages.len(), 3);
+
+        // Check that the contexts are properly formatted in each message
+        assert!(request.messages[0].string_contents().contains("file1.rs"));
+        assert!(!request.messages[0].string_contents().contains("file2.rs"));
+        assert!(!request.messages[0].string_contents().contains("file3.rs"));
+
+        assert!(!request.messages[1].string_contents().contains("file1.rs"));
+        assert!(request.messages[1].string_contents().contains("file2.rs"));
+        assert!(!request.messages[1].string_contents().contains("file3.rs"));
+
+        assert!(!request.messages[2].string_contents().contains("file1.rs"));
+        assert!(!request.messages[2].string_contents().contains("file2.rs"));
+        assert!(request.messages[2].string_contents().contains("file3.rs"));
+    }
+
+    #[gpui::test]
+    async fn test_message_without_files(cx: &mut TestAppContext) {
+        init_test_settings(cx);
+
+        let project = create_test_project(
+            cx,
+            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
+        )
+        .await;
+
+        let (_, _thread_store, thread, _context_store) =
+            setup_test_environment(cx, project.clone()).await;
+
+        // Insert user message without any context (empty context vector)
+        let message_id = thread.update(cx, |thread, cx| {
+            thread.insert_user_message("What is the best way to learn Rust?", vec![], None, cx)
+        });
+
+        // Check content and context in message object
+        let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
+
+        // Context should be empty when no files are included
+        assert_eq!(message.role, Role::User);
+        assert_eq!(message.segments.len(), 1);
+        assert_eq!(
+            message.segments[0],
+            MessageSegment::Text("What is the best way to learn Rust?".to_string())
+        );
+        assert_eq!(message.context, "");
+
+        // Check message in request
+        let request = thread.read_with(cx, |thread, cx| {
+            thread.to_completion_request(RequestKind::Chat, cx)
+        });
+
+        assert_eq!(request.messages.len(), 1);
+        assert_eq!(
+            request.messages[0].string_contents(),
+            "What is the best way to learn Rust?"
+        );
+
+        // Add second message, also without context
+        let message2_id = thread.update(cx, |thread, cx| {
+            thread.insert_user_message("Are there any good books?", vec![], None, cx)
+        });
+
+        let message2 =
+            thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
+        assert_eq!(message2.context, "");
+
+        // Check that both messages appear in the request
+        let request = thread.read_with(cx, |thread, cx| {
+            thread.to_completion_request(RequestKind::Chat, cx)
+        });
+
+        assert_eq!(request.messages.len(), 2);
+        assert_eq!(
+            request.messages[0].string_contents(),
+            "What is the best way to learn Rust?"
+        );
+        assert_eq!(
+            request.messages[1].string_contents(),
+            "Are there any good books?"
+        );
+    }
+
+    #[gpui::test]
+    async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
+        init_test_settings(cx);
+
+        let project = create_test_project(
+            cx,
+            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
+        )
+        .await;
+
+        let (_workspace, _thread_store, thread, context_store) =
+            setup_test_environment(cx, project.clone()).await;
+
+        // Open buffer and add it to context
+        let buffer = add_file_to_context(&project, &context_store, "test/code.rs", cx)
+            .await
+            .unwrap();
+
+        let context =
+            context_store.update(cx, |store, _| store.context().first().cloned().unwrap());
+
+        // Insert user message with the buffer as context
+        thread.update(cx, |thread, cx| {
+            thread.insert_user_message("Explain this code", vec![context], None, cx)
+        });
+
+        // Create a request and check that it doesn't have a stale buffer warning yet
+        let initial_request = thread.read_with(cx, |thread, cx| {
+            thread.to_completion_request(RequestKind::Chat, cx)
+        });
+
+        // Make sure we don't have a stale file warning yet
+        let has_stale_warning = initial_request.messages.iter().any(|msg| {
+            msg.string_contents()
+                .contains("These files changed since last read:")
+        });
+        assert!(
+            !has_stale_warning,
+            "Should not have stale buffer warning before buffer is modified"
+        );
+
+        // Modify the buffer
+        buffer.update(cx, |buffer, cx| {
+            // Find a position at the end of line 1
+            buffer.edit(
+                [(1..1, "\n    println!(\"Added a new line\");\n")],
+                None,
+                cx,
+            );
+        });
+
+        // Insert another user message without context
+        thread.update(cx, |thread, cx| {
+            thread.insert_user_message("What does the code do now?", vec![], None, cx)
+        });
+
+        // Create a new request and check for the stale buffer warning
+        let new_request = thread.read_with(cx, |thread, cx| {
+            thread.to_completion_request(RequestKind::Chat, cx)
+        });
+
+        // We should have a stale file warning as the last message
+        let last_message = new_request
+            .messages
+            .last()
+            .expect("Request should have messages");
+
+        // The last message should be the stale buffer notification
+        assert_eq!(last_message.role, Role::User);
+
+        // Check the exact content of the message
+        let expected_content = "These files changed since last read:\n- code.rs\n";
+        assert_eq!(
+            last_message.string_contents(),
+            expected_content,
+            "Last message should be exactly the stale buffer notification"
+        );
+    }
+
+    fn init_test_settings(cx: &mut TestAppContext) {
+        cx.update(|cx| {
+            let settings_store = SettingsStore::test(cx);
+            cx.set_global(settings_store);
+            language::init(cx);
+            Project::init_settings(cx);
+            AssistantSettings::register(cx);
+            thread_store::init(cx);
+            workspace::init_settings(cx);
+            ThemeSettings::register(cx);
+            ContextServerSettings::register(cx);
+            EditorSettings::register(cx);
+        });
+    }
+
+    // Helper to create a test project with test files
+    async fn create_test_project(
+        cx: &mut TestAppContext,
+        files: serde_json::Value,
+    ) -> Entity<Project> {
+        let fs = FakeFs::new(cx.executor());
+        fs.insert_tree(path!("/test"), files).await;
+        Project::test(fs, [path!("/test").as_ref()], cx).await
+    }
+
+    async fn setup_test_environment(
+        cx: &mut TestAppContext,
+        project: Entity<Project>,
+    ) -> (
+        Entity<Workspace>,
+        Entity<ThreadStore>,
+        Entity<Thread>,
+        Entity<ContextStore>,
+    ) {
+        let (workspace, cx) =
+            cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
+
+        let thread_store = cx.update(|_, cx| {
+            ThreadStore::new(
+                project.clone(),
+                Arc::default(),
+                Arc::new(PromptBuilder::new(None).unwrap()),
+                cx,
+            )
+            .unwrap()
+        });
+
+        let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
+        let context_store = cx.new(|_cx| ContextStore::new(workspace.downgrade(), None));
+
+        (workspace, thread_store, thread, context_store)
+    }
+
+    async fn add_file_to_context(
+        project: &Entity<Project>,
+        context_store: &Entity<ContextStore>,
+        path: &str,
+        cx: &mut TestAppContext,
+    ) -> Result<Entity<language::Buffer>> {
+        let buffer_path = project
+            .read_with(cx, |project, cx| project.find_project_path(path, cx))
+            .unwrap();
+
+        let buffer = project
+            .update(cx, |project, cx| project.open_buffer(buffer_path, cx))
+            .await
+            .unwrap();
+
+        context_store
+            .update(cx, |store, cx| {
+                store.add_file_from_buffer(buffer.clone(), cx)
+            })
+            .await?;
+
+        Ok(buffer)
+    }
+}

crates/agent/src/thread_store.rs 🔗

@@ -374,6 +374,8 @@ pub struct SerializedMessage {
     pub tool_uses: Vec<SerializedToolUse>,
     #[serde(default)]
     pub tool_results: Vec<SerializedToolResult>,
+    #[serde(default)]
+    pub context: String,
 }
 
 #[derive(Debug, Serialize, Deserialize)]
@@ -441,6 +443,7 @@ impl LegacySerializedMessage {
             segments: vec![SerializedMessageSegment::Text { text: self.text }],
             tool_uses: self.tool_uses,
             tool_results: self.tool_results,
+            context: String::new(),
         }
     }
 }

crates/assistant_eval/src/headless_assistant.rs 🔗

@@ -124,7 +124,7 @@ impl HeadlessAssistant {
                     let model_registry = LanguageModelRegistry::read_global(cx);
                     if let Some(model) = model_registry.active_model() {
                         thread.update(cx, |thread, cx| {
-                            thread.attach_tool_results(vec![], cx);
+                            thread.attach_tool_results(cx);
                             thread.send_to_model(model, RequestKind::Chat, cx);
                         });
                     }

crates/assistant_tool/src/action_log.rs 🔗

@@ -1,6 +1,6 @@
 use anyhow::{Context as _, Result};
 use buffer_diff::BufferDiff;
-use collections::{BTreeMap, HashSet};
+use collections::BTreeMap;
 use futures::{StreamExt, channel::mpsc};
 use gpui::{App, AppContext, AsyncApp, Context, Entity, Subscription, Task, WeakEntity};
 use language::{Anchor, Buffer, BufferEvent, DiskState, Point};
@@ -10,9 +10,6 @@ use util::RangeExt;
 
 /// Tracks actions performed by tools in a thread
 pub struct ActionLog {
-    /// Buffers that user manually added to the context, and whose content has
-    /// changed since the model last saw them.
-    stale_buffers_in_context: HashSet<Entity<Buffer>>,
     /// Buffers that we want to notify the model about when they change.
     tracked_buffers: BTreeMap<Entity<Buffer>, TrackedBuffer>,
     /// Has the model edited a file since it last checked diagnostics?
@@ -23,7 +20,6 @@ impl ActionLog {
     /// Creates a new, empty action log.
     pub fn new() -> Self {
         Self {
-            stale_buffers_in_context: HashSet::default(),
             tracked_buffers: BTreeMap::default(),
             edited_since_project_diagnostics_check: false,
         }
@@ -259,6 +255,11 @@ impl ActionLog {
         self.track_buffer(buffer, false, cx);
     }
 
+    /// Track a buffer that was added as context, so we can notify the model about user edits.
+    pub fn buffer_added_as_context(&mut self, buffer: Entity<Buffer>, cx: &mut Context<Self>) {
+        self.track_buffer(buffer, false, cx);
+    }
+
     /// Track a buffer as read, so we can notify the model about user edits.
     pub fn will_create_buffer(&mut self, buffer: Entity<Buffer>, cx: &mut Context<Self>) {
         self.track_buffer(buffer.clone(), true, cx);
@@ -268,7 +269,6 @@ impl ActionLog {
     /// Mark a buffer as edited, so we can refresh it in the context
     pub fn buffer_edited(&mut self, buffer: Entity<Buffer>, cx: &mut Context<Self>) {
         self.edited_since_project_diagnostics_check = true;
-        self.stale_buffers_in_context.insert(buffer.clone());
 
         let tracked_buffer = self.track_buffer(buffer.clone(), false, cx);
         if let TrackedBufferStatus::Deleted = tracked_buffer.status {
@@ -391,11 +391,6 @@ impl ActionLog {
             })
             .map(|(buffer, _)| buffer)
     }
-
-    /// Takes and returns the set of buffers pending refresh, clearing internal state.
-    pub fn take_stale_buffers_in_context(&mut self) -> HashSet<Entity<Buffer>> {
-        std::mem::take(&mut self.stale_buffers_in_context)
-    }
 }
 
 fn apply_non_conflicting_edits(