agent2: Initial infra for checkpoints and message editing (#36120)

Ben Brandt and Antonio Scandurra created

Release Notes:

- N/A

---------

Co-authored-by: Antonio Scandurra <me@as-cii.com>

Change summary

Cargo.lock                             |   2 
crates/acp_thread/Cargo.toml           |   1 
crates/acp_thread/src/acp_thread.rs    | 470 +++++++++++++++--
crates/acp_thread/src/connection.rs    |  34 +
crates/agent2/src/agent.rs             |  38 +
crates/agent2/src/tests/mod.rs         | 190 ++++++-
crates/agent2/src/thread.rs            | 703 +++++++++++++++------------
crates/agent_servers/src/acp/v0.rs     |   1 
crates/agent_servers/src/acp/v1.rs     |   1 
crates/agent_servers/src/claude.rs     |   3 
crates/agent_ui/src/acp/thread_view.rs |  14 
crates/agent_ui/src/agent_diff.rs      |   3 
crates/fs/Cargo.toml                   |   1 
crates/fs/src/fake_git_repo.rs         | 118 ++++
crates/fs/src/fs.rs                    | 364 +++++++++-----
crates/git/Cargo.toml                  |   4 
crates/git/src/git.rs                  |   7 
17 files changed, 1,373 insertions(+), 581 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -31,6 +31,7 @@ dependencies = [
  "ui",
  "url",
  "util",
+ "uuid",
  "watch",
  "workspace-hack",
 ]
@@ -6446,6 +6447,7 @@ dependencies = [
  "log",
  "parking_lot",
  "pretty_assertions",
+ "rand 0.8.5",
  "regex",
  "rope",
  "schemars",

crates/acp_thread/Cargo.toml 🔗

@@ -36,6 +36,7 @@ terminal.workspace = true
 ui.workspace = true
 url.workspace = true
 util.workspace = true
+uuid.workspace = true
 watch.workspace = true
 workspace-hack.workspace = true
 

crates/acp_thread/src/acp_thread.rs 🔗

@@ -9,18 +9,19 @@ pub use mention::*;
 pub use terminal::*;
 
 use action_log::ActionLog;
-use agent_client_protocol::{self as acp};
-use anyhow::{Context as _, Result};
+use agent_client_protocol as acp;
+use anyhow::{Context as _, Result, anyhow};
 use editor::Bias;
 use futures::{FutureExt, channel::oneshot, future::BoxFuture};
 use gpui::{AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
 use itertools::Itertools;
 use language::{Anchor, Buffer, BufferSnapshot, LanguageRegistry, Point, ToPoint, text_diff};
 use markdown::Markdown;
-use project::{AgentLocation, Project};
+use project::{AgentLocation, Project, git_store::GitStoreCheckpoint};
 use std::collections::HashMap;
 use std::error::Error;
-use std::fmt::Formatter;
+use std::fmt::{Formatter, Write};
+use std::ops::Range;
 use std::process::ExitStatus;
 use std::rc::Rc;
 use std::{fmt::Display, mem, path::PathBuf, sync::Arc};
@@ -29,24 +30,23 @@ use util::ResultExt;
 
 #[derive(Debug)]
 pub struct UserMessage {
+    pub id: Option<UserMessageId>,
     pub content: ContentBlock,
+    pub checkpoint: Option<GitStoreCheckpoint>,
 }
 
 impl UserMessage {
-    pub fn from_acp(
-        message: impl IntoIterator<Item = acp::ContentBlock>,
-        language_registry: Arc<LanguageRegistry>,
-        cx: &mut App,
-    ) -> Self {
-        let mut content = ContentBlock::Empty;
-        for chunk in message {
-            content.append(chunk, &language_registry, cx)
-        }
-        Self { content: content }
-    }
-
     fn to_markdown(&self, cx: &App) -> String {
-        format!("## User\n\n{}\n\n", self.content.to_markdown(cx))
+        let mut markdown = String::new();
+        if let Some(_) = self.checkpoint {
+            writeln!(markdown, "## User (checkpoint)").unwrap();
+        } else {
+            writeln!(markdown, "## User").unwrap();
+        }
+        writeln!(markdown).unwrap();
+        writeln!(markdown, "{}", self.content.to_markdown(cx)).unwrap();
+        writeln!(markdown).unwrap();
+        markdown
     }
 }
 
@@ -633,6 +633,7 @@ pub struct AcpThread {
 pub enum AcpThreadEvent {
     NewEntry,
     EntryUpdated(usize),
+    EntriesRemoved(Range<usize>),
     ToolAuthorizationRequired,
     Stopped,
     Error,
@@ -772,7 +773,7 @@ impl AcpThread {
     ) -> Result<()> {
         match update {
             acp::SessionUpdate::UserMessageChunk { content } => {
-                self.push_user_content_block(content, cx);
+                self.push_user_content_block(None, content, cx);
             }
             acp::SessionUpdate::AgentMessageChunk { content } => {
                 self.push_assistant_content_block(content, false, cx);
@@ -793,18 +794,32 @@ impl AcpThread {
         Ok(())
     }
 
-    pub fn push_user_content_block(&mut self, chunk: acp::ContentBlock, cx: &mut Context<Self>) {
+    pub fn push_user_content_block(
+        &mut self,
+        message_id: Option<UserMessageId>,
+        chunk: acp::ContentBlock,
+        cx: &mut Context<Self>,
+    ) {
         let language_registry = self.project.read(cx).languages().clone();
         let entries_len = self.entries.len();
 
         if let Some(last_entry) = self.entries.last_mut()
-            && let AgentThreadEntry::UserMessage(UserMessage { content }) = last_entry
+            && let AgentThreadEntry::UserMessage(UserMessage { id, content, .. }) = last_entry
         {
+            *id = message_id.or(id.take());
             content.append(chunk, &language_registry, cx);
-            cx.emit(AcpThreadEvent::EntryUpdated(entries_len - 1));
+            let idx = entries_len - 1;
+            cx.emit(AcpThreadEvent::EntryUpdated(idx));
         } else {
             let content = ContentBlock::new(chunk, &language_registry, cx);
-            self.push_entry(AgentThreadEntry::UserMessage(UserMessage { content }), cx);
+            self.push_entry(
+                AgentThreadEntry::UserMessage(UserMessage {
+                    id: message_id,
+                    content,
+                    checkpoint: None,
+                }),
+                cx,
+            );
         }
     }
 
@@ -819,7 +834,8 @@ impl AcpThread {
         if let Some(last_entry) = self.entries.last_mut()
             && let AgentThreadEntry::AssistantMessage(AssistantMessage { chunks }) = last_entry
         {
-            cx.emit(AcpThreadEvent::EntryUpdated(entries_len - 1));
+            let idx = entries_len - 1;
+            cx.emit(AcpThreadEvent::EntryUpdated(idx));
             match (chunks.last_mut(), is_thought) {
                 (Some(AssistantMessageChunk::Message { block }), false)
                 | (Some(AssistantMessageChunk::Thought { block }), true) => {
@@ -1118,69 +1134,113 @@ impl AcpThread {
             self.project.read(cx).languages().clone(),
             cx,
         );
+        let git_store = self.project.read(cx).git_store().clone();
+
+        let old_checkpoint = git_store.update(cx, |git, cx| git.checkpoint(cx));
+        let message_id = if self
+            .connection
+            .session_editor(&self.session_id, cx)
+            .is_some()
+        {
+            Some(UserMessageId::new())
+        } else {
+            None
+        };
         self.push_entry(
-            AgentThreadEntry::UserMessage(UserMessage { content: block }),
+            AgentThreadEntry::UserMessage(UserMessage {
+                id: message_id.clone(),
+                content: block,
+                checkpoint: None,
+            }),
             cx,
         );
         self.clear_completed_plan_entries(cx);
 
+        let (old_checkpoint_tx, old_checkpoint_rx) = oneshot::channel();
         let (tx, rx) = oneshot::channel();
         let cancel_task = self.cancel(cx);
+        let request = acp::PromptRequest {
+            prompt: message,
+            session_id: self.session_id.clone(),
+        };
 
-        self.send_task = Some(cx.spawn(async move |this, cx| {
-            async {
+        self.send_task = Some(cx.spawn({
+            let message_id = message_id.clone();
+            async move |this, cx| {
                 cancel_task.await;
 
-                let result = this
-                    .update(cx, |this, cx| {
-                        this.connection.prompt(
-                            acp::PromptRequest {
-                                prompt: message,
-                                session_id: this.session_id.clone(),
-                            },
-                            cx,
-                        )
-                    })?
-                    .await;
-
-                tx.send(result).log_err();
-
-                anyhow::Ok(())
+                old_checkpoint_tx.send(old_checkpoint.await).ok();
+                if let Ok(result) = this.update(cx, |this, cx| {
+                    this.connection.prompt(message_id, request, cx)
+                }) {
+                    tx.send(result.await).log_err();
+                }
             }
-            .await
-            .log_err();
         }));
 
-        cx.spawn(async move |this, cx| match rx.await {
-            Ok(Err(e)) => {
-                this.update(cx, |this, cx| {
-                    this.send_task.take();
-                    cx.emit(AcpThreadEvent::Error)
-                })
+        cx.spawn(async move |this, cx| {
+            let old_checkpoint = old_checkpoint_rx
+                .await
+                .map_err(|_| anyhow!("send canceled"))
+                .flatten()
+                .context("failed to get old checkpoint")
                 .log_err();
-                Err(e)?
-            }
-            result => {
-                let cancelled = matches!(
-                    result,
-                    Ok(Ok(acp::PromptResponse {
-                        stop_reason: acp::StopReason::Cancelled
-                    }))
-                );
 
-                // We only take the task if the current prompt wasn't cancelled.
-                //
-                // This prompt may have been cancelled because another one was sent
-                // while it was still generating. In these cases, dropping `send_task`
-                // would cause the next generation to be cancelled.
-                if !cancelled {
-                    this.update(cx, |this, _cx| this.send_task.take()).ok();
-                }
+            let response = rx.await;
 
-                this.update(cx, |_, cx| cx.emit(AcpThreadEvent::Stopped))
+            if let Some((old_checkpoint, message_id)) = old_checkpoint.zip(message_id) {
+                let new_checkpoint = git_store
+                    .update(cx, |git, cx| git.checkpoint(cx))?
+                    .await
+                    .context("failed to get new checkpoint")
                     .log_err();
-                Ok(())
+                if let Some(new_checkpoint) = new_checkpoint {
+                    let equal = git_store
+                        .update(cx, |git, cx| {
+                            git.compare_checkpoints(old_checkpoint.clone(), new_checkpoint, cx)
+                        })?
+                        .await
+                        .unwrap_or(true);
+                    if !equal {
+                        this.update(cx, |this, cx| {
+                            if let Some((ix, message)) = this.user_message_mut(&message_id) {
+                                message.checkpoint = Some(old_checkpoint);
+                                cx.emit(AcpThreadEvent::EntryUpdated(ix));
+                            }
+                        })?;
+                    }
+                }
             }
+
+            this.update(cx, |this, cx| {
+                match response {
+                    Ok(Err(e)) => {
+                        this.send_task.take();
+                        cx.emit(AcpThreadEvent::Error);
+                        Err(e)
+                    }
+                    result => {
+                        let cancelled = matches!(
+                            result,
+                            Ok(Ok(acp::PromptResponse {
+                                stop_reason: acp::StopReason::Cancelled
+                            }))
+                        );
+
+                        // We only take the task if the current prompt wasn't cancelled.
+                        //
+                        // This prompt may have been cancelled because another one was sent
+                        // while it was still generating. In these cases, dropping `send_task`
+                        // would cause the next generation to be cancelled.
+                        if !cancelled {
+                            this.send_task.take();
+                        }
+
+                        cx.emit(AcpThreadEvent::Stopped);
+                        Ok(())
+                    }
+                }
+            })?
         })
         .boxed()
     }
@@ -1212,6 +1272,66 @@ impl AcpThread {
         cx.foreground_executor().spawn(send_task)
     }
 
+    /// Rewinds this thread to before the entry at `index`, removing it and all
+    /// subsequent entries while reverting any changes made from that point.
+    pub fn rewind(&mut self, id: UserMessageId, cx: &mut Context<Self>) -> Task<Result<()>> {
+        let Some(session_editor) = self.connection.session_editor(&self.session_id, cx) else {
+            return Task::ready(Err(anyhow!("not supported")));
+        };
+        let Some(message) = self.user_message(&id) else {
+            return Task::ready(Err(anyhow!("message not found")));
+        };
+
+        let checkpoint = message.checkpoint.clone();
+
+        let git_store = self.project.read(cx).git_store().clone();
+        cx.spawn(async move |this, cx| {
+            if let Some(checkpoint) = checkpoint {
+                git_store
+                    .update(cx, |git, cx| git.restore_checkpoint(checkpoint, cx))?
+                    .await?;
+            }
+
+            cx.update(|cx| session_editor.truncate(id.clone(), cx))?
+                .await?;
+            this.update(cx, |this, cx| {
+                if let Some((ix, _)) = this.user_message_mut(&id) {
+                    let range = ix..this.entries.len();
+                    this.entries.truncate(ix);
+                    cx.emit(AcpThreadEvent::EntriesRemoved(range));
+                }
+            })
+        })
+    }
+
+    fn user_message(&self, id: &UserMessageId) -> Option<&UserMessage> {
+        self.entries.iter().find_map(|entry| {
+            if let AgentThreadEntry::UserMessage(message) = entry {
+                if message.id.as_ref() == Some(&id) {
+                    Some(message)
+                } else {
+                    None
+                }
+            } else {
+                None
+            }
+        })
+    }
+
+    fn user_message_mut(&mut self, id: &UserMessageId) -> Option<(usize, &mut UserMessage)> {
+        self.entries.iter_mut().enumerate().find_map(|(ix, entry)| {
+            if let AgentThreadEntry::UserMessage(message) = entry {
+                if message.id.as_ref() == Some(&id) {
+                    Some((ix, message))
+                } else {
+                    None
+                }
+            } else {
+                None
+            }
+        })
+    }
+
     pub fn read_text_file(
         &self,
         path: PathBuf,
@@ -1414,13 +1534,18 @@ mod tests {
     use futures::{channel::mpsc, future::LocalBoxFuture, select};
     use gpui::{AsyncApp, TestAppContext, WeakEntity};
     use indoc::indoc;
-    use project::FakeFs;
+    use project::{FakeFs, Fs};
     use rand::Rng as _;
     use serde_json::json;
     use settings::SettingsStore;
     use smol::stream::StreamExt as _;
-    use std::{cell::RefCell, path::Path, rc::Rc, time::Duration};
-
+    use std::{
+        cell::RefCell,
+        path::Path,
+        rc::Rc,
+        sync::atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
+        time::Duration,
+    };
     use util::path;
 
     fn init_test(cx: &mut TestAppContext) {
@@ -1452,6 +1577,7 @@ mod tests {
         // Test creating a new user message
         thread.update(cx, |thread, cx| {
             thread.push_user_content_block(
+                None,
                 acp::ContentBlock::Text(acp::TextContent {
                     annotations: None,
                     text: "Hello, ".to_string(),
@@ -1463,6 +1589,7 @@ mod tests {
         thread.update(cx, |thread, cx| {
             assert_eq!(thread.entries.len(), 1);
             if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
+                assert_eq!(user_msg.id, None);
                 assert_eq!(user_msg.content.to_markdown(cx), "Hello, ");
             } else {
                 panic!("Expected UserMessage");
@@ -1470,8 +1597,10 @@ mod tests {
         });
 
         // Test appending to existing user message
+        let message_1_id = UserMessageId::new();
         thread.update(cx, |thread, cx| {
             thread.push_user_content_block(
+                Some(message_1_id.clone()),
                 acp::ContentBlock::Text(acp::TextContent {
                     annotations: None,
                     text: "world!".to_string(),
@@ -1483,6 +1612,7 @@ mod tests {
         thread.update(cx, |thread, cx| {
             assert_eq!(thread.entries.len(), 1);
             if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[0] {
+                assert_eq!(user_msg.id, Some(message_1_id));
                 assert_eq!(user_msg.content.to_markdown(cx), "Hello, world!");
             } else {
                 panic!("Expected UserMessage");
@@ -1501,8 +1631,10 @@ mod tests {
             );
         });
 
+        let message_2_id = UserMessageId::new();
         thread.update(cx, |thread, cx| {
             thread.push_user_content_block(
+                Some(message_2_id.clone()),
                 acp::ContentBlock::Text(acp::TextContent {
                     annotations: None,
                     text: "New user message".to_string(),
@@ -1514,6 +1646,7 @@ mod tests {
         thread.update(cx, |thread, cx| {
             assert_eq!(thread.entries.len(), 3);
             if let AgentThreadEntry::UserMessage(user_msg) = &thread.entries[2] {
+                assert_eq!(user_msg.id, Some(message_2_id));
                 assert_eq!(user_msg.content.to_markdown(cx), "New user message");
             } else {
                 panic!("Expected UserMessage at index 2");
@@ -1830,6 +1963,180 @@ mod tests {
         assert!(cx.read(|cx| !thread.read(cx).has_pending_edit_tool_calls()));
     }
 
+    #[gpui::test(iterations = 10)]
+    async fn test_checkpoints(cx: &mut TestAppContext) {
+        init_test(cx);
+        let fs = FakeFs::new(cx.background_executor.clone());
+        fs.insert_tree(
+            path!("/test"),
+            json!({
+                ".git": {}
+            }),
+        )
+        .await;
+        let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
+
+        let simulate_changes = Arc::new(AtomicBool::new(true));
+        let next_filename = Arc::new(AtomicUsize::new(0));
+        let connection = Rc::new(FakeAgentConnection::new().on_user_message({
+            let simulate_changes = simulate_changes.clone();
+            let next_filename = next_filename.clone();
+            let fs = fs.clone();
+            move |request, thread, mut cx| {
+                let fs = fs.clone();
+                let simulate_changes = simulate_changes.clone();
+                let next_filename = next_filename.clone();
+                async move {
+                    if simulate_changes.load(SeqCst) {
+                        let filename = format!("/test/file-{}", next_filename.fetch_add(1, SeqCst));
+                        fs.write(Path::new(&filename), b"").await?;
+                    }
+
+                    let acp::ContentBlock::Text(content) = &request.prompt[0] else {
+                        panic!("expected text content block");
+                    };
+                    thread.update(&mut cx, |thread, cx| {
+                        thread
+                            .handle_session_update(
+                                acp::SessionUpdate::AgentMessageChunk {
+                                    content: content.text.to_uppercase().into(),
+                                },
+                                cx,
+                            )
+                            .unwrap();
+                    })?;
+                    Ok(acp::PromptResponse {
+                        stop_reason: acp::StopReason::EndTurn,
+                    })
+                }
+                .boxed_local()
+            }
+        }));
+        let thread = connection
+            .new_thread(project, Path::new(path!("/test")), &mut cx.to_async())
+            .await
+            .unwrap();
+
+        cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["Lorem".into()], cx)))
+            .await
+            .unwrap();
+        thread.read_with(cx, |thread, cx| {
+            assert_eq!(
+                thread.to_markdown(cx),
+                indoc! {"
+                    ## User (checkpoint)
+
+                    Lorem
+
+                    ## Assistant
+
+                    LOREM
+
+                "}
+            );
+        });
+        assert_eq!(fs.files(), vec![Path::new("/test/file-0")]);
+
+        cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["ipsum".into()], cx)))
+            .await
+            .unwrap();
+        thread.read_with(cx, |thread, cx| {
+            assert_eq!(
+                thread.to_markdown(cx),
+                indoc! {"
+                    ## User (checkpoint)
+
+                    Lorem
+
+                    ## Assistant
+
+                    LOREM
+
+                    ## User (checkpoint)
+
+                    ipsum
+
+                    ## Assistant
+
+                    IPSUM
+
+                "}
+            );
+        });
+        assert_eq!(
+            fs.files(),
+            vec![Path::new("/test/file-0"), Path::new("/test/file-1")]
+        );
+
+        // Checkpoint isn't stored when there are no changes.
+        simulate_changes.store(false, SeqCst);
+        cx.update(|cx| thread.update(cx, |thread, cx| thread.send(vec!["dolor".into()], cx)))
+            .await
+            .unwrap();
+        thread.read_with(cx, |thread, cx| {
+            assert_eq!(
+                thread.to_markdown(cx),
+                indoc! {"
+                    ## User (checkpoint)
+
+                    Lorem
+
+                    ## Assistant
+
+                    LOREM
+
+                    ## User (checkpoint)
+
+                    ipsum
+
+                    ## Assistant
+
+                    IPSUM
+
+                    ## User
+
+                    dolor
+
+                    ## Assistant
+
+                    DOLOR
+
+                "}
+            );
+        });
+        assert_eq!(
+            fs.files(),
+            vec![Path::new("/test/file-0"), Path::new("/test/file-1")]
+        );
+
+        // Rewinding the conversation truncates the history and restores the checkpoint.
+        thread
+            .update(cx, |thread, cx| {
+                let AgentThreadEntry::UserMessage(message) = &thread.entries[2] else {
+                    panic!("unexpected entries {:?}", thread.entries)
+                };
+                thread.rewind(message.id.clone().unwrap(), cx)
+            })
+            .await
+            .unwrap();
+        thread.read_with(cx, |thread, cx| {
+            assert_eq!(
+                thread.to_markdown(cx),
+                indoc! {"
+                    ## User (checkpoint)
+
+                    Lorem
+
+                    ## Assistant
+
+                    LOREM
+
+                "}
+            );
+        });
+        assert_eq!(fs.files(), vec![Path::new("/test/file-0")]);
+    }
+
     async fn run_until_first_tool_call(
         thread: &Entity<AcpThread>,
         cx: &mut TestAppContext,
@@ -1938,6 +2245,7 @@ mod tests {
 
         fn prompt(
             &self,
+            _id: Option<UserMessageId>,
             params: acp::PromptRequest,
             cx: &mut App,
         ) -> Task<gpui::Result<acp::PromptResponse>> {
@@ -1966,5 +2274,25 @@ mod tests {
             })
             .detach();
         }
+
+        fn session_editor(
+            &self,
+            session_id: &acp::SessionId,
+            _cx: &mut App,
+        ) -> Option<Rc<dyn AgentSessionEditor>> {
+            Some(Rc::new(FakeAgentSessionEditor {
+                _session_id: session_id.clone(),
+            }))
+        }
+    }
+
+    struct FakeAgentSessionEditor {
+        _session_id: acp::SessionId,
+    }
+
+    impl AgentSessionEditor for FakeAgentSessionEditor {
+        fn truncate(&self, _message_id: UserMessageId, _cx: &mut App) -> Task<Result<()>> {
+            Task::ready(Ok(()))
+        }
     }
 }

crates/acp_thread/src/connection.rs 🔗

@@ -1,13 +1,21 @@
-use std::{error::Error, fmt, path::Path, rc::Rc};
-
+use crate::AcpThread;
 use agent_client_protocol::{self as acp};
 use anyhow::Result;
 use collections::IndexMap;
 use gpui::{AsyncApp, Entity, SharedString, Task};
 use project::Project;
+use std::{error::Error, fmt, path::Path, rc::Rc, sync::Arc};
 use ui::{App, IconName};
+use uuid::Uuid;
 
-use crate::AcpThread;
+#[derive(Clone, Debug, Eq, PartialEq)]
+pub struct UserMessageId(Arc<str>);
+
+impl UserMessageId {
+    pub fn new() -> Self {
+        Self(Uuid::new_v4().to_string().into())
+    }
+}
 
 pub trait AgentConnection {
     fn new_thread(
@@ -21,11 +29,23 @@ pub trait AgentConnection {
 
     fn authenticate(&self, method: acp::AuthMethodId, cx: &mut App) -> Task<Result<()>>;
 
-    fn prompt(&self, params: acp::PromptRequest, cx: &mut App)
-    -> Task<Result<acp::PromptResponse>>;
+    fn prompt(
+        &self,
+        user_message_id: Option<UserMessageId>,
+        params: acp::PromptRequest,
+        cx: &mut App,
+    ) -> Task<Result<acp::PromptResponse>>;
 
     fn cancel(&self, session_id: &acp::SessionId, cx: &mut App);
 
+    fn session_editor(
+        &self,
+        _session_id: &acp::SessionId,
+        _cx: &mut App,
+    ) -> Option<Rc<dyn AgentSessionEditor>> {
+        None
+    }
+
     /// Returns this agent as an [Rc<dyn ModelSelector>] if the model selection capability is supported.
     ///
     /// If the agent does not support model selection, returns [None].
@@ -35,6 +55,10 @@ pub trait AgentConnection {
     }
 }
 
+pub trait AgentSessionEditor {
+    fn truncate(&self, message_id: UserMessageId, cx: &mut App) -> Task<Result<()>>;
+}
+
 #[derive(Debug)]
 pub struct AuthRequired;
 

crates/agent2/src/agent.rs 🔗

@@ -1,8 +1,9 @@
 use crate::{AgentResponseEvent, Thread, templates::Templates};
 use crate::{
     ContextServerRegistry, CopyPathTool, CreateDirectoryTool, DiagnosticsTool, EditFileTool,
-    FetchTool, FindPathTool, GrepTool, ListDirectoryTool, MessageContent, MovePathTool, NowTool,
-    OpenTool, ReadFileTool, TerminalTool, ThinkingTool, ToolCallAuthorization, WebSearchTool,
+    FetchTool, FindPathTool, GrepTool, ListDirectoryTool, MovePathTool, NowTool, OpenTool,
+    ReadFileTool, TerminalTool, ThinkingTool, ToolCallAuthorization, UserMessageContent,
+    WebSearchTool,
 };
 use acp_thread::AgentModelSelector;
 use agent_client_protocol as acp;
@@ -637,9 +638,11 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
 
     fn prompt(
         &self,
+        id: Option<acp_thread::UserMessageId>,
         params: acp::PromptRequest,
         cx: &mut App,
     ) -> Task<Result<acp::PromptResponse>> {
+        let id = id.expect("UserMessageId is required");
         let session_id = params.session_id.clone();
         let agent = self.0.clone();
         log::info!("Received prompt request for session: {}", session_id);
@@ -660,13 +663,14 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
                 })?;
             log::debug!("Found session for: {}", session_id);
 
-            let message: Vec<MessageContent> = params
+            let content: Vec<UserMessageContent> = params
                 .prompt
                 .into_iter()
                 .map(Into::into)
                 .collect::<Vec<_>>();
-            log::info!("Converted prompt to message: {} chars", message.len());
-            log::debug!("Message content: {:?}", message);
+            log::info!("Converted prompt to message: {} chars", content.len());
+            log::debug!("Message id: {:?}", id);
+            log::debug!("Message content: {:?}", content);
 
             // Get model using the ModelSelector capability (always available for agent2)
             // Get the selected model from the thread directly
@@ -674,7 +678,8 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
 
             // Send to thread
             log::info!("Sending message to thread with model: {:?}", model.name());
-            let mut response_stream = thread.update(cx, |thread, cx| thread.send(message, cx))?;
+            let mut response_stream =
+                thread.update(cx, |thread, cx| thread.send(id, content, cx))?;
 
             // Handle response stream and forward to session.acp_thread
             while let Some(result) = response_stream.next().await {
@@ -768,6 +773,27 @@ impl acp_thread::AgentConnection for NativeAgentConnection {
             }
         });
     }
+
+    fn session_editor(
+        &self,
+        session_id: &agent_client_protocol::SessionId,
+        cx: &mut App,
+    ) -> Option<Rc<dyn acp_thread::AgentSessionEditor>> {
+        self.0.update(cx, |agent, _cx| {
+            agent
+                .sessions
+                .get(session_id)
+                .map(|session| Rc::new(NativeAgentSessionEditor(session.thread.clone())) as _)
+        })
+    }
+}
+
+struct NativeAgentSessionEditor(Entity<Thread>);
+
+impl acp_thread::AgentSessionEditor for NativeAgentSessionEditor {
+    fn truncate(&self, message_id: acp_thread::UserMessageId, cx: &mut App) -> Task<Result<()>> {
+        Task::ready(self.0.update(cx, |thread, _cx| thread.truncate(message_id)))
+    }
 }
 
 #[cfg(test)]

crates/agent2/src/tests/mod.rs 🔗

@@ -1,6 +1,5 @@
 use super::*;
-use crate::MessageContent;
-use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelList};
+use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelList, UserMessageId};
 use action_log::ActionLog;
 use agent_client_protocol::{self as acp};
 use agent_settings::AgentProfileId;
@@ -38,15 +37,19 @@ async fn test_echo(cx: &mut TestAppContext) {
 
     let events = thread
         .update(cx, |thread, cx| {
-            thread.send("Testing: Reply with 'Hello'", cx)
+            thread.send(UserMessageId::new(), ["Testing: Reply with 'Hello'"], cx)
         })
         .collect()
         .await;
     thread.update(cx, |thread, _cx| {
         assert_eq!(
-            thread.messages().last().unwrap().content,
-            vec![MessageContent::Text("Hello".to_string())]
-        );
+            thread.last_message().unwrap().to_markdown(),
+            indoc! {"
+                ## Assistant
+
+                Hello
+            "}
+        )
     });
     assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
 }
@@ -59,12 +62,13 @@ async fn test_thinking(cx: &mut TestAppContext) {
     let events = thread
         .update(cx, |thread, cx| {
             thread.send(
-                indoc! {"
+                UserMessageId::new(),
+                [indoc! {"
                     Testing:
 
                     Generate a thinking step where you just think the word 'Think',
                     and have your final answer be 'Hello'
-                "},
+                "}],
                 cx,
             )
         })
@@ -72,9 +76,10 @@ async fn test_thinking(cx: &mut TestAppContext) {
         .await;
     thread.update(cx, |thread, _cx| {
         assert_eq!(
-            thread.messages().last().unwrap().to_markdown(),
+            thread.last_message().unwrap().to_markdown(),
             indoc! {"
-                ## assistant
+                ## Assistant
+
                 <think>Think</think>
                 Hello
             "}
@@ -95,7 +100,9 @@ async fn test_system_prompt(cx: &mut TestAppContext) {
 
     project_context.borrow_mut().shell = "test-shell".into();
     thread.update(cx, |thread, _| thread.add_tool(EchoTool));
-    thread.update(cx, |thread, cx| thread.send("abc", cx));
+    thread.update(cx, |thread, cx| {
+        thread.send(UserMessageId::new(), ["abc"], cx)
+    });
     cx.run_until_parked();
     let mut pending_completions = fake_model.pending_completions();
     assert_eq!(
@@ -132,7 +139,8 @@ async fn test_basic_tool_calls(cx: &mut TestAppContext) {
         .update(cx, |thread, cx| {
             thread.add_tool(EchoTool);
             thread.send(
-                "Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'.",
+                UserMessageId::new(),
+                ["Now test the echo tool with 'Hello'. Does it work? Say 'Yes' or 'No'."],
                 cx,
             )
         })
@@ -146,7 +154,11 @@ async fn test_basic_tool_calls(cx: &mut TestAppContext) {
             thread.remove_tool(&AgentTool::name(&EchoTool));
             thread.add_tool(DelayTool);
             thread.send(
-                "Now call the delay tool with 200ms. When the timer goes off, then you echo the output of the tool.",
+                UserMessageId::new(),
+                [
+                    "Now call the delay tool with 200ms.",
+                    "When the timer goes off, then you echo the output of the tool.",
+                ],
                 cx,
             )
         })
@@ -156,13 +168,14 @@ async fn test_basic_tool_calls(cx: &mut TestAppContext) {
     thread.update(cx, |thread, _cx| {
         assert!(
             thread
-                .messages()
-                .last()
+                .last_message()
+                .unwrap()
+                .as_agent_message()
                 .unwrap()
                 .content
                 .iter()
                 .any(|content| {
-                    if let MessageContent::Text(text) = content {
+                    if let AgentMessageContent::Text(text) = content {
                         text.contains("Ding")
                     } else {
                         false
@@ -182,7 +195,7 @@ async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
     // Test a tool call that's likely to complete *before* streaming stops.
     let mut events = thread.update(cx, |thread, cx| {
         thread.add_tool(WordListTool);
-        thread.send("Test the word_list tool.", cx)
+        thread.send(UserMessageId::new(), ["Test the word_list tool."], cx)
     });
 
     let mut saw_partial_tool_use = false;
@@ -190,8 +203,10 @@ async fn test_streaming_tool_calls(cx: &mut TestAppContext) {
         if let Ok(AgentResponseEvent::ToolCall(tool_call)) = event {
             thread.update(cx, |thread, _cx| {
                 // Look for a tool use in the thread's last message
-                let last_content = thread.messages().last().unwrap().content.last().unwrap();
-                if let MessageContent::ToolUse(last_tool_use) = last_content {
+                let message = thread.last_message().unwrap();
+                let agent_message = message.as_agent_message().unwrap();
+                let last_content = agent_message.content.last().unwrap();
+                if let AgentMessageContent::ToolUse(last_tool_use) = last_content {
                     assert_eq!(last_tool_use.name.as_ref(), "word_list");
                     if tool_call.status == acp::ToolCallStatus::Pending {
                         if !last_tool_use.is_input_complete
@@ -229,7 +244,7 @@ async fn test_tool_authorization(cx: &mut TestAppContext) {
 
     let mut events = thread.update(cx, |thread, cx| {
         thread.add_tool(ToolRequiringPermission);
-        thread.send("abc", cx)
+        thread.send(UserMessageId::new(), ["abc"], cx)
     });
     cx.run_until_parked();
     fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
@@ -357,7 +372,9 @@ async fn test_tool_hallucination(cx: &mut TestAppContext) {
     let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
     let fake_model = model.as_fake();
 
-    let mut events = thread.update(cx, |thread, cx| thread.send("abc", cx));
+    let mut events = thread.update(cx, |thread, cx| {
+        thread.send(UserMessageId::new(), ["abc"], cx)
+    });
     cx.run_until_parked();
     fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(
         LanguageModelToolUse {
@@ -449,7 +466,12 @@ async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
         .update(cx, |thread, cx| {
             thread.add_tool(DelayTool);
             thread.send(
-                "Call the delay tool twice in the same message. Once with 100ms. Once with 300ms. When both timers are complete, describe the outputs.",
+                UserMessageId::new(),
+                [
+                    "Call the delay tool twice in the same message.",
+                    "Once with 100ms. Once with 300ms.",
+                    "When both timers are complete, describe the outputs.",
+                ],
                 cx,
             )
         })
@@ -460,12 +482,13 @@ async fn test_concurrent_tool_calls(cx: &mut TestAppContext) {
     assert_eq!(stop_reasons, vec![acp::StopReason::EndTurn]);
 
     thread.update(cx, |thread, _cx| {
-        let last_message = thread.messages().last().unwrap();
-        let text = last_message
+        let last_message = thread.last_message().unwrap();
+        let agent_message = last_message.as_agent_message().unwrap();
+        let text = agent_message
             .content
             .iter()
             .filter_map(|content| {
-                if let MessageContent::Text(text) = content {
+                if let AgentMessageContent::Text(text) = content {
                     Some(text.as_str())
                 } else {
                     None
@@ -521,7 +544,7 @@ async fn test_profiles(cx: &mut TestAppContext) {
     // Test that test-1 profile (default) has echo and delay tools
     thread.update(cx, |thread, cx| {
         thread.set_profile(AgentProfileId("test-1".into()));
-        thread.send("test", cx);
+        thread.send(UserMessageId::new(), ["test"], cx);
     });
     cx.run_until_parked();
 
@@ -539,7 +562,7 @@ async fn test_profiles(cx: &mut TestAppContext) {
     // Switch to test-2 profile, and verify that it has only the infinite tool.
     thread.update(cx, |thread, cx| {
         thread.set_profile(AgentProfileId("test-2".into()));
-        thread.send("test2", cx)
+        thread.send(UserMessageId::new(), ["test2"], cx)
     });
     cx.run_until_parked();
     let mut pending_completions = fake_model.pending_completions();
@@ -562,7 +585,8 @@ async fn test_cancellation(cx: &mut TestAppContext) {
         thread.add_tool(InfiniteTool);
         thread.add_tool(EchoTool);
         thread.send(
-            "Call the echo tool and then call the infinite tool, then explain their output",
+            UserMessageId::new(),
+            ["Call the echo tool, then call the infinite tool, then explain their output"],
             cx,
         )
     });
@@ -607,14 +631,20 @@ async fn test_cancellation(cx: &mut TestAppContext) {
     // Ensure we can still send a new message after cancellation.
     let events = thread
         .update(cx, |thread, cx| {
-            thread.send("Testing: reply with 'Hello' then stop.", cx)
+            thread.send(
+                UserMessageId::new(),
+                ["Testing: reply with 'Hello' then stop."],
+                cx,
+            )
         })
         .collect::<Vec<_>>()
         .await;
     thread.update(cx, |thread, _cx| {
+        let message = thread.last_message().unwrap();
+        let agent_message = message.as_agent_message().unwrap();
         assert_eq!(
-            thread.messages().last().unwrap().content,
-            vec![MessageContent::Text("Hello".to_string())]
+            agent_message.content,
+            vec![AgentMessageContent::Text("Hello".to_string())]
         );
     });
     assert_eq!(stop_events(events), vec![acp::StopReason::EndTurn]);
@@ -625,13 +655,16 @@ async fn test_refusal(cx: &mut TestAppContext) {
     let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
     let fake_model = model.as_fake();
 
-    let events = thread.update(cx, |thread, cx| thread.send("Hello", cx));
+    let events = thread.update(cx, |thread, cx| {
+        thread.send(UserMessageId::new(), ["Hello"], cx)
+    });
     cx.run_until_parked();
     thread.read_with(cx, |thread, _| {
         assert_eq!(
             thread.to_markdown(),
             indoc! {"
-                ## user
+                ## User
+
                 Hello
             "}
         );
@@ -643,9 +676,12 @@ async fn test_refusal(cx: &mut TestAppContext) {
         assert_eq!(
             thread.to_markdown(),
             indoc! {"
-                ## user
+                ## User
+
                 Hello
-                ## assistant
+
+                ## Assistant
+
                 Hey!
             "}
         );
@@ -661,6 +697,85 @@ async fn test_refusal(cx: &mut TestAppContext) {
     });
 }
 
+#[gpui::test]
+async fn test_truncate(cx: &mut TestAppContext) {
+    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
+    let fake_model = model.as_fake();
+
+    let message_id = UserMessageId::new();
+    thread.update(cx, |thread, cx| {
+        thread.send(message_id.clone(), ["Hello"], cx)
+    });
+    cx.run_until_parked();
+    thread.read_with(cx, |thread, _| {
+        assert_eq!(
+            thread.to_markdown(),
+            indoc! {"
+                ## User
+
+                Hello
+            "}
+        );
+    });
+
+    fake_model.send_last_completion_stream_text_chunk("Hey!");
+    cx.run_until_parked();
+    thread.read_with(cx, |thread, _| {
+        assert_eq!(
+            thread.to_markdown(),
+            indoc! {"
+                ## User
+
+                Hello
+
+                ## Assistant
+
+                Hey!
+            "}
+        );
+    });
+
+    thread
+        .update(cx, |thread, _cx| thread.truncate(message_id))
+        .unwrap();
+    cx.run_until_parked();
+    thread.read_with(cx, |thread, _| {
+        assert_eq!(thread.to_markdown(), "");
+    });
+
+    // Ensure we can still send a new message after truncation.
+    thread.update(cx, |thread, cx| {
+        thread.send(UserMessageId::new(), ["Hi"], cx)
+    });
+    thread.update(cx, |thread, _cx| {
+        assert_eq!(
+            thread.to_markdown(),
+            indoc! {"
+                ## User
+
+                Hi
+            "}
+        );
+    });
+    cx.run_until_parked();
+    fake_model.send_last_completion_stream_text_chunk("Ahoy!");
+    cx.run_until_parked();
+    thread.read_with(cx, |thread, _| {
+        assert_eq!(
+            thread.to_markdown(),
+            indoc! {"
+                ## User
+
+                Hi
+
+                ## Assistant
+
+                Ahoy!
+            "}
+        );
+    });
+}
+
 #[gpui::test]
 async fn test_agent_connection(cx: &mut TestAppContext) {
     cx.update(settings::init);
@@ -774,6 +889,7 @@ async fn test_agent_connection(cx: &mut TestAppContext) {
     let result = cx
         .update(|cx| {
             connection.prompt(
+                Some(acp_thread::UserMessageId::new()),
                 acp::PromptRequest {
                     session_id: session_id.clone(),
                     prompt: vec!["ghi".into()],
@@ -796,7 +912,9 @@ async fn test_tool_updates_to_completion(cx: &mut TestAppContext) {
     thread.update(cx, |thread, _cx| thread.add_tool(ThinkingTool));
     let fake_model = model.as_fake();
 
-    let mut events = thread.update(cx, |thread, cx| thread.send("Think", cx));
+    let mut events = thread.update(cx, |thread, cx| {
+        thread.send(UserMessageId::new(), ["Think"], cx)
+    });
     cx.run_until_parked();
 
     // Simulate streaming partial input.

crates/agent2/src/thread.rs 🔗

@@ -1,12 +1,12 @@
 use crate::{ContextServerRegistry, SystemPromptTemplate, Template, Templates};
-use acp_thread::MentionUri;
+use acp_thread::{MentionUri, UserMessageId};
 use action_log::ActionLog;
 use agent_client_protocol as acp;
 use agent_settings::{AgentProfileId, AgentSettings};
 use anyhow::{Context as _, Result, anyhow};
 use assistant_tool::adapt_schema_to_format;
 use cloud_llm_client::{CompletionIntent, CompletionMode};
-use collections::HashMap;
+use collections::IndexMap;
 use fs::Fs;
 use futures::{
     channel::{mpsc, oneshot},
@@ -19,7 +19,6 @@ use language_model::{
     LanguageModelRequestTool, LanguageModelToolResult, LanguageModelToolResultContent,
     LanguageModelToolSchemaFormat, LanguageModelToolUse, LanguageModelToolUseId, Role, StopReason,
 };
-use log;
 use project::Project;
 use prompt_store::ProjectContext;
 use schemars::{JsonSchema, Schema};
@@ -30,49 +29,199 @@ use std::fmt::Write;
 use std::{cell::RefCell, collections::BTreeMap, path::Path, rc::Rc, sync::Arc};
 use util::{ResultExt, markdown::MarkdownCodeBlock};
 
-#[derive(Debug, Clone)]
-pub struct AgentMessage {
-    pub role: Role,
-    pub content: Vec<MessageContent>,
+#[derive(Debug, Clone, PartialEq, Eq)]
+pub enum Message {
+    User(UserMessage),
+    Agent(AgentMessage),
+}
+
+impl Message {
+    pub fn as_agent_message(&self) -> Option<&AgentMessage> {
+        match self {
+            Message::Agent(agent_message) => Some(agent_message),
+            _ => None,
+        }
+    }
+
+    pub fn to_markdown(&self) -> String {
+        match self {
+            Message::User(message) => message.to_markdown(),
+            Message::Agent(message) => message.to_markdown(),
+        }
+    }
+}
+
+#[derive(Debug, Clone, PartialEq, Eq)]
+pub struct UserMessage {
+    pub id: UserMessageId,
+    pub content: Vec<UserMessageContent>,
 }
 
 #[derive(Debug, Clone, PartialEq, Eq)]
-pub enum MessageContent {
+pub enum UserMessageContent {
     Text(String),
-    Thinking {
-        text: String,
-        signature: Option<String>,
-    },
-    Mention {
-        uri: MentionUri,
-        content: String,
-    },
-    RedactedThinking(String),
+    Mention { uri: MentionUri, content: String },
     Image(LanguageModelImage),
-    ToolUse(LanguageModelToolUse),
-    ToolResult(LanguageModelToolResult),
+}
+
+impl UserMessage {
+    pub fn to_markdown(&self) -> String {
+        let mut markdown = String::from("## User\n\n");
+
+        for content in &self.content {
+            match content {
+                UserMessageContent::Text(text) => {
+                    markdown.push_str(text);
+                    markdown.push('\n');
+                }
+                UserMessageContent::Image(_) => {
+                    markdown.push_str("<image />\n");
+                }
+                UserMessageContent::Mention { uri, content } => {
+                    if !content.is_empty() {
+                        markdown.push_str(&format!("{}\n\n{}\n", uri.to_link(), content));
+                    } else {
+                        markdown.push_str(&format!("{}\n", uri.to_link()));
+                    }
+                }
+            }
+        }
+
+        markdown
+    }
+
+    fn to_request(&self) -> LanguageModelRequestMessage {
+        let mut message = LanguageModelRequestMessage {
+            role: Role::User,
+            content: Vec::with_capacity(self.content.len()),
+            cache: false,
+        };
+
+        const OPEN_CONTEXT: &str = "<context>\n\
+            The following items were attached by the user. \
+            They are up-to-date and don't need to be re-read.\n\n";
+
+        const OPEN_FILES_TAG: &str = "<files>";
+        const OPEN_SYMBOLS_TAG: &str = "<symbols>";
+        const OPEN_THREADS_TAG: &str = "<threads>";
+        const OPEN_RULES_TAG: &str =
+            "<rules>\nThe user has specified the following rules that should be applied:\n";
+
+        let mut file_context = OPEN_FILES_TAG.to_string();
+        let mut symbol_context = OPEN_SYMBOLS_TAG.to_string();
+        let mut thread_context = OPEN_THREADS_TAG.to_string();
+        let mut rules_context = OPEN_RULES_TAG.to_string();
+
+        for chunk in &self.content {
+            let chunk = match chunk {
+                UserMessageContent::Text(text) => {
+                    language_model::MessageContent::Text(text.clone())
+                }
+                UserMessageContent::Image(value) => {
+                    language_model::MessageContent::Image(value.clone())
+                }
+                UserMessageContent::Mention { uri, content } => {
+                    match uri {
+                        MentionUri::File(path) | MentionUri::Symbol(path, _) => {
+                            write!(
+                                &mut symbol_context,
+                                "\n{}",
+                                MarkdownCodeBlock {
+                                    tag: &codeblock_tag(&path),
+                                    text: &content.to_string(),
+                                }
+                            )
+                            .ok();
+                        }
+                        MentionUri::Thread(_session_id) => {
+                            write!(&mut thread_context, "\n{}\n", content).ok();
+                        }
+                        MentionUri::Rule(_user_prompt_id) => {
+                            write!(
+                                &mut rules_context,
+                                "\n{}",
+                                MarkdownCodeBlock {
+                                    tag: "",
+                                    text: &content
+                                }
+                            )
+                            .ok();
+                        }
+                    }
+
+                    language_model::MessageContent::Text(uri.to_link())
+                }
+            };
+
+            message.content.push(chunk);
+        }
+
+        let len_before_context = message.content.len();
+
+        if file_context.len() > OPEN_FILES_TAG.len() {
+            file_context.push_str("</files>\n");
+            message
+                .content
+                .push(language_model::MessageContent::Text(file_context));
+        }
+
+        if symbol_context.len() > OPEN_SYMBOLS_TAG.len() {
+            symbol_context.push_str("</symbols>\n");
+            message
+                .content
+                .push(language_model::MessageContent::Text(symbol_context));
+        }
+
+        if thread_context.len() > OPEN_THREADS_TAG.len() {
+            thread_context.push_str("</threads>\n");
+            message
+                .content
+                .push(language_model::MessageContent::Text(thread_context));
+        }
+
+        if rules_context.len() > OPEN_RULES_TAG.len() {
+            rules_context.push_str("</user_rules>\n");
+            message
+                .content
+                .push(language_model::MessageContent::Text(rules_context));
+        }
+
+        if message.content.len() > len_before_context {
+            message.content.insert(
+                len_before_context,
+                language_model::MessageContent::Text(OPEN_CONTEXT.into()),
+            );
+            message
+                .content
+                .push(language_model::MessageContent::Text("</context>".into()));
+        }
+
+        message
+    }
 }
 
 impl AgentMessage {
     pub fn to_markdown(&self) -> String {
-        let mut markdown = format!("## {}\n", self.role);
+        let mut markdown = String::from("## Assistant\n\n");
 
         for content in &self.content {
             match content {
-                MessageContent::Text(text) => {
+                AgentMessageContent::Text(text) => {
                     markdown.push_str(text);
                     markdown.push('\n');
                 }
-                MessageContent::Thinking { text, .. } => {
+                AgentMessageContent::Thinking { text, .. } => {
                     markdown.push_str("<think>");
                     markdown.push_str(text);
                     markdown.push_str("</think>\n");
                 }
-                MessageContent::RedactedThinking(_) => markdown.push_str("<redacted_thinking />\n"),
-                MessageContent::Image(_) => {
+                AgentMessageContent::RedactedThinking(_) => {
+                    markdown.push_str("<redacted_thinking />\n")
+                }
+                AgentMessageContent::Image(_) => {
                     markdown.push_str("<image />\n");
                 }
-                MessageContent::ToolUse(tool_use) => {
+                AgentMessageContent::ToolUse(tool_use) => {
                     markdown.push_str(&format!(
                         "**Tool Use**: {} (ID: {})\n",
                         tool_use.name, tool_use.id
@@ -85,41 +234,106 @@ impl AgentMessage {
                         }
                     ));
                 }
-                MessageContent::ToolResult(tool_result) => {
-                    markdown.push_str(&format!(
-                        "**Tool Result**: {} (ID: {})\n\n",
-                        tool_result.tool_name, tool_result.tool_use_id
-                    ));
-                    if tool_result.is_error {
-                        markdown.push_str("**ERROR:**\n");
-                    }
+            }
+        }
 
-                    match &tool_result.content {
-                        LanguageModelToolResultContent::Text(text) => {
-                            writeln!(markdown, "{text}\n").ok();
-                        }
-                        LanguageModelToolResultContent::Image(_) => {
-                            writeln!(markdown, "<image />\n").ok();
-                        }
-                    }
+        for tool_result in self.tool_results.values() {
+            markdown.push_str(&format!(
+                "**Tool Result**: {} (ID: {})\n\n",
+                tool_result.tool_name, tool_result.tool_use_id
+            ));
+            if tool_result.is_error {
+                markdown.push_str("**ERROR:**\n");
+            }
 
-                    if let Some(output) = tool_result.output.as_ref() {
-                        writeln!(
-                            markdown,
-                            "**Debug Output**:\n\n```json\n{}\n```\n",
-                            serde_json::to_string_pretty(output).unwrap()
-                        )
-                        .unwrap();
-                    }
+            match &tool_result.content {
+                LanguageModelToolResultContent::Text(text) => {
+                    writeln!(markdown, "{text}\n").ok();
                 }
-                MessageContent::Mention { uri, .. } => {
-                    write!(markdown, "{}", uri.to_link()).ok();
+                LanguageModelToolResultContent::Image(_) => {
+                    writeln!(markdown, "<image />\n").ok();
                 }
             }
+
+            if let Some(output) = tool_result.output.as_ref() {
+                writeln!(
+                    markdown,
+                    "**Debug Output**:\n\n```json\n{}\n```\n",
+                    serde_json::to_string_pretty(output).unwrap()
+                )
+                .unwrap();
+            }
         }
 
         markdown
     }
+
+    pub fn to_request(&self) -> Vec<LanguageModelRequestMessage> {
+        let mut content = Vec::with_capacity(self.content.len());
+        for chunk in &self.content {
+            let chunk = match chunk {
+                AgentMessageContent::Text(text) => {
+                    language_model::MessageContent::Text(text.clone())
+                }
+                AgentMessageContent::Thinking { text, signature } => {
+                    language_model::MessageContent::Thinking {
+                        text: text.clone(),
+                        signature: signature.clone(),
+                    }
+                }
+                AgentMessageContent::RedactedThinking(value) => {
+                    language_model::MessageContent::RedactedThinking(value.clone())
+                }
+                AgentMessageContent::ToolUse(value) => {
+                    language_model::MessageContent::ToolUse(value.clone())
+                }
+                AgentMessageContent::Image(value) => {
+                    language_model::MessageContent::Image(value.clone())
+                }
+            };
+            content.push(chunk);
+        }
+
+        let mut messages = vec![LanguageModelRequestMessage {
+            role: Role::Assistant,
+            content,
+            cache: false,
+        }];
+
+        if !self.tool_results.is_empty() {
+            let mut tool_results = Vec::with_capacity(self.tool_results.len());
+            for tool_result in self.tool_results.values() {
+                tool_results.push(language_model::MessageContent::ToolResult(
+                    tool_result.clone(),
+                ));
+            }
+            messages.push(LanguageModelRequestMessage {
+                role: Role::User,
+                content: tool_results,
+                cache: false,
+            });
+        }
+
+        messages
+    }
+}
+
+#[derive(Default, Debug, Clone, PartialEq, Eq)]
+pub struct AgentMessage {
+    pub content: Vec<AgentMessageContent>,
+    pub tool_results: IndexMap<LanguageModelToolUseId, LanguageModelToolResult>,
+}
+
+#[derive(Debug, Clone, PartialEq, Eq)]
+pub enum AgentMessageContent {
+    Text(String),
+    Thinking {
+        text: String,
+        signature: Option<String>,
+    },
+    RedactedThinking(String),
+    Image(LanguageModelImage),
+    ToolUse(LanguageModelToolUse),
 }
 
 #[derive(Debug)]
@@ -140,13 +354,13 @@ pub struct ToolCallAuthorization {
 }
 
 pub struct Thread {
-    messages: Vec<AgentMessage>,
+    messages: Vec<Message>,
     completion_mode: CompletionMode,
     /// Holds the task that handles agent interaction until the end of the turn.
     /// Survives across multiple requests as the model performs tool calls and
     /// we run tools, report their results.
     running_turn: Option<Task<()>>,
-    pending_tool_uses: HashMap<LanguageModelToolUseId, LanguageModelToolUse>,
+    pending_agent_message: Option<AgentMessage>,
     tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
     context_server_registry: Entity<ContextServerRegistry>,
     profile_id: AgentProfileId,
@@ -172,7 +386,7 @@ impl Thread {
             messages: Vec::new(),
             completion_mode: CompletionMode::Normal,
             running_turn: None,
-            pending_tool_uses: HashMap::default(),
+            pending_agent_message: None,
             tools: BTreeMap::default(),
             context_server_registry,
             profile_id,
@@ -196,8 +410,13 @@ impl Thread {
         self.completion_mode = mode;
     }
 
-    pub fn messages(&self) -> &[AgentMessage] {
-        &self.messages
+    #[cfg(any(test, feature = "test-support"))]
+    pub fn last_message(&self) -> Option<Message> {
+        if let Some(message) = self.pending_agent_message.clone() {
+            Some(Message::Agent(message))
+        } else {
+            self.messages.last().cloned()
+        }
     }
 
     pub fn add_tool(&mut self, tool: impl AgentTool) {
@@ -213,35 +432,36 @@ impl Thread {
     }
 
     pub fn cancel(&mut self) {
+        // TODO: do we need to emit a stop::cancel for ACP?
         self.running_turn.take();
+        self.flush_pending_agent_message();
+    }
 
-        let tool_results = self
-            .pending_tool_uses
-            .drain()
-            .map(|(tool_use_id, tool_use)| {
-                MessageContent::ToolResult(LanguageModelToolResult {
-                    tool_use_id,
-                    tool_name: tool_use.name.clone(),
-                    is_error: true,
-                    content: LanguageModelToolResultContent::Text("Tool canceled by user".into()),
-                    output: None,
-                })
-            })
-            .collect::<Vec<_>>();
-        self.last_user_message().content.extend(tool_results);
+    pub fn truncate(&mut self, message_id: UserMessageId) -> Result<()> {
+        self.cancel();
+        let Some(position) = self.messages.iter().position(
+            |msg| matches!(msg, Message::User(UserMessage { id, .. }) if id == &message_id),
+        ) else {
+            return Err(anyhow!("Message not found"));
+        };
+        self.messages.truncate(position);
+        Ok(())
     }
 
     /// Sending a message results in the model streaming a response, which could include tool calls.
     /// After calling tools, the model will stops and waits for any outstanding tool calls to be completed and their results sent.
     /// The returned channel will report all the occurrences in which the model stops before erroring or ending its turn.
-    pub fn send(
+    pub fn send<T>(
         &mut self,
-        content: impl Into<UserMessage>,
+        message_id: UserMessageId,
+        content: impl IntoIterator<Item = T>,
         cx: &mut Context<Self>,
-    ) -> mpsc::UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>> {
-        let content = content.into().0;
-
+    ) -> mpsc::UnboundedReceiver<Result<AgentResponseEvent, LanguageModelCompletionError>>
+    where
+        T: Into<UserMessageContent>,
+    {
         let model = self.selected_model.clone();
+        let content = content.into_iter().map(Into::into).collect::<Vec<_>>();
         log::info!("Thread::send called with model: {:?}", model.name());
         log::debug!("Thread::send content: {:?}", content);
 
@@ -251,10 +471,10 @@ impl Thread {
         let event_stream = AgentResponseEventStream(events_tx);
 
         let user_message_ix = self.messages.len();
-        self.messages.push(AgentMessage {
-            role: Role::User,
+        self.messages.push(Message::User(UserMessage {
+            id: message_id,
             content,
-        });
+        }));
         log::info!("Total messages in thread: {}", self.messages.len());
         self.running_turn = Some(cx.spawn(async move |thread, cx| {
             log::info!("Starting agent turn execution");
@@ -270,15 +490,11 @@ impl Thread {
                         thread.build_completion_request(completion_intent, cx)
                     })?;
 
-                    // println!(
-                    //     "request: {}",
-                    //     serde_json::to_string_pretty(&request).unwrap()
-                    // );
-
                     // Stream events, appending to messages and collecting up tool uses.
                     log::info!("Calling model.stream_completion");
                     let mut events = model.stream_completion(request, cx).await?;
                     log::debug!("Stream completion started successfully");
+
                     let mut tool_uses = FuturesUnordered::new();
                     while let Some(event) = events.next().await {
                         match event {
@@ -286,6 +502,7 @@ impl Thread {
                                 event_stream.send_stop(reason);
                                 if reason == StopReason::Refusal {
                                     thread.update(cx, |thread, _cx| {
+                                        thread.pending_agent_message = None;
                                         thread.messages.truncate(user_message_ix);
                                     })?;
                                     break 'outer;
@@ -338,15 +555,16 @@ impl Thread {
                         );
                         thread
                             .update(cx, |thread, _cx| {
-                                thread.pending_tool_uses.remove(&tool_result.tool_use_id);
                                 thread
-                                    .last_user_message()
-                                    .content
-                                    .push(MessageContent::ToolResult(tool_result));
+                                    .pending_agent_message()
+                                    .tool_results
+                                    .insert(tool_result.tool_use_id.clone(), tool_result);
                             })
                             .ok();
                     }
 
+                    thread.update(cx, |thread, _cx| thread.flush_pending_agent_message())?;
+
                     completion_intent = CompletionIntent::ToolResults;
                 }
 
@@ -354,6 +572,10 @@ impl Thread {
             }
             .await;
 
+            thread
+                .update(cx, |thread, _cx| thread.flush_pending_agent_message())
+                .ok();
+
             if let Err(error) = turn_result {
                 log::error!("Turn execution failed: {:?}", error);
                 event_stream.send_error(error);
@@ -364,7 +586,7 @@ impl Thread {
         events_rx
     }
 
-    pub fn build_system_message(&self) -> AgentMessage {
+    pub fn build_system_message(&self) -> LanguageModelRequestMessage {
         log::debug!("Building system message");
         let prompt = SystemPromptTemplate {
             project: &self.project_context.borrow(),
@@ -374,9 +596,10 @@ impl Thread {
         .context("failed to build system prompt")
         .expect("Invalid template");
         log::debug!("System message built");
-        AgentMessage {
+        LanguageModelRequestMessage {
             role: Role::System,
-            content: vec![prompt.as_str().into()],
+            content: vec![prompt.into()],
+            cache: true,
         }
     }
 
@@ -394,10 +617,7 @@ impl Thread {
 
         match event {
             StartMessage { .. } => {
-                self.messages.push(AgentMessage {
-                    role: Role::Assistant,
-                    content: Vec::new(),
-                });
+                self.messages.push(Message::Agent(AgentMessage::default()));
             }
             Text(new_text) => self.handle_text_event(new_text, event_stream, cx),
             Thinking { text, signature } => {
@@ -435,11 +655,13 @@ impl Thread {
     ) {
         events_stream.send_text(&new_text);
 
-        let last_message = self.last_assistant_message();
-        if let Some(MessageContent::Text(text)) = last_message.content.last_mut() {
+        let last_message = self.pending_agent_message();
+        if let Some(AgentMessageContent::Text(text)) = last_message.content.last_mut() {
             text.push_str(&new_text);
         } else {
-            last_message.content.push(MessageContent::Text(new_text));
+            last_message
+                .content
+                .push(AgentMessageContent::Text(new_text));
         }
 
         cx.notify();
@@ -454,13 +676,14 @@ impl Thread {
     ) {
         event_stream.send_thinking(&new_text);
 
-        let last_message = self.last_assistant_message();
-        if let Some(MessageContent::Thinking { text, signature }) = last_message.content.last_mut()
+        let last_message = self.pending_agent_message();
+        if let Some(AgentMessageContent::Thinking { text, signature }) =
+            last_message.content.last_mut()
         {
             text.push_str(&new_text);
             *signature = new_signature.or(signature.take());
         } else {
-            last_message.content.push(MessageContent::Thinking {
+            last_message.content.push(AgentMessageContent::Thinking {
                 text: new_text,
                 signature: new_signature,
             });
@@ -470,10 +693,10 @@ impl Thread {
     }
 
     fn handle_redacted_thinking_event(&mut self, data: String, cx: &mut Context<Self>) {
-        let last_message = self.last_assistant_message();
+        let last_message = self.pending_agent_message();
         last_message
             .content
-            .push(MessageContent::RedactedThinking(data));
+            .push(AgentMessageContent::RedactedThinking(data));
         cx.notify();
     }
 
@@ -486,14 +709,17 @@ impl Thread {
         cx.notify();
 
         let tool = self.tools.get(tool_use.name.as_ref()).cloned();
-
-        self.pending_tool_uses
-            .insert(tool_use.id.clone(), tool_use.clone());
-        let last_message = self.last_assistant_message();
+        let mut title = SharedString::from(&tool_use.name);
+        let mut kind = acp::ToolKind::Other;
+        if let Some(tool) = tool.as_ref() {
+            title = tool.initial_title(tool_use.input.clone());
+            kind = tool.kind();
+        }
 
         // Ensure the last message ends in the current tool use
+        let last_message = self.pending_agent_message();
         let push_new_tool_use = last_message.content.last_mut().map_or(true, |content| {
-            if let MessageContent::ToolUse(last_tool_use) = content {
+            if let AgentMessageContent::ToolUse(last_tool_use) = content {
                 if last_tool_use.id == tool_use.id {
                     *last_tool_use = tool_use.clone();
                     false
@@ -505,18 +731,11 @@ impl Thread {
             }
         });
 
-        let mut title = SharedString::from(&tool_use.name);
-        let mut kind = acp::ToolKind::Other;
-        if let Some(tool) = tool.as_ref() {
-            title = tool.initial_title(tool_use.input.clone());
-            kind = tool.kind();
-        }
-
         if push_new_tool_use {
             event_stream.send_tool_call(&tool_use.id, title, kind, tool_use.input.clone());
             last_message
                 .content
-                .push(MessageContent::ToolUse(tool_use.clone()));
+                .push(AgentMessageContent::ToolUse(tool_use.clone()));
         } else {
             event_stream.update_tool_call_fields(
                 &tool_use.id,
@@ -601,30 +820,37 @@ impl Thread {
         }
     }
 
-    /// Guarantees the last message is from the assistant and returns a mutable reference.
-    fn last_assistant_message(&mut self) -> &mut AgentMessage {
-        if self
-            .messages
-            .last()
-            .map_or(true, |m| m.role != Role::Assistant)
-        {
-            self.messages.push(AgentMessage {
-                role: Role::Assistant,
-                content: Vec::new(),
-            });
-        }
-        self.messages.last_mut().unwrap()
+    fn pending_agent_message(&mut self) -> &mut AgentMessage {
+        self.pending_agent_message.get_or_insert_default()
     }
 
-    /// Guarantees the last message is from the user and returns a mutable reference.
-    fn last_user_message(&mut self) -> &mut AgentMessage {
-        if self.messages.last().map_or(true, |m| m.role != Role::User) {
-            self.messages.push(AgentMessage {
-                role: Role::User,
-                content: Vec::new(),
-            });
+    fn flush_pending_agent_message(&mut self) {
+        let Some(mut message) = self.pending_agent_message.take() else {
+            return;
+        };
+
+        for content in &message.content {
+            let AgentMessageContent::ToolUse(tool_use) = content else {
+                continue;
+            };
+
+            if !message.tool_results.contains_key(&tool_use.id) {
+                message.tool_results.insert(
+                    tool_use.id.clone(),
+                    LanguageModelToolResult {
+                        tool_use_id: tool_use.id.clone(),
+                        tool_name: tool_use.name.clone(),
+                        is_error: true,
+                        content: LanguageModelToolResultContent::Text(
+                            "Tool canceled by user".into(),
+                        ),
+                        output: None,
+                    },
+                );
+            }
         }
-        self.messages.last_mut().unwrap()
+
+        self.messages.push(Message::Agent(message));
     }
 
     pub(crate) fn build_completion_request(
@@ -712,46 +938,36 @@ impl Thread {
             "Building request messages from {} thread messages",
             self.messages.len()
         );
+        let mut messages = vec![self.build_system_message()];
+        for message in &self.messages {
+            match message {
+                Message::User(message) => messages.push(message.to_request()),
+                Message::Agent(message) => messages.extend(message.to_request()),
+            }
+        }
+
+        if let Some(message) = self.pending_agent_message.as_ref() {
+            messages.extend(message.to_request());
+        }
 
-        let messages = Some(self.build_system_message())
-            .iter()
-            .chain(self.messages.iter())
-            .map(|message| {
-                log::trace!(
-                    "  - {} message with {} content items",
-                    match message.role {
-                        Role::System => "System",
-                        Role::User => "User",
-                        Role::Assistant => "Assistant",
-                    },
-                    message.content.len()
-                );
-                message.to_request()
-            })
-            .collect();
         messages
     }
 
     pub fn to_markdown(&self) -> String {
         let mut markdown = String::new();
-        for message in &self.messages {
+        for (ix, message) in self.messages.iter().enumerate() {
+            if ix > 0 {
+                markdown.push('\n');
+            }
             markdown.push_str(&message.to_markdown());
         }
-        markdown
-    }
-}
 
-pub struct UserMessage(Vec<MessageContent>);
-
-impl From<Vec<MessageContent>> for UserMessage {
-    fn from(content: Vec<MessageContent>) -> Self {
-        UserMessage(content)
-    }
-}
+        if let Some(message) = self.pending_agent_message.as_ref() {
+            markdown.push('\n');
+            markdown.push_str(&message.to_markdown());
+        }
 
-impl<T: Into<MessageContent>> From<T> for UserMessage {
-    fn from(content: T) -> Self {
-        UserMessage(vec![content.into()])
+        markdown
     }
 }
 
@@ -1151,130 +1367,6 @@ impl std::ops::DerefMut for ToolCallEventStreamReceiver {
     }
 }
 
-impl AgentMessage {
-    fn to_request(&self) -> language_model::LanguageModelRequestMessage {
-        let mut message = LanguageModelRequestMessage {
-            role: self.role,
-            content: Vec::with_capacity(self.content.len()),
-            cache: false,
-        };
-
-        const OPEN_CONTEXT: &str = "<context>\n\
-            The following items were attached by the user. \
-            They are up-to-date and don't need to be re-read.\n\n";
-
-        const OPEN_FILES_TAG: &str = "<files>";
-        const OPEN_SYMBOLS_TAG: &str = "<symbols>";
-        const OPEN_THREADS_TAG: &str = "<threads>";
-        const OPEN_RULES_TAG: &str =
-            "<rules>\nThe user has specified the following rules that should be applied:\n";
-
-        let mut file_context = OPEN_FILES_TAG.to_string();
-        let mut symbol_context = OPEN_SYMBOLS_TAG.to_string();
-        let mut thread_context = OPEN_THREADS_TAG.to_string();
-        let mut rules_context = OPEN_RULES_TAG.to_string();
-
-        for chunk in &self.content {
-            let chunk = match chunk {
-                MessageContent::Text(text) => language_model::MessageContent::Text(text.clone()),
-                MessageContent::Thinking { text, signature } => {
-                    language_model::MessageContent::Thinking {
-                        text: text.clone(),
-                        signature: signature.clone(),
-                    }
-                }
-                MessageContent::RedactedThinking(value) => {
-                    language_model::MessageContent::RedactedThinking(value.clone())
-                }
-                MessageContent::ToolUse(value) => {
-                    language_model::MessageContent::ToolUse(value.clone())
-                }
-                MessageContent::ToolResult(value) => {
-                    language_model::MessageContent::ToolResult(value.clone())
-                }
-                MessageContent::Image(value) => {
-                    language_model::MessageContent::Image(value.clone())
-                }
-                MessageContent::Mention { uri, content } => {
-                    match uri {
-                        MentionUri::File(path) | MentionUri::Symbol(path, _) => {
-                            write!(
-                                &mut symbol_context,
-                                "\n{}",
-                                MarkdownCodeBlock {
-                                    tag: &codeblock_tag(&path),
-                                    text: &content.to_string(),
-                                }
-                            )
-                            .ok();
-                        }
-                        MentionUri::Thread(_session_id) => {
-                            write!(&mut thread_context, "\n{}\n", content).ok();
-                        }
-                        MentionUri::Rule(_user_prompt_id) => {
-                            write!(
-                                &mut rules_context,
-                                "\n{}",
-                                MarkdownCodeBlock {
-                                    tag: "",
-                                    text: &content
-                                }
-                            )
-                            .ok();
-                        }
-                    }
-
-                    language_model::MessageContent::Text(uri.to_link())
-                }
-            };
-
-            message.content.push(chunk);
-        }
-
-        let len_before_context = message.content.len();
-
-        if file_context.len() > OPEN_FILES_TAG.len() {
-            file_context.push_str("</files>\n");
-            message
-                .content
-                .push(language_model::MessageContent::Text(file_context));
-        }
-
-        if symbol_context.len() > OPEN_SYMBOLS_TAG.len() {
-            symbol_context.push_str("</symbols>\n");
-            message
-                .content
-                .push(language_model::MessageContent::Text(symbol_context));
-        }
-
-        if thread_context.len() > OPEN_THREADS_TAG.len() {
-            thread_context.push_str("</threads>\n");
-            message
-                .content
-                .push(language_model::MessageContent::Text(thread_context));
-        }
-
-        if rules_context.len() > OPEN_RULES_TAG.len() {
-            rules_context.push_str("</user_rules>\n");
-            message
-                .content
-                .push(language_model::MessageContent::Text(rules_context));
-        }
-
-        if message.content.len() > len_before_context {
-            message.content.insert(
-                len_before_context,
-                language_model::MessageContent::Text(OPEN_CONTEXT.into()),
-            );
-            message
-                .content
-                .push(language_model::MessageContent::Text("</context>".into()));
-        }
-
-        message
-    }
-}
-
 fn codeblock_tag(full_path: &Path) -> String {
     let mut result = String::new();
 
@@ -1287,16 +1379,20 @@ fn codeblock_tag(full_path: &Path) -> String {
     result
 }
 
-impl From<acp::ContentBlock> for MessageContent {
+impl From<&str> for UserMessageContent {
+    fn from(text: &str) -> Self {
+        Self::Text(text.into())
+    }
+}
+
+impl From<acp::ContentBlock> for UserMessageContent {
     fn from(value: acp::ContentBlock) -> Self {
         match value {
-            acp::ContentBlock::Text(text_content) => MessageContent::Text(text_content.text),
-            acp::ContentBlock::Image(image_content) => {
-                MessageContent::Image(convert_image(image_content))
-            }
+            acp::ContentBlock::Text(text_content) => Self::Text(text_content.text),
+            acp::ContentBlock::Image(image_content) => Self::Image(convert_image(image_content)),
             acp::ContentBlock::Audio(_) => {
                 // TODO
-                MessageContent::Text("[audio]".to_string())
+                Self::Text("[audio]".to_string())
             }
             acp::ContentBlock::ResourceLink(resource_link) => {
                 match MentionUri::parse(&resource_link.uri) {
@@ -1306,10 +1402,7 @@ impl From<acp::ContentBlock> for MessageContent {
                     },
                     Err(err) => {
                         log::error!("Failed to parse mention link: {}", err);
-                        MessageContent::Text(format!(
-                            "[{}]({})",
-                            resource_link.name, resource_link.uri
-                        ))
+                        Self::Text(format!("[{}]({})", resource_link.name, resource_link.uri))
                     }
                 }
             }
@@ -1322,7 +1415,7 @@ impl From<acp::ContentBlock> for MessageContent {
                         },
                         Err(err) => {
                             log::error!("Failed to parse mention link: {}", err);
-                            MessageContent::Text(
+                            Self::Text(
                                 MarkdownCodeBlock {
                                     tag: &resource.uri,
                                     text: &resource.text,
@@ -1334,7 +1427,7 @@ impl From<acp::ContentBlock> for MessageContent {
                 }
                 acp::EmbeddedResourceResource::BlobResourceContents(_) => {
                     // TODO
-                    MessageContent::Text("[blob]".to_string())
+                    Self::Text("[blob]".to_string())
                 }
             },
         }
@@ -1348,9 +1441,3 @@ fn convert_image(image_content: acp::ImageContent) -> LanguageModelImage {
         size: gpui::Size::new(0.into(), 0.into()),
     }
 }
-
-impl From<&str> for MessageContent {
-    fn from(text: &str) -> Self {
-        MessageContent::Text(text.into())
-    }
-}

crates/agent_servers/src/acp/v0.rs 🔗

@@ -467,6 +467,7 @@ impl AgentConnection for AcpConnection {
 
     fn prompt(
         &self,
+        _id: Option<acp_thread::UserMessageId>,
         params: acp::PromptRequest,
         cx: &mut App,
     ) -> Task<Result<acp::PromptResponse>> {

crates/agent_servers/src/acp/v1.rs 🔗

@@ -171,6 +171,7 @@ impl AgentConnection for AcpConnection {
 
     fn prompt(
         &self,
+        _id: Option<acp_thread::UserMessageId>,
         params: acp::PromptRequest,
         cx: &mut App,
     ) -> Task<Result<acp::PromptResponse>> {

crates/agent_servers/src/claude.rs 🔗

@@ -210,6 +210,7 @@ impl AgentConnection for ClaudeAgentConnection {
 
     fn prompt(
         &self,
+        _id: Option<acp_thread::UserMessageId>,
         params: acp::PromptRequest,
         cx: &mut App,
     ) -> Task<Result<acp::PromptResponse>> {
@@ -423,7 +424,7 @@ impl ClaudeAgentSession {
                             if !turn_state.borrow().is_cancelled() {
                                 thread
                                     .update(cx, |thread, cx| {
-                                        thread.push_user_content_block(text.into(), cx)
+                                        thread.push_user_content_block(None, text.into(), cx)
                                     })
                                     .log_err();
                             }

crates/agent_ui/src/acp/thread_view.rs 🔗

@@ -679,17 +679,19 @@ impl AcpThreadView {
         window: &mut Window,
         cx: &mut Context<Self>,
     ) {
-        let count = self.list_state.item_count();
         match event {
             AcpThreadEvent::NewEntry => {
                 let index = thread.read(cx).entries().len() - 1;
                 self.sync_thread_entry_view(index, window, cx);
-                self.list_state.splice(count..count, 1);
+                self.list_state.splice(index..index, 1);
             }
             AcpThreadEvent::EntryUpdated(index) => {
-                let index = *index;
-                self.sync_thread_entry_view(index, window, cx);
-                self.list_state.splice(index..index + 1, 1);
+                self.sync_thread_entry_view(*index, window, cx);
+                self.list_state.splice(*index..index + 1, 1);
+            }
+            AcpThreadEvent::EntriesRemoved(range) => {
+                // TODO: Clean up unused diff editors and terminal views
+                self.list_state.splice(range.clone(), 0);
             }
             AcpThreadEvent::ToolAuthorizationRequired => {
                 self.notify_with_sound("Waiting for tool confirmation", IconName::Info, window, cx);
@@ -3789,6 +3791,7 @@ mod tests {
 
         fn prompt(
             &self,
+            _id: Option<acp_thread::UserMessageId>,
             params: acp::PromptRequest,
             cx: &mut App,
         ) -> Task<gpui::Result<acp::PromptResponse>> {
@@ -3873,6 +3876,7 @@ mod tests {
 
         fn prompt(
             &self,
+            _id: Option<acp_thread::UserMessageId>,
             _params: acp::PromptRequest,
             _cx: &mut App,
         ) -> Task<gpui::Result<acp::PromptResponse>> {

crates/agent_ui/src/agent_diff.rs 🔗

@@ -1521,7 +1521,8 @@ impl AgentDiff {
                     self.update_reviewing_editors(workspace, window, cx);
                 }
             }
-            AcpThreadEvent::Stopped
+            AcpThreadEvent::EntriesRemoved(_)
+            | AcpThreadEvent::Stopped
             | AcpThreadEvent::ToolAuthorizationRequired
             | AcpThreadEvent::Error
             | AcpThreadEvent::ServerExited(_) => {}

crates/fs/Cargo.toml 🔗

@@ -51,6 +51,7 @@ ashpd.workspace = true
 
 [dev-dependencies]
 gpui = { workspace = true, features = ["test-support"] }
+git = { workspace = true, features = ["test-support"] }
 
 [features]
 test-support = ["gpui/test-support", "git/test-support"]

crates/fs/src/fake_git_repo.rs 🔗

@@ -1,8 +1,9 @@
-use crate::{FakeFs, Fs};
+use crate::{FakeFs, FakeFsEntry, Fs};
 use anyhow::{Context as _, Result};
 use collections::{HashMap, HashSet};
 use futures::future::{self, BoxFuture, join_all};
 use git::{
+    Oid,
     blame::Blame,
     repository::{
         AskPassDelegate, Branch, CommitDetails, CommitOptions, FetchOptions, GitRepository,
@@ -12,6 +13,7 @@ use git::{
 };
 use gpui::{AsyncApp, BackgroundExecutor, SharedString, Task};
 use ignore::gitignore::GitignoreBuilder;
+use parking_lot::Mutex;
 use rope::Rope;
 use smol::future::FutureExt as _;
 use std::{path::PathBuf, sync::Arc};
@@ -19,6 +21,7 @@ use std::{path::PathBuf, sync::Arc};
 #[derive(Clone)]
 pub struct FakeGitRepository {
     pub(crate) fs: Arc<FakeFs>,
+    pub(crate) checkpoints: Arc<Mutex<HashMap<Oid, FakeFsEntry>>>,
     pub(crate) executor: BackgroundExecutor,
     pub(crate) dot_git_path: PathBuf,
     pub(crate) repository_dir_path: PathBuf,
@@ -469,22 +472,57 @@ impl GitRepository for FakeGitRepository {
     }
 
     fn checkpoint(&self) -> BoxFuture<'static, Result<GitRepositoryCheckpoint>> {
-        unimplemented!()
+        let executor = self.executor.clone();
+        let fs = self.fs.clone();
+        let checkpoints = self.checkpoints.clone();
+        let repository_dir_path = self.repository_dir_path.parent().unwrap().to_path_buf();
+        async move {
+            executor.simulate_random_delay().await;
+            let oid = Oid::random(&mut executor.rng());
+            let entry = fs.entry(&repository_dir_path)?;
+            checkpoints.lock().insert(oid, entry);
+            Ok(GitRepositoryCheckpoint { commit_sha: oid })
+        }
+        .boxed()
     }
 
-    fn restore_checkpoint(
-        &self,
-        _checkpoint: GitRepositoryCheckpoint,
-    ) -> BoxFuture<'_, Result<()>> {
-        unimplemented!()
+    fn restore_checkpoint(&self, checkpoint: GitRepositoryCheckpoint) -> BoxFuture<'_, Result<()>> {
+        let executor = self.executor.clone();
+        let fs = self.fs.clone();
+        let checkpoints = self.checkpoints.clone();
+        let repository_dir_path = self.repository_dir_path.parent().unwrap().to_path_buf();
+        async move {
+            executor.simulate_random_delay().await;
+            let checkpoints = checkpoints.lock();
+            let entry = checkpoints
+                .get(&checkpoint.commit_sha)
+                .context(format!("invalid checkpoint: {}", checkpoint.commit_sha))?;
+            fs.insert_entry(&repository_dir_path, entry.clone())?;
+            Ok(())
+        }
+        .boxed()
     }
 
     fn compare_checkpoints(
         &self,
-        _left: GitRepositoryCheckpoint,
-        _right: GitRepositoryCheckpoint,
+        left: GitRepositoryCheckpoint,
+        right: GitRepositoryCheckpoint,
     ) -> BoxFuture<'_, Result<bool>> {
-        unimplemented!()
+        let executor = self.executor.clone();
+        let checkpoints = self.checkpoints.clone();
+        async move {
+            executor.simulate_random_delay().await;
+            let checkpoints = checkpoints.lock();
+            let left = checkpoints
+                .get(&left.commit_sha)
+                .context(format!("invalid left checkpoint: {}", left.commit_sha))?;
+            let right = checkpoints
+                .get(&right.commit_sha)
+                .context(format!("invalid right checkpoint: {}", right.commit_sha))?;
+
+            Ok(left == right)
+        }
+        .boxed()
     }
 
     fn diff_checkpoints(
@@ -499,3 +537,63 @@ impl GitRepository for FakeGitRepository {
         unimplemented!()
     }
 }
+
+#[cfg(test)]
+mod tests {
+    use crate::{FakeFs, Fs};
+    use gpui::BackgroundExecutor;
+    use serde_json::json;
+    use std::path::Path;
+    use util::path;
+
+    #[gpui::test]
+    async fn test_checkpoints(executor: BackgroundExecutor) {
+        let fs = FakeFs::new(executor);
+        fs.insert_tree(
+            path!("/"),
+            json!({
+                "bar": {
+                    "baz": "qux"
+                },
+                "foo": {
+                    ".git": {},
+                    "a": "lorem",
+                    "b": "ipsum",
+                },
+            }),
+        )
+        .await;
+        fs.with_git_state(Path::new("/foo/.git"), true, |_git| {})
+            .unwrap();
+        let repository = fs.open_repo(Path::new("/foo/.git")).unwrap();
+
+        let checkpoint_1 = repository.checkpoint().await.unwrap();
+        fs.write(Path::new("/foo/b"), b"IPSUM").await.unwrap();
+        fs.write(Path::new("/foo/c"), b"dolor").await.unwrap();
+        let checkpoint_2 = repository.checkpoint().await.unwrap();
+        let checkpoint_3 = repository.checkpoint().await.unwrap();
+
+        assert!(
+            repository
+                .compare_checkpoints(checkpoint_2.clone(), checkpoint_3.clone())
+                .await
+                .unwrap()
+        );
+        assert!(
+            !repository
+                .compare_checkpoints(checkpoint_1.clone(), checkpoint_2.clone())
+                .await
+                .unwrap()
+        );
+
+        repository.restore_checkpoint(checkpoint_1).await.unwrap();
+        assert_eq!(
+            fs.files_with_contents(Path::new("")),
+            [
+                (Path::new("/bar/baz").into(), b"qux".into()),
+                (Path::new("/foo/a").into(), b"lorem".into()),
+                (Path::new("/foo/b").into(), b"ipsum".into())
+            ]
+        );
+    }
+}

crates/fs/src/fs.rs 🔗

@@ -924,7 +924,7 @@ pub struct FakeFs {
 
 #[cfg(any(test, feature = "test-support"))]
 struct FakeFsState {
-    root: Arc<Mutex<FakeFsEntry>>,
+    root: FakeFsEntry,
     next_inode: u64,
     next_mtime: SystemTime,
     git_event_tx: smol::channel::Sender<PathBuf>,
@@ -939,7 +939,7 @@ struct FakeFsState {
 }
 
 #[cfg(any(test, feature = "test-support"))]
-#[derive(Debug)]
+#[derive(Clone, Debug)]
 enum FakeFsEntry {
     File {
         inode: u64,
@@ -953,7 +953,7 @@ enum FakeFsEntry {
         inode: u64,
         mtime: MTime,
         len: u64,
-        entries: BTreeMap<String, Arc<Mutex<FakeFsEntry>>>,
+        entries: BTreeMap<String, FakeFsEntry>,
         git_repo_state: Option<Arc<Mutex<FakeGitRepositoryState>>>,
     },
     Symlink {
@@ -961,6 +961,67 @@ enum FakeFsEntry {
     },
 }
 
+#[cfg(any(test, feature = "test-support"))]
+impl PartialEq for FakeFsEntry {
+    fn eq(&self, other: &Self) -> bool {
+        match (self, other) {
+            (
+                Self::File {
+                    inode: l_inode,
+                    mtime: l_mtime,
+                    len: l_len,
+                    content: l_content,
+                    git_dir_path: l_git_dir_path,
+                },
+                Self::File {
+                    inode: r_inode,
+                    mtime: r_mtime,
+                    len: r_len,
+                    content: r_content,
+                    git_dir_path: r_git_dir_path,
+                },
+            ) => {
+                l_inode == r_inode
+                    && l_mtime == r_mtime
+                    && l_len == r_len
+                    && l_content == r_content
+                    && l_git_dir_path == r_git_dir_path
+            }
+            (
+                Self::Dir {
+                    inode: l_inode,
+                    mtime: l_mtime,
+                    len: l_len,
+                    entries: l_entries,
+                    git_repo_state: l_git_repo_state,
+                },
+                Self::Dir {
+                    inode: r_inode,
+                    mtime: r_mtime,
+                    len: r_len,
+                    entries: r_entries,
+                    git_repo_state: r_git_repo_state,
+                },
+            ) => {
+                let same_repo_state = match (l_git_repo_state.as_ref(), r_git_repo_state.as_ref()) {
+                    (Some(l), Some(r)) => Arc::ptr_eq(l, r),
+                    (None, None) => true,
+                    _ => false,
+                };
+                l_inode == r_inode
+                    && l_mtime == r_mtime
+                    && l_len == r_len
+                    && l_entries == r_entries
+                    && same_repo_state
+            }
+            (Self::Symlink { target: l_target }, Self::Symlink { target: r_target }) => {
+                l_target == r_target
+            }
+            _ => false,
+        }
+    }
+}
+
 #[cfg(any(test, feature = "test-support"))]
 impl FakeFsState {
     fn get_and_increment_mtime(&mut self) -> MTime {
@@ -975,25 +1036,9 @@ impl FakeFsState {
         inode
     }
 
-    fn read_path(&self, target: &Path) -> Result<Arc<Mutex<FakeFsEntry>>> {
-        Ok(self
-            .try_read_path(target, true)
-            .ok_or_else(|| {
-                anyhow!(io::Error::new(
-                    io::ErrorKind::NotFound,
-                    format!("not found: {target:?}")
-                ))
-            })?
-            .0)
-    }
-
-    fn try_read_path(
-        &self,
-        target: &Path,
-        follow_symlink: bool,
-    ) -> Option<(Arc<Mutex<FakeFsEntry>>, PathBuf)> {
-        let mut path = target.to_path_buf();
+    fn canonicalize(&self, target: &Path, follow_symlink: bool) -> Option<PathBuf> {
         let mut canonical_path = PathBuf::new();
+        let mut path = target.to_path_buf();
         let mut entry_stack = Vec::new();
         'outer: loop {
             let mut path_components = path.components().peekable();
@@ -1003,7 +1048,7 @@ impl FakeFsState {
                     Component::Prefix(prefix_component) => prefix = Some(prefix_component),
                     Component::RootDir => {
                         entry_stack.clear();
-                        entry_stack.push(self.root.clone());
+                        entry_stack.push(&self.root);
                         canonical_path.clear();
                         match prefix {
                             Some(prefix_component) => {
@@ -1020,20 +1065,18 @@ impl FakeFsState {
                         canonical_path.pop();
                     }
                     Component::Normal(name) => {
-                        let current_entry = entry_stack.last().cloned()?;
-                        let current_entry = current_entry.lock();
-                        if let FakeFsEntry::Dir { entries, .. } = &*current_entry {
-                            let entry = entries.get(name.to_str().unwrap()).cloned()?;
+                        let current_entry = *entry_stack.last()?;
+                        if let FakeFsEntry::Dir { entries, .. } = current_entry {
+                            let entry = entries.get(name.to_str().unwrap())?;
                             if path_components.peek().is_some() || follow_symlink {
-                                let entry = entry.lock();
-                                if let FakeFsEntry::Symlink { target, .. } = &*entry {
+                                if let FakeFsEntry::Symlink { target, .. } = entry {
                                     let mut target = target.clone();
                                     target.extend(path_components);
                                     path = target;
                                     continue 'outer;
                                 }
                             }
-                            entry_stack.push(entry.clone());
+                            entry_stack.push(entry);
                             canonical_path = canonical_path.join(name);
                         } else {
                             return None;
@@ -1043,19 +1086,72 @@ impl FakeFsState {
             }
             break;
         }
-        Some((entry_stack.pop()?, canonical_path))
+
+        if entry_stack.is_empty() {
+            None
+        } else {
+            Some(canonical_path)
+        }
+    }
+
+    fn try_entry(
+        &mut self,
+        target: &Path,
+        follow_symlink: bool,
+    ) -> Option<(&mut FakeFsEntry, PathBuf)> {
+        let canonical_path = self.canonicalize(target, follow_symlink)?;
+
+        let mut components = canonical_path.components();
+        let Some(Component::RootDir) = components.next() else {
+            panic!(
+                "the path {:?} was not canonicalized properly {:?}",
+                target, canonical_path
+            )
+        };
+
+        let mut entry = &mut self.root;
+        for component in components {
+            match component {
+                Component::Normal(name) => {
+                    if let FakeFsEntry::Dir { entries, .. } = entry {
+                        entry = entries.get_mut(name.to_str().unwrap())?;
+                    } else {
+                        return None;
+                    }
+                }
+                _ => {
+                    panic!(
+                        "the path {:?} was not canonicalized properly {:?}",
+                        target, canonical_path
+                    )
+                }
+            }
+        }
+
+        Some((entry, canonical_path))
     }
 
-    fn write_path<Fn, T>(&self, path: &Path, callback: Fn) -> Result<T>
+    fn entry(&mut self, target: &Path) -> Result<&mut FakeFsEntry> {
+        Ok(self
+            .try_entry(target, true)
+            .ok_or_else(|| {
+                anyhow!(io::Error::new(
+                    io::ErrorKind::NotFound,
+                    format!("not found: {target:?}")
+                ))
+            })?
+            .0)
+    }
+
+    fn write_path<Fn, T>(&mut self, path: &Path, callback: Fn) -> Result<T>
     where
-        Fn: FnOnce(btree_map::Entry<String, Arc<Mutex<FakeFsEntry>>>) -> Result<T>,
+        Fn: FnOnce(btree_map::Entry<String, FakeFsEntry>) -> Result<T>,
     {
         let path = normalize_path(path);
         let filename = path.file_name().context("cannot overwrite the root")?;
         let parent_path = path.parent().unwrap();
 
-        let parent = self.read_path(parent_path)?;
-        let mut parent = parent.lock();
+        let parent = self.entry(parent_path)?;
         let new_entry = parent
             .dir_entries(parent_path)?
             .entry(filename.to_str().unwrap().into());
@@ -1105,13 +1201,13 @@ impl FakeFs {
             this: this.clone(),
             executor: executor.clone(),
             state: Arc::new(Mutex::new(FakeFsState {
-                root: Arc::new(Mutex::new(FakeFsEntry::Dir {
+                root: FakeFsEntry::Dir {
                     inode: 0,
                     mtime: MTime(UNIX_EPOCH),
                     len: 0,
                     entries: Default::default(),
                     git_repo_state: None,
-                })),
+                },
                 git_event_tx: tx,
                 next_mtime: UNIX_EPOCH + Self::SYSTEMTIME_INTERVAL,
                 next_inode: 1,
@@ -1161,15 +1257,15 @@ impl FakeFs {
             .write_path(path, move |entry| {
                 match entry {
                     btree_map::Entry::Vacant(e) => {
-                        e.insert(Arc::new(Mutex::new(FakeFsEntry::File {
+                        e.insert(FakeFsEntry::File {
                             inode: new_inode,
                             mtime: new_mtime,
                             content: Vec::new(),
                             len: 0,
                             git_dir_path: None,
-                        })));
+                        });
                     }
-                    btree_map::Entry::Occupied(mut e) => match &mut *e.get_mut().lock() {
+                    btree_map::Entry::Occupied(mut e) => match &mut *e.get_mut() {
                         FakeFsEntry::File { mtime, .. } => *mtime = new_mtime,
                         FakeFsEntry::Dir { mtime, .. } => *mtime = new_mtime,
                         FakeFsEntry::Symlink { .. } => {}
@@ -1188,7 +1284,7 @@ impl FakeFs {
     pub async fn insert_symlink(&self, path: impl AsRef<Path>, target: PathBuf) {
         let mut state = self.state.lock();
         let path = path.as_ref();
-        let file = Arc::new(Mutex::new(FakeFsEntry::Symlink { target }));
+        let file = FakeFsEntry::Symlink { target };
         state
             .write_path(path.as_ref(), move |e| match e {
                 btree_map::Entry::Vacant(e) => {
@@ -1221,13 +1317,13 @@ impl FakeFs {
             match entry {
                 btree_map::Entry::Vacant(e) => {
                     kind = Some(PathEventKind::Created);
-                    e.insert(Arc::new(Mutex::new(FakeFsEntry::File {
+                    e.insert(FakeFsEntry::File {
                         inode: new_inode,
                         mtime: new_mtime,
                         len: new_len,
                         content: new_content,
                         git_dir_path: None,
-                    })));
+                    });
                 }
                 btree_map::Entry::Occupied(mut e) => {
                     kind = Some(PathEventKind::Changed);
@@ -1237,7 +1333,7 @@ impl FakeFs {
                         len,
                         content,
                         ..
-                    } = &mut *e.get_mut().lock()
+                    } = e.get_mut()
                     {
                         *mtime = new_mtime;
                         *content = new_content;
@@ -1259,9 +1355,8 @@ impl FakeFs {
     pub fn read_file_sync(&self, path: impl AsRef<Path>) -> Result<Vec<u8>> {
         let path = path.as_ref();
         let path = normalize_path(path);
-        let state = self.state.lock();
-        let entry = state.read_path(&path)?;
-        let entry = entry.lock();
+        let mut state = self.state.lock();
+        let entry = state.entry(&path)?;
         entry.file_content(&path).cloned()
     }
 
@@ -1269,9 +1364,8 @@ impl FakeFs {
         let path = path.as_ref();
         let path = normalize_path(path);
         self.simulate_random_delay().await;
-        let state = self.state.lock();
-        let entry = state.read_path(&path)?;
-        let entry = entry.lock();
+        let mut state = self.state.lock();
+        let entry = state.entry(&path)?;
         entry.file_content(&path).cloned()
     }
 
@@ -1292,6 +1386,25 @@ impl FakeFs {
         self.state.lock().flush_events(count);
     }
 
+    pub(crate) fn entry(&self, target: &Path) -> Result<FakeFsEntry> {
+        self.state.lock().entry(target).cloned()
+    }
+
+    pub(crate) fn insert_entry(&self, target: &Path, new_entry: FakeFsEntry) -> Result<()> {
+        let mut state = self.state.lock();
+        state.write_path(target, |entry| {
+            match entry {
+                btree_map::Entry::Vacant(vacant_entry) => {
+                    vacant_entry.insert(new_entry);
+                }
+                btree_map::Entry::Occupied(mut occupied_entry) => {
+                    occupied_entry.insert(new_entry);
+                }
+            }
+            Ok(())
+        })
+    }
+
     #[must_use]
     pub fn insert_tree<'a>(
         &'a self,
@@ -1361,20 +1474,19 @@ impl FakeFs {
         F: FnOnce(&mut FakeGitRepositoryState, &Path, &Path) -> T,
     {
         let mut state = self.state.lock();
-        let entry = state.read_path(dot_git).context("open .git")?;
-        let mut entry = entry.lock();
+        let git_event_tx = state.git_event_tx.clone();
+        let entry = state.entry(dot_git).context("open .git")?;
 
-        if let FakeFsEntry::Dir { git_repo_state, .. } = &mut *entry {
+        if let FakeFsEntry::Dir { git_repo_state, .. } = entry {
             let repo_state = git_repo_state.get_or_insert_with(|| {
                 log::debug!("insert git state for {dot_git:?}");
-                Arc::new(Mutex::new(FakeGitRepositoryState::new(
-                    state.git_event_tx.clone(),
-                )))
+                Arc::new(Mutex::new(FakeGitRepositoryState::new(git_event_tx)))
             });
             let mut repo_state = repo_state.lock();
 
             let result = f(&mut repo_state, dot_git, dot_git);
 
+            drop(repo_state);
             if emit_git_event {
                 state.emit_event([(dot_git, None)]);
             }
@@ -1398,21 +1510,20 @@ impl FakeFs {
                 }
             }
             .clone();
-            drop(entry);
-            let Some((git_dir_entry, canonical_path)) = state.try_read_path(&path, true) else {
+            let Some((git_dir_entry, canonical_path)) = state.try_entry(&path, true) else {
                 anyhow::bail!("pointed-to git dir {path:?} not found")
             };
             let FakeFsEntry::Dir {
                 git_repo_state,
                 entries,
                 ..
-            } = &mut *git_dir_entry.lock()
+            } = git_dir_entry
             else {
                 anyhow::bail!("gitfile points to a non-directory")
             };
             let common_dir = if let Some(child) = entries.get("commondir") {
                 Path::new(
-                    std::str::from_utf8(child.lock().file_content("commondir".as_ref())?)
+                    std::str::from_utf8(child.file_content("commondir".as_ref())?)
                         .context("commondir content")?,
                 )
                 .to_owned()
@@ -1420,15 +1531,14 @@ impl FakeFs {
                 canonical_path.clone()
             };
             let repo_state = git_repo_state.get_or_insert_with(|| {
-                Arc::new(Mutex::new(FakeGitRepositoryState::new(
-                    state.git_event_tx.clone(),
-                )))
+                Arc::new(Mutex::new(FakeGitRepositoryState::new(git_event_tx)))
             });
             let mut repo_state = repo_state.lock();
 
             let result = f(&mut repo_state, &canonical_path, &common_dir);
 
             if emit_git_event {
+                drop(repo_state);
                 state.emit_event([(canonical_path, None)]);
             }
 
@@ -1655,14 +1765,12 @@ impl FakeFs {
     pub fn paths(&self, include_dot_git: bool) -> Vec<PathBuf> {
         let mut result = Vec::new();
         let mut queue = collections::VecDeque::new();
-        queue.push_back((
-            PathBuf::from(util::path!("/")),
-            self.state.lock().root.clone(),
-        ));
+        let state = &*self.state.lock();
+        queue.push_back((PathBuf::from(util::path!("/")), &state.root));
         while let Some((path, entry)) = queue.pop_front() {
-            if let FakeFsEntry::Dir { entries, .. } = &*entry.lock() {
+            if let FakeFsEntry::Dir { entries, .. } = entry {
                 for (name, entry) in entries {
-                    queue.push_back((path.join(name), entry.clone()));
+                    queue.push_back((path.join(name), entry));
                 }
             }
             if include_dot_git
@@ -1679,14 +1787,12 @@ impl FakeFs {
     pub fn directories(&self, include_dot_git: bool) -> Vec<PathBuf> {
         let mut result = Vec::new();
         let mut queue = collections::VecDeque::new();
-        queue.push_back((
-            PathBuf::from(util::path!("/")),
-            self.state.lock().root.clone(),
-        ));
+        let state = &*self.state.lock();
+        queue.push_back((PathBuf::from(util::path!("/")), &state.root));
         while let Some((path, entry)) = queue.pop_front() {
-            if let FakeFsEntry::Dir { entries, .. } = &*entry.lock() {
+            if let FakeFsEntry::Dir { entries, .. } = entry {
                 for (name, entry) in entries {
-                    queue.push_back((path.join(name), entry.clone()));
+                    queue.push_back((path.join(name), entry));
                 }
                 if include_dot_git
                     || !path
@@ -1703,17 +1809,14 @@ impl FakeFs {
     pub fn files(&self) -> Vec<PathBuf> {
         let mut result = Vec::new();
         let mut queue = collections::VecDeque::new();
-        queue.push_back((
-            PathBuf::from(util::path!("/")),
-            self.state.lock().root.clone(),
-        ));
+        let state = &*self.state.lock();
+        queue.push_back((PathBuf::from(util::path!("/")), &state.root));
         while let Some((path, entry)) = queue.pop_front() {
-            let e = entry.lock();
-            match &*e {
+            match entry {
                 FakeFsEntry::File { .. } => result.push(path),
                 FakeFsEntry::Dir { entries, .. } => {
                     for (name, entry) in entries {
-                        queue.push_back((path.join(name), entry.clone()));
+                        queue.push_back((path.join(name), entry));
                     }
                 }
                 FakeFsEntry::Symlink { .. } => {}
@@ -1725,13 +1828,10 @@ impl FakeFs {
     pub fn files_with_contents(&self, prefix: &Path) -> Vec<(PathBuf, Vec<u8>)> {
         let mut result = Vec::new();
         let mut queue = collections::VecDeque::new();
-        queue.push_back((
-            PathBuf::from(util::path!("/")),
-            self.state.lock().root.clone(),
-        ));
+        let state = &*self.state.lock();
+        queue.push_back((PathBuf::from(util::path!("/")), &state.root));
         while let Some((path, entry)) = queue.pop_front() {
-            let e = entry.lock();
-            match &*e {
+            match entry {
                 FakeFsEntry::File { content, .. } => {
                     if path.starts_with(prefix) {
                         result.push((path, content.clone()));
@@ -1739,7 +1839,7 @@ impl FakeFs {
                 }
                 FakeFsEntry::Dir { entries, .. } => {
                     for (name, entry) in entries {
-                        queue.push_back((path.join(name), entry.clone()));
+                        queue.push_back((path.join(name), entry));
                     }
                 }
                 FakeFsEntry::Symlink { .. } => {}
@@ -1805,10 +1905,7 @@ impl FakeFsEntry {
         }
     }
 
-    fn dir_entries(
-        &mut self,
-        path: &Path,
-    ) -> Result<&mut BTreeMap<String, Arc<Mutex<FakeFsEntry>>>> {
+    fn dir_entries(&mut self, path: &Path) -> Result<&mut BTreeMap<String, FakeFsEntry>> {
         if let Self::Dir { entries, .. } = self {
             Ok(entries)
         } else {
@@ -1855,12 +1952,12 @@ struct FakeHandle {
 impl FileHandle for FakeHandle {
     fn current_path(&self, fs: &Arc<dyn Fs>) -> Result<PathBuf> {
         let fs = fs.as_fake();
-        let state = fs.state.lock();
-        let Some(target) = state.moves.get(&self.inode) else {
+        let mut state = fs.state.lock();
+        let Some(target) = state.moves.get(&self.inode).cloned() else {
             anyhow::bail!("fake fd not moved")
         };
 
-        if state.try_read_path(&target, false).is_some() {
+        if state.try_entry(&target, false).is_some() {
             return Ok(target.clone());
         }
         anyhow::bail!("fake fd target not found")
@@ -1888,13 +1985,13 @@ impl Fs for FakeFs {
             state.write_path(&cur_path, |entry| {
                 entry.or_insert_with(|| {
                     created_dirs.push((cur_path.clone(), Some(PathEventKind::Created)));
-                    Arc::new(Mutex::new(FakeFsEntry::Dir {
+                    FakeFsEntry::Dir {
                         inode,
                         mtime,
                         len: 0,
                         entries: Default::default(),
                         git_repo_state: None,
-                    }))
+                    }
                 });
                 Ok(())
             })?
@@ -1909,13 +2006,13 @@ impl Fs for FakeFs {
         let mut state = self.state.lock();
         let inode = state.get_and_increment_inode();
         let mtime = state.get_and_increment_mtime();
-        let file = Arc::new(Mutex::new(FakeFsEntry::File {
+        let file = FakeFsEntry::File {
             inode,
             mtime,
             len: 0,
             content: Vec::new(),
             git_dir_path: None,
-        }));
+        };
         let mut kind = Some(PathEventKind::Created);
         state.write_path(path, |entry| {
             match entry {
@@ -1939,7 +2036,7 @@ impl Fs for FakeFs {
 
     async fn create_symlink(&self, path: &Path, target: PathBuf) -> Result<()> {
         let mut state = self.state.lock();
-        let file = Arc::new(Mutex::new(FakeFsEntry::Symlink { target }));
+        let file = FakeFsEntry::Symlink { target };
         state
             .write_path(path.as_ref(), move |e| match e {
                 btree_map::Entry::Vacant(e) => {
@@ -2002,7 +2099,7 @@ impl Fs for FakeFs {
             }
         })?;
 
-        let inode = match *moved_entry.lock() {
+        let inode = match moved_entry {
             FakeFsEntry::File { inode, .. } => inode,
             FakeFsEntry::Dir { inode, .. } => inode,
             _ => 0,
@@ -2051,8 +2148,8 @@ impl Fs for FakeFs {
         let mut state = self.state.lock();
         let mtime = state.get_and_increment_mtime();
         let inode = state.get_and_increment_inode();
-        let source_entry = state.read_path(&source)?;
-        let content = source_entry.lock().file_content(&source)?.clone();
+        let source_entry = state.entry(&source)?;
+        let content = source_entry.file_content(&source)?.clone();
         let mut kind = Some(PathEventKind::Created);
         state.write_path(&target, |e| match e {
             btree_map::Entry::Occupied(e) => {
@@ -2066,13 +2163,13 @@ impl Fs for FakeFs {
                 }
             }
             btree_map::Entry::Vacant(e) => Ok(Some(
-                e.insert(Arc::new(Mutex::new(FakeFsEntry::File {
+                e.insert(FakeFsEntry::File {
                     inode,
                     mtime,
                     len: content.len() as u64,
                     content,
                     git_dir_path: None,
-                })))
+                })
                 .clone(),
             )),
         })?;
@@ -2088,8 +2185,7 @@ impl Fs for FakeFs {
         let base_name = path.file_name().context("cannot remove the root")?;
 
         let mut state = self.state.lock();
-        let parent_entry = state.read_path(parent_path)?;
-        let mut parent_entry = parent_entry.lock();
+        let parent_entry = state.entry(parent_path)?;
         let entry = parent_entry
             .dir_entries(parent_path)?
             .entry(base_name.to_str().unwrap().into());
@@ -2100,15 +2196,14 @@ impl Fs for FakeFs {
                     anyhow::bail!("{path:?} does not exist");
                 }
             }
-            btree_map::Entry::Occupied(e) => {
+            btree_map::Entry::Occupied(mut entry) => {
                 {
-                    let mut entry = e.get().lock();
-                    let children = entry.dir_entries(&path)?;
+                    let children = entry.get_mut().dir_entries(&path)?;
                     if !options.recursive && !children.is_empty() {
                         anyhow::bail!("{path:?} is not empty");
                     }
                 }
-                e.remove();
+                entry.remove();
             }
         }
         state.emit_event([(path, Some(PathEventKind::Removed))]);
@@ -2122,8 +2217,7 @@ impl Fs for FakeFs {
         let parent_path = path.parent().context("cannot remove the root")?;
         let base_name = path.file_name().unwrap();
         let mut state = self.state.lock();
-        let parent_entry = state.read_path(parent_path)?;
-        let mut parent_entry = parent_entry.lock();
+        let parent_entry = state.entry(parent_path)?;
         let entry = parent_entry
             .dir_entries(parent_path)?
             .entry(base_name.to_str().unwrap().into());
@@ -2133,9 +2227,9 @@ impl Fs for FakeFs {
                     anyhow::bail!("{path:?} does not exist");
                 }
             }
-            btree_map::Entry::Occupied(e) => {
-                e.get().lock().file_content(&path)?;
-                e.remove();
+            btree_map::Entry::Occupied(mut entry) => {
+                entry.get_mut().file_content(&path)?;
+                entry.remove();
             }
         }
         state.emit_event([(path, Some(PathEventKind::Removed))]);
@@ -2149,12 +2243,10 @@ impl Fs for FakeFs {
 
     async fn open_handle(&self, path: &Path) -> Result<Arc<dyn FileHandle>> {
         self.simulate_random_delay().await;
-        let state = self.state.lock();
-        let entry = state.read_path(&path)?;
-        let entry = entry.lock();
-        let inode = match *entry {
-            FakeFsEntry::File { inode, .. } => inode,
-            FakeFsEntry::Dir { inode, .. } => inode,
+        let mut state = self.state.lock();
+        let inode = match state.entry(&path)? {
+            FakeFsEntry::File { inode, .. } => *inode,
+            FakeFsEntry::Dir { inode, .. } => *inode,
             _ => unreachable!(),
         };
         Ok(Arc::new(FakeHandle { inode }))
@@ -2204,8 +2296,8 @@ impl Fs for FakeFs {
         let path = normalize_path(path);
         self.simulate_random_delay().await;
         let state = self.state.lock();
-        let (_, canonical_path) = state
-            .try_read_path(&path, true)
+        let canonical_path = state
+            .canonicalize(&path, true)
             .with_context(|| format!("path does not exist: {path:?}"))?;
         Ok(canonical_path)
     }
@@ -2213,9 +2305,9 @@ impl Fs for FakeFs {
     async fn is_file(&self, path: &Path) -> bool {
         let path = normalize_path(path);
         self.simulate_random_delay().await;
-        let state = self.state.lock();
-        if let Some((entry, _)) = state.try_read_path(&path, true) {
-            entry.lock().is_file()
+        let mut state = self.state.lock();
+        if let Some((entry, _)) = state.try_entry(&path, true) {
+            entry.is_file()
         } else {
             false
         }
@@ -2232,17 +2324,16 @@ impl Fs for FakeFs {
         let path = normalize_path(path);
         let mut state = self.state.lock();
         state.metadata_call_count += 1;
-        if let Some((mut entry, _)) = state.try_read_path(&path, false) {
-            let is_symlink = entry.lock().is_symlink();
+        if let Some((mut entry, _)) = state.try_entry(&path, false) {
+            let is_symlink = entry.is_symlink();
             if is_symlink {
-                if let Some(e) = state.try_read_path(&path, true).map(|e| e.0) {
+                if let Some(e) = state.try_entry(&path, true).map(|e| e.0) {
                     entry = e;
                 } else {
                     return Ok(None);
                 }
             }
 
-            let entry = entry.lock();
             Ok(Some(match &*entry {
                 FakeFsEntry::File {
                     inode, mtime, len, ..
@@ -2274,12 +2365,11 @@ impl Fs for FakeFs {
     async fn read_link(&self, path: &Path) -> Result<PathBuf> {
         self.simulate_random_delay().await;
         let path = normalize_path(path);
-        let state = self.state.lock();
+        let mut state = self.state.lock();
         let (entry, _) = state
-            .try_read_path(&path, false)
+            .try_entry(&path, false)
             .with_context(|| format!("path does not exist: {path:?}"))?;
-        let entry = entry.lock();
-        if let FakeFsEntry::Symlink { target } = &*entry {
+        if let FakeFsEntry::Symlink { target } = entry {
             Ok(target.clone())
         } else {
             anyhow::bail!("not a symlink: {path:?}")
@@ -2294,8 +2384,7 @@ impl Fs for FakeFs {
         let path = normalize_path(path);
         let mut state = self.state.lock();
         state.read_dir_call_count += 1;
-        let entry = state.read_path(&path)?;
-        let mut entry = entry.lock();
+        let entry = state.entry(&path)?;
         let children = entry.dir_entries(&path)?;
         let paths = children
             .keys()
@@ -2359,6 +2448,7 @@ impl Fs for FakeFs {
                     dot_git_path: abs_dot_git.to_path_buf(),
                     repository_dir_path: repository_dir_path.to_owned(),
                     common_dir_path: common_dir_path.to_owned(),
+                    checkpoints: Arc::default(),
                 }) as _
             },
         )

crates/git/Cargo.toml 🔗

@@ -12,7 +12,7 @@ workspace = true
 path = "src/git.rs"
 
 [features]
-test-support = []
+test-support = ["rand"]
 
 [dependencies]
 anyhow.workspace = true
@@ -26,6 +26,7 @@ http_client.workspace = true
 log.workspace = true
 parking_lot.workspace = true
 regex.workspace = true
+rand = { workspace = true, optional = true }
 rope.workspace = true
 schemars.workspace = true
 serde.workspace = true
@@ -47,3 +48,4 @@ text = { workspace = true, features = ["test-support"] }
 unindent.workspace = true
 gpui = { workspace = true, features = ["test-support"] }
 tempfile.workspace = true
+rand.workspace = true

crates/git/src/git.rs 🔗

@@ -119,6 +119,13 @@ impl Oid {
         Ok(Self(oid))
     }
 
+    #[cfg(any(test, feature = "test-support"))]
+    pub fn random(rng: &mut impl rand::Rng) -> Self {
+        let mut bytes = [0; 20];
+        rng.fill(&mut bytes);
+        Self::from_bytes(&bytes).unwrap()
+    }
+
     pub fn as_bytes(&self) -> &[u8] {
         self.0.as_bytes()
     }